[web] Make SamlLogHandler.log() calls asynchronous (#17207)

* [web] Refactor exports in ErrorController

* [web] Make SamlLogHandler.log() async

* [web] await for SamlLogHandler.log() in ErrorController

* [web] await for SamlLogHandler.log() in SAMLMiddleware

* [web] await for SamlLogHandler.log() async controllers

* [web] await for SamlLogHandler.log() in SAMLManager

* [web] Remove explicit wait when testing SAML logs

After making the logs asynchronouse the wait
is no longer needed

* [web] Avoid using async with SamlLogHandler.log on callbacks

* Add expressifyErrorHandler to promise-utils

* Tighten assertion in SAMLMiddlewareTests

Co-authored-by: Jakob Ackermann <jakob.ackermann@overleaf.com>

* Updated SamlLogHandler.log to await for promise

---------

Co-authored-by: Jakob Ackermann <jakob.ackermann@overleaf.com>
GitOrigin-RevId: 3645923fae8096a9ba25dc9087f1a36231528569
This commit is contained in:
Miguel Serrano 2024-02-22 15:59:17 +01:00 committed by Copybot
parent 680c9b9570
commit f4454cfe7e
5 changed files with 163 additions and 117 deletions

View file

@ -10,6 +10,7 @@ module.exports = {
callbackifyAll, callbackifyAll,
callbackifyMultiResult, callbackifyMultiResult,
expressify, expressify,
expressifyErrorHandler,
promiseMapWithLimit, promiseMapWithLimit,
} }
@ -208,6 +209,17 @@ function expressify(fn) {
} }
} }
/**
* Transform an async function into an Error Handling Express middleware
*
* Any error will be passed to the error middlewares via `next()`
*/
function expressifyErrorHandler(fn) {
return (err, req, res, next) => {
fn(err, req, res, next).catch(next)
}
}
/** /**
* Map values in `array` with the async function `fn` * Map values in `array` with the async function `fn`
* *

View file

@ -4,6 +4,8 @@ const {
promisifyClass, promisifyClass,
callbackifyMultiResult, callbackifyMultiResult,
callbackifyAll, callbackifyAll,
expressify,
expressifyErrorHandler,
} = require('../..') } = require('../..')
describe('promisifyAll', function () { describe('promisifyAll', function () {
@ -324,3 +326,23 @@ describe('callbackifyAll', function () {
}) })
}) })
}) })
describe('expressify', function () {
it('should propagate any rejection to the "next" callback', function (done) {
const fn = () => Promise.reject(new Error('rejected'))
expressify(fn)({}, {}, error => {
expect(error.message).to.equal('rejected')
done()
})
})
})
describe('expressifyErrorHandler', function () {
it('should propagate any rejection to the "next" callback', function (done) {
const fn = () => Promise.reject(new Error('rejected'))
expressifyErrorHandler(fn)({}, {}, {}, error => {
expect(error.message).to.equal('rejected')
done()
})
})
})

View file

@ -1,109 +1,115 @@
let ErrorController
const Errors = require('./Errors') const Errors = require('./Errors')
const SessionManager = require('../Authentication/SessionManager') const SessionManager = require('../Authentication/SessionManager')
const SamlLogHandler = require('../SamlLog/SamlLogHandler') const SamlLogHandler = require('../SamlLog/SamlLogHandler')
const HttpErrorHandler = require('./HttpErrorHandler') const HttpErrorHandler = require('./HttpErrorHandler')
const { plainTextResponse } = require('../../infrastructure/Response') const { plainTextResponse } = require('../../infrastructure/Response')
const { expressifyErrorHandler } = require('@overleaf/promise-utils')
module.exports = ErrorController = { function notFound(req, res) {
notFound(req, res) { res.status(404)
res.status(404) res.render('general/404', { title: 'page_not_found' })
res.render('general/404', { title: 'page_not_found' }) }
},
function forbidden(req, res) {
forbidden(req, res) { res.status(403)
res.status(403) res.render('user/restricted')
res.render('user/restricted') }
},
function serverError(req, res) {
serverError(req, res) { res.status(500)
res.status(500) res.render('general/500', { title: 'Server Error' })
res.render('general/500', { title: 'Server Error' }) }
},
async function handleError(error, req, res, next) {
handleError(error, req, res, next) { const shouldSendErrorResponse = !res.headersSent
const shouldSendErrorResponse = !res.headersSent const user = SessionManager.getSessionUser(req.session)
const user = SessionManager.getSessionUser(req.session) req.logger.addFields({ err: error })
req.logger.addFields({ err: error }) // log errors related to SAML flow
// log errors related to SAML flow if (req.session && req.session.saml) {
if (req.session && req.session.saml) { req.logger.setLevel('error')
req.logger.setLevel('error') await SamlLogHandler.promises.log(req, { error })
SamlLogHandler.log(req, { error }) }
} if (error.code === 'EBADCSRFTOKEN') {
if (error.code === 'EBADCSRFTOKEN') { req.logger.addFields({ user })
req.logger.addFields({ user }) req.logger.setLevel('warn')
req.logger.setLevel('warn') if (shouldSendErrorResponse) {
if (shouldSendErrorResponse) { res.sendStatus(403)
res.sendStatus(403) }
} } else if (error instanceof Errors.NotFoundError) {
} else if (error instanceof Errors.NotFoundError) { req.logger.setLevel('warn')
req.logger.setLevel('warn') if (shouldSendErrorResponse) {
if (shouldSendErrorResponse) { notFound(req, res)
ErrorController.notFound(req, res) }
} } else if (
} else if ( error instanceof URIError &&
error instanceof URIError && error.message.match(/^Failed to decode param/)
error.message.match(/^Failed to decode param/) ) {
) { req.logger.setLevel('warn')
req.logger.setLevel('warn') if (shouldSendErrorResponse) {
if (shouldSendErrorResponse) { res.status(400)
res.status(400) res.render('general/500', { title: 'Invalid Error' })
res.render('general/500', { title: 'Invalid Error' }) }
} } else if (error instanceof Errors.ForbiddenError) {
} else if (error instanceof Errors.ForbiddenError) { req.logger.setLevel('warn')
req.logger.setLevel('warn') if (shouldSendErrorResponse) {
if (shouldSendErrorResponse) { forbidden(req, res)
ErrorController.forbidden(req, res) }
} } else if (error instanceof Errors.TooManyRequestsError) {
} else if (error instanceof Errors.TooManyRequestsError) { req.logger.setLevel('warn')
req.logger.setLevel('warn') if (shouldSendErrorResponse) {
if (shouldSendErrorResponse) { res.sendStatus(429)
res.sendStatus(429) }
} } else if (error instanceof Errors.InvalidError) {
} else if (error instanceof Errors.InvalidError) { req.logger.setLevel('warn')
req.logger.setLevel('warn') if (shouldSendErrorResponse) {
if (shouldSendErrorResponse) { res.status(400)
res.status(400) plainTextResponse(res, error.message)
plainTextResponse(res, error.message) }
} } else if (error instanceof Errors.InvalidNameError) {
} else if (error instanceof Errors.InvalidNameError) { req.logger.setLevel('warn')
req.logger.setLevel('warn') if (shouldSendErrorResponse) {
if (shouldSendErrorResponse) { res.status(400)
res.status(400) plainTextResponse(res, error.message)
plainTextResponse(res, error.message) }
} } else if (error instanceof Errors.SAMLSessionDataMissing) {
} else if (error instanceof Errors.SAMLSessionDataMissing) { req.logger.setLevel('warn')
req.logger.setLevel('warn') if (shouldSendErrorResponse) {
if (shouldSendErrorResponse) { HttpErrorHandler.badRequest(req, res, error.message)
HttpErrorHandler.badRequest(req, res, error.message) }
} } else {
} else { req.logger.setLevel('error')
req.logger.setLevel('error') if (shouldSendErrorResponse) {
if (shouldSendErrorResponse) { serverError(req, res)
ErrorController.serverError(req, res) }
} }
} if (!shouldSendErrorResponse) {
if (!shouldSendErrorResponse) { // Pass the error to the default Express error handler, which will close
// Pass the error to the default Express error handler, which will close // the connection.
// the connection. next(error)
next(error) }
} }
},
function handleApiError(err, req, res, next) {
handleApiError(err, req, res, next) { req.logger.addFields({ err })
req.logger.addFields({ err }) if (err instanceof Errors.NotFoundError) {
if (err instanceof Errors.NotFoundError) { req.logger.setLevel('warn')
req.logger.setLevel('warn') res.sendStatus(404)
res.sendStatus(404) } else if (
} else if ( err instanceof URIError &&
err instanceof URIError && err.message.match(/^Failed to decode param/)
err.message.match(/^Failed to decode param/) ) {
) { req.logger.setLevel('warn')
req.logger.setLevel('warn') res.sendStatus(400)
res.sendStatus(400) } else {
} else { req.logger.setLevel('error')
req.logger.setLevel('error') res.sendStatus(500)
res.sendStatus(500) }
} }
},
module.exports = {
notFound,
forbidden,
serverError,
handleError: expressifyErrorHandler(handleError),
handleApiError,
} }

View file

@ -2,8 +2,9 @@ const { SamlLog } = require('../../models/SamlLog')
const SessionManager = require('../Authentication/SessionManager') const SessionManager = require('../Authentication/SessionManager')
const logger = require('@overleaf/logger') const logger = require('@overleaf/logger')
const { err: errSerializer } = require('@overleaf/logger/serializers') const { err: errSerializer } = require('@overleaf/logger/serializers')
const { callbackify } = require('util')
function log(req, data, samlAssertion) { async function log(req, data, samlAssertion) {
let providerId, sessionId let providerId, sessionId
data = data || {} data = data || {}
@ -61,18 +62,17 @@ function log(req, data, samlAssertion) {
'SamlLog JSON.stringify Error' 'SamlLog JSON.stringify Error'
) )
} }
samlLog.save(err => { await samlLog.save()
if (err) {
logger.error({ err, sessionId, providerId }, 'SamlLog Error')
}
})
} catch (err) { } catch (err) {
logger.error({ err, sessionId, providerId }, 'SamlLog Error') logger.error({ err, sessionId, providerId }, 'SamlLog Error')
} }
} }
const SamlLogHandler = { const SamlLogHandler = {
log, log: callbackify(log),
promises: {
log,
},
} }
module.exports = SamlLogHandler module.exports = SamlLogHandler

View file

@ -29,8 +29,8 @@ describe('SamlLogHandler', function () {
}) })
describe('with valid data object', function () { describe('with valid data object', function () {
beforeEach(function () { beforeEach(async function () {
SamlLogHandler.log( await SamlLogHandler.promises.log(
{ {
session: { saml: { universityId: providerId } }, session: { saml: { universityId: providerId } },
sessionID: sessionId, sessionID: sessionId,
@ -54,11 +54,11 @@ describe('SamlLogHandler', function () {
}) })
describe('when a json stringify error occurs', function () { describe('when a json stringify error occurs', function () {
beforeEach(function () { beforeEach(async function () {
const circularRef = {} const circularRef = {}
circularRef.circularRef = circularRef circularRef.circularRef = circularRef
SamlLogHandler.log( await SamlLogHandler.promises.log(
{ {
session: { saml: { universityId: providerId } }, session: { saml: { universityId: providerId } },
sessionID: sessionId, sessionID: sessionId,
@ -81,10 +81,13 @@ describe('SamlLogHandler', function () {
}) })
describe('when logging error occurs', function () { describe('when logging error occurs', function () {
beforeEach(function () { let err
samlLog.save = sinon.stub().yields('error')
SamlLogHandler.log( beforeEach(async function () {
err = new Error()
samlLog.save = sinon.stub().rejects(err)
await SamlLogHandler.promises.log(
{ {
session: { saml: { universityId: providerId } }, session: { saml: { universityId: providerId } },
sessionID: sessionId, sessionID: sessionId,
@ -95,7 +98,10 @@ describe('SamlLogHandler', function () {
it('should log error', function () { it('should log error', function () {
this.logger.error.should.have.been.calledOnce.and.calledWithMatch( this.logger.error.should.have.been.calledOnce.and.calledWithMatch(
{ err: 'error', providerId, sessionId: sessionId.substr(0, 8) }, {
err,
sessionId: sessionId.substr(0, 8),
},
'SamlLog Error' 'SamlLog Error'
) )
}) })