Skip to content

Commit

Permalink
Merge pull request #6 from pbelitz/loadreduce-list-dict
Browse files Browse the repository at this point in the history
update list and dict to work with loadReduce
  • Loading branch information
matteo-grella authored Mar 13, 2023
2 parents 1b56f50 + 1ad3998 commit 40d0170
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 2 deletions.
9 changes: 7 additions & 2 deletions pickle/pickle.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,13 @@ func (u *Unpickler) findClass(module, name string) (interface{}, error) {
case "OrderedDict":
return &types.OrderedDictClass{}, nil
}

case "builtins":
switch name {
case "list":
return &types.List{}, nil
case "dict":
return &types.Dict{}, nil
}
case "__builtin__":
switch name {
case "object":
Expand All @@ -139,7 +145,6 @@ func (u *Unpickler) findClass(module, name string) (interface{}, error) {
}
return types.NewGenericClass(module, name), nil
}

func (u *Unpickler) read(n int) ([]byte, error) {
buf := make([]byte, n)

Expand Down
12 changes: 12 additions & 0 deletions pickle/pickle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
package pickle

import (
"fmt"
"github.com/nlpodyssey/gopickle/types"
"math/big"
"reflect"
"strings"
"testing"
)
Expand Down Expand Up @@ -637,6 +639,16 @@ func TestByteArrayP5(t *testing.T) {
}
}

func TestFindClass(t *testing.T) {
u := &Unpickler{}
v, _ := u.findClass("builtins", "list")
actual, _ := fmt.Println(reflect.TypeOf(v))
expected, _ := fmt.Println(reflect.TypeOf(&types.List{}))
if actual != expected {
t.Errorf("expected %v, actual: %v", expected, actual)
}
}

// TODO: test BinPersId
// TODO: test Get
// TODO: test BinGet
Expand Down
10 changes: 10 additions & 0 deletions types/dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,13 @@ func (d *Dict) Keys() []interface{} {

return out
}

func (*Dict) Call(args ...interface{}) (interface{}, error) {
if len(args) == 0 {
return NewDict(), nil
}
if len(args) == 1 {
return args[0], nil
}
return nil, fmt.Errorf("Dict: invalid arguments: %#v", args)
}
16 changes: 16 additions & 0 deletions types/dict_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package types

import "testing"

func TestDictCall(t *testing.T) {
d := NewDict()
d.Set("foo", "bar")
args := []interface{}{d}
result, _ := d.Call(args)
resultdict := *result.([]interface{})[0].(*Dict)
actual, _ := resultdict.Get("foo")
expected := "bar"
if actual != expected {
t.Errorf("expected %v, actual: %v", expected, actual)
}
}
12 changes: 12 additions & 0 deletions types/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package types

import "fmt"

// ListAppender is implemented by any value that exhibits a list-like
// behaviour, allowing arbitrary values to be appended.
type ListAppender interface {
Expand Down Expand Up @@ -47,3 +49,13 @@ func (l *List) Get(i int) interface{} {
func (l *List) Len() int {
return len(*l)
}

func (*List) Call(args ...interface{}) (interface{}, error) {
if len(args) == 0 {
return NewList(), nil
}
if len(args) == 1 {
return args[0], nil
}
return nil, fmt.Errorf("List: invalid arguments: %#v", args)
}
17 changes: 17 additions & 0 deletions types/list_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package types

import (
"testing"
)

func TestCall(t *testing.T) {
list := NewList()
list.Append("foo")
args := []interface{}{list}
result, _ := list.Call(args)
actual := (*result.([]interface{})[0].(*List))[0]
expected := "foo"
if actual != expected {
t.Errorf("expected %v, actual: %v", expected, actual)
}
}

0 comments on commit 40d0170

Please sign in to comment.