From 2bec7027ae4d81c4dd3de7cb7ef6b7c58f076c10 Mon Sep 17 00:00:00 2001 From: Tilman Vatteroth Date: Sat, 8 Apr 2023 22:43:13 +0200 Subject: [PATCH] fix: fix comma separated value detection in x-forwarded-proto parsing Signed-off-by: Tilman Vatteroth --- .../utils/determine-current-origin.spec.ts | 151 ++++++++++++++++++ .../src/utils/determine-current-origin.ts | 31 ++-- 2 files changed, 170 insertions(+), 12 deletions(-) create mode 100644 frontend/src/utils/determine-current-origin.spec.ts diff --git a/frontend/src/utils/determine-current-origin.spec.ts b/frontend/src/utils/determine-current-origin.spec.ts new file mode 100644 index 000000000..6fdbb4d2f --- /dev/null +++ b/frontend/src/utils/determine-current-origin.spec.ts @@ -0,0 +1,151 @@ +/* + * SPDX-FileCopyrightText: 2023 The HedgeDoc developers (see AUTHORS file) + * + * SPDX-License-Identifier: AGPL-3.0-only + */ +import { determineCurrentOrigin } from './determine-current-origin' +import * as IsClientSideRenderingModule from './is-client-side-rendering' +import type { NextPageContext } from 'next' +import { Mock } from 'ts-mockery' + +jest.mock('./is-client-side-rendering') +describe('determineCurrentOrigin', () => { + describe('client side', () => { + it('parses a client side origin correctly', () => { + jest.spyOn(IsClientSideRenderingModule, 'isClientSideRendering').mockImplementation(() => true) + const expectedOrigin = 'expectedOrigin' + Object.defineProperty(window, 'location', { value: { origin: expectedOrigin } }) + expect(determineCurrentOrigin(Mock.of({}))).toBe(expectedOrigin) + }) + }) + + describe('server side', () => { + beforeEach(() => { + jest.spyOn(IsClientSideRenderingModule, 'isClientSideRendering').mockImplementation(() => false) + }) + + it("won't return an origin if no request is present", () => { + expect(determineCurrentOrigin(Mock.of({}))).toBeUndefined() + }) + + it("won't return an origin if no headers are present", () => { + expect(determineCurrentOrigin(Mock.of({ req: { headers: undefined } }))).toBeUndefined() + }) + + it("won't return an origin if no host is present", () => { + expect( + determineCurrentOrigin( + Mock.of({ + req: { + headers: {} + } + }) + ) + ).toBeUndefined() + }) + + it('will return an origin for a forwarded host', () => { + expect( + determineCurrentOrigin( + Mock.of({ + req: { + headers: { + 'x-forwarded-host': 'forwardedMockHost', + 'x-forwarded-proto': 'mockProtocol' + } + } + }) + ) + ).toBe('mockProtocol://forwardedMockHost') + }) + + it("will fallback to host header if x-forwarded-host isn't present", () => { + expect( + determineCurrentOrigin( + Mock.of({ + req: { + headers: { + host: 'mockHost', + 'x-forwarded-proto': 'mockProtocol' + } + } + }) + ) + ).toBe('mockProtocol://mockHost') + }) + + it('will prefer x-forwarded-host over host', () => { + expect( + determineCurrentOrigin( + Mock.of({ + req: { + headers: { + 'x-forwarded-host': 'forwardedMockHost', + host: 'mockHost', + 'x-forwarded-proto': 'mockProtocol' + } + } + }) + ) + ).toBe('mockProtocol://forwardedMockHost') + }) + + it('will fallback to http if x-forwarded-proto is missing', () => { + expect( + determineCurrentOrigin( + Mock.of({ + req: { + headers: { + 'x-forwarded-host': 'forwardedMockHost' + } + } + }) + ) + ).toBe('http://forwardedMockHost') + }) + + it('will use the first header if x-forwarded-proto is defined multiple times', () => { + expect( + determineCurrentOrigin( + Mock.of({ + req: { + headers: { + 'x-forwarded-proto': ['mockProtocol1', 'mockProtocol2'], + 'x-forwarded-host': 'forwardedMockHost' + } + } + }) + ) + ).toBe('mockProtocol1://forwardedMockHost') + }) + + it('will use the first header if x-forwarded-host is defined multiple times', () => { + expect( + determineCurrentOrigin( + Mock.of({ + req: { + headers: { + 'x-forwarded-host': ['forwardedMockHost1', 'forwardedMockHost2'] + } + } + }) + ) + ).toBe('http://forwardedMockHost1') + }) + + it('will use the first value if x-forwarded-proto is a comma separated list', () => { + expect( + determineCurrentOrigin( + Mock.of({ + req: { + headers: { + 'x-forwarded-proto': 'mockProtocol1,mockProtocol2', + 'x-forwarded-host': 'forwardedMockHost' + } + } + }) + ) + ).toBe('mockProtocol1://forwardedMockHost') + }) + }) +}) diff --git a/frontend/src/utils/determine-current-origin.ts b/frontend/src/utils/determine-current-origin.ts index 573899f95..fcadf16b5 100644 --- a/frontend/src/utils/determine-current-origin.ts +++ b/frontend/src/utils/determine-current-origin.ts @@ -4,6 +4,8 @@ * SPDX-License-Identifier: AGPL-3.0-only */ import { isClientSideRendering } from './is-client-side-rendering' +import { Optional } from '@mrdrogdrog/optional' +import type { IncomingHttpHeaders } from 'http' import type { NextPageContext } from 'next' /** @@ -18,16 +20,21 @@ export const determineCurrentOrigin = (context: NextPageContext): string | undef if (isClientSideRendering()) { return window.location.origin } - const headers = context.req?.headers - if (headers === undefined) { - return undefined - } - - const protocol = headers['x-forwarded-proto'] ?? 'http' - const host = headers['x-forwarded-host'] ?? headers['host'] - if (host === undefined) { - return undefined - } - - return `${protocol as string}://${host as string}` + return Optional.ofNullable(context.req?.headers) + .flatMap((headers) => buildOriginFromHeaders(headers)) + .orElse(undefined) +} + +const buildOriginFromHeaders = (headers: IncomingHttpHeaders) => { + const rawHost = headers['x-forwarded-host'] ?? headers['host'] + return extractFirstValue(rawHost).map((host) => { + const protocol = extractFirstValue(headers['x-forwarded-proto']).orElse('http') + return `${protocol}://${host}` + }) +} + +const extractFirstValue = (rawValue: string | string[] | undefined): Optional => { + return Optional.ofNullable(rawValue) + .map((value) => (typeof value === 'string' ? value : value[0])) + .map((value) => value.split(',')[0]) }