Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move models actions to v2 protocol #35

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func run(opts *Options) error {
if err != nil {
return err
}
rsp, err := client.DeleteModel(opts.Database, opts.Engine, opts.Model)
rsp, err := client.DeleteModels(opts.Database, opts.Engine, []string{opts.Model})
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion examples/get_model/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func run(opts *Options) error {
if err != nil {
return err
}
rsp, err := client.ListModels(opts.Database, opts.Engine)
rsp, err := client.GetModel(opts.Database, opts.Engine, opts.Model)
if err != nil {
return err
}
Expand Down
52 changes: 0 additions & 52 deletions examples/list_model_names/main.go

This file was deleted.

17 changes: 12 additions & 5 deletions examples/load_model/main.go → examples/load_models/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package main

import (
"fmt"
"log"
"os"
"path/filepath"
Expand All @@ -41,16 +42,22 @@ func run(opts *Options) error {
if err != nil {
return err
}
r, err := os.Open(opts.File)

value, err := os.ReadFile(opts.File)
if err != nil {
return err
}
name := sansext(opts.File)
rsp, err := client.LoadModel(opts.Database, opts.Engine, name, r)

model := map[string]string{
sansext(opts.File): string(value),
}

rsp, err := client.LoadModels("hnr-db", "hnr-engine", model)
if err != nil {
return err
return nil
}
rsp.Show()

fmt.Println(rsp)
return nil
}

Expand Down
182 changes: 100 additions & 82 deletions rai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"io"
"io/ioutil"
"math/rand"
"mime"
"mime/multipart"
"net/http"
Expand Down Expand Up @@ -883,110 +884,127 @@ func (c *Client) ListOAuthClients() ([]OAuthClient, error) {
// Models
//

func (c *Client) DeleteModel(
database, engine, name string,
) (*TransactionResult, error) {
return c.DeleteModels(database, engine, []string{name})
func (c *Client) LoadModels(
database, engine string, models map[string]string,
) (*TransactionAsyncResult, error) {
randUint := rand.Uint32()
queries := make([]string, 0)
queryInputs := make(map[string]string)

index := 0
for name, value := range models {
index++
inputName := fmt.Sprintf("input_%d_%d", randUint, index)
queries = append(queries,
fmt.Sprintf(`
def delete:rel:catalog:model["%s"] = rel:catalog:model["%s"]
def insert:rel:catalog:model["%s"] = %s
`, name, name, name, inputName,
),
)
queryInputs[inputName] = value
}

return c.Execute(database, engine, strings.Join(queries, "\n"), queryInputs, false)
}

func (c *Client) DeleteModels(
database, engine string, models []string,
) (*TransactionResult, error) {
var result TransactionResult
tx := Transaction{
Region: c.Region,
Database: database,
Engine: engine,
Mode: "OPEN",
Readonly: false}
data := tx.Payload(makeDeleteModelsAction(models))
err := c.Post(PathTransaction, tx.QueryArgs(), data, &result)
if err != nil {
return nil, err
func (c *Client) LoadModelsAsync(
database, engine string, models map[string]string,
) (*TransactionAsyncResult, error) {
randUint := rand.Uint32()
queries := make([]string, 0)
queryInputs := make(map[string]string)

index := 0
for name, value := range models {
inputName := fmt.Sprintf("input_%d_%d", randUint, index)
queries = append(queries,
fmt.Sprintf(`
def delete:rel:catalog:model["%s"] = rel:catalog:model["%s"]
def insert:rel:catallog:model["%s"] = %s
`, name, name, name, inputName,
),
)
queryInputs[inputName] = value
index++
}
return &result, err

return c.ExecuteAsync(database, engine, strings.Join(queries, "\n"), queryInputs, false)
}

func (c *Client) GetModel(database, engine, model string) (*Model, error) {
var result listModelsResponse
tx := NewTransaction(c.Region, database, engine, "OPEN")
data := tx.Payload(makeListModelsAction())
err := c.Post(PathTransaction, tx.QueryArgs(), data, &result)
// Returns a list of model names for the given database.
func (c *Client) ListModels(database, engine string) ([]string, error) {
outName := fmt.Sprintf("models_%d", rand.Uint32())
query := fmt.Sprintf("def output:%s[name] = rel:catalog:model(name, _)", outName)
resp, err := c.Execute(database, engine, query, nil, true)
if err != nil {
return nil, err
}
// assert len(result.Actions) == 1
for _, item := range result.Actions[0].Result.Models {
if item.Name == model {
return &item, nil

var result ArrowRelation
for _, res := range resp.Results {
// use proto metadata instead
if res.RelationID == fmt.Sprintf("/:output/:%s/String", outName) {
result = res
}
}
return nil, ErrNotFound
}

func (c *Client) LoadModel(
database, engine, name string, r io.Reader,
) (*TransactionResult, error) {
return c.LoadModels(database, engine, map[string]io.Reader{name: r})
}

func (c *Client) LoadModels(
database, engine string, models map[string]io.Reader,
) (*TransactionResult, error) {
var result TransactionResult
tx := Transaction{
Region: c.Region,
Database: database,
Engine: engine,
Mode: "OPEN",
Readonly: false}
actions := []DbAction{}
for name, r := range models {
model, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
models := make([]string, 0)
if len(result.Table) > 0 {
for _, name := range result.Table[0] {
models = append(models, name.(string))
}
action := makeLoadModelAction(name, string(model))
actions = append(actions, action)
}
data := tx.Payload(actions...)
err := c.Post(PathTransaction, tx.QueryArgs(), data, &result)
if err != nil {
return nil, err

return models, nil
}
return &result, nil

return models, nil
}

// Returns a list of model names for the given database.
func (c *Client) ListModelNames(database, engine string) ([]string, error) {
var models listModelsResponse
tx := NewTransaction(c.Region, database, engine, "OPEN")
data := tx.Payload(makeListModelsAction())
err := c.Post(PathTransaction, tx.QueryArgs(), data, &models)
func (c *Client) GetModel(database, engine, model string) (*Model, error) {
outName := fmt.Sprintf("model_%d", rand.Uint32())
query := fmt.Sprintf(`def output:%s = rel:catalog:model["%s"]`, outName, model)
resp, err := c.Execute(database, engine, query, nil, true)
if err != nil {
return nil, err
}
actions := models.Actions
// assert len(actions) == 1
result := []string{}
for _, model := range actions[0].Result.Models {
result = append(result, model.Name)

var result ArrowRelation
for _, res := range resp.Results {
if res.RelationID == fmt.Sprintf("/:output/:%s/String", outName) {
result = res
}
}
return result, nil

if len(result.Table) > 0 {
name := model
value := result.Table[0][0].(string)
return &Model{name, value}, nil
}

return nil, ErrNotFound
}

// Returns the names of models installed in the given database.
func (c *Client) ListModels(database, engine string) ([]Model, error) {
var models listModelsResponse
tx := NewTransaction(c.Region, database, engine, "OPEN")
data := tx.Payload(makeListModelsAction())
err := c.Post(PathTransaction, tx.QueryArgs(), data, &models)
if err != nil {
return nil, err
func (c *Client) DeleteModels(
database, engine string, models []string,
) (*TransactionAsyncResult, error) {
queries := make([]string, 0)
for _, model := range models {
queries = append(queries, fmt.Sprintf(`def delete:rel:catalog:model["%s"] = rel:catalog:model["%s"]`, model, model))
}

return c.Execute(database, engine, strings.Join(queries, "\n"), nil, false)
}

func (c *Client) DeleteModelsAsync(
database, engine string, models []string,
) (*TransactionAsyncResult, error) {
queries := make([]string, 0)
for _, model := range models {
queries = append(queries, fmt.Sprintf(`def delete:rel:catalog:model["%s"] = rel:catalog:model["%s"]`, model, model))
}
actions := models.Actions
// assert len(actions) == 1
return actions[0].Result.Models, nil

return c.ExecuteAsync(database, engine, strings.Join(queries, "\n"), nil, false)
}

//
Expand Down
Loading