Skip to content

Commit

Permalink
feat(speculative-sampling): allow to specify a draft model in the mod…
Browse files Browse the repository at this point in the history
…el config (#1052)

**Description**

This PR fixes #1013.

It adds `draft_model` and `n_draft` to the model YAML config in order to
load models with speculative sampling. This should be compatible as well
with grammars.

example:

```yaml
backend: llama                                                                                                                                                                   
context_size: 1024                                                                                                                                                                        
name: my-model-name
parameters:
  model: foo-bar
n_draft: 16                                                                                                                                                                      
draft_model: model-name
```

---------

Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Sep 14, 2023
1 parent 247d85b commit 8ccf5b2
Show file tree
Hide file tree
Showing 12 changed files with 485 additions and 427 deletions.
2 changes: 2 additions & 0 deletions api/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
Seed: int32(c.Seed),
NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ,
DraftModel: c.DraftModel,
AudioPath: c.VallE.AudioPath,
LoraAdapter: c.LoraAdapter,
LoraBase: c.LoraBase,
Expand Down Expand Up @@ -79,6 +80,7 @@ func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
return &pb.PredictOptions{
Temperature: float32(c.Temperature),
TopP: float32(c.TopP),
NDraft: c.NDraft,
TopK: int32(c.TopK),
Tokens: int32(c.Maxtokens),
Threads: int32(c.Threads),
Expand Down
2 changes: 2 additions & 0 deletions api/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ type LLMConfig struct {
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
}

type AutoGPTQ struct {
Expand Down
64 changes: 32 additions & 32 deletions extra/grpc/autogptq/backend_pb2.py

Large diffs are not rendered by default.

64 changes: 32 additions & 32 deletions extra/grpc/bark/backend_pb2.py

Large diffs are not rendered by default.

64 changes: 32 additions & 32 deletions extra/grpc/diffusers/backend_pb2.py

Large diffs are not rendered by default.

64 changes: 32 additions & 32 deletions extra/grpc/exllama/backend_pb2.py

Large diffs are not rendered by default.

64 changes: 32 additions & 32 deletions extra/grpc/huggingface/backend_pb2.py

Large diffs are not rendered by default.

64 changes: 32 additions & 32 deletions extra/grpc/vall-e-x/backend_pb2.py

Large diffs are not rendered by default.

64 changes: 32 additions & 32 deletions extra/grpc/vllm/backend_pb2.py

Large diffs are not rendered by default.

37 changes: 35 additions & 2 deletions pkg/backend/llm/llama/llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ import (
type LLM struct {
base.SingleThread

llama *llama.LLama
llama *llama.LLama
draftModel *llama.LLama
}

func (llm *LLM) Load(opts *pb.ModelOptions) error {
Expand Down Expand Up @@ -78,7 +79,27 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
}

if opts.DraftModel != "" {
// https://github.com/ggerganov/llama.cpp/blob/71ca2fad7d6c0ef95ef9944fb3a1a843e481f314/examples/speculative/speculative.cpp#L40
llamaOpts = append(llamaOpts, llama.SetPerplexity(true))
}

model, err := llama.New(opts.ModelFile, llamaOpts...)

if opts.DraftModel != "" {
// opts.DraftModel is relative to opts.ModelFile, so we need to get the basepath of opts.ModelFile
if !filepath.IsAbs(opts.DraftModel) {
dir := filepath.Dir(opts.ModelFile)
opts.DraftModel = filepath.Join(dir, opts.DraftModel)
}

draftModel, err := llama.New(opts.DraftModel, llamaOpts...)
if err != nil {
return err
}
llm.draftModel = draftModel
}

llm.llama = model

return err
Expand Down Expand Up @@ -162,6 +183,9 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed)))
}

if opts.NDraft != 0 {
predictOptions = append(predictOptions, llama.SetNDraft(int(opts.NDraft)))
}
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))

predictOptions = append(predictOptions, llama.SetFrequencyPenalty(opts.FrequencyPenalty))
Expand All @@ -175,6 +199,9 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
}

func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
if llm.draftModel != nil {
return llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...)
}
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
}

Expand All @@ -187,7 +214,13 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
}))

go func() {
_, err := llm.llama.Predict(opts.Prompt, predictOptions...)
var err error
if llm.draftModel != nil {
_, err = llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...)
} else {
_, err = llm.llama.Predict(opts.Prompt, predictOptions...)
}

if err != nil {
fmt.Println("err: ", err)
}
Expand Down
419 changes: 219 additions & 200 deletions pkg/grpc/proto/backend.pb.go

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pkg/grpc/proto/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ message PredictOptions {
float RopeFreqScale = 38;
float NegativePromptScale = 39;
string NegativePrompt = 40;
int32 NDraft = 41;
}

// The response message containing the result
Expand Down Expand Up @@ -116,7 +117,8 @@ message ModelOptions {
string LoraBase = 35;
string LoraAdapter = 36;
bool NoMulMatQ = 37;

string DraftModel = 39;

string AudioPath = 38;
}

Expand Down

0 comments on commit 8ccf5b2

Please sign in to comment.