Skip to content

Commit

Permalink
add enrichment via Ollama multimodal models (e.g. LLaVA)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdzombak committed Jul 15, 2024
1 parent bbe3fc4 commit b900ece
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 2 deletions.
1 change: 1 addition & 0 deletions Dockerfile-amd64-cpu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ RUN apt-get -y update \
RUN mkdir /app
COPY ./*.py ./requirements.txt /app/
RUN pip install --no-cache-dir -r /app/requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
COPY ./enrichment-prompts /

WORKDIR /app
RUN curl -f -L -O https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt
Expand Down
1 change: 1 addition & 0 deletions Dockerfile-amd64-cuda
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ RUN mkdir /app
COPY ./*.py ./requirements.txt /app/
RUN pip install --no-cache-dir nvidia-tensorrt
RUN pip install --no-cache-dir -r /app/requirements.txt
COPY ./enrichment-prompts /

WORKDIR /app
RUN curl -f -L -O https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt
Expand Down
1 change: 1 addition & 0 deletions Dockerfile-arm64
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ RUN apt-get -y update \
RUN mkdir /app
COPY ./*.py ./requirements.txt /app/
RUN pip install --no-cache-dir -r /app/requirements.txt
COPY ./enrichment-prompts /

WORKDIR /app
RUN curl -f -L -O https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt
Expand Down
26 changes: 25 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

`driveway-monitor` accepts an RTSP video stream (or, for testing purposes, a video file) and uses the [YOLOv8 model](https://docs.ultralytics.com/models/yolov8/) to track objects in the video. When an object meets your notification criteria (highly customizable; see "Configuration" below), `driveway-monitor` will notify you via [Ntfy](https://ntfy.sh). The notification includes a snapshot of the object that triggered the notification and provides options to mute notifications for a period of time.

The model can run on your CPU or on NVIDIA or Apple Silicon GPUs. It would be possible to use a customized model, and in fact I originally planned to refine my own model based on YOLOv8, but it turned out that the pretrained YOLOv8 model seems to work fine.
The YOLO computer vision model can run on your CPU or on NVIDIA or Apple Silicon GPUs. It would be possible to use a customized model, and in fact I originally planned to refine my own model based on YOLOv8, but it turned out that the pretrained YOLOv8 model seems to work fine.

Optionally, `driveway-monitor` can also use an instance of [Ollama](https://ollama.com) to provide a detailed description of the object that triggered the notification.

[This short video](doc/ntfy-mute-ui.mov) gives an overview of the end result. A notification is received; clicking the "Mute" button results in another notifiation with options to extend the mute time period or unmute the system. Tapping on the notification would open an image of me in my driveway; this isn't shown in the video for privacy reasons.

Expand Down Expand Up @@ -52,6 +54,7 @@ services:
image: cdzombak/driveway-monitor:1-amd64-cuda
volumes:
- ./config.json:/config.json:ro
- ./enrichment-prompts:/enrichment-prompts:ro
command:
[
"--debug",
Expand Down Expand Up @@ -143,6 +146,20 @@ The prediction process consumes a video stream frame-by-frame and feeds each fra

The tracker process aggregates the model's predictions over time, building tracks that represent the movement of individual objects in the video stream. Every time a track is updated with a prediction from a new frame, the tracker evaluates the track against the notification criteria. If the track meets the criteria, a notification is triggered.

### Enrichment

Enrichment is an optional feature that uses an [Ollama](https://ollama.com) model to generate a more detailed description of the object that triggered a notification. If the Ollama model succeeds, the resulting description is included in the notification's message.

To use enrichment, you'll need a working Ollama setup with a multimodal model installed. `driveway-monitor` does not provide this, since it's not necessary for the core feature set, and honestly it provides little additional value.

The best results I've gotten (which still are not stellar) are using [the LLaVA 13b model](https://ollama.com/library/llava). This usually returns a result in under 3 seconds (when running on a 2080 Ti). On a CPU or less powerful GPU, consider `llava:7b`, [`llava-llama3`](https://ollama.com/library/llava-llama3), or just skip enrichment altogether.

You can change the timeout for Ollama enrichment to generate a response by setting `enrichment.timeout_s` in your config. If you want to use enrichment, I highly recommend setting an aggressive timeout to ensure `driveway-monitor`'s responsiveness.

Using enrichment requires providing a _prompt file_ for each YOLO object classification (e.g. `car`, `truck`, `person`) you want to enrich. This allows giving different instructions to your Ollama model for people vs. cars, for example. The `enrichment_prompts` directory provides a useful set of prompt files to get you started.

When running `driveway-monitor` in Docker, keep in mind that your enrichment prompt files must be mounted in the container, and the paths in your config file must reflect the paths inside the container.

### Notifier

(Configuration key: `notifier`.)
Expand Down Expand Up @@ -174,6 +191,13 @@ The file is a single JSON object containing the following keys, or a subset ther
- `tracker`: Configures the system that builds tracks from the model's detections over time.
- `inactive_track_prune_s`: Specifies the number of seconds after which an inactive track is pruned. This prevents incorrectly adding a new prediction to an old track.
- `track_connect_min_overlap`: Minimum overlap percentage of a prediction box with the average of the last 2 boxes in an existing track for the prediction to be added to that track.
- `enrichment`: Configures the subsystem that enriches notifications via the Ollama API.
- `enable`: Whether to enable enrichment via Ollama. Defaults to `false`.
- `endpoint`: Complete URL to the Ollama `/generate` endpoint, e.g. `http://localhost:11434/api/generate`.
- `keep_alive`: Ask Ollama to keep the model in memory for this long after the request. String, formatted like `60m`. [See the Ollama API docs](https://github.com/ollama/ollama/blob/main/docs/api.md#parameters).
- `model`: The name of the Ollama model to use, e.g. `llava` or `llava:13b`.
- `prompt_files`: Map of `YOLO classification name` → `path`. Each path is a file containing the prompt to give Ollama along with an image of that YOLO classification.
- `timeout_s`: Timeout for the Ollama request, in seconds. This includes connection/network time _and_ the time Ollama takes to generate a response.
- `notifier`: Configures how notifications are sent.
- `debounce_threshold_s`: Specifies the number of seconds to wait after a notification before sending another one for the same type of object.
- `default_priority`: Default priority for notifications. ([See Ntfy docs on Message Priority](https://docs.ntfy.sh/publish/#message-priority).)
Expand Down
12 changes: 12 additions & 0 deletions config.example.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@
},
"image_method": "attach"
},
"enrichment": {
"enable": true,
"endpoint": "https://mygpuserver.tailnet-example.ts.net:11434/api/generate",
"model": "llava",
"keep_alive": "60m",
"timeout_s": 5,
"prompt_files": {
"car": "enrichment_prompts/llava_prompt_car.txt",
"truck": "enrichment_prompts/llava_prompt_truck.txt",
"person": "enrichment_prompts/llava_prompt_person.txt"
}
},
"web": {
"port": 5550,
"external_base_url": "https://mymachine.tailnet-example.ts.net:5559"
Expand Down
56 changes: 56 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,5 +258,61 @@ def config_from_file(
# health:
cfg.health_pinger.req_timeout_s = int(cfg.model.liveness_tick_s - 1.0)

# enrichment:
enrichment_dict = cfg_dict.get("enrichment", {})
cfg.notifier.enrichment.enable = enrichment_dict.get(
"enable", cfg.notifier.enrichment.enable
)
if not isinstance(cfg.notifier.enrichment.enable, bool):
raise ConfigValidationError("enrichment.enable must be a bool")
if cfg.notifier.enrichment.enable:
cfg.notifier.enrichment.prompt_files = enrichment_dict.get(
"prompt_files", cfg.notifier.enrichment.prompt_files
)
if not isinstance(cfg.notifier.enrichment.prompt_files, dict):
raise ConfigValidationError("enrichment.prompt_files must be a dict")
for k, v in cfg.notifier.enrichment.prompt_files.items():
if not isinstance(k, str) or not isinstance(v, str):
raise ConfigValidationError(
"enrichment.prompt_files must be a dict of str -> str"
)
try:
with open(v) as f:
f.read()
except Exception as e:
raise ConfigValidationError(
f"enrichment.prompt_files: error reading file '{v}': {e}"
)
cfg.notifier.enrichment.endpoint = enrichment_dict.get(
"endpoint", cfg.notifier.enrichment.endpoint
)
if not cfg.notifier.enrichment.endpoint or not isinstance(
cfg.notifier.enrichment.endpoint, str
):
raise ConfigValidationError("enrichment.endpoint must be a string")
if not (
cfg.notifier.enrichment.endpoint.casefold().startswith("http://")
or cfg.notifier.enrichment.endpoint.casefold().startswith("https://")
):
# noinspection HttpUrlsUsage
raise ConfigValidationError(
"enrichment.endpoint must start with http:// or https://"
)
cfg.notifier.enrichment.model = enrichment_dict.get(
"model", cfg.notifier.enrichment.model
)
if not isinstance(cfg.notifier.enrichment.model, str):
raise ConfigValidationError("enrichment.model must be a string")
cfg.notifier.enrichment.timeout_s = enrichment_dict.get(
"timeout_s", cfg.notifier.enrichment.timeout_s
)
if not isinstance(cfg.notifier.enrichment.timeout_s, (int, float)):
raise ConfigValidationError("enrichment.timeout_s must be a number")
cfg.notifier.enrichment.keep_alive = enrichment_dict.get(
"keep_alive", cfg.notifier.enrichment.keep_alive
)
if not isinstance(cfg.notifier.enrichment.keep_alive, str):
raise ConfigValidationError("enrichment.keep_alive must be a str")

logger.info("config loaded & validated")
return cfg
17 changes: 17 additions & 0 deletions enrichment-prompts/llava_prompt_car.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
This is an image of a vehicle, taken from a security camera. Identify the vehicle's most likely type, according to the following rules:

- If it looks like an Amazon delivery vehicle, its type is "Amazon delivery". Notes: Any vehicle with the word "prime" on it is an Amazon delivery vehicle. Any vehicle with Amazon's logo on it is an Amazon delivery vehicle. A vehicle that looks like a passenger car is NOT an Amazon delivery vehicle.
- If it looks like a UPS delivery vehicle, its type is "UPS delivery". Notes: UPS delivery vehicles are painted dark brown. Any light-colored vehicle is NOT a UPS delivery vehicle.
- If it looks like a FedEx delivery vehicle, its type is "FedEx delivery". Note: Any dark-colored vehicle is NOT a FedEx delivery vehicle.
- If it looks like a USPS delivery vehicle, its type is "USPS delivery". Note: Any dark-colored vehicle is NOT a USPS delivery vehicle.
- If it looks like a yellow DHL delivery van, its type is "DHL delivery".
- If it looks like a pizza delivery vehicle, its type is "pizza delivery".
- If it looks like a contractor's truck, plumber's truck, electrician's truck, or a construction vehicle, its type is "contractor".
- If it looks like a pickup truck, its type is "pickup truck".
- If it looks like a sedan, coupe, hatchback, or passenger car, its type is "passenger car".
- If it does not look like any of those, you should describe its type in 3 words or less. Do not include any punctuation or any non-alphanumeric characters.

Your response MUST be a valid JSON object with exactly two keys, "desc" and "error":

- "desc" will contain the vehicle type you identified/described. If you could not identify or describe the vehicle, "desc" is "unknown". If there was no vehicle in the image, "desc" is an empty string ("").
- IF AND ONLY IF you could not identify the vehicle, "error" will describe what went wrong. If you identified the vehicle's type, do not provide any error message.
14 changes: 14 additions & 0 deletions enrichment-prompts/llava_prompt_person.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
This is an image from a security camera. The image contains at least one person.

Identify the person's most likely job, according to these rules:

- If the person is wearing a brown uniform, their job is "UPS delivery".
- If the person is wearing a purple uniform, their job is "FedEx delivery".
- If the person is wearing a blue uniform or a blue vest, their job is "Amazon delivery".
- If the person appears to be wearing some other uniform, you should describe a job their uniform is commonly associated with, in 3 words or less. Do not include any punctuation or any non-alphanumeric characters.
- If the person isn't wearing a uniform commonly associated with a specific job, or you cannot guess their job for any other reason, their job is "unknown".

Your response MUST be a valid JSON object with exactly two keys: "desc" and "error":

- "desc" will contain the job you identified. If you could not identify the person's job, "desc" is "unknown". If there was no person in the image, "desc" is an empty string ("").
- IF AND ONLY IF you could not plausibly guess the person's job, "error" will describe what went wrong. If you made a guess at the person's job, do not provide any error message.
17 changes: 17 additions & 0 deletions enrichment-prompts/llava_prompt_truck.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
This is an image of a vehicle, taken from a security camera. Identify the vehicle's most likely type, according to the following rules:

- If it looks like an Amazon delivery vehicle, its type is "Amazon delivery". Notes: Any vehicle with the word "prime" on it is an Amazon delivery vehicle. Any vehicle with Amazon's logo on it is an Amazon delivery vehicle. A vehicle that looks like a passenger car is NOT an Amazon delivery vehicle.
- If it looks like a UPS delivery vehicle, its type is "UPS delivery". Notes: UPS delivery vehicles are painted dark brown. Any light-colored vehicle is NOT a UPS delivery vehicle.
- If it looks like a FedEx delivery vehicle, its type is "FedEx delivery". Note: Any dark-colored vehicle is NOT a FedEx delivery vehicle.
- If it looks like a USPS delivery vehicle, its type is "USPS delivery". Note: Any dark-colored vehicle is NOT a USPS delivery vehicle.
- If it looks like a yellow DHL delivery van, its type is "DHL delivery".
- If it looks like a pizza delivery vehicle, its type is "pizza delivery".
- If it looks like a contractor's truck, plumber's truck, electrician's truck, or a construction vehicle, its type is "contractor".
- If it looks like a pickup truck, its type is "pickup truck".
- If it looks like a sedan, coupe, hatchback, or passenger car, its type is "passenger car".
- If it does not look like any of those, you should describe its type in 3 words or less. Do not include any punctuation or any non-alphanumeric characters.

Your response MUST be a valid JSON object with exactly two keys, "desc" and "error":

- "desc" will contain the vehicle type you identified/described. If you could not identify or describe the vehicle, "desc" is "unknown". If there was no vehicle in the image, "desc" is an empty string ("").
- IF AND ONLY IF you could not identify the vehicle, "error" will describe what went wrong. If you identified the vehicle's type, do not provide any error message.
95 changes: 94 additions & 1 deletion ntfy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import dataclasses
import datetime
import json
import logging
import multiprocessing
import os.path
Expand Down Expand Up @@ -51,8 +53,19 @@ class NtfyRecord:
jpeg_image: Optional[bytes]


@dataclasses.dataclass
class EnrichmentConfig:
enable: bool = False
endpoint: str = ""
keep_alive: str = "240m"
model: str = "llava"
prompt_files: Dict[str, str] = dataclasses.field(default_factory=lambda: {})
timeout_s: float = 5.0


@dataclasses.dataclass
class NtfyConfig:
enrichment: EnrichmentConfig = dataclasses.field(default_factory=EnrichmentConfig)
external_base_url: str = "http://localhost:5550"
log_level: Optional[int] = logging.INFO
topic: str = "driveway-monitor"
Expand Down Expand Up @@ -84,9 +97,12 @@ class ObjectNotification(Notification):
event: str
id: str
jpeg_image: Optional[bytes]
enriched_class: Optional[str] = None

def message(self):
return f"{self.classification} {self.event}.".capitalize()
if self.enriched_class:
return f"Likely: {self.enriched_class}.".capitalize()
return self.title()

def title(self):
return f"{self.classification} {self.event}".capitalize()
Expand Down Expand Up @@ -235,6 +251,82 @@ def _suppress(self, logger, n: ObjectNotification) -> bool:
self._last_notification[n.classification] = n.t
return False

def _enrich(self, logger, n: ObjectNotification) -> ObjectNotification:
if not self._config.enrichment.enable:
return n
if not n.jpeg_image:
return n

prompt_file = self._config.enrichment.prompt_files.get(n.classification)
if not prompt_file:
return n
try:
with open(prompt_file, "r") as f:
enrichment_prompt = f.read()
except Exception as e:
logger.error(f"error reading enrichment prompt file '{prompt_file}': {e}")
return n
if not enrichment_prompt:
return n

try:
resp = requests.post(
self._config.enrichment.endpoint,
json={
"model": self._config.enrichment.model,
"stream": False,
"images": [
base64.b64encode(n.jpeg_image).decode("ascii"),
],
"keep_alive": self._config.enrichment.keep_alive,
"format": "json",
"prompt": enrichment_prompt,
},
timeout=self._config.enrichment.timeout_s,
)
parsed = resp.json()
except requests.Timeout:
logger.error("enrichment request timed out")
return n
except requests.RequestException as e:
logger.error(f"enrichment failed: {e}")
return n

model_resp_str = parsed.get("response")
if not model_resp_str:
logger.error("enrichment response is missing")
return n

try:
model_resp_parsed = json.loads(model_resp_str)
except json.JSONDecodeError as e:
logger.info(f"enrichment model did not produce valid JSON: {e}")
logger.info(f"response: {model_resp_str}")
return n

if "type" not in model_resp_parsed and "error" not in model_resp_parsed:
logger.info("enrichment model did not produce expected JSON keys")
return n

model_desc = model_resp_parsed.get("desc", "unknown")
if model_desc == "unknown" or model_desc == "":
model_err = model_resp_parsed.get("error")
if not model_err:
model_err = "(no error returned)"
logger.info(
f"enrichment model could not produce a useful description: {model_err}"
)
return n

return ObjectNotification(
t=n.t,
classification=n.classification,
event=n.event,
id=n.id,
jpeg_image=n.jpeg_image,
enriched_class=model_desc,
)

def _run(self):
logger = logging.getLogger(__name__)
logging.basicConfig(level=self._config.log_level, format=LOG_DEFAULT_FMT)
Expand All @@ -260,6 +352,7 @@ def _run(self):
jpeg_image=n.jpeg_image,
expires_at=n.t + datetime.timedelta(days=1),
)
n = self._enrich(logger, n)

try:
headers = self._prep_ntfy_headers(n)
Expand Down

0 comments on commit b900ece

Please sign in to comment.