From a49f85a18772fd03eb77fd832ce3f0640b662c54 Mon Sep 17 00:00:00 2001 From: NevermindNilas Date: Tue, 4 Jun 2024 04:43:24 +0300 Subject: [PATCH] default rife trt to scene change trt --- src/segment/animeSegment.py | 1 - src/unifiedInterpolate.py | 133 +++++++++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 4 deletions(-) diff --git a/src/segment/animeSegment.py b/src/segment/animeSegment.py index 925f3eb8..abe5eccd 100644 --- a/src/segment/animeSegment.py +++ b/src/segment/animeSegment.py @@ -8,7 +8,6 @@ import torch.nn.functional as F from polygraphy.backend.trt import ( - TrtRunner, engine_from_network, network_from_onnx_path, CreateConfig, diff --git a/src/unifiedInterpolate.py b/src/unifiedInterpolate.py index 8a4ae953..3e96d4d9 100644 --- a/src/unifiedInterpolate.py +++ b/src/unifiedInterpolate.py @@ -396,7 +396,7 @@ def handleModel(self): self.firstRun = True if self.sceneChange: - self.sceneChangeProcess = SceneChange(self.half) + self.sceneChangeProcess = SceneChangeTensorRT(self.half) @torch.inference_mode() def processFrame(self, frame): @@ -613,6 +613,133 @@ def runNumpy(self, frame1, frame2): frame1 = frame1.astype(np.float16) inputs = np.ascontiguousarray(np.concatenate((frame0, frame1), 0)) - result = self.model.run(None, {"input": inputs})[0][0][0] - return result > 0.93 + return self.model.run(None, {"input": inputs})[0][0][0] > 0.93 + +class SceneChangeTensorRT(): + def __init__(self, half): + self.half = half + + from polygraphy.backend.trt import ( + TrtRunner, + engine_from_network, + network_from_onnx_path, + CreateConfig, + Profile, + EngineFromBytes, + SaveEngine, + ) + from polygraphy.backend.common import BytesFromPath + + self.TrtRunner = TrtRunner + self.engine_from_network = engine_from_network + self.network_from_onnx_path = network_from_onnx_path + self.CreateConfig = CreateConfig + self.Profile = Profile + self.EngineFromBytes = EngineFromBytes + self.SaveEngine = SaveEngine + self.BytesFromPath = BytesFromPath + + self.handleModel() + + def handleModel(self): + filename = modelsMap( + "scenechange", + self.half, + ) + + if not os.path.exists(os.path.join(weightsDir, "scenechange", filename)): + modelPath = downloadModels( + "scenechange", + self.half, + ) + + else: + modelPath = os.path.join(weightsDir, "scenechange", filename) + + if self.half: + trtEngineModelPath = modelPath.replace(".onnx", "_fp16.engine") + else: + trtEngineModelPath = modelPath.replace(".onnx", "_fp32.engine") + + if not os.path.exists(trtEngineModelPath): + toPrint = f"Engine not found, creating dynamic engine for model: {modelPath}, this may take a while, but it is worth the wait..." + print(yellow(toPrint)) + logging.info(toPrint) + + profile = [ + self.Profile().add( + "input", + min=(6, 224, 224), + opt=(6, 224, 224), + max=(6, 224, 224), + ) + ] + + self.config = self.CreateConfig( + fp16=self.half, + profiles=profile, + preview_features=[], + ) + + self.engine = self.engine_from_network( + self.network_from_onnx_path(modelPath), + config=self.config, + ) + self.engine = self.SaveEngine(self.engine, trtEngineModelPath) + self.engine.__call__() + + with open(trtEngineModelPath, "rb") as f, trt.Runtime( + trt.Logger(trt.Logger.INFO) + ) as runtime: + self.engine = runtime.deserialize_cuda_engine(f.read()) + self.context = self.engine.create_execution_context() + + self.dType = torch.float16 if self.half else torch.float32 + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.stream = torch.cuda.Stream() + self.dummyInput = torch.zeros( + (6, 224, 224), + device=self.device, + dtype=self.dType, + ) + self.dummyOutput = torch.zeros( + (1, 2), + device=self.device, + dtype=self.dType, + ) + + self.bindings = [self.dummyInput.data_ptr(), self.dummyOutput.data_ptr()] + + for i in range(self.engine.num_io_tensors): + self.context.set_tensor_address( + self.engine.get_tensor_name(i), self.bindings[i] + ) + tensor_name = self.engine.get_tensor_name(i) + if self.engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: + self.context.set_input_shape(tensor_name, self.dummyInput.shape) + + with torch.cuda.stream(self.stream): + for _ in range(50): + self.dummyInput.copy_(torch.zeros(6, 224, 224, device=self.device, dtype=self.dType)) + self.context.execute_async_v3(stream_handle=self.stream.cuda_stream) + self.stream.synchronize() + + @torch.inference_mode() + def run(self, frame0, frame1): + with torch.cuda.stream(self.stream): + if not self.half: + frame0 = F.interpolate(frame0.float(), size=(224, 224), mode="bilinear").squeeze(0) + frame1 = F.interpolate(frame1.float(), size=(224, 224), mode="bilinear").squeeze(0) + else: + frame0 = F.interpolate(frame0.float(), size=(224, 224), mode="bilinear").squeeze(0).half() + frame1 = F.interpolate(frame1.float(), size=(224, 224), mode="bilinear").squeeze(0).half() + + self.dummyInput.copy_(torch.cat([frame0, frame1], dim=0)) + self.context.execute_async_v3(stream_handle=self.stream.cuda_stream) + self.stream.synchronize() + return self.dummyOutput[0][0].item() > 0.93 + + + +