From c14bf7a728b52aacb4c845844b3d8c948e45d826 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Sat, 6 Jan 2024 23:57:27 +0100 Subject: [PATCH] Update mypy. --- clip_retrieval/clip_client.py | 16 ++++++++-------- clip_retrieval/load_clip.py | 3 +++ requirements-test.txt | 2 +- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/clip_retrieval/clip_client.py b/clip_retrieval/clip_client.py index 6e3c525d..eb536b62 100644 --- a/clip_retrieval/clip_client.py +++ b/clip_retrieval/clip_client.py @@ -4,7 +4,7 @@ import enum import json from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional import requests @@ -55,9 +55,9 @@ def __init__( def query( self, - text: str = None, - image: str = None, - embedding_input: list = None, + text: Optional[str] = None, + image: Optional[str] = None, + embedding_input: Optional[list] = None, ) -> List[Dict]: """ Given text or image/s, search for other captions/images that are semantically similar. @@ -95,10 +95,10 @@ def query( def __search_knn_api__( self, - text: str = None, - image: str = None, - image_url: str = None, - embedding_input: list = None, + text: Optional[str] = None, + image: Optional[str] = None, + image_url: Optional[str] = None, + embedding_input: Optional[list] = None, ) -> List: """ This function is used to send the request to the knn service. diff --git a/clip_retrieval/load_clip.py b/clip_retrieval/load_clip.py index 80e292fe..ed96966f 100644 --- a/clip_retrieval/load_clip.py +++ b/clip_retrieval/load_clip.py @@ -33,6 +33,9 @@ def encode_text(self, text): with autocast(device_type=self.device.type, dtype=self.dtype): return self.inner_model.get_text_features(text) + def forward(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + class OpenClipWrapper(nn.Module): """ diff --git a/requirements-test.txt b/requirements-test.txt index 125e9f4f..7388a963 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,6 +1,6 @@ img2dataset black==22.3.0 -mypy==0.950 +mypy==1.8.0 pylint==2.13.4 pytest-cov==3.0.0 pytest-xdist==2.5.0