diff --git a/.changes/next-release/feature-AWSSDKforJavav2-439f346.json b/.changes/next-release/feature-AWSSDKforJavav2-439f346.json new file mode 100644 index 000000000000..46ea293d42cc --- /dev/null +++ b/.changes/next-release/feature-AWSSDKforJavav2-439f346.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Optimized JSON marshalling performance for JSON RPC and REST JSON protocols." +} diff --git a/build-tools/src/main/resources/software/amazon/awssdk/spotbugs-suppressions.xml b/build-tools/src/main/resources/software/amazon/awssdk/spotbugs-suppressions.xml index 69a7894bfe94..f6704386ca2a 100644 --- a/build-tools/src/main/resources/software/amazon/awssdk/spotbugs-suppressions.xml +++ b/build-tools/src/main/resources/software/amazon/awssdk/spotbugs-suppressions.xml @@ -528,7 +528,10 @@ whose NULL marshallers handle null validation. --> - + + + + diff --git a/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java b/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java index c76aa851d997..39370010ffc9 100644 --- a/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java +++ b/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java @@ -22,17 +22,21 @@ import static software.amazon.awssdk.http.Header.TRANSFER_ENCODING; import java.io.ByteArrayInputStream; +import java.math.BigDecimal; import java.net.URI; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collections; import java.util.EnumMap; +import java.util.List; import java.util.Map; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.core.document.Document; import software.amazon.awssdk.core.protocol.MarshallLocation; +import software.amazon.awssdk.core.protocol.MarshallingKnownType; import software.amazon.awssdk.core.protocol.MarshallingType; import software.amazon.awssdk.core.traits.PayloadTrait; import software.amazon.awssdk.core.traits.RequiredTrait; @@ -214,17 +218,21 @@ void doMarshall(SdkPojo pojo) { } else if (isExplicitPayloadMember(field)) { marshallExplicitJsonPayload(field, val); } else if (val != null) { - marshallField(field, val); + if (field.location() == MarshallLocation.PAYLOAD) { + // HOT PATH: switch-based dispatch, no registry, no interface dispatch + marshallPayloadField(field, val); + } else { + // WARM PATH: cached registry lookup + interface dispatch + marshallFieldViaRegistry(field, val); + } } else if (field.location() != MarshallLocation.PAYLOAD) { - // Null payload fields that aren't required are no-op in the marshaller registry. - // We short circuit to avoid the registry lookup and dispatch overhead. - // Non payload locations (path, header, query) have null marshallers with - // different behavior, so they must still go through marshallField. - marshallField(field, val); + // Null non-payload: must go through registry (null marshallers vary by location) + marshallFieldViaRegistry(field, val); } else if (field.containsTrait(RequiredTrait.class, TraitType.REQUIRED_TRAIT)) { throw new IllegalArgumentException( String.format("Parameter '%s' must not be null", field.locationName())); } + // else: null payload field, not required → no-op } } @@ -312,6 +320,106 @@ private SdkHttpFullRequest finishMarshalling() { return request.build(); } + /** + * Marshalls a PAYLOAD-location field using a switch on {@link MarshallingKnownType} instead of + * registry lookup and interface dispatch. Each case is a monomorphic call site that the JIT can inline. + */ + @SuppressWarnings("unchecked") + private void marshallPayloadField(SdkField field, Object val) { + MarshallingKnownType knownType = field.marshallingType().getKnownType(); + if (knownType == null) { + marshallFieldViaRegistry(field, val); + return; + } + + StructuredJsonGenerator gen = marshallerContext.jsonGenerator(); + String fieldName = field.locationName(); + + switch (knownType) { + case STRING: + gen.writeFieldName(fieldName); + gen.writeValue((String) val); + break; + case INTEGER: + gen.writeFieldName(fieldName); + gen.writeValue((int) (Integer) val); + break; + case LONG: + gen.writeFieldName(fieldName); + gen.writeValue((long) (Long) val); + break; + case SHORT: + gen.writeFieldName(fieldName); + gen.writeValue((short) (Short) val); + break; + case BYTE: + gen.writeFieldName(fieldName); + gen.writeValue((byte) (Byte) val); + break; + case FLOAT: + gen.writeFieldName(fieldName); + gen.writeValue((float) (Float) val); + break; + case DOUBLE: + gen.writeFieldName(fieldName); + gen.writeValue((double) (Double) val); + break; + case BIG_DECIMAL: + gen.writeFieldName(fieldName); + gen.writeValue((BigDecimal) val); + break; + case BOOLEAN: + gen.writeFieldName(fieldName); + gen.writeValue((boolean) (Boolean) val); + break; + case INSTANT: + // Delegate to existing INSTANT marshaller to preserve TimestampFormatTrait handling. + // Note: INSTANT marshaller writes the field name itself. + SimpleTypeJsonMarshaller.INSTANT.marshall((Instant) val, marshallerContext, + fieldName, (SdkField) field); + break; + case SDK_BYTES: + gen.writeFieldName(fieldName); + gen.writeValue(((SdkBytes) val).asByteBuffer()); + break; + case SDK_POJO: + SimpleTypeJsonMarshaller.SDK_POJO.marshall((SdkPojo) val, marshallerContext, + fieldName, (SdkField) field); + break; + case LIST: + SimpleTypeJsonMarshaller.LIST.marshall((List) val, marshallerContext, + fieldName, (SdkField>) field); + break; + case MAP: + SimpleTypeJsonMarshaller.MAP.marshall((Map) val, marshallerContext, + fieldName, (SdkField>) field); + break; + case DOCUMENT: + SimpleTypeJsonMarshaller.DOCUMENT.marshall((Document) val, marshallerContext, + fieldName, (SdkField) field); + break; + default: + // Unknown type — fall back to registry lookup + marshallFieldViaRegistry(field, val); + break; + } + } + + @SuppressWarnings("unchecked") + private void marshallFieldViaRegistry(SdkField field, Object val) { + if (val == null) { + MARSHALLER_REGISTRY.getMarshaller(field.location(), field.marshallingType(), val) + .marshall(val, marshallerContext, field.locationName(), (SdkField) field); + return; + } + JsonMarshaller marshaller = field.cachedMarshaller(MARSHALLER_REGISTRY); + if (marshaller == null) { + marshaller = MARSHALLER_REGISTRY.getMarshaller(field.location(), field.marshallingType(), val); + field.cacheMarshaller(MARSHALLER_REGISTRY, marshaller); + } + marshaller.marshall(val, marshallerContext, field.locationName(), (SdkField) field); + } + private void marshallField(SdkField field, Object val) { MARSHALLER_REGISTRY.getMarshaller(field.location(), field.marshallingType(), val) .marshall(val, marshallerContext, field.locationName(), (SdkField) field); diff --git a/core/protocols/aws-json-protocol/src/test/java/software/amazon/awssdk/protocols/json/internal/marshall/CachedNonPayloadMarshallingTest.java b/core/protocols/aws-json-protocol/src/test/java/software/amazon/awssdk/protocols/json/internal/marshall/CachedNonPayloadMarshallingTest.java new file mode 100644 index 000000000000..2b4e559a518e --- /dev/null +++ b/core/protocols/aws-json-protocol/src/test/java/software/amazon/awssdk/protocols/json/internal/marshall/CachedNonPayloadMarshallingTest.java @@ -0,0 +1,188 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.protocols.json.internal.marshall; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SdkField; +import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.core.protocol.MarshallLocation; +import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.LocationTrait; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.protocols.core.OperationInfo; +import software.amazon.awssdk.protocols.core.ProtocolMarshaller; +import software.amazon.awssdk.protocols.json.AwsJsonProtocol; +import software.amazon.awssdk.protocols.json.AwsJsonProtocolMetadata; +import software.amazon.awssdk.protocols.json.internal.AwsStructuredPlainJsonFactory; + +/** + * Tests that the cached non-payload marshalling path in + * {@link JsonProtocolMarshaller#marshallFieldViaRegistry} produces correct output + * and that the cache is populated after the first call. + * + *

Validates: Property 3 — Cached non-payload marshalling equivalence

+ *

Validates: Requirements 7.3, 7.4

+ */ +class CachedNonPayloadMarshallingTest { + + private static final URI ENDPOINT = URI.create("http://localhost"); + private static final String CONTENT_TYPE = "application/x-amz-json-1.0"; + private static final OperationInfo OP_INFO = OperationInfo.builder() + .httpMethod(SdkHttpMethod.POST) + .hasImplicitPayloadMembers(true) + .build(); + private static final AwsJsonProtocolMetadata METADATA = + AwsJsonProtocolMetadata.builder() + .protocol(AwsJsonProtocol.AWS_JSON) + .contentType(CONTENT_TYPE) + .build(); + + // ---- HEADER tests ---- + + @Test + void header_string_producesCorrectHeader() { + SdkField field = headerField("x-custom-header", obj -> "headerValue"); + SdkPojo pojo = new SimplePojo(field); + + SdkHttpFullRequest result = createMarshaller().marshall(pojo); + + assertThat(result.firstMatchingHeader("x-custom-header")) + .isPresent() + .hasValue("headerValue"); + } + + @Test + void header_string_secondCall_usesCachedMarshaller() { + // Use the SAME SdkField instance for both calls so the cache is shared + SdkField field = headerField("x-custom-header", obj -> "headerValue"); + + // First call — populates the cache + SdkPojo pojo1 = new SimplePojo(field); + SdkHttpFullRequest result1 = createMarshaller().marshall(pojo1); + + // After first marshalling, the cache should be populated on the SdkField. + // We can't access the exact registry key, but we can verify the field has + // a non-null cached marshaller by checking that a second marshalling produces + // identical output. + + // Second call — should use cached marshaller + SdkPojo pojo2 = new SimplePojo(field); + SdkHttpFullRequest result2 = createMarshaller().marshall(pojo2); + + // Both calls produce identical header output + assertThat(result1.firstMatchingHeader("x-custom-header")) + .isPresent() + .hasValue("headerValue"); + assertThat(result2.firstMatchingHeader("x-custom-header")) + .isPresent() + .hasValue("headerValue"); + + // Verify the cache was populated: the field should have a non-null cached + // marshaller for at least one registry key. Since we can't access the private + // MARSHALLER_REGISTRY, we verify indirectly: the field's cachedMarshaller + // with a dummy key returns null (different key), but the fact that both calls + // succeeded with identical output confirms the cached path works. + Object cachedWithDifferentKey = field.cachedMarshaller(new Object()); + assertThat(cachedWithDifferentKey) + .as("Different registry key should return null") + .isNull(); + } + + // ---- QUERY_PARAM tests ---- + + @Test + void queryParam_string_producesCorrectQueryParam() { + SdkField field = queryParamField("myParam", obj -> "paramValue"); + SdkPojo pojo = new SimplePojo(field); + + SdkHttpFullRequest result = createMarshaller().marshall(pojo); + + assertThat(result.rawQueryParameters().get("myParam")) + .isNotNull() + .containsExactly("paramValue"); + } + + // ---- Helper methods ---- + + private static SdkField headerField(String headerName, + java.util.function.Function getter) { + return SdkField.builder(MarshallingType.STRING) + .memberName(headerName) + .getter(getter) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.HEADER) + .locationName(headerName) + .build()) + .build(); + } + + private static SdkField queryParamField(String paramName, + java.util.function.Function getter) { + return SdkField.builder(MarshallingType.STRING) + .memberName(paramName) + .getter(getter) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.QUERY_PARAM) + .locationName(paramName) + .build()) + .build(); + } + + private static ProtocolMarshaller createMarshaller() { + return JsonProtocolMarshallerBuilder.create() + .endpoint(ENDPOINT) + .jsonGenerator(AwsStructuredPlainJsonFactory + .SDK_JSON_FACTORY.createWriter(CONTENT_TYPE)) + .contentType(CONTENT_TYPE) + .operationInfo(OP_INFO) + .sendExplicitNullForPayload(false) + .protocolMetadata(METADATA) + .build(); + } + + private static final class SimplePojo implements SdkPojo { + private final List> fields; + + SimplePojo(SdkField... fields) { + this.fields = Arrays.asList(fields); + } + + @Override + public List> sdkFields() { + return fields; + } + + @Override + public boolean equalsBySdkFields(Object other) { + return other instanceof SimplePojo; + } + + @Override + public Map> sdkFieldNameToField() { + return Collections.emptyMap(); + } + } +} diff --git a/core/protocols/aws-json-protocol/src/test/java/software/amazon/awssdk/protocols/json/internal/marshall/PayloadMarshallingEquivalenceTest.java b/core/protocols/aws-json-protocol/src/test/java/software/amazon/awssdk/protocols/json/internal/marshall/PayloadMarshallingEquivalenceTest.java new file mode 100644 index 000000000000..a83192056584 --- /dev/null +++ b/core/protocols/aws-json-protocol/src/test/java/software/amazon/awssdk/protocols/json/internal/marshall/PayloadMarshallingEquivalenceTest.java @@ -0,0 +1,580 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.protocols.json.internal.marshall; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.math.BigDecimal; +import java.net.URI; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.SdkField; +import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.core.document.Document; +import software.amazon.awssdk.core.protocol.MarshallLocation; +import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.ListTrait; +import software.amazon.awssdk.core.traits.LocationTrait; +import software.amazon.awssdk.core.traits.MapTrait; +import software.amazon.awssdk.core.traits.TimestampFormatTrait; +import software.amazon.awssdk.core.util.DefaultSdkAutoConstructList; +import software.amazon.awssdk.core.util.DefaultSdkAutoConstructMap; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.protocols.core.OperationInfo; +import software.amazon.awssdk.protocols.core.ProtocolMarshaller; +import software.amazon.awssdk.protocols.json.AwsJsonProtocol; +import software.amazon.awssdk.protocols.json.AwsJsonProtocolMetadata; +import software.amazon.awssdk.protocols.json.internal.AwsStructuredPlainJsonFactory; + +/** + * Tests that the switch-based payload dispatch in {@link JsonProtocolMarshaller#marshallPayloadField} + * produces correct JSON output for all 16 {@code MarshallingKnownType} values. + * + *

Validates: Property 1 — Payload marshalling behavioral equivalence

+ *

Validates: Requirements 2.1–2.12, 3.1–3.5, 4.1, 5.1–5.3, 6.1–6.4

+ */ +class PayloadMarshallingEquivalenceTest { + + private static final URI ENDPOINT = URI.create("http://localhost"); + private static final String CONTENT_TYPE = "application/x-amz-json-1.0"; + private static final OperationInfo OP_INFO = OperationInfo.builder() + .httpMethod(SdkHttpMethod.POST) + .hasImplicitPayloadMembers(true) + .build(); + private static final AwsJsonProtocolMetadata METADATA = + AwsJsonProtocolMetadata.builder() + .protocol(AwsJsonProtocol.AWS_JSON) + .contentType(CONTENT_TYPE) + .build(); + + // ---- STRING ---- + + @Test + void string_producesCorrectJson() { + SdkField field = payloadField("fieldName", MarshallingType.STRING, obj -> "hello world"); + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":\"hello world\""); + } + + // ---- INTEGER ---- + + @Test + void integer_producesCorrectJson() { + SdkField field = payloadField("fieldName", MarshallingType.INTEGER, obj -> 42); + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":42"); + } + + // ---- LONG ---- + + @Test + void long_producesCorrectJson() { + SdkField field = payloadField("fieldName", MarshallingType.LONG, obj -> 123456789L); + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":123456789"); + } + + // ---- SHORT ---- + + @Test + void short_producesCorrectJson() { + SdkField field = payloadField("fieldName", MarshallingType.SHORT, obj -> (short) 7); + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":7"); + } + + // ---- BYTE ---- + + @Test + void byte_producesCorrectJson() { + SdkField field = payloadField("fieldName", MarshallingType.BYTE, obj -> (byte) 3); + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":3"); + } + + // ---- FLOAT ---- + + @Test + void float_producesCorrectJson() { + SdkField field = payloadField("fieldName", MarshallingType.FLOAT, obj -> 1.5f); + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":1.5"); + } + + // ---- DOUBLE ---- + + @Test + void double_producesCorrectJson() { + SdkField field = payloadField("fieldName", MarshallingType.DOUBLE, obj -> 3.14); + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":3.14"); + } + + // ---- BIG_DECIMAL ---- + + @Test + void bigDecimal_producesCorrectJson() { + SdkField field = payloadField("fieldName", MarshallingType.BIG_DECIMAL, + obj -> new BigDecimal("99.99")); + String body = marshallAndGetBody(field); + // BigDecimal is serialized as a quoted string by the JSON generator + assertThat(body).contains("\"fieldName\":\"99.99\""); + } + + // ---- BOOLEAN ---- + + @Test + void boolean_producesCorrectJson() { + SdkField field = payloadField("fieldName", MarshallingType.BOOLEAN, obj -> true); + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":true"); + } + + // ---- INSTANT (default format — UNIX_TIMESTAMP for PAYLOAD) ---- + + @Test + void instant_defaultFormat_producesUnixTimestamp() { + SdkField field = payloadField("fieldName", MarshallingType.INSTANT, + obj -> Instant.ofEpochSecond(1000)); + String body = marshallAndGetBody(field); + // Default PAYLOAD format is UNIX_TIMESTAMP — written via jsonGenerator.writeValue(Instant) + // which for plain JSON writes epoch seconds (e.g. 1000.0 or 1000) + assertThat(body).contains("\"fieldName\":"); + assertThat(body).contains("1000"); + } + + // ---- INSTANT with UNIX_TIMESTAMP trait ---- + + @Test + void instant_unixTimestampTrait_producesUnixTimestamp() { + SdkField field = SdkField.builder(MarshallingType.INSTANT) + .memberName("fieldName") + .getter(obj -> Instant.ofEpochSecond(1000)) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build(), + TimestampFormatTrait.create(TimestampFormatTrait.Format.UNIX_TIMESTAMP)) + .build(); + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":"); + assertThat(body).contains("1000"); + } + + // ---- INSTANT with RFC_822 trait ---- + + @Test + void instant_rfc822Trait_producesRfc822String() { + SdkField field = SdkField.builder(MarshallingType.INSTANT) + .memberName("fieldName") + .getter(obj -> Instant.ofEpochSecond(1000)) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build(), + TimestampFormatTrait.create(TimestampFormatTrait.Format.RFC_822)) + .build(); + String body = marshallAndGetBody(field); + // RFC 822 format: e.g. "Thu, 01 Jan 1970 00:16:40 GMT" + assertThat(body).contains("\"fieldName\":\""); + assertThat(body).contains("1970"); + } + + // ---- INSTANT with ISO_8601 trait ---- + + @Test + void instant_iso8601Trait_producesIso8601String() { + SdkField field = SdkField.builder(MarshallingType.INSTANT) + .memberName("fieldName") + .getter(obj -> Instant.ofEpochSecond(1000)) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build(), + TimestampFormatTrait.create(TimestampFormatTrait.Format.ISO_8601)) + .build(); + String body = marshallAndGetBody(field); + // ISO 8601 format: e.g. "1970-01-01T00:16:40Z" + assertThat(body).contains("\"fieldName\":\""); + assertThat(body).contains("1970-01-01T"); + } + + // ---- SDK_BYTES ---- + + @Test + void sdkBytes_producesBase64EncodedJson() { + SdkField field = payloadField("fieldName", MarshallingType.SDK_BYTES, + obj -> SdkBytes.fromUtf8String("data")); + String body = marshallAndGetBody(field); + // "data" base64 encoded is "ZGF0YQ==" + assertThat(body).contains("\"fieldName\":\"ZGF0YQ==\""); + } + + // ---- SDK_POJO (nested) ---- + + @Test + void sdkPojo_producesNestedObjectJson() { + // Inner pojo with a single string field + SdkField innerField = payloadField("innerField", MarshallingType.STRING, obj -> "innerValue"); + SimplePojo innerPojo = new SimplePojo(innerField); + + SdkField outerField = SdkField.builder(MarshallingType.SDK_POJO) + .memberName("fieldName") + .getter(obj -> innerPojo) + .setter((obj, val) -> { }) + .constructor(() -> innerPojo) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build()) + .build(); + + String body = marshallAndGetBody(outerField); + assertThat(body).contains("\"fieldName\":{\"innerField\":\"innerValue\"}"); + } + + // ---- LIST (non-empty) ---- + + @Test + void list_nonEmpty_producesArrayJson() { + List listValue = Arrays.asList("a", "b", "c"); + + SdkField memberField = SdkField.builder(MarshallingType.STRING) + .memberName("member") + .getter(obj -> null) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("member") + .build()) + .build(); + + SdkField> field = SdkField.>builder(MarshallingType.LIST) + .memberName("fieldName") + .getter(obj -> listValue) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build(), + ListTrait.builder() + .memberFieldInfo(memberField) + .build()) + .build(); + + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":[\"a\",\"b\",\"c\"]"); + } + + // ---- LIST (empty SdkAutoConstructList — should be skipped) ---- + + @Test + void list_emptySdkAutoConstructList_isSkipped() { + List autoList = DefaultSdkAutoConstructList.getInstance(); + + SdkField memberField = SdkField.builder(MarshallingType.STRING) + .memberName("member") + .getter(obj -> null) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("member") + .build()) + .build(); + + SdkField> field = SdkField.>builder(MarshallingType.LIST) + .memberName("fieldName") + .getter(obj -> autoList) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build(), + ListTrait.builder() + .memberFieldInfo(memberField) + .build()) + .build(); + + String body = marshallAndGetBody(field); + assertThat(body).doesNotContain("fieldName"); + } + + // ---- LIST (empty regular list — should emit empty array) ---- + + @Test + void list_emptyRegularList_producesEmptyArray() { + List emptyList = new ArrayList<>(); + + SdkField memberField = SdkField.builder(MarshallingType.STRING) + .memberName("member") + .getter(obj -> null) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("member") + .build()) + .build(); + + SdkField> field = SdkField.>builder(MarshallingType.LIST) + .memberName("fieldName") + .getter(obj -> emptyList) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build(), + ListTrait.builder() + .memberFieldInfo(memberField) + .build()) + .build(); + + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":[]"); + } + + // ---- MAP (non-empty) ---- + + @Test + void map_nonEmpty_producesObjectJson() { + // Use LinkedHashMap for deterministic ordering + Map mapValue = new LinkedHashMap<>(); + mapValue.put("key1", "val1"); + mapValue.put("key2", "val2"); + + SdkField valueField = SdkField.builder(MarshallingType.STRING) + .memberName("value") + .getter(obj -> null) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("value") + .build()) + .build(); + + SdkField> field = SdkField.>builder(MarshallingType.MAP) + .memberName("fieldName") + .getter(obj -> mapValue) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build(), + MapTrait.builder() + .valueFieldInfo(valueField) + .build()) + .build(); + + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":{\"key1\":\"val1\",\"key2\":\"val2\"}"); + } + + // ---- MAP (empty SdkAutoConstructMap — should be skipped) ---- + + @Test + void map_emptySdkAutoConstructMap_isSkipped() { + Map autoMap = DefaultSdkAutoConstructMap.getInstance(); + + SdkField valueField = SdkField.builder(MarshallingType.STRING) + .memberName("value") + .getter(obj -> null) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("value") + .build()) + .build(); + + SdkField> field = SdkField.>builder(MarshallingType.MAP) + .memberName("fieldName") + .getter(obj -> autoMap) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build(), + MapTrait.builder() + .valueFieldInfo(valueField) + .build()) + .build(); + + String body = marshallAndGetBody(field); + assertThat(body).doesNotContain("fieldName"); + } + + // ---- MAP (empty regular map — should emit empty object) ---- + + @Test + void map_emptyRegularMap_producesEmptyObject() { + Map emptyMap = new HashMap<>(); + + SdkField valueField = SdkField.builder(MarshallingType.STRING) + .memberName("value") + .getter(obj -> null) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("value") + .build()) + .build(); + + SdkField> field = SdkField.>builder(MarshallingType.MAP) + .memberName("fieldName") + .getter(obj -> emptyMap) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build(), + MapTrait.builder() + .valueFieldInfo(valueField) + .build()) + .build(); + + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":{}"); + } + + // ---- MAP with null value entry — entry is skipped ---- + + @Test + void map_nullValueEntry_isSkipped() { + Map mapValue = new LinkedHashMap<>(); + mapValue.put("key1", "val1"); + mapValue.put("key2", null); + mapValue.put("key3", "val3"); + + SdkField valueField = SdkField.builder(MarshallingType.STRING) + .memberName("value") + .getter(obj -> null) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("value") + .build()) + .build(); + + SdkField> field = SdkField.>builder(MarshallingType.MAP) + .memberName("fieldName") + .getter(obj -> mapValue) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("fieldName") + .build(), + MapTrait.builder() + .valueFieldInfo(valueField) + .build()) + .build(); + + String body = marshallAndGetBody(field); + assertThat(body).contains("\"key1\":\"val1\""); + assertThat(body).doesNotContain("key2"); + assertThat(body).contains("\"key3\":\"val3\""); + } + + // ---- DOCUMENT ---- + + @Test + void document_producesCorrectJson() { + SdkField field = payloadField("fieldName", MarshallingType.DOCUMENT, + obj -> Document.fromString("test")); + String body = marshallAndGetBody(field); + assertThat(body).contains("\"fieldName\":\"test\""); + } + + // ---- Helper methods ---- + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static SdkField payloadField(String name, + MarshallingType marshallingType, + Function getter) { + return (SdkField) SdkField.builder(marshallingType) + .memberName(name) + .getter((Function) getter) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName(name) + .build()) + .build(); + } + + private String marshallAndGetBody(SdkField... fields) { + SdkPojo pojo = new SimplePojo(fields); + SdkHttpFullRequest result = createMarshaller().marshall(pojo); + return bodyAsString(result); + } + + private static ProtocolMarshaller createMarshaller() { + return JsonProtocolMarshallerBuilder.create() + .endpoint(ENDPOINT) + .jsonGenerator(AwsStructuredPlainJsonFactory + .SDK_JSON_FACTORY.createWriter(CONTENT_TYPE)) + .contentType(CONTENT_TYPE) + .operationInfo(OP_INFO) + .sendExplicitNullForPayload(false) + .protocolMetadata(METADATA) + .build(); + } + + private static String bodyAsString(SdkHttpFullRequest request) { + return request.contentStreamProvider() + .map(p -> { + try { + return software.amazon.awssdk.utils.IoUtils.toUtf8String(p.newStream()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }) + .orElse(""); + } + + private static final class SimplePojo implements SdkPojo { + private final List> fields; + + SimplePojo(SdkField... fields) { + this.fields = Arrays.asList(fields); + } + + @Override + public List> sdkFields() { + return fields; + } + + @Override + public boolean equalsBySdkFields(Object other) { + return other instanceof SimplePojo; + } + + @Override + public Map> sdkFieldNameToField() { + return Collections.emptyMap(); + } + } +} diff --git a/core/protocols/aws-json-protocol/src/test/java/software/amazon/awssdk/protocols/json/internal/marshall/UnknownMarshallingKnownTypeFallbackTest.java b/core/protocols/aws-json-protocol/src/test/java/software/amazon/awssdk/protocols/json/internal/marshall/UnknownMarshallingKnownTypeFallbackTest.java new file mode 100644 index 000000000000..6886452c2dc1 --- /dev/null +++ b/core/protocols/aws-json-protocol/src/test/java/software/amazon/awssdk/protocols/json/internal/marshall/UnknownMarshallingKnownTypeFallbackTest.java @@ -0,0 +1,202 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.protocols.json.internal.marshall; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SdkField; +import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.core.protocol.MarshallLocation; +import software.amazon.awssdk.core.protocol.MarshallingKnownType; +import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.LocationTrait; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.protocols.core.OperationInfo; +import software.amazon.awssdk.protocols.core.ProtocolMarshaller; +import software.amazon.awssdk.protocols.json.AwsJsonProtocol; +import software.amazon.awssdk.protocols.json.AwsJsonProtocolMetadata; +import software.amazon.awssdk.protocols.json.internal.AwsStructuredPlainJsonFactory; + +/** + * Tests that when {@code getKnownType()} returns null, the marshaller falls back to the + * registry-based path without throwing a {@link NullPointerException} from the switch statement. + * + *

Validates: Requirements 1.3, 1.4

+ */ +class UnknownMarshallingKnownTypeFallbackTest { + + private static final URI ENDPOINT = URI.create("http://localhost"); + private static final String CONTENT_TYPE = "application/x-amz-json-1.0"; + private static final OperationInfo OP_INFO = OperationInfo.builder() + .httpMethod(SdkHttpMethod.POST) + .hasImplicitPayloadMembers(true) + .build(); + private static final AwsJsonProtocolMetadata METADATA = + AwsJsonProtocolMetadata.builder() + .protocol(AwsJsonProtocol.AWS_JSON) + .contentType(CONTENT_TYPE) + .build(); + + /** + * A custom MarshallingType whose {@code getKnownType()} returns null. + * This simulates a future or third-party MarshallingType that is not in the known enum set. + */ + private static final MarshallingType CUSTOM_NULL_KNOWN_TYPE = new MarshallingType() { + @Override + public Class getTargetClass() { + return String.class; + } + + @Override + public MarshallingKnownType getKnownType() { + return null; + } + + @Override + public String toString() { + return "CUSTOM_NULL_KNOWN_TYPE"; + } + }; + + /** + * Validates Requirement 1.4: When {@code getKnownType()} returns null, the marshaller falls back + * to the registry-based path without throwing a NullPointerException from the switch statement. + * + *

Since the custom type is not registered in the static MARSHALLER_REGISTRY, the registry + * fallback will fail — but the failure must NOT be a NullPointerException from the switch. + * It should be a NullPointerException from invoking {@code .marshall()} on the null result + * returned by the registry lookup (since the custom type is unregistered).

+ */ + @Test + void nullKnownType_fallsBackToRegistryPath_doesNotThrowNpeFromSwitch() { + SdkField field = SdkField.builder(CUSTOM_NULL_KNOWN_TYPE) + .memberName("customField") + .getter(obj -> "someValue") + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("customField") + .build()) + .build(); + + SdkPojo pojo = new SimplePojo(field); + + // The null-knownType guard in marshallPayloadField should redirect to marshallFieldViaRegistry. + // Since CUSTOM_NULL_KNOWN_TYPE is not registered in the static MARSHALLER_REGISTRY, + // the registry returns null and a NullPointerException occurs when invoking .marshall() on it. + // The critical assertion: the NPE stack trace must NOT originate from the switch statement + // in marshallPayloadField — it must come from the registry fallback path. + assertThatThrownBy(() -> createMarshaller().marshall(pojo)) + .isInstanceOf(NullPointerException.class) + .satisfies(thrown -> { + // Verify the NPE comes from marshallFieldViaRegistry (the fallback), + // not from marshallPayloadField's switch statement + StackTraceElement[] stack = thrown.getStackTrace(); + boolean fromRegistryPath = false; + for (StackTraceElement element : stack) { + if ("marshallFieldViaRegistry".equals(element.getMethodName())) { + fromRegistryPath = true; + break; + } + } + assertThat(fromRegistryPath) + .as("NPE should originate from marshallFieldViaRegistry (registry fallback), " + + "not from the switch in marshallPayloadField") + .isTrue(); + }); + } + + /** + * Validates Requirement 1.3: A standard MarshallingType (STRING) with a known type is handled + * by the switch path, confirming the switch dispatch works for recognized types. + * This serves as a control test — if the switch were broken, this would fail too. + */ + @Test + void knownType_string_isHandledBySwitchPath() { + SdkField field = SdkField.builder(MarshallingType.STRING) + .memberName("normalField") + .getter(obj -> "hello") + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("normalField") + .build()) + .build(); + + SdkPojo pojo = new SimplePojo(field); + + SdkHttpFullRequest result = createMarshaller().marshall(pojo); + String body = bodyAsString(result); + assertThat(body).contains("\"normalField\":\"hello\""); + } + + // ---- Helper methods ---- + + private static ProtocolMarshaller createMarshaller() { + return JsonProtocolMarshallerBuilder.create() + .endpoint(ENDPOINT) + .jsonGenerator(AwsStructuredPlainJsonFactory + .SDK_JSON_FACTORY.createWriter(CONTENT_TYPE)) + .contentType(CONTENT_TYPE) + .operationInfo(OP_INFO) + .sendExplicitNullForPayload(false) + .protocolMetadata(METADATA) + .build(); + } + + private static String bodyAsString(SdkHttpFullRequest request) { + return request.contentStreamProvider() + .map(p -> { + try { + return software.amazon.awssdk.utils.IoUtils.toUtf8String(p.newStream()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }) + .orElse(""); + } + + private static final class SimplePojo implements SdkPojo { + private final List> fields; + + SimplePojo(SdkField... fields) { + this.fields = Arrays.asList(fields); + } + + @Override + public List> sdkFields() { + return fields; + } + + @Override + public boolean equalsBySdkFields(Object other) { + return other instanceof SimplePojo; + } + + @Override + public Map> sdkFieldNameToField() { + return Collections.emptyMap(); + } + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/SdkField.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/SdkField.java index 98561baca4ac..730144e62363 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/SdkField.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/SdkField.java @@ -50,6 +50,15 @@ public final class SdkField { private final Map l1Traits; private final Map, Trait> l2Traits; + // Single-slot marshaller cache. Two volatile fields are used instead of an AtomicReference to an immutable + // holder to avoid per-SdkField object allocation. The read in cachedMarshaller() is not atomic across both + // fields: between reading the key and reading the marshaller, another thread could overwrite both. This is + // safe because (1) in practice there is only one registry per protocol, so all threads converge to the same + // marshaller, and (2) the worst case with multiple registries is a benign cache miss or a single call using + // a marshaller from a different registry, which self-corrects on the next call. + private volatile Object cachedMarshaller; + private volatile Object cachedMarshallerRegistryKey; + private SdkField(Builder builder) { this.memberName = builder.memberName; this.marshallingType = builder.marshallingType; @@ -253,6 +262,33 @@ public boolean containsTrait(Class clzz, TraitType type) { return getTrait(clzz, type) != null; } + /** + * Returns the cached marshaller for the given registry key, or null if not cached. + * Uses reference identity ({@code ==}) for the registry key comparison. + * + * @param registryKey The registry key to match against the cached key. + * @param The type of the cached marshaller. + * @return The cached marshaller if the registry key matches, or null. + */ + @SuppressWarnings("unchecked") + public T cachedMarshaller(Object registryKey) { + if (cachedMarshallerRegistryKey == registryKey) { + return (T) cachedMarshaller; + } + return null; + } + + /** + * Caches the resolved marshaller for the given registry key. + * + * @param registryKey The registry key to associate with the cached marshaller. + * @param marshaller The marshaller instance to cache. + */ + public void cacheMarshaller(Object registryKey, Object marshaller) { + this.cachedMarshaller = marshaller; + this.cachedMarshallerRegistryKey = registryKey; + } + /** * Retrieves the current value of 'this' field from the given POJO. Uses the getter passed into the {@link Builder}. * diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/SdkFieldCacheMarshallerTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/SdkFieldCacheMarshallerTest.java new file mode 100644 index 000000000000..e266b3363814 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/SdkFieldCacheMarshallerTest.java @@ -0,0 +1,117 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.protocol.MarshallLocation; +import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.LocationTrait; + +/** + * Tests for the marshaller cache on {@link SdkField}. + * + *

Validates: Requirements 7.1, 7.2

+ *

Property 2: Marshaller cache round-trip

+ */ +public class SdkFieldCacheMarshallerTest { + + private static SdkField newStringField() { + return SdkField.builder(MarshallingType.STRING) + .memberName("testField") + .getter(obj -> null) + .setter((obj, val) -> { }) + .traits(LocationTrait.builder() + .location(MarshallLocation.PAYLOAD) + .locationName("testField") + .build()) + .build(); + } + + /** + * cachedMarshaller returns null when nothing has been cached yet. + */ + @Test + public void cachedMarshaller_beforeAnyCaching_returnsNull() { + SdkField field = newStringField(); + Object registryKey = new Object(); + + Object cached = field.cachedMarshaller(registryKey); + assertThat(cached).isNull(); + } + + /** + * Round-trip: cacheMarshaller(key, m) then cachedMarshaller(key) returns the same instance. + */ + @Test + public void cachedMarshaller_afterCaching_returnsSameInstance() { + SdkField field = newStringField(); + Object registryKey = new Object(); + Object marshaller = new Object(); + + field.cacheMarshaller(registryKey, marshaller); + + Object cached = field.cachedMarshaller(registryKey); + assertThat(cached).isSameAs(marshaller); + } + + /** + * A different registry key reference returns null, even if both keys are "equal" by value. + * The cache uses reference identity (==), not equals(). + */ + @Test + public void cachedMarshaller_differentKeyReference_returnsNull() { + SdkField field = newStringField(); + // Use strings constructed so they are .equals() but not == + String key1 = new String("registry"); + String key2 = new String("registry"); + Object marshaller = new Object(); + + field.cacheMarshaller(key1, marshaller); + + // key2.equals(key1) is true, but key2 != key1 + Object cached = field.cachedMarshaller(key2); + assertThat(cached).isNull(); + } + + /** + * Overwriting the cache with a new registry key replaces the old entry. + * The old key no longer returns the old marshaller (single-slot replacement). + */ + @Test + public void cacheMarshaller_overwrite_replacesOldEntry() { + SdkField field = newStringField(); + Object oldKey = new Object(); + Object oldMarshaller = new Object(); + Object newKey = new Object(); + Object newMarshaller = new Object(); + + field.cacheMarshaller(oldKey, oldMarshaller); + Object cachedOld = field.cachedMarshaller(oldKey); + assertThat(cachedOld).isSameAs(oldMarshaller); + + // Overwrite with a new key + field.cacheMarshaller(newKey, newMarshaller); + + // New key returns the new marshaller + Object cachedNew = field.cachedMarshaller(newKey); + assertThat(cachedNew).isSameAs(newMarshaller); + // Old key no longer returns anything + Object cachedOldAfter = field.cachedMarshaller(oldKey); + assertThat(cachedOldAfter).isNull(); + } +}