From 5bec50838c9d5ce097a82d91b7924f22453de1a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Tue, 11 Jul 2023 09:48:57 +0200 Subject: [PATCH] tpl/collections: Fix WordCount (etc.) regression in Where, Sort, Delimit Fixes #11234 --- tpl/collections/collections.go | 5 ++-- tpl/collections/collections_test.go | 5 ++-- tpl/collections/integration_test.go | 41 +++++++++++++++++++++++++++++ tpl/collections/sort.go | 9 ++++--- tpl/collections/sort_test.go | 5 ++-- tpl/collections/where.go | 35 +++++++++++++++--------- tpl/collections/where_test.go | 13 ++++----- 7 files changed, 86 insertions(+), 27 deletions(-) diff --git a/tpl/collections/collections.go b/tpl/collections/collections.go index 04b777dfb..92aa2b9e6 100644 --- a/tpl/collections/collections.go +++ b/tpl/collections/collections.go @@ -16,6 +16,7 @@ package collections import ( + "context" "fmt" "html/template" "math/rand" @@ -99,7 +100,7 @@ func (ns *Namespace) After(n any, l any) (any, error) { // Delimit takes a given list l and returns a string delimited by sep. // If last is passed to the function, it will be used as the final delimiter. -func (ns *Namespace) Delimit(l, sep any, last ...any) (template.HTML, error) { +func (ns *Namespace) Delimit(ctx context.Context, l, sep any, last ...any) (template.HTML, error) { d, err := cast.ToStringE(sep) if err != nil { return "", err @@ -125,7 +126,7 @@ func (ns *Namespace) Delimit(l, sep any, last ...any) (template.HTML, error) { var str string switch lv.Kind() { case reflect.Map: - sortSeq, err := ns.Sort(l) + sortSeq, err := ns.Sort(ctx, l) if err != nil { return "", err } diff --git a/tpl/collections/collections_test.go b/tpl/collections/collections_test.go index 86192c480..43f8377f3 100644 --- a/tpl/collections/collections_test.go +++ b/tpl/collections/collections_test.go @@ -14,6 +14,7 @@ package collections import ( + "context" "errors" "fmt" "html/template" @@ -166,9 +167,9 @@ func TestDelimit(t *testing.T) { var err error if test.last == nil { - result, err = ns.Delimit(test.seq, test.delimiter) + result, err = ns.Delimit(context.Background(), test.seq, test.delimiter) } else { - result, err = ns.Delimit(test.seq, test.delimiter, test.last) + result, err = ns.Delimit(context.Background(), test.seq, test.delimiter, test.last) } c.Assert(err, qt.IsNil, errMsg) diff --git a/tpl/collections/integration_test.go b/tpl/collections/integration_test.go index 80d2f043a..7ef0b6c47 100644 --- a/tpl/collections/integration_test.go +++ b/tpl/collections/integration_test.go @@ -155,3 +155,44 @@ func TestAppendNilsToSliceWithNils(t *testing.T) { } } + +// Issue 11234. +func TestWhereWithWordCount(t *testing.T) { + t.Parallel() + + files := ` +-- config.toml -- +baseURL = 'http://example.com/' +-- layouts/index.html -- +Home: {{ range where site.RegularPages "WordCount" "gt" 50 }}{{ .Title }}|{{ end }} +-- layouts/shortcodes/lorem.html -- +{{ "ipsum " | strings.Repeat (.Get 0 | int) }} + +-- content/p1.md -- +--- +title: "p1" +--- +{{< lorem 100 >}} +-- content/p2.md -- +--- +title: "p2" +--- +{{< lorem 20 >}} +-- content/p3.md -- +--- +title: "p3" +--- +{{< lorem 60 >}} + ` + + b := hugolib.NewIntegrationTestBuilder( + hugolib.IntegrationTestConfig{ + T: t, + TxtarString: files, + }, + ).Build() + + b.AssertFileContent("public/index.html", ` +Home: p1|p3| +`) +} diff --git a/tpl/collections/sort.go b/tpl/collections/sort.go index 4a2106039..2040f8490 100644 --- a/tpl/collections/sort.go +++ b/tpl/collections/sort.go @@ -14,6 +14,7 @@ package collections import ( + "context" "errors" "reflect" "sort" @@ -26,7 +27,7 @@ import ( ) // Sort returns a sorted copy of the list l. -func (ns *Namespace) Sort(l any, args ...any) (any, error) { +func (ns *Namespace) Sort(ctx context.Context, l any, args ...any) (any, error) { if l == nil { return nil, errors.New("sequence must be provided") } @@ -36,6 +37,8 @@ func (ns *Namespace) Sort(l any, args ...any) (any, error) { return nil, errors.New("can't iterate over a nil value") } + ctxv := reflect.ValueOf(ctx) + var sliceType reflect.Type switch seqv.Kind() { case reflect.Array, reflect.Slice: @@ -78,7 +81,7 @@ func (ns *Namespace) Sort(l any, args ...any) (any, error) { v := p.Pairs[i].Value var err error for i, elemName := range path { - v, err = evaluateSubElem(v, elemName) + v, err = evaluateSubElem(ctxv, v, elemName) if err != nil { return nil, err } @@ -108,7 +111,7 @@ func (ns *Namespace) Sort(l any, args ...any) (any, error) { v := p.Pairs[i].Value var err error for i, elemName := range path { - v, err = evaluateSubElem(v, elemName) + v, err = evaluateSubElem(ctxv, v, elemName) if err != nil { return nil, err } diff --git a/tpl/collections/sort_test.go b/tpl/collections/sort_test.go index da9c75d04..1ec95882f 100644 --- a/tpl/collections/sort_test.go +++ b/tpl/collections/sort_test.go @@ -14,6 +14,7 @@ package collections import ( + "context" "fmt" "reflect" "testing" @@ -240,9 +241,9 @@ func TestSort(t *testing.T) { var result any var err error if test.sortByField == nil { - result, err = ns.Sort(test.seq) + result, err = ns.Sort(context.Background(), test.seq) } else { - result, err = ns.Sort(test.seq, test.sortByField, test.sortAsc) + result, err = ns.Sort(context.Background(), test.seq, test.sortByField, test.sortAsc) } if b, ok := test.expect.(bool); ok && !b { diff --git a/tpl/collections/where.go b/tpl/collections/where.go index b20c290fa..2904b7cdd 100644 --- a/tpl/collections/where.go +++ b/tpl/collections/where.go @@ -14,6 +14,7 @@ package collections import ( + "context" "errors" "fmt" "reflect" @@ -24,7 +25,7 @@ import ( ) // Where returns a filtered subset of collection c. -func (ns *Namespace) Where(c, key any, args ...any) (any, error) { +func (ns *Namespace) Where(ctx context.Context, c, key any, args ...any) (any, error) { seqv, isNil := indirect(reflect.ValueOf(c)) if isNil { return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(c).Type().String()) @@ -35,6 +36,8 @@ func (ns *Namespace) Where(c, key any, args ...any) (any, error) { return nil, err } + ctxv := reflect.ValueOf(ctx) + var path []string kv := reflect.ValueOf(key) if kv.Kind() == reflect.String { @@ -43,9 +46,9 @@ func (ns *Namespace) Where(c, key any, args ...any) (any, error) { switch seqv.Kind() { case reflect.Array, reflect.Slice: - return ns.checkWhereArray(seqv, kv, mv, path, op) + return ns.checkWhereArray(ctxv, seqv, kv, mv, path, op) case reflect.Map: - return ns.checkWhereMap(seqv, kv, mv, path, op) + return ns.checkWhereMap(ctxv, seqv, kv, mv, path, op) default: return nil, fmt.Errorf("can't iterate over %v", c) } @@ -275,7 +278,7 @@ func (ns *Namespace) checkCondition(v, mv reflect.Value, op string) (bool, error return false, nil } -func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) { +func evaluateSubElem(ctx, obj reflect.Value, elemName string) (reflect.Value, error) { if !obj.IsValid() { return zero, errors.New("can't evaluate an invalid value") } @@ -301,12 +304,20 @@ func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) index := hreflect.GetMethodIndexByName(objPtr.Type(), elemName) if index != -1 { + var args []reflect.Value mt := objPtr.Type().Method(index) + num := mt.Type.NumIn() + maxNumIn := 1 + if num > 1 && mt.Type.In(1).Implements(hreflect.ContextInterface) { + args = []reflect.Value{ctx} + maxNumIn = 2 + } + switch { case mt.PkgPath != "": return zero, fmt.Errorf("%s is an unexported method of type %s", elemName, typ) - case mt.Type.NumIn() > 1: - return zero, fmt.Errorf("%s is a method of type %s but requires more than 1 parameter", elemName, typ) + case mt.Type.NumIn() > maxNumIn: + return zero, fmt.Errorf("%s is a method of type %s but requires more than %d parameter", elemName, typ, maxNumIn) case mt.Type.NumOut() == 0: return zero, fmt.Errorf("%s is a method of type %s but returns no output", elemName, typ) case mt.Type.NumOut() > 2: @@ -316,7 +327,7 @@ func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) case mt.Type.NumOut() == 2 && !mt.Type.Out(1).Implements(errorType): return zero, fmt.Errorf("%s is a method of type %s returning two values but the second value is not an error type", elemName, typ) } - res := objPtr.Method(mt.Index).Call([]reflect.Value{}) + res := objPtr.Method(mt.Index).Call(args) if len(res) == 2 && !res[1].IsNil() { return zero, fmt.Errorf("error at calling a method %s of type %s: %s", elemName, typ, res[1].Interface().(error)) } @@ -371,7 +382,7 @@ func parseWhereArgs(args ...any) (mv reflect.Value, op string, err error) { // checkWhereArray handles the where-matching logic when the seqv value is an // Array or Slice. -func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (any, error) { +func (ns *Namespace) checkWhereArray(ctxv, seqv, kv, mv reflect.Value, path []string, op string) (any, error) { rv := reflect.MakeSlice(seqv.Type(), 0, 0) for i := 0; i < seqv.Len(); i++ { @@ -385,7 +396,7 @@ func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string, vvv = rvv for i, elemName := range path { var err error - vvv, err = evaluateSubElem(vvv, elemName) + vvv, err = evaluateSubElem(ctxv, vvv, elemName) if err != nil { continue @@ -417,14 +428,14 @@ func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string, } // checkWhereMap handles the where-matching logic when the seqv value is a Map. -func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (any, error) { +func (ns *Namespace) checkWhereMap(ctxv, seqv, kv, mv reflect.Value, path []string, op string) (any, error) { rv := reflect.MakeMap(seqv.Type()) keys := seqv.MapKeys() for _, k := range keys { elemv := seqv.MapIndex(k) switch elemv.Kind() { case reflect.Array, reflect.Slice: - r, err := ns.checkWhereArray(elemv, kv, mv, path, op) + r, err := ns.checkWhereArray(ctxv, elemv, kv, mv, path, op) if err != nil { return nil, err } @@ -443,7 +454,7 @@ func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op switch elemvv.Kind() { case reflect.Array, reflect.Slice: - r, err := ns.checkWhereArray(elemvv, kv, mv, path, op) + r, err := ns.checkWhereArray(ctxv, elemvv, kv, mv, path, op) if err != nil { return nil, err } diff --git a/tpl/collections/where_test.go b/tpl/collections/where_test.go index e5ae85e88..1b787daa2 100644 --- a/tpl/collections/where_test.go +++ b/tpl/collections/where_test.go @@ -14,6 +14,7 @@ package collections import ( + "context" "fmt" "html/template" "reflect" @@ -641,9 +642,9 @@ func TestWhere(t *testing.T) { var err error if len(test.op) > 0 { - results, err = ns.Where(test.seq, test.key, test.op, test.match) + results, err = ns.Where(context.Background(), test.seq, test.key, test.op, test.match) } else { - results, err = ns.Where(test.seq, test.key, test.match) + results, err = ns.Where(context.Background(), test.seq, test.key, test.match) } if b, ok := test.expect.(bool); ok && !b { if err == nil { @@ -662,17 +663,17 @@ func TestWhere(t *testing.T) { } var err error - _, err = ns.Where(map[string]int{"a": 1, "b": 2}, "a", []byte("="), 1) + _, err = ns.Where(context.Background(), map[string]int{"a": 1, "b": 2}, "a", []byte("="), 1) if err == nil { t.Errorf("Where called with none string op value didn't return an expected error") } - _, err = ns.Where(map[string]int{"a": 1, "b": 2}, "a", []byte("="), 1, 2) + _, err = ns.Where(context.Background(), map[string]int{"a": 1, "b": 2}, "a", []byte("="), 1, 2) if err == nil { t.Errorf("Where called with more than two variable arguments didn't return an expected error") } - _, err = ns.Where(map[string]int{"a": 1, "b": 2}, "a") + _, err = ns.Where(context.Background(), map[string]int{"a": 1, "b": 2}, "a") if err == nil { t.Errorf("Where called with no variable arguments didn't return an expected error") } @@ -842,7 +843,7 @@ func TestEvaluateSubElem(t *testing.T) { {reflect.ValueOf(map[int]string{1: "foo", 2: "bar"}), "1", false}, {reflect.ValueOf([]string{"foo", "bar"}), "1", false}, } { - result, err := evaluateSubElem(test.value, test.key) + result, err := evaluateSubElem(reflect.ValueOf(context.Background()), test.value, test.key) if b, ok := test.expect.(bool); ok && !b { if err == nil { t.Errorf("[%d] evaluateSubElem didn't return an expected error", i)