diff --git a/lambda/invoke_loop.go b/lambda/invoke_loop.go index 81402f5f..527f46fe 100644 --- a/lambda/invoke_loop.go +++ b/lambda/invoke_loop.go @@ -53,6 +53,7 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error { lc := lambdacontext.LambdaContext{ AwsRequestID: invoke.id, InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN), + TenantID: invoke.headers.Get(headerTenantID), } if err := parseClientContext(invoke, &lc.ClientContext); err != nil { return reportFailure(invoke, lambdaErrorResponse(err)) diff --git a/lambda/invoke_loop_test.go b/lambda/invoke_loop_test.go index 1f4dbd18..7c1870a8 100644 --- a/lambda/invoke_loop_test.go +++ b/lambda/invoke_loop_test.go @@ -212,14 +212,16 @@ func TestRuntimeAPIContextPlumbing(t *testing.T) { }, nil }) - ts, record := runtimeAPIServer(``, 1) + metadata2 := defaultInvokeMetadata() + metadata2.tenantID = "some-tenant-id" + ts, record := runtimeAPIServer(``, 2, defaultInvokeMetadata(), metadata2) defer ts.Close() endpoint := strings.Split(ts.URL, "://")[1] expectedError := fmt.Sprintf("failed to GET http://%s/2018-06-01/runtime/invocation/next: got unexpected status code: 410", endpoint) assert.EqualError(t, startRuntimeAPILoop(endpoint, handler), expectedError) - expected := ` + expected1 := ` { "Context": { "AwsRequestID": "dummyid", @@ -244,7 +246,35 @@ func TestRuntimeAPIContextPlumbing(t *testing.T) { "Deadline": 22 } ` - assert.JSONEq(t, expected, string(record.responses[0])) + expected2 := ` + { + "Context": { + "AwsRequestID": "dummyid", + "InvokedFunctionArn": "dummyarn", + "TenantID": "some-tenant-id", + "Identity": { + "CognitoIdentityID": "dummyident", + "CognitoIdentityPoolID": "dummypool" + }, + "ClientContext": { + "Client": { + "installation_id": "dummyinstallid", + "app_title": "dummytitle", + "app_version_code": "dummycode", + "app_package_name": "dummyname" + }, + "env": null, + "custom": null + } + }, + "TraceID": "its-xray-time", + "EnvTraceID": "its-xray-time", + "Deadline": 22 + } + ` + + assert.JSONEq(t, expected1, string(record.responses[0])) + assert.JSONEq(t, expected2, string(record.responses[1])) } func TestReadPayload(t *testing.T) { @@ -387,6 +417,7 @@ type eventMetadata struct { deadline string requestID string functionARN string + tenantID string } func defaultInvokeMetadata() eventMetadata { @@ -440,6 +471,9 @@ func runtimeAPIServer(eventPayload string, failAfter int, overrides ...eventMeta w.Header().Add(string(headerClientContext), metadata.clientContext) w.Header().Add(string(headerCognitoIdentity), metadata.cognito) w.Header().Add(string(headerTraceID), metadata.xray) + if metadata.tenantID != "" { + w.Header().Add(string(headerTenantID), metadata.tenantID) + } w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(eventPayload)) case http.MethodPost: diff --git a/lambda/runtime_api_client.go b/lambda/runtime_api_client.go index 1d268cc6..0fa12b4f 100644 --- a/lambda/runtime_api_client.go +++ b/lambda/runtime_api_client.go @@ -24,6 +24,7 @@ const ( headerCognitoIdentity = "Lambda-Runtime-Cognito-Identity" headerClientContext = "Lambda-Runtime-Client-Context" headerInvokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" + headerTenantID = "Lambda-Runtime-Aws-Tenant-Id" headerXRayErrorCause = "Lambda-Runtime-Function-Xray-Error-Cause" trailerLambdaErrorType = "Lambda-Runtime-Function-Error-Type" trailerLambdaErrorBody = "Lambda-Runtime-Function-Error-Body" diff --git a/lambdacontext/context.go b/lambdacontext/context.go index 658d870c..7b9d9c65 100644 --- a/lambdacontext/context.go +++ b/lambdacontext/context.go @@ -77,6 +77,7 @@ type LambdaContext struct { InvokedFunctionArn string //nolint: stylecheck Identity CognitoIdentity ClientContext ClientContext + TenantID string `json:",omitempty"` } // An unexported type to be used as the key for types in this package.