Skip to content

Commit

Permalink
Allow to install ollama models from CLI
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler committed Jun 21, 2024
1 parent 25d3fb3 commit 1f68f26
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
15 changes: 9 additions & 6 deletions core/cli/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

cliContext "github.com/go-skynet/LocalAI/core/cli/context"

"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/startup"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -79,13 +80,15 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
return err
}

model := gallery.FindModel(models, modelName, mi.ModelsPath)
if model == nil {
log.Error().Str("model", modelName).Msg("model not found")
return err
}
if !downloader.LooksLikeOCI(modelName) {
model := gallery.FindModel(models, modelName, mi.ModelsPath)
if model == nil {
log.Error().Str("model", modelName).Msg("model not found")
return err
}

log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
}
err = startup.InstallModels(galleries, "", mi.ModelsPath, progressCallback, modelName)
if err != nil {
return err
Expand Down
2 changes: 2 additions & 0 deletions pkg/downloader/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ func DownloadFile(url string, filePath, sha string, fileN, total int, downloadSt
}

if strings.HasPrefix(url, OllamaPrefix) {
url = strings.TrimPrefix(url, OllamaPrefix)
return oci.OllamaFetchModel(url, filePath, progressStatus)
}

url = strings.TrimPrefix(url, OCIPrefix)
img, err := oci.GetImage(url, "", nil, nil)
if err != nil {
return fmt.Errorf("failed to get image %q: %v", url, err)
Expand Down
23 changes: 23 additions & 0 deletions pkg/startup/model_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/go-skynet/LocalAI/embedded"
"github.com/go-skynet/LocalAI/pkg/downloader"
Expand Down Expand Up @@ -52,6 +53,28 @@ func InstallModels(galleries []gallery.Gallery, modelLibraryURL string, modelPat
log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
err = errors.Join(err, e)
}
case downloader.LooksLikeOCI(url):
log.Debug().Msgf("[startup] resolved OCI model to download: %s", url)

// convert OCI image name to a file name.
ociName := strings.TrimPrefix(url, downloader.OCIPrefix)
ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix)
ociName = strings.ReplaceAll(ociName, "/", "__")
ociName = strings.ReplaceAll(ociName, ":", "__")

// check if file exists
if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) {
modelDefinitionFilePath := filepath.Join(modelPath, ociName)
e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
})
if e != nil {
log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
err = errors.Join(err, e)
}
}

log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName)
case downloader.LooksLikeURL(url):
log.Debug().Msgf("[startup] resolved model to download: %s", url)

Expand Down

0 comments on commit 1f68f26

Please sign in to comment.