Skip to content

Commit

Permalink
fix half precission issues for scene change detect
Browse files Browse the repository at this point in the history
  • Loading branch information
NevermindNilas committed Jun 4, 2024
1 parent a49f85a commit 16a7fae
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/unifiedInterpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def handleModel(self):

self.firstRun = True
if self.sceneChange:
self.sceneChangeProcess = SceneChangeTensorRT(self.half)
self.sceneChangeProcess = SceneChangeTensorRT(self.halfS)

@torch.inference_mode()
def processFrame(self, frame):
Expand Down Expand Up @@ -561,18 +561,20 @@ def __init__(
def loadModel(self):
filename = modelsMap(
"scenechange",
self.half,
half = self.half,
)

if not os.path.exists(os.path.join(weightsDir, "scenechange", filename)):
modelPath = downloadModels(
"scenechange",
self.half,
half = self.half,
)

else:
modelPath = os.path.join(weightsDir, "scenechange", filename)

logging.info(f"Loading scenechange detection model from {modelPath}")

providers = self.ort.get_available_providers()
if "DmlExecutionProvider" in providers:
logging.info(
Expand All @@ -597,10 +599,15 @@ def run(self, frame0, frame1):
frame0 = F.interpolate(frame0.float(), size=(224, 224), mode="bilinear").squeeze(0).cpu().numpy()
frame1 = F.interpolate(frame1.float(), size=(224, 224), mode="bilinear").squeeze(0).cpu().numpy()
else:
frame0 = F.interpolate(frame0.float(), size=(224, 224), mode="bilinear").squeeze(0).half().cpu().numpy()
frame1 = F.interpolate(frame1.float(), size=(224, 224), mode="bilinear").squeeze(0).half().cpu().numpy()
frame0 = F.interpolate(frame0.float(), size=(224, 224), mode="bilinear").squeeze(0).cpu().numpy()
frame1 = F.interpolate(frame1.float(), size=(224, 224), mode="bilinear").squeeze(0).cpu().numpy()

if self.half:
frame0 = frame0.astype(np.float16)
frame1 = frame1.astype(np.float16)

inputs = np.ascontiguousarray(np.concatenate((frame0, frame1), 0))

result = self.model.run(None, {"input": inputs})[0][0][0]
return result > 0.93

Expand Down Expand Up @@ -645,13 +652,13 @@ def __init__(self, half):
def handleModel(self):
filename = modelsMap(
"scenechange",
self.half,
half = self.half,
)

if not os.path.exists(os.path.join(weightsDir, "scenechange", filename)):
modelPath = downloadModels(
"scenechange",
self.half,
half = self.half,
)

else:
Expand Down

0 comments on commit 16a7fae

Please sign in to comment.