Skip to content

Commit

Permalink
much faster slic; smaller S
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Dec 2, 2023
1 parent 14f2784 commit d14ac08
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 23 deletions.
1 change: 1 addition & 0 deletions dev/named_full_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def get_unique_index(nl_hi,nl_wi,hi,wi,

# -- divide out stride0 --
vprint("pre: ",ws_i,ws_j)
# if not(oob_i or oob_j):
if not(and_oob):
ws_i = ws_i//stride0
ws_j = ws_j//stride0
Expand Down
65 changes: 48 additions & 17 deletions lib/csrc/agg/scatter_labels_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,83 @@ __device__ __forceinline__
void get_unique_index(int& li, bool& oob,
int nl_hi,int nl_wi,int hi,int wi,
int wsOff_h,int wsOff_w,int time_offset,
int stride1,int ws,int wsHalf,bool full_ws){
int stride0, int stride1,int ws,int wsHalf,bool full_ws){

// -- init --
int ws_i = -1;
int ws_j = -1;

// -- check spatial coordinates --
int num_h = abs(nl_hi - hi)/stride1;
num_h = (nl_hi >= hi) ? num_h : -num_h;
int num_w = abs(nl_wi - wi)/stride1;
num_w = (nl_wi >= wi) ? num_w : -num_w;
int num_h = nl_hi - hi;//stride1;
// num_h = (nl_hi >= hi) ? num_h : -num_h;
int num_w = nl_wi - wi;//stride1;
// num_w = (nl_wi >= wi) ? num_w : -num_w;

// -- check oob --
int wsNum = (ws-1)/stride0+1;
bool oob_i = abs(num_h) > wsHalf;
bool oob_j = abs(num_w) > wsHalf;
oob = (oob_i or oob_j) and full_ws;
bool and_oob = oob_i and oob_j and full_ws;
bool xor_oob = (oob_i or oob_j) and not(oob_i and oob_j) and full_ws;

// -- oob names --
if (oob_i and oob_j){

// -- di,dj --
int di = wsHalf - abs(wsHalf - wsOff_h);
int dj = wsHalf - abs(wsHalf - wsOff_w);
// // -- di,dj --
int adj_h = wsHalf - wsOff_h;
int adj_w = wsHalf - wsOff_w;

// -- small square --
int mi = di + wsHalf*dj;
ws_i = mi % ws;
ws_j = mi / ws + (ws-1);
// int di = wsHalf - abs(wsHalf - wsOff_h);
// int dj = wsHalf - abs(wsHalf - wsOff_w);

// // -- small square --
// int mi = di + wsHalf*dj;
// ws_i = mi % ws;
// ws_j = mi / ws + (ws-1);

// -- only adj --
ws_i = (abs(adj_h)-1)/stride0;
ws_j = (abs(adj_w)-1)/stride0;

}else if (oob_i and not(oob_j)){
ws_j = abs(num_h) - (wsHalf+1);
ws_i = num_w+wsHalf;
}else if (oob_j and not(oob_i)){
ws_j = abs(num_w) - (wsHalf+1) + (wsHalf);
// ws_j = abs(num_w) - (wsHalf+1) + (wsHalf);
ws_j = abs(num_w) - (wsHalf+1);
ws_i = num_h+wsHalf;
}

// -- standard names --
if (not(oob_i or oob_j)){
if (not(oob)){
ws_i = num_h + wsHalf;
ws_j = num_w + wsHalf;
}

// -- standard names --
if (not(and_oob)){
ws_i = ws_i/stride0;
ws_j = ws_j/stride0;
}

// -- get unique index --
li = (ws_i) + (ws_j)*ws + time_offset;
li = oob ? li + ws*ws : li;
if (not(oob_i or oob_j)){
li = (ws_i) + (ws_j)*wsNum + time_offset;
}else if (xor_oob and oob_i){
li = (ws_i) + (ws_j)*wsNum + time_offset + wsNum*wsNum;
}else if (xor_oob and oob_j){
li = (ws_i) + (ws_j)*wsNum + (wsNum/2)*wsNum + time_offset + wsNum*wsNum;
}else if (and_oob){
li = (ws_i) + (ws_j)*(wsNum/2);
li = li + time_offset + wsNum*wsNum + 2*(wsNum/2)*wsNum;
}else{
assert(1==0);
}

// // -- get unique index --
// li = (ws_i) + (ws_j)*ws + time_offset;
// li = oob ? li + ws*ws : li;

}

Expand Down Expand Up @@ -169,7 +200,7 @@ __global__ void scatter_labels_kernel(
get_unique_index(li, oob, nl_patch[1], nl_patch[2],
ref_patch[1], ref_patch[2],
wsOff_h, wsOff_w, time_offset,
stride1, ws, wsHalf, full_ws);
stride0, stride1, ws, wsHalf, full_ws);

// -- assign to sparse matrix --
// if (not(oob)){ return; }
Expand Down
15 changes: 9 additions & 6 deletions lib/stnls/agg/scatter_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ def run(flows,flows_k,ws,wt,stride0,stride1,H,W,full_ws):
W_t = 2*wt+1
# H = nH*stride0
# W = nW*stride0
wsHalf = (ws-1)//2
# wsHalf = (ws-1)//2

# -- number of maximum possible groups a single patch can belong to --
Wt_num = T if wt > 0 else 1
Ws_num = ws*ws
if full_ws: Ws_num += 2*ws*wsHalf + wsHalf**2
# Ws_num = ws*ws
wsNum = (ws-1)//stride0+1
Ws_num = wsNum*wsNum
if full_ws: Ws_num += 2*wsNum*(wsNum//2) + (wsNum//2)**2
S = Wt_num*Ws_num
print(S,ws,wt,stride0,stride1,full_ws)

Expand All @@ -43,9 +45,10 @@ def run(flows,flows_k,ws,wt,stride0,stride1,H,W,full_ws):
stnls_cuda.scatter_labels(flows,flows_k,labels,names,ws,wt,stride0,stride1,full_ws)

# -- check --
# nvalid = (names[...,0] >= 0).float().sum(2)
# if full_ws:
# assert(int(nvalid.sum().item()) == Q*K)
nvalid = (names[...,0] >= 0).float().sum(2)
if full_ws:
print(int(nvalid.sum().item()),Q*K)
# assert(int(nvalid.sum().item()) == Q*K)

return names,labels

Expand Down

0 comments on commit d14ac08

Please sign in to comment.