From b900ecec239d93426f4dc535be9f937139deee53 Mon Sep 17 00:00:00 2001 From: Chris Dzombak Date: Mon, 15 Jul 2024 16:57:08 -0400 Subject: [PATCH] add enrichment via Ollama multimodal models (e.g. LLaVA) --- Dockerfile-amd64-cpu | 1 + Dockerfile-amd64-cuda | 1 + Dockerfile-arm64 | 1 + README.md | 26 +++++- config.example.json | 12 +++ config.py | 56 +++++++++++++ enrichment-prompts/llava_prompt_car.txt | 17 ++++ enrichment-prompts/llava_prompt_person.txt | 14 ++++ enrichment-prompts/llava_prompt_truck.txt | 17 ++++ ntfy.py | 95 +++++++++++++++++++++- 10 files changed, 238 insertions(+), 2 deletions(-) create mode 100644 enrichment-prompts/llava_prompt_car.txt create mode 100644 enrichment-prompts/llava_prompt_person.txt create mode 100644 enrichment-prompts/llava_prompt_truck.txt diff --git a/Dockerfile-amd64-cpu b/Dockerfile-amd64-cpu index a0baa09..7db3271 100644 --- a/Dockerfile-amd64-cpu +++ b/Dockerfile-amd64-cpu @@ -12,6 +12,7 @@ RUN apt-get -y update \ RUN mkdir /app COPY ./*.py ./requirements.txt /app/ RUN pip install --no-cache-dir -r /app/requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu +COPY ./enrichment-prompts / WORKDIR /app RUN curl -f -L -O https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt diff --git a/Dockerfile-amd64-cuda b/Dockerfile-amd64-cuda index d1592a5..d8e2426 100644 --- a/Dockerfile-amd64-cuda +++ b/Dockerfile-amd64-cuda @@ -17,6 +17,7 @@ RUN mkdir /app COPY ./*.py ./requirements.txt /app/ RUN pip install --no-cache-dir nvidia-tensorrt RUN pip install --no-cache-dir -r /app/requirements.txt +COPY ./enrichment-prompts / WORKDIR /app RUN curl -f -L -O https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt diff --git a/Dockerfile-arm64 b/Dockerfile-arm64 index da3c388..dd82711 100644 --- a/Dockerfile-arm64 +++ b/Dockerfile-arm64 @@ -12,6 +12,7 @@ RUN apt-get -y update \ RUN mkdir /app COPY ./*.py ./requirements.txt /app/ RUN pip install --no-cache-dir -r /app/requirements.txt +COPY ./enrichment-prompts / WORKDIR /app RUN curl -f -L -O https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt diff --git a/README.md b/README.md index 3fb2a8d..ee87268 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,9 @@ `driveway-monitor` accepts an RTSP video stream (or, for testing purposes, a video file) and uses the [YOLOv8 model](https://docs.ultralytics.com/models/yolov8/) to track objects in the video. When an object meets your notification criteria (highly customizable; see "Configuration" below), `driveway-monitor` will notify you via [Ntfy](https://ntfy.sh). The notification includes a snapshot of the object that triggered the notification and provides options to mute notifications for a period of time. -The model can run on your CPU or on NVIDIA or Apple Silicon GPUs. It would be possible to use a customized model, and in fact I originally planned to refine my own model based on YOLOv8, but it turned out that the pretrained YOLOv8 model seems to work fine. +The YOLO computer vision model can run on your CPU or on NVIDIA or Apple Silicon GPUs. It would be possible to use a customized model, and in fact I originally planned to refine my own model based on YOLOv8, but it turned out that the pretrained YOLOv8 model seems to work fine. + +Optionally, `driveway-monitor` can also use an instance of [Ollama](https://ollama.com) to provide a detailed description of the object that triggered the notification. [This short video](doc/ntfy-mute-ui.mov) gives an overview of the end result. A notification is received; clicking the "Mute" button results in another notifiation with options to extend the mute time period or unmute the system. Tapping on the notification would open an image of me in my driveway; this isn't shown in the video for privacy reasons. @@ -52,6 +54,7 @@ services: image: cdzombak/driveway-monitor:1-amd64-cuda volumes: - ./config.json:/config.json:ro + - ./enrichment-prompts:/enrichment-prompts:ro command: [ "--debug", @@ -143,6 +146,20 @@ The prediction process consumes a video stream frame-by-frame and feeds each fra The tracker process aggregates the model's predictions over time, building tracks that represent the movement of individual objects in the video stream. Every time a track is updated with a prediction from a new frame, the tracker evaluates the track against the notification criteria. If the track meets the criteria, a notification is triggered. +### Enrichment + +Enrichment is an optional feature that uses an [Ollama](https://ollama.com) model to generate a more detailed description of the object that triggered a notification. If the Ollama model succeeds, the resulting description is included in the notification's message. + +To use enrichment, you'll need a working Ollama setup with a multimodal model installed. `driveway-monitor` does not provide this, since it's not necessary for the core feature set, and honestly it provides little additional value. + +The best results I've gotten (which still are not stellar) are using [the LLaVA 13b model](https://ollama.com/library/llava). This usually returns a result in under 3 seconds (when running on a 2080 Ti). On a CPU or less powerful GPU, consider `llava:7b`, [`llava-llama3`](https://ollama.com/library/llava-llama3), or just skip enrichment altogether. + +You can change the timeout for Ollama enrichment to generate a response by setting `enrichment.timeout_s` in your config. If you want to use enrichment, I highly recommend setting an aggressive timeout to ensure `driveway-monitor`'s responsiveness. + +Using enrichment requires providing a _prompt file_ for each YOLO object classification (e.g. `car`, `truck`, `person`) you want to enrich. This allows giving different instructions to your Ollama model for people vs. cars, for example. The `enrichment_prompts` directory provides a useful set of prompt files to get you started. + +When running `driveway-monitor` in Docker, keep in mind that your enrichment prompt files must be mounted in the container, and the paths in your config file must reflect the paths inside the container. + ### Notifier (Configuration key: `notifier`.) @@ -174,6 +191,13 @@ The file is a single JSON object containing the following keys, or a subset ther - `tracker`: Configures the system that builds tracks from the model's detections over time. - `inactive_track_prune_s`: Specifies the number of seconds after which an inactive track is pruned. This prevents incorrectly adding a new prediction to an old track. - `track_connect_min_overlap`: Minimum overlap percentage of a prediction box with the average of the last 2 boxes in an existing track for the prediction to be added to that track. +- `enrichment`: Configures the subsystem that enriches notifications via the Ollama API. + - `enable`: Whether to enable enrichment via Ollama. Defaults to `false`. + - `endpoint`: Complete URL to the Ollama `/generate` endpoint, e.g. `http://localhost:11434/api/generate`. + - `keep_alive`: Ask Ollama to keep the model in memory for this long after the request. String, formatted like `60m`. [See the Ollama API docs](https://github.com/ollama/ollama/blob/main/docs/api.md#parameters). + - `model`: The name of the Ollama model to use, e.g. `llava` or `llava:13b`. + - `prompt_files`: Map of `YOLO classification name` → `path`. Each path is a file containing the prompt to give Ollama along with an image of that YOLO classification. + - `timeout_s`: Timeout for the Ollama request, in seconds. This includes connection/network time _and_ the time Ollama takes to generate a response. - `notifier`: Configures how notifications are sent. - `debounce_threshold_s`: Specifies the number of seconds to wait after a notification before sending another one for the same type of object. - `default_priority`: Default priority for notifications. ([See Ntfy docs on Message Priority](https://docs.ntfy.sh/publish/#message-priority).) diff --git a/config.example.json b/config.example.json index 8ee8a2f..534c8d3 100644 --- a/config.example.json +++ b/config.example.json @@ -35,6 +35,18 @@ }, "image_method": "attach" }, + "enrichment": { + "enable": true, + "endpoint": "https://mygpuserver.tailnet-example.ts.net:11434/api/generate", + "model": "llava", + "keep_alive": "60m", + "timeout_s": 5, + "prompt_files": { + "car": "enrichment_prompts/llava_prompt_car.txt", + "truck": "enrichment_prompts/llava_prompt_truck.txt", + "person": "enrichment_prompts/llava_prompt_person.txt" + } + }, "web": { "port": 5550, "external_base_url": "https://mymachine.tailnet-example.ts.net:5559" diff --git a/config.py b/config.py index 228a0b5..d39c730 100644 --- a/config.py +++ b/config.py @@ -258,5 +258,61 @@ def config_from_file( # health: cfg.health_pinger.req_timeout_s = int(cfg.model.liveness_tick_s - 1.0) + # enrichment: + enrichment_dict = cfg_dict.get("enrichment", {}) + cfg.notifier.enrichment.enable = enrichment_dict.get( + "enable", cfg.notifier.enrichment.enable + ) + if not isinstance(cfg.notifier.enrichment.enable, bool): + raise ConfigValidationError("enrichment.enable must be a bool") + if cfg.notifier.enrichment.enable: + cfg.notifier.enrichment.prompt_files = enrichment_dict.get( + "prompt_files", cfg.notifier.enrichment.prompt_files + ) + if not isinstance(cfg.notifier.enrichment.prompt_files, dict): + raise ConfigValidationError("enrichment.prompt_files must be a dict") + for k, v in cfg.notifier.enrichment.prompt_files.items(): + if not isinstance(k, str) or not isinstance(v, str): + raise ConfigValidationError( + "enrichment.prompt_files must be a dict of str -> str" + ) + try: + with open(v) as f: + f.read() + except Exception as e: + raise ConfigValidationError( + f"enrichment.prompt_files: error reading file '{v}': {e}" + ) + cfg.notifier.enrichment.endpoint = enrichment_dict.get( + "endpoint", cfg.notifier.enrichment.endpoint + ) + if not cfg.notifier.enrichment.endpoint or not isinstance( + cfg.notifier.enrichment.endpoint, str + ): + raise ConfigValidationError("enrichment.endpoint must be a string") + if not ( + cfg.notifier.enrichment.endpoint.casefold().startswith("http://") + or cfg.notifier.enrichment.endpoint.casefold().startswith("https://") + ): + # noinspection HttpUrlsUsage + raise ConfigValidationError( + "enrichment.endpoint must start with http:// or https://" + ) + cfg.notifier.enrichment.model = enrichment_dict.get( + "model", cfg.notifier.enrichment.model + ) + if not isinstance(cfg.notifier.enrichment.model, str): + raise ConfigValidationError("enrichment.model must be a string") + cfg.notifier.enrichment.timeout_s = enrichment_dict.get( + "timeout_s", cfg.notifier.enrichment.timeout_s + ) + if not isinstance(cfg.notifier.enrichment.timeout_s, (int, float)): + raise ConfigValidationError("enrichment.timeout_s must be a number") + cfg.notifier.enrichment.keep_alive = enrichment_dict.get( + "keep_alive", cfg.notifier.enrichment.keep_alive + ) + if not isinstance(cfg.notifier.enrichment.keep_alive, str): + raise ConfigValidationError("enrichment.keep_alive must be a str") + logger.info("config loaded & validated") return cfg diff --git a/enrichment-prompts/llava_prompt_car.txt b/enrichment-prompts/llava_prompt_car.txt new file mode 100644 index 0000000..561e807 --- /dev/null +++ b/enrichment-prompts/llava_prompt_car.txt @@ -0,0 +1,17 @@ +This is an image of a vehicle, taken from a security camera. Identify the vehicle's most likely type, according to the following rules: + +- If it looks like an Amazon delivery vehicle, its type is "Amazon delivery". Notes: Any vehicle with the word "prime" on it is an Amazon delivery vehicle. Any vehicle with Amazon's logo on it is an Amazon delivery vehicle. A vehicle that looks like a passenger car is NOT an Amazon delivery vehicle. +- If it looks like a UPS delivery vehicle, its type is "UPS delivery". Notes: UPS delivery vehicles are painted dark brown. Any light-colored vehicle is NOT a UPS delivery vehicle. +- If it looks like a FedEx delivery vehicle, its type is "FedEx delivery". Note: Any dark-colored vehicle is NOT a FedEx delivery vehicle. +- If it looks like a USPS delivery vehicle, its type is "USPS delivery". Note: Any dark-colored vehicle is NOT a USPS delivery vehicle. +- If it looks like a yellow DHL delivery van, its type is "DHL delivery". +- If it looks like a pizza delivery vehicle, its type is "pizza delivery". +- If it looks like a contractor's truck, plumber's truck, electrician's truck, or a construction vehicle, its type is "contractor". +- If it looks like a pickup truck, its type is "pickup truck". +- If it looks like a sedan, coupe, hatchback, or passenger car, its type is "passenger car". +- If it does not look like any of those, you should describe its type in 3 words or less. Do not include any punctuation or any non-alphanumeric characters. + +Your response MUST be a valid JSON object with exactly two keys, "desc" and "error": + +- "desc" will contain the vehicle type you identified/described. If you could not identify or describe the vehicle, "desc" is "unknown". If there was no vehicle in the image, "desc" is an empty string (""). +- IF AND ONLY IF you could not identify the vehicle, "error" will describe what went wrong. If you identified the vehicle's type, do not provide any error message. diff --git a/enrichment-prompts/llava_prompt_person.txt b/enrichment-prompts/llava_prompt_person.txt new file mode 100644 index 0000000..3a2baa6 --- /dev/null +++ b/enrichment-prompts/llava_prompt_person.txt @@ -0,0 +1,14 @@ +This is an image from a security camera. The image contains at least one person. + +Identify the person's most likely job, according to these rules: + +- If the person is wearing a brown uniform, their job is "UPS delivery". +- If the person is wearing a purple uniform, their job is "FedEx delivery". +- If the person is wearing a blue uniform or a blue vest, their job is "Amazon delivery". +- If the person appears to be wearing some other uniform, you should describe a job their uniform is commonly associated with, in 3 words or less. Do not include any punctuation or any non-alphanumeric characters. +- If the person isn't wearing a uniform commonly associated with a specific job, or you cannot guess their job for any other reason, their job is "unknown". + +Your response MUST be a valid JSON object with exactly two keys: "desc" and "error": + +- "desc" will contain the job you identified. If you could not identify the person's job, "desc" is "unknown". If there was no person in the image, "desc" is an empty string (""). +- IF AND ONLY IF you could not plausibly guess the person's job, "error" will describe what went wrong. If you made a guess at the person's job, do not provide any error message. diff --git a/enrichment-prompts/llava_prompt_truck.txt b/enrichment-prompts/llava_prompt_truck.txt new file mode 100644 index 0000000..561e807 --- /dev/null +++ b/enrichment-prompts/llava_prompt_truck.txt @@ -0,0 +1,17 @@ +This is an image of a vehicle, taken from a security camera. Identify the vehicle's most likely type, according to the following rules: + +- If it looks like an Amazon delivery vehicle, its type is "Amazon delivery". Notes: Any vehicle with the word "prime" on it is an Amazon delivery vehicle. Any vehicle with Amazon's logo on it is an Amazon delivery vehicle. A vehicle that looks like a passenger car is NOT an Amazon delivery vehicle. +- If it looks like a UPS delivery vehicle, its type is "UPS delivery". Notes: UPS delivery vehicles are painted dark brown. Any light-colored vehicle is NOT a UPS delivery vehicle. +- If it looks like a FedEx delivery vehicle, its type is "FedEx delivery". Note: Any dark-colored vehicle is NOT a FedEx delivery vehicle. +- If it looks like a USPS delivery vehicle, its type is "USPS delivery". Note: Any dark-colored vehicle is NOT a USPS delivery vehicle. +- If it looks like a yellow DHL delivery van, its type is "DHL delivery". +- If it looks like a pizza delivery vehicle, its type is "pizza delivery". +- If it looks like a contractor's truck, plumber's truck, electrician's truck, or a construction vehicle, its type is "contractor". +- If it looks like a pickup truck, its type is "pickup truck". +- If it looks like a sedan, coupe, hatchback, or passenger car, its type is "passenger car". +- If it does not look like any of those, you should describe its type in 3 words or less. Do not include any punctuation or any non-alphanumeric characters. + +Your response MUST be a valid JSON object with exactly two keys, "desc" and "error": + +- "desc" will contain the vehicle type you identified/described. If you could not identify or describe the vehicle, "desc" is "unknown". If there was no vehicle in the image, "desc" is an empty string (""). +- IF AND ONLY IF you could not identify the vehicle, "error" will describe what went wrong. If you identified the vehicle's type, do not provide any error message. diff --git a/ntfy.py b/ntfy.py index 990ebfb..584b7d2 100644 --- a/ntfy.py +++ b/ntfy.py @@ -1,5 +1,7 @@ +import base64 import dataclasses import datetime +import json import logging import multiprocessing import os.path @@ -51,8 +53,19 @@ class NtfyRecord: jpeg_image: Optional[bytes] +@dataclasses.dataclass +class EnrichmentConfig: + enable: bool = False + endpoint: str = "" + keep_alive: str = "240m" + model: str = "llava" + prompt_files: Dict[str, str] = dataclasses.field(default_factory=lambda: {}) + timeout_s: float = 5.0 + + @dataclasses.dataclass class NtfyConfig: + enrichment: EnrichmentConfig = dataclasses.field(default_factory=EnrichmentConfig) external_base_url: str = "http://localhost:5550" log_level: Optional[int] = logging.INFO topic: str = "driveway-monitor" @@ -84,9 +97,12 @@ class ObjectNotification(Notification): event: str id: str jpeg_image: Optional[bytes] + enriched_class: Optional[str] = None def message(self): - return f"{self.classification} {self.event}.".capitalize() + if self.enriched_class: + return f"Likely: {self.enriched_class}.".capitalize() + return self.title() def title(self): return f"{self.classification} {self.event}".capitalize() @@ -235,6 +251,82 @@ def _suppress(self, logger, n: ObjectNotification) -> bool: self._last_notification[n.classification] = n.t return False + def _enrich(self, logger, n: ObjectNotification) -> ObjectNotification: + if not self._config.enrichment.enable: + return n + if not n.jpeg_image: + return n + + prompt_file = self._config.enrichment.prompt_files.get(n.classification) + if not prompt_file: + return n + try: + with open(prompt_file, "r") as f: + enrichment_prompt = f.read() + except Exception as e: + logger.error(f"error reading enrichment prompt file '{prompt_file}': {e}") + return n + if not enrichment_prompt: + return n + + try: + resp = requests.post( + self._config.enrichment.endpoint, + json={ + "model": self._config.enrichment.model, + "stream": False, + "images": [ + base64.b64encode(n.jpeg_image).decode("ascii"), + ], + "keep_alive": self._config.enrichment.keep_alive, + "format": "json", + "prompt": enrichment_prompt, + }, + timeout=self._config.enrichment.timeout_s, + ) + parsed = resp.json() + except requests.Timeout: + logger.error("enrichment request timed out") + return n + except requests.RequestException as e: + logger.error(f"enrichment failed: {e}") + return n + + model_resp_str = parsed.get("response") + if not model_resp_str: + logger.error("enrichment response is missing") + return n + + try: + model_resp_parsed = json.loads(model_resp_str) + except json.JSONDecodeError as e: + logger.info(f"enrichment model did not produce valid JSON: {e}") + logger.info(f"response: {model_resp_str}") + return n + + if "type" not in model_resp_parsed and "error" not in model_resp_parsed: + logger.info("enrichment model did not produce expected JSON keys") + return n + + model_desc = model_resp_parsed.get("desc", "unknown") + if model_desc == "unknown" or model_desc == "": + model_err = model_resp_parsed.get("error") + if not model_err: + model_err = "(no error returned)" + logger.info( + f"enrichment model could not produce a useful description: {model_err}" + ) + return n + + return ObjectNotification( + t=n.t, + classification=n.classification, + event=n.event, + id=n.id, + jpeg_image=n.jpeg_image, + enriched_class=model_desc, + ) + def _run(self): logger = logging.getLogger(__name__) logging.basicConfig(level=self._config.log_level, format=LOG_DEFAULT_FMT) @@ -260,6 +352,7 @@ def _run(self): jpeg_image=n.jpeg_image, expires_at=n.t + datetime.timedelta(days=1), ) + n = self._enrich(logger, n) try: headers = self._prep_ntfy_headers(n)