diff --git a/uast/nodes/node.go b/uast/nodes/node.go index 8fd3ef61..1f97b5ae 100644 --- a/uast/nodes/node.go +++ b/uast/nodes/node.go @@ -948,3 +948,29 @@ func UniqueKey(n Node) Comparable { return unkPtr(ptr) } } + +// ChildrenCount reports the number of immediate children of n. If n is an Array this is +// the length of the array. If n is an Object, each object in a field of n counts as +// one child and each array is counted as its length. +func ChildrenCount(n Node) int { + switch n := n.(type) { + case nil: + return 0 + case Value: + return 0 + case Array: + return len(n) + case Object: + c := 0 + for _, v := range n { + switch v := v.(type) { + case Object: + c++ + case Array: + c += len(v) + } + } + return c + } + return 0 +} diff --git a/uast/nodes/node_test.go b/uast/nodes/node_test.go index 84572f3a..92aba6aa 100644 --- a/uast/nodes/node_test.go +++ b/uast/nodes/node_test.go @@ -38,6 +38,50 @@ func TestClone(t *testing.T) { }, arr) } +func TestChildrenCount(t *testing.T) { + var cases = []struct { + name string + node Node + exp int + }{ + { + name: "value", + node: Int(3), + exp: 0, + }, + { + name: "array", + node: Array{ + Int(1), + Array{Int(2), Int(3)}, + Object{ + "a": Int(1), + "b": Int(2), + }, + }, + exp: 3, + }, + { + name: "object", + node: Object{ + "k1": Int(1), + "k2": Int(2), + "arr": Array{Int(2), Int(3)}, + "obj": Object{ + "a": Int(1), + "b": Int(2), + }, + }, + exp: 3, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + require.Equal(t, c.exp, ChildrenCount(c.node)) + }) + } +} + func TestApply(t *testing.T) { o1 := Object{"v": Int(1)} o2 := Object{"k": o1, "v": Int(2)} diff --git a/uast/uast.go b/uast/uast.go index 226a4ba8..7a54db42 100644 --- a/uast/uast.go +++ b/uast/uast.go @@ -182,7 +182,13 @@ func AsPosition(m nodes.Object) *Position { } // PositionsOf returns a complete positions map for the given UAST node. -func PositionsOf(m nodes.Object) Positions { +// The function will return nil for non-object nodes like arrays and values. To get +// positions for these nodes, PositionsOf should be called on their parent node. +func PositionsOf(n nodes.Node) Positions { + m, ok := n.(nodes.Object) + if !ok { + return nil + } o, _ := m[KeyPos].(nodes.Object) if len(o) == 0 { return nil @@ -216,7 +222,12 @@ func RoleList(roles ...role.Role) nodes.Array { } // RolesOf is a helper for getting node UAST roles (see KeyRoles). -func RolesOf(m nodes.Object) role.Roles { +// The function will returns nil roles array for non-object nodes like arrays and values. +func RolesOf(n nodes.Node) role.Roles { + m, ok := n.(nodes.Object) + if !ok { + return nil + } arr, ok := m[KeyRoles].(nodes.Array) if !ok || len(arr) == 0 { if tp := TypeOf(m); tp == "" || strings.HasPrefix(tp, NS+":") { @@ -234,15 +245,19 @@ func RolesOf(m nodes.Object) role.Roles { } // TokenOf is a helper for getting node token (see KeyToken). -func TokenOf(m nodes.Object) string { - t := m[KeyToken] - s, ok := t.(nodes.String) - if ok { - return string(s) - } - v, _ := t.(nodes.Value) - if v != nil { - return fmt.Sprint(v) +// It returns an empty string if the node is not an object, or there is no token. +func TokenOf(n nodes.Node) string { + switch n := n.(type) { + case nodes.String: + return string(n) + case nodes.Value: + return fmt.Sprint(n) + case nodes.Object: + t := n[KeyToken] + if t == nil { + return "" + } + return TokenOf(t) } return "" }