From 80259c9ee29ecb1fe92482f4552076c617c0fd6e Mon Sep 17 00:00:00 2001 From: Nicholas Broad Date: Tue, 17 Sep 2024 08:28:31 -0700 Subject: [PATCH] feat(python): add cls and mean pooling (#402) --- .../text_embeddings_server/models/default_model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index 17ad4589..3c41b6c3 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -43,7 +43,14 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs["position_ids"] = batch.position_ids output = self.model(**kwargs) - embedding = output[0][:, 0] + + if self.pooling_mode == "cls": + embedding = output[0][:, 0] + elif self.pooling_mode == "mean": + embedding = output[0].mean(dim=1) + else: + raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") + cpu_results = embedding.view(-1).tolist() return [