Skip to content

Commit

Permalink
updated named full in dev
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Dec 2, 2023
1 parent 3b3778c commit 0234029
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 30 deletions.
155 changes: 135 additions & 20 deletions dev/named_full_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,37 +64,80 @@ def set_search_offsets(wsOff_h, wsOff_w, hi, wi,

return wsOff_h,wsOff_w

def vprint(*args,**kwargs):
verbose = True
if verbose:
print(*args,**kwargs)

def get_unique_index(nl_hi,nl_wi,hi,wi,
wsOff_h,wsOff_w,time_offset,
stride1,ws,wsHalf,full_ws):
stride0,stride1,ws,wsHalf,full_ws):

# -- check spatial coordinates --
num_h = (nl_hi - hi)//stride1
num_w = (nl_wi - wi)//stride1
num_h = (nl_hi - hi)#//stride1
num_w = (nl_wi - wi)#//stride1
# num_h = (nl_hi - hi)//stride0
# num_w = (nl_wi - wi)//stride0
vprint("num_h,num_w,wsHalf: ",num_h,num_w,wsHalf)

# -- check oob --
oob_i = abs(num_h) > wsHalf
oob_j = abs(num_w) > wsHalf

# -- oob names --
if oob_i and oob_j:
vprint("case0")
# -- check offset --
adj_h = wsHalf - wsOff_h
adj_w = wsHalf - wsOff_w
vprint("adj_h,adj_w: ",adj_h,adj_w)

# -- di,dj --
di = wsHalf - abs(adj_h)
dj = wsHalf - abs(adj_w)
# di = wsHalf - abs(adj_h)
# dj = wsHalf - abs(adj_w)
di = wsHalf - adj_h if adj_h > 0 else adj_h
dj = wsHalf - adj_w if adj_w > 0 else adj_w

# -- small square --
# mi = di + wsHalf*dj
# ws_i = mi % ws
# ws_j = mi // ws + (ws-1)
mi = di + wsHalf*dj
ws_i = mi % ws
ws_j = mi // ws + (ws-1)
vprint("[case0] ws_i,ws_j: ",ws_i,ws_j)

ws_i = (abs(adj_h)-1)//stride0
ws_j = (abs(adj_w)-1)//stride0

# wsNum = (ws-1)//stride0+1

# # -- check offset --
# adj_h = wsHalf - wsOff_h
# adj_w = wsHalf - wsOff_w
# adj_h = adj_h//stride0
# adj_w = adj_w//stride0

# # -- di,dj --
# di = wsHalf - abs(adj_h)
# dj = wsHalf - abs(adj_w)
# di = di//stride0
# dj = dj//stride0

# # -- small square --
# mi = di + wsHalf*dj
# ws_i = mi % wsNum
# ws_j = mi // wsNum + (wsNum-1)

elif oob_i and not(oob_j):
vprint("case1")
ws_j = abs(num_h) - (wsHalf+1)
ws_i = num_w+wsHalf
elif oob_j and not(oob_i):
ws_j = abs(num_w) - (wsHalf+1) + (wsHalf)
vprint("case2")
# ws_j = abs(num_w) - (wsHalf+1) + (wsHalf)
# ws_i = num_h+wsHalf
ws_j = abs(num_w) - (wsHalf+1)
ws_i = num_h+wsHalf

# -- debug --
Expand All @@ -104,20 +147,66 @@ def get_unique_index(nl_hi,nl_wi,hi,wi,
if not(oob_i or oob_j):
ws_i = num_h + wsHalf
ws_j = num_w + wsHalf
# print(num_h,num_w,wsHalf,ws_i,ws_j)

# -- check oob --
oob = (oob_i or oob_j) and full_ws
wsNum = (ws-1)//stride0+1
xor_oob = (oob_i or oob_j) and not(oob_i and oob_j) and full_ws
and_oob = (oob_i and oob_j) and full_ws

# -- divide out stride0 --
vprint("pre: ",ws_i,ws_j)
if not(and_oob):
ws_i = ws_i//stride0
ws_j = ws_j//stride0
vprint("post: ",ws_i,ws_j)

# -- check --
if not(oob_i and oob_j):
assert (ws_i >= 0) and (ws_i < ws)
assert (ws_j >= 0) and (ws_j < ws)

# -- get unique index --
li = (ws_i) + (ws_j)*ws + time_offset
li = li + ws*ws if oob else li
# wsNum = (ws-1)//stride0+1
# li = (ws_i) + (ws_j)*wsNum + time_offset
# # li = li + wsNum*wsNum if oob else li
# xor_oob = (oob_i or oob_j) and not(oob_i and oob_j)
# li = li + wsNum*wsNum if xor_oob else li
# print(li)
# # 2*(wsNum//2)*wsNum
# li = li + wsNum*wsNum+2*(wsNum//2)*wsNum if (oob_i and oob_j) else li
# print(li)

return li,oob
# -- get unique index --
if not(oob_i or oob_j):
li = (ws_i) + (ws_j)*wsNum + time_offset
# elif xor_oob:
# li = (ws_i) + (ws_j)*wsNum + time_offset + wsNum*wsNum
elif xor_oob and oob_i:
li = (ws_i) + (ws_j)*wsNum + time_offset + wsNum*wsNum
elif xor_oob and oob_j:
li = (ws_i) + (ws_j)*wsNum + (wsNum//2)*wsNum + time_offset + wsNum*wsNum
elif (and_oob):
# ws_i = abs(ws_i-1)//stride0-1
# ws_j = abs(ws_j-1)//stride0-1
# ws_i = ws_i
# ws_j = ws_j % stride0
vprint("[case0] ws_i,ws_j: ",ws_i,ws_j)
li = (ws_i) + (ws_j)*(wsNum//2)
li = li + time_offset + wsNum*wsNum + 2*(wsNum//2)*wsNum
# li = 15
else:
raise ValueError("What?")

# # li = li + wsNum*wsNum if oob else li
# li = li + wsNum*wsNum if xor_oob else li
# print(li)
# # 2*(wsNum//2)*wsNum
# li = li + wsNum*wsNum+2*(wsNum//2)*wsNum if (oob_i and oob_j) else li
# print(li)


return li,and_oob


def get_tlims(ti, T, wt):
Expand All @@ -142,7 +231,10 @@ def fill_names(ti,h_ref,w_ref,ki,ws,wt,stride0,stride1,st_offset,
nl_wi = wi + flows_k[ti][h_ref][w_ref][ki][2]
valid = check_valid(nl_ti,nl_hi,nl_wi,T,H,W)
if not(valid): return
# if not((wi == 0) or (hi == 0)): return
# if (wi < 8) or (hi < 8): return
# if (wi > 56) or (hi > 56): return
# if not((nl_hi == 16) and (nl_wi == 9)): return
# if not((nl_hi == 4) and (nl_wi == 16)): return
# if not((wi == 0) and (hi == 0)): return
# if not((nl_hi == 1) and (nl_wi == 0)): return
# if not((nl_hi == 0) and (nl_wi == 2)): return
Expand Down Expand Up @@ -176,17 +268,18 @@ def fill_names(ti,h_ref,w_ref,ki,ws,wt,stride0,stride1,st_offset,
time_offset = ws_ti*(ws*ws+2*(ws//2)*ws+(ws//2)**2)
li,oob = get_unique_index(nl_hi,nl_wi,hi,wi,
wsOff_h,wsOff_w,time_offset,
stride1,ws,wsHalf,full_ws)
stride0,stride1,ws,wsHalf,full_ws)
# nl_hi,nl_wi,hi,wi,stride1,ws,
# wsHalf,wsOff_h,wsOff_w,time_offset,full_ws)
# ws_i,ws_j,oob = check_oob(ws_i,ws_j,nl_hi,nl_wi,hi,wi,stride1,ws,
# wsHalf,wsOff_h,wsOff_w,full_ws)
# if not(oob): return

# print("Ref/NonLocal: ",(ti,hi,wi),(nl_ti,nl_hi,nl_wi),ws_ti,dt,li,oob,stride1,ws)
print("Ref/NonLocal: ",(ti,hi,wi),(nl_ti,nl_hi,nl_wi),li,ws_ti,dt,oob,stride1,ws)

# -- update --
# assert((ws_ti >= 0) and (ws_ti <= (W_t-1)))
wsNum = (ws-1)//stride0+1
if np.any(names[li,ti,hi,wi]<0):
names[li,ti,hi,wi,...] = 0
names[li,nl_ti,nl_hi,nl_wi,0] = ti
Expand All @@ -195,7 +288,8 @@ def fill_names(ti,h_ref,w_ref,ki,ws,wt,stride0,stride1,st_offset,
if counts[li,nl_ti,nl_hi,nl_wi] == 0:
print("already here.")
exit()
if li < ws*ws and oob:
if li < wsNum*wsNum and oob:
print("li is small to be out of bounds.")
exit()
counts[li,nl_ti,nl_hi,nl_wi] += 1

Expand All @@ -205,14 +299,28 @@ def set_seed(seed):
random.seed(seed)

def main():
ws = 9
# ws = 7
wt = 0
W_t = 2*wt+1
full_ws = True
T,H,W = 5,64,64
stride0,stride1 = 8,1
T,H,W = 1,64,64
stride0 = 8
ws = 2*stride0 + 1
ws = 17
assert( ws>=(2*stride0+1) )
stride1 = 1
W_t_num = T if wt > 0 else 1#min(W_t + 2*wt,T)
S = W_t_num*(ws*ws + 2*(ws//2)*ws + (ws//2)**2)
wsNum = (ws-1)//stride0+1
# wsNum = (ws//2)//stride0+1
print(wsNum)
#S = W_t_num*(ws*ws + 2*(ws//2)*ws + (ws//2)**2)
# S = W_t_num*(wsNum*wsNum + 2*(wsNum//2)*wsNum+(wsNum//2)**2)
S = W_t_num*(wsNum*wsNum + 2*(wsNum//2)*wsNum+(wsNum//2)**2)
print("wsNum,wsNum*wsNum,2*(wsNum//2)*wsNum: ",wsNum,wsNum*wsNum,2*(wsNum//2)*wsNum)
print("S: ",S)
# S = 17
# S = 16
print("ws,wt,S: ",ws,wt,S)
vals = np.zeros((T,H,W,ws,ws))
names = -np.ones((S,T,H,W,3))
counts = -np.ones((S,T,H,W))
Expand All @@ -228,11 +336,18 @@ def main():
fill_names(ti,h_ref,w_ref,ki,ws,wt,stride0,stride1,
st_offset,full_ws,names,counts,flows_k)

print(counts[:,0,2,2])
print(counts[:,0,:3,:3].T)
# print(counts[:,0,2,2])
# print(counts[:,0,:3,:3].T)
# print(counts[:,0,59,3])
print(counts[:,0,16,9])
# for i in range(S):
# print(counts[i,0])
print(np.sum(counts>=0,0).max())
print(np.sum(counts>=0),T*nH*nW*K)
print(np.sum(counts==0),T*nH*nW*K)
print("wsNum,wsNum*wsNum,2*(wsNum//2)*wsNum: ",wsNum,wsNum*wsNum,2*(wsNum//2)*wsNum)
print("S: ",S)


if __name__ == "__main__":
main()
30 changes: 22 additions & 8 deletions lib/stnls/dev/misc/viz_nls_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,22 @@ def get_pix(vid,loc):
for ix in range(2):
for jx in range(2):
i = int(loc[1]+ix)
wi = max(0.,1-abs(1.*i - loc[1]))
wi = 1 - abs( (ix - (abs(loc[1]) % 1) ) )
# wi = max(0.,1-abs(1.*i - loc[1]))

j = int(loc[2]+jx)
wj = max(0.,1-abs(1.*j - loc[2]))
wj = 1 - abs( (jx - (abs(loc[2]) % 1) ) )
# wj = max(0.,1-abs(1.*j - loc[2]))

w = wi * wj
i = bound(i,H)
j = bound(j,W)

# print(w)
pix += w*vid[int(loc[0]),:,i,j]
return pix

def delta_patch(vid0,vid1,loc0,loc1,ps):
def delta_patch(vid0,vid1,loc0,loc1,ps,dist_type="l2"):
delta = 0
poff = -ps//2
for pi in range(ps):
Expand All @@ -80,24 +85,32 @@ def delta_patch(vid0,vid1,loc0,loc1,ps):
loc1_ij = [loc1[0],loc1[1]+pi+poff,loc1[2]+pj+poff]
pix0 = get_pix(vid0,loc0_ij)
pix1 = get_pix(vid1,loc1_ij)
delta += th.sum((pix0-pix1)**2)
if dist_type == "l2":
delta += th.sum((pix0-pix1)**2)
elif dist_type == "prod":
delta += th.sum(pix0*pix1)
else:
raise ValueError(f"Uknown dist_type [{dist_type}]")
return delta


def search_deltas(vid0,vid1,fflow,bflow,loc0,grid,stride1,ws,ps,K=9):
def search_deltas(vid0,vid1,fflow,bflow,loc0,grid,stride1,ws,ps,K=9,dist_type="l2"):
dmap = th.zeros_like(grid[0])*1.
flow = get_pix(fflow,loc0).flip(0)
for wi in range(ws):
for wj in range(ws):
# -- get search location --
off_i = grid[:,wi,wj]
loc1 = [1,]+[loc0[i] + flow[(i-1)] + stride1*off_i[i] for i in range(1,3)]
loc1_i = loc0[1] + flow[0] + stride1*off_i[1]
loc1_j = loc0[2] + flow[1] + stride1*off_i[2]
# loc1 = [1,]+[loc0[i] + flow[(i-1)] + stride1*off_i[i] for i in range(1,3)]
loc1 = [1,loc1_i,loc1_j]
# print([wi,wj],off_i)
# loc1 = [0,]+[loc0[i] + stride1*off_i[i] for i in range(1,3)]
# print(wi,wj,[stride1*off_i[i] for i in range(1,3)])

# -- compute delta ---
dmap[wi,wj] = delta_patch(vid0,vid1,loc0,loc1,ps)
dmap[wi,wj] = delta_patch(vid0,vid1,loc0,loc1,ps,dist_type)
# print(loc1,dmap[wi,wj])
# if off_i[1] == 0 and off_i[2] == 0:
# dmap[wi,wj] = 10000.
Expand All @@ -108,7 +121,8 @@ def search_deltas(vid0,vid1,fflow,bflow,loc0,grid,stride1,ws,ps,K=9):
# dmap = th.log(dmap + eps)
dmap -= dmap.min()
dmap /= dmap.max()
dmap = th.exp(-10*dmap)
sign = -1 if dist_type == "l2" else 1
dmap = th.exp(sign*10*dmap)
# # print(dmap)
dmap -= dmap.min()
dmap /= dmap.max()
Expand Down
3 changes: 1 addition & 2 deletions lib/stnls/nn/accumulate_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,7 @@ def flow_warp(x, flow, interp_mode='bilinear',

# -- resample --
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode,
padding_mode="reflection", align_corners=align_corners,
)
padding_mode="reflection", align_corners=align_corners)

return output

Expand Down

0 comments on commit 0234029

Please sign in to comment.