| import sys |
| import os |
| sys.path.insert(0, os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/core')) |
| from raft import RAFT |
| from utils import flow_viz |
| sys.path = sys.path[1:] |
| import torch |
| from cwm.utils import imagenet_unnormalize |
| from torch import nn |
| import argparse |
|
|
|
|
| class Args: |
| model = os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/models/raft-sintel.pth') |
| small = False |
| path = None |
| mixed_precision = False |
| alternate_corr = False |
|
|
| def __iter__(self): |
| for attr, value in self.__dict__.items(): |
| yield attr, value |
|
|
| class RAFTInterface(nn.Module): |
| def __init__(self): |
| super().__init__() |
| args = Args() |
| model = torch.nn.DataParallel(RAFT(args)) |
| model.load_state_dict(torch.load(args.model, map_location=torch.device('cpu'))) |
| self.model = model.module |
| self.model.eval() |
|
|
| for p in self.model.parameters(): |
| p.requires_grad = False |
|
|
| @staticmethod |
| def prepare_inputs(x): |
| |
| if x.max() <= 1.0 and x.min() >= 0.: |
| x = x * 255. |
| elif x.min() < 0: |
| x = imagenet_unnormalize(x) |
| x = x * 255. |
|
|
| return x |
|
|
| def forward(self, x0, x1, return_magnitude=False): |
| |
| |
|
|
| |
| x0 = self.prepare_inputs(x0) |
| x1 = self.prepare_inputs(x1) |
| with torch.no_grad(): |
| _, flow_up = self.model(x0, x1, iters=20, test_mode=True) |
|
|
| if return_magnitude: |
| flow_magnitude = flow_up.norm(p=2, dim=1) |
| return flow_up, flow_magnitude |
|
|
| return flow_up |
|
|
| def viz(self, flow): |
| flow_rgb = flow_viz.flow_to_image(flow[0].permute(1,2,0).cpu().numpy()) |
| return flow_rgb |