Skip to content

Commit

Permalink
default rife trt to scene change trt
Browse files Browse the repository at this point in the history
  • Loading branch information
NevermindNilas committed Jun 4, 2024
1 parent abd5236 commit a49f85a
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 4 deletions.
1 change: 0 additions & 1 deletion src/segment/animeSegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.nn.functional as F

from polygraphy.backend.trt import (
TrtRunner,
engine_from_network,
network_from_onnx_path,
CreateConfig,
Expand Down
133 changes: 130 additions & 3 deletions src/unifiedInterpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def handleModel(self):

self.firstRun = True
if self.sceneChange:
self.sceneChangeProcess = SceneChange(self.half)
self.sceneChangeProcess = SceneChangeTensorRT(self.half)

@torch.inference_mode()
def processFrame(self, frame):
Expand Down Expand Up @@ -613,6 +613,133 @@ def runNumpy(self, frame1, frame2):
frame1 = frame1.astype(np.float16)

inputs = np.ascontiguousarray(np.concatenate((frame0, frame1), 0))
result = self.model.run(None, {"input": inputs})[0][0][0]

return result > 0.93
return self.model.run(None, {"input": inputs})[0][0][0] > 0.93

class SceneChangeTensorRT():
def __init__(self, half):
self.half = half

from polygraphy.backend.trt import (
TrtRunner,
engine_from_network,
network_from_onnx_path,
CreateConfig,
Profile,
EngineFromBytes,
SaveEngine,
)
from polygraphy.backend.common import BytesFromPath

self.TrtRunner = TrtRunner
self.engine_from_network = engine_from_network
self.network_from_onnx_path = network_from_onnx_path
self.CreateConfig = CreateConfig
self.Profile = Profile
self.EngineFromBytes = EngineFromBytes
self.SaveEngine = SaveEngine
self.BytesFromPath = BytesFromPath

self.handleModel()

def handleModel(self):
filename = modelsMap(
"scenechange",
self.half,
)

if not os.path.exists(os.path.join(weightsDir, "scenechange", filename)):
modelPath = downloadModels(
"scenechange",
self.half,
)

else:
modelPath = os.path.join(weightsDir, "scenechange", filename)

if self.half:
trtEngineModelPath = modelPath.replace(".onnx", "_fp16.engine")
else:
trtEngineModelPath = modelPath.replace(".onnx", "_fp32.engine")

if not os.path.exists(trtEngineModelPath):
toPrint = f"Engine not found, creating dynamic engine for model: {modelPath}, this may take a while, but it is worth the wait..."
print(yellow(toPrint))
logging.info(toPrint)

profile = [
self.Profile().add(
"input",
min=(6, 224, 224),
opt=(6, 224, 224),
max=(6, 224, 224),
)
]

self.config = self.CreateConfig(
fp16=self.half,
profiles=profile,
preview_features=[],
)

self.engine = self.engine_from_network(
self.network_from_onnx_path(modelPath),
config=self.config,
)
self.engine = self.SaveEngine(self.engine, trtEngineModelPath)
self.engine.__call__()

with open(trtEngineModelPath, "rb") as f, trt.Runtime(
trt.Logger(trt.Logger.INFO)
) as runtime:
self.engine = runtime.deserialize_cuda_engine(f.read())
self.context = self.engine.create_execution_context()

self.dType = torch.float16 if self.half else torch.float32
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.stream = torch.cuda.Stream()
self.dummyInput = torch.zeros(
(6, 224, 224),
device=self.device,
dtype=self.dType,
)
self.dummyOutput = torch.zeros(
(1, 2),
device=self.device,
dtype=self.dType,
)

self.bindings = [self.dummyInput.data_ptr(), self.dummyOutput.data_ptr()]

for i in range(self.engine.num_io_tensors):
self.context.set_tensor_address(
self.engine.get_tensor_name(i), self.bindings[i]
)
tensor_name = self.engine.get_tensor_name(i)
if self.engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
self.context.set_input_shape(tensor_name, self.dummyInput.shape)

with torch.cuda.stream(self.stream):
for _ in range(50):
self.dummyInput.copy_(torch.zeros(6, 224, 224, device=self.device, dtype=self.dType))
self.context.execute_async_v3(stream_handle=self.stream.cuda_stream)
self.stream.synchronize()

@torch.inference_mode()
def run(self, frame0, frame1):
with torch.cuda.stream(self.stream):
if not self.half:
frame0 = F.interpolate(frame0.float(), size=(224, 224), mode="bilinear").squeeze(0)
frame1 = F.interpolate(frame1.float(), size=(224, 224), mode="bilinear").squeeze(0)
else:
frame0 = F.interpolate(frame0.float(), size=(224, 224), mode="bilinear").squeeze(0).half()
frame1 = F.interpolate(frame1.float(), size=(224, 224), mode="bilinear").squeeze(0).half()

self.dummyInput.copy_(torch.cat([frame0, frame1], dim=0))
self.context.execute_async_v3(stream_handle=self.stream.cuda_stream)
self.stream.synchronize()
return self.dummyOutput[0][0].item() > 0.93




0 comments on commit a49f85a

Please sign in to comment.