Skip to content

Commit

Permalink
Bug fix: Fixed the scratch-space problems used by greens()
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnAshburner committed Mar 14, 2024
1 parent c3f43c7 commit f9a7883
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 86 deletions.
8 changes: 4 additions & 4 deletions Artifacts.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
[[pp_lib]]
arch = "x86_64"
os = "linux"
git-tree-sha1 = "b7e7b7b8229d424c912116ba5718a9a8adbcd7b3"
git-tree-sha1 = "a70ba3aeb5136eb997a2f001acea1bd59fde6470"
lazy = true
[[pp_lib.download]]
sha256 = "e871f2093922359f6e0740171500bde700527e0c25535edab79c1843cd65a13b"
sha256 = "c970514e7ffc427d54d8ca7065defcaffe05c1d3de645b5607ba26d80df81e7e"
#url = "file:///home/john/work/julia/pp_lib.tar.gz"
url = "https://raw.githubusercontent.com/spm/PushPull.jl/main/artifacts/pp_lib.tar.gz"

[[pp_lib]]
arch = "x86_64"
os = "windows"
git-tree-sha1 = "b7e7b7b8229d424c912116ba5718a9a8adbcd7b3"
git-tree-sha1 = "a70ba3aeb5136eb997a2f001acea1bd59fde6470"
lazy = true
[[pp_lib.download]]
sha256 = "e871f2093922359f6e0740171500bde700527e0c25535edab79c1843cd65a13b"
sha256 = "c970514e7ffc427d54d8ca7065defcaffe05c1d3de645b5607ba26d80df81e7e"
url = "https://raw.githubusercontent.com/spm/PushPull.jl/main/artifacts/pp_lib.tar.gz"

24 changes: 17 additions & 7 deletions artifacts/C/sparse_operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ void vel2mom_midd(USIZE_t i_start, USIZE_t i_stop, USIZE_t j_start, USIZE_t j_st
float *u, const USIZE_t *d, const float* v,
const int *offset, const int *length, const float *values, const int *indices)
{
USIZE_t k;
SSIZE_t k;
for(k=k_start; k<k_stop; k++)
{
USIZE_t j, kd = d[1]*k;
SSIZE_t j, kd = d[1]*k;
for(j=j_start; j<j_stop; j++)
{
USIZE_t i, jkd = d[0]*(j + kd);
SSIZE_t i, jkd = d[0]*(j + kd);
for(i=i_start; i<i_stop; i++)
{
USIZE_t ijk = jkd + i;
Expand All @@ -21,38 +21,48 @@ void vel2mom_midd(USIZE_t i_start, USIZE_t i_stop, USIZE_t j_start, USIZE_t j_st
}
}

/*
vel2mom_padded1(USIZE_t i, USIZE_t j, USIZE_t k, float *u, const USIZE_t *d, const float *v,
const int *offset, const int *length, const float *values, const int *patch_indices,
const USIZE_t *dp, const int *bnd)
*/
void vel2mom_edge(USIZE_t i_start, USIZE_t i_stop, USIZE_t j_start, USIZE_t j_stop, USIZE_t k_start, USIZE_t k_stop,
float *u, const USIZE_t *d, const float* v,
const int *offset, const int *length, const float *values, const int *patch_indices,
const USIZE_t *dp, const int *bnd)
{
USIZE_t i, j, k;
SSIZE_t i, j, k;
for(k=k_start; k<k_stop; k++)
for(j=j_start; j<j_stop; j++)
for(i=i_start; i<i_stop; i++)
vel2mom_padded1(i,j,k, u, d, v, offset, length, values, patch_indices, dp, bnd);
}


#define MIN(a,b) ((signed)(a)<(signed)(b) ? (a) : (b))
#define MAX(a,b) ((signed)(a)>(signed)(b) ? (a) : (b))

void vel2mom(float *u, const float* v, const USIZE_t *d, const USIZE_t *dp,
const int *offset, const int *length,
const float *values, const int *indices, const int *patch_indices, const int *bnd)
{
USIZE_t rs[3], re[3], i;
SSIZE_t rs[3], re[3], i;
for(i=0; i<3; i++)
{
rs[i] = MIN((dp[i]-1)/2,d[i]);
re[i] = MAX(rs[i],d[i]-(dp[i]-1)/2);
rs[i] = MIN((dp[i]+1)/2,d[i]);
re[i] = MAX(rs[i],(signed)d[i]-(dp[i]+1)/2);
}
vel2mom_edge( 0 , d[0], 0 , d[1], 0 , d[2], u, d, v, offset, length, values, patch_indices, dp, bnd);
if (0)
{
vel2mom_edge( 0 , d[0], 0 , d[1], 0 , rs[2], u, d, v, offset, length, values, patch_indices, dp, bnd);
vel2mom_edge( 0 , d[0], 0 ,rs[1], rs[2], re[2], u, d, v, offset, length, values, patch_indices, dp, bnd);
vel2mom_edge( 0 ,rs[0], rs[1],re[1], rs[2], re[2], u, d, v, offset, length, values, patch_indices, dp, bnd);
vel2mom_midd(rs[0],re[0], rs[1],re[1], rs[2], re[2], u, d, v, offset, length, values, indices);
vel2mom_edge(re[0], d[0], rs[1],re[1], rs[2], re[2], u, d, v, offset, length, values, patch_indices, dp, bnd);
vel2mom_edge( 0 , d[0], re[1], d[1], rs[2], re[2], u, d, v, offset, length, values, patch_indices, dp, bnd);
vel2mom_edge( 0 , d[0], 0 , d[1], re[2], d[2], u, d, v, offset, length, values, patch_indices, dp, bnd);
}
}


2 changes: 1 addition & 1 deletion artifacts/C/sparse_operator_dev.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ __device__ void relax1(float *v, const USIZE_t *d, const float *g, const float *


__device__ void vel2mom1(float *u, const USIZE_t *d, const float *v,
const int *offset, const int *length, const float *values, const int *indices)
const int *offset, const int *length, const float *values, const int *indices)
{
USIZE_t d3 = d[3], j;

Expand Down
Binary file modified artifacts/lib/liba64/sparse_operator.so
Binary file not shown.
Binary file modified artifacts/pp_lib.tar.gz
Binary file not shown.
117 changes: 80 additions & 37 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,37 +118,72 @@ Take the regularisation operator `L`, and (if necessary) reduce its
dimensions so that it fits with a set of image dimensions `d`.
"""
function reduce2fit!(L::Array{<:Real,4}, d::NTuple{3,Integer})
function reduce2fit!(L::Union{Array{<:Real,4},CuArray{<:Real,4}}, d::NTuple{3,Integer},
bnd::Array{<:Integer} = Int32.([2 1 1; 1 2 1; 1 1 2]))
dp = size(L)
c = Int32.((dp[1:3].+1)./2)
r0 = max.(c.-d,0)
r0 = max.(c.-d,0)

r = [1:r0[1]; (dp[1]+1-r0[1]):dp[1]]
if ~isempty(r)
L[[c[1]],:,:,:] .+= sum(L[r,:,:,:],dims=1)
L[ r ,:,:,:] .= 0
end
for dim=1:3
if bnd[1,dim]==0 || d[1]==1
r = [1:r0[1]; (dp[1]+1-r0[1]):dp[1]]
else
r = []
end
if ~isempty(r)
if bnd[1,dim]==2
w = ones(dp[1])
w[[c[1]-1,c[1]+1]] .= -1
w = reshape(w[r],(length(r),1,1))
L[[c[1]],:,:,dim] .+= sum(L[r,:,:,dim].*w,dims=1)
else
L[[c[1]],:,:,dim] .+= sum(L[r,:,:,dim],dims=1)
end
L[r,:,:,dim] .= 0
end

r = [1:r0[2]; (dp[2]+1-r0[2]):dp[2]]
if ~isempty(r)
L[:,[c[2]],:,:] .+= sum(L[:,r,:,:],dims=2)
L[:, r ,:,:] .= 0
end
if bnd[2,dim]==0 || d[2]==1
r = [1:r0[2]; (dp[2]+1-r0[2]):dp[2]]
else
r = []
end
if ~isempty(r)
if bnd[2,dim]==2
w = ones(dp[2])
w[[c[2]-1,c[2]+1]] .= -1
w = reshape(w[r],(1,length(r),1))
L[:,[c[2]],:,dim] .+= sum(L[:,r,:,dim].*w,dims=2)
else
L[:,[c[2]],:,dim] .+= sum(L[:,r,:,dim],dims=2)
end
L[:,r,:,dim] .= 0
end

r = [1:r0[3]; (dp[3]+1-r0[3]):dp[3]]
if ~isempty(r)
L[:,:,[c[3]],:] .+= sum(L[:,:,r,:],dims=3)
L[:,:, r ,:] .= 0
end
if any(r0.>0)
r1 = (r0[1]+1):(dp[1]-r0[1])
r2 = (r0[2]+1):(dp[2]-r0[2])
r3 = (r0[3]+1):(dp[3]-r0[3])
L = L[r1,r2,r3,:]
if bnd[3,dim]==0 || d[3]==1
r = [1:r0[3]; (dp[3]+1-r0[3]):dp[3]]
else
r = []
end
if ~isempty(r)
if bnd[3,dim]==2
w = ones(dp[3])
w[[c[3]-1,c[3]+1]] .= -1
w = reshape(w[r],(1,1,length(r)))
L[:,:,[c[3]],dim] .+= sum(L[:,:,r,dim].*w,dims=3)
else
L[:,:,[c[3]],dim] .+= sum(L[:,:,r,dim],dims=3)
end
L[:,:, r,dim] .= 0
end
end
return L
msk = sum(L.!=0,dims=4)
i1 = sum(msk,dims=(2,3))[:] .!= 0
i2 = sum(msk,dims=(1,3))[:] .!= 0
i3 = sum(msk,dims=(1,2))[:] .!= 0
return L[i1[:],i2[:],i3[:],:]
end


"""
sparsify(L::Array{<:Real,4}, d::NTuple{3,Integer}, nd=3)
Expand Down Expand Up @@ -194,7 +229,6 @@ where:
"""
function sparsify(L::Array{<:Real,4}, d::NTuple{3,Integer}, nd=3)
#L = reduce2fit!(L,d)
dp = (size(L,1),size(L,2),size(L,3))
@assert(all(rem.(dp,2).==1),"First three dimensions of `L` must be odd")
@assert(size(L,4)==3 || size(L,4)==6, "Incorrectly sized `L`")
Expand Down Expand Up @@ -308,22 +342,26 @@ function greens(L::Union{CuArray{Float32,4},Array{Float32,4}}, d::NTuple{3,Integ

function scratchlen(dl,d)
# Could re-order the fft to reduce memory requirements
d1 = [dl...]
d2 = deepcopy(d1)
len = 0;
d0 = [dl...]
d1 = deepcopy(d0)
len1 = 0
for i=1:3
d2[i] = 2*d[i]
d1[i] = 2*d[i]
len1 = max(len1,prod(d1))
d1[i] = d[i]+1
len = max(len,prod(d2)+prod(d1))
d2[i] = d1[i]
end
return len
len0 = 0
for i=1:3
d0[i] = d[i]+1
len0 = max(len0,prod(d0))
end
return len0,len1
end

dl = size(L)
dl = size(L)
# TODO: Determine the actual maximum amount of memory required
#sl = 2*prod(d.+1)+prod(d.+1)
sl = scratchlen(size(L)[1:3],d)
sl = sum(scratchlen(size(L)[1:3],d))
if isa(L,CuArray)
K = Array{CuArray{Float32,3}}(undef, dl[4])
scratch = CUDA.zeros(ComplexF32,sl)
Expand All @@ -339,7 +377,7 @@ function greens(L::Union{CuArray{Float32,4},Array{Float32,4}}, d::NTuple{3,Integ
dl0 = [dl...]
dl0[dim] = length(r[dim])
L0 = reshape(view(scratch,1:prod(dl0)),dl0...)
o = prod(dl0)
o,unused = scratchlen(dl,d)
dl0[dim] = 2*d[dim]
L1 = reshape(view(scratch,(o+1):(o+prod(dl0))),dl0...)
L1 .= 0
Expand All @@ -353,7 +391,9 @@ function greens(L::Union{CuArray{Float32,4},Array{Float32,4}}, d::NTuple{3,Integ
T = T<:Complex ? T : Complex{T}
#L1 = isa(L,CuArray) ? CUDA.zeros(T, (dl1...)) : zeros(T, (dl1...))
L0,L1 = scratch_array(size(L),d,r,dim)
indices = mod.((1:size(L,dim)).-round(Int,(size(L,dim)+1)/2),2*d[dim]).+1
bc = 0
c = round(Int,(size(L,dim)+1)/2)
indices = mod.((1:size(L,dim)).-c, 2*d[dim]).+1
ind0 = [UnitRange.(1,size(L))...]
ind1 = deepcopy(ind0)
for i=1:length(indices)
Expand Down Expand Up @@ -381,9 +421,11 @@ function greens(L::Union{CuArray{Float32,4},Array{Float32,4}}, d::NTuple{3,Integ
K[i] = padft(view(L, :,:,:,i), d, (r...,))
K[i] .^= (-1)
r[i] = ri
if ~all(isfinite.(Array(view(K[i],1:1))))
CUDA.@allowscalar K[i][1:1] .= 0.0f0
end
end
else
# TODO: Work out why there are three GPU allocations more than expected
if isa(L,CuArray)
F = Array{CuArray{Float32,3}}(undef, dl[4])
else
Expand All @@ -395,6 +437,8 @@ function greens(L::Union{CuArray{Float32,4},Array{Float32,4}}, d::NTuple{3,Integ
end

# Re-use scratch
# Note that dF and t are real, but scratch is complex.
# Need to figure out a way of using the same memory as either real or complex.
d1 = d[1:3].+1
dF = reshape(view(scratch, 1:prod(d1)), d1...)
t = reshape(view(scratch, (length(dF)+1):(2*length(dF))), size(dF)...)
Expand Down Expand Up @@ -445,7 +489,6 @@ end

function kernel(d::NTuple{3,Integer}, vx::Vector{<:Real}=[1,1,1], λ::Vector{<:Real}=[0,1,0,0])
L = registration_operator(vx,λ)
L = reduce2fit!(L,d)
K = greens(L, d)
end

Expand Down
36 changes: 19 additions & 17 deletions src/operator_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,26 @@ end
function vel2mom!(u::Array{Float32,4},
v::Array{Float32,4},
kernel::NamedTuple{(:stride, :d, :nchan, :offset, :length, :values, :indices, :patch_indices),
Tuple{NTuple{3, Int64}, NTuple{3, Int64}, Int64, Matrix{Int32}, Matrix{Int32},
Vector{Float32}, Vector{Int32}, Vector{Int32}}},
Tuple{NTuple{3, Int64}, NTuple{3, Int64}, Int64, Matrix{Int32}, Matrix{Int32},
Vector{Float32}, Vector{Int32}, Vector{Int32}}},
bnd::Array{<:Integer} = Int32.([2 1 1; 1 2 1; 1 1 2]))

global oplib

@assert(all(size(u).==size(v)))
@assert(all(kernel.d .== size(v)[1:3]))
@assert(kernel.nchan == size(v,4))
d = Csize_t.([size(v)..., 1])
dp = Csize_t.([kernel.stride...])
d = Csize_t.([size(v)..., 1])
dp = Csize_t.([kernel.stride...])
bnd = Int32.(bnd[:])

ccall(dlsym(oplib,:vel2mom), Cvoid,
(Ref{Cfloat}, Ptr{Cfloat}, Ptr{Csize_t}, Ptr{Csize_t},
Ptr{Cint}, Ptr{Cint},
Ptr{Cfloat}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}),
pointer(u), pointer(v), pointer(d), pointer(dp),
pointer(kernel.offset), pointer(kernel.length),
pointer(kernel.values), pointer(kernel.indices), pointer(kernel.patch_indices), pointer(bnd[:]))
pointer(kernel.values), pointer(kernel.indices), pointer(kernel.patch_indices), pointer(bnd))
return u
end

Expand All @@ -95,8 +96,9 @@ function vel2mom!(u::CuArray{Float32,4},
bnd::Array{<:Integer} = Int32.([2 1 1; 1 2 1; 1 1 2]))

function run_kernel(fun, threads, r, u, v)
o = UInt64.(first.(r) .- 1)
n = UInt64.(max.((last.(r) .- first.(r)).+1,0))
# Computations using zero-offset
o = UInt64.(first.(r))
n = UInt64.(max.((last.(r) .- first.(r)),0))
n1 = prod(n)
if n1>0
setindex!(CuGlobal{NTuple{3,UInt64}}(opmod,"o"), o)
Expand Down Expand Up @@ -129,34 +131,34 @@ function vel2mom!(u::CuArray{Float32,4},

d = kernel.d
dp = kernel.stride
rs = Int.((dp.+1)./2) # Start of middle block
re = Int.(d.-(dp.-1)./2) # End of middle block
rs = min.(Int.((dp.+1)./2),d) # Start of middle block
re = max.(rs,Int.(d.-(dp.+1)./2)) # End of middle block

if any(re.<rs)
if any(re.<=rs)
# No middle block
r = UnitRange.(1,d)
r = UnitRange.(0,d)
run_kernel(cuVel2momPad, threads_pad, r, u, v)
else
r = (UnitRange(1,d[1]), UnitRange(1,d[2]), UnitRange(1,rs[3]-1))
r = (UnitRange(0,d[1]), UnitRange(0,d[2]), UnitRange(0,rs[3]))
run_kernel(cuVel2momPad, threads_pad, r, u, v)

r = (UnitRange(1,d[1]), UnitRange(1,rs[2]-1), UnitRange(rs[3],re[3]))
r = (UnitRange(0,d[1]), UnitRange(0,rs[2]), UnitRange(rs[3],re[3]))
run_kernel(cuVel2momPad, threads_pad, r, u, v)

r = (UnitRange(1,rs[1]-1), UnitRange(rs[2],re[2]), UnitRange(rs[3],re[3]))
r = (UnitRange(0,rs[1]), UnitRange(rs[2],re[2]), UnitRange(rs[3],re[3]))
run_kernel(cuVel2momPad, threads_pad, r, u, v)

# Middle block
r = UnitRange.(rs,re)
run_kernel(cuVel2mom, threads_nopad, r, u, v)

r = (UnitRange(re[1]+1,d[1]), UnitRange(rs[2],re[2]), UnitRange(rs[3],re[3]))
r = (UnitRange(re[1],d[1]), UnitRange(rs[2],re[2]), UnitRange(rs[3],re[3]))
run_kernel(cuVel2momPad, threads_pad, r, u, v)

r = (UnitRange(1,d[1]), UnitRange(re[2]+1,d[2]), UnitRange(rs[3],re[3]))
r = (UnitRange(0,d[1]), UnitRange(re[2],d[2]), UnitRange(rs[3],re[3]))
run_kernel(cuVel2momPad, threads_pad, r, u, v)

r = (UnitRange(1,d[1]), UnitRange(1,d[2]), UnitRange(re[3]+1,d[3]))
r = (UnitRange(0,d[1]), UnitRange(0,d[2]), UnitRange(re[3],d[3]))
run_kernel(cuVel2momPad, threads_pad, r, u, v)
end
return v
Expand Down
Loading

0 comments on commit f9a7883

Please sign in to comment.