Skip to content

Commit

Permalink
gather/scatter works
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Dec 6, 2023
1 parent 6a7e32c commit 8d790cc
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 38 deletions.
46 changes: 28 additions & 18 deletions lib/stnls/agg/gather_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class NonLocalGatherAddFunction(th.autograd.Function):

@staticmethod
def forward(ctx, vid, weights, flows, ps, strideIn, strideOut,
pt=1, dilation=1, reflect_bounds=True, use_adj=False, itype="float"):
outH=0, outW=0, pt=1, dilation=1, reflect_bounds=True,
use_adj=False, itype="float"):
"""
vid = [BatchSize,nHeads or 1,T,C,H,W]
weights = [BatchSize,nHeads,NumQueries,K]
Expand All @@ -57,14 +58,25 @@ def forward(ctx, vid, weights, flows, ps, strideIn, strideOut,

# -- unpack --
device = weights.device
B,HD,T,nH,nW,K = weights.shape
wshape = weights.shape
vid = vid.contiguous()
flows = get_inds(flows,itype)
B,HD,T,nH,nW,K = weights.shape
B,HD,T,F,inH,inW = vid.shape
if outH == 0:
if strideOut == 1: outH = strideOut*nH
else: outH = inH
if outW == 0:
if strideOut == 1: outW = strideOut*nW
else: outW = inW
nH_in,nW_in = (inH-1)//strideIn+1,(inW-1)//strideIn+1
nH_out,nW_out = (outH-1)//strideOut+1,(outW-1)//strideOut+1
assert (nH == nH_in) and (nW == nW_in)
assert (nH == nH_out) and (nW == nW_out)

# -- allocate --
out_vid = th.zeros_like(vid)
counts = th.zeros_like(vid[0,0,0,0,:,:]).type(th.int)
out_vid = th.zeros((B,HD,T,F,outH,outW),device=device,dtype=th.float)
counts = th.zeros_like(out_vid[0,0,0,0,:,:]).type(th.int)
patch_offset = 0 if use_adj else -(ps//2)

# -- view --
Expand All @@ -82,10 +94,6 @@ def forward(ctx, vid, weights, flows, ps, strideIn, strideOut,

# -- normalize --
H,W = vid.shape[-2:]
# print(counts)
# print(counts.sum(-1))
# print(counts.sum(-2))
# exit()
counts = counts.view((1,1,1,1,H,W))
out_vid = out_vid / (counts+eps)
assert th.all(counts>1e-3)
Expand Down Expand Up @@ -178,16 +186,16 @@ def backward(ctx, grad_out_vid):
grad_flows = None

return grad_in_vid,grad_weights,grad_flows,None,None,None,\
None,None,None,None,None,None,None,None,None
None,None,None,None,None,None,None,None,None,None,None

class NonLocalGatherAdd(th.nn.Module):
# [video -> patches] @ flows

def __init__(self, ps, strideIn, strideOut, pt=1, dilation=1,
reflect_bounds=True, use_adj=False, itype="float"):
def __init__(self, ps, strideIn, strideOut, outH=0, outW=0, pt=1,
dilation=1, reflect_bounds=True, use_adj=False, itype="float"):
super().__init__()
_vars = ["ps","strideIn","strideOut", "pt","dilation",
"reflect_bounds","use_adj","itype"]
_vars = ["ps","strideIn","strideOut","outH","outW",
"pt","dilation","reflect_bounds","use_adj","itype"]
self._vars = _vars
for var in _vars:
setattr(self,var,eval(var))
Expand Down Expand Up @@ -220,11 +228,11 @@ def flops(self, nrefs, chnls_per_head, nheads, k):
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

def _apply(vid, weights, flows, ps, strideIn, strideOut,
pt=1, dilation=1,reflect_bounds=True, use_adj=False):
outH=0, outW=0, pt=1, dilation=1,reflect_bounds=True, use_adj=False):
# wrap "new (2018) apply function
# https://discuss.pytorch.org #13845/17
fxn = NonLocalGatherAddFunction.apply
return fxn(vid,weights,flows,ps, strideIn, strideOut,
return fxn(vid,weights,flows,ps,strideIn,strideOut,outH,outW,
pt,dilation,reflect_bounds,use_adj)

# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
Expand All @@ -234,14 +242,16 @@ def _apply(vid, weights, flows, ps, strideIn, strideOut,
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

def extract_config(cfg,restrict=True):
pairs = {"ps":3,"strideIn":1,"strideOut":1,"pt":1,"dilation":1,
"reflect_bounds":True, "use_adj":False, "itype":"float"}
pairs = {"ps":3,"strideIn":1,"strideOut":1,"outH":0,"outW":0,
"pt":1,"dilation":1,"reflect_bounds":True,
"use_adj":False, "itype":"float"}
return extract_pairs(cfg,pairs,restrict=restrict)

def init(cfg):
cfg = extract_config(cfg,False)
reducer = NonLocalGatherAdd(
cfg.ps, cfg.strideIn, cfg.strideOut, pt=cfg.pt, dilation=cfg.dilation,
cfg.ps, cfg.strideIn, cfg.strideOut,
outH=cfg.outH, outW=cfg.outW, pt=cfg.pt, dilation=cfg.dilation,
reflect_bounds=cfg.reflect_bounds,use_adj=cfg.use_adj,itype=cfg.itype)
return reducer

Expand Down
46 changes: 30 additions & 16 deletions lib/stnls/agg/scatter_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class NonLocalScatterAddFunction(th.autograd.Function):

@staticmethod
def forward(ctx, vid, weights, flows, ps, strideIn, strideOut,
pt=1, dilation=1, reflect_bounds=True, use_adj=False, itype="float"):
outH=0, outW=0, pt=1, dilation=1, reflect_bounds=True,
use_adj=False, itype="float"):
"""
vid = [BatchSize,nHeads or 1,T,C,H,W]
weights = [BatchSize,nHeads,NumQueries,K]
Expand All @@ -61,10 +62,21 @@ def forward(ctx, vid, weights, flows, ps, strideIn, strideOut,
wshape = weights.shape
vid = vid.contiguous()
flows = get_inds(flows,itype)
B,HD,T,nH,nW,K = weights.shape
B,HD,T,F,inH,inW = vid.shape
if outH == 0:
if strideOut == 1: outH = strideOut*nH
else: outH = inH
if outW == 0:
if strideOut == 1: outW = strideOut*nW
else: outW = inW
nH_in,nW_in = (inH-1)//strideIn+1,(inW-1)//strideIn+1
nH_out,nW_out = (outH-1)//strideOut+1,(outW-1)//strideOut+1
assert (nH == nH_in) and (nW == nW_in)
assert (nH == nH_out) and (nW == nW_out)

# -- allocate --
B,HD,T,F,H,W = vid.shape
out_vid = th.zeros((B,HD,T,F,H,W),device=device,dtype=th.float)
out_vid = th.zeros((B,HD,T,F,outH,outW),device=device,dtype=th.float)
counts = th.zeros_like(out_vid[0,0,0,0,:,:]).type(th.int)
patch_offset = 0 if use_adj else -(ps//2)

Expand All @@ -74,8 +86,8 @@ def forward(ctx, vid, weights, flows, ps, strideIn, strideOut,
flows = flows.view(B,HD,Q,K,3)

# -- exec --
fwd_fxn = stnls_cuda.scatter_add_forward
itype_int = flows.dtype == th.int
fwd_fxn = stnls_cuda.scatter_add_forward
fwd_fxn(out_vid, counts, vid, weights, flows,
ps, strideIn, strideOut, pt, dilation,
reflect_bounds, patch_offset, itype_int)
Expand Down Expand Up @@ -165,16 +177,16 @@ def backward(ctx, grad_out_vid):
grad_flows = None

return grad_in_vid,grad_weights,grad_flows,None,None,None,\
None,None,None,None,None,None,None,None,None,None
None,None,None,None,None,None,None,None,None,None,None,None

class NonLocalScatterAdd(th.nn.Module):
# [video -> patches] @ flows

def __init__(self, ps, strideIn, strideOut, pt=1, dilation=1,
reflect_bounds=True, use_adj=False, itype="float"):
def __init__(self, ps, strideIn, strideOut, outH=0, outW=0, pt=1,
dilation=1, reflect_bounds=True, use_adj=False, itype="float"):
super().__init__()
_vars = ["ps","strideIn","strideOut","pt","dilation",
"reflect_bounds","use_adj","itype"]
_vars = ["ps","strideIn","strideOut","outH","outW",
"pt","dilation","reflect_bounds","use_adj","itype"]
self._vars = _vars
for var in _vars:
setattr(self,var,eval(var))
Expand Down Expand Up @@ -206,12 +218,12 @@ def flops(self, nrefs, chnls_per_head, nheads, k):
#
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

def _apply(vid, weights, flows, ps, stride0,
pt=1, dilation=1,reflect_bounds=True, use_adj=False):
def _apply(vid, weights, flows, ps, strideIn, strideOut,
outH=0, outW=0, pt=1, dilation=1,reflect_bounds=True, use_adj=False):
# wrap "new (2018) apply function
# https://discuss.pytorch.org #13845/17
fxn = NonLocalScatterAddFunction.apply
return fxn(vid,weights,flows,ps,stride0,
return fxn(vid,weights,flows,ps,strideIn,strideOut,outH,outW,
pt,dilation,reflect_bounds,use_adj)

# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
Expand All @@ -221,14 +233,16 @@ def _apply(vid, weights, flows, ps, stride0,
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

def extract_config(cfg,restrict=True):
pairs = {"ps":3,"strideIn":1,"strideOut":1,"pt":1,"dilation":1,
"reflect_bounds":True, "use_adj":False, "itype":"float"}
pairs = {"ps":3,"strideIn":1,"strideOut":1,"outH":0,"outW":0,
"pt":1,"dilation":1,"reflect_bounds":True,
"use_adj":False, "itype":"float"}
return extract_pairs(cfg,pairs,restrict=restrict)

def init(cfg):
cfg = extract_config(cfg,False)
reducer = NonLocalScatterAdd(
cfg.ps, cfg.strideIn, cfg.strideOut, pt=cfg.pt, dilation=cfg.dilation,
reflect_bounds=cfg.reflect_bounds,use_adj=cfg.use_adj,itype=cfg.itype)
cfg.ps, cfg.strideIn, cfg.strideOut,
outH=cfg.outH, outW=cfg.outW, pt=cfg.pt, dilation=cfg.dilation,
reflect_bounds=cfg.reflect_bounds, use_adj=cfg.use_adj,itype=cfg.itype)
return reducer

12 changes: 8 additions & 4 deletions lib/stnls/dev/slic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def run_cas(vid,ws,wt,ps,stride0,full_ws,M=0.5,
return mask

def slic_v2_run(vid,ws,wt,ps,stride0,full_ws,M=0.5,
softmax_weight=10.,niters=-1,use_rand=False):
softmax_weight=10.,niters=1,use_rand=False):

# -- config --
use_flow = False
Expand All @@ -119,6 +119,7 @@ def slic_v2_run(vid,ws,wt,ps,stride0,full_ws,M=0.5,
# -- init slic state --
device = vid.device
HD = 1
H,W = vid.shape[-2:]
nH,nW = (H-1)//stride0+1,(W-1)//stride0+1
dists_k = th.ones((B,HD,T,nH,nW,1),device=device,dtype=th.float32)
flows_k = th.zeros((B,HD,T,nH,nW,1,3),device=device,dtype=th.int)
Expand All @@ -133,7 +134,8 @@ def slic_v2_run(vid,ws,wt,ps,stride0,full_ws,M=0.5,

# -- pooling (u_c) --
weights = th.softmax(-softmax_weight*dists_k,-1)
agg = stnls.agg.NonLocalGatherSum(agg_ps,stride0,1,itype="int")
agg = stnls.agg.NonLocalGatherAdd(agg_ps,stride0,1,
outH=nH,outW=nW,itype="int")
pooled = rearrange(agg(vid,weights,flows_k),'b hd t c h w -> b t (hd c) h w')
# print("Delta: ",th.mean((pooled-pooled0)**2).item())
print("[gather] pooled.shape: ",pooled.shape)
Expand Down Expand Up @@ -169,7 +171,7 @@ def slic_v2_run(vid,ws,wt,ps,stride0,full_ws,M=0.5,

# -- pooling (u_c) --
weights = th.softmax(-softmax_weight*dists_k,-1)
agg = stnls.agg.NonLocalScatterSum(agg_ps,1,stride0,itype="int")
agg = stnls.agg.NonLocalScatterAdd(agg_ps,1,stride0,outH=H,outW=W,itype="int")
# agg = stnls.agg.PooledPatchSum(agg_ps,stride0,itype="int")
pooled = rearrange(agg(pooled,weights,flows_k),'b hd t c h w -> b t (hd c) h w')
# print("Delta: ",th.mean((pooled-pooled0)**2).item())
Expand Down Expand Up @@ -255,9 +257,11 @@ def slic_run(vid,ws,wt,ps,stride0,full_ws,M=0.5,softmax_weight=10.,niters=1):

def get_slic_sampling_mask(vid,ws,wt,ps,stride0,full_ws,M=0.5,use_rand=False):
# pooled,dists_k,flows_k,_,_ = slic_run(vid,ws,wt,ps,stride0,full_ws,M)
print("hi.")
pooled,dists_k,flows_k,_,_ = slic_v2_run(vid,ws,wt,ps,stride0,full_ws,M)
vid = append_grid(vid,M,stride0)
mask = get_sampling_mask(vid, pooled, flows_k, ws, wt, ps, stride0, use_rand)
exit()
return mask

def run_scatter_k2q(s_dists,s_flows,s_labels,stride0,T,H,W):
Expand Down Expand Up @@ -573,7 +577,7 @@ def slic_clusters(vid,s_weights,s_flows_k,s_labels,ps,stride0,stride1,K0,
elif pool_method == "wpsum":
# ps = stride0*2
ps = ps + (1 - ps % 2) # ensure odd
agg = stnls.agg.NonLocalGatherSum(ps,stride0,stride0,itype="int")
agg = stnls.agg.NonLocalGatherAdd(ps,stride0,stride0,itype="int")
else:
raise ValueError(f"Uknown pool method [{pool_method}]")
# print(weights[0,0,0].sum(-1))
Expand Down

0 comments on commit 8d790cc

Please sign in to comment.