Skip to content

Commit

Permalink
Feat: rwkv improvements: (#937)
Browse files Browse the repository at this point in the history
  • Loading branch information
dave-gray101 authored Aug 22, 2023
1 parent 0d6165e commit 901f070
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 150 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# go-llama build artifacts
go-llama
go-llama-stable
/gpt4all
go-stable-diffusion
go-piper
Expand Down
2 changes: 2 additions & 0 deletions api/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
Device: c.AutoGPTQ.Device,
UseTriton: c.AutoGPTQ.Triton,
UseFastTokenizer: c.AutoGPTQ.UseFastTokenizer,
// RWKV
Tokenizer: c.Tokenizer,
}
}

Expand Down
3 changes: 3 additions & 0 deletions api/config/prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,7 @@ type PredictionOptions struct {

// Diffusers
ClipSkip int `json:"clip_skip" yaml:"clip_skip"`

// RWKV (?)
Tokenizer string `json:"tokenizer" yaml:"tokenizer"`
}
29 changes: 27 additions & 2 deletions pkg/backend/llm/rwkv/rwkv.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@ type LLM struct {
}

func (llm *LLM) Load(opts *pb.ModelOptions) error {
tokenizerFile := opts.Tokenizer
if tokenizerFile == "" {
modelFile := filepath.Base(opts.ModelFile)
tokenizerFile = modelFile + tokenizerSuffix
}
modelPath := filepath.Dir(opts.ModelFile)
modelFile := filepath.Base(opts.ModelFile)
model := rwkv.LoadFiles(opts.ModelFile, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads()))
tokenizerPath := filepath.Join(modelPath, tokenizerFile)

model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads()))

if model == nil {
return fmt.Errorf("could not load model")
Expand Down Expand Up @@ -68,3 +74,22 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro

return nil
}

func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
tokens, err := llm.rwkv.Tokenizer.Encode(opts.Prompt)
if err != nil {
return pb.TokenizationResponse{}, err
}

l := len(tokens)
i32Tokens := make([]int32, l)

for i, t := range tokens {
i32Tokens[i] = int32(t.ID)
}

return pb.TokenizationResponse{
Length: int32(l),
Tokens: i32Tokens,
}, nil
}
265 changes: 138 additions & 127 deletions pkg/grpc/proto/backend.pb.go

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pkg/grpc/proto/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ message ModelOptions {
string CLIPModel = 31;
string CLIPSubfolder = 32;
int32 CLIPSkip = 33;

// RWKV
string Tokenizer = 34;
}

message Result {
Expand Down
55 changes: 34 additions & 21 deletions pkg/grpc/proto/backend_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 901f070

Please sign in to comment.