Skip to content

Commit

Permalink
forgot to add slic
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Dec 3, 2023
1 parent d3b0553 commit 99a800b
Showing 1 changed file with 370 additions and 0 deletions.
370 changes: 370 additions & 0 deletions scripts/slic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
"""
Slic is easy with our packages
"""

# -- basic --
import torch as th
import numpy as np
from einops import rearrange,repeat
from easydict import EasyDict as edict
from dev_basics.utils import vid_io

# -- exps --
from dev_basics.utils.misc import set_seed

# -- optical flow --
from dev_basics import flow

# -- data --
import data_hub

# -- non-local opts --
import stnls

# -- benchmarking --
from dev_basics.utils.timer import ExpTimer,TimeIt
from dev_basics.utils.gpu_mem import GpuMemer,MemIt

# -- view segmentation --
from torchvision.utils import draw_segmentation_masks
from skimage.segmentation import mark_boundaries


def load_video(cfg):
device = "cuda:0"
data,loaders = data_hub.sets.load(cfg)
indices = data_hub.filter_subseq(data[cfg.dset],cfg.vid_name,0,cfg.nframes)
vid = data[cfg.dset][indices[0]]['clean'][None,:].to(device)/255.
# F = 32
# B,T,_,H,W = vid.shape
# vid = th.randn((B,T,F,H,W),device=device,dtype=vid.dtype)
return vid

def append_grid(vid,M,S):
B,T,F,H,W = vid.shape
dtype,device = vid.dtype,vid.device
grid_y, grid_x = th.meshgrid(th.arange(0, H, dtype=dtype, device=device),
th.arange(0, W, dtype=dtype, device=device))
grid = th.stack((grid_x, grid_y), -1).float() # 2, W(x), H(y)
grid = repeat(grid,'h w two -> b t two h w',b=B,t=T)
vid = th.cat([vid,M/S*grid],2)
return vid

def run_exp(cfg):

# -- set seed --
set_seed(cfg.seed)

# -- read video --
vid = load_video(cfg).half()
vid = append_grid(vid,cfg.M,cfg.stride0)
B,T,F,H,W = vid.shape
print("vid.shape: ",vid.shape)

# -- compute flows --
flows = flow.orun(vid,cfg.flow,ftype="cv2")
flows = stnls.nn.search_flow(flows.fflow,flows.bflow,cfg.wt,cfg.stride0)
flows = flows[:,None].round().int()
print(flows.shape)

# -- benchmark --
timer,memer = ExpTimer(),GpuMemer()
# with TimeIt(timer,"slic"):
# with MemIt(memer,"slic"):
# pooled,slic_flows = run_slic(vid,flows,cfg)

# -- benchmark pooling --
pooled,seg = {},{}
pooling_grid = ["ave","max","slic","nls"]
for pooling_type in pooling_grid:
with TimeIt(timer,pooling_type):
with MemIt(memer,pooling_type):
pooled_p,seg_p = run_pooling(cfg,vid,flows,pooling_type)
pooled[pooling_type] = pooled_p
seg[pooling_type] = seg_p

# -- view info --
print(timer)
print(memer)

# -- slic_flow to labels for plotting --
# seg_labels = inds2labels(slic_flows,cfg,H,W)
seg["slic"] = inds2labels(seg["slic"],cfg,H,W)

return vid,pooled,seg

def run_pooling(cfg,vid,flows,pooling_type):
ws,stride0 = cfg.ws,cfg.stride0
ksize = ws
ksize = stride0
stride = 1#stride0//2
# stride = stride0
B = vid.shape[0]

def run_standard(pool_fxn,vid,ksize,stride):
vid = rearrange(vid,'b t c h w -> (b t) c h w')
pooled = pool_fxn(vid, ksize, stride=stride)
pooled = rearrange(pooled,'(b t) c h w -> b t c h w',b=B)
return pooled

if pooling_type == "ave":
pool_fxn = th.nn.functional.avg_pool2d
pooled,seg = run_standard(pool_fxn,vid,ksize,stride),None
elif pooling_type == "max":
pool_fxn = th.nn.functional.max_pool2d
pooled,seg = run_standard(pool_fxn,vid,ksize,stride),None
elif pooling_type == "slic":
pool_fxn = th.nn.functional.avg_pool2d
pooled,seg = run_slic(vid,flows,cfg)
pooled = run_standard(pool_fxn,pooled,ksize,stride)
elif pooling_type == "nls":
pool_fxn = th.nn.functional.avg_pool2d
pooled,seg = run_nls(vid,flows,cfg)
pooled = run_standard(pool_fxn,pooled,ksize,stride)
else:
raise ValueError("Uknown pooling type.")
return pooled,seg

def run_nls(vid,flows,cfg):
# -- compute search window --
full_ws = True
B,T,F,H,W = vid.shape
search = stnls.search.NonLocalSearch(cfg.ws,cfg.wt,cfg.ps,cfg.nls_k,
nheads=1,dist_type="l2",
stride0=cfg.stride0,
self_action="anchor_self",
full_ws=full_ws,itype="int")
dists,flows_k = search(vid,vid,flows)
weights = th.softmax(-dists,-1)

# -- aggregate --
ps = int(cfg.stride0*1.75)
agg = stnls.agg.WeightedPatchSum(ps,cfg.stride0,itype="int")
vout = agg(vid,weights,flows_k)
vout = rearrange(vout,'b hd t c h w -> b t (hd c) h w')

return vout[:,:,:3],None

def run_slic(vid,flows,cfg):

# -- compute search window --
B,T,F,H,W = vid.shape
search = stnls.search.NonLocalSearch(cfg.ws,cfg.wt,cfg.ps,cfg.k,
nheads=1,dist_type="l2",
stride0=cfg.stride0,
self_action="anchor_self",
full_ws=cfg.full_ws,itype="int")
dists,flows_k = search(vid,vid,flows)
# print(dists.shape,flows_k.shape)
inds = stnls.utils.misc.flow2inds(flows_k,cfg.stride0)
# print(inds.shape)
# print(inds[0,0,0,:4,:4])

# print(inds[0,0,0,:2,:2,0])
# print(inds[0,0,0,-4:,-4:,0])

# -- scattering top-K=1 --
K0 = 1
gather_weights = th.softmax(-dists,-1)
# timer,memer = ExpTimer(),GpuMemer()
# with TimeIt(timer,"labels"):
# with MemIt(memer,"labels"):
names,labels = stnls.agg.scatter_labels(flows,flows_k,cfg.ws,cfg.wt,
cfg.stride0,cfg.stride1,H,W,cfg.full_ws)
# print(timer,memer)
# print(labels.min().item(),labels.max().item())
print("[scattering]: ",gather_weights.shape,flows_k.shape,labels.shape)
gather_labels = labels.reshape_as(gather_weights)
scatter_weights = stnls.agg.scatter_tensor(gather_weights,flows_k,labels,
cfg.stride0,cfg.stride1,H,W)
scatter_flows_k = stnls.agg.scatter_tensor(flows_k,flows_k,labels,
cfg.stride0,cfg.stride1,H,W)
scatter_labels = stnls.agg.scatter_tensor(gather_labels,flows_k,labels,
cfg.stride0,cfg.stride1,H,W)
print("[a]: ",scatter_flows_k.shape,flows_k.shape,scatter_labels.shape)


# -- checking in --
# nH,nW = H//cfg.stride1,W//cfg.stride1
# shape_str = 'b hd (t nh nw) k tr -> b hd t nh nw k tr'
# scatter_flows_k = rearrange(scatter_flows_k,shape_str,nh=nH,nw=nW)
# shape_str = 'b hd (t nh nw) k -> b hd t nh nw k'
# scatter_weights = rearrange(scatter_weights,shape_str,nh=nH,nw=nW)
# print(scatter_weights.shape,scatter_flows_k.shape)
# print(scatter_weights[0,0,0,-3:,-3:])
# print(scatter_flows_k[0,0,0,-3:,-3:])
# exit()

# -- topk --
scatter_flows_k = -scatter_flows_k
s_weight,s_flows_k,s_labels = stnls.agg.scatter_topk(scatter_weights,scatter_flows_k,
scatter_labels,K0)
# print(s_flows_k.shape,s_labels.shape)
# s_flows_k = s_flows_k.int()
# print(th.any(s_weight<-1000).item())
# print(th.any(s_flows_k<-1000).item(),th.any(s_flows_k>1000).item())
# print(th.where(s_flows_k[...,0]<-1000))
# print(s_weight[th.where(s_flows_k[...,0]<-1000)])
# print(th.where(s_flows_k<-1000))
# print(s_weight.shape)
# print(s_flows_k.shape)
# print(s_weight[0,0,:3])
# print(s_weight[0,0,100:103])
# print(s_weight[0,0,-3:])
# print(s_flows_k[0,0,:3])
# print(s_flows_k[0,0,100:103])
# print(s_flows_k[0,0,-3:])

pooled = slic_pooling(vid,s_weight,s_flows_k,s_labels,
cfg.ps,cfg.stride0,cfg.stride1,K0)
# pooled = None

print(pooled.shape)

return pooled[:,:,:3],s_flows_k


def slic_pooling(vid,s_weights,s_flows_k,s_labels,ps,stride0,stride1,K0):

# -- prepare weights and flows --
B,T,F,H,W = vid.shape
HD = s_weights.shape[1]
s_weights = s_weights.reshape(B,HD,T,H,W,K0)
s_flows_k = s_flows_k.reshape(B,HD,T,H,W,K0,3)
s_labels = s_labels.reshape(B,HD,T*H*W,-1)

# -- run scatters --
print("pooling: ",s_weights.shape,s_flows_k.shape,s_labels.shape)
weights = stnls.agg.scatter_tensor(s_weights,s_flows_k,s_labels,
stride1,stride0,H,W)
flows_k = stnls.agg.scatter_tensor(s_flows_k,s_flows_k,s_labels,
stride1,stride0,H,W)
print(weights.shape,flows_k.shape)

# -- reshape --
K = weights.shape[-1]
nH = (H-1)//stride0+1
nW = (W-1)//stride0+1
weights = weights.reshape(B,HD,T,nH,nW,K)
flows_k = flows_k.reshape(B,HD,T,nH,nW,K,3)

# -- renormalize weights --
# print(weights)
weights = th.softmax(weights,-1)
# print(weights)

# # -- inspect --
# print("scatter_weights.shape: ",weights.shape)
# args = th.where(th.isnan(weights[0,0]))
# print(args)
# exit()

# -- aggregate --
ps = int(stride0*1.75)
# ps = stride0
agg = stnls.agg.WeightedPatchSum(ps,stride0,itype="int")
vout = agg(vid,weights,flows_k)
vout = rearrange(vout,'b hd t c h w -> b t (hd c) h w')
# vout = None
# print("vout.shape,vid.shape: ",vout.shape,vid.shape)

return vout

def inds2labels(s_flows_k,cfg,H,W):

# -- get segmentation labels --
nH0,nW0 = (H-1)//cfg.stride0+1,(W-1)//cfg.stride0+1
nH,nW = (H-1)//cfg.stride1+1,(W-1)//cfg.stride1+1
shape_str = 'b hd (t nh nw) k tr -> b hd t nh nw k tr'
s_flows_k = rearrange(s_flows_k,shape_str,nh=nH,nw=nW)
s_inds = stnls.utils.misc.flow2inds(s_flows_k,cfg.stride1)
nH0,nW0 = H//cfg.stride0,W//cfg.stride0
s_inds = s_inds[:,0,...,0,:].contiguous() # 1 head, 1 k
stnls.utils.misc.reflect_inds(s_inds,H,W)

# -- labels --
seg_labels = s_inds[...,0]*nH0*nW0
seg_labels += th.div(s_inds[...,1],cfg.stride0,rounding_mode="floor")*nW0
seg_labels += th.div(s_inds[...,2],cfg.stride0,rounding_mode="floor")

# -- fill invalid --
valid = th.logical_and(seg_labels<100000,seg_labels>-100000)
S = seg_labels[th.where(valid)].max()
seg_labels[th.where(~valid)] = S+1

# -- view --
print(seg_labels.shape)
print(seg_labels[0,0,-5:,-5:])

return seg_labels

def labels2masks(labels):
S = labels.max()+1
masks = th.zeros([S,]+list(labels.shape),dtype=th.bool).to(labels.device)
for si in range(S):
masks[si] = labels==si
return masks

def main():

# -- config --
cfg = edict()
cfg.seed = 123
cfg.dname = "set8"
cfg.dset = "val"
cfg.isize = "540_540"
# cfg.isize = "256_256"
# cfg.isize = "128_128"
# cfg.isize = None
# cfg.isize = "400_400"
# cfg.isize = "300_300"
cfg.vid_name = "sunflower"
cfg.ntype = "g"
cfg.sigma = 0.1
cfg.nframes = 5
cfg.flow = False
cfg.full_ws = True
cfg.wt = 0
cfg.stride0 = 8
cfg.ws = 2*cfg.stride0-2
# if cfg.ws == 1: cfg.ws += 1
cfg.stride1 = 1
cfg.k = -1#cfg.ws*cfg.ws
cfg.nls_k = 8
cfg.ps = 1
cfg.M = 0.1

# -- run slic --
vid,pooled,segs = run_exp(cfg)
vid = vid[:,:,:3]
# pooled = pooled[:,:,:3]
labels = segs['slic']
# print(vid.shape,pooled.shape)

# -- save output --
vid = (255*vid).type(th.uint8)
B,T,F,H,W = vid.shape
seg = []
for bi in range(B):
for ti in range(T):
# mask = labels2masks(labels[bi,ti]).to(vid.device)
# print(vid[bi,ti].shape,mask.shape)
vid_bt = rearrange(vid[bi,ti].cpu().numpy(),'tr h w -> h w tr')
labels_bt = labels[bi,ti].cpu().numpy()
seg_bt = mark_boundaries(vid_bt,labels_bt)
seg_bt = rearrange(seg_bt,'h w tr -> tr h w')
# seg_bt = draw_segmentation_masks(vid[bi,ti].cpu(),mask.cpu())
seg.append(th.tensor(seg_bt))
seg = th.stack(seg).view(B,T,F,H,W)
print(seg.shape)
vid_io.save_video(seg,"output/slic","ex")

for ptype in pooled:
print(pooled[ptype].type,pooled[ptype].shape,pooled[ptype].max())
vid_io.save_video(pooled[ptype][:,:,:3],"output/slic_pooled/",ptype)


if __name__ == "__main__":
main()

0 comments on commit 99a800b

Please sign in to comment.