Skip to content

Commit

Permalink
Implement webdataset input format and related improvements
Browse files Browse the repository at this point in the history
* webdataset input format
* save in batch
* test files in tests folder
* save metadata as parquet
* use autofaiss in a new clip index
* remove indexing from clip batch and rename to clip inference
  • Loading branch information
rom1504 committed Sep 3, 2021
1 parent 97a17e0 commit 9094ff3
Show file tree
Hide file tree
Showing 17 changed files with 658 additions and 281 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Continuous integration

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install, lint and unit tests
run: |
python3 -m venv .env
source .env/bin/activate
pip install --upgrade pip
pip install -e .
pip install -r requirements-test.txt
python -m pytest -v tests -s
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ myimglist.txt
output_folder
indice_folder
image_folder
cat
cat
embedding_folder
index_folder
indices_paths.json
55 changes: 46 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

Easily computing clip embeddings and building a clip retrieval system with them.

* clip batch allows you to quickly (1500 sample/s on a 3080) compute image and text embeddings and indices
* clip filter allows you to filter out the data using the clip embeddings
* clip inference allows you to quickly (1500 sample/s on a 3080) compute image and text embeddings
* clip index builds efficient indices out of the embeddings
* clip filter allows you to filter out the data using the clip index
* clip back hosts the indices with a simple flask service
* clip service is a simple ui querying the back

Expand All @@ -17,7 +18,7 @@ Interested to learn about semantic search in general ? You can read by [medium p

pip install clip-retrieval

## clip batch
## clip inference

Get some images in an `example_folder`, for example by doing:
```
Expand All @@ -29,15 +30,42 @@ img2dataset --url_list=myimglist.txt --output_folder=image_folder --thread_count
```
You can also put text files with the same names as the images in that folder, to get the text embeddings.

Then run `clip-retrieval batch --dataset_path image_folder --output_folder indice_folder`
Then run `clip-retrieval inference --input_dataset image_folder --output_folder embeddings_folder`

Output folder will contain:
* description_list containing the list of caption line by line
* image_list containing the file path of images line by line
* img_emb.npy containing the image embeddings as numpy
* text_emb.npy containing the text embeddings as numpy
* img_emb/
* img_emb_0.npy containing the image embeddings as numpy
* text_emb/
* text_emb_0.npy containing the text embeddings as numpy
* metadata/
* metadata_0.parquet containing the image paths, captions and metadata

### API

clip_inference turn a set of text+image into clip embeddings

* **input_dataset** Path to input dataset. Folder if input_format is files. Bash brace pattern such as "{000..150}.tar" (see https://pypi.org/project/braceexpand/) if webdataset (*required*)
* **output_folder** Folder where the clip embeddings will be saved, as well as metadata (*required*)
* **input_format** files or webdataset (default *files*)
* **cache_path** cache path for webdataset (default *None*)
* **batch_size** Number of items to do the inference on at once (default *256*)
* **num_prepro_workers** Number of processes to do the preprocessing (default *8*)
* **enable_text** Enable text processing (default *True*)
* **enable_image** Enable image processing (default *True*)
* **enable_metadata** Enable metadata processing (default *False*)
* **write_batch_size** Write batch size (default *10**6*)
* **subset_size** Only process a subset of this size (default *None*)

## Clip index

Clip index takes as input the output of clip inference and makes an index out of it using [autofaiss](https://github.com/criteo/autofaiss)

`clip-retrieval index --input_folder embeddings_folder --output_folder index_folder`

The output is a folder containing:
* image.index containing a brute force faiss index for images
* text.index containing a brute force faiss index for texts
* metadata.arrow containing the metadata in a format that is easy to memory map

## Clip filter

Expand All @@ -48,7 +76,7 @@ Using the `--num_results` or `--threshold` may be helpful to refine the filter

## Clip back

Then run (output_folder is the output of clip batch)
Then run (output_folder is the output of clip index)
```bash
echo '{"example_index": "output_folder"}' > indices_paths.json
clip-retrieval back --port 1234 --indices-paths indices_paths.json
Expand Down Expand Up @@ -93,3 +121,12 @@ source .env/bin/activate
pip install -U pip
pip install -e .
```

to run tests:
```
pip install -r requirements-test.txt
```
then
```
python -m pytest -v tests -s
```
7 changes: 4 additions & 3 deletions clip_retrieval/cli.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from clip_retrieval.clip_back import clip_back
from clip_retrieval.clip_batch import clip_batch
from clip_retrieval.clip_inference import clip_inference
from clip_retrieval.clip_filter import clip_filter
from clip_retrieval.clip_index import clip_index
import fire
import logging


def main():
"""Main entry point"""
fire.Fire(
{
"back": clip_back,
"batch": clip_batch,
"inference": clip_inference,
"index": clip_index,
"filter": clip_filter
}
)
85 changes: 56 additions & 29 deletions clip_retrieval/clip_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,62 @@
import faiss
import torch
import json
import sys
from PIL import Image
from io import BytesIO
from PIL import Image
import base64
import os
import fire
from pathlib import Path
import pandas as pd


class Health(Resource):
def get(self):
return "ok"

class IndicesList(Resource):
def get(self, **kwargs):
return list(kwargs['indices'].keys())
def __init__(self, **kwargs):
super().__init__()
self.indices = kwargs['indices']


def get(self):
return list(self.indices.keys())

class KnnService(Resource):
def post(self, **kwargs):
indices_loaded = kwargs['indices_loaded']
device = kwargs['device']
model = kwargs['model']
preprocess = kwargs['preprocess']
def __init__(self, **kwargs):
super().__init__()
self.indices_loaded = kwargs['indices_loaded']
self.device = kwargs['device']
self.model = kwargs['model']
self.preprocess = kwargs['preprocess']


def post(self):
json_data = request.get_json(force=True)
text_input = json_data["text"] if "text" in json_data else None
image_input = json_data["image"] if "image" in json_data else None
modality = json_data["modality"]
num_images = json_data["num_images"]
indice_name = json_data["indice_name"]
image_index = indices_loaded[indice_name]["image_index"]
text_index = indices_loaded[indice_name]["text_index"]
image_list = indices_loaded[indice_name]["image_list"]
description_list = indices_loaded[indice_name]["description_list"]
image_index = self.indices_loaded[indice_name]["image_index"]
text_index = self.indices_loaded[indice_name]["text_index"]
image_list = self.indices_loaded[indice_name]["image_list"]
description_list = self.indices_loaded[indice_name]["description_list"]
url_list = self.indices_loaded[indice_name]["url_list"]

if text_input is not None:
text = clip.tokenize([text_input]).to(device)
text_features = model.encode_text(text)
text = clip.tokenize([text_input]).to(self.device)
text_features = self.model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
query = text_features.cpu().detach().numpy().astype("float32")
if image_input is not None:
binary_data = base64.b64decode(image_input)
img_data = BytesIO(binary_data)
img = Image.open(img_data)
prepro = preprocess(img).unsqueeze(0)
image_features = model.encode_image(prepro)
prepro = self.preprocess(img).unsqueeze(0)
image_features = self.model.encode_image(prepro)
image_features /= image_features.norm(dim=-1, keepdim=True)
query = image_features.cpu().detach().numpy().astype("float32")

Expand All @@ -58,13 +69,21 @@ def post(self, **kwargs):
D, I = index.search(query, num_images)
results = []
for d, i in zip(D[0], I[0]):
path = image_list[i]
description = description_list[i]
img = Image.open(path)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
results.append({"image": img_str, "text": description})
output = {}
if image_list is not None:
path = image_list[i]
if os.path.exists(path):
img = Image.open(path)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
output["image"] = img_str
if description_list is not None:
description = description_list[i]
output["text"] = description
if url_list is not None:
output["url"] = url_list[i]
results.append(output)
return results


Expand All @@ -77,24 +96,32 @@ def clip_back(indices_paths="indices_paths.json", port=1234):
indices_loaded = {}

for name, indice_folder in indices.items():
image_present = os.path.exists(indice_folder+"/image_list")
text_present = os.path.exists(indice_folder+"/description_list")
image_present = os.path.exists(indice_folder+"/image.index")
text_present = os.path.exists(indice_folder+"/text.index")
data_dir = Path(indice_folder+"/metadata")
metadata_df = pd.concat(
pd.read_parquet(parquet_file)
for parquet_file in data_dir.glob('*.parquet')
)

url_list = None
if "url" in metadata_df:
url_list = metadata_df["url"].tolist()
if image_present:
with open(indice_folder+"/image_list") as f:
image_list = [x for x in f.read().split("\n") if x !=""]
image_list = metadata_df["image_path"].tolist()
image_index = faiss.read_index(indice_folder+"/image.index")
else:
image_list = None
image_index = None
if text_present:
with open(indice_folder+"/description_list") as f:
description_list = [x for x in f.read().split("\n") if x !=""]
description_list = metadata_df["caption"].tolist()
text_index = faiss.read_index(indice_folder+"/text.index")
else:
description_list = None
text_index = None
indices_loaded[name]={
'image_list': image_list,
'url_list': url_list,
'description_list': description_list,
'image_index': image_index,
'text_index': text_index
Expand Down
Loading

0 comments on commit 9094ff3

Please sign in to comment.