diff --git a/demos/common/stream_client/stream_client.py b/demos/common/stream_client/stream_client.py index f015ce5e70..1ad3984ec4 100644 --- a/demos/common/stream_client/stream_client.py +++ b/demos/common/stream_client/stream_client.py @@ -70,16 +70,26 @@ def write(self, frame): def release(self): self.cv_sink.release() +class ImshowOutputBackend(OutputBackend): + def init(self, sink, fps, width, height): + ... + def write(self, frame): + cv2.imshow("OVMS StreamClient", frame) + cv2.waitKey(1) + def release(self): + cv2.destroyAllWindows() + class StreamClient: class OutputBackends(): ffmpeg = FfmpegOutputBackend() cv2 = CvOutputBackend() + imshow = ImshowOutputBackend() none = OutputBackend() class Datatypes(): fp32 = FP32() uint8 = UINT8() - def __init__(self, *, preprocess_callback = None, postprocess_callback, source, sink : str, ffmpeg_output_width = None, ffmpeg_output_height = None, output_backend :OutputBackend = OutputBackends.ffmpeg, verbose : bool = False, exact : bool = True, benchmark : bool = False): + def __init__(self, *, preprocess_callback = None, postprocess_callback, source, sink: str, ffmpeg_output_width = None, ffmpeg_output_height = None, output_backend: OutputBackend = OutputBackends.ffmpeg, verbose: bool = False, exact: bool = True, benchmark: bool = False, max_inflight_packets: int = 4): """ Parameters ---------- @@ -114,6 +124,7 @@ def __init__(self, *, preprocess_callback = None, postprocess_callback, source, self.benchmark = benchmark self.pq = queue.PriorityQueue() + self.req_q = queue.Queue(max_inflight_packets) def grab_frame(self): success, frame = self.cap.read() @@ -132,18 +143,24 @@ def grab_frame(self): dropped_frames = 0 frames = 0 def callback(self, frame, i, timestamp, result, error): + if error is not None: + if self.benchmark: + self.dropped_frames += 1 + if self.verbose: + print(error) + if i == None: + i = result.get_response().parameters["OVMS_MP_TIMESTAMP"].int64_param + if timestamp == None: + timestamp = result.get_response().parameters["OVMS_MP_TIMESTAMP"].int64_param frame = self.postprocess_callback(frame, result) self.pq.put((i, frame, timestamp)) - if error is not None and self.verbose == True: - print(error) + self.req_q.get() def display(self): i = 0 while True: - if self.pq.empty(): - continue entry = self.pq.get() - if (entry[0] == i and self.exact) or (entry[0] > i and self.exact is not True): + if (entry[0] == i and self.exact and self.streaming_api is not True) or (entry[0] > i and (self.exact is not True or self.streaming_api is True)): if isinstance(entry[1], str) and entry[1] == "EOS": break frame = entry[1] @@ -161,8 +178,10 @@ def display(self): elif self.exact: self.pq.put(entry) + def get_timestamp(self) -> int: + return int(cv2.getTickCount() / cv2.getTickFrequency() * 1e6) - def start(self, *, ovms_address : str, input_name : str, model_name : str, datatype : Datatype = FP32(), batch = True, limit_stream_duration : int = 0, limit_frames : int = 0): + def start(self, *, ovms_address : str, input_name : str, model_name : str, datatype : Datatype = FP32(), batch = True, limit_stream_duration : int = 0, limit_frames : int = 0, streaming_api: bool = False): """ Parameters ---------- @@ -180,12 +199,15 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat Limits how long client could run limit_frames : int Limits how many frames should be processed + streaming_api : bool + Use experimental streaming endpoint """ - self.cap = cv2.VideoCapture(self.source, cv2.CAP_ANY) + self.cap = cv2.VideoCapture(int(self.source) if len(self.source) == 1 and self.source[0].isdigit() else self.source, cv2.CAP_ANY) self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 0) fps = self.cap.get(cv2.CAP_PROP_FPS) triton_client = grpcclient.InferenceServerClient(url=ovms_address, verbose=False) + self.streaming_api = streaming_api display_th = threading.Thread(target=self.display) display_th.start() @@ -199,26 +221,37 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat if self.height is None: self.height = np_test_frame.shape[0] self.output_backend.init(self.sink, fps, self.width, self.height) + + if streaming_api: + triton_client.start_stream(partial(self.callback, None, None, None)) - i = 0 + frame_number = 0 total_time_start = time.time() - while not self.force_exit: - timestamp = time.time() - frame = self.grab_frame() - if frame is not None: - np_frame = np.array([frame], dtype=datatype.dtype()) if batch else np.array(frame, dtype=datatype.dtype()) - inputs=[grpcclient.InferInput(input_name, np_frame.shape, datatype.string())] - inputs[0].set_data_from_numpy(np_frame) - triton_client.async_infer( - model_name=model_name, - callback=partial(self.callback, frame, i, timestamp), - inputs=inputs) - i += 1 - if limit_stream_duration > 0 and time.time() - total_time_start > limit_stream_duration: - break - if limit_frames > 0 and i > limit_frames: - break - self.pq.put((i, "EOS")) + try: + while not self.force_exit: + self.req_q.put(frame_number) + timestamp = time.time() + frame = self.grab_frame() + if frame is not None: + np_frame = np.array([frame], dtype=datatype.dtype()) if batch else np.array(frame, dtype=datatype.dtype()) + inputs=[grpcclient.InferInput(input_name, np_frame.shape, datatype.string())] + inputs[0].set_data_from_numpy(np_frame) + if streaming_api: + triton_client.async_stream_infer(model_name=model_name, inputs=inputs, parameters={"OVMS_MP_TIMESTAMP":self.get_timestamp()}, request_id=str(frame_number)) + else: + triton_client.async_infer( + model_name=model_name, + callback=partial(self.callback, frame, frame_number, timestamp), + inputs=inputs) + frame_number += 1 + if limit_stream_duration > 0 and time.time() - total_time_start > limit_stream_duration: + break + if limit_frames > 0 and frame_number > limit_frames: + break + finally: + self.pq.put((frame_number, "EOS")) + if streaming_api: + triton_client.stop_stream() sent_all_frames = time.time() - total_time_start @@ -227,4 +260,4 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat self.output_backend.release() total_time = time.time() - total_time_start if self.benchmark: - print(f"{{\"inference_time\": {sum(self.inference_time)/i}, \"dropped_frames\": {self.dropped_frames}, \"frames\": {self.frames}, \"fps\": {self.frames/total_time}, \"total_time\": {total_time}, \"sent_all_frames\": {sent_all_frames}}}") + print(f"{{\"inference_time\": {sum(self.inference_time)/frame_number}, \"dropped_frames\": {self.dropped_frames}, \"frames\": {self.frames}, \"fps\": {self.frames/total_time}, \"total_time\": {total_time}, \"sent_all_frames\": {sent_all_frames}}}") diff --git a/demos/mediapipe/holistic_tracking/README.md b/demos/mediapipe/holistic_tracking/README.md index 3dc225f759..c697a02158 100644 --- a/demos/mediapipe/holistic_tracking/README.md +++ b/demos/mediapipe/holistic_tracking/README.md @@ -4,8 +4,7 @@ This guide shows how to implement [MediaPipe](../../../docs/mediapipe.md) graph Example usage of graph that accepts Mediapipe::ImageFrame as a input: -The demo is based on the [upstream Mediapipe holistic demo](https://github.com/google/mediapipe/blob/master/docs/solutions/holistic.md) -and [Mediapipe Iris demo](https://github.com/google/mediapipe/blob/master/docs/solutions/iris.md) +The demo is based on the [upstream Mediapipe holistic demo](https://github.com/google/mediapipe/blob/master/docs/solutions/holistic.md). ## Prepare the server deployment @@ -82,23 +81,6 @@ Results saved to :image_0.jpg ## Output image ![output](output_image.jpg) -## Run client application for iris tracking -In a similar way can be executed the iris image analysis: - -```bash -python mediapipe_holistic_tracking.py --graph_name irisTracking --images_list input_images.txt --grpc_port 9000 -Running demo application. -Start processing: - Graph name: irisTracking -(640, 960, 3) -Iteration 0; Processing time: 77.03 ms; speed 12.98 fps -Results saved to :image_0.jpg -``` - -## Output image -![output](output_image1.jpg) - - ## RTSP Client Mediapipe graph can be used for remote analysis of individual images but the client can use it for a complete video stream processing. diff --git a/demos/mediapipe/holistic_tracking/rtsp_client.py b/demos/mediapipe/holistic_tracking/rtsp_client.py index 434f614f7f..69b5b0fe3b 100755 --- a/demos/mediapipe/holistic_tracking/rtsp_client.py +++ b/demos/mediapipe/holistic_tracking/rtsp_client.py @@ -51,5 +51,5 @@ def postprocess(frame, result): exact = True client = StreamClient(postprocess_callback = postprocess, preprocess_callback=preprocess, output_backend=backend, source=args.input_stream, sink=args.output_stream, exact=exact, benchmark=args.benchmark, verbose=args.verbose) -client.start(ovms_address=args.grpc_address, input_name=args.input_name, model_name=args.model_name, datatype = StreamClient.Datatypes.uint8, batch = False, limit_stream_duration = args.limit_stream_duration, limit_frames = args.limit_frames) +client.start(ovms_address=args.grpc_address, input_name=args.input_name, model_name=args.model_name, datatype = StreamClient.Datatypes.uint8, batch = False, limit_stream_duration = args.limit_stream_duration, limit_frames = args.limit_frames, streaming_api=True)