-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f100458
commit 5ff2bf5
Showing
50 changed files
with
8,776 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ GPUtil | |
spandrel | ||
yt-dlp | ||
requests | ||
accelerate | ||
tqdm | ||
tensorrt | ||
scikit-image | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ kornia | |
scipy | ||
wmi | ||
spandrel | ||
accelerate | ||
yt-dlp | ||
requests | ||
tqdm | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import torch | ||
|
||
|
||
class InputPadder: | ||
"""Pads images such that dimensions are divisible by factor""" | ||
|
||
def __init__(self, size, divide=8, mode="center"): | ||
self.ht, self.wd = size[-2:] | ||
pad_ht = (((self.ht // divide) + 1) * divide - self.ht) % divide | ||
pad_wd = (((self.wd // divide) + 1) * divide - self.wd) % divide | ||
if mode == "center": | ||
self._pad = [ | ||
pad_wd // 2, | ||
pad_wd - pad_wd // 2, | ||
pad_ht // 2, | ||
pad_ht - pad_ht // 2, | ||
] | ||
else: | ||
self._pad = [0, pad_wd, 0, pad_ht] | ||
|
||
def _pad_(self, x): | ||
return torch.nn.functional.pad(x, self._pad, mode="constant") | ||
|
||
def pad(self, *inputs): | ||
return [self._pad_(x) for x in inputs] | ||
|
||
def _unpad_(self, x): | ||
return x[ | ||
..., | ||
self._pad[2] : self.ht + self._pad[2], | ||
self._pad[0] : self.wd + self._pad[0], | ||
] | ||
|
||
def unpad(self, *inputs): | ||
return [self._unpad_(x) for x in inputs] | ||
|
||
|
||
def build_flow_estimator(name, device="cuda", checkpoint=None): | ||
if name.lower() == "raft": | ||
import argparse | ||
|
||
from .raft.raft import RAFT | ||
|
||
args = argparse.Namespace( | ||
mixed_precision=True, alternate_corr=False, small=False | ||
) | ||
model = RAFT(args) | ||
|
||
if checkpoint is None: | ||
ckpt = "checkpoints/RAFT/raft-sintel.pth" | ||
else: | ||
ckpt = checkpoint # Path to the checkpoint .pth, modified from original arch for better comp | ||
|
||
model.load_state_dict( | ||
{k.replace("module.", ""): v for k, v in torch.load(ckpt, map_location=device).items()} | ||
) | ||
model.to(device).eval() | ||
|
||
@torch.no_grad() | ||
def infer(I1, I2): | ||
I1 = I1.to(device) * 255.0 | ||
I2 = I2.to(device) * 255.0 | ||
padder = InputPadder(I1.shape, 8) | ||
I1, I2 = padder.pad(I1, I2) | ||
fflow = model(I1, I2, bidirection=False, iters=12) | ||
bflow = model(I2, I1, bidirection=False, iters=12) | ||
return padder.unpad(fflow, bflow) | ||
|
||
if name.lower() == "raft_small": | ||
import argparse | ||
|
||
from .raft.raft import RAFT | ||
|
||
args = argparse.Namespace( | ||
mixed_precision=True, alternate_corr=False, small=True | ||
) | ||
model = RAFT(args) | ||
ckpt = "checkpoints/RAFT/raft-small.pth" | ||
model.load_state_dict( | ||
{k.replace("module.", ""): v for k, v in torch.load(ckpt).items()} | ||
) | ||
model.to(device).eval() | ||
|
||
@torch.no_grad() | ||
def infer(I1, I2): | ||
I1 = I1.to(device) * 255.0 | ||
I2 = I2.to(device) * 255.0 | ||
padder = InputPadder(I1.shape, 8) | ||
I1, I2 = padder.pad(I1, I2) | ||
fflow = model(I1, I2, bidirection=False, iters=12) | ||
bflow = model(I2, I1, bidirection=False, iters=12) | ||
return padder.unpad(fflow, bflow) | ||
|
||
if name.lower() == "gma": | ||
import argparse | ||
|
||
from .gma.network import RAFTGMA | ||
|
||
args = argparse.Namespace( | ||
mixed_precision=True, | ||
num_heads=1, | ||
position_only=False, | ||
position_and_content=False, | ||
) | ||
model = RAFTGMA(args) | ||
ckpt = "checkpoints/GMA/gma-sintel.pth" | ||
model.load_state_dict( | ||
{k.replace("module.", ""): v for k, v in torch.load(ckpt).items()} | ||
) | ||
model.to(device).eval() | ||
|
||
@torch.no_grad() | ||
def infer(I1, I2): | ||
I1 = I1.to(device) * 255.0 | ||
I2 = I2.to(device) * 255.0 | ||
padder = InputPadder(I1.shape, 8) | ||
I1, I2 = padder.pad(I1, I2) | ||
_, fflow = model(I1, I2, test_mode=True, iters=20) | ||
_, bflow = model(I2, I1, test_mode=True, iters=20) | ||
return padder.unpad(fflow, bflow) | ||
|
||
if name.lower() == "gmflow": | ||
from .gmflow.gmflow import GMFlow | ||
|
||
model = GMFlow( | ||
feature_channels=128, | ||
num_scales=1, | ||
upsample_factor=8, | ||
num_head=1, | ||
attention_type="swin", | ||
ffn_dim_expansion=4, | ||
num_transformer_layers=6, | ||
) | ||
ckpt = "checkpoints/GMFlow/gmflow_sintel-0c07dcb3.pth" | ||
model.load_state_dict(torch.load(ckpt)["model"]) | ||
model.to(device).eval() | ||
|
||
@torch.no_grad() | ||
def infer(I1, I2): | ||
I1 = I1.to(device) * 255.0 | ||
I2 = I2.to(device) * 255.0 | ||
padder = InputPadder(I1.shape, 16) | ||
I1, I2 = padder.pad(I1, I2) | ||
results_dict = model( | ||
I1, | ||
I2, | ||
attn_splits_list=[2], | ||
corr_radius_list=[-1], | ||
prop_radius_list=[-1], | ||
pred_bidir_flow=False, | ||
) | ||
fflow = results_dict["flow_preds"][-1] | ||
results_dict = model( | ||
I2, | ||
I1, | ||
attn_splits_list=[2], | ||
corr_radius_list=[-1], | ||
prop_radius_list=[-1], | ||
pred_bidir_flow=False, | ||
) | ||
bflow = results_dict["flow_preds"][-1] | ||
|
||
fflow, bflow = padder.unpad(fflow, bflow) | ||
return fflow, bflow | ||
|
||
return model, infer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from .utils.utils import bilinear_sampler, coords_grid | ||
|
||
# from compute_sparse_correlation import compute_sparse_corr, compute_sparse_corr_torch, compute_sparse_corr_mink | ||
|
||
try: | ||
import alt_cuda_corr | ||
except: | ||
# alt_cuda_corr is not compiled | ||
pass | ||
|
||
|
||
class CorrBlock: | ||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): | ||
self.num_levels = num_levels | ||
self.radius = radius | ||
self.corr_pyramid = [] | ||
|
||
# all pairs correlation | ||
corr = CorrBlock.corr(fmap1, fmap2) | ||
|
||
batch, h1, w1, dim, h2, w2 = corr.shape | ||
corr = corr.reshape(batch * h1 * w1, dim, h2, w2) | ||
|
||
self.corr_pyramid.append(corr) | ||
for i in range(self.num_levels - 1): | ||
corr = F.avg_pool2d(corr, 2, stride=2) | ||
self.corr_pyramid.append(corr) | ||
|
||
def __call__(self, coords): | ||
r = self.radius | ||
coords = coords.permute(0, 2, 3, 1) | ||
batch, h1, w1, _ = coords.shape | ||
|
||
out_pyramid = [] | ||
for i in range(self.num_levels): | ||
corr = self.corr_pyramid[i] | ||
dx = torch.linspace(-r, r, 2 * r + 1) | ||
dy = torch.linspace(-r, r, 2 * r + 1) | ||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) | ||
|
||
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i | ||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) | ||
coords_lvl = centroid_lvl + delta_lvl | ||
|
||
corr = bilinear_sampler(corr, coords_lvl) | ||
corr = corr.view(batch, h1, w1, -1) | ||
out_pyramid.append(corr) | ||
|
||
out = torch.cat(out_pyramid, dim=-1) | ||
return out.permute(0, 3, 1, 2).contiguous().float() | ||
|
||
@staticmethod | ||
def corr(fmap1, fmap2): | ||
batch, dim, ht, wd = fmap1.shape | ||
fmap1 = fmap1.view(batch, dim, ht * wd) | ||
fmap2 = fmap2.view(batch, dim, ht * wd) | ||
|
||
corr = torch.matmul(fmap1.transpose(1, 2), fmap2) | ||
corr = corr.view(batch, ht, wd, 1, ht, wd) | ||
return corr / torch.sqrt(torch.tensor(dim).float()) | ||
|
||
|
||
class CorrBlockSingleScale(nn.Module): | ||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): | ||
super().__init__() | ||
self.radius = radius | ||
|
||
# all pairs correlation | ||
corr = CorrBlock.corr(fmap1, fmap2) | ||
batch, h1, w1, dim, h2, w2 = corr.shape | ||
self.corr = corr.reshape(batch * h1 * w1, dim, h2, w2) | ||
|
||
def __call__(self, coords): | ||
r = self.radius | ||
coords = coords.permute(0, 2, 3, 1) | ||
batch, h1, w1, _ = coords.shape | ||
|
||
corr = self.corr | ||
dx = torch.linspace(-r, r, 2 * r + 1) | ||
dy = torch.linspace(-r, r, 2 * r + 1) | ||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) | ||
|
||
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) | ||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) | ||
coords_lvl = centroid_lvl + delta_lvl | ||
|
||
corr = bilinear_sampler(corr, coords_lvl) | ||
out = corr.view(batch, h1, w1, -1) | ||
out = out.permute(0, 3, 1, 2).contiguous().float() | ||
return out | ||
|
||
@staticmethod | ||
def corr(fmap1, fmap2): | ||
batch, dim, ht, wd = fmap1.shape | ||
fmap1 = fmap1.view(batch, dim, ht * wd) | ||
fmap2 = fmap2.view(batch, dim, ht * wd) | ||
|
||
corr = torch.matmul(fmap1.transpose(1, 2), fmap2) | ||
corr = corr.view(batch, ht, wd, 1, ht, wd) | ||
return corr / torch.sqrt(torch.tensor(dim).float()) |
Oops, something went wrong.