-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclient.go
81 lines (66 loc) · 1.7 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
package main
import (
"context"
"time"
"github.com/fatih/color"
"github.com/ollama/ollama/api"
"github.com/rs/zerolog/log"
)
func New(ollama *api.Client) *Client {
return &Client{
ollama: ollama,
color: color.New(color.FgHiCyan),
}
}
type Client struct {
ollama *api.Client
color *color.Color
}
var keepAlive = &api.Duration{Duration: time.Second * 30}
func (c *Client) prompt(model, msg string) *api.ChatResponse {
var res = make(chan *api.ChatResponse)
must(
c.ollama.Chat(context.Background(), &api.ChatRequest{
Model: model,
Stream: ptr(false),
KeepAlive: keepAlive,
Messages: []api.Message{
{
Role: "user",
Content: msg,
},
},
}, func(cr api.ChatResponse) error {
go func() {
res <- &cr
}()
return nil
}),
)
return <-res
}
func (c *Client) TestModel(model *Model) *Run {
var (
start = time.Now()
run = &Run{Model: model}
allQuestionsTook time.Duration
)
// Load model into memory with simple first query
log.Info().Str("model", model.Name).Str("size", model.Storage).Msg("Loading into memory...")
run.InitialLoadDuration = c.prompt(model.Name, "hi").Metrics.LoadDuration
for _, question := range questions {
println(question)
res := c.prompt(model.Name, question)
allQuestionsTook += res.Metrics.TotalDuration
if res.TotalDuration > run.MaxDuration {
run.MaxDuration = res.TotalDuration
}
if run.MinDuration > res.TotalDuration || run.MinDuration == 0 {
run.MinDuration = res.TotalDuration
}
c.color.Println("🤖: " + res.Message.Content)
}
run.AvgDuration = time.Duration(int64(allQuestionsTook) / int64(len(questions)))
run.TotalTestTime = time.Since(start)
return run
}