From 8f08cdd0ac6a2decd5aa5c9c12c0b2c264f9a989 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?=
" - match := []byte(endBodyTag) - replaceTemplate := `%s` - replace := []byte(fmt.Sprintf(replaceTemplate, port, endBodyTag)) - - newcontent := bytes.Replace(b, match, replace, 1) - if len(newcontent) == len(b) { - endBodyTag = "" - replace := []byte(fmt.Sprintf(replaceTemplate, port, endBodyTag)) - match := []byte(endBodyTag) - newcontent = bytes.Replace(b, match, replace, 1) + var idx = -1 + var match tag + // We used to insert the livereload script right before the closing body. + // This does not work when combined with tools such as Turbolinks. + // So we try to inject the script as early as possible. + for _, t := range tags { + idx = bytes.Index(b, t.markup) + if idx != -1 { + match = t + break + } } - if _, err := ft.To().Write(newcontent); err != nil { + c := make([]byte, len(b)) + copy(c, b) + + if idx == -1 { + _, err := ft.To().Write(c) + return err + } + + script := []byte(fmt.Sprintf(``, port)) + + i := idx + if match.appendScript { + i += len(match.markup) + } + + c = append(c[:i], append(script, c[i:]...)...) + + if _, err := ft.To().Write(c); err != nil { helpers.DistinctWarnLog.Println("Failed to inject LiveReload script:", err) } return nil diff --git a/transform/livereloadinject/livereloadinject_test.go b/transform/livereloadinject/livereloadinject_test.go index 413ca7b43..4dd256bb0 100644 --- a/transform/livereloadinject/livereloadinject_test.go +++ b/transform/livereloadinject/livereloadinject_test.go @@ -15,27 +15,45 @@ package livereloadinject import ( "bytes" - "fmt" "strings" "testing" + qt "github.com/frankban/quicktest" "github.com/gohugoio/hugo/transform" ) func TestLiveReloadInject(t *testing.T) { - doTestLiveReloadInject(t, "") - doTestLiveReloadInject(t, "") -} + c := qt.New(t) -func doTestLiveReloadInject(t *testing.T, bodyEndTag string) { - out := new(bytes.Buffer) - in := strings.NewReader(bodyEndTag) + expectBase := `` + apply := func(s string) string { + out := new(bytes.Buffer) + in := strings.NewReader(s) - tr := transform.New(New(1313)) - tr.Apply(out, in) + tr := transform.New(New(1313)) + tr.Apply(out, in) - expected := fmt.Sprintf(`%s`, bodyEndTag) - if out.String() != expected { - t.Errorf("Expected %s got %s", expected, out.String()) + return out.String() } + + c.Run("Head lower", func(c *qt.C) { + c.Assert(apply("
foo"), qt.Equals, "
"+expectBase+"foo") + }) + + c.Run("Head upper", func(c *qt.C) { + c.Assert(apply("
foo"), qt.Equals, "
"+expectBase+"foo") + }) + + c.Run("Body lower", func(c *qt.C) { + c.Assert(apply("foo"), qt.Equals, "foo"+expectBase+"") + }) + + c.Run("Body upper", func(c *qt.C) { + c.Assert(apply("foo"), qt.Equals, "foo"+expectBase+"") + }) + + c.Run("No match", func(c *qt.C) { + c.Assert(apply("
"), qt.Equals, "
") + }) + }