mirror of
https://github.com/gohugoio/hugo.git
synced 2024-12-26 05:01:26 +00:00
tpl: Extend where to iterate over maps
Refactor and extend where to iterate over maps.
This commit is contained in:
parent
206440eef2
commit
0141a02160
2 changed files with 115 additions and 44 deletions
|
@ -37,13 +37,11 @@ import (
|
|||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/hugo/hugofs"
|
||||
|
||||
"github.com/bep/inflect"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/spf13/hugo/helpers"
|
||||
"github.com/spf13/hugo/hugofs"
|
||||
jww "github.com/spf13/jwalterweatherman"
|
||||
)
|
||||
|
||||
|
@ -771,64 +769,125 @@ func checkCondition(v, mv reflect.Value, op string) (bool, error) {
|
|||
return false, nil
|
||||
}
|
||||
|
||||
// where returns a filtered subset of a given data type.
|
||||
func where(seq, key interface{}, args ...interface{}) (r interface{}, err error) {
|
||||
seqv := reflect.ValueOf(seq)
|
||||
kv := reflect.ValueOf(key)
|
||||
|
||||
var mv reflect.Value
|
||||
var op string
|
||||
// parseWhereArgs parses the end arguments to the where function. Return a
|
||||
// match value and an operator, if one is defined.
|
||||
func parseWhereArgs(args ...interface{}) (mv reflect.Value, op string, err error) {
|
||||
switch len(args) {
|
||||
case 1:
|
||||
mv = reflect.ValueOf(args[0])
|
||||
case 2:
|
||||
var ok bool
|
||||
if op, ok = args[0].(string); !ok {
|
||||
return nil, errors.New("operator argument must be string type")
|
||||
err = errors.New("operator argument must be string type")
|
||||
return
|
||||
}
|
||||
op = strings.TrimSpace(strings.ToLower(op))
|
||||
mv = reflect.ValueOf(args[1])
|
||||
default:
|
||||
return nil, errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
|
||||
err = errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
seqv, isNil := indirect(seqv)
|
||||
// checkWhereArray handles the where-matching logic when the seqv value is an
|
||||
// Array or Slice.
|
||||
func checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
|
||||
rv := reflect.MakeSlice(seqv.Type(), 0, 0)
|
||||
for i := 0; i < seqv.Len(); i++ {
|
||||
var vvv reflect.Value
|
||||
rvv := seqv.Index(i)
|
||||
if kv.Kind() == reflect.String {
|
||||
vvv = rvv
|
||||
for _, elemName := range path {
|
||||
var err error
|
||||
vvv, err = evaluateSubElem(vvv, elemName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
vv, _ := indirect(rvv)
|
||||
if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
|
||||
vvv = vv.MapIndex(kv)
|
||||
}
|
||||
}
|
||||
|
||||
if ok, err := checkCondition(vvv, mv, op); ok {
|
||||
rv = reflect.Append(rv, rvv)
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return rv.Interface(), nil
|
||||
}
|
||||
|
||||
// checkWhereMap handles the where-matching logic when the seqv value is a Map.
|
||||
func checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
|
||||
rv := reflect.MakeMap(seqv.Type())
|
||||
keys := seqv.MapKeys()
|
||||
for _, k := range keys {
|
||||
elemv := seqv.MapIndex(k)
|
||||
switch elemv.Kind() {
|
||||
case reflect.Array, reflect.Slice:
|
||||
r, err := checkWhereArray(elemv, kv, mv, path, op)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch rr := reflect.ValueOf(r); rr.Kind() {
|
||||
case reflect.Slice:
|
||||
if rr.Len() > 0 {
|
||||
rv.SetMapIndex(k, elemv)
|
||||
}
|
||||
}
|
||||
case reflect.Interface:
|
||||
elemvv, isNil := indirect(elemv)
|
||||
if isNil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch elemvv.Kind() {
|
||||
case reflect.Array, reflect.Slice:
|
||||
r, err := checkWhereArray(elemvv, kv, mv, path, op)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch rr := reflect.ValueOf(r); rr.Kind() {
|
||||
case reflect.Slice:
|
||||
if rr.Len() > 0 {
|
||||
rv.SetMapIndex(k, elemv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return rv, nil
|
||||
}
|
||||
|
||||
// where returns a filtered subset of a given data type.
|
||||
func where(seq, key interface{}, args ...interface{}) (interface{}, error) {
|
||||
seqv, isNil := indirect(reflect.ValueOf(seq))
|
||||
if isNil {
|
||||
return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(seq).Type().String())
|
||||
}
|
||||
|
||||
mv, op, err := parseWhereArgs(args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var path []string
|
||||
kv := reflect.ValueOf(key)
|
||||
if kv.Kind() == reflect.String {
|
||||
path = strings.Split(strings.Trim(kv.String(), "."), ".")
|
||||
}
|
||||
|
||||
switch seqv.Kind() {
|
||||
case reflect.Array, reflect.Slice:
|
||||
rv := reflect.MakeSlice(seqv.Type(), 0, 0)
|
||||
for i := 0; i < seqv.Len(); i++ {
|
||||
var vvv reflect.Value
|
||||
rvv := seqv.Index(i)
|
||||
if kv.Kind() == reflect.String {
|
||||
vvv = rvv
|
||||
for _, elemName := range path {
|
||||
vvv, err = evaluateSubElem(vvv, elemName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
vv, _ := indirect(rvv)
|
||||
if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
|
||||
vvv = vv.MapIndex(kv)
|
||||
}
|
||||
}
|
||||
if ok, err := checkCondition(vvv, mv, op); ok {
|
||||
rv = reflect.Append(rv, rvv)
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return rv.Interface(), nil
|
||||
return checkWhereArray(seqv, kv, mv, path, op)
|
||||
case reflect.Map:
|
||||
return checkWhereMap(seqv, kv, mv, path, op)
|
||||
default:
|
||||
return nil, fmt.Errorf("can't iterate over %v", seq)
|
||||
}
|
||||
|
|
|
@ -18,11 +18,6 @@ import (
|
|||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/spf13/hugo/hugofs"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"html/template"
|
||||
"math/rand"
|
||||
"path"
|
||||
|
@ -32,6 +27,12 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/spf13/hugo/hugofs"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type tstNoStringer struct {
|
||||
|
@ -1298,6 +1299,17 @@ func TestWhere(t *testing.T) {
|
|||
key: "B", op: "op", match: "f",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
sequence: map[string]interface{}{
|
||||
"foo": []interface{}{map[interface{}]interface{}{"a": 1, "b": 2}},
|
||||
"bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}},
|
||||
"zap": []interface{}{map[interface{}]interface{}{"a": 5, "b": 6}},
|
||||
},
|
||||
key: "b", op: "in", match: slice(3, 4, 5),
|
||||
expect: map[string]interface{}{
|
||||
"bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}},
|
||||
},
|
||||
},
|
||||
} {
|
||||
var results interface{}
|
||||
var err error
|
||||
|
|
Loading…
Reference in a new issue