Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datasource cleanup: introduce some types and avoid pass-by-context #5317

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions internal/datasources/rest/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/mindersec/minder/internal/util/schemaupdate"
"github.com/mindersec/minder/internal/util/schemavalidate"
minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1"
"github.com/mindersec/minder/pkg/engine/v1/interfaces"
)

const (
Expand Down Expand Up @@ -69,7 +70,7 @@ func newHandlerFromDef(def *minderv1.RestDataSource_Def) (*restHandler, error) {
}, nil
}

func (h *restHandler) GetArgsSchema() any {
func (h *restHandler) GetArgsSchema() *structpb.Struct {
return h.rawInputSchema
}

Expand All @@ -86,28 +87,18 @@ func (h *restHandler) ValidateArgs(args any) error {
return schemavalidate.ValidateAgainstSchema(h.inputSchema, mapobj)
}

func (h *restHandler) ValidateUpdate(obj any) error {
if obj == nil {
func (h *restHandler) ValidateUpdate(argsSchema *structpb.Struct) error {
if argsSchema == nil {
return errors.New("update schema cannot be nil")
}

switch castedobj := obj.(type) {
case *structpb.Struct:
if _, err := schemavalidate.CompileSchemaFromPB(castedobj); err != nil {
return fmt.Errorf("update validation failed due to invalid schema: %w", err)
}
return schemaupdate.ValidateSchemaUpdate(h.rawInputSchema, castedobj)
case map[string]any:
if _, err := schemavalidate.CompileSchemaFromMap(castedobj); err != nil {
return fmt.Errorf("update validation failed due to invalid schema: %w", err)
}
return schemaupdate.ValidateSchemaUpdateMap(h.rawInputSchema.AsMap(), castedobj)
default:
return errors.New("invalid type")
if _, err := schemavalidate.CompileSchemaFromPB(argsSchema); err != nil {
return fmt.Errorf("update validation failed due to invalid schema: %w", err)
}
return schemaupdate.ValidateSchemaUpdate(h.rawInputSchema, argsSchema)
}

func (h *restHandler) Call(ctx context.Context, args any) (any, error) {
func (h *restHandler) Call(ctx context.Context, _ *interfaces.Result, args any) (any, error) {
argsMap, ok := args.(map[string]any)
if !ok {
return nil, errors.New("args is not a map")
Expand Down
32 changes: 2 additions & 30 deletions internal/datasources/rest/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func Test_restHandler_Call(t *testing.T) {
headers: tt.fields.headers,
parse: tt.fields.parse,
}
got, err := h.Call(context.Background(), tt.args.args)
got, err := h.Call(context.Background(), nil, tt.args.args)
if tt.wantErr {
assert.Error(t, err)
} else {
Expand Down Expand Up @@ -367,7 +367,7 @@ func Test_restHandler_ValidateUpdate(t *testing.T) {
t.Parallel()

type args struct {
updateSchema any
updateSchema *structpb.Struct
}
tests := []struct {
name string
Expand Down Expand Up @@ -408,34 +408,6 @@ func Test_restHandler_ValidateUpdate(t *testing.T) {
},
wantErr: false,
},
{
name: "Valid map[string]any",
inputSchema: map[string]any{
"type": "object",
"properties": map[string]any{"key": map[string]any{"type": "string"}},
},
args: args{
updateSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"key": map[string]any{"type": "string"},
"new_key": map[string]any{"type": "number"},
},
},
},
wantErr: false,
},
{
name: "Invalid type",
inputSchema: map[string]any{
"type": "object",
"properties": map[string]any{"key": map[string]any{"type": "string"}},
},
args: args{
updateSchema: "invalid_type",
},
wantErr: true,
},
{
name: "nil update schema",
inputSchema: map[string]any{
Expand Down
18 changes: 7 additions & 11 deletions internal/datasources/structured/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ import (

"github.com/go-git/go-billy/v5"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/structpb"

minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1"
v1datasources "github.com/mindersec/minder/pkg/datasources/v1"
"github.com/mindersec/minder/pkg/engine/v1/interfaces"
)

const (
Expand Down Expand Up @@ -148,21 +150,15 @@ func parseFile(f billy.File) (any, error) {
}

// Call parses the structured data from the billy filesystem in the context
func (sh *structHandler) Call(ctx context.Context, _ any) (any, error) {
var ctxData v1datasources.Context
var ok bool
if ctxData, ok = ctx.Value(v1datasources.ContextKey{}).(v1datasources.Context); !ok {
return nil, fmt.Errorf("unable to read execution context")
}

if ctxData.Ingest == nil || ctxData.Ingest.Fs == nil {
func (sh *structHandler) Call(_ context.Context, ingest *interfaces.Result, _ any) (any, error) {
if ingest == nil || ingest.Fs == nil {
return nil, fmt.Errorf("filesystem not found in execution context")
}

return parseFileAlternatives(ctxData.Ingest.Fs, sh.Path.GetFileName(), sh.Path.GetAlternatives())
return parseFileAlternatives(ingest.Fs, sh.Path.GetFileName(), sh.Path.GetAlternatives())
}

func (*structHandler) GetArgsSchema() any {
func (*structHandler) GetArgsSchema() *structpb.Struct {
return nil
}

Expand All @@ -172,6 +168,6 @@ func (_ *structHandler) ValidateArgs(any) error {
}

// ValidateUpdate
func (_ *structHandler) ValidateUpdate(any) error {
func (_ *structHandler) ValidateUpdate(*structpb.Struct) error {
return nil
}
35 changes: 12 additions & 23 deletions internal/datasources/structured/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/stretchr/testify/require"

minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1"
v1datasources "github.com/mindersec/minder/pkg/datasources/v1"
"github.com/mindersec/minder/pkg/engine/v1/interfaces"
)

Expand Down Expand Up @@ -166,25 +165,19 @@ func TestNew(t *testing.T) {
func TestCall(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
buildContext func(t *testing.T) context.Context
def *minderv1.StructDataSource_Def
mustErr bool
name string
ingest func(t *testing.T) *interfaces.Result
def *minderv1.StructDataSource_Def
mustErr bool
}{
{
"success",
func(t *testing.T) context.Context {
func(t *testing.T) *interfaces.Result {
t.Helper()
fs := memfs.New()
writeFSFile(t, fs, "./test1.json", []byte("{ \"a\": \"b\"}"))

return context.WithValue(
context.Background(),
v1datasources.ContextKey{},
v1datasources.Context{
Ingest: &interfaces.Result{Fs: fs},
},
)
return &interfaces.Result{Fs: fs}
},
&minderv1.StructDataSource_Def{
Path: &minderv1.StructDataSource_Def_Path{
Expand All @@ -195,31 +188,27 @@ func TestCall(t *testing.T) {
},
{
"no-datasource-context",
func(t *testing.T) context.Context {
func(t *testing.T) *interfaces.Result {
t.Helper()
return context.Background()
return nil
},
&minderv1.StructDataSource_Def{},
true,
},
{"ctx-no-fs",
func(t *testing.T) context.Context {
func(t *testing.T) *interfaces.Result {
t.Helper()
return context.WithValue(
context.Background(),
v1datasources.ContextKey{},
v1datasources.Context{},
)
return &interfaces.Result{}
},
&minderv1.StructDataSource_Def{},
true},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := tc.buildContext(t)
ingest := tc.ingest(t)
handler, err := newHandlerFromDef(tc.def)
require.NoError(t, err)
_, err = handler.Call(ctx, []string{})
_, err = handler.Call(context.Background(), ingest, []string{})
if tc.mustErr {
require.Error(t, err)
return
Expand Down
13 changes: 2 additions & 11 deletions internal/engine/eval/rego/datasources.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package rego

import (
"context"
"fmt"
"strings"

Expand Down Expand Up @@ -48,7 +47,7 @@ func buildFromDataSource(
Name: k,
Decl: types.NewFunction(types.Args(types.A), types.A),
},
func(_ rego.BuiltinContext, obj *ast.Term) (*ast.Term, error) {
func(bctx rego.BuiltinContext, obj *ast.Term) (*ast.Term, error) {
// Convert the AST value back to a Go interface{}
jsonObj, err := ast.JSON(obj.Value)
if err != nil {
Expand All @@ -59,15 +58,7 @@ func buildFromDataSource(
return nil, err
}

// Call the data source function
ctx := context.WithValue(
context.Background(),
v1datasources.ContextKey{},
v1datasources.Context{
Ingest: res,
},
)
ret, err := dsf.Call(ctx, jsonObj)
ret, err := dsf.Call(bctx.Context, res, jsonObj)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/engine/eval/rego/rego_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ allow {
emptyPol := map[string]any{}

// Matches
fdsf.EXPECT().Call(gomock.Any(), gomock.Any()).Return("foo", nil)
fdsf.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any()).Return("foo", nil)
_, err = e.Eval(context.Background(), emptyPol, nil, &interfaces.Result{
Object: map[string]any{
"data": "foo",
Expand All @@ -531,7 +531,7 @@ allow {
require.NoError(t, err, "could not evaluate")

// Doesn't match
fdsf.EXPECT().Call(gomock.Any(), gomock.Any()).Return("bar", nil)
fdsf.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any()).Return("bar", nil)
_, err = e.Eval(context.Background(), emptyPol, nil, &interfaces.Result{
Object: map[string]any{
"data": "bar",
Expand Down
8 changes: 5 additions & 3 deletions pkg/datasources/v1/datasources.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package v1
import (
"context"

"google.golang.org/protobuf/types/known/structpb"

"github.com/mindersec/minder/pkg/engine/v1/interfaces"
)

Expand Down Expand Up @@ -36,14 +38,14 @@ type DataSourceFuncDef interface {
// ValidateUpdate validates the update to the data source.
// The data source implementation should respect the update and return an error
// if the update is invalid.
ValidateUpdate(obj any) error
ValidateUpdate(obj *structpb.Struct) error
// Call calls the function with the given arguments.
// It is the responsibility of the data source implementation to handle the call.
// It is also the responsibility of the caller to validate the arguments
// before calling the function.
Call(ctx context.Context, args any) (any, error)
Call(ctx context.Context, ingest *interfaces.Result, args any) (any, error)
// GetArgsSchema returns the schema of the arguments.
GetArgsSchema() any
GetArgsSchema() *structpb.Struct
}

// DataSource is the interface that a data source must implement.
Expand Down
16 changes: 9 additions & 7 deletions pkg/datasources/v1/mock/datasources.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading