Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

Commit

Permalink
add: --recursive and --concurrency flags for ingestion
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 committed May 3, 2024
1 parent 6264e81 commit 4db4074
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 37 deletions.
3 changes: 2 additions & 1 deletion pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
type IngestPathsOpts struct {
IgnoreExtensions []string
Concurrency int
Recursive bool
}

type Client interface {
Expand All @@ -19,7 +20,7 @@ type Client interface {
GetDataset(ctx context.Context, datasetID string) (*index.Dataset, error)
ListDatasets(ctx context.Context) ([]types.Dataset, error)
Ingest(ctx context.Context, datasetID string, data []byte, opts datastore.IngestOpts) ([]string, error)
IngestPaths(ctx context.Context, datasetID string, opts *IngestPathsOpts, paths ...string) error
IngestPaths(ctx context.Context, datasetID string, opts *IngestPathsOpts, paths ...string) (int, error) // returns number of files ingested
DeleteDocuments(ctx context.Context, datasetID string, documentIDs ...string) error
Retrieve(ctx context.Context, datasetID string, query string) ([]vectorstore.Document, error)
}
79 changes: 48 additions & 31 deletions pkg/client/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import (
"path/filepath"
)

func ingestPaths(ctx context.Context, opts *IngestPathsOpts, ingestionFunc func(path string) error, paths ...string) error {
func ingestPaths(ctx context.Context, opts *IngestPathsOpts, ingestionFunc func(path string) error, paths ...string) (int, error) {

ingestedFilesCount := 0

if opts.Concurrency < 1 {
opts.Concurrency = 10
}
Expand All @@ -19,43 +22,57 @@ func ingestPaths(ctx context.Context, opts *IngestPathsOpts, ingestionFunc func(

for _, p := range paths {
path := p
g.Go(func() error {
// Wait for a free slot, or exit if the context is done
if err := sem.Acquire(ctx, 1); err != nil {
return err
}
defer sem.Release(1)

fileInfo, err := os.Stat(path)
if err != nil {
return fmt.Errorf("failed to get file info for %s: %w", path, err)
}
fileInfo, err := os.Stat(path)
if err != nil {
return ingestedFilesCount, fmt.Errorf("failed to get file info for %s: %w", path, err)
}

if fileInfo.IsDir() {
// Process each file in the directory but skip directories (i.e. don't recurse)
err := filepath.WalkDir(path, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
return ingestionFunc(path)
})
if fileInfo.IsDir() {
// Process directory
err = filepath.WalkDir(path, func(subPath string, d os.DirEntry, err error) error {
if err != nil {
return err
}
} else {
// Process single file
err := ingestionFunc(path)
if err != nil {
return err
if d.IsDir() {
if subPath == path {
return nil // Always process the top-level directory
}
if !opts.Recursive {
return filepath.SkipDir // Skip subdirectories if not recursive
}
return nil
}

sp := subPath
g.Go(func() error {
if err := sem.Acquire(ctx, 1); err != nil {
return err
}
defer sem.Release(1)

ingestedFilesCount++
return ingestionFunc(sp)
})
return nil
})
if err != nil {
return ingestedFilesCount, err
}
return nil
})
} else {
// Process a file directly
g.Go(func() error {
if err := sem.Acquire(ctx, 1); err != nil {
return err
}
defer sem.Release(1)

ingestedFilesCount++
return ingestionFunc(path)
})
}
}

// Wait for all goroutines in the group to finish
return g.Wait()
// Wait for all goroutines to finish
return ingestedFilesCount, g.Wait()
}
2 changes: 1 addition & 1 deletion pkg/client/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (c *DefaultClient) Ingest(_ context.Context, datasetID string, data []byte,
return res.Documents, nil
}

func (c *DefaultClient) IngestPaths(ctx context.Context, datasetID string, opts *IngestPathsOpts, paths ...string) error {
func (c *DefaultClient) IngestPaths(ctx context.Context, datasetID string, opts *IngestPathsOpts, paths ...string) (int, error) {
ingestFile := func(path string) error {
content, err := os.ReadFile(path)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/client/standalone.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (c *StandaloneClient) Ingest(ctx context.Context, datasetID string, data []
return c.Datastore.Ingest(ctx, datasetID, data, opts)
}

func (c *StandaloneClient) IngestPaths(ctx context.Context, datasetID string, opts *IngestPathsOpts, paths ...string) error {
func (c *StandaloneClient) IngestPaths(ctx context.Context, datasetID string, opts *IngestPathsOpts, paths ...string) (int, error) {
ingestFile := func(path string) error {
// Gather metadata
finfo, err := os.Stat(path)
Expand Down
8 changes: 5 additions & 3 deletions pkg/cmd/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ type ClientIngest struct {
Client
Dataset string `usage:"Target Dataset ID" default:"default" env:"KNOW_TARGET_DATASET"`
IgnoreExtensions string `usage:"Comma-separated list of file extensions to ignore" env:"KNOW_INGEST_IGNORE_EXTENSIONS"`
Concurrency int `usage:"Number of concurrent ingestion processes" default:"10" env:"KNOW_INGEST_CONCURRENCY"`
Concurrency int `usage:"Number of concurrent ingestion processes" short:"c" default:"10" env:"KNOW_INGEST_CONCURRENCY"`
Recursive bool `usage:"Recursively ingest directories" short:"r" default:"false" env:"KNOW_INGEST_RECURSIVE"`
}

func (s *ClientIngest) Customize(cmd *cobra.Command) {
Expand All @@ -32,13 +33,14 @@ func (s *ClientIngest) Run(cmd *cobra.Command, args []string) error {
ingestOpts := &client.IngestPathsOpts{
IgnoreExtensions: strings.Split(s.IgnoreExtensions, ","),
Concurrency: s.Concurrency,
Recursive: s.Recursive,
}

err = c.IngestPaths(cmd.Context(), datasetID, ingestOpts, filePath)
filesIngested, err := c.IngestPaths(cmd.Context(), datasetID, ingestOpts, filePath)
if err != nil {
return err
}

fmt.Printf("Ingested %q into dataset %q\n", filePath, datasetID)
fmt.Printf("Ingested %d files from %q into dataset %q\n", filesIngested, filePath, datasetID)
return nil
}

0 comments on commit 4db4074

Please sign in to comment.