From 2f96ef11f902eab2864a592b1aa58ce10fb51e92 Mon Sep 17 00:00:00 2001 From: andrew rumble Date: Wed, 25 Sep 2024 17:22:33 +0100 Subject: [PATCH] Allow ESM OL modules to be loaded Lots of changes to async/await required to allow this. We have to make some changes to handle the fact that modules are loaded async in stages rather than sync (so we can't control when top-level functionality is run in a fine grained way) GitOrigin-RevId: 0127b15bfc4f228a267df3af8519c675e900654e --- .../LinkedFiles/LinkedFilesController.js | 32 ++++--- .../app/src/infrastructure/ExpressLocals.js | 47 +++++----- .../web/app/src/infrastructure/Modules.js | 92 ++++++++++++------- services/web/app/src/router.mjs | 4 +- .../LinkedFiles/LinkedFilesControllerTests.js | 2 +- .../unit/src/Project/ProjectDeleterTests.js | 3 + .../SubscriptionControllerTests.js | 3 + 7 files changed, 112 insertions(+), 71 deletions(-) diff --git a/services/web/app/src/Features/LinkedFiles/LinkedFilesController.js b/services/web/app/src/Features/LinkedFiles/LinkedFilesController.js index 6970bccff8..186c5782f2 100644 --- a/services/web/app/src/Features/LinkedFiles/LinkedFilesController.js +++ b/services/web/app/src/Features/LinkedFiles/LinkedFilesController.js @@ -42,13 +42,16 @@ const { plainTextResponse } = require('../../infrastructure/Response') const ReferencesHandler = require('../References/ReferencesHandler') const EditorRealTimeController = require('../Editor/EditorRealTimeController') const { expressify } = require('@overleaf/promise-utils') +const ProjectOutputFileAgent = require('./ProjectOutputFileAgent') +const ProjectFileAgent = require('./ProjectFileAgent') +const UrlAgent = require('./UrlAgent') async function createLinkedFile(req, res, next) { const { project_id: projectId } = req.params const { name, provider, data, parent_folder_id: parentFolderId } = req.body const userId = SessionManager.getLoggedInUserId(req.session) - const Agent = LinkedFilesController._getAgent(provider) + const Agent = await LinkedFilesController._getAgent(provider) if (Agent == null) { return res.sendStatus(400) } @@ -102,7 +105,7 @@ async function refreshLinkedFile(req, res, next) { const { provider } = linkedFileData const parentFolderId = parentFolder._id - const Agent = LinkedFilesController._getAgent(provider) + const Agent = await LinkedFilesController._getAgent(provider) if (Agent == null) { return res.sendStatus(400) } @@ -144,16 +147,23 @@ async function refreshLinkedFile(req, res, next) { } module.exports = LinkedFilesController = { - Agents: _.extend( - { - url: require('./UrlAgent'), - project_file: require('./ProjectFileAgent'), - project_output_file: require('./ProjectOutputFileAgent'), - }, - Modules.linkedFileAgentsIncludes() - ), + Agents: null, - _getAgent(provider) { + async _cacheAgents() { + if (!LinkedFilesController.Agents) { + LinkedFilesController.Agents = _.extend( + { + url: UrlAgent, + project_file: ProjectFileAgent, + project_output_file: ProjectOutputFileAgent, + }, + await Modules.linkedFileAgentsIncludes() + ) + } + }, + + async _getAgent(provider) { + await LinkedFilesController._cacheAgents() if ( !Object.prototype.hasOwnProperty.call( LinkedFilesController.Agents, diff --git a/services/web/app/src/infrastructure/ExpressLocals.js b/services/web/app/src/infrastructure/ExpressLocals.js index 226d5efcaf..4b2042a29d 100644 --- a/services/web/app/src/infrastructure/ExpressLocals.js +++ b/services/web/app/src/infrastructure/ExpressLocals.js @@ -23,29 +23,31 @@ const { const IEEE_BRAND_ID = Settings.ieeeBrandId let webpackManifest -switch (process.env.NODE_ENV) { - case 'production': - // Only load webpack manifest file in production. - webpackManifest = require('../../../public/manifest.json') - break - case 'development': { - // In dev, fetch the manifest from the webpack container. - loadManifestFromWebpackDevServer() - const intervalHandle = setInterval( - loadManifestFromWebpackDevServer, - 10 * 1000 - ) - addOptionalCleanupHandlerAfterDrainingConnections( - 'refresh webpack manifest', - () => { - clearInterval(intervalHandle) - } - ) - break +function loadManifest() { + switch (process.env.NODE_ENV) { + case 'production': + // Only load webpack manifest file in production. + webpackManifest = require('../../../public/manifest.json') + break + case 'development': { + // In dev, fetch the manifest from the webpack container. + loadManifestFromWebpackDevServer() + const intervalHandle = setInterval( + loadManifestFromWebpackDevServer, + 10 * 1000 + ) + addOptionalCleanupHandlerAfterDrainingConnections( + 'refresh webpack manifest', + () => { + clearInterval(intervalHandle) + } + ) + break + } + default: + // In ci, all entries are undefined. + webpackManifest = {} } - default: - // In ci, all entries are undefined. - webpackManifest = {} } function loadManifestFromWebpackDevServer(done = function () {}) { fetchJson(new URL(`/manifest.json`, Settings.apis.webpack.url), { @@ -72,6 +74,7 @@ function getWebpackAssets(entrypoint, section) { } module.exports = function (webRouter, privateApiRouter, publicApiRouter) { + loadManifest() if (process.env.NODE_ENV === 'development') { // In the dev-env, delay requests until we fetched the manifest once. webRouter.use(function (req, res, next) { diff --git a/services/web/app/src/infrastructure/Modules.js b/services/web/app/src/infrastructure/Modules.js index 5dc4e5b217..81ff2dbf62 100644 --- a/services/web/app/src/infrastructure/Modules.js +++ b/services/web/app/src/infrastructure/Modules.js @@ -4,6 +4,7 @@ const async = require('async') const { promisify } = require('util') const Settings = require('@overleaf/settings') const Views = require('./Views') +const _ = require('lodash') const MODULE_BASE_PATH = Path.join(__dirname, '/../../../modules') @@ -13,27 +14,33 @@ const _hooks = {} const _middleware = {} let _viewIncludes = {} -function modules() { +async function modules() { if (!_modulesLoaded) { - loadModules() + await loadModules() } return _modules } -function loadModules() { +async function loadModulesImpl() { const settingsCheckModule = Path.join( MODULE_BASE_PATH, 'settings-check', 'index.js' ) if (fs.existsSync(settingsCheckModule)) { - require(settingsCheckModule) + await import(settingsCheckModule) } for (const moduleName of Settings.moduleImportSequence || []) { - const loadedModule = require( - Path.join(MODULE_BASE_PATH, moduleName, 'index.js') - ) + let path + if (fs.existsSync(Path.join(MODULE_BASE_PATH, moduleName, 'index.mjs'))) { + path = Path.join(MODULE_BASE_PATH, moduleName, 'index.mjs') + } else { + path = Path.join(MODULE_BASE_PATH, moduleName, 'index.js') + } + const module = await import(path) + const loadedModule = module.default || module + loadedModule.name = moduleName _modules.push(loadedModule) if (loadedModule.viewIncludes) { @@ -52,20 +59,26 @@ function loadModules() { } } _modulesLoaded = true - attachHooks() - attachMiddleware() + await attachHooks() + await attachMiddleware() } -function applyRouter(webRouter, privateApiRouter, publicApiRouter) { - for (const module of modules()) { +const loadModules = _.memoize(loadModulesImpl) + +async function applyRouter(webRouter, privateApiRouter, publicApiRouter) { + for (const module of await modules()) { if (module.router && module.router.apply) { module.router.apply(webRouter, privateApiRouter, publicApiRouter) } } } -function applyNonCsrfRouter(webRouter, privateApiRouter, publicApiRouter) { - for (const module of modules()) { +async function applyNonCsrfRouter( + webRouter, + privateApiRouter, + publicApiRouter +) { + for (const module of await modules()) { if (module.nonCsrfRouter != null) { module.nonCsrfRouter.apply(webRouter, privateApiRouter, publicApiRouter) } @@ -80,7 +93,7 @@ function applyNonCsrfRouter(webRouter, privateApiRouter, publicApiRouter) { } async function start() { - for (const module of modules()) { + for (const module of await modules()) { await module.start?.() } } @@ -89,13 +102,13 @@ function loadViewIncludes(app) { _viewIncludes = Views.compileViewIncludes(app) } -function applyMiddleware(appOrRouter, middlewareName, options) { +async function applyMiddleware(appOrRouter, middlewareName, options) { if (!middlewareName) { throw new Error( 'middleware name must be provided to register module middleware' ) } - for (const module of modules()) { + for (const module of await modules()) { if (module[middlewareName]) { module[middlewareName](appOrRouter, options) } @@ -115,9 +128,9 @@ function moduleIncludesAvailable(view) { return (_viewIncludes[view] || []).length > 0 } -function linkedFileAgentsIncludes() { +async function linkedFileAgentsIncludes() { const agents = {} - for (const module of modules()) { + for (const module of await modules()) { for (const name in module.linkedFileAgents) { const agentFunction = module.linkedFileAgents[name] agents[name] = agentFunction() @@ -126,8 +139,8 @@ function linkedFileAgentsIncludes() { return agents } -function attachHooks() { - for (const module of modules()) { +async function attachHooks() { + for (const module of await modules()) { for (const hook in module.hooks || {}) { const method = module.hooks[hook] attachHook(hook, method) @@ -142,8 +155,8 @@ function attachHook(name, method) { _hooks[name].push(method) } -function attachMiddleware() { - for (const module of modules()) { +async function attachMiddleware() { + for (const module of await modules()) { for (const middleware in module.middleware || {}) { const method = module.middleware[middleware] if (_middleware[middleware] == null) { @@ -155,28 +168,37 @@ function attachMiddleware() { } function fireHook(name, ...rest) { + const adjustedLength = Math.max(rest.length, 1) + const args = rest.slice(0, adjustedLength - 1) + const callback = rest[adjustedLength - 1] + function fire() { + const methods = _hooks[name] || [] + const callMethods = methods.map(method => cb => method(...args, cb)) + async.series(callMethods, function (error, results) { + if (error) { + return callback(error) + } + callback(null, results) + }) + } + // ensure that modules are loaded if we need to fire a hook // this can happen if a script calls a method that fires a hook if (!_modulesLoaded) { loadModules() + .then(() => { + fire() + }) + .catch(err => callback(err)) + } else { + fire() } - const adjustedLength = Math.max(rest.length, 1) - const args = rest.slice(0, adjustedLength - 1) - const callback = rest[adjustedLength - 1] - const methods = _hooks[name] || [] - const callMethods = methods.map(method => cb => method(...args, cb)) - async.series(callMethods, function (error, results) { - if (error) { - return callback(error) - } - callback(null, results) - }) } -function getMiddleware(name) { +async function getMiddleware(name) { // ensure that modules are loaded if we need to call a middleware if (!_modulesLoaded) { - loadModules() + await loadModules() } return _middleware[name] || [] } diff --git a/services/web/app/src/router.mjs b/services/web/app/src/router.mjs index 4bd1047fdd..733118ea2a 100644 --- a/services/web/app/src/router.mjs +++ b/services/web/app/src/router.mjs @@ -351,7 +351,7 @@ async function initialize(webRouter, privateApiRouter, publicApiRouter) { '/user/emails/resend_confirmation', AuthenticationController.requireLogin(), RateLimiterMiddleware.rateLimit(rateLimiters.resendConfirmation), - Modules.middleware('resendConfirmationEmail'), + await Modules.middleware('resendConfirmationEmail'), UserEmailsController.resendConfirmation ) @@ -382,7 +382,7 @@ async function initialize(webRouter, privateApiRouter, publicApiRouter) { '/user/emails/delete', AuthenticationController.requireLogin(), RateLimiterMiddleware.rateLimit(rateLimiters.deleteEmail), - Modules.middleware('userDeleteEmail'), + await Modules.middleware('userDeleteEmail'), UserEmailsController.remove ) webRouter.post( diff --git a/services/web/test/unit/src/LinkedFiles/LinkedFilesControllerTests.js b/services/web/test/unit/src/LinkedFiles/LinkedFilesControllerTests.js index a0a08be251..87457968dd 100644 --- a/services/web/test/unit/src/LinkedFiles/LinkedFilesControllerTests.js +++ b/services/web/test/unit/src/LinkedFiles/LinkedFilesControllerTests.js @@ -61,7 +61,7 @@ describe('LinkedFilesController', function () { '@overleaf/settings': this.settings, }, }) - this.LinkedFilesController._getAgent = sinon.stub().returns(this.Agent) + this.LinkedFilesController._getAgent = sinon.stub().resolves(this.Agent) }) describe('createLinkedFile', function () { diff --git a/services/web/test/unit/src/Project/ProjectDeleterTests.js b/services/web/test/unit/src/Project/ProjectDeleterTests.js index e5dc2ea2d6..6946aa348b 100644 --- a/services/web/test/unit/src/Project/ProjectDeleterTests.js +++ b/services/web/test/unit/src/Project/ProjectDeleterTests.js @@ -140,6 +140,9 @@ describe('ProjectDeleter', function () { } this.ProjectDeleter = SandboxedModule.require(modulePath, { requires: { + '../../infrastructure/Modules': { + promises: { hooks: { fire: sinon.stub().resolves() } }, + }, '../../infrastructure/Features': this.Features, '../Editor/EditorRealTimeController': this.EditorRealTimeController, '../../models/Project': { Project }, diff --git a/services/web/test/unit/src/Subscription/SubscriptionControllerTests.js b/services/web/test/unit/src/Subscription/SubscriptionControllerTests.js index e22cd46d21..c416a14cb7 100644 --- a/services/web/test/unit/src/Subscription/SubscriptionControllerTests.js +++ b/services/web/test/unit/src/Subscription/SubscriptionControllerTests.js @@ -172,6 +172,9 @@ describe('SubscriptionController', function () { recordEventForSession: sinon.stub(), setUserPropertyForUser: sinon.stub(), }), + '../../infrastructure/Modules': { + promises: { hooks: { fire: sinon.stub().resolves() } }, + }, '../../infrastructure/Features': this.Features, '../../util/currency': (this.currency = { formatCurrencyLocalized: sinon.stub(),