diff --git a/common/maps/cache.go b/common/maps/cache.go index 0175974b5..cdc31a684 100644 --- a/common/maps/cache.go +++ b/common/maps/cache.go @@ -13,11 +13,14 @@ package maps -import "sync" +import ( + "sync" +) // Cache is a simple thread safe cache backed by a map. type Cache[K comparable, T any] struct { - m map[K]T + m map[K]T + hasBeenInitialized bool sync.RWMutex } @@ -34,11 +37,16 @@ func (c *Cache[K, T]) Get(key K) (T, bool) { return zero, false } c.RLock() - v, found := c.m[key] + v, found := c.get(key) c.RUnlock() return v, found } +func (c *Cache[K, T]) get(key K) (T, bool) { + v, found := c.m[key] + return v, found +} + // GetOrCreate gets the value for the given key if it exists, or creates it if not. func (c *Cache[K, T]) GetOrCreate(key K, create func() (T, error)) (T, error) { c.RLock() @@ -61,13 +69,49 @@ func (c *Cache[K, T]) GetOrCreate(key K, create func() (T, error)) (T, error) { return v, nil } +// InitAndGet initializes the cache if not already done and returns the value for the given key. +// The init state will be reset on Reset or Drain. +func (c *Cache[K, T]) InitAndGet(key K, init func(get func(key K) (T, bool), set func(key K, value T)) error) (T, error) { + var v T + c.RLock() + if !c.hasBeenInitialized { + c.RUnlock() + if err := func() error { + c.Lock() + defer c.Unlock() + // Double check in case another goroutine has initialized it in the meantime. + if !c.hasBeenInitialized { + err := init(c.get, c.set) + if err != nil { + return err + } + c.hasBeenInitialized = true + } + return nil + }(); err != nil { + return v, err + } + // Reacquire the read lock. + c.RLock() + } + + v = c.m[key] + c.RUnlock() + + return v, nil +} + // Set sets the given key to the given value. func (c *Cache[K, T]) Set(key K, value T) { c.Lock() - c.m[key] = value + c.set(key, value) c.Unlock() } +func (c *Cache[K, T]) set(key K, value T) { + c.m[key] = value +} + // ForEeach calls the given function for each key/value pair in the cache. func (c *Cache[K, T]) ForEeach(f func(K, T)) { c.RLock() @@ -81,6 +125,7 @@ func (c *Cache[K, T]) Drain() map[K]T { c.Lock() m := c.m c.m = make(map[K]T) + c.hasBeenInitialized = false c.Unlock() return m } @@ -94,6 +139,7 @@ func (c *Cache[K, T]) Len() int { func (c *Cache[K, T]) Reset() { c.Lock() c.m = make(map[K]T) + c.hasBeenInitialized = false c.Unlock() } diff --git a/hugolib/content_map_page.go b/hugolib/content_map_page.go index 5e8646b21..8c9e4a31a 100644 --- a/hugolib/content_map_page.go +++ b/hugolib/content_map_page.go @@ -37,7 +37,6 @@ import ( "github.com/gohugoio/hugo/hugolib/doctree" "github.com/gohugoio/hugo/hugolib/pagesfromdata" "github.com/gohugoio/hugo/identity" - "github.com/gohugoio/hugo/lazy" "github.com/gohugoio/hugo/media" "github.com/gohugoio/hugo/output" "github.com/gohugoio/hugo/resources" @@ -925,59 +924,58 @@ func newPageMap(i int, s *Site, mcache *dynacache.Cache, pageTrees *pageTrees) * s: s, } - m.pageReverseIndex = &contentTreeReverseIndex{ - initFn: func(rm map[any]contentNodeI) { - add := func(k string, n contentNodeI) { - existing, found := rm[k] - if found && existing != ambiguousContentNode { - rm[k] = ambiguousContentNode - } else if !found { - rm[k] = n + m.pageReverseIndex = newContentTreeTreverseIndex(func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI)) { + add := func(k string, n contentNodeI) { + existing, found := get(k) + if found && existing != ambiguousContentNode { + set(k, ambiguousContentNode) + } else if !found { + set(k, n) + } + } + + w := &doctree.NodeShiftTreeWalker[contentNodeI]{ + Tree: m.treePages, + LockType: doctree.LockTypeRead, + Handle: func(s string, n contentNodeI, match doctree.DimensionFlag) (bool, error) { + p := n.(*pageState) + if p.PathInfo() != nil { + add(p.PathInfo().BaseNameNoIdentifier(), p) } - } + return false, nil + }, + } - w := &doctree.NodeShiftTreeWalker[contentNodeI]{ - Tree: m.treePages, - LockType: doctree.LockTypeRead, - Handle: func(s string, n contentNodeI, match doctree.DimensionFlag) (bool, error) { - p := n.(*pageState) - if p.PathInfo() != nil { - add(p.PathInfo().BaseNameNoIdentifier(), p) - } - return false, nil - }, - } - - if err := w.Walk(context.Background()); err != nil { - panic(err) - } - }, - contentTreeReverseIndexMap: &contentTreeReverseIndexMap{}, - } + if err := w.Walk(context.Background()); err != nil { + panic(err) + } + }) return m } +func newContentTreeTreverseIndex(init func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI))) *contentTreeReverseIndex { + return &contentTreeReverseIndex{ + initFn: init, + mm: maps.NewCache[any, contentNodeI](), + } +} + type contentTreeReverseIndex struct { - initFn func(rm map[any]contentNodeI) - *contentTreeReverseIndexMap + initFn func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI)) + mm *maps.Cache[any, contentNodeI] } func (c *contentTreeReverseIndex) Reset() { - c.init.ResetWithLock().Unlock() + c.mm.Reset() } func (c *contentTreeReverseIndex) Get(key any) contentNodeI { - c.init.Do(func() { - c.m = make(map[any]contentNodeI) - c.initFn(c.contentTreeReverseIndexMap.m) + v, _ := c.mm.InitAndGet(key, func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI)) error { + c.initFn(get, set) + return nil }) - return c.m[key] -} - -type contentTreeReverseIndexMap struct { - init lazy.OnceMore - m map[any]contentNodeI + return v } type sitePagesAssembler struct { diff --git a/hugolib/content_map_test.go b/hugolib/content_map_test.go index bf9920071..a9f719f4a 100644 --- a/hugolib/content_map_test.go +++ b/hugolib/content_map_test.go @@ -17,9 +17,11 @@ import ( "fmt" "path/filepath" "strings" + "sync" "testing" qt "github.com/frankban/quicktest" + "github.com/gohugoio/hugo/identity" ) func TestContentMapSite(t *testing.T) { @@ -396,3 +398,77 @@ irrelevant "https://example.org/en/sitemap.xml", ) } + +func TestContentTreeReverseIndex(t *testing.T) { + t.Parallel() + + c := qt.New(t) + + pageReverseIndex := newContentTreeTreverseIndex( + func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI)) { + for i := 0; i < 10; i++ { + key := fmt.Sprint(i) + set(key, &testContentNode{key: key}) + } + }, + ) + + for i := 0; i < 10; i++ { + key := fmt.Sprint(i) + v := pageReverseIndex.Get(key) + c.Assert(v, qt.Not(qt.IsNil)) + c.Assert(v.Path(), qt.Equals, key) + } +} + +// Issue 13019. +func TestContentTreeReverseIndexPara(t *testing.T) { + t.Parallel() + + var wg sync.WaitGroup + + for i := 0; i < 10; i++ { + pageReverseIndex := newContentTreeTreverseIndex( + func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI)) { + for i := 0; i < 10; i++ { + key := fmt.Sprint(i) + set(key, &testContentNode{key: key}) + } + }, + ) + + for j := 0; j < 10; j++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + pageReverseIndex.Get(fmt.Sprint(i)) + }(j) + } + } +} + +type testContentNode struct { + key string +} + +func (n *testContentNode) GetIdentity() identity.Identity { + return identity.StringIdentity(n.key) +} + +func (n *testContentNode) ForEeachIdentity(cb func(id identity.Identity) bool) bool { + panic("not supported") +} + +func (n *testContentNode) Path() string { + return n.key +} + +func (n *testContentNode) isContentNodeBranch() bool { + return false +} + +func (n *testContentNode) resetBuildState() { +} + +func (n *testContentNode) MarkStale() { +} diff --git a/lazy/init.go b/lazy/init.go index bef3867a9..7b88a5351 100644 --- a/lazy/init.go +++ b/lazy/init.go @@ -36,7 +36,7 @@ type Init struct { prev *Init children []*Init - init OnceMore + init onceMore out any err error f func(context.Context) (any, error) diff --git a/lazy/once.go b/lazy/once.go index dac689df3..cea096652 100644 --- a/lazy/once.go +++ b/lazy/once.go @@ -24,13 +24,13 @@ import ( // * it can be reset, so the action can be repeated if needed // * it has methods to check if it's done or in progress -type OnceMore struct { +type onceMore struct { mu sync.Mutex lock uint32 done uint32 } -func (t *OnceMore) Do(f func()) { +func (t *onceMore) Do(f func()) { if atomic.LoadUint32(&t.done) == 1 { return } @@ -53,15 +53,15 @@ func (t *OnceMore) Do(f func()) { f() } -func (t *OnceMore) InProgress() bool { +func (t *onceMore) InProgress() bool { return atomic.LoadUint32(&t.lock) == 1 } -func (t *OnceMore) Done() bool { +func (t *onceMore) Done() bool { return atomic.LoadUint32(&t.done) == 1 } -func (t *OnceMore) ResetWithLock() *sync.Mutex { +func (t *onceMore) ResetWithLock() *sync.Mutex { t.mu.Lock() defer atomic.StoreUint32(&t.done, 0) return &t.mu