diff --git a/tpl/collections/apply.go b/tpl/collections/apply.go index c3c3a297b..0b2b00621 100644 --- a/tpl/collections/apply.go +++ b/tpl/collections/apply.go @@ -148,3 +148,15 @@ func indirect(v reflect.Value) (rv reflect.Value, isNil bool) { } return v, false } + +func indirectInterface(v reflect.Value) (rv reflect.Value, isNil bool) { + for ; v.Kind() == reflect.Interface; v = v.Elem() { + if v.IsNil() { + return v, true + } + if v.Kind() == reflect.Interface && v.NumMethod() > 0 { + break + } + } + return v, false +} diff --git a/tpl/collections/collections.go b/tpl/collections/collections.go index ab3d08f5e..103cb3860 100644 --- a/tpl/collections/collections.go +++ b/tpl/collections/collections.go @@ -256,7 +256,9 @@ func (ns *Namespace) In(l interface{}, v interface{}) bool { } default: if isNumber(vv.Kind()) && isNumber(lvv.Kind()) { - if numberToFloat(vv) == numberToFloat(lvv) { + f1, err1 := numberToFloat(vv) + f2, err2 := numberToFloat(lvv) + if err1 == nil && err2 == nil && f1 == f2 { return true } } @@ -277,69 +279,24 @@ func (ns *Namespace) Intersect(l1, l2 interface{}) (interface{}, error) { return make([]interface{}, 0), nil } + var ins *intersector + l1v := reflect.ValueOf(l1) l2v := reflect.ValueOf(l2) switch l1v.Kind() { case reflect.Array, reflect.Slice: + ins = &intersector{r: reflect.MakeSlice(l1v.Type(), 0, 0), seen: make(map[interface{}]bool)} switch l2v.Kind() { case reflect.Array, reflect.Slice: - r := reflect.MakeSlice(l1v.Type(), 0, 0) for i := 0; i < l1v.Len(); i++ { l1vv := l1v.Index(i) for j := 0; j < l2v.Len(); j++ { l2vv := l2v.Index(j) - switch l1vv.Kind() { - case reflect.String: - l2t, err := toString(l2vv) - if err == nil && l1vv.String() == l2t && !ns.In(r.Interface(), l1vv.Interface()) { - r = reflect.Append(r, l1vv) - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - l2t, err := toInt(l2vv) - if err == nil && l1vv.Int() == l2t && !ns.In(r.Interface(), l1vv.Interface()) { - r = reflect.Append(r, l1vv) - } - case reflect.Float32, reflect.Float64: - l2t, err := toFloat(l2vv) - if err == nil && l1vv.Float() == l2t && !ns.In(r.Interface(), l1vv.Interface()) { - r = reflect.Append(r, l1vv) - } - case reflect.Interface: - switch l1vvActual := l1vv.Interface().(type) { - case string: - switch l2vvActual := l2vv.Interface().(type) { - case string: - if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) { - r = reflect.Append(r, l1vv) - } - } - case int, int8, int16, int32, int64: - switch l2vvActual := l2vv.Interface().(type) { - case int, int8, int16, int32, int64: - if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) { - r = reflect.Append(r, l1vv) - } - } - case uint, uint8, uint16, uint32, uint64: - switch l2vvActual := l2vv.Interface().(type) { - case uint, uint8, uint16, uint32, uint64: - if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) { - r = reflect.Append(r, l1vv) - } - } - case float32, float64: - switch l2vvActual := l2vv.Interface().(type) { - case float32, float64: - if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) { - r = reflect.Append(r, l1vv) - } - } - } - } + ins.handleValuePair(l1vv, l2vv) } } - return r.Interface(), nil + return ins.r.Interface(), nil default: return nil, errors.New("can't iterate over " + reflect.ValueOf(l2).Type().String()) } @@ -531,6 +488,41 @@ func (ns *Namespace) Slice(args ...interface{}) []interface{} { return args } +type intersector struct { + r reflect.Value + seen map[interface{}]bool +} + +func (i *intersector) appendIfNotSeen(v reflect.Value) { + vi := v.Interface() + if !i.seen[vi] { + i.r = reflect.Append(i.r, v) + i.seen[vi] = true + } +} + +func (ins *intersector) handleValuePair(l1vv, l2vv reflect.Value) { + switch kind := l1vv.Kind(); { + case kind == reflect.String: + l2t, err := toString(l2vv) + if err == nil && l1vv.String() == l2t { + ins.appendIfNotSeen(l1vv) + } + case isNumber(kind): + f1, err1 := numberToFloat(l1vv) + f2, err2 := numberToFloat(l2vv) + if err1 == nil && err2 == nil && f1 == f2 { + ins.appendIfNotSeen(l1vv) + } + case kind == reflect.Ptr, kind == reflect.Struct: + if l1vv.Interface() == l2vv.Interface() { + ins.appendIfNotSeen(l1vv) + } + case kind == reflect.Interface: + ins.handleValuePair(reflect.ValueOf(l1vv.Interface()), l2vv) + } +} + // Union returns the union of the given sets, l1 and l2. l1 and // l2 must be of the same type and may be either arrays or slices. // If l1 and l2 aren't of the same type then l1 will be returned. @@ -547,105 +539,54 @@ func (ns *Namespace) Union(l1, l2 interface{}) (interface{}, error) { l1v := reflect.ValueOf(l1) l2v := reflect.ValueOf(l2) + var ins *intersector + switch l1v.Kind() { case reflect.Array, reflect.Slice: switch l2v.Kind() { case reflect.Array, reflect.Slice: - r := reflect.MakeSlice(l1v.Type(), 0, 0) + ins = &intersector{r: reflect.MakeSlice(l1v.Type(), 0, 0), seen: make(map[interface{}]bool)} if l1v.Type() != l2v.Type() && l1v.Type().Elem().Kind() != reflect.Interface && l2v.Type().Elem().Kind() != reflect.Interface { - return r.Interface(), nil + return ins.r.Interface(), nil } - var l1vv reflect.Value + var ( + l1vv reflect.Value + isNil bool + ) + for i := 0; i < l1v.Len(); i++ { - l1vv = l1v.Index(i) - if !ns.In(r.Interface(), l1vv.Interface()) { - r = reflect.Append(r, l1vv) + l1vv, isNil = indirectInterface(l1v.Index(i)) + if !isNil { + ins.appendIfNotSeen(l1vv) } } for j := 0; j < l2v.Len(); j++ { l2vv := l2v.Index(j) - switch l1vv.Kind() { - case reflect.String: + switch kind := l1vv.Kind(); { + case kind == reflect.String: l2t, err := toString(l2vv) - if err == nil && !ns.In(r.Interface(), l2t) { - r = reflect.Append(r, reflect.ValueOf(l2t)) + if err == nil { + ins.appendIfNotSeen(reflect.ValueOf(l2t)) } - case reflect.Int: - l2t, err := toInt(l2vv) - if err == nil && !ns.In(r.Interface(), l2t) { - r = reflect.Append(r, reflect.ValueOf(int(l2t))) - } - case reflect.Int8: - l2t, err := toInt(l2vv) - if err == nil && !ns.In(r.Interface(), l2t) { - r = reflect.Append(r, reflect.ValueOf(int8(l2t))) - } - case reflect.Int16: - l2t, err := toInt(l2vv) - if err == nil && !ns.In(r.Interface(), l2t) { - r = reflect.Append(r, reflect.ValueOf(int16(l2t))) - } - case reflect.Int32: - l2t, err := toInt(l2vv) - if err == nil && !ns.In(r.Interface(), l2t) { - r = reflect.Append(r, reflect.ValueOf(int32(l2t))) - } - case reflect.Int64: - l2t, err := toInt(l2vv) - if err == nil && !ns.In(r.Interface(), l2t) { - r = reflect.Append(r, reflect.ValueOf(l2t)) - } - case reflect.Float32: - l2t, err := toFloat(l2vv) - if err == nil && !ns.In(r.Interface(), float32(l2t)) { - r = reflect.Append(r, reflect.ValueOf(float32(l2t))) - } - case reflect.Float64: - l2t, err := toFloat(l2vv) - if err == nil && !ns.In(r.Interface(), l2t) { - r = reflect.Append(r, reflect.ValueOf(l2t)) - } - case reflect.Interface: - switch l1vv.Interface().(type) { - case string: - switch l2vvActual := l2vv.Interface().(type) { - case string: - if !ns.In(r.Interface(), l2vvActual) { - r = reflect.Append(r, l2vv) - } - } - case int, int8, int16, int32, int64: - switch l2vvActual := l2vv.Interface().(type) { - case int, int8, int16, int32, int64: - if !ns.In(r.Interface(), l2vvActual) { - r = reflect.Append(r, l2vv) - } - } - case uint, uint8, uint16, uint32, uint64: - switch l2vvActual := l2vv.Interface().(type) { - case uint, uint8, uint16, uint32, uint64: - if !ns.In(r.Interface(), l2vvActual) { - r = reflect.Append(r, l2vv) - } - } - case float32, float64: - switch l2vvActual := l2vv.Interface().(type) { - case float32, float64: - if !ns.In(r.Interface(), l2vvActual) { - r = reflect.Append(r, l2vv) - } - } + case isNumber(kind): + var err error + l2vv, err = convertNumber(l2vv, kind) + if err == nil { + ins.appendIfNotSeen(l2vv) } + case kind == reflect.Interface, kind == reflect.Struct, kind == reflect.Ptr: + ins.appendIfNotSeen(l2vv) + } } - return r.Interface(), nil + return ins.r.Interface(), nil default: return nil, errors.New("can't iterate over " + reflect.ValueOf(l2).Type().String()) } diff --git a/tpl/collections/collections_test.go b/tpl/collections/collections_test.go index ea23a1de7..46bef9483 100644 --- a/tpl/collections/collections_test.go +++ b/tpl/collections/collections_test.go @@ -258,11 +258,34 @@ func TestIn(t *testing.T) { } } +type page struct { + Title string +} + +func (p page) String() string { + return "p-" + p.Title +} + +type pagesPtr []*page +type pagesVals []page + func TestIntersect(t *testing.T) { t.Parallel() ns := New(&deps.Deps{}) + var ( + p1 = &page{"A"} + p2 = &page{"B"} + p3 = &page{"C"} + p4 = &page{"D"} + + p1v = page{"A"} + p2v = page{"B"} + p3v = page{"C"} + p4v = page{"D"} + ) + for i, test := range []struct { l1, l2 interface{} expect interface{} @@ -280,6 +303,7 @@ func TestIntersect(t *testing.T) { {[]int{2, 4}, []int{1, 2, 4}, []int{2, 4}}, {[]int{1, 2, 4}, []int{3, 6}, []int{}}, {[]float64{2.2, 4.4}, []float64{1.1, 2.2, 4.4}, []float64{2.2, 4.4}}, + // errors {"not array or slice", []string{"a"}, false}, {[]string{"a"}, "not array or slice", false}, @@ -314,8 +338,15 @@ func TestIntersect(t *testing.T) { {[]int64{1, 2, 3}, []interface{}{int64(1), int64(2), int64(2)}, []int64{1, 2}}, {[]float32{1, 2, 3}, []interface{}{float32(1), float32(2), float32(2)}, []float32{1, 2}}, {[]float64{1, 2, 3}, []interface{}{float64(1), float64(2), float64(2)}, []float64{1, 2}}, + + // Structs + {pagesPtr{p1, p4, p2, p3}, pagesPtr{p4, p2, p2}, pagesPtr{p4, p2}}, + {pagesVals{p1v, p4v, p2v, p3v}, pagesVals{p1v, p3v, p3v}, pagesVals{p1v, p3v}}, + {[]interface{}{p1, p4, p2, p3}, []interface{}{p4, p2, p2}, []interface{}{p4, p2}}, + {[]interface{}{p1v, p4v, p2v, p3v}, []interface{}{p1v, p3v, p3v}, []interface{}{p1v, p3v}}, } { - errMsg := fmt.Sprintf("[%d] %v", i, test) + + errMsg := fmt.Sprintf("[%d]", test) result, err := ns.Intersect(test.l1, test.l2) @@ -325,7 +356,9 @@ func TestIntersect(t *testing.T) { } assert.NoError(t, err, errMsg) - assert.Equal(t, test.expect, result, errMsg) + if !reflect.DeepEqual(result, test.expect) { + t.Fatalf("[%d] Got\n%v expected\n%v", i, result, test.expect) + } } } @@ -569,6 +602,18 @@ func TestUnion(t *testing.T) { ns := New(&deps.Deps{}) + var ( + p1 = &page{"A"} + p2 = &page{"B"} + // p3 = &page{"C"} + p4 = &page{"D"} + + p1v = page{"A"} + //p2v = page{"B"} + p3v = page{"C"} + //p4v = page{"D"} + ) + for i, test := range []struct { l1 interface{} l2 interface{} @@ -604,6 +649,7 @@ func TestUnion(t *testing.T) { {[]int16{2, 4}, []interface{}{1, 2, 4}, []int16{2, 4, 1}, false}, {[]int32{2, 4}, []interface{}{1, 2, 4}, []int32{2, 4, 1}, false}, {[]int64{2, 4}, []interface{}{1, 2, 4}, []int64{2, 4, 1}, false}, + {[]float64{2.2, 4.4}, []interface{}{1.1, 2.2, 4.4}, []float64{2.2, 4.4, 1.1}, false}, {[]float32{2.2, 4.4}, []interface{}{1.1, 2.2, 4.4}, []float32{2.2, 4.4, 1.1}, false}, @@ -611,14 +657,21 @@ func TestUnion(t *testing.T) { {[]interface{}{"a", "b", "c", "c"}, []string{"a", "b", "d"}, []interface{}{"a", "b", "c", "d"}, false}, {[]interface{}{}, []string{}, []interface{}{}, false}, {[]interface{}{1, 2}, []int{2, 3}, []interface{}{1, 2, 3}, false}, - {[]interface{}{1, 2}, []int8{2, 3}, []interface{}{1, 2, int8(3)}, false}, + {[]interface{}{1, 2}, []int8{2, 3}, []interface{}{1, 2, 3}, false}, // 28 {[]interface{}{uint(1), uint(2)}, []uint{2, 3}, []interface{}{uint(1), uint(2), uint(3)}, false}, {[]interface{}{1.1, 2.2}, []float64{2.2, 3.3}, []interface{}{1.1, 2.2, 3.3}, false}, + // Structs + {pagesPtr{p1, p4}, pagesPtr{p4, p2, p2}, pagesPtr{p1, p4, p2}, false}, + {pagesVals{p1v}, pagesVals{p3v, p3v}, pagesVals{p1v, p3v}, false}, + {[]interface{}{p1, p4}, []interface{}{p4, p2, p2}, []interface{}{p1, p4, p2}, false}, + {[]interface{}{p1v}, []interface{}{p3v, p3v}, []interface{}{p1v, p3v}, false}, + // errors {"not array or slice", []string{"a"}, false, true}, {[]string{"a"}, "not array or slice", false, true}, } { + errMsg := fmt.Sprintf("[%d] %v", i, test) result, err := ns.Union(test.l1, test.l2) @@ -628,7 +681,9 @@ func TestUnion(t *testing.T) { } assert.NoError(t, err, errMsg) - assert.Equal(t, test.expect, result, errMsg) + if !reflect.DeepEqual(result, test.expect) { + t.Fatalf("[%d] Got\n%v expected\n%v", i, result, test.expect) + } } } diff --git a/tpl/collections/reflect_helpers.go b/tpl/collections/reflect_helpers.go index f07ea978c..69eaa68c4 100644 --- a/tpl/collections/reflect_helpers.go +++ b/tpl/collections/reflect_helpers.go @@ -14,26 +14,96 @@ package collections import ( + "errors" + "fmt" "reflect" + "time" ) -func numberToFloat(v reflect.Value) float64 { +var ( + zero reflect.Value + errorType = reflect.TypeOf((*error)(nil)).Elem() + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() +) + +func numberToFloat(v reflect.Value) (float64, error) { switch kind := v.Kind(); { case isFloat(kind): - return v.Float() + return v.Float(), nil case isInt(kind): - return float64(v.Int()) - case isUInt(kind): - return float64(v.Uint()) + return float64(v.Int()), nil + case isUint(kind): + return float64(v.Uint()), nil case kind == reflect.Interface: return numberToFloat(v.Elem()) default: - panic("Invalid type in numberToFloat") + return 0, fmt.Errorf("Invalid kind %s in numberToFloat", kind) } } +// There are potential overflows in this function, but the downconversion of +// int64 etc. into int8 etc. is coming from the synthetic unit tests for Union etc. +// TODO(bep) We should consider normalizing the slices to int64 etc. +func convertNumber(v reflect.Value, to reflect.Kind) (reflect.Value, error) { + var n reflect.Value + if isFloat(to) { + f, err := toFloat(v) + if err != nil { + return n, err + } + switch to { + case reflect.Float32: + n = reflect.ValueOf(float32(f)) + default: + n = reflect.ValueOf(float64(f)) + } + } else if isInt(to) { + i, err := toInt(v) + if err != nil { + return n, err + } + switch to { + case reflect.Int: + n = reflect.ValueOf(int(i)) + case reflect.Int8: + n = reflect.ValueOf(int8(i)) + case reflect.Int16: + n = reflect.ValueOf(int16(i)) + case reflect.Int32: + n = reflect.ValueOf(int32(i)) + case reflect.Int64: + n = reflect.ValueOf(int64(i)) + } + } else if isUint(to) { + i, err := toUint(v) + if err != nil { + return n, err + } + switch to { + case reflect.Uint: + n = reflect.ValueOf(uint(i)) + case reflect.Uint8: + n = reflect.ValueOf(uint8(i)) + case reflect.Uint16: + n = reflect.ValueOf(uint16(i)) + case reflect.Uint32: + n = reflect.ValueOf(uint32(i)) + case reflect.Uint64: + n = reflect.ValueOf(uint64(i)) + } + + } + + if !n.IsValid() { + return n, errors.New("invalid values") + } + + return n, nil + +} + func isNumber(kind reflect.Kind) bool { - return isInt(kind) || isUInt(kind) || isFloat(kind) + return isInt(kind) || isUint(kind) || isFloat(kind) } func isInt(kind reflect.Kind) bool { @@ -45,7 +115,7 @@ func isInt(kind reflect.Kind) bool { } } -func isUInt(kind reflect.Kind) bool { +func isUint(kind reflect.Kind) bool { switch kind { case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return true diff --git a/tpl/collections/where.go b/tpl/collections/where.go index d8045b301..37be00509 100644 --- a/tpl/collections/where.go +++ b/tpl/collections/where.go @@ -18,7 +18,6 @@ import ( "fmt" "reflect" "strings" - "time" ) // Where returns a filtered subset of a given data type. @@ -404,6 +403,16 @@ func toInt(v reflect.Value) (int64, error) { return -1, errors.New("unable to convert value to int") } +func toUint(v reflect.Value) (uint64, error) { + switch v.Kind() { + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return v.Uint(), nil + case reflect.Interface: + return toUint(v.Elem()) + } + return 0, errors.New("unable to convert value to uint") +} + // toString returns the string value if possible, "" if not. func toString(v reflect.Value) (string, error) { switch v.Kind() { @@ -415,12 +424,6 @@ func toString(v reflect.Value) (string, error) { return "", errors.New("unable to convert value to string") } -var ( - zero reflect.Value - errorType = reflect.TypeOf((*error)(nil)).Elem() - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() -) - func toTimeUnix(v reflect.Value) int64 { if v.Kind() == reflect.Interface { return toTimeUnix(v.Elem())