diff --git a/src/solvers/green/bands.jl b/src/solvers/green/bands.jl index 136ddc92..de9d7aa2 100644 --- a/src/solvers/green/bands.jl +++ b/src/solvers/green/bands.jl @@ -44,28 +44,31 @@ function leading_zeros(d::Series{N}) where {N} return 0 end +function trim_and_map(func, d::Series{N}, d´::Series{N}) where {N} + f, f´ = trim(d), trim(d´) + pow = min(f.pow, f´.pow) + t = ntuple(i -> func(f[pow + i - 1], f´[pow + i - 1]), Val(N)) + return Series(t, pow) +end + Base.first(d::Series) = first(d.x) -Base.getindex(d::Series{N,T}, i::Integer) where {N,T} = - d.pow <= i < d.pow + N ? (@inbounds d.x[i-d.pow+1]) : zero(T) +function Base.getindex(d::Series{N,T}, i::Integer) where {N,T} + i´ = i - d.pow + 1 + checkbounds(Bool, d.x, i´) ? (@inbounds d.x[i´]) : zero(T) +end Base.eltype(::Series{<:Any,T}) where {T} = T Base.one(::Type{<:Series{N,T}}) where {N,T<:Number} = Series{N}(one(T)) Base.zero(::Type{<:Series{N,T}}) where {N,T} = Series(zero(SVector{N,T}), 0) Base.iszero(d::Series) = iszero(d.x) -Base.transpose(d::Series) = d - -function Base.:+(d::Series{N}, d´::Series{N}) where {N} - f, f´ = trim(d), trim(d´) - pow = min(f.pow, f´.pow) - t = ntuple(i -> f[pow + i - 1] + f´[pow + i - 1], Val(N)) - Series(t, pow) -end +Base.transpose(d::Series) = d # act as a scalar Base.:+(d::Series) = d -Base.:-(d::Series, d´::Series) = d + (-d´) Base.:-(d::Series) = Series(.-(d.x), d.pow) +Base.:+(d::Series, d´::Series) = trim_and_map(+, d, d´) +Base.:-(d::Series, d´::Series) = trim_and_map(-, d, d´) Base.:*(d::Number, d´::Series) = Series(d * d´.x, d´.pow) Base.:*(d´::Series, d::Number) = Series(d * d´.x, d´.pow) Base.:/(d::Series{N}, d´::Series{N}) where {N} = d * inv(d´) @@ -79,7 +82,7 @@ function Base.:*(d::Series{N}, d´::Series{N}) where {N} return Series(s, pow) end -function Base.inv(d::Series{N}) where {N} +function Base.inv(d::Series) d´ = trim(d) # remove leading zeros iszero(d´) && argerror("Divide by zero") pow = d´.pow @@ -91,10 +94,11 @@ function Base.inv(d::Series{N}) where {N} end # Ud = [x1 0 0 0; x2 x1 0 0; x3 x2 x1 0; x4 x3 x2 x1] + # product of two series d*d´ is Ud * d´.x -function product_matrix(s::SVector{N,T}) where {N,T} +function product_matrix(s::SVector{N}) where {N} t = ntuple(Val(N)) do i - shiftpad(s, i-1) + shiftpad(s, i - 1) end return hcat(t...) end @@ -120,6 +124,25 @@ struct Simplex{D,T,S1,S2,S3,SU<:SMatrix{D,D,T}} VD::T # D!V = |det(U)| end +struct Expansions{N,TC<:NTuple{N},TJ,SJ} + cis::TC + J0::TJ + Jmat::SJ +end + +# Precomputes the Series expansion coefficients for cis, J(z->0) and J(z) +function Expansions(::Val{N´}, ::Type{T}) where {N´,T} # here N´ = N-1 + C = complex(T) + cis = ntuple(n -> C(im)^(n-1)/(factorial(n-1)), Val(N´+1)) + J0 = ntuple(n -> C(-im)^n/(n*factorial(n)), Val(N´)) + Jmat = ntuple(Val(N´*N´)) do ij + j, i = fldmod1(ij, N´) + j > i ? zero(C) : ifelse(isodd(i), 1, -1) * C(im)^(i-j) / (i*factorial(i-j)) + end |> SMatrix{N´,N´,C} + return Expansions(cis, J0, Jmat) +end + + function Simplex(ei::SVector{D´}, kij::SMatrix{D´,D,T}) where {D´,D,T} eij = chop(ei' .- ei) k0 = kij[1, :] @@ -146,7 +169,8 @@ function is_valid_Q(Q, es, ks) for qβ in eachcol(Q) phi = ks * qβ phis = phi' .- phi - for j in axes(es, 2), k in 1:j-1, l in 1:k-1 + for j in axes(es, 2), k in axes(es, 1), l in axes(es, 1) + l != k != j && l != k || continue eʲₖ = es[k,j] eʲₗ = es[l,j] (iszero(eʲₖ) || iszero(eʲₗ)) && continue @@ -168,9 +192,10 @@ function g_simplex(val, ω, dn, s::Simplex{D}) where {D} return g0, gk end -g_simplex(val, ω, dn, s, β::Int) = g_simplex(val, ω, dn, s, s.phi´[:, β]) +g_simplex(::Val{N}, ω, dn, s::Simplex{<:Any,T}, β::Int) where {N,T} = + g_simplex(Expansions(Val(N-1), T), ω, dn, s, s.phi´[:, β]) -function g_simplex(::Val{N}, ω::Number, dn::SVector{D}, s::Simplex{D,T}, ϕ´verts::SVector) where {D,N,T} +function g_simplex(ex::Expansions{N}, ω::Number, dn::SVector{D}, s::Simplex{D,T}, ϕ´verts::SVector) where {D,N,T} # phases ϕverts[j+1] will be perturbed by ϕ´verts[j+1]*dϕ, for j in 0:D # Similartly, ϕedges[j+1,k+1] will be perturbed by ϕ´edges[j+1,k+1]*dϕ ϕ´edges = ϕ´verts' .- ϕ´verts @@ -180,13 +205,13 @@ function g_simplex(::Val{N}, ω::Number, dn::SVector{D}, s::Simplex{D,T}, ϕ´ve Δverts = ω .- s.ei eedges = s.eij zedges = zkj_series.(ϕedges, eedges) - eϕ = cis_series.(ϕverts) + eϕ = cis_series.(ϕverts, Ref(ex)) γα = γα_series(ϕedges, zedges, eedges) # SMatrix{D´,D´} if iszero(eedges) # special case, full energy degeneracy Δ0 = chop(first(Δverts)) eγαJ = iszero(Δ0) ? zero(Series{N,complex(T)}) : im * sum(γα[1,:] .* eϕ) / Δ0 else - J = J_series.(zedges, eedges, transpose(Δverts)) # SMatrix{D´,D´} + J = J_series.(zedges, eedges, transpose(Δverts), Ref(ex)) # SMatrix{D´,D´} eγαJ = sum(γα .* J .* transpose(eϕ)) # manual contraction is slower! end gsum = (-im)^(D+1) * s.VD * trim(chop(eγαJ)) @@ -195,19 +220,15 @@ end zkj_series(ϕ, e) = iszero(e) ? ϕ : ϕ/e -function cis_series(z::Series{N}) where {N} +@inline function cis_series(z::Series{N}, ex) where {N} @assert iszero(z.pow) - c = cis_series(z[0], Val(N)) + c = cis_series(z[0], ex) # Go from dz differential to dϕ return rescale(c, z[1]) end # Series of cis(ϕ) -function cis_series(ϕ::Real, ::Val{N}) where {N} - E₀, Eᵢ = complex(1.0), ntuple(n -> im^n/(factorial(n)), Val(N-1)) - E = cis(ϕ) * Series{N}(E₀, Eᵢ...) - return E -end +cis_series(ϕ::Real, ex) = cis(ϕ) * Series(ex.cis) @inline function γα_series(ϕedges::S, zedges::S, eedges::SMatrix{D´,D´}) where {N,T,D´,S<:SMatrix{D´,D´,Series{N,T}}} # js = ks = SVector{D´}(1:D´) @@ -225,11 +246,11 @@ end return γα end -function α⁻¹_series((j, k), zedges::SMatrix{D´,D´,T}, eedges) where {D´,T<:Series} - x = one(T) +function α⁻¹_series((j, k), zedges::SMatrix{D´,D´,S}, eedges) where {D´,S<:Series} + x = one(S) @inbounds j != k && !iszero(eedges[k, j]) || return x @inbounds for l in 1:D´ - if l != j && !iszero(eedges[l, j]) + if l != j && !iszero(eedges[l, j]) x *= eedges[l, j] if l != k # ekj != 0, already constrained above x *= zedges[l, j] - zedges[k, j] @@ -239,8 +260,8 @@ function α⁻¹_series((j, k), zedges::SMatrix{D´,D´,T}, eedges) where {D´,T return x end -function γ⁻¹_series(j, ϕedges::SMatrix{D´,D´,T}, eedges) where {D´,T<:Series} - x = one(T) +function γ⁻¹_series(j, ϕedges::SMatrix{D´,D´,S}, eedges) where {D´,S<:Series} + x = one(S) @inbounds for l in 1:D´ if l != j && iszero(eedges[l, j]) x *= ϕedges[l, j] @@ -249,35 +270,40 @@ function γ⁻¹_series(j, ϕedges::SMatrix{D´,D´,T}, eedges) where {D´,T<:Se return x end -@inline function J_series(z::Series{N,T}, e, Δ) where {N,T} +@inline function J_series(z::Series{N,T}, e, Δ, ex) where {N,T} iszero(e) && return zero(Series{N,Complex{T}}) - J = J_series(z[0], Δ, Val(N)) + J = J_series(z[0], Δ, ex) # Go from d(zΔ) = dz*Δ differential to dϕ return rescale(J, z[1] * Δ) end # Series of J(zΔ) = cis(zΔ) * [Ci(|z|Δ) - i Si(zΔ)] (variable zΔ for Series) -function J_series(z::T, Δ::T, ::Val{N}) where {N,T<:Number} +function J_series(z::T, Δ::T, ex::Expansions{N}) where {N,T<:Number} C = complex(T) iszero(Δ) && return Series{N}(C(Inf)) zΔ = z * Δ imπ = im * ifelse(Δ > 0, 0, π) # strangely enough, union splitting is faster than stable if iszero(zΔ) J₀ = log(abs(Δ)) + imπ #+ MathConstants.γ + log(|z|) # not needed, cancels out - Jᵢ = ntuple(n -> (-im)^n/(n*factorial(n)), Val(N-1)) + Jᵢ = ex.J0 # = ntuple(n -> (-im)^n/(n*factorial(n)), Val(N-1)) J = Series{N}(J₀, Jᵢ...) - E = cis_series(zΔ, Val(N)) + E = cis_series(zΔ, ex) EJ = E * J else ciszΔ = cis(zΔ) J₀ = cosint(abs(zΔ)) - im*sinint(zΔ) + imπ - J₁ = conj(ciszΔ) .* ntuple(Val(N-1)) do n - ifelse(isodd(n), 1, -1) * sum(m -> im^m * zΔ^(m-n)/(n*factorial(m)), 0:n-1) + if N > 1 + invzΔ = cumprod(ntuple(Returns(1/zΔ), Val(N-1))) + Jᵢ = Tuple(conj(ciszΔ) * (ex.Jmat * SVector(invzΔ))) + # Jᵢ = conj(ciszΔ) .* ntuple(Val(N-1)) do n + # (-1)^(n-1) * sum(m -> im^m * zΔ^(m-n)/(n*factorial(m)), 0:n-1) + # end + J = Series(J₀, Jᵢ...) + else + J = Series(J₀) end - E₀ = ciszΔ - E₁ = ciszΔ .* ntuple(n -> im^n/factorial(n), Val(N-1)) - J = Series(J₀, J₁...) - E = Series(E₀, E₁...) + Eᵢ = ciszΔ .* ex.cis + E = Series(Eᵢ) EJ = E * J end return EJ