Skip to content

Commit

Permalink
feat: allow "negative" text queries
Browse files Browse the repository at this point in the history
  • Loading branch information
ramayer committed Sep 28, 2021
1 parent 6e5d9e6 commit f1c1aef
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
8 changes: 5 additions & 3 deletions rclip/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ def ensure_index(self, directory: str):

self._db.commit()

def search(self, query: str, directory: str, top_k: int = 10) -> List[SearchResult]:
def search(self, query: str, directory: str, top_k: int = 10,
positive_phrases: List[str] = [], negative_phrases: List[str] = []) -> List[SearchResult]:
filepaths, features = self._get_features(directory)

sorted_similarities = self._model.compute_similarities_to_text(features, query)
positive_phrases = [query] + positive_phrases
sorted_similarities = self._model.compute_similarities_to_phrase(features, positive_phrases, negative_phrases)

filtered_similarities = filter(
lambda similarity: not self._exclude_dir_regex.match(filepaths[similarity[1]]),
Expand Down Expand Up @@ -159,7 +161,7 @@ def main():
if not args.skip_index:
rclip.ensure_index(current_directory)

result = rclip.search(args.query, current_directory, args.top)
result = rclip.search(args.query, current_directory, args.top, args.plus, args.minus)
if args.filepath_only:
for r in result:
print(r.filepath)
Expand Down
15 changes: 15 additions & 0 deletions rclip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import clip
import clip.model
import functools
import numpy as np
from PIL import Image
import torch
Expand Down Expand Up @@ -44,3 +45,17 @@ def compute_similarities_to_text(self, item_features: np.ndarray, text: str) ->
sorted_similarities = sorted(zip(similarities, range(item_features.shape[0])), key=lambda x: x[0], reverse=True)

return sorted_similarities

def compute_similarities_to_phrase(self, item_features: np.ndarray,
plus: List[str], minus: List[str]) -> List[Tuple[float, int]]:

plus_features: List[np.ndarray] = [self.compute_text_features(phrase) for phrase in plus]
minus_features: List[np.ndarray] = [-self.compute_text_features(phrase) for phrase in minus]
all_features: List[np.ndarray] = plus_features + minus_features

text_features = functools.reduce(lambda x, y: x+y, all_features)

similarities = (text_features @ item_features.T).squeeze(0).tolist()
sorted_similarities = sorted(zip(similarities, range(item_features.shape[0])), key=lambda x: x[0], reverse=True)

return sorted_similarities
2 changes: 2 additions & 0 deletions rclip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def top_arg_type(arg: str) -> int:
def init_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('query')
parser.add_argument('--minus', action='append', help='phrases to subtract from the score')
parser.add_argument('--plus', action='append', help='phrases to add to the score')
parser.add_argument('--top', '-t', type=top_arg_type, default=10, help='number of top results to display')
parser.add_argument('--filepath-only', '-f', action='store_true', default=False, help='outputs only filepaths')
parser.add_argument(
Expand Down

0 comments on commit f1c1aef

Please sign in to comment.