Skip to content

Commit

Permalink
Merge pull request #64 from RolnickLab/feat/2024-release
Browse files Browse the repository at this point in the history
2024 Release updates
  • Loading branch information
mihow authored Nov 22, 2024
2 parents 266a7a4 + 9188978 commit 0f55d0b
Show file tree
Hide file tree
Showing 37 changed files with 3,883 additions and 1,418 deletions.
8 changes: 4 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ celerybeat.pid
*.sage.py

# Environments
.env
.venv
env/
venv/
.env*
.venv/
.venv*/
ENV/
env.bak/
venv.bak/
bak/

# Spyder project settings
.spyderproject
Expand Down
69 changes: 58 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ Test the whole backend pipeline without the GUI using this command

```sh
python trapdata/tests/test_pipeline.py
# or
ami test pipeline
```

Run all other tests with:

```sh
ami test all
```

## GUI Usage
Expand Down Expand Up @@ -149,19 +157,58 @@ A script is available in the repo source to run the commands above.



## KG Notes for adding new models
## Adding new models

- To add new models, save the pt and json files to:
```
~/Library/Application Support/trapdata/models
```
or wherever you set the appropriate dir in settings.
The json file is simply a dict of species name and index.
1) Create a new inference class in `trapdata/ml/models/classification.py` or `trapdata/ml/models/localization.py`. All models inherit from `InferenceBaseClass`, but there are more specific classes for classification and localization and different architectures. Choose the appropriate class to inherit from. It's best to copy an existing inference class that is similar to the new model you are adding.

2) Upload your model weights and category map to a cloud storage service and make sure the file is publicly accessible via a URL. The weights will be downloaded the first time the model is run. Alternatively, you can manually add the model weights to the configured `USER_DATA_PATH` directory under the subdir `USER_DATA_PATH/models/` (on macOS this is `~/Library/Application Support/trapdata/models`). However the model will not be available to other users unless they also manually add the model weights. The category map json file is simply a dict of species names and their indexes in your model's last layer. See the existing category maps for examples.

3) Select your model in the GUI settings or set the `SPECIES_CLASSIFICATION_MODEL` setting. If the model inherits from `SpeciesClassifier` class, it will automatically become one of the valid choices.

Then you need to create a class in `trapdata/ml/models/classification.py` or `trapdata/ml/models/localization.py` and add the model details.
## Clearing the cache & starting fresh

- To clear the cache:
Remove the index of images, all detections and classifications by removing the database file. This will not remove the images themselves, only the metadata about them. The database is located in the user data directory.

On macOS:
```
rm ~/Library/Application\ Support/trapdata/trapdata.db
```
rm ~/Library/Application\ Support/trapdata/trapdata.db
```

On Linux:
```
rm ~/.config/trapdata/trapdata.db
```

On Windows:
```
del %AppData%\trapdata\trapdata.db
```

## Running the web API

The model inference pipeline can be run as a web API using FastAPI. This is what the Antenna platform uses to process images.

To run the API, use the following command:

```sh
ami api
```

View the interactive API docs at http://localhost:2000/


## Web UI demo (Gradio)

A simple web UI is also available to test the inference pipeline. This is a quick way to test models on a remote server via a web browser.

```sh
ami gradio
```

Open http://localhost:7861/

Use ngrok to temporarily expose localhost to the internet:

```sh
ngrok http 7861
```
11 changes: 11 additions & 0 deletions gradio.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[program:gradio]
directory=/home/ubuntu/ami-data-companion
command=/home/ubuntu/ami-data-companion/.venv/bin/ami gradio
autostart=true
autorestart=true
# stopsignal=KILL
stopasgroup=true
killasgroup=true
stderr_logfile=/var/log/gradio.err.log
stdout_logfile=/var/log/gradio.out.log
# process_name=%(program_name)s_%(process_num)02d
3,007 changes: 2,006 additions & 1,001 deletions poetry.lock

Large diffs are not rendered by default.

17 changes: 12 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry.dependencies]
python = "^3.9"
python = "^3.10"
pillow = "^9.5.0"
python-dateutil = "^2.8.2"
python-dotenv = "^1.0.0"
pydantic = "^1.10.7"
typer = "^0.7.0"
pydantic = "^2.5.0"
rich = "^13.3.3"
pandas = "^1.5.3"
sqlalchemy = ">2.0"
Expand All @@ -27,8 +26,8 @@ alembic = "^1.10.2"
psycopg2-binary = { version = "^2.9.5", optional = true }
sentry-sdk = "^1.18.0"
imagesize = "^1.4.1"
torch = "^2.0.0"
torchvision = "^0.15.1"
torch = "^2.1.0"
torchvision = "^0.16.0"
timm = "^0.6.13"
structlog = "^22.3.0"
kivy = { extras = ["base"], version = "^2.3.0" }
Expand All @@ -45,6 +44,14 @@ ipython = "^8.11.0"
pytest-cov = "^4.0.0"
pytest-asyncio = "^0.21.0"
pytest = "*"
numpy = "^1.26.2"
pip = "^23.3.1"
pydantic-settings = "^2.1.0"
boto3 = "^1.33.0"
botocore = "^1.33.0"
mypy-boto3-s3 = "^1.29.7"
typer = "^0.12.3"
gradio = "^4.41.0"


[tool.pytest.ini_options]
Expand Down
4 changes: 2 additions & 2 deletions scripts/start_db_container.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ set -o errexit
set -o nounset

CONTAINER_NAME=ami-db
HOST_PORT=5432
HOST_PORT=5433
POSTGRES_VERSION=14
POSTGRES_DB=ami

docker run -d -i --name $CONTAINER_NAME -v "$(pwd)/db_data":/var/lib/postgresql/data --restart always -p $HOST_PORT:5432 -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=$POSTGRES_DB postgres:$POSTGRES_VERSION

docker logs ami-db --tail 100

echo 'Database started, Connection string: "postgresql://postgres@localhost:5432/ami"'
echo "Database started, Connection string: \"postgresql://postgres@localhost:${HOST_PORT}/${POSTGRES_DB}\""
echo "Stop (and destroy) database with 'docker stop $CONTAINER_NAME' && docker remove $CONTAINER_NAME"
13 changes: 7 additions & 6 deletions trapdata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import sentry_sdk

from .common import constants, utils
from .common.logs import logger
from .db.models.detections import DetectedObject
from .db.models.events import MonitoringSession
from .db.models.images import TrapImage

sentry_sdk.init(
dsn="https://d2f65f945fe343669bbd3be5116d5922@o4503927026876416.ingest.sentry.io/4503927029497856",
traces_sample_rate=1.0,
)
#

# import multiprocessing

from .common import constants, utils
from .common.logs import logger
from .db.models.detections import DetectedObject
from .db.models.events import MonitoringSession
from .db.models.images import TrapImage

__all__ = [
logger,
Expand Down
3 changes: 3 additions & 0 deletions trapdata/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from trapdata.settings import read_settings

settings = read_settings()
194 changes: 194 additions & 0 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
Fast API interface for processing images through the localization and classification pipelines.
"""

import enum
import time

import fastapi
import pydantic
from rich import print

from ..common.logs import logger # noqa: F401
from . import settings
from .models.classification import (
APIMothClassifier,
MothClassifierBinary,
MothClassifierGlobal,
MothClassifierPanama,
MothClassifierPanama2024,
MothClassifierQuebecVermont,
MothClassifierTuringAnguilla,
MothClassifierTuringCostaRica,
MothClassifierUKDenmark,
)
from .models.localization import APIMothDetector
from .schemas import Detection, SourceImage

app = fastapi.FastAPI()


class SourceImageRequest(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="ignore")

# @TODO bring over new SourceImage & b64 validation from the lepsAI repo
id: str
url: str
# b64: str | None = None


class SourceImageResponse(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="ignore")

id: str
url: str


PIPELINE_CHOICES = {
"panama_moths_2023": MothClassifierPanama,
"panama_moths_2024": MothClassifierPanama2024,
"quebec_vermont_moths_2023": MothClassifierQuebecVermont,
"uk_denmark_moths_2023": MothClassifierUKDenmark,
"costa_rica_moths_turing_2024": MothClassifierTuringCostaRica,
"anguilla_moths_turing_2024": MothClassifierTuringAnguilla,
"global_moths_2024": MothClassifierGlobal,
}
_pipeline_choices = dict(zip(PIPELINE_CHOICES.keys(), list(PIPELINE_CHOICES.keys())))


PipelineChoice = enum.Enum("PipelineChoice", _pipeline_choices)


class PipelineRequest(pydantic.BaseModel):
pipeline: PipelineChoice
source_images: list[SourceImageRequest]


class PipelineResponse(pydantic.BaseModel):
pipeline: PipelineChoice
total_time: float
source_images: list[SourceImageResponse]
detections: list[Detection]


@app.get("/")
async def root():
return fastapi.responses.RedirectResponse("/docs")


@app.post("/pipeline/process")
@app.post("/pipeline/process/")
async def process(data: PipelineRequest) -> PipelineResponse:
# Ensure that the source images are unique, filter out duplicates
source_images_index = {
source_image.id: source_image for source_image in data.source_images
}
incoming_source_images = list(source_images_index.values())
if len(incoming_source_images) != len(data.source_images):
logger.warning(
f"Removed {len(data.source_images) - len(incoming_source_images)} duplicate source images"
)

source_image_results = [
SourceImageResponse(**image.model_dump()) for image in incoming_source_images
]
source_images = [
SourceImage(**image.model_dump()) for image in incoming_source_images
]

start_time = time.time()
detector = APIMothDetector(
source_images=source_images,
batch_size=settings.localization_batch_size,
num_workers=settings.num_workers,
# single=True if len(source_images) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
)
detector_results = detector.run()
num_pre_filter = len(detector_results)

filter = MothClassifierBinary(
source_images=source_images,
detections=detector_results,
batch_size=settings.classification_batch_size,
num_workers=settings.num_workers,
# single=True if len(detector_results) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
filter_results=False, # Only save results with the positive_binary_label, @TODO make this configurable from request
)
filter.run()
# all_binary_classifications = filter.results

# Compare num detections with num moth detections
num_post_filter = len(filter.results)
logger.info(
f"Binary classifier returned {num_post_filter} out of {num_pre_filter} detections"
)

# Filter results based on positive_binary_label
moth_detections = []
non_moth_detections = []
for detection in filter.results:
for classification in detection.classifications:
if classification.classification == filter.positive_binary_label:
moth_detections.append(detection)
elif classification.classification == filter.negative_binary_label:
non_moth_detections.append(detection)
break

logger.info(
f"Sending {len(moth_detections)} out of {num_pre_filter} detections to the classifier"
)

Classifier = PIPELINE_CHOICES[data.pipeline.value]
classifier: APIMothClassifier = Classifier(
source_images=source_images,
detections=moth_detections,
batch_size=settings.classification_batch_size,
num_workers=settings.num_workers,
# single=True if len(filtered_detections) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
)
classifier.run()
end_time = time.time()
seconds_elapsed = float(end_time - start_time)

# Return all detections, including those that were not classified as moths
all_detections = classifier.results + non_moth_detections

logger.info(
f"Processed {len(source_images)} images in {seconds_elapsed:.2f} seconds"
)
logger.info(f"Returning {len(all_detections)} detections")
print(all_detections)

# If the number of detections is greater than 100, its suspicious. Log it.
if len(all_detections) > 100:
logger.warning(
f"Detected {len(all_detections)} detections. This is suspicious and may contain duplicates."
)

response = PipelineResponse(
pipeline=data.pipeline,
source_images=source_image_results,
detections=all_detections,
total_time=seconds_elapsed,
)
return response


# Future methods

# batch processing
# async def process_batch(data: PipelineRequest) -> PipelineResponse:
# pass

# render image crops and bboxes on top of the original image
# async def render(data: PipelineRequest) -> PipelineResponse:
# pass


if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=2000)
Loading

0 comments on commit 0f55d0b

Please sign in to comment.