From 358fafeb40633949ab67172cc8993b15c0462d1b Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Thu, 7 Dec 2023 15:42:46 -0800 Subject: [PATCH] pgvector --- .../vectorstores/_pgvector_data_models.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/_pgvector_data_models.py b/libs/langchain/langchain/vectorstores/_pgvector_data_models.py index 1a4b60776537b..062d64896b65d 100644 --- a/libs/langchain/langchain/vectorstores/_pgvector_data_models.py +++ b/libs/langchain/langchain/vectorstores/_pgvector_data_models.py @@ -1,12 +1,24 @@ -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple import sqlalchemy -from pgvector.sqlalchemy import Vector from sqlalchemy.dialects.postgresql import JSON, UUID from sqlalchemy.orm import Session, relationship from langchain.vectorstores.pgvector import BaseModel +if TYPE_CHECKING: + from pgvector.sqlalchemy import Vector + + +def _import_vector() -> None: + try: + from pgvector.sqlalchemy import Vector + except ImportError: + raise ImportError( + "The `pgvector` library is required to use the PGVectorStore." + ) + return Vector + class CollectionStore(BaseModel): """Collection store.""" @@ -63,7 +75,7 @@ class EmbeddingStore(BaseModel): ) collection = relationship(CollectionStore, back_populates="embeddings") - embedding: Vector = sqlalchemy.Column(Vector(None)) + embedding: Vector = sqlalchemy.Column(_import_vector()(None)) document = sqlalchemy.Column(sqlalchemy.String, nullable=True) cmetadata = sqlalchemy.Column(JSON, nullable=True)