diff --git a/pickle/pickle.go b/pickle/pickle.go index bafefc8..b7501db 100644 --- a/pickle/pickle.go +++ b/pickle/pickle.go @@ -122,7 +122,13 @@ func (u *Unpickler) findClass(module, name string) (interface{}, error) { case "OrderedDict": return &types.OrderedDictClass{}, nil } - + case "builtins": + switch name { + case "list": + return &types.List{}, nil + case "dict": + return &types.Dict{}, nil + } case "__builtin__": switch name { case "object": @@ -139,7 +145,6 @@ func (u *Unpickler) findClass(module, name string) (interface{}, error) { } return types.NewGenericClass(module, name), nil } - func (u *Unpickler) read(n int) ([]byte, error) { buf := make([]byte, n) diff --git a/pickle/pickle_test.go b/pickle/pickle_test.go index 065106d..02a98af 100644 --- a/pickle/pickle_test.go +++ b/pickle/pickle_test.go @@ -5,8 +5,10 @@ package pickle import ( + "fmt" "github.com/nlpodyssey/gopickle/types" "math/big" + "reflect" "strings" "testing" ) @@ -637,6 +639,16 @@ func TestByteArrayP5(t *testing.T) { } } +func TestFindClass(t *testing.T) { + u := &Unpickler{} + v, _ := u.findClass("builtins", "list") + actual, _ := fmt.Println(reflect.TypeOf(v)) + expected, _ := fmt.Println(reflect.TypeOf(&types.List{})) + if actual != expected { + t.Errorf("expected %v, actual: %v", expected, actual) + } +} + // TODO: test BinPersId // TODO: test Get // TODO: test BinGet diff --git a/types/dict.go b/types/dict.go index 1740e8f..a60b784 100644 --- a/types/dict.go +++ b/types/dict.go @@ -78,3 +78,13 @@ func (d *Dict) Keys() []interface{} { return out } + +func (*Dict) Call(args ...interface{}) (interface{}, error) { + if len(args) == 0 { + return NewDict(), nil + } + if len(args) == 1 { + return args[0], nil + } + return nil, fmt.Errorf("Dict: invalid arguments: %#v", args) +} diff --git a/types/dict_test.go b/types/dict_test.go new file mode 100644 index 0000000..fe4b41c --- /dev/null +++ b/types/dict_test.go @@ -0,0 +1,16 @@ +package types + +import "testing" + +func TestDictCall(t *testing.T) { + d := NewDict() + d.Set("foo", "bar") + args := []interface{}{d} + result, _ := d.Call(args) + resultdict := *result.([]interface{})[0].(*Dict) + actual, _ := resultdict.Get("foo") + expected := "bar" + if actual != expected { + t.Errorf("expected %v, actual: %v", expected, actual) + } +} diff --git a/types/list.go b/types/list.go index 7b54bcc..84325aa 100644 --- a/types/list.go +++ b/types/list.go @@ -4,6 +4,8 @@ package types +import "fmt" + // ListAppender is implemented by any value that exhibits a list-like // behaviour, allowing arbitrary values to be appended. type ListAppender interface { @@ -47,3 +49,13 @@ func (l *List) Get(i int) interface{} { func (l *List) Len() int { return len(*l) } + +func (*List) Call(args ...interface{}) (interface{}, error) { + if len(args) == 0 { + return NewList(), nil + } + if len(args) == 1 { + return args[0], nil + } + return nil, fmt.Errorf("List: invalid arguments: %#v", args) +} diff --git a/types/list_test.go b/types/list_test.go new file mode 100644 index 0000000..20a1e9d --- /dev/null +++ b/types/list_test.go @@ -0,0 +1,17 @@ +package types + +import ( + "testing" +) + +func TestCall(t *testing.T) { + list := NewList() + list.Append("foo") + args := []interface{}{list} + result, _ := list.Call(args) + actual := (*result.([]interface{})[0].(*List))[0] + expected := "foo" + if actual != expected { + t.Errorf("expected %v, actual: %v", expected, actual) + } +}