Skip to content

Commit

Permalink
Merge pull request #82 from bjwswang/main
Browse files Browse the repository at this point in the history
fix: optimize chromadb configurations
  • Loading branch information
bjwswang authored Sep 11, 2023
2 parents 6419aab + 96b9dbd commit 1fecdf6
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 52 deletions.
2 changes: 1 addition & 1 deletion examples/embedding/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func main() {
panic(fmt.Errorf("error create embedder: %s", err.Error()))
}
// init vector store
chroma, err := chromadb.New(chromadb.WithScheme("http"), chromadb.WithHost("localhost:8000"), chromadb.WithEmbedder(embedder))
chroma, err := chromadb.New(chromadb.WithURL("http://localhost:8000"), chromadb.WithEmbedder(embedder))
if err != nil {
panic(fmt.Errorf("error create chroma db: %s", err.Error()))
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/amikos-tech/chroma-go v0.0.0-20230901221218-d0087270239e
github.com/go-logr/logr v1.2.0
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.18.1
github.com/r3labs/sse/v2 v2.10.0
Expand Down Expand Up @@ -48,7 +49,6 @@ require (
github.com/google/gnostic v0.5.7-v3refs // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/gofuzz v1.1.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
Expand Down
18 changes: 8 additions & 10 deletions pkg/vectorstores/chromadb/chroma.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package chromadb

import (
"context"
"errors"
"fmt"

chroma "github.com/amikos-tech/chroma-go"
chromaopenapi "github.com/amikos-tech/chroma-go/swagger"
Expand All @@ -41,8 +41,7 @@ type Store struct {
embedder wrappedEmbeddingFunction
client *chroma.Client

scheme string
host string
url string

// optional
nameSpaceKey string
Expand All @@ -59,16 +58,16 @@ type Store struct {
var _ vectorstores.VectorStore = Store{}

// New creates a new Store with options for chromadb.
func New(opts ...Option) (vectorstores.VectorStore, error) {
func New(opts ...Option) (Store, error) {
s, err := applyClientOptions(opts...)
if err != nil {
return nil, err
return Store{}, err
}

configuration := chromaopenapi.NewConfiguration()
configuration.Servers = chromaopenapi.ServerConfigurations{
{
URL: fmt.Sprintf("%s://%s", s.scheme, s.host),
URL: s.url,
Description: "Chromadb server url for this store",
},
}
Expand All @@ -77,7 +76,7 @@ func New(opts ...Option) (vectorstores.VectorStore, error) {
}

if _, err = s.client.Heartbeat(); err != nil {
return nil, err
return Store{}, err
}

return s, nil
Expand All @@ -89,9 +88,8 @@ func (s Store) AddDocuments(ctx context.Context, docs []schema.Document, options

texts := make([]string, 0, len(docs))
ids := make([]string, len(docs))
for idx, doc := range docs {
for _, doc := range docs {
texts = append(texts, doc.PageContent)
ids[idx] = fmt.Sprintf("%d", idx)
}

collection, err := s.client.CreateCollection(s.nameSpace, map[string]interface{}{}, true, s.embedder, s.distanceFunc)
Expand Down Expand Up @@ -184,7 +182,7 @@ func (s Store) getScoreThreshold(opts vectorstores.Options) (float32, error) {
return f32, nil
}

// FIXME: optimize filter
// FIXME: optimize filter.
func (s Store) getFilters(opts vectorstores.Options) map[string]any {
filters, ok := opts.Filters.(map[string]any)
if !ok {
Expand Down
36 changes: 12 additions & 24 deletions pkg/vectorstores/chromadb/chroma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package chromadb

import (
Expand All @@ -25,46 +26,35 @@ import (
chroma "github.com/amikos-tech/chroma-go"
"github.com/google/uuid"
"github.com/stretchr/testify/require"

openaiEmbeddings "github.com/tmc/langchaingo/embeddings/openai"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores"
)

func getValues(t *testing.T) (string, string) {
func getValues(t *testing.T) string {
t.Helper()

scheme := os.Getenv("CHROMA_SCHEME")
if scheme == "" {
t.Skip("Must set CHROMA_SCHEME to run test")
}

host := os.Getenv("CHROMA_HOST")
if host == "" {
t.Skip("Must set CHROMA_HOST to run test")
url := os.Getenv("CHROMA_URL")
if url == "" {
t.Skip("Must set CHROMA_URL to run test")
}

if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
}

if zhipuaiKey := os.Getenv("ZHIPUAI_API_KEY"); zhipuaiKey == "" {
t.Skip("ZHIPUAI_API_KEY not set")
}

return scheme, host
return url
}

func TestChromaStoreRest(t *testing.T) {
t.Parallel()

scheme, host := getValues(t)
url := getValues(t)
e, err := openaiEmbeddings.NewOpenAI()
require.NoError(t, err)

store, err := New(
WithScheme(scheme),
WithHost(host),
WithURL(url),
WithEmbedder(e),
WithNameSpace(uuid.New().String()),
WithDistanceFunc(chroma.COSINE),
Expand Down Expand Up @@ -92,13 +82,12 @@ func TestChromaStoreRest(t *testing.T) {
func TestChromaStoreRestWithScoreThreshold(t *testing.T) {
t.Parallel()

scheme, host := getValues(t)
url := getValues(t)
e, err := openaiEmbeddings.NewOpenAI()
require.NoError(t, err)

store, err := New(
WithScheme(scheme),
WithHost(host),
WithURL(url),
WithEmbedder(e),
WithNameSpace(uuid.New().String()),
WithDistanceFunc(chroma.COSINE),
Expand Down Expand Up @@ -137,13 +126,12 @@ func TestChromaStoreRestWithScoreThreshold(t *testing.T) {
func TestSimilaritySearchWithInvalidScoreThreshold(t *testing.T) {
t.Parallel()

scheme, host := getValues(t)
url := getValues(t)
e, err := openaiEmbeddings.NewOpenAI()
require.NoError(t, err)

store, err := New(
WithScheme(scheme),
WithHost(host),
WithURL(url),
WithEmbedder(e),
WithNameSpace(uuid.New().String()),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package chromadb

import (
Expand All @@ -24,7 +25,7 @@ import (

var _ chroma.EmbeddingFunction = wrappedEmbeddingFunction{}

// wrappedEmbeddingFunction is a wrapper around an embeddings.Embedder to convert langchain embedder to chroma embeddingFunction
// wrappedEmbeddingFunction is a wrapper around an embeddings.
type wrappedEmbeddingFunction struct {
embeddings.Embedder
}
Expand Down
20 changes: 5 additions & 15 deletions pkg/vectorstores/chromadb/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package chromadb

import (
Expand Down Expand Up @@ -60,16 +61,9 @@ func WithNameSpaceKey(nameSpaceKey string) Option {
}

// WithScheme is an option for setting the scheme of the chromadb server.Must be set.
func WithScheme(scheme string) Option {
func WithURL(url string) Option {
return func(p *Store) {
p.scheme = scheme
}
}

// WithHost is an option for setting the host of the chromadb server.Must be set.
func WithHost(host string) Option {
return func(p *Store) {
p.host = host
p.url = url
}
}

Expand Down Expand Up @@ -106,12 +100,8 @@ func applyClientOptions(opts ...Option) (Store, error) {
opt(o)
}

if o.scheme == "" {
return Store{}, fmt.Errorf("%w: missing scheme", ErrInvalidOptions)
}

if o.host == "" {
return Store{}, fmt.Errorf("%w: missing host", ErrInvalidOptions)
if o.url == "" {
return Store{}, fmt.Errorf("%w: missing url", ErrInvalidOptions)
}

if o.embedder.Embedder == nil {
Expand Down

0 comments on commit 1fecdf6

Please sign in to comment.