diff --git a/src/greenfunction.jl b/src/greenfunction.jl index c92bfdcd..a091812d 100644 --- a/src/greenfunction.jl +++ b/src/greenfunction.jl @@ -125,14 +125,6 @@ getindex!(output, g::GreenSolution, i::Union{Integer,Colon}, j::Union{Integer,Co getindex!(output, g, i::AnyCellOrbitals, j::AnyCellOrbitals; post = identity) = (output .= post.(g[i, j]); output) -function getindex!(output::OrbitalSliceMatrix, g, i::AnyCellOrbitals, j::AnyCellOrbitals; kw...) - oi, oj = orbaxes(output) - rows, cols = orbrange(oi, cell(i)), orbrange(oj, cell(j)) - v = view(parent(output), rows, cols) - getindex!(v, g, i, j; kw...) - return output -end - # indexing over several cells getindex!(output, g, ci::AnyOrbitalSlice, cj::AnyOrbitalSlice; kw...) = getindex_cells!(output, g, cellsdict(ci), cellsdict(cj); kw...) @@ -141,25 +133,21 @@ getindex!(output, g, ci::AnyOrbitalSlice, cj::AnyCellOrbitals; kw...) = getindex!(output, g, ci::AnyCellOrbitals, cj::AnyOrbitalSlice; kw...) = getindex_cells!(output, g, (ci,), cellsdict(cj); kw...) -# we use the fact that the output AbstractMatrix is populated in the same site order of cis, cjs function getindex_cells!(output, g, cis, cjs; kw...) - ioffset = joffset = 0 - for ci in cis - ilen = length(ci) - irng = ioffset+1:ioffset+ilen - joffset = 0 - for cj in cjs - jlen = length(cj) - jrng = joffset+1:joffset+jlen - v = view(output, irng, jrng) - getindex!(v, g, ci, cj; kw...) - joffset += jlen - end - ioffset += ilen + for ci in cis, cj in cjs + getindex!(output, g, ci, cj; kw...) # will typically call the method below end return output end +function getindex!(output::OrbitalSliceMatrix, g, i::AnyCellOrbitals, j::AnyCellOrbitals; kw...) + oi, oj = orbaxes(output) + rows, cols = orbrange(oi, cell(i)), orbrange(oj, cell(j)) + v = view(parent(output), rows, cols) + getindex!(v, g, i, j; kw...) + return output +end + ## common getindex! shortcut in terms of GreenSlice # index object g over a slice enconded in gs::GreenSlice, using its preallocated output diff --git a/src/specialmatrices.jl b/src/specialmatrices.jl index 38009064..166ac96e 100644 --- a/src/specialmatrices.jl +++ b/src/specialmatrices.jl @@ -550,7 +550,6 @@ function Base.view(a::OrbitalSliceMatrix, i::AnyCellSites, j::AnyCellSites = i) i´, j´ = apply(i, lattice(rowslice)), apply(j, lattice(colslice)) rows = indexcollection(rowslice, i´) cols = j === i && rowslice === colslice ? rows : indexcollection(colslice, j´) - @show rows, cols return view(parent(a), rows, cols) end diff --git a/src/tools.jl b/src/tools.jl index 4269b1a6..ddb16691 100644 --- a/src/tools.jl +++ b/src/tools.jl @@ -138,20 +138,24 @@ lengths_to_offsets(f::Function, v) = prepend!(accumulate((i,j) -> i + f(j), v; i # fast tr(A*B) -trace_prod(A::AbstractMatrix, B::AbstractMatrix) = (check_sizes(A,B); sum(splat(*), zip(transpose(A), B))) -trace_prod(A::Number, B::AbstractMatrix) = A*tr(B) -trace_prod(A::AbstractMatrix, B::Number) = trace_prod(B, A) -trace_prod(A::UniformScaling, B::AbstractMatrix) = A.λ*tr(B) -trace_prod(A::AbstractMatrix, B::UniformScaling) = trace_prod(B, A) -trace_prod(A::Diagonal, B::Diagonal) = sum(i -> A[i] * B[i], axes(A,1)) -trace_prod(A::Diagonal, B::AbstractMatrix) = sum(i -> A[i] * B[i, i], axes(A,1)) -trace_prod(A::AbstractMatrix, B::Diagonal) = trace_prod(B, A) -trace_prod(A::Diagonal, B::Number) = only(A) * B -trace_prod(A::Number, B::Diagonal) = trace_prod(B, A) -trace_prod(A::Union{SMatrix,UniformScaling,Number}, B::Union{SMatrix,UniformScaling,Number}) = +trace_prod(A::AbstractMatrix, B::AbstractMatrix) = (check_sizes(A,B); unsafe_trace_prod(A, B)) + +unsafe_trace_prod(A::AbstractMatrix, B::AbstractMatrix) = sum(splat(*), zip(transpose(A), B)) +unsafe_trace_prod(A::Number, B::AbstractMatrix) = A*tr(B) +unsafe_trace_prod(A::AbstractMatrix, B::Number) = unsafe_trace_prod(B, A) +unsafe_trace_prod(A::UniformScaling, B::AbstractMatrix) = A.λ*tr(B) +unsafe_trace_prod(A::AbstractMatrix, B::UniformScaling) = unsafe_trace_prod(B, A) +unsafe_trace_prod(A::Diagonal, B::Diagonal) = sum(i -> A[i] * B[i], axes(A,1)) +unsafe_trace_prod(A::Diagonal, B::AbstractMatrix) = sum(i -> A[i] * B[i, i], axes(A,1)) +unsafe_trace_prod(A::AbstractMatrix, B::Diagonal) = unsafe_trace_prod(B, A) +unsafe_trace_prod(A::Diagonal, B::Number) = only(A) * B +unsafe_trace_prod(A::Number, B::Diagonal) = unsafe_trace_prod(B, A) +unsafe_trace_prod(A::Union{SMatrix,UniformScaling,Number}, B::Union{SMatrix,UniformScaling,Number}) = tr(A*B) -check_sizes(A,B) = size(A,2) == size(B,1) || throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))")) +check_sizes(A::AbstractMatrix,B::AbstractMatrix) = size(A,2) == size(B,1) || + throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))")) +check_sizes(_, _) = nothing # Taken from Base julia, now deprecated there function permute!!(a, p::AbstractVector{<:Integer}) diff --git a/test/test_greenfunction.jl b/test/test_greenfunction.jl index 0ae9d21b..b124f70e 100644 --- a/test/test_greenfunction.jl +++ b/test/test_greenfunction.jl @@ -378,6 +378,9 @@ end c = sites(SA[1], 1) view(gmat, c) @test (@allocations view(gmat, c)) <= 2 + i, j = orbaxes(gmat) + @test g(0.2)[i, j] isa Quantica.OrbitalSliceMatrix + @test gmat[c] isa Matrix end function testcond(g0; nambu = false)