From 96b9dbdaf1fab10911f16ec370ad8a10e9f47596 Mon Sep 17 00:00:00 2001 From: bjwswang Date: Mon, 11 Sep 2023 11:11:42 +0800 Subject: [PATCH] fix: optimize chromadb configurations Signed-off-by: bjwswang --- examples/embedding/main.go | 2 +- go.mod | 2 +- pkg/vectorstores/chromadb/chroma.go | 18 +++++----- pkg/vectorstores/chromadb/chroma_test.go | 36 +++++++------------ .../chromadb/{embedder.go => embedingfunc.go} | 3 +- pkg/vectorstores/chromadb/options.go | 20 +++-------- 6 files changed, 29 insertions(+), 52 deletions(-) rename pkg/vectorstores/chromadb/{embedder.go => embedingfunc.go} (91%) diff --git a/examples/embedding/main.go b/examples/embedding/main.go index 7c8caba98..7dfd1683d 100644 --- a/examples/embedding/main.go +++ b/examples/embedding/main.go @@ -41,7 +41,7 @@ func main() { panic(fmt.Errorf("error create embedder: %s", err.Error())) } // init vector store - chroma, err := chromadb.New(chromadb.WithScheme("http"), chromadb.WithHost("localhost:8000"), chromadb.WithEmbedder(embedder)) + chroma, err := chromadb.New(chromadb.WithURL("http://localhost:8000"), chromadb.WithEmbedder(embedder)) if err != nil { panic(fmt.Errorf("error create chroma db: %s", err.Error())) } diff --git a/go.mod b/go.mod index 98c4abc5a..b0cac44d1 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/amikos-tech/chroma-go v0.0.0-20230901221218-d0087270239e github.com/go-logr/logr v1.2.0 github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/uuid v1.3.0 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.18.1 github.com/r3labs/sse/v2 v2.10.0 @@ -48,7 +49,6 @@ require ( github.com/google/gnostic v0.5.7-v3refs // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/gofuzz v1.1.0 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/imdario/mergo v0.3.12 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect diff --git a/pkg/vectorstores/chromadb/chroma.go b/pkg/vectorstores/chromadb/chroma.go index 5003d9007..b6270ed10 100644 --- a/pkg/vectorstores/chromadb/chroma.go +++ b/pkg/vectorstores/chromadb/chroma.go @@ -13,12 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package chromadb import ( "context" "errors" - "fmt" chroma "github.com/amikos-tech/chroma-go" chromaopenapi "github.com/amikos-tech/chroma-go/swagger" @@ -41,8 +41,7 @@ type Store struct { embedder wrappedEmbeddingFunction client *chroma.Client - scheme string - host string + url string // optional nameSpaceKey string @@ -59,16 +58,16 @@ type Store struct { var _ vectorstores.VectorStore = Store{} // New creates a new Store with options for chromadb. -func New(opts ...Option) (vectorstores.VectorStore, error) { +func New(opts ...Option) (Store, error) { s, err := applyClientOptions(opts...) if err != nil { - return nil, err + return Store{}, err } configuration := chromaopenapi.NewConfiguration() configuration.Servers = chromaopenapi.ServerConfigurations{ { - URL: fmt.Sprintf("%s://%s", s.scheme, s.host), + URL: s.url, Description: "Chromadb server url for this store", }, } @@ -77,7 +76,7 @@ func New(opts ...Option) (vectorstores.VectorStore, error) { } if _, err = s.client.Heartbeat(); err != nil { - return nil, err + return Store{}, err } return s, nil @@ -89,9 +88,8 @@ func (s Store) AddDocuments(ctx context.Context, docs []schema.Document, options texts := make([]string, 0, len(docs)) ids := make([]string, len(docs)) - for idx, doc := range docs { + for _, doc := range docs { texts = append(texts, doc.PageContent) - ids[idx] = fmt.Sprintf("%d", idx) } collection, err := s.client.CreateCollection(s.nameSpace, map[string]interface{}{}, true, s.embedder, s.distanceFunc) @@ -184,7 +182,7 @@ func (s Store) getScoreThreshold(opts vectorstores.Options) (float32, error) { return f32, nil } -// FIXME: optimize filter +// FIXME: optimize filter. func (s Store) getFilters(opts vectorstores.Options) map[string]any { filters, ok := opts.Filters.(map[string]any) if !ok { diff --git a/pkg/vectorstores/chromadb/chroma_test.go b/pkg/vectorstores/chromadb/chroma_test.go index 708f160c6..dfdd759fc 100644 --- a/pkg/vectorstores/chromadb/chroma_test.go +++ b/pkg/vectorstores/chromadb/chroma_test.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package chromadb import ( @@ -25,46 +26,35 @@ import ( chroma "github.com/amikos-tech/chroma-go" "github.com/google/uuid" "github.com/stretchr/testify/require" - openaiEmbeddings "github.com/tmc/langchaingo/embeddings/openai" "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" ) -func getValues(t *testing.T) (string, string) { +func getValues(t *testing.T) string { t.Helper() - scheme := os.Getenv("CHROMA_SCHEME") - if scheme == "" { - t.Skip("Must set CHROMA_SCHEME to run test") - } - - host := os.Getenv("CHROMA_HOST") - if host == "" { - t.Skip("Must set CHROMA_HOST to run test") + url := os.Getenv("CHROMA_URL") + if url == "" { + t.Skip("Must set CHROMA_URL to run test") } if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" { t.Skip("OPENAI_API_KEY not set") } - if zhipuaiKey := os.Getenv("ZHIPUAI_API_KEY"); zhipuaiKey == "" { - t.Skip("ZHIPUAI_API_KEY not set") - } - - return scheme, host + return url } func TestChromaStoreRest(t *testing.T) { t.Parallel() - scheme, host := getValues(t) + url := getValues(t) e, err := openaiEmbeddings.NewOpenAI() require.NoError(t, err) store, err := New( - WithScheme(scheme), - WithHost(host), + WithURL(url), WithEmbedder(e), WithNameSpace(uuid.New().String()), WithDistanceFunc(chroma.COSINE), @@ -92,13 +82,12 @@ func TestChromaStoreRest(t *testing.T) { func TestChromaStoreRestWithScoreThreshold(t *testing.T) { t.Parallel() - scheme, host := getValues(t) + url := getValues(t) e, err := openaiEmbeddings.NewOpenAI() require.NoError(t, err) store, err := New( - WithScheme(scheme), - WithHost(host), + WithURL(url), WithEmbedder(e), WithNameSpace(uuid.New().String()), WithDistanceFunc(chroma.COSINE), @@ -137,13 +126,12 @@ func TestChromaStoreRestWithScoreThreshold(t *testing.T) { func TestSimilaritySearchWithInvalidScoreThreshold(t *testing.T) { t.Parallel() - scheme, host := getValues(t) + url := getValues(t) e, err := openaiEmbeddings.NewOpenAI() require.NoError(t, err) store, err := New( - WithScheme(scheme), - WithHost(host), + WithURL(url), WithEmbedder(e), WithNameSpace(uuid.New().String()), ) diff --git a/pkg/vectorstores/chromadb/embedder.go b/pkg/vectorstores/chromadb/embedingfunc.go similarity index 91% rename from pkg/vectorstores/chromadb/embedder.go rename to pkg/vectorstores/chromadb/embedingfunc.go index 7fa66d906..b9b70adb4 100644 --- a/pkg/vectorstores/chromadb/embedder.go +++ b/pkg/vectorstores/chromadb/embedingfunc.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package chromadb import ( @@ -24,7 +25,7 @@ import ( var _ chroma.EmbeddingFunction = wrappedEmbeddingFunction{} -// wrappedEmbeddingFunction is a wrapper around an embeddings.Embedder to convert langchain embedder to chroma embeddingFunction +// wrappedEmbeddingFunction is a wrapper around an embeddings. type wrappedEmbeddingFunction struct { embeddings.Embedder } diff --git a/pkg/vectorstores/chromadb/options.go b/pkg/vectorstores/chromadb/options.go index e4d61e2b3..34c6ed130 100644 --- a/pkg/vectorstores/chromadb/options.go +++ b/pkg/vectorstores/chromadb/options.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package chromadb import ( @@ -60,16 +61,9 @@ func WithNameSpaceKey(nameSpaceKey string) Option { } // WithScheme is an option for setting the scheme of the chromadb server.Must be set. -func WithScheme(scheme string) Option { +func WithURL(url string) Option { return func(p *Store) { - p.scheme = scheme - } -} - -// WithHost is an option for setting the host of the chromadb server.Must be set. -func WithHost(host string) Option { - return func(p *Store) { - p.host = host + p.url = url } } @@ -106,12 +100,8 @@ func applyClientOptions(opts ...Option) (Store, error) { opt(o) } - if o.scheme == "" { - return Store{}, fmt.Errorf("%w: missing scheme", ErrInvalidOptions) - } - - if o.host == "" { - return Store{}, fmt.Errorf("%w: missing host", ErrInvalidOptions) + if o.url == "" { + return Store{}, fmt.Errorf("%w: missing url", ErrInvalidOptions) } if o.embedder.Embedder == nil {