Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

Commit

Permalink
change: merge archive logic into getClient
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 committed Jul 1, 2024
1 parent c066081 commit 6598b78
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 52 deletions.
30 changes: 19 additions & 11 deletions pkg/cmd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import (
)

type Client struct {
Server string `usage:"URL of the Knowledge API Server" default:"" env:"KNOW_SERVER_URL"`
Server string `usage:"URL of the Knowledge API Server" default:"" env:"KNOW_SERVER_URL"`
datastoreArchive string
config.OpenAIConfig
config.DatabaseConfig
config.VectorDBConfig
Expand All @@ -24,30 +25,33 @@ type ClientFlowsConfig struct {
Flow string `usage:"Flow name" env:"KNOW_FLOW"`
}

func (s *Client) getClientFromArchive(archive string) (client.Client, error) {
func (s *Client) loadArchive() error {
if s.datastoreArchive == "" {
return nil
}
// unpack to tempdir
tmpDir, err := os.MkdirTemp(os.TempDir(), "knowledge-retrieve-*")
if err != nil {
return nil, err
return err
}
defer os.RemoveAll(tmpDir)

r, err := zip.OpenReader(archive)
r, err := zip.OpenReader(s.datastoreArchive)
if err != nil {
return nil, err
return err
}
defer r.Close()

if len(r.File) != 2 {
return nil, fmt.Errorf("knowledge archive must contain exactly two files, found %d", len(r.File))
return fmt.Errorf("knowledge archive must contain exactly two files, found %d", len(r.File))
}

dbFile := ""
vectorStoreFile := ""
for _, f := range r.File {
rc, err := f.Open()
if err != nil {
return nil, err
return err
}
defer rc.Close()

Expand All @@ -58,12 +62,12 @@ func (s *Client) getClientFromArchive(archive string) (client.Client, error) {

f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return nil, err
return err
}
defer f.Close()

if _, err := io.Copy(f, rc); err != nil {
return nil, err
return err
}
_ = f.Close()
_ = rc.Close()
Expand All @@ -77,17 +81,21 @@ func (s *Client) getClientFromArchive(archive string) (client.Client, error) {
}

if dbFile == "" || vectorStoreFile == "" {
return nil, fmt.Errorf("knowledge archive must contain exactly one .db and one .gob file")
return fmt.Errorf("knowledge archive must contain exactly one .db and one .gob file")
}

s.DSN = types.ArchivePrefix + dbFile
s.VectorDBPath = types.ArchivePrefix + vectorStoreFile

return s.getClient()
return nil
}

func (s *Client) getClient() (client.Client, error) {

if err := s.loadArchive(); err != nil {
return nil, err
}

if s.Server == "" || s.Server == "standalone" {
ds, err := datastore.NewDatastore(s.DSN, s.AutoMigrate == "true", s.VectorDBConfig.VectorDBPath, s.OpenAIConfig)
if err != nil {
Expand Down
17 changes: 3 additions & 14 deletions pkg/cmd/get_dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package cmd
import (
"encoding/json"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/client"

"github.com/spf13/cobra"
)

Expand All @@ -21,18 +19,9 @@ func (s *ClientGetDataset) Customize(cmd *cobra.Command) {
}

func (s *ClientGetDataset) Run(cmd *cobra.Command, args []string) error {
var err error
var c client.Client
if s.Archive != "" {
c, err = s.getClientFromArchive(s.Archive)
if err != nil {
return err
}
} else {
c, err = s.getClient()
if err != nil {
return err
}
c, err := s.getClient()
if err != nil {
return err
}

datasetID := args[0]
Expand Down
18 changes: 4 additions & 14 deletions pkg/cmd/list_datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package cmd
import (
"encoding/json"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/client"

"github.com/spf13/cobra"
)

Expand All @@ -20,18 +18,10 @@ func (s *ClientListDatasets) Customize(cmd *cobra.Command) {
}

func (s *ClientListDatasets) Run(cmd *cobra.Command, args []string) error {
var err error
var c client.Client
if s.Archive != "" {
c, err = s.getClientFromArchive(s.Archive)
if err != nil {
return err
}
} else {
c, err = s.getClient()
if err != nil {
return err
}

c, err := s.getClient()
if err != nil {
return err
}

ds, err := c.ListDatasets(cmd.Context())
Expand Down
16 changes: 3 additions & 13 deletions pkg/cmd/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/client"
"github.com/gptscript-ai/knowledge/pkg/datastore"
flowconfig "github.com/gptscript-ai/knowledge/pkg/flows/config"
vserr "github.com/gptscript-ai/knowledge/pkg/vectorstore/errors"
Expand Down Expand Up @@ -32,18 +31,9 @@ func (s *ClientRetrieve) Customize(cmd *cobra.Command) {

func (s *ClientRetrieve) Run(cmd *cobra.Command, args []string) error {

var err error
var c client.Client
if s.Archive != "" {
c, err = s.getClientFromArchive(s.Archive)
if err != nil {
return err
}
} else {
c, err = s.getClient()
if err != nil {
return err
}
c, err := s.getClient()
if err != nil {
return err
}

datasetID := s.Dataset
Expand Down

0 comments on commit 6598b78

Please sign in to comment.