1313//===----------------------------------------------------------------------===//
1414
1515#if DEBUG
16+ import DequeModule
1617import Dispatch
1718import Logging
1819import NIOConcurrencyHelpers
@@ -47,24 +48,15 @@ extension Lambda {
4748 /// - note: This API is designed strictly for local testing and is behind a DEBUG flag
4849 static func withLocalServer(
4950 invocationEndpoint: String ? = nil ,
50- _ body: @escaping ( ) async throws -> Void
51+ _ body: sending @escaping ( ) async throws -> Void
5152 ) async throws {
53+ var logger = Logger ( label: " LocalServer " )
54+ logger. logLevel = Lambda . env ( " LOG_LEVEL " ) . flatMap ( Logger . Level. init) ?? . info
5255
53- // launch the local server and wait for it to be started before running the body
54- try await withThrowingTaskGroup ( of: Void . self) { group in
55- // this call will return when the server calls continuation.resume()
56- try await withCheckedThrowingContinuation { ( continuation: CheckedContinuation < Void , any Error > ) in
57- group. addTask {
58- do {
59- try await LambdaHttpServer ( invocationEndpoint: invocationEndpoint) . start (
60- continuation: continuation
61- )
62- } catch {
63- continuation. resume ( throwing: error)
64- }
65- }
66- }
67- // now that server is started, run the Lambda function itself
56+ try await LambdaHTTPServer . withLocalServer (
57+ invocationEndpoint: invocationEndpoint,
58+ logger: logger
59+ ) {
6860 try await body ( )
6961 }
7062 }
@@ -84,34 +76,38 @@ extension Lambda {
8476/// 1. POST /invoke - the client posts the event to the lambda function
8577///
8678/// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client.
87- private struct LambdaHttpServer {
88- private let logger : Logger
89- private let group : EventLoopGroup
90- private let host : String
91- private let port : Int
79+ private struct LambdaHTTPServer {
9280 private let invocationEndpoint : String
9381
9482 private let invocationPool = Pool < LocalServerInvocation > ( )
9583 private let responsePool = Pool < LocalServerResponse > ( )
9684
97- init ( invocationEndpoint: String ? ) {
98- var logger = Logger ( label: " LocalServer " )
99- logger. logLevel = Lambda . env ( " LOG_LEVEL " ) . flatMap ( Logger . Level. init) ?? . info
100- self . logger = logger
101- self . group = MultiThreadedEventLoopGroup . singleton
102- self . host = " 127.0.0.1 "
103- self . port = 7000
85+ private init (
86+ invocationEndpoint: String ?
87+ ) {
10488 self . invocationEndpoint = invocationEndpoint ?? " /invoke "
10589 }
10690
107- func start( continuation: CheckedContinuation < Void , any Error > ) async throws {
108- let channel = try await ServerBootstrap ( group: self . group)
91+ private enum TaskResult < Result: Sendable > : Sendable {
92+ case closureResult( Swift . Result < Result , any Error > )
93+ case serverReturned( Swift . Result < Void , any Error > )
94+ }
95+
96+ static func withLocalServer< Result: Sendable > (
97+ invocationEndpoint: String ? ,
98+ host: String = " 127.0.0.1 " ,
99+ port: Int = 7000 ,
100+ eventLoopGroup: MultiThreadedEventLoopGroup = . singleton,
101+ logger: Logger ,
102+ _ closure: sending @escaping ( ) async throws -> Result
103+ ) async throws -> Result {
104+ let channel = try await ServerBootstrap ( group: eventLoopGroup)
109105 . serverChannelOption ( . backlog, value: 256 )
110106 . serverChannelOption ( . socketOption( . so_reuseaddr) , value: 1 )
111107 . childChannelOption ( . maxMessagesPerRead, value: 1 )
112108 . bind (
113- host: self . host,
114- port: self . port
109+ host: host,
110+ port: port
115111 ) { channel in
116112 channel. eventLoop. makeCompletedFuture {
117113
@@ -129,8 +125,6 @@ private struct LambdaHttpServer {
129125 }
130126 }
131127
132- // notify the caller that the server is started
133- continuation. resume ( )
134128 logger. info (
135129 " Server started and listening " ,
136130 metadata: [
@@ -139,30 +133,76 @@ private struct LambdaHttpServer {
139133 ]
140134 )
141135
136+ let server = LambdaHTTPServer ( invocationEndpoint: invocationEndpoint)
137+
142138 // We are handling each incoming connection in a separate child task. It is important
143139 // to use a discarding task group here which automatically discards finished child tasks.
144140 // A normal task group retains all child tasks and their outputs in memory until they are
145141 // consumed by iterating the group or by exiting the group. Since, we are never consuming
146142 // the results of the group we need the group to automatically discard them; otherwise, this
147143 // would result in a memory leak over time.
148- try await withThrowingDiscardingTaskGroup { group in
149- try await channel. executeThenClose { inbound in
150- for try await connectionChannel in inbound {
151-
152- group. addTask {
153- logger. trace ( " Handling a new connection " )
154- await self . handleConnection ( channel: connectionChannel)
155- logger. trace ( " Done handling the connection " )
144+ let result = await withTaskGroup ( of: TaskResult< Result> . self , returning: Swift . Result < Result , any Error > . self) { group in
145+
146+ let c = closure
147+ group. addTask {
148+ do {
149+
150+ let result = try await c ( )
151+ return . closureResult( . success( result) )
152+ } catch {
153+ return . closureResult( . failure( error) )
154+ }
155+ }
156+
157+ group. addTask {
158+ do {
159+ try await withThrowingDiscardingTaskGroup { taskGroup in
160+ try await channel. executeThenClose { inbound in
161+ for try await connectionChannel in inbound {
162+
163+ taskGroup. addTask {
164+ logger. trace ( " Handling a new connection " )
165+ await server. handleConnection ( channel: connectionChannel, logger: logger)
166+ logger. trace ( " Done handling the connection " )
167+ }
168+ }
169+ }
156170 }
171+ return . serverReturned( . success( ( ) ) )
172+ } catch {
173+ return . serverReturned( . failure( error) )
174+ }
175+ }
176+
177+ let task1 = await group. next ( ) !
178+ group. cancelAll ( )
179+ let task2 = await group. next ( ) !
180+
181+ switch task1 {
182+ case . closureResult( let result) :
183+ return result
184+
185+ case . serverReturned:
186+ switch task2 {
187+ case . closureResult( let result) :
188+ return result
189+
190+ case . serverReturned:
191+ fatalError ( )
157192 }
158193 }
159194 }
195+
160196 logger. info ( " Server shutting down " )
197+ return try result. get ( )
161198 }
162199
200+
201+
163202 /// This method handles individual TCP connections
164203 private func handleConnection(
165- channel: NIOAsyncChannel < HTTPServerRequestPart , HTTPServerResponsePart >
204+ channel: NIOAsyncChannel < HTTPServerRequestPart , HTTPServerResponsePart > ,
205+ logger: Logger
166206 ) async {
167207
168208 var requestHead : HTTPRequestHead !
@@ -186,12 +226,14 @@ private struct LambdaHttpServer {
186226 // process the request
187227 let response = try await self . processRequest (
188228 head: requestHead,
189- body: requestBody
229+ body: requestBody,
230+ logger: logger
190231 )
191232 // send the responses
192233 try await self . sendResponse (
193234 response: response,
194- outbound: outbound
235+ outbound: outbound,
236+ logger: logger
195237 )
196238
197239 requestHead = nil
@@ -214,15 +256,15 @@ private struct LambdaHttpServer {
214256 /// - body: the HTTP request body
215257 /// - Throws:
216258 /// - Returns: the response to send back to the client or the Lambda function
217- private func processRequest( head: HTTPRequestHead , body: ByteBuffer ? ) async throws -> LocalServerResponse {
259+ private func processRequest( head: HTTPRequestHead , body: ByteBuffer ? , logger : Logger ) async throws -> LocalServerResponse {
218260
219261 if let body {
220- self . logger. trace (
262+ logger. trace (
221263 " Processing request " ,
222264 metadata: [ " URI " : " \( head. method) \( head. uri) " , " Body " : " \( String ( buffer: body) ) " ]
223265 )
224266 } else {
225- self . logger. trace ( " Processing request " , metadata: [ " URI " : " \( head. method) \( head. uri) " ] )
267+ logger. trace ( " Processing request " , metadata: [ " URI " : " \( head. method) \( head. uri) " ] )
226268 }
227269
228270 switch ( head. method, head. uri) {
@@ -237,7 +279,9 @@ private struct LambdaHttpServer {
237279 }
238280 // we always accept the /invoke request and push them to the pool
239281 let requestId = " \( DispatchTime . now ( ) . uptimeNanoseconds) "
240- logger. trace ( " /invoke received invocation " , metadata: [ " requestId " : " \( requestId) " ] )
282+ var logger = logger
283+ logger [ metadataKey: " requestID " ] = " \( requestId) "
284+ logger. trace ( " /invoke received invocation " )
241285 await self . invocationPool. push ( LocalServerInvocation ( requestId: requestId, request: body) )
242286
243287 // wait for the lambda function to process the request
@@ -273,9 +317,9 @@ private struct LambdaHttpServer {
273317 case ( . GET, let url) where url. hasSuffix ( Consts . getNextInvocationURLSuffix) :
274318
275319 // pop the tasks from the queue
276- self . logger. trace ( " /next waiting for /invoke " )
320+ logger. trace ( " /next waiting for /invoke " )
277321 for try await invocation in self . invocationPool {
278- self . logger. trace ( " /next retrieved invocation " , metadata: [ " requestId " : " \( invocation. requestId) " ] )
322+ logger. trace ( " /next retrieved invocation " , metadata: [ " requestId " : " \( invocation. requestId) " ] )
279323 // this call also stores the invocation requestId into the response
280324 return invocation. makeResponse ( status: . accepted)
281325 }
@@ -322,12 +366,13 @@ private struct LambdaHttpServer {
322366
323367 private func sendResponse(
324368 response: LocalServerResponse ,
325- outbound: NIOAsyncChannelOutboundWriter < HTTPServerResponsePart >
369+ outbound: NIOAsyncChannelOutboundWriter < HTTPServerResponsePart > ,
370+ logger: Logger
326371 ) async throws {
327372 var headers = HTTPHeaders ( response. headers ?? [ ] )
328373 headers. add ( name: " Content-Length " , value: " \( response. body? . readableBytes ?? 0 ) " )
329374
330- self . logger. trace ( " Writing response " , metadata: [ " requestId " : " \( response. requestId ?? " " ) " ] )
375+ logger. trace ( " Writing response " , metadata: [ " requestId " : " \( response. requestId ?? " " ) " ] )
331376 try await outbound. write (
332377 HTTPServerResponsePart . head (
333378 HTTPResponseHead (
@@ -350,44 +395,67 @@ private struct LambdaHttpServer {
350395 private final class Pool < T> : AsyncSequence , AsyncIteratorProtocol , Sendable where T: Sendable {
351396 typealias Element = T
352397
353- private let _buffer = Mutex < CircularBuffer < T > > ( . init( ) )
354- private let _continuation = Mutex < CheckedContinuation < T , any Error > ? > ( nil )
398+ struct State {
399+ enum State {
400+ case buffer( Deque < T > )
401+ case continuation( CheckedContinuation < T , any Error > ? )
402+ }
355403
356- /// retrieve the first element from the buffer
357- public func popFirst( ) async -> T ? {
358- self . _buffer. withLock { $0. popFirst ( ) }
404+ var state : State
405+
406+ init ( ) {
407+ self . state = . buffer( [ ] )
408+ }
359409 }
360410
411+ private let lock = Mutex < State > ( . init( ) )
412+
361413 /// enqueue an element, or give it back immediately to the iterator if it is waiting for an element
362414 public func push( _ invocation: T ) async {
363415 // if the iterator is waiting for an element, give it to it
364416 // otherwise, enqueue the element
365- if let continuation = self . _continuation. withLock ( { $0 } ) {
366- self . _continuation. withLock { $0 = nil }
367- continuation. resume ( returning: invocation)
368- } else {
369- self . _buffer. withLock { $0. append ( invocation) }
417+ let maybeContinuation = self . lock. withLock { state -> CheckedContinuation < T , any Error > ? in
418+ switch state. state {
419+ case . continuation( let continuation) :
420+ state. state = . buffer( [ ] )
421+ return continuation
422+
423+ case . buffer( var buffer) :
424+ buffer. append ( invocation)
425+ state. state = . buffer( buffer)
426+ return nil
427+ }
370428 }
429+
430+ maybeContinuation? . resume ( returning: invocation)
371431 }
372432
373433 func next( ) async throws -> T ? {
374-
375434 // exit the async for loop if the task is cancelled
376435 guard !Task. isCancelled else {
377436 return nil
378437 }
379438
380- if let element = await self . popFirst ( ) {
381- return element
382- } else {
383- // we can't return nil if there is nothing to dequeue otherwise the async for loop will stop
384- // wait for an element to be enqueued
385- return try await withCheckedThrowingContinuation { ( continuation: CheckedContinuation < T , any Error > ) in
386- // store the continuation for later, when an element is enqueued
387- self . _continuation. withLock {
388- $0 = continuation
439+ return try await withCheckedThrowingContinuation { ( continuation: CheckedContinuation < T , any Error > ) in
440+ let nextAction = self . lock. withLock { state -> T ? in
441+ switch state. state {
442+ case . buffer( var buffer) :
443+ if let first = buffer. popFirst ( ) {
444+ state. state = . buffer( buffer)
445+ return first
446+ } else {
447+ state. state = . continuation( continuation)
448+ return nil
449+ }
450+
451+ case . continuation:
452+ fatalError ( " Concurrent invocations to next(). This is illigal. " )
389453 }
390454 }
455+
456+ guard let nextAction else { return }
457+
458+ continuation. resume ( returning: nextAction)
391459 }
392460 }
393461
0 commit comments