|
1 | 1 | import type { MiddlewareOptions } from '@stacksjs/types' |
| 2 | +import type { EnhancedRequest, MiddlewareHandler, NextFunction } from 'bun-router' |
2 | 3 | import type { Request } from './request' |
3 | 4 | import { userMiddlewarePath } from '@stacksjs/path' |
| 5 | +import { fs } from '@stacksjs/storage' |
| 6 | +import { log } from '@stacksjs/logging' |
4 | 7 |
|
| 8 | +/** |
| 9 | + * Stacks Middleware class - wrapper for bun-router compatible middleware |
| 10 | + * |
| 11 | + * This provides backwards compatibility with the existing Stacks middleware |
| 12 | + * system while integrating with bun-router's middleware pipeline. |
| 13 | + */ |
5 | 14 | export class Middleware implements MiddlewareOptions { |
| 15 | + /** Middleware name for registration */ |
6 | 16 | name: string |
| 17 | + |
| 18 | + /** Execution priority (lower = earlier) */ |
7 | 19 | priority: number |
| 20 | + |
| 21 | + /** Handler function */ |
8 | 22 | handle: (request: Request) => Promise<void> | void |
9 | 23 |
|
10 | 24 | constructor(data: MiddlewareOptions) { |
11 | 25 | this.name = data.name |
12 | 26 | this.priority = data.priority |
13 | 27 | this.handle = data.handle |
14 | 28 | } |
| 29 | + |
| 30 | + /** |
| 31 | + * Convert to bun-router compatible middleware handler |
| 32 | + */ |
| 33 | + toBunMiddleware(): MiddlewareHandler { |
| 34 | + return async (req: EnhancedRequest, next: NextFunction): Promise<Response | null> => { |
| 35 | + try { |
| 36 | + // Import Request class dynamically to avoid circular dependency |
| 37 | + const { Request: StacksRequest } = await import('./request') |
| 38 | + const stacksRequest = StacksRequest.fromEnhancedRequest(req) |
| 39 | + |
| 40 | + // Execute the Stacks middleware |
| 41 | + await this.handle(stacksRequest as unknown as Request) |
| 42 | + |
| 43 | + // Continue to next middleware |
| 44 | + return next() |
| 45 | + } |
| 46 | + catch (error: any) { |
| 47 | + // If middleware throws, return error response |
| 48 | + if (error.status) { |
| 49 | + return Response.json({ error: error.message }, { status: error.status }) |
| 50 | + } |
| 51 | + throw error |
| 52 | + } |
| 53 | + } |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +/** |
| 58 | + * Create a bun-router compatible middleware from a function |
| 59 | + * |
| 60 | + * @example |
| 61 | + * ```ts |
| 62 | + * const authMiddleware = createMiddleware(async (req, next) => { |
| 63 | + * const token = req.headers.get('authorization') |
| 64 | + * if (!token) { |
| 65 | + * return Response.json({ error: 'Unauthorized' }, { status: 401 }) |
| 66 | + * } |
| 67 | + * return next() |
| 68 | + * }) |
| 69 | + * ``` |
| 70 | + */ |
| 71 | +export function createMiddleware( |
| 72 | + handler: MiddlewareHandler, |
| 73 | +): MiddlewareHandler { |
| 74 | + return handler |
15 | 75 | } |
16 | 76 |
|
17 | | -// const readdir = promisify(fs.readdir) |
| 77 | +/** |
| 78 | + * Create a middleware that only runs for specific paths |
| 79 | + * |
| 80 | + * @example |
| 81 | + * ```ts |
| 82 | + * const apiAuthMiddleware = createPathMiddleware('/api', authMiddleware) |
| 83 | + * ``` |
| 84 | + */ |
| 85 | +export function createPathMiddleware( |
| 86 | + pathPrefix: string, |
| 87 | + handler: MiddlewareHandler, |
| 88 | +): MiddlewareHandler { |
| 89 | + return async (req: EnhancedRequest, next: NextFunction): Promise<Response | null> => { |
| 90 | + const url = new URL(req.url) |
| 91 | + if (url.pathname.startsWith(pathPrefix)) { |
| 92 | + return handler(req, next) |
| 93 | + } |
| 94 | + return next() |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +/** |
| 99 | + * Create a middleware that only runs for specific HTTP methods |
| 100 | + * |
| 101 | + * @example |
| 102 | + * ```ts |
| 103 | + * const postOnlyMiddleware = createMethodMiddleware(['POST', 'PUT'], validationMiddleware) |
| 104 | + * ``` |
| 105 | + */ |
| 106 | +export function createMethodMiddleware( |
| 107 | + methods: string[], |
| 108 | + handler: MiddlewareHandler, |
| 109 | +): MiddlewareHandler { |
| 110 | + const upperMethods = methods.map(m => m.toUpperCase()) |
| 111 | + return async (req: EnhancedRequest, next: NextFunction): Promise<Response | null> => { |
| 112 | + if (upperMethods.includes(req.method.toUpperCase())) { |
| 113 | + return handler(req, next) |
| 114 | + } |
| 115 | + return next() |
| 116 | + } |
| 117 | +} |
| 118 | + |
| 119 | +/** |
| 120 | + * Compose multiple middleware into a single handler |
| 121 | + * |
| 122 | + * @example |
| 123 | + * ```ts |
| 124 | + * const combined = composeMiddleware([ |
| 125 | + * corsMiddleware, |
| 126 | + * authMiddleware, |
| 127 | + * rateLimitMiddleware, |
| 128 | + * ]) |
| 129 | + * ``` |
| 130 | + */ |
| 131 | +export function composeMiddleware( |
| 132 | + middlewares: MiddlewareHandler[], |
| 133 | +): MiddlewareHandler { |
| 134 | + return async (req: EnhancedRequest, next: NextFunction): Promise<Response | null> => { |
| 135 | + let index = -1 |
| 136 | + |
| 137 | + const dispatch = async (i: number): Promise<Response | null> => { |
| 138 | + if (i <= index) { |
| 139 | + throw new Error('next() called multiple times') |
| 140 | + } |
| 141 | + index = i |
18 | 142 |
|
19 | | -async function importMiddlewares(directory: string): Promise<string[]> { |
20 | | - // const middlewares = [] |
21 | | - // TODO: somehow this breaks ./buddy dev |
22 | | - // const files = await readdir(directory) |
| 143 | + const middleware = middlewares[i] |
23 | 144 |
|
24 | | - // for (const file of files) { |
25 | | - // // Dynamically import the middleware |
26 | | - // const imported = await import(path.join(directory, file)) |
27 | | - // middlewares.push(imported.default) |
28 | | - // } |
| 145 | + if (!middleware) { |
| 146 | + return next() |
| 147 | + } |
29 | 148 |
|
30 | | - // return middlewares |
31 | | - return [directory] // fix this: return array of middlewares |
| 149 | + return middleware(req, () => dispatch(i + 1)) |
| 150 | + } |
| 151 | + |
| 152 | + return dispatch(0) |
| 153 | + } |
| 154 | +} |
| 155 | + |
| 156 | +/** |
| 157 | + * Import all middleware files from the user's middleware directory |
| 158 | + */ |
| 159 | +export async function importMiddlewares(directory?: string): Promise<Middleware[]> { |
| 160 | + const middlewareDir = directory || userMiddlewarePath() |
| 161 | + const middlewares: Middleware[] = [] |
| 162 | + |
| 163 | + try { |
| 164 | + if (!fs.existsSync(middlewareDir)) { |
| 165 | + log.debug(`Middleware directory not found: ${middlewareDir}`) |
| 166 | + return middlewares |
| 167 | + } |
| 168 | + |
| 169 | + const files = fs.readdirSync(middlewareDir) |
| 170 | + |
| 171 | + for (const file of files) { |
| 172 | + if (!file.endsWith('.ts') && !file.endsWith('.js')) continue |
| 173 | + |
| 174 | + try { |
| 175 | + const filePath = `${middlewareDir}/${file}` |
| 176 | + const imported = await import(filePath) |
| 177 | + |
| 178 | + if (imported.default) { |
| 179 | + if (imported.default instanceof Middleware) { |
| 180 | + middlewares.push(imported.default) |
| 181 | + } |
| 182 | + else if (typeof imported.default === 'object' && imported.default.handle) { |
| 183 | + middlewares.push(new Middleware(imported.default)) |
| 184 | + } |
| 185 | + } |
| 186 | + } |
| 187 | + catch (error) { |
| 188 | + log.error(`Failed to import middleware from ${file}:`, error) |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + // Sort by priority |
| 193 | + middlewares.sort((a, b) => a.priority - b.priority) |
| 194 | + |
| 195 | + return middlewares |
| 196 | + } |
| 197 | + catch (error) { |
| 198 | + log.error('Failed to import middlewares:', error) |
| 199 | + return middlewares |
| 200 | + } |
32 | 201 | } |
33 | 202 |
|
| 203 | +/** |
| 204 | + * Get all middleware from the user's middleware directory |
| 205 | + * |
| 206 | + * @deprecated Use importMiddlewares() instead |
| 207 | + */ |
34 | 208 | export async function middlewares(): Promise<string[]> { |
35 | | - return await importMiddlewares(userMiddlewarePath()) |
| 209 | + const middlewareDir = userMiddlewarePath() |
| 210 | + |
| 211 | + try { |
| 212 | + if (!fs.existsSync(middlewareDir)) { |
| 213 | + return [] |
| 214 | + } |
| 215 | + |
| 216 | + const files = fs.readdirSync(middlewareDir) |
| 217 | + return files.filter(f => f.endsWith('.ts') || f.endsWith('.js')) |
| 218 | + } |
| 219 | + catch { |
| 220 | + return [] |
| 221 | + } |
| 222 | +} |
| 223 | + |
| 224 | +/** |
| 225 | + * Convert a Stacks middleware to bun-router format |
| 226 | + */ |
| 227 | +export function convertStacksMiddleware( |
| 228 | + stacksMiddleware: MiddlewareOptions, |
| 229 | +): MiddlewareHandler { |
| 230 | + return new Middleware(stacksMiddleware).toBunMiddleware() |
| 231 | +} |
| 232 | + |
| 233 | +// ============================================================================ |
| 234 | +// COMMON MIDDLEWARE FACTORIES |
| 235 | +// ============================================================================ |
| 236 | + |
| 237 | +/** |
| 238 | + * Create a timing middleware that logs request duration |
| 239 | + */ |
| 240 | +export function createTimingMiddleware(name: string = 'request'): MiddlewareHandler { |
| 241 | + return async (req: EnhancedRequest, next: NextFunction): Promise<Response | null> => { |
| 242 | + const start = performance.now() |
| 243 | + const response = await next() |
| 244 | + const duration = performance.now() - start |
| 245 | + |
| 246 | + log.debug(`[${name}] ${req.method} ${new URL(req.url).pathname} - ${duration.toFixed(2)}ms`) |
| 247 | + |
| 248 | + return response |
| 249 | + } |
| 250 | +} |
| 251 | + |
| 252 | +/** |
| 253 | + * Create a request logging middleware |
| 254 | + */ |
| 255 | +export function createLoggingMiddleware(): MiddlewareHandler { |
| 256 | + return async (req: EnhancedRequest, next: NextFunction): Promise<Response | null> => { |
| 257 | + const url = new URL(req.url) |
| 258 | + const start = Date.now() |
| 259 | + |
| 260 | + log.info(`--> ${req.method} ${url.pathname}`) |
| 261 | + |
| 262 | + const response = await next() |
| 263 | + |
| 264 | + const duration = Date.now() - start |
| 265 | + const status = response?.status || 200 |
| 266 | + |
| 267 | + log.info(`<-- ${req.method} ${url.pathname} ${status} ${duration}ms`) |
| 268 | + |
| 269 | + return response |
| 270 | + } |
| 271 | +} |
| 272 | + |
| 273 | +/** |
| 274 | + * Create a maintenance mode middleware |
| 275 | + */ |
| 276 | +export function createMaintenanceMiddleware( |
| 277 | + options: { |
| 278 | + enabled?: boolean | (() => boolean | Promise<boolean>) |
| 279 | + allowedIPs?: string[] |
| 280 | + message?: string |
| 281 | + } = {}, |
| 282 | +): MiddlewareHandler { |
| 283 | + const { enabled = false, allowedIPs = [], message = 'Service temporarily unavailable' } = options |
| 284 | + |
| 285 | + return async (req: EnhancedRequest, next: NextFunction): Promise<Response | null> => { |
| 286 | + const isEnabled = typeof enabled === 'function' ? await enabled() : enabled |
| 287 | + |
| 288 | + if (!isEnabled) { |
| 289 | + return next() |
| 290 | + } |
| 291 | + |
| 292 | + // Check if IP is allowed |
| 293 | + const clientIP = req.headers.get('x-forwarded-for')?.split(',')[0]?.trim() |
| 294 | + || req.headers.get('x-real-ip') |
| 295 | + || '' |
| 296 | + |
| 297 | + if (allowedIPs.includes(clientIP)) { |
| 298 | + return next() |
| 299 | + } |
| 300 | + |
| 301 | + return Response.json( |
| 302 | + { error: message }, |
| 303 | + { status: 503, headers: { 'Retry-After': '3600' } }, |
| 304 | + ) |
| 305 | + } |
36 | 306 | } |
| 307 | + |
| 308 | +// Re-export types |
| 309 | +export type { MiddlewareHandler, NextFunction, EnhancedRequest } |
0 commit comments