tpl: Extend where to iterate over maps

Refactor and extend where to iterate over maps.
This commit is contained in:
Cameron Moore 2016-04-12 20:31:14 -05:00 committed by Bjørn Erik Pedersen
parent 206440eef2
commit 0141a02160
2 changed files with 115 additions and 44 deletions

View file

@ -37,13 +37,11 @@ import (
"time"
"unicode/utf8"
"github.com/spf13/afero"
"github.com/spf13/hugo/hugofs"
"github.com/bep/inflect"
"github.com/spf13/afero"
"github.com/spf13/cast"
"github.com/spf13/hugo/helpers"
"github.com/spf13/hugo/hugofs"
jww "github.com/spf13/jwalterweatherman"
)
@ -771,64 +769,125 @@ func checkCondition(v, mv reflect.Value, op string) (bool, error) {
return false, nil
}
// where returns a filtered subset of a given data type.
func where(seq, key interface{}, args ...interface{}) (r interface{}, err error) {
seqv := reflect.ValueOf(seq)
kv := reflect.ValueOf(key)
var mv reflect.Value
var op string
// parseWhereArgs parses the end arguments to the where function. Return a
// match value and an operator, if one is defined.
func parseWhereArgs(args ...interface{}) (mv reflect.Value, op string, err error) {
switch len(args) {
case 1:
mv = reflect.ValueOf(args[0])
case 2:
var ok bool
if op, ok = args[0].(string); !ok {
return nil, errors.New("operator argument must be string type")
err = errors.New("operator argument must be string type")
return
}
op = strings.TrimSpace(strings.ToLower(op))
mv = reflect.ValueOf(args[1])
default:
return nil, errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
err = errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
}
return
}
seqv, isNil := indirect(seqv)
// checkWhereArray handles the where-matching logic when the seqv value is an
// Array or Slice.
func checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
rv := reflect.MakeSlice(seqv.Type(), 0, 0)
for i := 0; i < seqv.Len(); i++ {
var vvv reflect.Value
rvv := seqv.Index(i)
if kv.Kind() == reflect.String {
vvv = rvv
for _, elemName := range path {
var err error
vvv, err = evaluateSubElem(vvv, elemName)
if err != nil {
return nil, err
}
}
} else {
vv, _ := indirect(rvv)
if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
vvv = vv.MapIndex(kv)
}
}
if ok, err := checkCondition(vvv, mv, op); ok {
rv = reflect.Append(rv, rvv)
} else if err != nil {
return nil, err
}
}
return rv.Interface(), nil
}
// checkWhereMap handles the where-matching logic when the seqv value is a Map.
func checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
rv := reflect.MakeMap(seqv.Type())
keys := seqv.MapKeys()
for _, k := range keys {
elemv := seqv.MapIndex(k)
switch elemv.Kind() {
case reflect.Array, reflect.Slice:
r, err := checkWhereArray(elemv, kv, mv, path, op)
if err != nil {
return nil, err
}
switch rr := reflect.ValueOf(r); rr.Kind() {
case reflect.Slice:
if rr.Len() > 0 {
rv.SetMapIndex(k, elemv)
}
}
case reflect.Interface:
elemvv, isNil := indirect(elemv)
if isNil {
continue
}
switch elemvv.Kind() {
case reflect.Array, reflect.Slice:
r, err := checkWhereArray(elemvv, kv, mv, path, op)
if err != nil {
return nil, err
}
switch rr := reflect.ValueOf(r); rr.Kind() {
case reflect.Slice:
if rr.Len() > 0 {
rv.SetMapIndex(k, elemv)
}
}
}
}
}
return rv, nil
}
// where returns a filtered subset of a given data type.
func where(seq, key interface{}, args ...interface{}) (interface{}, error) {
seqv, isNil := indirect(reflect.ValueOf(seq))
if isNil {
return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(seq).Type().String())
}
mv, op, err := parseWhereArgs(args...)
if err != nil {
return nil, err
}
var path []string
kv := reflect.ValueOf(key)
if kv.Kind() == reflect.String {
path = strings.Split(strings.Trim(kv.String(), "."), ".")
}
switch seqv.Kind() {
case reflect.Array, reflect.Slice:
rv := reflect.MakeSlice(seqv.Type(), 0, 0)
for i := 0; i < seqv.Len(); i++ {
var vvv reflect.Value
rvv := seqv.Index(i)
if kv.Kind() == reflect.String {
vvv = rvv
for _, elemName := range path {
vvv, err = evaluateSubElem(vvv, elemName)
if err != nil {
return nil, err
}
}
} else {
vv, _ := indirect(rvv)
if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
vvv = vv.MapIndex(kv)
}
}
if ok, err := checkCondition(vvv, mv, op); ok {
rv = reflect.Append(rv, rvv)
} else if err != nil {
return nil, err
}
}
return rv.Interface(), nil
return checkWhereArray(seqv, kv, mv, path, op)
case reflect.Map:
return checkWhereMap(seqv, kv, mv, path, op)
default:
return nil, fmt.Errorf("can't iterate over %v", seq)
}

View file

@ -18,11 +18,6 @@ import (
"encoding/base64"
"errors"
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cast"
"github.com/spf13/hugo/hugofs"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"html/template"
"math/rand"
"path"
@ -32,6 +27,12 @@ import (
"strings"
"testing"
"time"
"github.com/spf13/afero"
"github.com/spf13/cast"
"github.com/spf13/hugo/hugofs"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
)
type tstNoStringer struct {
@ -1298,6 +1299,17 @@ func TestWhere(t *testing.T) {
key: "B", op: "op", match: "f",
expect: false,
},
{
sequence: map[string]interface{}{
"foo": []interface{}{map[interface{}]interface{}{"a": 1, "b": 2}},
"bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}},
"zap": []interface{}{map[interface{}]interface{}{"a": 5, "b": 6}},
},
key: "b", op: "in", match: slice(3, 4, 5),
expect: map[string]interface{}{
"bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}},
},
},
} {
var results interface{}
var err error