Skip to content

Commit

Permalink
sitepairs test
Browse files Browse the repository at this point in the history
  • Loading branch information
pablosanjose committed Oct 24, 2024
1 parent da8f8ff commit 5b24741
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 35 deletions.
32 changes: 10 additions & 22 deletions src/greenfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,6 @@ getindex!(output, g::GreenSolution, i::Union{Integer,Colon}, j::Union{Integer,Co
getindex!(output, g, i::AnyCellOrbitals, j::AnyCellOrbitals; post = identity) =
(output .= post.(g[i, j]); output)

function getindex!(output::OrbitalSliceMatrix, g, i::AnyCellOrbitals, j::AnyCellOrbitals; kw...)
oi, oj = orbaxes(output)
rows, cols = orbrange(oi, cell(i)), orbrange(oj, cell(j))
v = view(parent(output), rows, cols)
getindex!(v, g, i, j; kw...)
return output
end

# indexing over several cells
getindex!(output, g, ci::AnyOrbitalSlice, cj::AnyOrbitalSlice; kw...) =
getindex_cells!(output, g, cellsdict(ci), cellsdict(cj); kw...)
Expand All @@ -141,25 +133,21 @@ getindex!(output, g, ci::AnyOrbitalSlice, cj::AnyCellOrbitals; kw...) =
getindex!(output, g, ci::AnyCellOrbitals, cj::AnyOrbitalSlice; kw...) =
getindex_cells!(output, g, (ci,), cellsdict(cj); kw...)

# we use the fact that the output AbstractMatrix is populated in the same site order of cis, cjs
function getindex_cells!(output, g, cis, cjs; kw...)
ioffset = joffset = 0
for ci in cis
ilen = length(ci)
irng = ioffset+1:ioffset+ilen
joffset = 0
for cj in cjs
jlen = length(cj)
jrng = joffset+1:joffset+jlen
v = view(output, irng, jrng)
getindex!(v, g, ci, cj; kw...)
joffset += jlen
end
ioffset += ilen
for ci in cis, cj in cjs
getindex!(output, g, ci, cj; kw...) # will typically call the method below
end
return output
end

function getindex!(output::OrbitalSliceMatrix, g, i::AnyCellOrbitals, j::AnyCellOrbitals; kw...)
oi, oj = orbaxes(output)
rows, cols = orbrange(oi, cell(i)), orbrange(oj, cell(j))
v = view(parent(output), rows, cols)
getindex!(v, g, i, j; kw...)
return output
end

## common getindex! shortcut in terms of GreenSlice

# index object g over a slice enconded in gs::GreenSlice, using its preallocated output
Expand Down
1 change: 0 additions & 1 deletion src/specialmatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,6 @@ function Base.view(a::OrbitalSliceMatrix, i::AnyCellSites, j::AnyCellSites = i)
i´, j´ = apply(i, lattice(rowslice)), apply(j, lattice(colslice))
rows = indexcollection(rowslice, i´)
cols = j === i && rowslice === colslice ? rows : indexcollection(colslice, j´)
@show rows, cols
return view(parent(a), rows, cols)
end

Expand Down
28 changes: 16 additions & 12 deletions src/tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,24 @@ lengths_to_offsets(f::Function, v) = prepend!(accumulate((i,j) -> i + f(j), v; i


# fast tr(A*B)
trace_prod(A::AbstractMatrix, B::AbstractMatrix) = (check_sizes(A,B); sum(splat(*), zip(transpose(A), B)))
trace_prod(A::Number, B::AbstractMatrix) = A*tr(B)
trace_prod(A::AbstractMatrix, B::Number) = trace_prod(B, A)
trace_prod(A::UniformScaling, B::AbstractMatrix) = A.λ*tr(B)
trace_prod(A::AbstractMatrix, B::UniformScaling) = trace_prod(B, A)
trace_prod(A::Diagonal, B::Diagonal) = sum(i -> A[i] * B[i], axes(A,1))
trace_prod(A::Diagonal, B::AbstractMatrix) = sum(i -> A[i] * B[i, i], axes(A,1))
trace_prod(A::AbstractMatrix, B::Diagonal) = trace_prod(B, A)
trace_prod(A::Diagonal, B::Number) = only(A) * B
trace_prod(A::Number, B::Diagonal) = trace_prod(B, A)
trace_prod(A::Union{SMatrix,UniformScaling,Number}, B::Union{SMatrix,UniformScaling,Number}) =
trace_prod(A::AbstractMatrix, B::AbstractMatrix) = (check_sizes(A,B); unsafe_trace_prod(A, B))

unsafe_trace_prod(A::AbstractMatrix, B::AbstractMatrix) = sum(splat(*), zip(transpose(A), B))
unsafe_trace_prod(A::Number, B::AbstractMatrix) = A*tr(B)
unsafe_trace_prod(A::AbstractMatrix, B::Number) = unsafe_trace_prod(B, A)
unsafe_trace_prod(A::UniformScaling, B::AbstractMatrix) = A.λ*tr(B)
unsafe_trace_prod(A::AbstractMatrix, B::UniformScaling) = unsafe_trace_prod(B, A)
unsafe_trace_prod(A::Diagonal, B::Diagonal) = sum(i -> A[i] * B[i], axes(A,1))
unsafe_trace_prod(A::Diagonal, B::AbstractMatrix) = sum(i -> A[i] * B[i, i], axes(A,1))
unsafe_trace_prod(A::AbstractMatrix, B::Diagonal) = unsafe_trace_prod(B, A)
unsafe_trace_prod(A::Diagonal, B::Number) = only(A) * B
unsafe_trace_prod(A::Number, B::Diagonal) = unsafe_trace_prod(B, A)
unsafe_trace_prod(A::Union{SMatrix,UniformScaling,Number}, B::Union{SMatrix,UniformScaling,Number}) =
tr(A*B)

check_sizes(A,B) = size(A,2) == size(B,1) || throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))"))
check_sizes(A::AbstractMatrix,B::AbstractMatrix) = size(A,2) == size(B,1) ||
throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))"))
check_sizes(_, _) = nothing

# Taken from Base julia, now deprecated there
function permute!!(a, p::AbstractVector{<:Integer})
Expand Down
3 changes: 3 additions & 0 deletions test/test_greenfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ end
c = sites(SA[1], 1)
view(gmat, c)
@test (@allocations view(gmat, c)) <= 2
i, j = orbaxes(gmat)
@test g(0.2)[i, j] isa Quantica.OrbitalSliceMatrix
@test gmat[c] isa Matrix
end

function testcond(g0; nambu = false)
Expand Down

0 comments on commit 5b24741

Please sign in to comment.