Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to return scores for all classes in a prediction #67

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fac27c8
feat: allow all softmax scores to be returned
mihow Nov 22, 2024
061bd8a
feat: test API requests in addition to underlying methods
mihow Nov 22, 2024
d9d9a7f
fix: intial api tests working with test image server
mihow Nov 23, 2024
e4e6076
test: ensure pipeline requests respect the classification_num_predict…
mihow Nov 23, 2024
455650e
feat: update parameter name and API examples
mihow Nov 23, 2024
31cacb7
chore: update dependencies lock file
mihow Nov 23, 2024
5b9d9a1
Try new test workflow
mihow Nov 26, 2024
b5d70b4
fix: run tests in importlib mode
mihow Nov 26, 2024
97a1cc2
feat: add logits, always return all scores, assume calibrated
mihow Nov 26, 2024
5336933
fix: line-lengths
mihow Dec 6, 2024
dd70e6a
fix: most depreciation warnings that distracted from test output
mihow Dec 6, 2024
8cc7999
fix: image server for tests
mihow Dec 6, 2024
65237cc
feat: update schema to return algorithms and all scores
mihow Dec 6, 2024
e08605b
feat: return category map data in API response
mihow Dec 6, 2024
7b60b13
feat: compress response (getting long!)
mihow Dec 6, 2024
2cfd591
fix: save non-moth pipeline addition for another day
mihow Dec 6, 2024
afb946f
fix: tests
mihow Dec 6, 2024
ed68770
feat: make port configurable
mihow Dec 7, 2024
ed8f27b
feat: update schemas and category map response
mihow Dec 7, 2024
1001bb0
fix: don't print the massive output
mihow Dec 7, 2024
90ebd05
feat: update schemas and endpoints to match new API
mihow Jan 21, 2025
2f7287b
feat: return all algorithm & category map details in /info
mihow Jan 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Run tests

on:
Expand All @@ -13,20 +10,43 @@ permissions:
contents: read

jobs:
build:
test:
name: Run Python Tests
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set up Python 3.10
uses: actions/setup-python@main # Need latest version to use pyproject.toml instead of requirements.txt
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: |
poetry.lock
pyproject.toml

- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: latest
virtualenvs-create: true
virtualenvs-in-project: true

- name: Load cached Poetry virtualenv
id: cached-poetry-dependencies
uses: actions/cache@v3
with:
path: .venv
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction

- name: Run tests
run: |
pytest
# Clean any cached Python files before running tests
find . -type d -name "__pycache__" -exec rm -r {} +
find . -type f -name "*.pyc" -delete
poetry run pytest --import-mode=importlib
1,996 changes: 1,089 additions & 907 deletions poetry.lock

Large diffs are not rendered by default.

221 changes: 181 additions & 40 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""
Fast API interface for processing images through the localization and classification pipelines.
Fast API interface for processing images through the localization and classification
pipelines.
"""

import enum
import time

import fastapi
import pydantic
from rich import print
from fastapi.middleware.gzip import GZipMiddleware

from ..common.logs import logger # noqa: F401
from . import settings
Expand All @@ -23,70 +24,164 @@
MothClassifierUKDenmark,
)
from .models.localization import APIMothDetector
from .schemas import Detection, SourceImage
from .schemas import (
AlgorithmCategoryMapResponse,
AlgorithmConfigResponse,
PipelineConfigResponse,
)
from .schemas import PipelineRequest as PipelineRequest_
from .schemas import PipelineResultsResponse as PipelineResponse_
from .schemas import ProcessingServiceInfoResponse, SourceImage, SourceImageResponse

app = fastapi.FastAPI()
app.add_middleware(GZipMiddleware)


class SourceImageRequest(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="ignore")

# @TODO bring over new SourceImage & b64 validation from the lepsAI repo
id: str
url: str
# b64: str | None = None


class SourceImageResponse(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="ignore")

id: str
url: str


PIPELINE_CHOICES = {
CLASSIFIER_CHOICES = {
"panama_moths_2023": MothClassifierPanama,
"panama_moths_2024": MothClassifierPanama2024,
"quebec_vermont_moths_2023": MothClassifierQuebecVermont,
"uk_denmark_moths_2023": MothClassifierUKDenmark,
"costa_rica_moths_turing_2024": MothClassifierTuringCostaRica,
"anguilla_moths_turing_2024": MothClassifierTuringAnguilla,
"global_moths_2024": MothClassifierGlobal,
# "moth_binary": MothClassifierBinary,
}
_pipeline_choices = dict(zip(PIPELINE_CHOICES.keys(), list(PIPELINE_CHOICES.keys())))
_classifier_choices = dict(
zip(CLASSIFIER_CHOICES.keys(), list(CLASSIFIER_CHOICES.keys()))
)


PipelineChoice = enum.Enum("PipelineChoice", _classifier_choices)

PipelineChoice = enum.Enum("PipelineChoice", _pipeline_choices)

def make_category_map_response(
model: APIMothDetector | APIMothClassifier,
default_taxon_rank: str = "SPECIES",
) -> AlgorithmCategoryMapResponse:
categories_sorted_by_index = sorted(model.category_map.items(), key=lambda x: x[0])
# as list of dicts:
categories_sorted_by_index = [
{
"index": index,
"label": label,
"taxon_rank": default_taxon_rank,
}
for index, label in categories_sorted_by_index
]
label_strings_sorted_by_index = [cat["label"] for cat in categories_sorted_by_index]
return AlgorithmCategoryMapResponse(
data=categories_sorted_by_index,
labels=label_strings_sorted_by_index,
uri=model.labels_path,
)

class PipelineRequest(pydantic.BaseModel):
pipeline: PipelineChoice
source_images: list[SourceImageRequest]

def make_algorithm_response(
model: APIMothDetector | APIMothClassifier,
) -> AlgorithmConfigResponse:

class PipelineResponse(pydantic.BaseModel):
pipeline: PipelineChoice
total_time: float
source_images: list[SourceImageResponse]
detections: list[Detection]
category_map = make_category_map_response(model) if model.category_map else None
return AlgorithmConfigResponse(
name=model.name,
key=model.get_key(),
task_type=model.task_type,
description=model.description,
category_map=category_map,
uri=model.weights_path,
)


def make_algorithm_config_response(
model: APIMothDetector | APIMothClassifier,
) -> AlgorithmConfigResponse:
category_map = make_category_map_response(model)
return AlgorithmConfigResponse(
name=model.name,
key=model.get_key(),
task_type=model.task_type,
description=model.description,
category_map=category_map,
uri=model.weights_path,
)


def make_pipeline_config_response(
Classifier: type[APIMothClassifier],
) -> PipelineConfigResponse:
detector = APIMothDetector(
source_images=[],
)

binary_classifier = MothClassifierBinary(
source_images=[],
detections=[],
)

classifier = Classifier(
source_images=[],
detections=[],
batch_size=settings.classification_batch_size,
num_workers=settings.num_workers,
terminal=True,
)

return PipelineConfigResponse(
name=classifier.name,
slug=classifier.get_key(),
description=classifier.description,
version=1,
algorithms=[
make_algorithm_config_response(detector),
make_algorithm_config_response(binary_classifier),
make_algorithm_config_response(classifier),
],
)


# @TODO This requires loading all models into memory! Can we avoid this?
PIPELINE_CONFIGS = [
make_pipeline_config_response(classifier_class)
for classifier_class in CLASSIFIER_CHOICES.values()
]


class PipelineRequest(PipelineRequest_):
pipeline: PipelineChoice = pydantic.Field(
description=PipelineRequest_.model_fields["pipeline"].description,
examples=list(_classifier_choices.keys()),
)


class PipelineResponse(PipelineResponse_):
pipeline: PipelineChoice = pydantic.Field(
PipelineChoice,
description=PipelineResponse_.model_fields["pipeline"].description,
examples=list(_classifier_choices.keys()),
)


@app.get("/")
async def root():
return fastapi.responses.RedirectResponse("/docs")


@app.post("/pipeline/process")
@app.post("/pipeline/process/")
@app.post(
"/pipeline/process/", deprecated=True, tags=["services"]
) # old endpoint, deprecated, remove after jan 2025
@app.post("/process/", tags=["services"]) # new endpoint
async def process(data: PipelineRequest) -> PipelineResponse:
algorithms_used: dict[str, AlgorithmConfigResponse] = {}

# Ensure that the source images are unique, filter out duplicates
source_images_index = {
source_image.id: source_image for source_image in data.source_images
}
incoming_source_images = list(source_images_index.values())
if len(incoming_source_images) != len(data.source_images):
logger.warning(
f"Removed {len(data.source_images) - len(incoming_source_images)} duplicate source images"
f"Removed {len(data.source_images) - len(incoming_source_images)} "
"duplicate source images"
)

source_image_results = [
Expand All @@ -106,6 +201,7 @@ async def process(data: PipelineRequest) -> PipelineResponse:
)
detector_results = detector.run()
num_pre_filter = len(detector_results)
algorithms_used[detector.get_key()] = make_algorithm_response(detector)

filter = MothClassifierBinary(
source_images=source_images,
Expand All @@ -114,15 +210,15 @@ async def process(data: PipelineRequest) -> PipelineResponse:
num_workers=settings.num_workers,
# single=True if len(detector_results) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
filter_results=False, # Only save results with the positive_binary_label, @TODO make this configurable from request
terminal=False,
)
filter.run()
# all_binary_classifications = filter.results
algorithms_used[filter.get_key()] = make_algorithm_response(filter)

# Compare num detections with num moth detections
num_post_filter = len(filter.results)
logger.info(
f"Binary classifier returned {num_post_filter} out of {num_pre_filter} detections"
f"Binary classifier returned {num_post_filter} of {num_pre_filter} detections"
)

# Filter results based on positive_binary_label
Expand All @@ -137,21 +233,25 @@ async def process(data: PipelineRequest) -> PipelineResponse:
break

logger.info(
f"Sending {len(moth_detections)} out of {num_pre_filter} detections to the classifier"
f"Sending {len(moth_detections)} of {num_pre_filter} "
"detections to the classifier"
)

Classifier = PIPELINE_CHOICES[data.pipeline.value]
Classifier = CLASSIFIER_CHOICES[str(data.pipeline)]
classifier: APIMothClassifier = Classifier(
source_images=source_images,
detections=moth_detections,
batch_size=settings.classification_batch_size,
num_workers=settings.num_workers,
# single=True if len(filtered_detections) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
example_config_param=data.config.example_config_param,
terminal=True,
)
classifier.run()
end_time = time.time()
seconds_elapsed = float(end_time - start_time)
algorithms_used[classifier.get_key()] = make_algorithm_response(classifier)

# Return all detections, including those that were not classified as moths
all_detections = classifier.results + non_moth_detections
Expand All @@ -160,23 +260,64 @@ async def process(data: PipelineRequest) -> PipelineResponse:
f"Processed {len(source_images)} images in {seconds_elapsed:.2f} seconds"
)
logger.info(f"Returning {len(all_detections)} detections")
print(all_detections)
# print(all_detections)

# If the number of detections is greater than 100, its suspicious. Log it.
if len(all_detections) > 100:
logger.warning(
f"Detected {len(all_detections)} detections. This is suspicious and may contain duplicates."
f"Detected {len(all_detections)} detections. "
"This is suspicious and may contain duplicates."
)

response = PipelineResponse(
pipeline=data.pipeline,
algorithms=algorithms_used,
source_images=source_image_results,
detections=all_detections,
total_time=seconds_elapsed,
)
return response


@app.get("/info", tags=["services"])
async def info() -> ProcessingServiceInfoResponse:
info = ProcessingServiceInfoResponse(
name="Antenna Inference API",
description=(
"The primary endpoint for processing images for the Antenna platform. "
"This API provides access to multiple detection and classification "
"algorithms by multiple labs for processing images of moths."
),
pipelines=PIPELINE_CONFIGS,
# algorithms=list(algorithm_choices.values()),
)
return info


# Check if the server is online
@app.get("/livez", tags=["health checks"])
async def livez():
return fastapi.responses.JSONResponse(status_code=200, content={"status": True})


# Check if the pipelines are ready to process data
@app.get("/readyz", tags=["health checks"])
async def readyz():
"""
Check if the server is ready to process data.

Returns a list of pipeline slugs that are online and ready to process data.
@TODO may need to simplify this to just return True/False. Pipeline algorithms will
likely be loaded into memory on-demand when the pipeline is selected.
"""
if _classifier_choices:
return fastapi.responses.JSONResponse(
status_code=200, content={"status": list(_classifier_choices.keys())}
)
else:
return fastapi.responses.JSONResponse(status_code=503, content={"status": []})


# Future methods

# batch processing
Expand Down
Loading
Loading