From f4454cfe7ef803e955b912d8560c23dd8699655b Mon Sep 17 00:00:00 2001 From: Miguel Serrano Date: Thu, 22 Feb 2024 15:59:17 +0100 Subject: [PATCH] [web] Make SamlLogHandler.log() calls asynchronous (#17207) * [web] Refactor exports in ErrorController * [web] Make SamlLogHandler.log() async * [web] await for SamlLogHandler.log() in ErrorController * [web] await for SamlLogHandler.log() in SAMLMiddleware * [web] await for SamlLogHandler.log() async controllers * [web] await for SamlLogHandler.log() in SAMLManager * [web] Remove explicit wait when testing SAML logs After making the logs asynchronouse the wait is no longer needed * [web] Avoid using async with SamlLogHandler.log on callbacks * Add expressifyErrorHandler to promise-utils * Tighten assertion in SAMLMiddlewareTests Co-authored-by: Jakob Ackermann * Updated SamlLogHandler.log to await for promise --------- Co-authored-by: Jakob Ackermann GitOrigin-RevId: 3645923fae8096a9ba25dc9087f1a36231528569 --- libraries/promise-utils/index.js | 12 + .../test/unit/PromiseUtilsTests.js | 22 ++ .../src/Features/Errors/ErrorController.js | 210 +++++++++--------- .../src/Features/SamlLog/SamlLogHandler.js | 14 +- .../unit/src/SamlLog/SamlLogHandlerTests.js | 22 +- 5 files changed, 163 insertions(+), 117 deletions(-) diff --git a/libraries/promise-utils/index.js b/libraries/promise-utils/index.js index 5437a36997..5b81a43e14 100644 --- a/libraries/promise-utils/index.js +++ b/libraries/promise-utils/index.js @@ -10,6 +10,7 @@ module.exports = { callbackifyAll, callbackifyMultiResult, expressify, + expressifyErrorHandler, promiseMapWithLimit, } @@ -208,6 +209,17 @@ function expressify(fn) { } } +/** + * Transform an async function into an Error Handling Express middleware + * + * Any error will be passed to the error middlewares via `next()` + */ +function expressifyErrorHandler(fn) { + return (err, req, res, next) => { + fn(err, req, res, next).catch(next) + } +} + /** * Map values in `array` with the async function `fn` * diff --git a/libraries/promise-utils/test/unit/PromiseUtilsTests.js b/libraries/promise-utils/test/unit/PromiseUtilsTests.js index c91121d757..39304cdc44 100644 --- a/libraries/promise-utils/test/unit/PromiseUtilsTests.js +++ b/libraries/promise-utils/test/unit/PromiseUtilsTests.js @@ -4,6 +4,8 @@ const { promisifyClass, callbackifyMultiResult, callbackifyAll, + expressify, + expressifyErrorHandler, } = require('../..') describe('promisifyAll', function () { @@ -324,3 +326,23 @@ describe('callbackifyAll', function () { }) }) }) + +describe('expressify', function () { + it('should propagate any rejection to the "next" callback', function (done) { + const fn = () => Promise.reject(new Error('rejected')) + expressify(fn)({}, {}, error => { + expect(error.message).to.equal('rejected') + done() + }) + }) +}) + +describe('expressifyErrorHandler', function () { + it('should propagate any rejection to the "next" callback', function (done) { + const fn = () => Promise.reject(new Error('rejected')) + expressifyErrorHandler(fn)({}, {}, {}, error => { + expect(error.message).to.equal('rejected') + done() + }) + }) +}) diff --git a/services/web/app/src/Features/Errors/ErrorController.js b/services/web/app/src/Features/Errors/ErrorController.js index eb7eb95bdf..989d76861d 100644 --- a/services/web/app/src/Features/Errors/ErrorController.js +++ b/services/web/app/src/Features/Errors/ErrorController.js @@ -1,109 +1,115 @@ -let ErrorController const Errors = require('./Errors') const SessionManager = require('../Authentication/SessionManager') const SamlLogHandler = require('../SamlLog/SamlLogHandler') const HttpErrorHandler = require('./HttpErrorHandler') const { plainTextResponse } = require('../../infrastructure/Response') +const { expressifyErrorHandler } = require('@overleaf/promise-utils') -module.exports = ErrorController = { - notFound(req, res) { - res.status(404) - res.render('general/404', { title: 'page_not_found' }) - }, - - forbidden(req, res) { - res.status(403) - res.render('user/restricted') - }, - - serverError(req, res) { - res.status(500) - res.render('general/500', { title: 'Server Error' }) - }, - - handleError(error, req, res, next) { - const shouldSendErrorResponse = !res.headersSent - const user = SessionManager.getSessionUser(req.session) - req.logger.addFields({ err: error }) - // log errors related to SAML flow - if (req.session && req.session.saml) { - req.logger.setLevel('error') - SamlLogHandler.log(req, { error }) - } - if (error.code === 'EBADCSRFTOKEN') { - req.logger.addFields({ user }) - req.logger.setLevel('warn') - if (shouldSendErrorResponse) { - res.sendStatus(403) - } - } else if (error instanceof Errors.NotFoundError) { - req.logger.setLevel('warn') - if (shouldSendErrorResponse) { - ErrorController.notFound(req, res) - } - } else if ( - error instanceof URIError && - error.message.match(/^Failed to decode param/) - ) { - req.logger.setLevel('warn') - if (shouldSendErrorResponse) { - res.status(400) - res.render('general/500', { title: 'Invalid Error' }) - } - } else if (error instanceof Errors.ForbiddenError) { - req.logger.setLevel('warn') - if (shouldSendErrorResponse) { - ErrorController.forbidden(req, res) - } - } else if (error instanceof Errors.TooManyRequestsError) { - req.logger.setLevel('warn') - if (shouldSendErrorResponse) { - res.sendStatus(429) - } - } else if (error instanceof Errors.InvalidError) { - req.logger.setLevel('warn') - if (shouldSendErrorResponse) { - res.status(400) - plainTextResponse(res, error.message) - } - } else if (error instanceof Errors.InvalidNameError) { - req.logger.setLevel('warn') - if (shouldSendErrorResponse) { - res.status(400) - plainTextResponse(res, error.message) - } - } else if (error instanceof Errors.SAMLSessionDataMissing) { - req.logger.setLevel('warn') - if (shouldSendErrorResponse) { - HttpErrorHandler.badRequest(req, res, error.message) - } - } else { - req.logger.setLevel('error') - if (shouldSendErrorResponse) { - ErrorController.serverError(req, res) - } - } - if (!shouldSendErrorResponse) { - // Pass the error to the default Express error handler, which will close - // the connection. - next(error) - } - }, - - handleApiError(err, req, res, next) { - req.logger.addFields({ err }) - if (err instanceof Errors.NotFoundError) { - req.logger.setLevel('warn') - res.sendStatus(404) - } else if ( - err instanceof URIError && - err.message.match(/^Failed to decode param/) - ) { - req.logger.setLevel('warn') - res.sendStatus(400) - } else { - req.logger.setLevel('error') - res.sendStatus(500) - } - }, +function notFound(req, res) { + res.status(404) + res.render('general/404', { title: 'page_not_found' }) +} + +function forbidden(req, res) { + res.status(403) + res.render('user/restricted') +} + +function serverError(req, res) { + res.status(500) + res.render('general/500', { title: 'Server Error' }) +} + +async function handleError(error, req, res, next) { + const shouldSendErrorResponse = !res.headersSent + const user = SessionManager.getSessionUser(req.session) + req.logger.addFields({ err: error }) + // log errors related to SAML flow + if (req.session && req.session.saml) { + req.logger.setLevel('error') + await SamlLogHandler.promises.log(req, { error }) + } + if (error.code === 'EBADCSRFTOKEN') { + req.logger.addFields({ user }) + req.logger.setLevel('warn') + if (shouldSendErrorResponse) { + res.sendStatus(403) + } + } else if (error instanceof Errors.NotFoundError) { + req.logger.setLevel('warn') + if (shouldSendErrorResponse) { + notFound(req, res) + } + } else if ( + error instanceof URIError && + error.message.match(/^Failed to decode param/) + ) { + req.logger.setLevel('warn') + if (shouldSendErrorResponse) { + res.status(400) + res.render('general/500', { title: 'Invalid Error' }) + } + } else if (error instanceof Errors.ForbiddenError) { + req.logger.setLevel('warn') + if (shouldSendErrorResponse) { + forbidden(req, res) + } + } else if (error instanceof Errors.TooManyRequestsError) { + req.logger.setLevel('warn') + if (shouldSendErrorResponse) { + res.sendStatus(429) + } + } else if (error instanceof Errors.InvalidError) { + req.logger.setLevel('warn') + if (shouldSendErrorResponse) { + res.status(400) + plainTextResponse(res, error.message) + } + } else if (error instanceof Errors.InvalidNameError) { + req.logger.setLevel('warn') + if (shouldSendErrorResponse) { + res.status(400) + plainTextResponse(res, error.message) + } + } else if (error instanceof Errors.SAMLSessionDataMissing) { + req.logger.setLevel('warn') + if (shouldSendErrorResponse) { + HttpErrorHandler.badRequest(req, res, error.message) + } + } else { + req.logger.setLevel('error') + if (shouldSendErrorResponse) { + serverError(req, res) + } + } + if (!shouldSendErrorResponse) { + // Pass the error to the default Express error handler, which will close + // the connection. + next(error) + } +} + +function handleApiError(err, req, res, next) { + req.logger.addFields({ err }) + if (err instanceof Errors.NotFoundError) { + req.logger.setLevel('warn') + res.sendStatus(404) + } else if ( + err instanceof URIError && + err.message.match(/^Failed to decode param/) + ) { + req.logger.setLevel('warn') + res.sendStatus(400) + } else { + req.logger.setLevel('error') + res.sendStatus(500) + } +} + +module.exports = { + notFound, + forbidden, + serverError, + handleError: expressifyErrorHandler(handleError), + handleApiError, } diff --git a/services/web/app/src/Features/SamlLog/SamlLogHandler.js b/services/web/app/src/Features/SamlLog/SamlLogHandler.js index 9ba554ad0a..24f68aea14 100644 --- a/services/web/app/src/Features/SamlLog/SamlLogHandler.js +++ b/services/web/app/src/Features/SamlLog/SamlLogHandler.js @@ -2,8 +2,9 @@ const { SamlLog } = require('../../models/SamlLog') const SessionManager = require('../Authentication/SessionManager') const logger = require('@overleaf/logger') const { err: errSerializer } = require('@overleaf/logger/serializers') +const { callbackify } = require('util') -function log(req, data, samlAssertion) { +async function log(req, data, samlAssertion) { let providerId, sessionId data = data || {} @@ -61,18 +62,17 @@ function log(req, data, samlAssertion) { 'SamlLog JSON.stringify Error' ) } - samlLog.save(err => { - if (err) { - logger.error({ err, sessionId, providerId }, 'SamlLog Error') - } - }) + await samlLog.save() } catch (err) { logger.error({ err, sessionId, providerId }, 'SamlLog Error') } } const SamlLogHandler = { - log, + log: callbackify(log), + promises: { + log, + }, } module.exports = SamlLogHandler diff --git a/services/web/test/unit/src/SamlLog/SamlLogHandlerTests.js b/services/web/test/unit/src/SamlLog/SamlLogHandlerTests.js index fbb9f51721..5400a1b963 100644 --- a/services/web/test/unit/src/SamlLog/SamlLogHandlerTests.js +++ b/services/web/test/unit/src/SamlLog/SamlLogHandlerTests.js @@ -29,8 +29,8 @@ describe('SamlLogHandler', function () { }) describe('with valid data object', function () { - beforeEach(function () { - SamlLogHandler.log( + beforeEach(async function () { + await SamlLogHandler.promises.log( { session: { saml: { universityId: providerId } }, sessionID: sessionId, @@ -54,11 +54,11 @@ describe('SamlLogHandler', function () { }) describe('when a json stringify error occurs', function () { - beforeEach(function () { + beforeEach(async function () { const circularRef = {} circularRef.circularRef = circularRef - SamlLogHandler.log( + await SamlLogHandler.promises.log( { session: { saml: { universityId: providerId } }, sessionID: sessionId, @@ -81,10 +81,13 @@ describe('SamlLogHandler', function () { }) describe('when logging error occurs', function () { - beforeEach(function () { - samlLog.save = sinon.stub().yields('error') + let err - SamlLogHandler.log( + beforeEach(async function () { + err = new Error() + samlLog.save = sinon.stub().rejects(err) + + await SamlLogHandler.promises.log( { session: { saml: { universityId: providerId } }, sessionID: sessionId, @@ -95,7 +98,10 @@ describe('SamlLogHandler', function () { it('should log error', function () { this.logger.error.should.have.been.calledOnce.and.calledWithMatch( - { err: 'error', providerId, sessionId: sessionId.substr(0, 8) }, + { + err, + sessionId: sessionId.substr(0, 8), + }, 'SamlLog Error' ) })