Skip to content

Commit

Permalink
Merge branch 'optimize' into bandsolve
Browse files Browse the repository at this point in the history
  • Loading branch information
pablosanjose committed Jul 17, 2023
2 parents ef48b84 + 81f0bf6 commit 93fed20
Showing 1 changed file with 68 additions and 42 deletions.
110 changes: 68 additions & 42 deletions src/solvers/green/bands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,31 @@ function leading_zeros(d::Series{N}) where {N}
return 0
end

function trim_and_map(func, d::Series{N}, d´::Series{N}) where {N}
f, f´ = trim(d), trim(d´)
pow = min(f.pow, f´.pow)
t = ntuple(i -> func(f[pow + i - 1], f´[pow + i - 1]), Val(N))
return Series(t, pow)
end

Base.first(d::Series) = first(d.x)

Base.getindex(d::Series{N,T}, i::Integer) where {N,T} =
d.pow <= i < d.pow + N ? (@inbounds d.x[i-d.pow+1]) : zero(T)
function Base.getindex(d::Series{N,T}, i::Integer) where {N,T}
= i - d.pow + 1
checkbounds(Bool, d.x, i´) ? (@inbounds d.x[i´]) : zero(T)
end

Base.eltype(::Series{<:Any,T}) where {T} = T

Base.one(::Type{<:Series{N,T}}) where {N,T<:Number} = Series{N}(one(T))
Base.zero(::Type{<:Series{N,T}}) where {N,T} = Series(zero(SVector{N,T}), 0)
Base.iszero(d::Series) = iszero(d.x)
Base.transpose(d::Series) = d

function Base.:+(d::Series{N}, d´::Series{N}) where {N}
f, f´ = trim(d), trim(d´)
pow = min(f.pow, f´.pow)
t = ntuple(i -> f[pow + i - 1] + f´[pow + i - 1], Val(N))
Series(t, pow)
end
Base.transpose(d::Series) = d # act as a scalar

Base.:+(d::Series) = d
Base.:-(d::Series, d´::Series) = d + (-d´)
Base.:-(d::Series) = Series(.-(d.x), d.pow)
Base.:+(d::Series, d´::Series) = trim_and_map(+, d, d´)
Base.:-(d::Series, d´::Series) = trim_and_map(-, d, d´)
Base.:*(d::Number, d´::Series) = Series(d *.x, d´.pow)
Base.:*(d´::Series, d::Number) = Series(d *.x, d´.pow)
Base.:/(d::Series{N}, d´::Series{N}) where {N} = d * inv(d´)
Expand All @@ -79,7 +82,7 @@ function Base.:*(d::Series{N}, d´::Series{N}) where {N}
return Series(s, pow)
end

function Base.inv(d::Series{N}) where {N}
function Base.inv(d::Series)
= trim(d) # remove leading zeros
iszero(d´) && argerror("Divide by zero")
pow =.pow
Expand All @@ -91,10 +94,11 @@ function Base.inv(d::Series{N}) where {N}
end

# Ud = [x1 0 0 0; x2 x1 0 0; x3 x2 x1 0; x4 x3 x2 x1]

# product of two series d*d´ is Ud * d´.x
function product_matrix(s::SVector{N,T}) where {N,T}
function product_matrix(s::SVector{N}) where {N}
t = ntuple(Val(N)) do i
shiftpad(s, i-1)
shiftpad(s, i - 1)
end
return hcat(t...)
end
Expand All @@ -120,6 +124,25 @@ struct Simplex{D,T,S1,S2,S3,SU<:SMatrix{D,D,T}}
VD::T # D!V = |det(U)|
end

struct Expansions{N,TC<:NTuple{N},TJ,SJ}
cis::TC
J0::TJ
Jmat::SJ
end

# Precomputes the Series expansion coefficients for cis, J(z->0) and J(z)
function Expansions(::Val{N´}, ::Type{T}) where {N´,T} # here N´ = N-1
C = complex(T)
cis = ntuple(n -> C(im)^(n-1)/(factorial(n-1)), Val(N´+1))
J0 = ntuple(n -> C(-im)^n/(n*factorial(n)), Val(N´))
Jmat = ntuple(Val(N´*N´)) do ij
j, i = fldmod1(ij, N´)
j > i ? zero(C) : ifelse(isodd(i), 1, -1) * C(im)^(i-j) / (i*factorial(i-j))
end |> SMatrix{N´,N´,C}
return Expansions(cis, J0, Jmat)
end


function Simplex(ei::SVector{D´}, kij::SMatrix{D´,D,T}) where {D´,D,T}
eij = chop(ei' .- ei)
k0 = kij[1, :]
Expand All @@ -146,7 +169,8 @@ function is_valid_Q(Q, es, ks)
forin eachcol(Q)
phi = ks *
phis = phi' .- phi
for j in axes(es, 2), k in 1:j-1, l in 1:k-1
for j in axes(es, 2), k in axes(es, 1), l in axes(es, 1)
l != k != j && l != k || continue
eʲₖ = es[k,j]
eʲₗ = es[l,j]
(iszero(eʲₖ) || iszero(eʲₗ)) && continue
Expand All @@ -168,9 +192,10 @@ function g_simplex(val, ω, dn, s::Simplex{D}) where {D}
return g0, gk
end

g_simplex(val, ω, dn, s, β::Int) = g_simplex(val, ω, dn, s, s.phi´[:, β])
g_simplex(::Val{N}, ω, dn, s::Simplex{<:Any,T}, β::Int) where {N,T} =
g_simplex(Expansions(Val(N-1), T), ω, dn, s, s.phi´[:, β])

function g_simplex(::Val{N}, ω::Number, dn::SVector{D}, s::Simplex{D,T}, ϕ´verts::SVector) where {D,N,T}
function g_simplex(ex::Expansions{N}, ω::Number, dn::SVector{D}, s::Simplex{D,T}, ϕ´verts::SVector) where {D,N,T}
# phases ϕverts[j+1] will be perturbed by ϕ´verts[j+1]*dϕ, for j in 0:D
# Similartly, ϕedges[j+1,k+1] will be perturbed by ϕ´edges[j+1,k+1]*dϕ
ϕ´edges = ϕ´verts' .- ϕ´verts
Expand All @@ -180,13 +205,13 @@ function g_simplex(::Val{N}, ω::Number, dn::SVector{D}, s::Simplex{D,T}, ϕ´ve
Δverts = ω .- s.ei
eedges = s.eij
zedges = zkj_series.(ϕedges, eedges)
= cis_series.(ϕverts)
= cis_series.(ϕverts, Ref(ex))
γα = γα_series(ϕedges, zedges, eedges) # SMatrix{D´,D´}
if iszero(eedges) # special case, full energy degeneracy
Δ0 = chop(first(Δverts))
eγαJ = iszero(Δ0) ? zero(Series{N,complex(T)}) : im * sum(γα[1,:] .* eϕ) / Δ0
else
J = J_series.(zedges, eedges, transpose(Δverts)) # SMatrix{D´,D´}
J = J_series.(zedges, eedges, transpose(Δverts), Ref(ex)) # SMatrix{D´,D´}
eγαJ = sum(γα .* J .* transpose(eϕ)) # manual contraction is slower!
end
gsum = (-im)^(D+1) * s.VD * trim(chop(eγαJ))
Expand All @@ -195,19 +220,15 @@ end

zkj_series(ϕ, e) = iszero(e) ? ϕ : ϕ/e

function cis_series(z::Series{N}) where {N}
@inline function cis_series(z::Series{N}, ex) where {N}
@assert iszero(z.pow)
c = cis_series(z[0], Val(N))
c = cis_series(z[0], ex)
# Go from dz differential to dϕ
return rescale(c, z[1])
end

# Series of cis(ϕ)
function cis_series::Real, ::Val{N}) where {N}
E₀, Eᵢ = complex(1.0), ntuple(n -> im^n/(factorial(n)), Val(N-1))
E = cis(ϕ) * Series{N}(E₀, Eᵢ...)
return E
end
cis_series::Real, ex) = cis(ϕ) * Series(ex.cis)

@inline function γα_series(ϕedges::S, zedges::S, eedges::SMatrix{D´,D´}) where {N,T,D´,S<:SMatrix{D´,D´,Series{N,T}}}
# js = ks = SVector{D´}(1:D´)
Expand All @@ -225,11 +246,11 @@ end
return γα
end

function α⁻¹_series((j, k), zedges::SMatrix{D´,D´,T}, eedges) where {D´,T<:Series}
x = one(T)
function α⁻¹_series((j, k), zedges::SMatrix{D´,D´,S}, eedges) where {D´,S<:Series}
x = one(S)
@inbounds j != k && !iszero(eedges[k, j]) || return x
@inbounds for l in 1:
if l != j && !iszero(eedges[l, j])
if l != j && !iszero(eedges[l, j])
x *= eedges[l, j]
if l != k # ekj != 0, already constrained above
x *= zedges[l, j] - zedges[k, j]
Expand All @@ -239,8 +260,8 @@ function α⁻¹_series((j, k), zedges::SMatrix{D´,D´,T}, eedges) where {D´,T
return x
end

function γ⁻¹_series(j, ϕedges::SMatrix{D´,D´,T}, eedges) where {D´,T<:Series}
x = one(T)
function γ⁻¹_series(j, ϕedges::SMatrix{D´,D´,S}, eedges) where {D´,S<:Series}
x = one(S)
@inbounds for l in 1:
if l != j && iszero(eedges[l, j])
x *= ϕedges[l, j]
Expand All @@ -249,35 +270,40 @@ function γ⁻¹_series(j, ϕedges::SMatrix{D´,D´,T}, eedges) where {D´,T<:Se
return x
end

@inline function J_series(z::Series{N,T}, e, Δ) where {N,T}
@inline function J_series(z::Series{N,T}, e, Δ, ex) where {N,T}
iszero(e) && return zero(Series{N,Complex{T}})
J = J_series(z[0], Δ, Val(N))
J = J_series(z[0], Δ, ex)
# Go from d(zΔ) = dz*Δ differential to dϕ
return rescale(J, z[1] * Δ)
end

# Series of J(zΔ) = cis(zΔ) * [Ci(|z|Δ) - i Si(zΔ)] (variable zΔ for Series)
function J_series(z::T, Δ::T, ::Val{N}) where {N,T<:Number}
function J_series(z::T, Δ::T, ex::Expansions{N}) where {N,T<:Number}
C = complex(T)
iszero(Δ) && return Series{N}(C(Inf))
= z * Δ
imπ = im * ifelse> 0, 0, π) # strangely enough, union splitting is faster than stable
if iszero(zΔ)
J₀ = log(abs(Δ)) + imπ #+ MathConstants.γ + log(|z|) # not needed, cancels out
Jᵢ = ntuple(n -> (-im)^n/(n*factorial(n)), Val(N-1))
Jᵢ = ex.J0 # = ntuple(n -> (-im)^n/(n*factorial(n)), Val(N-1))
J = Series{N}(J₀, Jᵢ...)
E = cis_series(zΔ, Val(N))
E = cis_series(zΔ, ex)
EJ = E * J
else
ciszΔ = cis(zΔ)
J₀ = cosint(abs(zΔ)) - im*sinint(zΔ) + imπ
J₁ = conj(ciszΔ) .* ntuple(Val(N-1)) do n
ifelse(isodd(n), 1, -1) * sum(m -> im^m *^(m-n)/(n*factorial(m)), 0:n-1)
if N > 1
invzΔ = cumprod(ntuple(Returns(1/zΔ), Val(N-1)))
Jᵢ = Tuple(conj(ciszΔ) * (ex.Jmat * SVector(invzΔ)))
# Jᵢ = conj(ciszΔ) .* ntuple(Val(N-1)) do n
# (-1)^(n-1) * sum(m -> im^m * zΔ^(m-n)/(n*factorial(m)), 0:n-1)
# end
J = Series(J₀, Jᵢ...)
else
J = Series(J₀)
end
E₀ = ciszΔ
E₁ = ciszΔ .* ntuple(n -> im^n/factorial(n), Val(N-1))
J = Series(J₀, J₁...)
E = Series(E₀, E₁...)
Eᵢ = ciszΔ .* ex.cis
E = Series(Eᵢ)
EJ = E * J
end
return EJ
Expand Down

0 comments on commit 93fed20

Please sign in to comment.