diff --git a/.github/workflows/build-and-deploy-snapshot.yml b/.github/workflows/build-and-deploy-snapshot.yml index eb469084a433..883abb17f982 100644 --- a/.github/workflows/build-and-deploy-snapshot.yml +++ b/.github/workflows/build-and-deploy-snapshot.yml @@ -2,7 +2,7 @@ name: Build and Deploy Snapshot on: push: branches: - - 7.0.x + - main concurrency: group: ${{ github.workflow }}-${{ github.ref }} jobs: @@ -27,7 +27,7 @@ jobs: /**/framework-api-*.zip::zip.name=spring-framework,zip.deployed=false /**/framework-api-*-docs.zip::zip.type=docs /**/framework-api-*-schema.zip::zip.type=schema - build-name: 'spring-framework-7.0.x' + build-name: 'spring-framework-7.1.x' folder: 'deployment-repository' password: ${{ secrets.ARTIFACTORY_PASSWORD }} repository: 'libs-snapshot-local' diff --git a/.github/workflows/release-milestone.yml b/.github/workflows/release-milestone.yml index bc6ab1a9a558..e3379fa2178e 100644 --- a/.github/workflows/release-milestone.yml +++ b/.github/workflows/release-milestone.yml @@ -2,8 +2,8 @@ name: Release Milestone on: push: tags: - - v7.0.0-M[1-9] - - v7.0.0-RC[1-9] + - v7.1.0-M[1-9] + - v7.1.0-RC[1-9] concurrency: group: ${{ github.workflow }}-${{ github.ref }} jobs: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e1ee0bf5e7d6..36ba8c05ba06 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -2,7 +2,7 @@ name: Release on: push: tags: - - v7.0.[0-9]+ + - v7.1.[0-9]+ concurrency: group: ${{ github.workflow }}-${{ github.ref }} jobs: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8d582c654892..0b28403db7f5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ -# Contributing to the Spring Framework +# Contributing to the Spring Framework First off, thank you for taking the time to contribute! :+1: :tada: diff --git a/buildSrc/build.gradle b/buildSrc/build.gradle index 58cbf888cd4b..0b40802d0c3a 100644 --- a/buildSrc/build.gradle +++ b/buildSrc/build.gradle @@ -20,7 +20,7 @@ ext { dependencies { checkstyle "io.spring.javaformat:spring-javaformat-checkstyle:${javaFormatVersion}" implementation "org.jetbrains.kotlin:kotlin-gradle-plugin:${kotlinVersion}" - implementation "org.jetbrains.dokka:dokka-gradle-plugin:2.1.0" + implementation "org.jetbrains.dokka:dokka-gradle-plugin:2.2.0" implementation "com.tngtech.archunit:archunit:1.4.1" implementation "org.gradle:test-retry-gradle-plugin:1.6.2" implementation "io.spring.javaformat:spring-javaformat-gradle-plugin:${javaFormatVersion}" diff --git a/framework-docs/antora.yml b/framework-docs/antora.yml index cb1c780c45b2..c8c78a50db3d 100644 --- a/framework-docs/antora.yml +++ b/framework-docs/antora.yml @@ -31,7 +31,7 @@ asciidoc: spring-org: 'spring-projects' spring-github-org: "https://github.com/{spring-org}" spring-framework-github: "https://github.com/{spring-org}/spring-framework" - spring-framework-code: '{spring-framework-github}/tree/7.0.x' + spring-framework-code: '{spring-framework-github}/tree/main' spring-framework-issues: '{spring-framework-github}/issues' spring-framework-wiki: '{spring-framework-github}/wiki' # Docs diff --git a/framework-docs/modules/ROOT/pages/core/expressions/evaluation.adoc b/framework-docs/modules/ROOT/pages/core/expressions/evaluation.adoc index 137c0496f847..2501e72e4b35 100644 --- a/framework-docs/modules/ROOT/pages/core/expressions/evaluation.adoc +++ b/framework-docs/modules/ROOT/pages/core/expressions/evaluation.adoc @@ -531,7 +531,6 @@ following kinds of expressions cannot be compiled. * Expressions relying on the conversion service * Expressions using custom resolvers * Expressions using overloaded operators -* Expressions using `Optional` with the null-safe or Elvis operator * Expressions using array construction syntax * Expressions using selection or projection * Expressions using bean references diff --git a/framework-docs/modules/ROOT/pages/integration/rest-clients.adoc b/framework-docs/modules/ROOT/pages/integration/rest-clients.adoc index 4dd3ce2daedb..327aa7076f06 100644 --- a/framework-docs/modules/ROOT/pages/integration/rest-clients.adoc +++ b/framework-docs/modules/ROOT/pages/integration/rest-clients.adoc @@ -402,6 +402,27 @@ To serialize only a subset of the object properties, you can specify a {baeldung .toBodilessEntity(); ---- +==== URL encoded Forms + +URL encoded forms, using the `"application/x-www-form-urlencoded"` media type, are useful for sending String key/values over the wire. +This is supported by the `FormHttpMessageConverter`, if the application uses a `MultiValueMap` as source instance +or a target type. + +For example: + +[source,java,indent=0,subs="verbatim"] +---- + MultiValueMap form = new LinkedMultiValueMap<>(); + form.add("project", "Spring Framework"); + form.add("module", "spring-web"); + ResponseEntity response = this.restClient.post() + .contentType(MediaType.APPLICATION_FORM_URLENCODED) + .body(form) + .retrieve() + .toBodilessEntity(); +---- + + ==== Multipart To send multipart data, you need to provide a `MultiValueMap` whose values may be an `Object` for part content, a `Resource` for a file part, or an `HttpEntity` for part content with headers. @@ -419,18 +440,70 @@ For example: headers.setContentType(MediaType.APPLICATION_XML); parts.add("xmlPart", new HttpEntity<>(myBean, headers)); - // send using RestClient.post or RestTemplate.postForEntity + ResponseEntity response = this.restClient.post() + .contentType(MediaType.MULTIPART_FORM_DATA) + .body(parts) + .retrieve() + .toBodilessEntity(); ---- In most cases, you do not have to specify the `Content-Type` for each part. The content type is determined automatically based on the `HttpMessageConverter` chosen to serialize it or, in the case of a `Resource`, based on the file extension. If necessary, you can explicitly provide the `MediaType` with an `HttpEntity` wrapper. -Once the `MultiValueMap` is ready, you can use it as the body of a `POST` request, using `RestClient.post().body(parts)` (or `RestTemplate.postForObject`). +The `Content-Type` is set to `multipart/form-data` by the `MultipartHttpMessageConverter`. +As seen in the previous section, `MultiValueMap` types can also be used for URL encoded forms. +It is preferable to explicitly set the media type in the `Content-Type` or `Accept` HTTP request headers to ensure that the expected +message converter is used. + +`RestClient` can also receive multipart responses. +To decode a multipart response body, use a `ParameterizedTypeReference>`. +The decoded map contains `Part` instances where `FormFieldPart` represents form field values +and `FilePart` represents file parts with a `filename()` and a `transferTo()` method. + +[tabs] +====== +Java:: ++ +[source,java,indent=0,subs="verbatim"] +---- + MultiValueMap result = this.restClient.get() + .uri("https://example.com/upload") + .accept(MediaType.MULTIPART_FORM_DATA) + .retrieve() + .body(new ParameterizedTypeReference<>() {}); + + Part field = result.getFirst("fieldPart"); + if (field instanceof FormFieldPart formField) { + String fieldValue = formField.value(); + } + Part file = result.getFirst("filePart"); + if (file instanceof FilePart filePart) { + filePart.transferTo(Path.of("/tmp/" + filePart.filename())); + } +---- + +Kotlin:: ++ +[source,kotlin,indent=0,subs="verbatim"] +---- + val result = this.restClient.get() + .uri("https://example.com/upload") + .accept(MediaType.MULTIPART_FORM_DATA) + .retrieve() + .body(object : ParameterizedTypeReference>() {}) + + val field = result?.getFirst("fieldPart") + if (field is FormFieldPart) { + val fieldValue = field.value() + } + val file = result?.getFirst("filePart") + if (file is FilePart) { + file.transferTo(Path.of("/tmp/" + file.filename())) + } +---- +====== -If the `MultiValueMap` contains at least one non-`String` value, the `Content-Type` is set to `multipart/form-data` by the `FormHttpMessageConverter`. -If the `MultiValueMap` has `String` values, the `Content-Type` defaults to `application/x-www-form-urlencoded`. -If necessary the `Content-Type` may also be set explicitly. [[rest-request-factories]] === Client Request Factories diff --git a/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring/annotation-mockitobean.adoc b/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring/annotation-mockitobean.adoc index b0ad7707ff3b..591d3a50c356 100644 --- a/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring/annotation-mockitobean.adoc +++ b/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring/annotation-mockitobean.adoc @@ -12,18 +12,20 @@ The annotations can be applied in the following ways. * On a non-static field in a test class or any of its superclasses. * On a non-static field in an enclosing class for a `@Nested` test class or in any class in the type hierarchy or enclosing class hierarchy above the `@Nested` test class. +* On a parameter in the constructor for a test class. * At the type level on a test class or any superclass or implemented interface in the type hierarchy above the test class. * At the type level on an enclosing class for a `@Nested` test class or on any class or interface in the type hierarchy or enclosing class hierarchy above the `@Nested` test class. -When `@MockitoBean` or `@MockitoSpyBean` is declared on a field, the bean to mock or spy -is inferred from the type of the annotated field. If multiple candidates exist in the -`ApplicationContext`, a `@Qualifier` annotation can be declared on the field to help -disambiguate. In the absence of a `@Qualifier` annotation, the name of the annotated -field will be used as a _fallback qualifier_. Alternatively, you can explicitly specify a -bean name to mock or spy by setting the `value` or `name` attribute in the annotation. +When `@MockitoBean` or `@MockitoSpyBean` is declared on a field or constructor parameter, +the bean to mock or spy is inferred from the type of the annotated field or parameter. If +multiple candidates exist in the `ApplicationContext`, a `@Qualifier` annotation can be +declared on the field or parameter to help disambiguate. In the absence of a `@Qualifier` +annotation, the name of the annotated field or parameter will be used as a _fallback +qualifier_. Alternatively, you can explicitly specify a bean name to mock or spy by +setting the `value` or `name` attribute in the annotation. When `@MockitoBean` or `@MockitoSpyBean` is declared at the type level, the type of bean (or beans) to mock or spy must be supplied via the `types` attribute in the annotation – @@ -201,6 +203,82 @@ Kotlin:: <1> Replace the bean named `service` with a Mockito mock. ====== +The following example shows how to use `@MockitoBean` on a constructor parameter for a +by-type lookup. + +[tabs] +====== +Java:: ++ +[source,java,indent=0,subs="verbatim,quotes"] +---- + @SpringJUnitConfig(TestConfig.class) + class BeanOverrideTests { + + private final CustomService customService; + + BeanOverrideTests(@MockitoBean CustomService customService) { // <1> + this.customService = customService; + } + + // tests... + } +---- +<1> Replace the bean with type `CustomService` with a Mockito mock and inject it into + the constructor. + +Kotlin:: ++ +[source,kotlin,indent=0,subs="verbatim,quotes"] +---- + @SpringJUnitConfig(TestConfig::class) + class BeanOverrideTests(@MockitoBean val customService: CustomService) { // <1> + + // tests... + } +---- +<1> Replace the bean with type `CustomService` with a Mockito mock and inject it into + the constructor. +====== + +The following example shows how to use `@MockitoBean` on a constructor parameter for a +by-name lookup. + +[tabs] +====== +Java:: ++ +[source,java,indent=0,subs="verbatim,quotes"] +---- + @SpringJUnitConfig(TestConfig.class) + class BeanOverrideTests { + + private final CustomService customService; + + BeanOverrideTests(@MockitoBean("service") CustomService customService) { // <1> + this.customService = customService; + } + + // tests... + } +---- +<1> Replace the bean named `service` with a Mockito mock and inject it into the + constructor. + +Kotlin:: ++ +[source,kotlin,indent=0,subs="verbatim,quotes"] +---- + @SpringJUnitConfig(TestConfig::class) + class BeanOverrideTests(@MockitoBean("service") val customService: CustomService) { // <1> + + // tests... + } +---- +<1> Replace the bean named `service` with a Mockito mock and inject it into the + constructor. +====== + The following `@SharedMocks` annotation registers two mocks by-type and one mock by-name. [tabs] @@ -375,6 +453,80 @@ Kotlin:: <1> Wrap the bean named `service` with a Mockito spy. ====== +The following example shows how to use `@MockitoSpyBean` on a constructor parameter for +a by-type lookup. + +[tabs] +====== +Java:: ++ +[source,java,indent=0,subs="verbatim,quotes"] +---- + @SpringJUnitConfig(TestConfig.class) + class BeanOverrideTests { + + private final CustomService customService; + + BeanOverrideTests(@MockitoSpyBean CustomService customService) { // <1> + this.customService = customService; + } + + // tests... + } +---- +<1> Wrap the bean with type `CustomService` with a Mockito spy and inject it into the + constructor. + +Kotlin:: ++ +[source,kotlin,indent=0,subs="verbatim,quotes"] +---- + @SpringJUnitConfig(TestConfig::class) + class BeanOverrideTests(@MockitoSpyBean val customService: CustomService) { // <1> + + // tests... + } +---- +<1> Wrap the bean with type `CustomService` with a Mockito spy and inject it into the + constructor. +====== + +The following example shows how to use `@MockitoSpyBean` on a constructor parameter for +a by-name lookup. + +[tabs] +====== +Java:: ++ +[source,java,indent=0,subs="verbatim,quotes"] +---- + @SpringJUnitConfig(TestConfig.class) + class BeanOverrideTests { + + private final CustomService customService; + + BeanOverrideTests(@MockitoSpyBean("service") CustomService customService) { // <1> + this.customService = customService; + } + + // tests... + } +---- +<1> Wrap the bean named `service` with a Mockito spy and inject it into the constructor. + +Kotlin:: ++ +[source,kotlin,indent=0,subs="verbatim,quotes"] +---- + @SpringJUnitConfig(TestConfig::class) + class BeanOverrideTests(@MockitoSpyBean("service") val customService: CustomService) { // <1> + + // tests... + } +---- +<1> Wrap the bean named `service` with a Mockito spy and inject it into the constructor. +====== + The following `@SharedSpies` annotation registers two spies by-type and one spy by-name. [tabs] diff --git a/framework-docs/modules/ROOT/pages/testing/resttestclient.adoc b/framework-docs/modules/ROOT/pages/testing/resttestclient.adoc index 73244a574d57..61bc3eb6bdee 100644 --- a/framework-docs/modules/ROOT/pages/testing/resttestclient.adoc +++ b/framework-docs/modules/ROOT/pages/testing/resttestclient.adoc @@ -142,6 +142,9 @@ provides two alternative ways to verify the response: 1. xref:resttestclient-workflow[Built-in Assertions] extend the request workflow with a chain of expectations 2. xref:resttestclient-assertj[AssertJ Integration] to verify the response via `assertThat()` statements +TIP: See the xref:integration/rest-clients.adoc#rest-message-conversion[HTTP Message Conversion] +section for examples on how to prepare a request with any content, including form data and multipart data. + [[resttestclient.workflow]] @@ -213,6 +216,16 @@ To verify JSON content with https://github.com/jayway/JsonPath[JSONPath]: include-code::./JsonTests[tag=jsonPath,indent=0] +[[resttestclient.multipart]] +==== Multipart Content + +When testing endpoints that return multipart responses, you can decode the body to a +`MultiValueMap` and assert individual parts using the `FormFieldPart` +and `FilePart` subtypes. + +include-code::./MultipartTests[tag=multipart,indent=0] + + [[resttestclient.assertj]] === AssertJ Integration diff --git a/framework-docs/modules/ROOT/pages/testing/testcontext-framework/bean-overriding.adoc b/framework-docs/modules/ROOT/pages/testing/testcontext-framework/bean-overriding.adoc index 055b718feaae..7fbb0ceeb0c2 100644 --- a/framework-docs/modules/ROOT/pages/testing/testcontext-framework/bean-overriding.adoc +++ b/framework-docs/modules/ROOT/pages/testing/testcontext-framework/bean-overriding.adoc @@ -2,8 +2,9 @@ = Bean Overriding in Tests Bean overriding in tests refers to the ability to override specific beans in the -`ApplicationContext` for a test class, by annotating the test class or one or more -non-static fields in the test class. +`ApplicationContext` for a test class, by annotating the test class, one or more +non-static fields in the test class, or one or more parameters in the constructor for the +test class. NOTE: This feature is intended as a less risky alternative to the practice of registering a bean via `@Bean` with the `DefaultListableBeanFactory` @@ -42,9 +43,9 @@ The `spring-test` module registers implementations of the latter two {spring-framework-code}/spring-test/src/main/resources/META-INF/spring.factories[`META-INF/spring.factories` properties file]. -The bean overriding infrastructure searches for annotations on test classes as well as -annotations on non-static fields in test classes that are meta-annotated with -`@BeanOverride` and instantiates the corresponding `BeanOverrideProcessor` which is +The bean overriding infrastructure searches for annotations on test classes, non-static +fields in test classes, and parameters in test class constructors that are meta-annotated +with `@BeanOverride`, and instantiates the corresponding `BeanOverrideProcessor` which is responsible for creating an appropriate `BeanOverrideHandler`. The internal `BeanOverrideBeanFactoryPostProcessor` then uses bean override handlers to diff --git a/framework-docs/modules/ROOT/pages/testing/testcontext-framework/support-classes.adoc b/framework-docs/modules/ROOT/pages/testing/testcontext-framework/support-classes.adoc index 5f77d9e08b4a..a98f0f72f9cf 100644 --- a/framework-docs/modules/ROOT/pages/testing/testcontext-framework/support-classes.adoc +++ b/framework-docs/modules/ROOT/pages/testing/testcontext-framework/support-classes.adoc @@ -179,6 +179,10 @@ If a specific parameter in a constructor for a JUnit Jupiter test class is of ty `ApplicationContext` (or a sub-type thereof) or is annotated or meta-annotated with `@Autowired`, `@Qualifier`, or `@Value`, Spring injects the value for that specific parameter with the corresponding bean or value from the test's `ApplicationContext`. +Similarly, if a specific parameter is annotated with `@MockitoBean` or `@MockitoSpyBean`, +Spring will inject a Mockito mock or spy, respectively — see +xref:testing/annotations/integration-spring/annotation-mockitobean.adoc[`@MockitoBean` and `@MockitoSpyBean`] +for details. Spring can also be configured to autowire all arguments for a test class constructor if the constructor is considered to be _autowirable_. A constructor is considered to be diff --git a/framework-docs/modules/ROOT/pages/testing/webtestclient.adoc b/framework-docs/modules/ROOT/pages/testing/webtestclient.adoc index 1d86864e0d28..b2bd9807be21 100644 --- a/framework-docs/modules/ROOT/pages/testing/webtestclient.adoc +++ b/framework-docs/modules/ROOT/pages/testing/webtestclient.adoc @@ -580,8 +580,8 @@ Kotlin:: [[webtestclient-stream]] ==== Streaming Responses -To test potentially infinite streams such as `"text/event-stream"` or -`"application/x-ndjson"`, start by verifying the response status and headers, and then +To test potentially infinite streams such as `"text/event-stream"`, +`"application/jsonl"` or `"application/x-ndjson"`, start by verifying the response status and headers, and then obtain a `FluxExchangeResult`: [tabs] diff --git a/framework-docs/modules/ROOT/pages/web/webflux/reactive-spring.adoc b/framework-docs/modules/ROOT/pages/web/webflux/reactive-spring.adoc index 99e0d09913a9..54e44b528227 100644 --- a/framework-docs/modules/ROOT/pages/web/webflux/reactive-spring.adoc +++ b/framework-docs/modules/ROOT/pages/web/webflux/reactive-spring.adoc @@ -485,8 +485,8 @@ The `JacksonJsonEncoder` works as follows: * For a multi-value publisher with `application/json`, by default collect the values with `Flux#collectToList()` and then serialize the resulting collection. * For a multi-value publisher with a streaming media type such as -`application/x-ndjson` or `application/stream+x-jackson-smile`, encode, write, and -flush each value individually using a +`application/jsonl`, `application/x-ndjson` or `application/stream+x-jackson-smile`, +encode, write, and flush each value individually using a https://en.wikipedia.org/wiki/JSON_streaming[line-delimited JSON] format. Other streaming media types may be registered with the encoder. * For SSE the `JacksonJsonEncoder` is invoked per event and the output is flushed to ensure @@ -598,7 +598,7 @@ To configure all three in WebFlux, you'll need to supply a pre-configured instan [.small]#xref:web/webmvc/mvc-ann-async.adoc#mvc-ann-async-http-streaming[See equivalent in the Servlet stack]# When streaming to the HTTP response (for example, `text/event-stream`, -`application/x-ndjson`), it is important to send data periodically, in order to +`application/jsonl`, `application/x-ndjson`), it is important to send data periodically, in order to reliably detect a disconnected client sooner rather than later. Such a send could be a comment-only, empty SSE event or any other "no-op" data that would effectively serve as a heartbeat. diff --git a/framework-docs/modules/ROOT/pages/web/webmvc/message-converters.adoc b/framework-docs/modules/ROOT/pages/web/webmvc/message-converters.adoc index d65c68aab7f3..c9e48ff2c82b 100644 --- a/framework-docs/modules/ROOT/pages/web/webmvc/message-converters.adoc +++ b/framework-docs/modules/ROOT/pages/web/webmvc/message-converters.adoc @@ -23,13 +23,17 @@ For all converters, a default media type is used, but you can override it by set By default, this converter supports all text media types(`text/{asterisk}`) and writes with a `Content-Type` of `text/plain`. | `FormHttpMessageConverter` -| An `HttpMessageConverter` implementation that can read and write form data from the HTTP request and response. +| An `HttpMessageConverter` implementation that can read and write URL encoded forms. By default, this converter reads and writes the `application/x-www-form-urlencoded` media type. Form data is read from and written into a `MultiValueMap`. -The converter can also write (but not read) multipart data read from a `MultiValueMap`. -By default, `multipart/form-data` is supported. -Additional multipart subtypes can be supported for writing form data. -Consult the javadoc for `FormHttpMessageConverter` for further details. +`Map` is also supported, but multiple values under the same key will be ignored. + +| `MultipartHttpMessageConverter` +| An `HttpMessageConverter` implementation that can read and write multipart messages. +`MultiValueMap` can be written to multipart messages, converting each part independently using +the configured message converters. Multipart messages can be read into `MultiValueMap`, each value +being a `Part` or one of its subtypes (`FormFieldPart` and `FilePart`). +By default, `multipart/form-data` is supported. Additional multipart subtypes can be supported for writing form data. | `ByteArrayHttpMessageConverter` | An `HttpMessageConverter` implementation that can read and write byte arrays from the HTTP request and response. diff --git a/framework-docs/modules/ROOT/pages/web/webmvc/mvc-ann-async.adoc b/framework-docs/modules/ROOT/pages/web/webmvc/mvc-ann-async.adoc index 17779b8aeb62..137da77f9602 100644 --- a/framework-docs/modules/ROOT/pages/web/webmvc/mvc-ann-async.adoc +++ b/framework-docs/modules/ROOT/pages/web/webmvc/mvc-ann-async.adoc @@ -423,8 +423,8 @@ Reactive return values are handled as follows: * A single-value promise is adapted to, similar to using `DeferredResult`. Examples include `CompletionStage` (JDK), `Mono` (Reactor), and `Single` (RxJava). -* A multi-value stream with a streaming media type (such as `application/x-ndjson` -or `text/event-stream`) is adapted to, similar to using `ResponseBodyEmitter` or +* A multi-value stream with a streaming media type (such as `application/jsonl`, +`application/x-ndjson` or `text/event-stream`) is adapted to, similar to using `ResponseBodyEmitter` or `SseEmitter`. Examples include `Flux` (Reactor) or `Observable` (RxJava). Applications can also return `Flux` or `Observable`. * A multi-value stream with any other media type (such as `application/json`) is adapted diff --git a/framework-docs/src/main/java/org/springframework/docs/testing/resttestclient/multipart/MultipartTests.java b/framework-docs/src/main/java/org/springframework/docs/testing/resttestclient/multipart/MultipartTests.java new file mode 100644 index 000000000000..9f58cd698d50 --- /dev/null +++ b/framework-docs/src/main/java/org/springframework/docs/testing/resttestclient/multipart/MultipartTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.docs.testing.resttestclient.multipart; + +import org.junit.jupiter.api.Test; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.MediaType; +import org.springframework.http.converter.multipart.FilePart; +import org.springframework.http.converter.multipart.FormFieldPart; +import org.springframework.http.converter.multipart.Part; +import org.springframework.test.web.servlet.client.RestTestClient; +import org.springframework.util.MultiValueMap; + +import static org.assertj.core.api.Assertions.assertThat; + +public class MultipartTests { + + RestTestClient client; + + @Test + void multipart() { + // tag::multipart[] + client.get().uri("/upload") + .accept(MediaType.MULTIPART_FORM_DATA) + .exchange() + .expectStatus().isOk() + .expectBody(new ParameterizedTypeReference>() {}) + .value(result -> { + Part field = result.getFirst("fieldPart"); + assertThat(field).isInstanceOfSatisfying(FormFieldPart.class, + formField -> assertThat(formField.value()).isEqualTo("fieldValue")); + Part file = result.getFirst("filePart"); + assertThat(file).isInstanceOfSatisfying(FilePart.class, + filePart -> assertThat(filePart.filename()).isEqualTo("logo.png")); + }); + // end::multipart[] + } + +} diff --git a/framework-platform/framework-platform.gradle b/framework-platform/framework-platform.gradle index 986fed60b25f..eb602f3c25e7 100644 --- a/framework-platform/framework-platform.gradle +++ b/framework-platform/framework-platform.gradle @@ -7,7 +7,7 @@ javaPlatform { } dependencies { - api(platform("com.fasterxml.jackson:jackson-bom:2.20.2")) + api(platform("com.fasterxml.jackson:jackson-bom:2.21.2")) api(platform("io.micrometer:micrometer-bom:1.16.5")) api(platform("io.netty:netty-bom:4.2.12.Final")) api(platform("io.projectreactor:reactor-bom:2025.0.5")) @@ -18,10 +18,10 @@ dependencies { api(platform("org.eclipse.jetty:jetty-bom:12.1.7")) api(platform("org.eclipse.jetty.ee11:jetty-ee11-bom:12.1.7")) api(platform("org.jetbrains.kotlinx:kotlinx-coroutines-bom:1.10.2")) - api(platform("org.jetbrains.kotlinx:kotlinx-serialization-bom:1.9.0")) + api(platform("org.jetbrains.kotlinx:kotlinx-serialization-bom:1.11.0")) api(platform("org.junit:junit-bom:6.0.3")) api(platform("org.mockito:mockito-bom:5.23.0")) - api(platform("tools.jackson:jackson-bom:3.0.4")) + api(platform("tools.jackson:jackson-bom:3.1.1")) constraints { api("com.fasterxml:aalto-xml:1.3.4") @@ -120,7 +120,7 @@ dependencies { api("org.glassfish:jakarta.el:4.0.2") api("org.graalvm.sdk:graal-sdk:22.3.1") api("org.hamcrest:hamcrest:3.0") - api("org.hibernate.orm:hibernate-core:7.2.11.Final") + api("org.hibernate.orm:hibernate-core:7.3.1.Final") api("org.hibernate.validator:hibernate-validator:9.1.0.Final") api("org.hsqldb:hsqldb:2.7.4") api("org.htmlunit:htmlunit:4.21.0") diff --git a/gradle.properties b/gradle.properties index 729750414176..edb3a7bca441 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,10 +1,10 @@ -version=7.0.8-SNAPSHOT +version=7.1.0-SNAPSHOT org.gradle.caching=true org.gradle.jvmargs=-Xmx2048m org.gradle.parallel=true -kotlinVersion=2.2.21 +kotlinVersion=2.3.20 byteBuddyVersion=1.17.6 kotlin.jvm.target.validation.mode=ignore diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/BeanFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/BeanFactory.java index ad47efb2b2ae..b7ffe33ede7c 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/BeanFactory.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/BeanFactory.java @@ -100,6 +100,7 @@ * @author Rod Johnson * @author Juergen Hoeller * @author Chris Beams + * @author Yanming Zhou * @since 13 April 2001 * @see BeanNameAware#setBeanName * @see BeanClassLoaderAware#setBeanClassLoader @@ -175,6 +176,29 @@ public interface BeanFactory { */ T getBean(String name, Class requiredType) throws BeansException; + /** + * Return an instance, which may be shared or independent, of the specified bean. + *

Behaves the same as {@link #getBean(String)}, but provides a measure of type + * safety by throwing a BeanNotOfRequiredTypeException if the bean is not of the + * required type. This means that ClassCastException can't be thrown on casting + * the result correctly, as can happen with {@link #getBean(String)}. + *

Translates aliases back to the corresponding canonical bean name. + *

Will ask the parent factory if the bean cannot be found in this factory instance. + * @param name the name of the bean to retrieve + * @param typeReference the reference to obtain type the bean must match + * @return an instance of the bean. + * Note that the return value will never be {@code null}. In case of a stub for + * {@code null} from a factory method having been resolved for the requested bean, a + * {@code BeanNotOfRequiredTypeException} against the NullBean stub will be raised. + * Consider using {@link #getBeanProvider(Class)} for resolving optional dependencies. + * @throws NoSuchBeanDefinitionException if there is no such bean definition + * @throws BeanNotOfRequiredTypeException if the bean is not of the required type + * @throws BeansException if the bean could not be created + * @since 7.1 + * @see #getBean(String, Class) + */ + T getBean(String name, ParameterizedTypeReference typeReference) throws BeansException; + /** * Return an instance, which may be shared or independent, of the specified bean. *

Allows for specifying explicit constructor arguments / factory method arguments, diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/BeanNotOfRequiredTypeException.java b/spring-beans/src/main/java/org/springframework/beans/factory/BeanNotOfRequiredTypeException.java index bbc504098aad..bb51dd64b9b5 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/BeanNotOfRequiredTypeException.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/BeanNotOfRequiredTypeException.java @@ -16,14 +16,17 @@ package org.springframework.beans.factory; +import java.lang.reflect.Type; + import org.springframework.beans.BeansException; -import org.springframework.util.ClassUtils; +import org.springframework.core.ResolvableType; /** * Thrown when a bean doesn't match the expected type. * * @author Rod Johnson * @author Juergen Hoeller + * @author Yanming Zhou */ @SuppressWarnings("serial") public class BeanNotOfRequiredTypeException extends BeansException { @@ -32,7 +35,7 @@ public class BeanNotOfRequiredTypeException extends BeansException { private final String beanName; /** The required type. */ - private final Class requiredType; + private final Type genericRequiredType; /** The offending type. */ private final Class actualType; @@ -46,10 +49,22 @@ public class BeanNotOfRequiredTypeException extends BeansException { * the expected type */ public BeanNotOfRequiredTypeException(String beanName, Class requiredType, Class actualType) { - super("Bean named '" + beanName + "' is expected to be of type '" + ClassUtils.getQualifiedName(requiredType) + - "' but was actually of type '" + ClassUtils.getQualifiedName(actualType) + "'"); + this(beanName, (Type) requiredType, actualType); + } + + /** + * Create a new BeanNotOfRequiredTypeException. + * @param beanName the name of the bean requested + * @param requiredType the required type + * @param actualType the actual type returned, which did not match + * the expected type + * @since 7.1 + */ + public BeanNotOfRequiredTypeException(String beanName, Type requiredType, Class actualType) { + super("Bean named '" + beanName + "' is expected to be of type '" + requiredType.getTypeName() + + "' but was actually of type '" + actualType.getTypeName() + "'"); this.beanName = beanName; - this.requiredType = requiredType; + this.genericRequiredType = requiredType; this.actualType = actualType; } @@ -65,7 +80,15 @@ public String getBeanName() { * Return the expected type for the bean. */ public Class getRequiredType() { - return this.requiredType; + return (this.genericRequiredType instanceof Class clazz ? clazz : ResolvableType.forType(this.genericRequiredType).toClass()); + } + + /** + * Return the expected generic type for the bean. + * @since 7.1 + */ + public Type getGenericRequiredType() { + return this.genericRequiredType; } /** diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/BeanRegistrar.java b/spring-beans/src/main/java/org/springframework/beans/factory/BeanRegistrar.java index 4ba835973c82..db73cd8b85ff 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/BeanRegistrar.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/BeanRegistrar.java @@ -19,21 +19,9 @@ import org.springframework.core.env.Environment; /** - * Contract for registering beans programmatically, typically imported with an - * {@link org.springframework.context.annotation.Import @Import} annotation on - * a {@link org.springframework.context.annotation.Configuration @Configuration} - * class. - *

- * @Configuration
- * @Import(MyBeanRegistrar.class)
- * class MyConfiguration {
- * }
- * Can also be applied to an application context via - * {@link org.springframework.context.support.GenericApplicationContext#register(BeanRegistrar...)}. - * + * Contract for registering beans programmatically. Implementations use the + * {@link BeanRegistry} and {@link Environment} to register beans: * - *

Bean registrar implementations use {@link BeanRegistry} and {@link Environment} - * APIs to register beans programmatically in a concise and flexible way. *

  * class MyBeanRegistrar implements BeanRegistrar {
  *
@@ -52,9 +40,55 @@
  *     }
  * }
* + *

{@code BeanRegistrar} implementations are not Spring components: they must have + * a no-arg constructor and cannot rely on dependency injection or any other + * component-model feature. They can be used in two distinct ways depending on the + * application context setup. + * + *

With the {@code @Configuration} model

+ * + *

A {@code BeanRegistrar} must be imported via + * {@link org.springframework.context.annotation.Import @Import} on a + * {@link org.springframework.context.annotation.Configuration @Configuration} class: + * + *

+ * @Configuration
+ * @Import(MyBeanRegistrar.class)
+ * class MyConfiguration {
+ * }
+ * + *

This is the only mechanism that triggers bean registration in the annotation-based + * configuration model. Annotating an implementation with {@code @Configuration} or + * {@code @Component}, or returning an instance from a {@code @Bean} method, registers + * it as a bean but does not invoke its + * {@link #register(BeanRegistry, Environment) register} method. + * + *

When imported, the registrar is invoked in the order it is encountered during + * configuration class processing. It can therefore check for and build on beans that + * have already been defined, but has no visibility into beans that will be registered + * by classes processed later. + * + *

Programmatic usage

+ * + *

A {@code BeanRegistrar} can also be applied directly to a + * {@link org.springframework.context.support.GenericApplicationContext}: + * + *

+ * GenericApplicationContext context = new GenericApplicationContext();
+ * context.register(new MyBeanRegistrar());
+ * context.registerBean("myBean", MyBean.class);
+ * context.refresh();
+ * + *

This mode is primarily intended for fully programmatic application context setups. + * Registrars applied this way are invoked before any {@code @Configuration} class is + * processed. They can therefore observe beans registered programmatically (e.g., via + * one of the {@code GenericApplicationContext#registerBean} methods), but will + * not see any beans defined in {@code @Configuration} classes also + * registered with the context. + * *

A {@code BeanRegistrar} implementing {@link org.springframework.context.annotation.ImportAware} - * can optionally introspect import metadata when used in an import scenario, otherwise the - * {@code setImportMetadata} method is simply not being called. + * can optionally introspect import metadata when used in an import scenario; otherwise + * the {@code setImportMetadata} method is not called. * *

In Kotlin, it is recommended to use {@code BeanRegistrarDsl} instead of * implementing {@code BeanRegistrar}. diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/BeanRegistry.java b/spring-beans/src/main/java/org/springframework/beans/factory/BeanRegistry.java index 07e1eb5da071..cadc878644da 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/BeanRegistry.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/BeanRegistry.java @@ -33,6 +33,7 @@ * programmatic bean registration capabilities. * * @author Sebastien Deleuze + * @author Juergen Hoeller * @since 7.0 */ public interface BeanRegistry { @@ -140,6 +141,28 @@ public interface BeanRegistry { */ void registerBean(String name, ParameterizedTypeReference beanType, Consumer> customizer); + /** + * Determine whether a bean of the given name is already registered. + * @param name the name of the bean + * @since 7.1 + */ + boolean containsBean(String name); + + /** + * Determine whether a bean of the given type is already registered. + * @param beanType the type of the bean + * @since 7.1 + */ + boolean containsBean(Class beanType); + + /** + * Determine whether a bean of the given generics-containing type is + * already registered. + * @param beanType the generics-containing type of the bean + * @since 7.1 + */ + boolean containsBean(ParameterizedTypeReference beanType); + /** * Specification for customizing a bean. diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/ParameterResolutionDelegate.java b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/ParameterResolutionDelegate.java index 4f4dc58622d6..31635d5a5f1a 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/ParameterResolutionDelegate.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/ParameterResolutionDelegate.java @@ -89,6 +89,31 @@ public static boolean isAutowirable(Parameter parameter, int parameterIndex) { AnnotatedElementUtils.hasAnnotation(annotatedParameter, Value.class)); } + /** + * Resolve the dependency for the supplied {@link Parameter} from the + * supplied {@link AutowireCapableBeanFactory}. + *

See {@link #resolveDependency(Parameter, int, String, Class, AutowireCapableBeanFactory)} + * for details. + * @param parameter the parameter whose dependency should be resolved (must not be + * {@code null}) + * @param parameterIndex the index of the parameter in the constructor or method + * that declares the parameter + * @param containingClass the concrete class that contains the parameter; this may + * differ from the class that declares the parameter in that it may be a subclass + * thereof, potentially substituting type variables (must not be {@code null}) + * @param beanFactory the {@code AutowireCapableBeanFactory} from which to resolve + * the dependency (must not be {@code null}) + * @return the resolved object, or {@code null} if none found + * @throws BeansException if dependency resolution failed + * @see #resolveDependency(Parameter, int, String, Class, AutowireCapableBeanFactory) + */ + public static @Nullable Object resolveDependency( + Parameter parameter, int parameterIndex, Class containingClass, AutowireCapableBeanFactory beanFactory) + throws BeansException { + + return resolveDependency(parameter, parameterIndex, null, containingClass, beanFactory); + } + /** * Resolve the dependency for the supplied {@link Parameter} from the * supplied {@link AutowireCapableBeanFactory}. @@ -101,11 +126,13 @@ public static boolean isAutowirable(Parameter parameter, int parameterIndex) { * with {@link Autowired @Autowired} with the {@link Autowired#required required} * flag set to {@code false}. *

If an explicit qualifier is not declared, the name of the parameter - * will be used as the qualifier for resolving ambiguities. + * (or a supplied custom name) will be used as the qualifier for resolving ambiguities. * @param parameter the parameter whose dependency should be resolved (must not be * {@code null}) * @param parameterIndex the index of the parameter in the constructor or method * that declares the parameter + * @param parameterName a custom name for the parameter; or {@code null} to use + * the default parameter name discovery logic * @param containingClass the concrete class that contains the parameter; this may * differ from the class that declares the parameter in that it may be a subclass * thereof, potentially substituting type variables (must not be {@code null}) @@ -113,13 +140,14 @@ public static boolean isAutowirable(Parameter parameter, int parameterIndex) { * the dependency (must not be {@code null}) * @return the resolved object, or {@code null} if none found * @throws BeansException if dependency resolution failed + * @since 7.1 * @see #isAutowirable * @see Autowired#required * @see SynthesizingMethodParameter#forExecutable(Executable, int) * @see AutowireCapableBeanFactory#resolveDependency(DependencyDescriptor, String) */ - public static @Nullable Object resolveDependency( - Parameter parameter, int parameterIndex, Class containingClass, AutowireCapableBeanFactory beanFactory) + public static @Nullable Object resolveDependency(Parameter parameter, int parameterIndex, + @Nullable String parameterName, Class containingClass, AutowireCapableBeanFactory beanFactory) throws BeansException { Assert.notNull(parameter, "Parameter must not be null"); @@ -132,7 +160,7 @@ public static boolean isAutowirable(Parameter parameter, int parameterIndex) { MethodParameter methodParameter = SynthesizingMethodParameter.forExecutable( parameter.getDeclaringExecutable(), parameterIndex); - DependencyDescriptor descriptor = new DependencyDescriptor(methodParameter, required); + DependencyDescriptor descriptor = new NamedParameterDependencyDescriptor(methodParameter, required, parameterName); descriptor.setContainingClass(containingClass); return beanFactory.resolveDependency(descriptor, null); } @@ -171,4 +199,26 @@ private static AnnotatedElement getEffectiveAnnotatedParameter(Parameter paramet return parameter; } + + @SuppressWarnings("serial") + private static class NamedParameterDependencyDescriptor extends DependencyDescriptor { + + private final @Nullable String parameterName; + + NamedParameterDependencyDescriptor(MethodParameter methodParameter, boolean required, @Nullable String parameterName) { + super(methodParameter, required); + this.parameterName = parameterName; + } + + @Override + public @Nullable String getDependencyName() { + return (this.parameterName != null ? this.parameterName : super.getDependencyName()); + } + + @Override + public boolean usesStandardBeanLookup() { + return true; + } + } + } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/AutowiredArguments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/AutowiredArguments.java index f175ddd27034..4618f11b9b45 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/AutowiredArguments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/AutowiredArguments.java @@ -45,7 +45,7 @@ public interface AutowiredArguments { Object value = getObject(index); if (!ClassUtils.isAssignableValue(requiredType, value)) { throw new IllegalArgumentException("Argument type mismatch: expected '" + - ClassUtils.getQualifiedName(requiredType) + "' for value [" + value + "]"); + requiredType.getTypeName() + "' for value [" + value + "]"); } return (T) value; } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/config/YamlProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/config/YamlProcessor.java index 480352cfe95f..a31eab3533a9 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/config/YamlProcessor.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/config/YamlProcessor.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Properties; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import org.apache.commons.logging.Log; @@ -194,30 +195,31 @@ protected Yaml createYaml() { } private boolean process(MatchCallback callback, Yaml yaml, Resource resource) { - int count = 0; + AtomicInteger count = new AtomicInteger(); try { if (logger.isDebugEnabled()) { logger.debug("Loading from YAML: " + resource); } - try (Reader reader = new UnicodeReader(resource.getInputStream())) { + resource.consumeContent(inputStream -> { + Reader reader = new UnicodeReader(inputStream); for (Object object : yaml.loadAll(reader)) { if (object != null && process(asMap(object), callback)) { - count++; + count.incrementAndGet(); if (this.resolutionMethod == ResolutionMethod.FIRST_FOUND) { break; } } } if (logger.isDebugEnabled()) { - logger.debug("Loaded " + count + " document" + (count > 1 ? "s" : "") + + logger.debug("Loaded " + count + " document" + (count.get() > 1 ? "s" : "") + " from YAML resource: " + resource); } - } + }); } catch (IOException ex) { handleProcessError(resource, ex); } - return (count > 0); + return (count.get() > 0); } private void handleProcessError(Resource resource, IOException ex) { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractBeanFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractBeanFactory.java index 54a5921b546c..1f8c483db66b 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractBeanFactory.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractBeanFactory.java @@ -17,6 +17,7 @@ package org.springframework.beans.factory.support; import java.beans.PropertyEditor; +import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -66,6 +67,7 @@ import org.springframework.beans.factory.config.SmartInstantiationAwareBeanPostProcessor; import org.springframework.core.DecoratingClassLoader; import org.springframework.core.NamedThreadLocal; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; import org.springframework.core.convert.ConversionService; import org.springframework.core.log.LogMessage; @@ -201,6 +203,17 @@ public T getBean(String name, Class requiredType) throws BeansException { return doGetBean(name, requiredType, null, false); } + @Override + @SuppressWarnings("unchecked") + public T getBean(String name, ParameterizedTypeReference typeReference) throws BeansException { + Object bean = getBean(name); + Type requiredType = typeReference.getType(); + if (!ResolvableType.forType(requiredType).isInstance(bean)) { + throw new BeanNotOfRequiredTypeException(name, requiredType, bean.getClass()); + } + return (T) bean; + } + @Override public Object getBean(String name, @Nullable Object @Nullable ... args) throws BeansException { return doGetBean(name, null, args, false); @@ -413,7 +426,7 @@ T adaptBeanInstance(String name, Object bean, @Nullable Class requiredTyp catch (TypeMismatchException ex) { if (logger.isTraceEnabled()) { logger.trace("Failed to convert bean '" + name + "' to required type '" + - ClassUtils.getQualifiedName(requiredType) + "'", ex); + requiredType.getTypeName() + "'", ex); } throw new BeanNotOfRequiredTypeException(name, requiredType, bean.getClass()); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/BeanRegistryAdapter.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/BeanRegistryAdapter.java index 9e00964a06e7..cfc3135ce375 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/BeanRegistryAdapter.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/BeanRegistryAdapter.java @@ -26,6 +26,7 @@ import org.springframework.beans.BeanUtils; import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.BeanRegistrar; import org.springframework.beans.factory.BeanRegistry; import org.springframework.beans.factory.ListableBeanFactory; @@ -176,6 +177,22 @@ public void registerBean(String name, ParameterizedTypeReference beanType this.beanRegistry.registerBeanDefinition(name, beanDefinition); } + @Override + public boolean containsBean(String name) { + return this.beanFactory.containsBean(name); + } + + @Override + public boolean containsBean(Class beanType) { + return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(this.beanFactory, beanType).length > 0; + } + + @Override + public boolean containsBean(ParameterizedTypeReference beanType) { + ResolvableType resolvableType = ResolvableType.forType(beanType); + return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(this.beanFactory, resolvableType).length > 0; + } + /** * {@link RootBeanDefinition} subclass for {@code #registerBean} based diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/PropertiesBeanDefinitionReader.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/PropertiesBeanDefinitionReader.java index 60a49a92afa0..b7a5c326dd54 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/PropertiesBeanDefinitionReader.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/PropertiesBeanDefinitionReader.java @@ -17,7 +17,6 @@ package org.springframework.beans.factory.support; import java.io.IOException; -import java.io.InputStream; import java.io.InputStreamReader; import java.util.Enumeration; import java.util.HashMap; @@ -256,14 +255,14 @@ public int loadBeanDefinitions(EncodedResource encodedResource, @Nullable String Properties props = new Properties(); try { - try (InputStream is = encodedResource.getResource().getInputStream()) { + encodedResource.getResource().consumeContent(is -> { if (encodedResource.getEncoding() != null) { getPropertiesPersister().load(props, new InputStreamReader(is, encodedResource.getEncoding())); } else { getPropertiesPersister().load(props, is); } - } + }); int count = registerBeanDefinitions(props, prefix, encodedResource.getResource().getDescription()); if (logger.isDebugEnabled()) { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java index 87fbc1a27512..7a1570c660d3 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java @@ -17,6 +17,7 @@ package org.springframework.beans.factory.support; import java.lang.annotation.Annotation; +import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -64,6 +65,7 @@ * @author Rod Johnson * @author Juergen Hoeller * @author Sam Brannen + * @author Yanming Zhou * @since 06.01.2003 * @see DefaultListableBeanFactory */ @@ -149,6 +151,17 @@ else if (bean instanceof FactoryBean factoryBean) { return (T) bean; } + @Override + @SuppressWarnings("unchecked") + public T getBean(String name, ParameterizedTypeReference typeReference) throws BeansException { + Object bean = getBean(name); + Type requiredType = typeReference.getType(); + if (!ResolvableType.forType(requiredType).isInstance(bean)) { + throw new BeanNotOfRequiredTypeException(name, requiredType, bean.getClass()); + } + return (T) bean; + } + @Override public Object getBean(String name, @Nullable Object @Nullable ... args) throws BeansException { if (!ObjectUtils.isEmpty(args)) { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/xml/XmlBeanDefinitionReader.java b/spring-beans/src/main/java/org/springframework/beans/factory/xml/XmlBeanDefinitionReader.java index b7b4cf63431f..631bb958d9de 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/xml/XmlBeanDefinitionReader.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/xml/XmlBeanDefinitionReader.java @@ -21,6 +21,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import javax.xml.parsers.ParserConfigurationException; @@ -337,12 +338,16 @@ public int loadBeanDefinitions(EncodedResource encodedResource) throws BeanDefin "Detected cyclic loading of " + encodedResource + " - check your import definitions!"); } - try (InputStream inputStream = encodedResource.getResource().getInputStream()) { - InputSource inputSource = new InputSource(inputStream); - if (encodedResource.getEncoding() != null) { - inputSource.setEncoding(encodedResource.getEncoding()); - } - return doLoadBeanDefinitions(inputSource, encodedResource.getResource()); + try { + AtomicInteger count = new AtomicInteger(); + encodedResource.getResource().consumeContent(inputStream -> { + InputSource inputSource = new InputSource(inputStream); + if (encodedResource.getEncoding() != null) { + inputSource.setEncoding(encodedResource.getEncoding()); + } + count.addAndGet(doLoadBeanDefinitions(inputSource, encodedResource.getResource())); + }); + return count.get(); } catch (IOException ex) { throw new BeanDefinitionStoreException( diff --git a/spring-beans/src/main/java/org/springframework/beans/propertyeditors/ClassArrayEditor.java b/spring-beans/src/main/java/org/springframework/beans/propertyeditors/ClassArrayEditor.java index 2224c14d1409..7532425316d9 100644 --- a/spring-beans/src/main/java/org/springframework/beans/propertyeditors/ClassArrayEditor.java +++ b/spring-beans/src/main/java/org/springframework/beans/propertyeditors/ClassArrayEditor.java @@ -84,8 +84,8 @@ public String getAsText() { return ""; } StringJoiner sj = new StringJoiner(","); - for (Class klass : classes) { - sj.add(ClassUtils.getQualifiedName(klass)); + for (Class clazz : classes) { + sj.add(clazz.getTypeName()); } return sj.toString(); } diff --git a/spring-beans/src/main/java/org/springframework/beans/propertyeditors/ClassEditor.java b/spring-beans/src/main/java/org/springframework/beans/propertyeditors/ClassEditor.java index 4d3bfb3de7a6..126f70718b72 100644 --- a/spring-beans/src/main/java/org/springframework/beans/propertyeditors/ClassEditor.java +++ b/spring-beans/src/main/java/org/springframework/beans/propertyeditors/ClassEditor.java @@ -72,12 +72,7 @@ public void setAsText(String text) throws IllegalArgumentException { @Override public String getAsText() { Class clazz = (Class) getValue(); - if (clazz != null) { - return ClassUtils.getQualifiedName(clazz); - } - else { - return ""; - } + return (clazz != null ? clazz.getTypeName() : ""); } } diff --git a/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanFactoryExtensions.kt b/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanFactoryExtensions.kt index 608b11a50dc5..3bd4d67bb2ea 100644 --- a/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanFactoryExtensions.kt +++ b/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanFactoryExtensions.kt @@ -24,6 +24,7 @@ import org.springframework.core.ResolvableType * This extension is not subject to type erasure and retains actual generic type arguments. * * @author Sebastien Deleuze + * @author Yanming Zhou * @since 5.0 */ inline fun BeanFactory.getBean(): T = @@ -31,14 +32,14 @@ inline fun BeanFactory.getBean(): T = /** * Extension for [BeanFactory.getBean] providing a `getBean("foo")` variant. - * Like the original Java method, this extension is subject to type erasure. + * This extension is not subject to type erasure and retains actual generic type arguments. * * @see BeanFactory.getBean(String, Class) * @author Sebastien Deleuze * @since 5.0 */ inline fun BeanFactory.getBean(name: String): T = - getBean(name, T::class.java) + getBean(name, (object : ParameterizedTypeReference() {})) /** * Extension for [BeanFactory.getBean] providing a `getBean(arg1, arg2)` variant. diff --git a/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanRegistrarDsl.kt b/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanRegistrarDsl.kt index aeb68a5be71a..f9eb6242d04a 100644 --- a/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanRegistrarDsl.kt +++ b/spring-beans/src/main/kotlin/org/springframework/beans/factory/BeanRegistrarDsl.kt @@ -18,8 +18,8 @@ package org.springframework.beans.factory import org.springframework.beans.factory.BeanRegistry.SupplierContext import org.springframework.core.ParameterizedTypeReference -import org.springframework.core.ResolvableType import org.springframework.core.env.Environment +import kotlin.reflect.KClass /** * Contract for registering programmatically beans. @@ -364,6 +364,28 @@ open class BeanRegistrarDsl(private val init: BeanRegistrarDsl.() -> Unit): Bean return registry.registerBean(object: ParameterizedTypeReference() {}, customizer) } + /** + * Determine whether a bean of the given name is already registered. + * @param name the name of the bean + * @since 7.1 + */ + fun containsBean(name: String): Boolean = registry.containsBean(name) + + /** + * Determine whether a bean of the given type is already registered. + * @param beanType the type of the bean + * @since 7.1 + */ + fun containsBean(beanType: KClass<*>): Boolean = registry.containsBean(beanType.java) + + /** + * Determine whether a bean of the given type is already registered. + * @param T the type of the bean + * @since 7.1 + */ + inline fun containsBean(): Boolean = + registry.containsBean(object: ParameterizedTypeReference() {}) + /** * Context available from the bean instance supplier designed to give access diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/DefaultListableBeanFactoryTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/DefaultListableBeanFactoryTests.java index e5e024a9313f..c9831eeaec19 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/DefaultListableBeanFactoryTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/DefaultListableBeanFactoryTests.java @@ -79,6 +79,7 @@ import org.springframework.beans.testfixture.beans.factory.DummyFactory; import org.springframework.core.MethodParameter; import org.springframework.core.Ordered; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.core.annotation.Order; @@ -1682,6 +1683,29 @@ void getBeanByTypeWithAmbiguity() { lbf.getBean(TestBean.class)); } + @Test + void getBeanByNameWithTypeReference() { + RootBeanDefinition bd1 = new RootBeanDefinition(StringTemplate.class); + RootBeanDefinition bd2 = new RootBeanDefinition(NumberTemplate.class); + lbf.registerBeanDefinition("bd1", bd1); + lbf.registerBeanDefinition("bd2", bd2); + + Template stringTemplate = lbf.getBean("bd1", new ParameterizedTypeReference<>() {}); + Template numberTemplate = lbf.getBean("bd2", new ParameterizedTypeReference<>() {}); + + assertThat(stringTemplate).isInstanceOf(StringTemplate.class); + assertThat(numberTemplate).isInstanceOf(NumberTemplate.class); + + assertThatExceptionOfType(BeanNotOfRequiredTypeException.class) + .isThrownBy(() -> lbf.getBean("bd2", new ParameterizedTypeReference>() {})) + .satisfies(ex -> { + assertThat(ex.getBeanName()).isEqualTo("bd2"); + assertThat(ex.getRequiredType()).isEqualTo(Template.class); + assertThat(ex.getActualType()).isEqualTo(NumberTemplate.class); + assertThat(ex.getGenericRequiredType().toString()).endsWith("Template"); + }); + } + @Test void getBeanByTypeWithPrimary() { RootBeanDefinition bd1 = new RootBeanDefinition(TestBean.class); @@ -3872,4 +3896,16 @@ public Class getObjectType() { } } + private static class Template { + + } + + private static class StringTemplate extends Template { + + } + + private static class NumberTemplate extends Template { + + } + } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/ParameterResolutionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/ParameterResolutionTests.java index 8ed67a96fb2c..08ccff1b7597 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/ParameterResolutionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/ParameterResolutionTests.java @@ -45,9 +45,9 @@ class ParameterResolutionTests { @Test void isAutowirablePreconditions() { - assertThatIllegalArgumentException().isThrownBy(() -> - ParameterResolutionDelegate.isAutowirable(null, 0)) - .withMessageContaining("Parameter must not be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> ParameterResolutionDelegate.isAutowirable(null, 0)) + .withMessageContaining("Parameter must not be null"); } @Test @@ -87,29 +87,30 @@ void nonAnnotatedParametersInTopLevelClassConstructorAreNotCandidatesForAutowiri Parameter[] parameters = notAutowirableConstructor.getParameters(); for (int parameterIndex = 0; parameterIndex < parameters.length; parameterIndex++) { Parameter parameter = parameters[parameterIndex]; - assertThat(ParameterResolutionDelegate.isAutowirable(parameter, parameterIndex)).as("Parameter " + parameter + " must not be autowirable").isFalse(); + assertThat(ParameterResolutionDelegate.isAutowirable(parameter, parameterIndex)) + .as("Parameter " + parameter + " must not be autowirable").isFalse(); } } @Test void resolveDependencyPreconditionsForParameter() { assertThatIllegalArgumentException() - .isThrownBy(() -> ParameterResolutionDelegate.resolveDependency(null, 0, null, mock())) - .withMessageContaining("Parameter must not be null"); + .isThrownBy(() -> ParameterResolutionDelegate.resolveDependency(null, 0, null, mock())) + .withMessageContaining("Parameter must not be null"); } @Test void resolveDependencyPreconditionsForContainingClass() { - assertThatIllegalArgumentException().isThrownBy(() -> - ParameterResolutionDelegate.resolveDependency(getParameter(), 0, null, null)) - .withMessageContaining("Containing class must not be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> ParameterResolutionDelegate.resolveDependency(getParameter(), 0, null, null)) + .withMessageContaining("Containing class must not be null"); } @Test void resolveDependencyPreconditionsForBeanFactory() { - assertThatIllegalArgumentException().isThrownBy(() -> - ParameterResolutionDelegate.resolveDependency(getParameter(), 0, getClass(), null)) - .withMessageContaining("AutowireCapableBeanFactory must not be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> ParameterResolutionDelegate.resolveDependency(getParameter(), 0, getClass(), null)) + .withMessageContaining("AutowireCapableBeanFactory must not be null"); } private Parameter getParameter() throws NoSuchMethodException { @@ -133,9 +134,64 @@ void resolveDependencyForAnnotatedParametersInTopLevelClassConstructor() throws parameter, parameterIndex, AutowirableClass.class, beanFactory); assertThat(intermediateDependencyDescriptor.getAnnotatedElement()).isEqualTo(constructor); assertThat(intermediateDependencyDescriptor.getMethodParameter().getParameter()).isEqualTo(parameter); + assertThat(intermediateDependencyDescriptor.usesStandardBeanLookup()).isTrue(); } } + @Test + void resolveDependencyWithCustomParameterNamePreconditionsForParameter() { + assertThatIllegalArgumentException() + .isThrownBy(() -> ParameterResolutionDelegate.resolveDependency(null, 0, "customName", getClass(), mock())) + .withMessageContaining("Parameter must not be null"); + } + + @Test + void resolveDependencyWithCustomParameterNamePreconditionsForContainingClass() { + assertThatIllegalArgumentException() + .isThrownBy(() -> ParameterResolutionDelegate.resolveDependency(getParameter(), 0, "customName", null, mock())) + .withMessageContaining("Containing class must not be null"); + } + + @Test + void resolveDependencyWithCustomParameterNamePreconditionsForBeanFactory() { + assertThatIllegalArgumentException() + .isThrownBy(() -> ParameterResolutionDelegate.resolveDependency(getParameter(), 0, "customName", getClass(), null)) + .withMessageContaining("AutowireCapableBeanFactory must not be null"); + } + + @Test + void resolveDependencyWithNullCustomParameterNameFallsBackToDefaultParameterNameDiscovery() throws Exception { + Constructor constructor = AutowirableClass.class.getConstructor(String.class, String.class, String.class, String.class); + AutowireCapableBeanFactory beanFactory = mock(); + given(beanFactory.resolveDependency(any(), isNull())).willAnswer(invocation -> invocation.getArgument(0)); + + Parameter[] parameters = constructor.getParameters(); + for (int parameterIndex = 0; parameterIndex < parameters.length; parameterIndex++) { + Parameter parameter = parameters[parameterIndex]; + DependencyDescriptor via4ArgMethod = (DependencyDescriptor) ParameterResolutionDelegate.resolveDependency( + parameter, parameterIndex, AutowirableClass.class, beanFactory); + DependencyDescriptor via5ArgMethod = (DependencyDescriptor) ParameterResolutionDelegate.resolveDependency( + parameter, parameterIndex, null, AutowirableClass.class, beanFactory); + assertThat(via5ArgMethod.getDependencyName()).isEqualTo(via4ArgMethod.getDependencyName()); + } + } + + @Test + void resolveDependencyWithCustomParameterName() throws Exception { + Constructor constructor = AutowirableClass.class.getConstructor(String.class, String.class, String.class, String.class); + AutowireCapableBeanFactory beanFactory = mock(); + given(beanFactory.resolveDependency(any(), isNull())).willAnswer(invocation -> invocation.getArgument(0)); + + Parameter parameter = constructor.getParameters()[0]; + DependencyDescriptor descriptor = (DependencyDescriptor) ParameterResolutionDelegate.resolveDependency( + parameter, 0, "customBeanName", AutowirableClass.class, beanFactory); + + assertThat(descriptor.getAnnotatedElement()).isEqualTo(constructor); + assertThat(descriptor.getMethodParameter().getParameter()).isEqualTo(parameter); + assertThat(descriptor.getDependencyName()).isEqualTo("customBeanName"); + assertThat(descriptor.usesStandardBeanLookup()).isTrue(); + } + void autowirableMethod( @Autowired String firstParameter, diff --git a/spring-beans/src/test/kotlin/org/springframework/beans/factory/BeanFactoryExtensionsTests.kt b/spring-beans/src/test/kotlin/org/springframework/beans/factory/BeanFactoryExtensionsTests.kt index fcca43c61f28..750fc095efda 100644 --- a/spring-beans/src/test/kotlin/org/springframework/beans/factory/BeanFactoryExtensionsTests.kt +++ b/spring-beans/src/test/kotlin/org/springframework/beans/factory/BeanFactoryExtensionsTests.kt @@ -21,6 +21,7 @@ import io.mockk.mockk import io.mockk.verify import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test +import org.springframework.core.ParameterizedTypeReference import org.springframework.core.ResolvableType /** @@ -53,7 +54,16 @@ class BeanFactoryExtensionsTests { fun `getBean with String and reified type parameters`() { val name = "foo" bf.getBean(name) - verify { bf.getBean(name, Foo::class.java) } + verify { bf.getBean(name, ofType>()) } + } + + @Test + fun `getBean with String and reified generic type parameters`() { + val name = "foo" + val foo = listOf(Foo()) + every { bf.getBean(name, ofType>>()) } returns foo + assertThat(bf.getBean>("foo")).isSameAs(foo) + verify { bf.getBean(name, ofType>>()) } } @Test diff --git a/spring-context-support/src/main/java/org/springframework/mail/SimpleMailMessage.java b/spring-context-support/src/main/java/org/springframework/mail/SimpleMailMessage.java index e1f8422caa02..3ac352301d8d 100644 --- a/spring-context-support/src/main/java/org/springframework/mail/SimpleMailMessage.java +++ b/spring-context-support/src/main/java/org/springframework/mail/SimpleMailMessage.java @@ -79,7 +79,7 @@ public SimpleMailMessage(SimpleMailMessage original) { this.to = copyOrNull(original.getTo()); this.cc = copyOrNull(original.getCc()); this.bcc = copyOrNull(original.getBcc()); - this.sentDate = original.getSentDate(); + this.sentDate = copyOrNull(original.sentDate); this.subject = original.getSubject(); this.text = original.getText(); } @@ -147,11 +147,11 @@ public void setBcc(String @Nullable ... bcc) { @Override public void setSentDate(@Nullable Date sentDate) { - this.sentDate = sentDate; + this.sentDate = copyOrNull(sentDate); } public @Nullable Date getSentDate() { - return this.sentDate; + return copyOrNull(this.sentDate); } @Override @@ -194,8 +194,8 @@ public void copyTo(MailMessage target) { if (getBcc() != null) { target.setBcc(copy(getBcc())); } - if (getSentDate() != null) { - target.setSentDate(getSentDate()); + if (this.sentDate != null) { + target.setSentDate((Date) this.sentDate.clone()); } if (getSubject() != null) { target.setSubject(getSubject()); @@ -247,6 +247,10 @@ public String toString() { return copy(state); } + private static @Nullable Date copyOrNull(@Nullable Date date) { + return (date != null ? (Date) date.clone() : null); + } + private static String[] copy(String[] state) { return state.clone(); } diff --git a/spring-context-support/src/test/java/org/springframework/mail/SimpleMailMessageTests.java b/spring-context-support/src/test/java/org/springframework/mail/SimpleMailMessageTests.java index 976f8abfbb4f..16e5052bf1e4 100644 --- a/spring-context-support/src/test/java/org/springframework/mail/SimpleMailMessageTests.java +++ b/spring-context-support/src/test/java/org/springframework/mail/SimpleMailMessageTests.java @@ -21,11 +21,16 @@ import java.util.List; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** + * Tests for {@link SimpleMailMessage}. + * * @author Dmitriy Kopylenko * @author Juergen Hoeller * @author Rick Evans @@ -98,6 +103,64 @@ void deepCopyOfStringArrayTypedFieldsOnCopyCtor() { assertThat(copy.getBcc()[0]).isEqualTo("us@mail.org"); } + @Test // gh-36626 + void setSentDateStoresACopy() { + SimpleMailMessage message = new SimpleMailMessage(); + Date sentDate = new Date(1234L); + + message.setSentDate(sentDate); + sentDate.setTime(0L); + + assertThat(message.getSentDate()).isEqualTo(new Date(1234L)); + } + + @Test // gh-36626 + void getSentDateReturnsACopy() { + SimpleMailMessage message = new SimpleMailMessage(); + Date sentDate = new Date(1234L); + message.setSentDate(sentDate); + + Date exportedDate = message.getSentDate(); + exportedDate.setTime(0L); + + assertThat(message.getSentDate()).isEqualTo(new Date(1234L)); + } + + @Test // gh-36626 + void copyConstructorCopiesSentDate() { + Date sentDate = new Date(1234L); + SimpleMailMessage original = new SimpleMailMessage(); + original.setSentDate(sentDate); + + SimpleMailMessage copy = new SimpleMailMessage(original); + sentDate.setTime(0L); + + Date copiedDate = copy.getSentDate(); + assertThat(copiedDate).isNotNull(); + copiedDate.setTime(1L); + + assertThat(original.getSentDate()).isEqualTo(new Date(1234L)); + assertThat(copy.getSentDate()).isEqualTo(new Date(1234L)); + } + + @Test // gh-36626 + void copyToCopiesSentDate() { + SimpleMailMessage source = new SimpleMailMessage(); + source.setSentDate(new Date(1234L)); + + MailMessage target = mock(); + source.copyTo(target); + + ArgumentCaptor dateCaptor = ArgumentCaptor.forClass(Date.class); + verify(target).setSentDate(dateCaptor.capture()); + + Date copiedDate = dateCaptor.getValue(); + assertThat(copiedDate).isNotNull(); + copiedDate.setTime(0L); + + assertThat(source.getSentDate()).isEqualTo(new Date(1234L)); + } + /** * Tests that two equal SimpleMailMessages have equal hash codes. */ diff --git a/spring-context/src/main/java/org/springframework/cache/interceptor/CacheAspectSupport.java b/spring-context/src/main/java/org/springframework/cache/interceptor/CacheAspectSupport.java index 36ce7f8802e8..2b2bd0110ad7 100644 --- a/spring-context/src/main/java/org/springframework/cache/interceptor/CacheAspectSupport.java +++ b/spring-context/src/main/java/org/springframework/cache/interceptor/CacheAspectSupport.java @@ -291,23 +291,6 @@ public void afterSingletonsInstantiated() { this.initialized = true; } - - /** - * Convenience method to return a String representation of this Method - * for use in logging. Can be overridden in subclasses to provide a - * different identifier for the given method. - * @param method the method we're interested in - * @param targetClass class the method is on - * @return log message identifying this method - * @see org.springframework.util.ClassUtils#getQualifiedMethodName - * @deprecated since 6.2.18 with no replacement, for removal in 7.1 - */ - @Deprecated(since = "6.2.18", forRemoval = true) - protected String methodIdentification(Method method, Class targetClass) { - Method specificMethod = ClassUtils.getMostSpecificMethod(method, targetClass); - return ClassUtils.getQualifiedMethodName(specificMethod); - } - protected Collection getCaches( CacheOperationInvocationContext context, CacheResolver cacheResolver) { diff --git a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassBeanDefinitionReader.java b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassBeanDefinitionReader.java index 56ceb8f5ab55..a5f64e71afea 100644 --- a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassBeanDefinitionReader.java +++ b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassBeanDefinitionReader.java @@ -423,12 +423,14 @@ private void loadBeanDefinitionsFromImportBeanDefinitionRegistrars( } private void loadBeanDefinitionsFromBeanRegistrars(MultiValueMap registrars) { - if (!(this.registry instanceof ListableBeanFactory beanFactory)) { - throw new IllegalStateException("Cannot support bean registrars since " + - this.registry.getClass().getName() + " does not implement ListableBeanFactory"); - } - registrars.values().forEach(registrarList -> registrarList.forEach(registrar -> registrar.register(new BeanRegistryAdapter( - this.registry, beanFactory, this.environment, registrar.getClass()), this.environment))); + registrars.values().forEach(registrarList -> registrarList.forEach(registrar -> { + if (!(this.registry instanceof ListableBeanFactory beanFactory)) { + throw new IllegalStateException("Cannot support bean registrars since " + + this.registry.getClass().getName() + " does not implement ListableBeanFactory"); + } + registrar.register(new BeanRegistryAdapter( + this.registry, beanFactory, this.environment, registrar.getClass()), this.environment); + })); } diff --git a/spring-context/src/main/java/org/springframework/context/support/AbstractApplicationContext.java b/spring-context/src/main/java/org/springframework/context/support/AbstractApplicationContext.java index eb1fbc194340..65b2f1ad3b02 100644 --- a/spring-context/src/main/java/org/springframework/context/support/AbstractApplicationContext.java +++ b/spring-context/src/main/java/org/springframework/context/support/AbstractApplicationContext.java @@ -132,6 +132,7 @@ * @author Sam Brannen * @author Sebastien Deleuze * @author Brian Clozel + * @author Yanming Zhou * @since January 21, 2001 * @see #refreshBeanFactory * @see #getBeanFactory @@ -1305,6 +1306,12 @@ public T getBean(String name, Class requiredType) throws BeansException { return getBeanFactory().getBean(name, requiredType); } + @Override + public T getBean(String name, ParameterizedTypeReference typeReference) throws BeansException { + assertBeanFactoryActive(); + return getBeanFactory().getBean(name, typeReference); + } + @Override public Object getBean(String name, @Nullable Object @Nullable ... args) throws BeansException { assertBeanFactoryActive(); diff --git a/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java b/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java index 9a75057a7f2d..fc4b08ccb067 100644 --- a/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java +++ b/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java @@ -44,6 +44,8 @@ import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.ApplicationContext; +import org.springframework.core.Ordered; +import org.springframework.core.PriorityOrdered; import org.springframework.core.io.ProtocolResolver; import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; @@ -105,6 +107,10 @@ */ public class GenericApplicationContext extends AbstractApplicationContext implements BeanDefinitionRegistry { + private static final String DEFERRED_REGISTRY_POST_PROCESSOR_BEAN_NAME = + GenericApplicationContext.class.getName() + ".deferredRegistryPostProcessor"; + + private final DefaultListableBeanFactory beanFactory; private @Nullable ResourceLoader resourceLoader; @@ -604,7 +610,13 @@ public void registerBean(@Nullable String beanName, Class beanClass, */ public void register(BeanRegistrar... registrars) { for (BeanRegistrar registrar : registrars) { - new BeanRegistryAdapter(this.beanFactory, getEnvironment(), registrar.getClass()).register(registrar); + DeferredRegistryPostProcessor pp = (DeferredRegistryPostProcessor) + this.beanFactory.getSingleton(DEFERRED_REGISTRY_POST_PROCESSOR_BEAN_NAME); + if (pp == null) { + pp = new DeferredRegistryPostProcessor(); + this.beanFactory.registerSingleton(DEFERRED_REGISTRY_POST_PROCESSOR_BEAN_NAME, pp); + } + pp.addRegistrar(registrar); } } @@ -648,4 +660,31 @@ public RootBeanDefinition cloneBeanDefinition() { } } + + /** + * Internal post-processor for invoking DeferredBeanRegistrars at the end + * of the BeanDefinitionRegistryPostProcessor PriorityOrdered phase, + * right before a potential ConfigurationClassPostProcessor. + */ + private class DeferredRegistryPostProcessor implements BeanDefinitionRegistryPostProcessor, PriorityOrdered { + + private final List registrars = new ArrayList<>(); + + public void addRegistrar(BeanRegistrar registrar) { + this.registrars.add(registrar); + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE - 1; // within PriorityOrdered, 1 before ConfigurationClassPostProcessor + } + + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + for (BeanRegistrar registrar : this.registrars) { + new BeanRegistryAdapter(beanFactory, getEnvironment(), registrar.getClass()).register(registrar); + } + } + } + } diff --git a/spring-context/src/main/java/org/springframework/context/support/ReloadableResourceBundleMessageSource.java b/spring-context/src/main/java/org/springframework/context/support/ReloadableResourceBundleMessageSource.java index 851aeed9da9c..305852097ccd 100644 --- a/spring-context/src/main/java/org/springframework/context/support/ReloadableResourceBundleMessageSource.java +++ b/spring-context/src/main/java/org/springframework/context/support/ReloadableResourceBundleMessageSource.java @@ -60,9 +60,13 @@ * are treated in a slightly different fashion than the "basenames" property of * {@link ResourceBundleMessageSource}. It follows the basic ResourceBundle rule of not * specifying file extension or language codes, but can refer to any Spring resource - * location (instead of being restricted to classpath resources). With a "classpath:" - * prefix, resources can still be loaded from the classpath, but "cacheSeconds" values - * other than "-1" (caching forever) might not work reliably in this case. + * location (instead of being restricted to classpath resources). + * + *

With a "classpath:" prefix, resources can still be loaded from the classpath, + * but "cacheSeconds" values other than "-1" (caching forever) are not expected to + * be effective in this case. As of 7.1, a "classpath*:" prefix is accepted as well, + * loading all classpath resources of the same fully-qualified name: for example, + * "classpath*:/messages.properties" or "classpath*:META-INF/messages.properties". * *

For a typical web application, message files could be placed in {@code WEB-INF}: * for example, a "WEB-INF/messages" basename would find a "WEB-INF/messages.properties", @@ -562,8 +566,8 @@ protected PropertiesHolder refreshProperties(String filename, @Nullable Properti */ protected Properties loadProperties(Resource resource, String filename) throws IOException { Properties props = newProperties(); - try (InputStream inputStream = resource.getInputStream()) { - String resourceFilename = resource.getFilename(); + String resourceFilename = resource.getFilename(); + resource.consumeContent(inputStream -> { if (resourceFilename != null && resourceFilename.endsWith(XML_EXTENSION)) { if (logger.isDebugEnabled()) { logger.debug("Loading properties [" + resource.getFilename() + "]"); @@ -594,8 +598,8 @@ protected Properties loadProperties(Resource resource, String filename) throws I this.propertiesPersister.load(props, inputStream); } } - return props; - } + }); + return props; } /** diff --git a/spring-context/src/main/java/org/springframework/jndi/support/SimpleJndiBeanFactory.java b/spring-context/src/main/java/org/springframework/jndi/support/SimpleJndiBeanFactory.java index cd6f5feb498e..224df23e114d 100644 --- a/spring-context/src/main/java/org/springframework/jndi/support/SimpleJndiBeanFactory.java +++ b/spring-context/src/main/java/org/springframework/jndi/support/SimpleJndiBeanFactory.java @@ -16,6 +16,7 @@ package org.springframework.jndi.support; +import java.lang.reflect.Type; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -59,6 +60,7 @@ * in particular if BeanFactory-style type checking is required. * * @author Juergen Hoeller + * @author Yanming Zhou * @since 2.5 * @see org.springframework.beans.factory.support.DefaultListableBeanFactory * @see org.springframework.context.annotation.CommonAnnotationBeanPostProcessor @@ -132,6 +134,17 @@ public T getBean(String name, Class requiredType) throws BeansException { } } + @Override + @SuppressWarnings("unchecked") + public T getBean(String name, ParameterizedTypeReference typeReference) throws BeansException { + Object bean = getBean(name); + Type requiredType = typeReference.getType(); + if (!ResolvableType.forType(requiredType).isInstance(bean)) { + throw new BeanNotOfRequiredTypeException(name, requiredType, bean.getClass()); + } + return (T) bean; + } + @Override public Object getBean(String name, @Nullable Object @Nullable ... args) throws BeansException { if (args != null) { diff --git a/spring-context/src/test/java/org/springframework/context/annotation/beanregistrar/BeanRegistrarConfigurationTests.java b/spring-context/src/test/java/org/springframework/context/annotation/beanregistrar/BeanRegistrarConfigurationTests.java index 04f40c8c0029..c0ffebbe88fe 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/beanregistrar/BeanRegistrarConfigurationTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/beanregistrar/BeanRegistrarConfigurationTests.java @@ -22,8 +22,10 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.testfixture.beans.factory.BarRegistrar; +import org.springframework.context.testfixture.beans.factory.ConditionalBeanRegistrar; import org.springframework.context.testfixture.beans.factory.FooRegistrar; import org.springframework.context.testfixture.beans.factory.GenericBeanRegistrar; import org.springframework.context.testfixture.beans.factory.ImportAwareBeanRegistrar; @@ -32,17 +34,28 @@ import org.springframework.context.testfixture.beans.factory.SampleBeanRegistrar.Foo; import org.springframework.context.testfixture.beans.factory.SampleBeanRegistrar.Init; import org.springframework.context.testfixture.context.annotation.registrar.BeanRegistrarConfiguration; +import org.springframework.context.testfixture.context.annotation.registrar.ComponentBeanRegistrar; +import org.springframework.context.testfixture.context.annotation.registrar.ComponentBeanRegistrar.IgnoredFromComponent; +import org.springframework.context.testfixture.context.annotation.registrar.ConditionalBeanRegistrarConfiguration; +import org.springframework.context.testfixture.context.annotation.registrar.ConfigurationBeanRegistrar; +import org.springframework.context.testfixture.context.annotation.registrar.ConfigurationBeanRegistrar.BeanBeanRegistrar; +import org.springframework.context.testfixture.context.annotation.registrar.ConfigurationBeanRegistrar.IgnoredFromBean; +import org.springframework.context.testfixture.context.annotation.registrar.ConfigurationBeanRegistrar.IgnoredFromConfiguration; import org.springframework.context.testfixture.context.annotation.registrar.GenericBeanRegistrarConfiguration; import org.springframework.context.testfixture.context.annotation.registrar.ImportAwareBeanRegistrarConfiguration; import org.springframework.context.testfixture.context.annotation.registrar.MultipleBeanRegistrarsConfiguration; +import org.springframework.context.testfixture.context.annotation.registrar.TestBeanConfiguration; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Tests for {@link BeanRegistrar} imported by @{@link org.springframework.context.annotation.Configuration}. * * @author Sebastien Deleuze + * @author Stephane Nicoll */ class BeanRegistrarConfigurationTests { @@ -59,6 +72,36 @@ void beanRegistrar() { assertThat(beanDefinition.getDescription()).isEqualTo("Custom description"); } + @Test + void beanRegistrarIgnoreBeans() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(ConfigurationBeanRegistrar.class); + assertThatNoException().isThrownBy(() -> context.getBean(ConfigurationBeanRegistrar.class)); + assertThatNoException().isThrownBy(() -> context.getBean(BeanBeanRegistrar.class)); + assertThatExceptionOfType(NoSuchBeanDefinitionException.class) + .isThrownBy(() -> context.getBean(IgnoredFromConfiguration.class)); + assertThatExceptionOfType(NoSuchBeanDefinitionException.class) + .isThrownBy(() -> context.getBean(IgnoredFromBean.class)); + } + + @Test + void beanRegistrarWithClasspathScanningIgnoreBeans() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.scan("org.springframework.context.testfixture.context.annotation.registrar"); + context.refresh(); + + assertThatNoException().isThrownBy(() -> context.getBean(ConfigurationBeanRegistrar.class)); + assertThatExceptionOfType(NoSuchBeanDefinitionException.class) + .isThrownBy(() -> context.getBean(IgnoredFromConfiguration.class)); + + assertThatNoException().isThrownBy(() -> context.getBean(BeanBeanRegistrar.class)); + assertThatExceptionOfType(NoSuchBeanDefinitionException.class) + .isThrownBy(() -> context.getBean(IgnoredFromBean.class)); + + assertThatNoException().isThrownBy(() -> context.getBean(ComponentBeanRegistrar.class)); + assertThatExceptionOfType(NoSuchBeanDefinitionException.class) + .isThrownBy(() -> context.getBean(IgnoredFromComponent.class)); + } + @Test void beanRegistrarWithProfile() { AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); @@ -105,4 +148,42 @@ void multipleBeanRegistrars() { assertThat(context.getBean(BarRegistrar.Bar.class)).isNotNull(); } + @Test + void programmaticBeanRegistrarIsInvokedBeforeConfigurationClassPostProcessor() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(TestBeanConfiguration.class); + context.register(new ConditionalBeanRegistrar()); + context.refresh(); + assertThat(context.containsBean("myTestBean")).isFalse(); + } + + @Test + void programmaticBeanRegistrarHandlesProgrammaticRegisteredBean() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(new ConditionalBeanRegistrar()); + context.registerBean("testBean", TestBean.class); + context.refresh(); + assertThat(context.containsBean("myTestBean")).isTrue(); + assertThat(context.getBean("myTestBean")).isInstanceOf(TestBean.class); + } + + @Test + void importedBeanRegistrarWithConditionNotMet() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(ConditionalBeanRegistrarConfiguration.class); + context.register(TestBeanConfiguration.class); + context.refresh(); + assertThat(context.containsBean("myTestBean")).isFalse(); + } + + @Test + void importedBeanRegistrarWithConditionMet() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(TestBeanConfiguration.class); + context.register(ConditionalBeanRegistrarConfiguration.class); + context.refresh(); + assertThat(context.containsBean("myTestBean")).isTrue(); + assertThat(context.getBean("myTestBean")).isInstanceOf(TestBean.class); + } + } diff --git a/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java b/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java index 0694ad4462fe..3a39f0620c0e 100644 --- a/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java +++ b/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java @@ -44,8 +44,10 @@ import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; +import org.springframework.context.testfixture.beans.factory.ConditionalBeanRegistrar; import org.springframework.context.testfixture.beans.factory.ImportAwareBeanRegistrar; import org.springframework.context.testfixture.beans.factory.SampleBeanRegistrar; import org.springframework.core.DecoratingProxy; @@ -645,6 +647,24 @@ void importAwareBeanRegistrar() { assertThat(context.getBean(ImportAwareBeanRegistrar.ClassNameHolder.class).className()).isNull(); } + @Test + void beanRegistrarWithConditionNotMet() { + GenericApplicationContext context = new GenericApplicationContext(); + context.register(new ConditionalBeanRegistrar()); + context.refresh(); + assertThat(context.containsBean("myTestBean")).isFalse(); + } + + @Test + void beanRegistrarWithConditionMet() { + GenericApplicationContext context = new GenericApplicationContext(); + context.register(new ConditionalBeanRegistrar()); + context.registerBean("testBean", TestBean.class); + context.refresh(); + assertThat(context.containsBean("myTestBean")).isTrue(); + assertThat(context.getBean("myTestBean")).isInstanceOf(TestBean.class); + } + private MergedBeanDefinitionPostProcessor registerMockMergedBeanDefinitionPostProcessor(GenericApplicationContext context) { MergedBeanDefinitionPostProcessor bpp = mock(); diff --git a/spring-context/src/test/kotlin/org/springframework/context/annotation/BeanRegistrarDslConfigurationTests.kt b/spring-context/src/test/kotlin/org/springframework/context/annotation/BeanRegistrarDslConfigurationTests.kt index 120e2c229ed8..f00ec5345629 100644 --- a/spring-context/src/test/kotlin/org/springframework/context/annotation/BeanRegistrarDslConfigurationTests.kt +++ b/spring-context/src/test/kotlin/org/springframework/context/annotation/BeanRegistrarDslConfigurationTests.kt @@ -84,7 +84,13 @@ class BeanRegistrarDslConfigurationTests { assertThat(context.getBeanProvider().singleOrNull()).isNotNull } + @Test + fun containsBean() { + AnnotationConfigApplicationContext(ContainsBeanRegistrarKotlinConfiguration::class.java) + } + class Foo + data class Bar(val foo: Foo) data class Baz(val message: String = "") class Init : InitializingBean { @@ -145,4 +151,18 @@ class BeanRegistrarDslConfigurationTests { private class ChainedBeanRegistrar : BeanRegistrarDsl({ register(SampleBeanRegistrar()) }) + + @Configuration + @Import(ContainsBeanRegistrar::class) + internal class ContainsBeanRegistrarKotlinConfiguration + + private class ContainsBeanRegistrar : BeanRegistrarDsl({ + assertThat(containsBean("foo")).isFalse() + assertThat(containsBean(Foo::class)).isFalse() + assertThat(containsBean()).isFalse() + registerBean("foo") + assertThat(containsBean("foo")).isTrue() + assertThat(containsBean(Foo::class)).isTrue() + assertThat(containsBean()).isTrue() + }) } diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/beans/factory/ConditionalBeanRegistrar.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/beans/factory/ConditionalBeanRegistrar.java new file mode 100644 index 000000000000..4aff96075ca9 --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/beans/factory/ConditionalBeanRegistrar.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.beans.factory; + +import org.springframework.beans.factory.BeanRegistrar; +import org.springframework.beans.factory.BeanRegistry; +import org.springframework.beans.testfixture.beans.TestBean; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.env.Environment; + +public class ConditionalBeanRegistrar implements BeanRegistrar { + + @Override + public void register(BeanRegistry registry, Environment env) { + if (registry.containsBean("testBean") && + registry.containsBean(TestBean.class) && + registry.containsBean(new ParameterizedTypeReference>() { + })) { + registry.registerBean("myTestBean", TestBean.class); + } + } +} diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/ComponentBeanRegistrar.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/ComponentBeanRegistrar.java new file mode 100644 index 000000000000..7f6ba15a13cd --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/ComponentBeanRegistrar.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.context.annotation.registrar; + +import org.springframework.beans.factory.BeanRegistrar; +import org.springframework.beans.factory.BeanRegistry; +import org.springframework.core.env.Environment; +import org.springframework.stereotype.Component; + +@Component +public class ComponentBeanRegistrar implements BeanRegistrar { + + @Override + public void register(BeanRegistry registry, Environment env) { + registry.registerBean(IgnoredFromComponent.class); + } + + + public record IgnoredFromComponent() {} + +} diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/ConditionalBeanRegistrarConfiguration.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/ConditionalBeanRegistrarConfiguration.java new file mode 100644 index 000000000000..892fd0da1072 --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/ConditionalBeanRegistrarConfiguration.java @@ -0,0 +1,26 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.context.annotation.registrar; + +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.context.testfixture.beans.factory.ConditionalBeanRegistrar; + +@Configuration +@Import(ConditionalBeanRegistrar.class) +public class ConditionalBeanRegistrarConfiguration { +} diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/ConfigurationBeanRegistrar.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/ConfigurationBeanRegistrar.java new file mode 100644 index 000000000000..5093e43390fb --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/ConfigurationBeanRegistrar.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.context.annotation.registrar; + +import org.springframework.beans.factory.BeanRegistrar; +import org.springframework.beans.factory.BeanRegistry; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.env.Environment; + +@Configuration +public class ConfigurationBeanRegistrar implements BeanRegistrar { + + @Override + public void register(BeanRegistry registry, Environment env) { + registry.registerBean(IgnoredFromConfiguration.class); + } + + @Bean + BeanBeanRegistrar beanBeanRegistrar() { + return new BeanBeanRegistrar(); + } + + public static class BeanBeanRegistrar implements BeanRegistrar { + @Override + public void register(BeanRegistry registry, Environment env) { + registry.registerBean(IgnoredFromBean.class); + } + } + + + public record IgnoredFromConfiguration() {} + + public record IgnoredFromBean() {} +} diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/TestBeanConfiguration.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/TestBeanConfiguration.java new file mode 100644 index 000000000000..c786f7f4fea6 --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/annotation/registrar/TestBeanConfiguration.java @@ -0,0 +1,31 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.context.annotation.registrar; + +import org.springframework.beans.testfixture.beans.TestBean; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class TestBeanConfiguration { + + @Bean + public TestBean testBean() { + return new TestBean(); + } + +} diff --git a/spring-core-test/src/main/java/org/springframework/aot/agent/MethodReference.java b/spring-core-test/src/main/java/org/springframework/aot/agent/MethodReference.java index 0ac867a17791..966f2297c782 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/agent/MethodReference.java +++ b/spring-core-test/src/main/java/org/springframework/aot/agent/MethodReference.java @@ -20,6 +20,8 @@ import org.jspecify.annotations.Nullable; +import org.springframework.util.ClassUtils; + /** * Reference to a Java method, identified by its owner class and the method name. * @@ -43,7 +45,7 @@ private MethodReference(String className, String methodName) { } public static MethodReference of(Class klass, String methodName) { - return new MethodReference(klass.getCanonicalName(), methodName); + return new MethodReference(ClassUtils.getCanonicalName(klass), methodName); } /** diff --git a/spring-core-test/src/main/java/org/springframework/aot/agent/RecordedInvocation.java b/spring-core-test/src/main/java/org/springframework/aot/agent/RecordedInvocation.java index d56f98d4bffc..19c6ad45e59a 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/agent/RecordedInvocation.java +++ b/spring-core-test/src/main/java/org/springframework/aot/agent/RecordedInvocation.java @@ -25,6 +25,7 @@ import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; /** * Record of an invocation of a method relevant to {@link org.springframework.aot.hint.RuntimeHints}. @@ -181,7 +182,7 @@ public String toString() { else { Class instanceType = (getInstance() instanceof Class clazz) ? clazz : getInstance().getClass(); return "<%s> invocation of <%s> on type <%s> with arguments %s".formatted( - getHintType().hintClassName(), getMethodReference(), instanceType.getCanonicalName(), getArguments()); + getHintType().hintClassName(), getMethodReference(), ClassUtils.getCanonicalName(instanceType), getArguments()); } } diff --git a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java index 568d94b6c03f..020a00064c98 100644 --- a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java +++ b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java @@ -134,9 +134,10 @@ public static Publisher invokeSuspendingFunction( Object arg = args[index]; if (!(parameter.isOptional() && arg == null)) { KType type = parameter.getType(); - if (!type.isMarkedNullable() && + if (!(type.isMarkedNullable() && arg == null) && type.getClassifier() instanceof KClass kClass && - KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass))) { + KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass)) && + !JvmClassMappingKt.getJavaClass(kClass).isInstance(arg)) { arg = box(kClass, arg); } argMap.put(parameter, arg); @@ -166,9 +167,10 @@ public static Publisher invokeSuspendingFunction( private static Object box(KClass kClass, @Nullable Object arg) { KFunction constructor = Objects.requireNonNull(KClasses.getPrimaryConstructor(kClass)); KType type = constructor.getParameters().get(0).getType(); - if (!type.isMarkedNullable() && + if (!(type.isMarkedNullable() && arg == null) && type.getClassifier() instanceof KClass parameterClass && - KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(parameterClass))) { + KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(parameterClass)) && + !JvmClassMappingKt.getJavaClass(parameterClass).isInstance(arg)) { arg = box(parameterClass, arg); } if (!KCallablesJvm.isAccessible(constructor)) { diff --git a/spring-core/src/main/java/org/springframework/core/GenericTypeResolver.java b/spring-core/src/main/java/org/springframework/core/GenericTypeResolver.java index 6ea0eae822cc..083964732733 100644 --- a/spring-core/src/main/java/org/springframework/core/GenericTypeResolver.java +++ b/spring-core/src/main/java/org/springframework/core/GenericTypeResolver.java @@ -160,6 +160,10 @@ public static Type resolveType(Type genericType, @Nullable Class contextClass resolvedTypeVariable = ResolvableType.forVariableBounds(typeVariable); } if (resolvedTypeVariable != ResolvableType.NONE) { + Type type = resolvedTypeVariable.getType(); + if (type instanceof ParameterizedType) { + return resolveType(type, contextClass); + } Class resolved = resolvedTypeVariable.resolve(); if (resolved != null) { return resolved; diff --git a/spring-core/src/main/java/org/springframework/core/ResolvableType.java b/spring-core/src/main/java/org/springframework/core/ResolvableType.java index af01421dc27f..9273abd4d11f 100644 --- a/spring-core/src/main/java/org/springframework/core/ResolvableType.java +++ b/spring-core/src/main/java/org/springframework/core/ResolvableType.java @@ -22,6 +22,7 @@ import java.lang.reflect.Field; import java.lang.reflect.GenericArrayType; import java.lang.reflect.Method; +import java.lang.reflect.Parameter; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.lang.reflect.TypeVariable; @@ -1152,7 +1153,6 @@ public static ResolvableType forClass(Class baseType, Class implementation * @see #forClassWithGenerics(Class, ResolvableType...) */ public static ResolvableType forClassWithGenerics(Class clazz, Class... generics) { - Assert.notNull(clazz, "Class must not be null"); Assert.notNull(generics, "Generics array must not be null"); ResolvableType[] resolvableGenerics = new ResolvableType[generics.length]; for (int i = 0; i < generics.length; i++) { @@ -1281,6 +1281,18 @@ public static ResolvableType forField(Field field, int nestingLevel, @Nullable C return forType(null, new FieldTypeProvider(field), owner.asVariableResolver()).getNested(nestingLevel); } + /** + * Return a {@code ResolvableType} for the specified {@link Parameter}. + *

This is a convenience factory method for scenarios where a {@code Parameter} + * descriptor is already available. + * @param parameter the source parameter + * @return a {@code ResolvableType} for the specified parameter + * @since 7.1 + */ + public static ResolvableType forParameter(Parameter parameter) { + return forMethodParameter(MethodParameter.forParameter(parameter)); + } + /** * Return a {@code ResolvableType} for the specified {@link Constructor} parameter. * @param constructor the source constructor (must not be {@code null}) @@ -1289,7 +1301,6 @@ public static ResolvableType forField(Field field, int nestingLevel, @Nullable C * @see #forConstructorParameter(Constructor, int, Class) */ public static ResolvableType forConstructorParameter(Constructor constructor, int parameterIndex) { - Assert.notNull(constructor, "Constructor must not be null"); return forMethodParameter(new MethodParameter(constructor, parameterIndex)); } @@ -1307,7 +1318,6 @@ public static ResolvableType forConstructorParameter(Constructor constructor, public static ResolvableType forConstructorParameter(Constructor constructor, int parameterIndex, Class implementationClass) { - Assert.notNull(constructor, "Constructor must not be null"); MethodParameter methodParameter = new MethodParameter(constructor, parameterIndex, implementationClass); return forMethodParameter(methodParameter); } @@ -1319,7 +1329,6 @@ public static ResolvableType forConstructorParameter(Constructor constructor, * @see #forMethodReturnType(Method, Class) */ public static ResolvableType forMethodReturnType(Method method) { - Assert.notNull(method, "Method must not be null"); return forMethodParameter(new MethodParameter(method, -1)); } @@ -1333,7 +1342,6 @@ public static ResolvableType forMethodReturnType(Method method) { * @see #forMethodReturnType(Method) */ public static ResolvableType forMethodReturnType(Method method, Class implementationClass) { - Assert.notNull(method, "Method must not be null"); MethodParameter methodParameter = new MethodParameter(method, -1, implementationClass); return forMethodParameter(methodParameter); } @@ -1347,7 +1355,6 @@ public static ResolvableType forMethodReturnType(Method method, Class impleme * @see #forMethodParameter(MethodParameter) */ public static ResolvableType forMethodParameter(Method method, int parameterIndex) { - Assert.notNull(method, "Method must not be null"); return forMethodParameter(new MethodParameter(method, parameterIndex)); } @@ -1363,7 +1370,6 @@ public static ResolvableType forMethodParameter(Method method, int parameterInde * @see #forMethodParameter(MethodParameter) */ public static ResolvableType forMethodParameter(Method method, int parameterIndex, Class implementationClass) { - Assert.notNull(method, "Method must not be null"); MethodParameter methodParameter = new MethodParameter(method, parameterIndex, implementationClass); return forMethodParameter(methodParameter); } diff --git a/spring-core/src/main/java/org/springframework/core/annotation/AbstractMergedAnnotation.java b/spring-core/src/main/java/org/springframework/core/annotation/AbstractMergedAnnotation.java index 0c844b10b9a5..7f2e090147c9 100644 --- a/spring-core/src/main/java/org/springframework/core/annotation/AbstractMergedAnnotation.java +++ b/spring-core/src/main/java/org/springframework/core/annotation/AbstractMergedAnnotation.java @@ -24,6 +24,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; /** * Abstract base class for {@link MergedAnnotation} implementations. @@ -215,7 +216,7 @@ private T getRequiredAttributeValue(String attributeName, Class type) { T value = getAttributeValue(attributeName, type); if (value == null) { throw new NoSuchElementException("No attribute named '" + attributeName + - "' present in merged annotation " + getType().getName()); + "' present in merged annotation " + ClassUtils.getCanonicalName(getType())); } return value; } diff --git a/spring-core/src/main/java/org/springframework/core/annotation/AnnotationAttributes.java b/spring-core/src/main/java/org/springframework/core/annotation/AnnotationAttributes.java index a4be0aef1dd6..82941489e6d6 100644 --- a/spring-core/src/main/java/org/springframework/core/annotation/AnnotationAttributes.java +++ b/spring-core/src/main/java/org/springframework/core/annotation/AnnotationAttributes.java @@ -25,6 +25,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; /** @@ -126,7 +127,7 @@ public AnnotationAttributes(Class annotationType) { AnnotationAttributes(Class annotationType, boolean validated) { Assert.notNull(annotationType, "'annotationType' must not be null"); this.annotationType = annotationType; - this.displayName = annotationType.getName(); + this.displayName = ClassUtils.getCanonicalName(annotationType); this.validated = validated; } diff --git a/spring-core/src/main/java/org/springframework/core/annotation/AnnotationTypeMapping.java b/spring-core/src/main/java/org/springframework/core/annotation/AnnotationTypeMapping.java index 19f123e6ff10..563341654e1b 100644 --- a/spring-core/src/main/java/org/springframework/core/annotation/AnnotationTypeMapping.java +++ b/spring-core/src/main/java/org/springframework/core/annotation/AnnotationTypeMapping.java @@ -32,6 +32,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.core.annotation.AnnotationTypeMapping.MirrorSets.MirrorSet; +import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; @@ -649,7 +650,7 @@ int resolve(@Nullable Object source, @Nullable A annotation, ValueExtractor throw new AnnotationConfigurationException(String.format( "Different @AliasFor mirror values for annotation [%s]%s; attribute '%s' " + "and its alias '%s' are declared with values of [%s] and [%s].", - getAnnotationType().getName(), on, + ClassUtils.getCanonicalName(getAnnotationType()), on, attributes.get(result).getName(), attribute.getName(), ObjectUtils.nullSafeToString(lastValue), diff --git a/spring-core/src/main/java/org/springframework/core/annotation/AnnotationTypeMappings.java b/spring-core/src/main/java/org/springframework/core/annotation/AnnotationTypeMappings.java index 7c9f595b2300..7f1d396badf7 100644 --- a/spring-core/src/main/java/org/springframework/core/annotation/AnnotationTypeMappings.java +++ b/spring-core/src/main/java/org/springframework/core/annotation/AnnotationTypeMappings.java @@ -28,6 +28,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.lang.Contract; +import org.springframework.util.ClassUtils; import org.springframework.util.ConcurrentReferenceHashMap; /** @@ -124,7 +125,8 @@ private void addIfPossible(Deque queue, @Nullable Annotat AnnotationUtils.rethrowAnnotationConfigurationException(ex); if (failureLogger.isEnabled()) { failureLogger.log("Failed to introspect " + (meta ? "meta-annotation @" : "annotation @") + - annotationType.getName(), (source != null ? source.getAnnotationType() : null), ex); + ClassUtils.getCanonicalName(annotationType), + (source != null ? ClassUtils.getCanonicalName(source.getAnnotationType()) : null), ex); } } } diff --git a/spring-core/src/main/java/org/springframework/core/annotation/AttributeMethods.java b/spring-core/src/main/java/org/springframework/core/annotation/AttributeMethods.java index 0e998b84b1a0..44007afb4fe9 100644 --- a/spring-core/src/main/java/org/springframework/core/annotation/AttributeMethods.java +++ b/spring-core/src/main/java/org/springframework/core/annotation/AttributeMethods.java @@ -26,6 +26,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.ConcurrentReferenceHashMap; import org.springframework.util.ReflectionUtils; @@ -143,8 +144,9 @@ void validate(Annotation annotation) { throw ex; } catch (Throwable ex) { - throw new IllegalStateException("Could not obtain annotation attribute value for " + - get(i).getName() + " declared on @" + getName(annotation.annotationType()), ex); + throw new IllegalStateException( + "Could not obtain annotation attribute value for " + get(i).getName() + + " declared on @" + ClassUtils.getCanonicalName(annotation.annotationType()), ex); } } } @@ -305,13 +307,8 @@ static String describe(@Nullable Class annotationType, @Nullable String attri if (attributeName == null) { return "(none)"; } - String in = (annotationType != null ? " in annotation [" + annotationType.getName() + "]" : ""); + String in = (annotationType != null ? " in annotation [" + ClassUtils.getCanonicalName(annotationType) + "]" : ""); return "attribute '" + attributeName + "'" + in; } - private static String getName(Class clazz) { - String canonicalName = clazz.getCanonicalName(); - return (canonicalName != null ? canonicalName : clazz.getName()); - } - } diff --git a/spring-core/src/main/java/org/springframework/core/annotation/RepeatableContainers.java b/spring-core/src/main/java/org/springframework/core/annotation/RepeatableContainers.java index 53d3da336d85..b8462d8f3799 100644 --- a/spring-core/src/main/java/org/springframework/core/annotation/RepeatableContainers.java +++ b/spring-core/src/main/java/org/springframework/core/annotation/RepeatableContainers.java @@ -26,6 +26,7 @@ import org.springframework.lang.Contract; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.ConcurrentReferenceHashMap; import org.springframework.util.ObjectUtils; @@ -312,7 +313,7 @@ private static class ExplicitRepeatableContainer extends RepeatableContainers { if (returnType.componentType() != repeatable) { throw new AnnotationConfigurationException( "Container type [%s] must declare a 'value' attribute for an array of type [%s]" - .formatted(container.getName(), repeatable.getName())); + .formatted(ClassUtils.getCanonicalName(container), ClassUtils.getCanonicalName(repeatable))); } } catch (AnnotationConfigurationException ex) { @@ -321,7 +322,7 @@ private static class ExplicitRepeatableContainer extends RepeatableContainers { catch (Throwable ex) { throw new AnnotationConfigurationException( "Invalid declaration of container type [%s] for repeatable annotation [%s]" - .formatted(container.getName(), repeatable.getName()), ex); + .formatted(ClassUtils.getCanonicalName(container), ClassUtils.getCanonicalName(repeatable)), ex); } this.repeatable = repeatable; this.container = container; @@ -331,7 +332,7 @@ private static class ExplicitRepeatableContainer extends RepeatableContainers { private Class deduceContainer(Class repeatable) { Repeatable annotation = repeatable.getAnnotation(Repeatable.class); Assert.notNull(annotation, () -> "Annotation type must be a repeatable annotation: " + - "failed to resolve container type for " + repeatable.getName()); + "failed to resolve container type for " + ClassUtils.getCanonicalName(repeatable)); return annotation.value(); } diff --git a/spring-core/src/main/java/org/springframework/core/annotation/SynthesizedMergedAnnotationInvocationHandler.java b/spring-core/src/main/java/org/springframework/core/annotation/SynthesizedMergedAnnotationInvocationHandler.java index 24c21acf6af2..d05917169245 100644 --- a/spring-core/src/main/java/org/springframework/core/annotation/SynthesizedMergedAnnotationInvocationHandler.java +++ b/spring-core/src/main/java/org/springframework/core/annotation/SynthesizedMergedAnnotationInvocationHandler.java @@ -136,7 +136,7 @@ private Integer computeHashCode() { private String annotationToString() { String string = this.string; if (string == null) { - StringBuilder builder = new StringBuilder("@").append(getName(this.type)).append('('); + StringBuilder builder = new StringBuilder("@").append(ClassUtils.getCanonicalName(this.type)).append('('); if (this.attributes.size() == 1 && this.attributes.get(0).getName().equals(MergedAnnotation.VALUE)) { // Don't prepend "value=" for an annotation that only declares a "value" attribute. builder.append(toString(getAttributeValue(this.attributes.get(0)))); @@ -208,7 +208,7 @@ private String toString(Object value) { return e.name(); } if (type == Class.class) { - return getName((Class) value) + ".class"; + return ClassUtils.getCanonicalName((Class) value) + ".class"; } return String.valueOf(value); } @@ -218,7 +218,7 @@ private Object getAttributeValue(Method method) { Class type = ClassUtils.resolvePrimitiveIfNecessary(method.getReturnType()); return this.annotation.getValue(attributeName, type).orElseThrow( () -> new NoSuchElementException("No value found for attribute named '" + attributeName + - "' in merged annotation " + getName(this.annotation.getType()))); + "' in merged annotation " + ClassUtils.getCanonicalName(this.annotation.getType()))); }); // Clone non-empty arrays so that users cannot alter the contents of values in our cache. @@ -272,9 +272,4 @@ static A createProxy(MergedAnnotation annotation, Clas return (A) Proxy.newProxyInstance(classLoader, interfaces, handler); } - private static String getName(Class clazz) { - String canonicalName = clazz.getCanonicalName(); - return (canonicalName != null ? canonicalName : clazz.getName()); - } - } diff --git a/spring-core/src/main/java/org/springframework/core/annotation/TypeMappedAnnotation.java b/spring-core/src/main/java/org/springframework/core/annotation/TypeMappedAnnotation.java index 781078f65804..f28f77605825 100644 --- a/spring-core/src/main/java/org/springframework/core/annotation/TypeMappedAnnotation.java +++ b/spring-core/src/main/java/org/springframework/core/annotation/TypeMappedAnnotation.java @@ -450,7 +450,12 @@ private Object getRequiredValue(int attributeIndex, String attributeName) { value = clazz.getName(); } else if (value instanceof String str && type == Class.class) { - value = ClassUtils.resolveClassName(str, getClassLoader()); + try { + value = ClassUtils.forName(str, getClassLoader()); + } + catch (ClassNotFoundException | LinkageError ex) { + throw new TypeNotPresentException(str, ex); + } } else if (value instanceof Class[] classes && type == String[].class) { String[] names = new String[classes.length]; @@ -461,8 +466,14 @@ else if (value instanceof Class[] classes && type == String[].class) { } else if (value instanceof String[] names && type == Class[].class) { Class[] classes = new Class[names.length]; + ClassLoader classLoader = getClassLoader(); for (int i = 0; i < names.length; i++) { - classes[i] = ClassUtils.resolveClassName(names[i], getClassLoader()); + try { + classes[i] = ClassUtils.forName(names[i], classLoader); + } + catch (ClassNotFoundException | LinkageError ex) { + throw new TypeNotPresentException(names[i], ex); + } } value = classes; } @@ -479,7 +490,7 @@ else if (value instanceof MergedAnnotation[] annotations && } if (!type.isInstance(value)) { throw new IllegalArgumentException("Unable to adapt value of type " + - value.getClass().getName() + " to " + type.getName()); + ClassUtils.getCanonicalName(value.getClass()) + " to " + ClassUtils.getCanonicalName(type)); } return (T) value; } @@ -514,8 +525,8 @@ private Object adaptForAttribute(Method attribute, Object value) { } if (!attributeType.isInstance(value)) { throw new IllegalStateException("Attribute '" + attribute.getName() + - "' in annotation " + getType().getName() + " should be compatible with " + - attributeType.getName() + " but a " + value.getClass().getName() + + "' in annotation " + ClassUtils.getCanonicalName(getType()) + " should be compatible with " + + ClassUtils.getCanonicalName(attributeType) + " but a " + ClassUtils.getCanonicalName(value.getClass()) + " value was returned"); } return value; @@ -572,7 +583,7 @@ private int getAttributeIndex(String attributeName, boolean required) { int attributeIndex = (isFiltered(attributeName) ? -1 : this.mapping.getAttributes().indexOf(attributeName)); if (attributeIndex == -1 && required) { throw new NoSuchElementException("No attribute named '" + attributeName + - "' present in merged annotation " + getType().getName()); + "' present in merged annotation " + ClassUtils.getCanonicalName(getType())); } return attributeIndex; } @@ -649,9 +660,10 @@ static MergedAnnotation of( catch (Exception ex) { AnnotationUtils.rethrowAnnotationConfigurationException(ex); if (logger.isEnabled()) { - String type = mapping.getAnnotationType().getName(); + String type = ClassUtils.getCanonicalName(mapping.getAnnotationType()); String item = (mapping.getDistance() == 0 ? "annotation " + type : - "meta-annotation " + type + " from " + mapping.getRoot().getAnnotationType().getName()); + "meta-annotation " + type + " from " + + ClassUtils.getCanonicalName(mapping.getRoot().getAnnotationType())); logger.log("Failed to introspect " + item, source, ex); } return null; diff --git a/spring-core/src/main/java/org/springframework/core/convert/TypeDescriptor.java b/spring-core/src/main/java/org/springframework/core/convert/TypeDescriptor.java index 3759eff1733b..78ba2283b9e6 100644 --- a/spring-core/src/main/java/org/springframework/core/convert/TypeDescriptor.java +++ b/spring-core/src/main/java/org/springframework/core/convert/TypeDescriptor.java @@ -541,7 +541,7 @@ public int hashCode() { public String toString() { StringBuilder builder = new StringBuilder(); for (Annotation ann : getAnnotations()) { - builder.append('@').append(getName(ann.annotationType())).append(' '); + builder.append('@').append(ClassUtils.getCanonicalName(ann.annotationType())).append(' '); } builder.append(getResolvableType()); return builder.toString(); @@ -726,11 +726,6 @@ public static TypeDescriptor map(Class mapType, @Nullable TypeDescriptor keyT return new TypeDescriptor(property).nested(nestingLevel); } - private static String getName(Class clazz) { - String canonicalName = clazz.getCanonicalName(); - return (canonicalName != null ? canonicalName : clazz.getName()); - } - private interface AnnotatedElementSupplier extends Supplier, Serializable { } diff --git a/spring-core/src/main/java/org/springframework/core/env/ProfilesParser.java b/spring-core/src/main/java/org/springframework/core/env/ProfilesParser.java index 9d477c9d7d97..326f6f659c4e 100644 --- a/spring-core/src/main/java/org/springframework/core/env/ProfilesParser.java +++ b/spring-core/src/main/java/org/springframework/core/env/ProfilesParser.java @@ -88,13 +88,10 @@ private static Profiles parseTokens(String expression, StringTokenizer tokens, C } case "!" -> elements.add(not(parseTokens(expression, tokens, Context.NEGATE))); case ")" -> { - Profiles merged = merge(expression, elements, operator); if (context == Context.PARENTHESIS) { - return merged; + return merge(expression, elements, operator); } - elements.clear(); - elements.add(merged); - operator = null; + assertWellFormed(expression, false); } default -> { Profiles value = equals(token); @@ -105,6 +102,7 @@ private static Profiles parseTokens(String expression, StringTokenizer tokens, C } } } + assertWellFormed(expression, context != Context.PARENTHESIS); return merge(expression, elements, operator); } diff --git a/spring-core/src/main/java/org/springframework/core/io/DefaultResourceLoader.java b/spring-core/src/main/java/org/springframework/core/io/DefaultResourceLoader.java index 164d85d9c00a..9ec371e39bae 100644 --- a/spring-core/src/main/java/org/springframework/core/io/DefaultResourceLoader.java +++ b/spring-core/src/main/java/org/springframework/core/io/DefaultResourceLoader.java @@ -16,10 +16,19 @@ package org.springframework.core.io; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.io.SequenceInputStream; import java.net.MalformedURLException; import java.net.URL; +import java.net.URLConnection; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -30,6 +39,7 @@ import org.springframework.util.ClassUtils; import org.springframework.util.ResourceUtils; import org.springframework.util.StringUtils; +import org.springframework.util.function.IOConsumer; /** * Default implementation of the {@link ResourceLoader} interface. @@ -158,6 +168,9 @@ public Resource getResource(String location) { else if (location.startsWith(CLASSPATH_URL_PREFIX)) { return new ClassPathResource(location.substring(CLASSPATH_URL_PREFIX.length()), getClassLoader()); } + else if (location.startsWith(CLASSPATH_ALL_URL_PREFIX)) { + return new ClassPathAllResource(location.substring(CLASSPATH_ALL_URL_PREFIX.length()), getClassLoader()); + } else { try { // Try to parse the location as a URL... @@ -187,6 +200,96 @@ protected Resource getResourceByPath(String path) { } + /** + * A multi-content ClassPathResource handle that can expose the content + * from all matching resources in the classpath. + * @since 7.1 + */ + protected static class ClassPathAllResource extends ClassPathResource { + + public ClassPathAllResource(String path, @Nullable ClassLoader classLoader) { + super(path, classLoader); + } + + @Override + public boolean isFile() { + return false; + } + + @Override + public URL getURL() throws IOException { + throw new FileNotFoundException( + getDescription() + " cannot be resolved to single URL or File - use 'classpath:' instead"); + } + + @Override + public long contentLength() throws IOException { + long combinedLength = 0; + ClassLoader cl = getClassLoader(); + Enumeration urls = (cl != null ? cl.getResources(getPath()) : ClassLoader.getSystemResources(getPath())); + while (urls.hasMoreElements()) { + URLConnection con = urls.nextElement().openConnection(); + long length = con.getContentLengthLong(); + if (length < 0) { + return -1; + } + combinedLength += length; + } + return combinedLength; + } + + @Override + public InputStream getInputStream() throws IOException { + List streams = new ArrayList<>(); + ClassLoader cl = getClassLoader(); + Enumeration urls = (cl != null ? cl.getResources(getPath()) : ClassLoader.getSystemResources(getPath())); + while (urls.hasMoreElements()) { + try { + streams.add(urls.nextElement().openStream()); + } + catch (IOException ex) { + streams.forEach(stream -> { + try { + stream.close(); + } + catch (IOException ex2) { + ex.addSuppressed(ex2); + } + }); + throw ex; + } + } + return switch (streams.size()) { + case 0 -> InputStream.nullInputStream(); + case 1 -> streams.get(0); + default -> new SequenceInputStream(Collections.enumeration(streams)); + }; + } + + @Override + public void consumeContent(IOConsumer consumer) throws IOException { + ClassLoader cl = getClassLoader(); + Enumeration urls = (cl != null ? cl.getResources(getPath()) : ClassLoader.getSystemResources(getPath())); + while (urls.hasMoreElements()) { + try (InputStream inputStream = urls.nextElement().openStream()) { + consumer.accept(inputStream); + } + } + } + + @Override + public Resource createRelative(String relativePath) { + String pathToUse = StringUtils.applyRelativePath(getPath(), relativePath); + return new ClassPathAllResource(pathToUse, getClassLoader()); + } + + @Override + public String getDescription() { + return "'classpath*:' resource [" + getPath() + "]"; + } + } + + /** * ClassPathResource that explicitly expresses a context-relative path * through implementing the ContextResource interface. diff --git a/spring-core/src/main/java/org/springframework/core/io/Resource.java b/spring-core/src/main/java/org/springframework/core/io/Resource.java index 54b7084a43eb..0f4228e6508f 100644 --- a/spring-core/src/main/java/org/springframework/core/io/Resource.java +++ b/spring-core/src/main/java/org/springframework/core/io/Resource.java @@ -30,6 +30,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.util.FileCopyUtils; +import org.springframework.util.function.IOConsumer; /** * Interface for a resource descriptor that abstracts from the actual @@ -156,6 +157,26 @@ default ReadableByteChannel readableChannel() throws IOException { return Channels.newChannel(getInputStream()); } + /** + * Process the contents of this resource through the given consumer callback. + *

The given consumer will be invoked a single time by default - but may + * also be invoked multiple times in case of a multi-content resource handle, + * for example returned from a + * {@link ResourceLoader#getResource getResource("classpath*:...")} call. + * While {@link #getInputStream()} returns a merged sequence of content + * in such a case, this method performs one callback per file content. + * @param consumer a consumer for each InputStream + * @throws IOException in case of general resolution/reading failures + * @since 7.1 + * @see #getInputStream() + * @see ResourceLoader#CLASSPATH_ALL_URL_PREFIX + */ + default void consumeContent(IOConsumer consumer) throws IOException { + try (InputStream inputStream = getInputStream()) { + consumer.accept(inputStream); + } + } + /** * Return the contents of this resource as a byte array. * @return the contents of this resource as byte array diff --git a/spring-core/src/main/java/org/springframework/core/io/ResourceLoader.java b/spring-core/src/main/java/org/springframework/core/io/ResourceLoader.java index 2e281c21c0d5..8b0ef92a7d71 100644 --- a/spring-core/src/main/java/org/springframework/core/io/ResourceLoader.java +++ b/spring-core/src/main/java/org/springframework/core/io/ResourceLoader.java @@ -42,18 +42,47 @@ */ public interface ResourceLoader { - /** Pseudo URL prefix for loading from the class path: "classpath:". */ + /** + * Pseudo URL prefix for loading from the class path: {@value}. + *

This retrieves the "nearest" matching resource in the classpath. + * @see ClassLoader#getResource + */ String CLASSPATH_URL_PREFIX = ResourceUtils.CLASSPATH_URL_PREFIX; + /** + * Pseudo URL prefix for all matching resources from the class path: {@value}. + *

This differs from the common {@link #CLASSPATH_URL_PREFIX "classpath:"} prefix + * in that it retrieves all matching resources for a given path. For example, to + * locate all "messages.properties" files in the root of all deployed JAR files + * you can use the location pattern {@code "classpath*:/messages.properties"}. + *

As of Spring Framework 6.0, the semantics for the {@code "classpath*:"} + * prefix have been expanded to include the module path as well as the class path. + *

As of Spring Framework 7.1, this prefix is supported for {@link #getResource} + * calls as well (exposing a multi-content resource handle), rather than just for + * {@link org.springframework.core.io.support.ResourcePatternResolver#getResources}. + * @since 7.1 (previously only declared on the + * {@link org.springframework.core.io.support.ResourcePatternResolver} sub-interface) + * @see ClassLoader#getResources + * @see Resource#consumeContent + */ + String CLASSPATH_ALL_URL_PREFIX = "classpath*:"; + /** * Return a {@code Resource} handle for the specified resource location. *

The handle should always be a reusable resource descriptor, * allowing for multiple {@link Resource#getInputStream()} calls. - *

* - *

When {@code @MockitoBean} is declared on a field, the bean to mock is inferred - * from the type of the annotated field. If multiple candidates exist in the - * {@code ApplicationContext}, a {@code @Qualifier} annotation can be declared - * on the field to help disambiguate. In the absence of a {@code @Qualifier} - * annotation, the name of the annotated field will be used as a fallback - * qualifier. Alternatively, you can explicitly specify a bean name to mock - * by setting the {@link #value() value} or {@link #name() name} attribute. + *

When {@code @MockitoBean} is declared on a field or parameter, the bean to + * mock is inferred from the type of the annotated field or parameter. If multiple + * candidates exist in the {@code ApplicationContext}, a {@code @Qualifier} annotation + * can be declared on the field or parameter to help disambiguate. In the absence + * of a {@code @Qualifier} annotation, the name of the annotated field or parameter + * will be used as a fallback qualifier. Alternatively, you can explicitly + * specify a bean name to mock by setting the {@link #value() value} or + * {@link #name() name} attribute. * *

When {@code @MockitoBean} is declared at the type level, the type of bean * (or beans) to mock must be supplied via the {@link #types() types} attribute. @@ -116,7 +118,7 @@ * @see org.springframework.test.context.bean.override.mockito.MockitoSpyBean @MockitoSpyBean * @see org.springframework.test.context.bean.override.convention.TestBean @TestBean */ -@Target({ElementType.FIELD, ElementType.TYPE}) +@Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) @Documented @Repeatable(MockitoBeans.class) @@ -135,9 +137,9 @@ /** * Name of the bean to mock. *

If left unspecified, the bean to mock is selected according to the - * configured {@link #types() types} or the annotated field's type, taking - * qualifiers into account if necessary. See the {@linkplain MockitoBean - * class-level documentation} for details. + * configured {@link #types() types} or the type of the annotated field or + * parameter, taking qualifiers into account if necessary. See the + * {@linkplain MockitoBean class-level documentation} for details. * @see #value() */ @AliasFor("value") @@ -148,7 +150,7 @@ *

Defaults to none. *

Each type specified will result in a mock being created and registered * with the {@code ApplicationContext}. - *

Types must be omitted when the annotation is used on a field. + *

Types must be omitted when the annotation is used on a field or parameter. *

When {@code @MockitoBean} also defines a {@link #name name}, this attribute * can only contain a single value. * @return the types to mock diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideHandler.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideHandler.java index c0532f6f52bd..79339d2d514d 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideHandler.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideHandler.java @@ -17,6 +17,7 @@ package org.springframework.test.context.bean.override.mockito; import java.lang.reflect.Field; +import java.lang.reflect.Parameter; import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashSet; @@ -58,7 +59,7 @@ class MockitoBeanOverrideHandler extends AbstractMockitoBeanOverrideHandler { MockitoBeanOverrideHandler(ResolvableType typeToMock, MockitoBean mockitoBean) { - this(null, typeToMock, mockitoBean); + this((Field) null, typeToMock, mockitoBean); } MockitoBeanOverrideHandler(@Nullable Field field, ResolvableType typeToMock, MockitoBean mockitoBean) { @@ -67,6 +68,12 @@ class MockitoBeanOverrideHandler extends AbstractMockitoBeanOverrideHandler { mockitoBean.reset(), mockitoBean.extraInterfaces(), mockitoBean.answers(), mockitoBean.serializable()); } + MockitoBeanOverrideHandler(Parameter parameter, ResolvableType typeToMock, MockitoBean mockitoBean) { + this(parameter, typeToMock, (!mockitoBean.name().isBlank() ? mockitoBean.name() : null), + mockitoBean.contextName(), (mockitoBean.enforceOverride() ? REPLACE : REPLACE_OR_CREATE), + mockitoBean.reset(), mockitoBean.extraInterfaces(), mockitoBean.answers(), mockitoBean.serializable()); + } + private MockitoBeanOverrideHandler(@Nullable Field field, ResolvableType typeToMock, @Nullable String beanName, String contextName, BeanOverrideStrategy strategy, MockReset reset, Class[] extraInterfaces, Answers answers, boolean serializable) { @@ -78,6 +85,16 @@ private MockitoBeanOverrideHandler(@Nullable Field field, ResolvableType typeToM this.serializable = serializable; } + private MockitoBeanOverrideHandler(Parameter parameter, ResolvableType typeToMock, @Nullable String beanName, + String contextName, BeanOverrideStrategy strategy, MockReset reset, Class[] extraInterfaces, + Answers answers, boolean serializable) { + + super(parameter, typeToMock, beanName, contextName, strategy, reset); + Assert.notNull(typeToMock, "'typeToMock' must not be null"); + this.extraInterfaces = asClassSet(extraInterfaces); + this.answers = answers; + this.serializable = serializable; + } private static Set> asClassSet(Class[] classes) { if (classes.length == 0) { @@ -158,6 +175,7 @@ public int hashCode() { public String toString() { return new ToStringCreator(this) .append("field", getField()) + .append("parameter", getParameter()) .append("beanType", getBeanType()) .append("beanName", getBeanName()) .append("contextName", getContextName()) diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideProcessor.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideProcessor.java index fe2cdfcbe41a..7bd7468c6acd 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideProcessor.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideProcessor.java @@ -18,9 +18,12 @@ import java.lang.annotation.Annotation; import java.lang.reflect.Field; +import java.lang.reflect.Parameter; import java.util.ArrayList; import java.util.List; +import org.jspecify.annotations.Nullable; + import org.springframework.core.ResolvableType; import org.springframework.test.context.bean.override.BeanOverrideHandler; import org.springframework.test.context.bean.override.BeanOverrideProcessor; @@ -56,6 +59,24 @@ else if (overrideAnnotation instanceof MockitoSpyBean mockitoSpyBean) { .formatted(field.getDeclaringClass().getName(), field.getName())); } + @Override + public @Nullable BeanOverrideHandler createHandler(Annotation overrideAnnotation, Class testClass, Parameter parameter) { + if (overrideAnnotation instanceof MockitoBean mockitoBean) { + Assert.state(mockitoBean.types().length == 0, + "The @MockitoBean 'types' attribute must be omitted when declared on a parameter"); + return new MockitoBeanOverrideHandler(parameter, ResolvableType.forParameter(parameter), mockitoBean); + } + else if (overrideAnnotation instanceof MockitoSpyBean mockitoSpyBean) { + Assert.state(mockitoSpyBean.types().length == 0, + "The @MockitoSpyBean 'types' attribute must be omitted when declared on a parameter"); + return new MockitoSpyBeanOverrideHandler(parameter, ResolvableType.forParameter(parameter), mockitoSpyBean); + } + throw new IllegalStateException(""" + Invalid annotation passed to MockitoBeanOverrideProcessor: \ + expected either @MockitoBean or @MockitoSpyBean on parameter '%s' in constructor %s""" + .formatted(parameter.getName(), parameter.getDeclaringExecutable().getName())); + } + @Override public List createHandlers(Annotation overrideAnnotation, Class testClass) { if (overrideAnnotation instanceof MockitoBean mockitoBean) { diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoResetTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoResetTestExecutionListener.java index 055c4c61beec..d3b0558eef59 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoResetTestExecutionListener.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoResetTestExecutionListener.java @@ -113,14 +113,15 @@ private static void resetMocks(ApplicationContext applicationContext, MockReset private static void resetMocks(ConfigurableApplicationContext applicationContext, MockReset reset) { ConfigurableListableBeanFactory beanFactory = applicationContext.getBeanFactory(); - String[] beanNames = beanFactory.getBeanDefinitionNames(); Set instantiatedSingletons = new HashSet<>(Arrays.asList(beanFactory.getSingletonNames())); - for (String beanName : beanNames) { - BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName); - if (beanDefinition.isSingleton() && instantiatedSingletons.contains(beanName)) { - Object bean = getBean(beanFactory, beanName); - if (bean != null && reset == MockReset.get(bean)) { - Mockito.reset(bean); + for (String beanName : beanFactory.getBeanDefinitionNames()) { + if (instantiatedSingletons.contains(beanName)) { + BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName); + if (beanDefinition.isSingleton()) { + Object bean = getBean(beanFactory, beanName); + if (bean != null && reset == MockReset.get(bean)) { + Mockito.reset(bean); + } } } } diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBean.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBean.java index 7ad92818a365..c58790aec3e9 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBean.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBean.java @@ -38,6 +38,7 @@ *

  • On a non-static field in an enclosing class for a {@code @Nested} test class * or in any class in the type hierarchy or enclosing class hierarchy above the * {@code @Nested} test class.
  • + *
  • On a parameter in the constructor for the test class.
  • *
  • At the type level on a test class or any superclass or implemented interface * in the type hierarchy above the test class.
  • *
  • At the type level on an enclosing class for a {@code @Nested} test class @@ -45,15 +46,16 @@ * above the {@code @Nested} test class.
  • * * - *

    When {@code @MockitoSpyBean} is declared on a field, the bean to spy is - * inferred from the type of the annotated field. If multiple candidates exist in - * the {@code ApplicationContext}, a {@code @Qualifier} annotation can be declared - * on the field to help disambiguate. In the absence of a {@code @Qualifier} - * annotation, the name of the annotated field will be used as a fallback - * qualifier. Alternatively, you can explicitly specify a bean name to spy - * by setting the {@link #value() value} or {@link #name() name} attribute. If a - * bean name is specified, it is required that a target bean with that name has - * been previously registered in the application context. + *

    When {@code @MockitoSpyBean} is declared on a field or parameter, the bean + * to spy is inferred from the type of the annotated field or parameter. If multiple + * candidates exist in the {@code ApplicationContext}, a {@code @Qualifier} annotation + * can be declared on the field or parameter to help disambiguate. In the absence + * of a {@code @Qualifier} annotation, the name of the annotated field or parameter + * will be used as a fallback qualifier. Alternatively, you can explicitly + * specify a bean name to spy by setting the {@link #value() value} or + * {@link #name() name} attribute. If a bean name is specified, it is required that + * a target bean with that name has been previously registered in the application + * context. * *

    When {@code @MockitoSpyBean} is declared at the type level, the type of bean * (or beans) to spy must be supplied via the {@link #types() types} attribute. @@ -123,7 +125,7 @@ * @see org.springframework.test.context.bean.override.mockito.MockitoBean @MockitoBean * @see org.springframework.test.context.bean.override.convention.TestBean @TestBean */ -@Target({ElementType.FIELD, ElementType.TYPE}) +@Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) @Documented @Repeatable(MockitoSpyBeans.class) @@ -142,9 +144,9 @@ /** * Name of the bean to spy. *

    If left unspecified, the bean to spy is selected according to the - * configured {@link #types() types} or the annotated field's type, taking - * qualifiers into account if necessary. See the {@linkplain MockitoSpyBean - * class-level documentation} for details. + * configured {@link #types() types} or the type of the annotated field or + * parameter, taking qualifiers into account if necessary. See the + * {@linkplain MockitoSpyBean class-level documentation} for details. * @see #value() */ @AliasFor("value") @@ -155,7 +157,7 @@ *

    Defaults to none. *

    Each type specified will result in a spy being created and registered * with the {@code ApplicationContext}. - *

    Types must be omitted when the annotation is used on a field. + *

    Types must be omitted when the annotation is used on a field or parameter. *

    When {@code @MockitoSpyBean} also defines a {@link #name name}, this * attribute can only contain a single value. * @return the types to spy diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanOverrideHandler.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanOverrideHandler.java index e9913ab44e01..d8cddb674898 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanOverrideHandler.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanOverrideHandler.java @@ -17,6 +17,7 @@ package org.springframework.test.context.bean.override.mockito; import java.lang.reflect.Field; +import java.lang.reflect.Parameter; import java.lang.reflect.Proxy; import org.jspecify.annotations.Nullable; @@ -49,7 +50,7 @@ class MockitoSpyBeanOverrideHandler extends AbstractMockitoBeanOverrideHandler { MockitoSpyBeanOverrideHandler(ResolvableType typeToSpy, MockitoSpyBean spyBean) { - this(null, typeToSpy, spyBean); + this((Field) null, typeToSpy, spyBean); } MockitoSpyBeanOverrideHandler(@Nullable Field field, ResolvableType typeToSpy, MockitoSpyBean spyBean) { @@ -58,6 +59,12 @@ class MockitoSpyBeanOverrideHandler extends AbstractMockitoBeanOverrideHandler { Assert.notNull(typeToSpy, "typeToSpy must not be null"); } + MockitoSpyBeanOverrideHandler(Parameter parameter, ResolvableType typeToSpy, MockitoSpyBean spyBean) { + super(parameter, typeToSpy, (StringUtils.hasText(spyBean.name()) ? spyBean.name() : null), + spyBean.contextName(), BeanOverrideStrategy.WRAP, spyBean.reset()); + Assert.notNull(typeToSpy, "typeToSpy must not be null"); + } + @Override protected Object createOverrideInstance(String beanName, @Nullable BeanDefinition existingBeanDefinition, diff --git a/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java b/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java index 8a9f35d410db..26bccfaee4bb 100644 --- a/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java +++ b/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java @@ -58,6 +58,9 @@ import org.springframework.test.context.MethodInvoker; import org.springframework.test.context.TestContextAnnotationUtils; import org.springframework.test.context.TestContextManager; +import org.springframework.test.context.bean.override.BeanOverride; +import org.springframework.test.context.bean.override.BeanOverrideHandler; +import org.springframework.test.context.bean.override.BeanOverrideUtils; import org.springframework.test.context.event.ApplicationEvents; import org.springframework.test.context.event.RecordApplicationEvents; import org.springframework.test.context.support.PropertyProvider; @@ -370,6 +373,9 @@ public void afterEach(ExtensionContext context) throws Exception { * invoked with a fallback {@link PropertyProvider} that delegates its lookup * to {@link ExtensionContext#getConfigurationParameter(String)}. *

  • The parameter is of type {@link ApplicationContext} or a sub-type thereof.
  • + *
  • The parameter is annotated or meta-annotated with a + * {@link BeanOverride @BeanOverride} composed annotation — for example, + * {@code @MockitoBean} or {@code @MockitoSpyBean}.
  • *
  • The parameter is of type {@link ApplicationEvents} or a sub-type thereof.
  • *
  • {@link ParameterResolutionDelegate#isAutowirable} returns {@code true}.
  • * @@ -396,11 +402,12 @@ public boolean supportsParameter(ParameterContext parameterContext, ExtensionCon extensionContext.getConfigurationParameter(propertyName).orElse(null); return (TestConstructorUtils.isAutowirableConstructor(executable, junitPropertyProvider) || ApplicationContext.class.isAssignableFrom(parameterType) || + isBeanOverride(parameter) || supportsApplicationEvents(parameterType, executable) || ParameterResolutionDelegate.isAutowirable(parameter, parameterContext.getIndex())); } - private boolean supportsApplicationEvents(Class parameterType, Executable executable) { + private static boolean supportsApplicationEvents(Class parameterType, Executable executable) { if (ApplicationEvents.class.isAssignableFrom(parameterType)) { Assert.isTrue(executable instanceof Method, "ApplicationEvents can only be injected into test and lifecycle methods"); @@ -412,9 +419,9 @@ private boolean supportsApplicationEvents(Class parameterType, Executable exe /** * Resolve a value for the {@link Parameter} in the supplied {@link ParameterContext} by * retrieving the corresponding dependency from the test's {@link ApplicationContext}. - *

    Delegates to {@link ParameterResolutionDelegate#resolveDependency}. + *

    Delegates to {@link ParameterResolutionDelegate}. * @see #supportsParameter - * @see ParameterResolutionDelegate#resolveDependency + * @see ParameterResolutionDelegate#resolveDependency(Parameter, int, String, Class, org.springframework.beans.factory.config.AutowireCapableBeanFactory) */ @Override public @Nullable Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) { @@ -428,6 +435,15 @@ private boolean supportsApplicationEvents(Class parameterType, Executable exe } ApplicationContext applicationContext = getApplicationContext(extensionContext); + + // If the parameter is a @BeanOverride with an explicit name, we simply look + // up the bean by name instead of performing full dependency resolution. + if (isBeanOverride(parameter)) { + BeanOverrideHandler handler = BeanOverrideUtils.resolveHandlerForParameter(parameter, testClass); + if (handler != null && handler.getBeanName() != null) { + return applicationContext.getBean(handler.getBeanName()); + } + } return ParameterResolutionDelegate.resolveDependency(parameter, index, testClass, applicationContext.getAutowireCapableBeanFactory()); } @@ -495,6 +511,11 @@ private static boolean isAutowiredTestOrLifecycleMethod(Method method) { return false; } + private static boolean isBeanOverride(Parameter parameter) { + return (parameter.getDeclaringExecutable() instanceof Constructor && + MergedAnnotations.from(parameter).isPresent(BeanOverride.class)); + } + /** * Find the properly scoped {@link ExtensionContext} for the supplied test class. *

    If the supplied {@code ExtensionContext} is already properly scoped, it diff --git a/spring-test/src/main/java/org/springframework/test/context/support/AbstractTestContextBootstrapper.java b/spring-test/src/main/java/org/springframework/test/context/support/AbstractTestContextBootstrapper.java index 06162f25e410..d5a654935db4 100644 --- a/spring-test/src/main/java/org/springframework/test/context/support/AbstractTestContextBootstrapper.java +++ b/spring-test/src/main/java/org/springframework/test/context/support/AbstractTestContextBootstrapper.java @@ -20,11 +20,9 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -80,13 +78,6 @@ */ public abstract class AbstractTestContextBootstrapper implements TestContextBootstrapper { - private static final String IGNORED_DEFAULT_CONFIG_MESSAGE = """ - For test class [%1$s], the following 'default' context configuration %2$s were detected \ - but are currently ignored: %3$s. In Spring Framework 7.1, these %2$s will no longer be ignored. \ - Please update your test configuration accordingly. For details, see: \ - https://docs.spring.io/spring-framework/reference/testing/testcontext-framework/ctx-management/default-config.html"""; - - private final Log logger = LogFactory.getLog(getClass()); private @Nullable BootstrapContext bootstrapContext; @@ -262,12 +253,9 @@ private MergedContextConfiguration buildDefaultMergedContextConfiguration(Class< CacheAwareContextLoaderDelegate cacheAwareContextLoaderDelegate) { List defaultConfigAttributesList = - Collections.singletonList(new ContextConfigurationAttributes(testClass)); - // for 7.1: ContextLoaderUtils.resolveDefaultContextConfigurationAttributes(testClass); - + ContextLoaderUtils.resolveDefaultContextConfigurationAttributes(testClass); MergedContextConfiguration mergedConfig = buildMergedContextConfiguration( testClass, defaultConfigAttributesList, null, cacheAwareContextLoaderDelegate, false); - logWarningForIgnoredDefaultConfig(mergedConfig, cacheAwareContextLoaderDelegate); if (logger.isTraceEnabled()) { logger.trace(String.format( @@ -283,46 +271,6 @@ else if (logger.isDebugEnabled()) { return mergedConfig; } - /** - * In Spring Framework 7.1, we will use the "complete" list of default config attributes. - * In the interim, we log a warning if the "current" detected config differs from the - * "complete" detected config, which signals that some default configuration is currently - * being ignored. - */ - private void logWarningForIgnoredDefaultConfig(MergedContextConfiguration mergedConfig, - CacheAwareContextLoaderDelegate cacheAwareContextLoaderDelegate) { - - if (logger.isWarnEnabled()) { - Class testClass = mergedConfig.getTestClass(); - List completeDefaultConfigAttributesList = - ContextLoaderUtils.resolveDefaultContextConfigurationAttributes(testClass); - MergedContextConfiguration completeMergedConfig = buildMergedContextConfiguration( - testClass, completeDefaultConfigAttributesList, null, - cacheAwareContextLoaderDelegate, false); - - if (!Arrays.equals(mergedConfig.getClasses(), completeMergedConfig.getClasses())) { - Set> currentClasses = new HashSet<>(Arrays.asList(mergedConfig.getClasses())); - String ignoredClasses = Arrays.stream(completeMergedConfig.getClasses()) - .filter(clazz -> !currentClasses.contains(clazz)) - .map(Class::getName) - .collect(Collectors.joining(", ")); - if (!ignoredClasses.isEmpty()) { - logger.warn(IGNORED_DEFAULT_CONFIG_MESSAGE.formatted(testClass.getName(), "classes", ignoredClasses)); - } - } - - if (!Arrays.equals(mergedConfig.getLocations(), completeMergedConfig.getLocations())) { - Set currentLocations = new HashSet<>(Arrays.asList(mergedConfig.getLocations())); - String ignoredLocations = Arrays.stream(completeMergedConfig.getLocations()) - .filter(location -> !currentLocations.contains(location)) - .collect(Collectors.joining(", ")); - if (!ignoredLocations.isEmpty()) { - logger.warn(IGNORED_DEFAULT_CONFIG_MESSAGE.formatted(testClass.getName(), "locations", ignoredLocations)); - } - } - } - } - /** * Build the {@linkplain MergedContextConfiguration merged context configuration} * for the supplied {@link Class testClass}, context configuration attributes, diff --git a/spring-test/src/main/java/org/springframework/test/web/client/MockRestServiceServer.java b/spring-test/src/main/java/org/springframework/test/web/client/MockRestServiceServer.java index 779bc7408d39..639997cc461d 100644 --- a/spring-test/src/main/java/org/springframework/test/web/client/MockRestServiceServer.java +++ b/spring-test/src/main/java/org/springframework/test/web/client/MockRestServiceServer.java @@ -152,7 +152,10 @@ public static MockRestServiceServerBuilder bindTo(RestClient.Builder restClientB * Return a builder for a {@code MockRestServiceServer} that should be used * to reply to the given {@code RestTemplate}. * @since 4.3 + * @deprecated as of 7.1 in favor of {@link #bindTo(RestClient.Builder)}. */ + @Deprecated(since = "7.1", forRemoval = true) + @SuppressWarnings("removal") public static MockRestServiceServerBuilder bindTo(RestTemplate restTemplate) { return new RestTemplateMockRestServiceServerBuilder(restTemplate); } @@ -161,7 +164,10 @@ public static MockRestServiceServerBuilder bindTo(RestTemplate restTemplate) { * Return a builder for a {@code MockRestServiceServer} that should be used * to reply to the {@code RestTemplate} for the given {@code RestGatewaySupport}. * @since 4.3 + * @deprecated as of 7.1 in favor of {@link #bindTo(RestClient.Builder)}. */ + @Deprecated(since = "7.1", forRemoval = true) + @SuppressWarnings("removal") public static MockRestServiceServerBuilder bindTo(RestGatewaySupport restGatewaySupport) { Assert.notNull(restGatewaySupport, "'restGatewaySupport' must not be null"); return new RestTemplateMockRestServiceServerBuilder(restGatewaySupport.getRestTemplate()); @@ -182,7 +188,10 @@ public static MockRestServiceServer createServer(RestClient.Builder clientBuilde * A shortcut for {@code bindTo(restTemplate).build()}. * @param restTemplate the RestTemplate to set up for mock testing * @return the mock server + * @deprecated as of 7.1 in favor of {@link #bindTo(RestClient.Builder)}. */ + @Deprecated(since = "7.1", forRemoval = true) + @SuppressWarnings("removal") public static MockRestServiceServer createServer(RestTemplate restTemplate) { return bindTo(restTemplate).build(); } @@ -191,7 +200,10 @@ public static MockRestServiceServer createServer(RestTemplate restTemplate) { * A shortcut for {@code bindTo(restGateway).build()}. * @param restGateway the REST gateway to set up for mock testing * @return the mock server + * @deprecated as of 7.1 in favor of {@link #bindTo(RestClient.Builder)}. */ + @Deprecated(since = "7.1", forRemoval = true) + @SuppressWarnings("removal") public static MockRestServiceServer createServer(RestGatewaySupport restGateway) { return bindTo(restGateway).build(); } @@ -300,7 +312,7 @@ protected void injectRequestFactory(ClientHttpRequestFactory requestFactory) { } } - + @SuppressWarnings("removal") private static class RestTemplateMockRestServiceServerBuilder extends AbstractMockRestServiceServerBuilder { private final RestTemplate restTemplate; diff --git a/spring-test/src/main/java/org/springframework/test/web/client/match/ContentRequestMatchers.java b/spring-test/src/main/java/org/springframework/test/web/client/match/ContentRequestMatchers.java index 58dbbfb4b57e..8b4964aa87a8 100644 --- a/spring-test/src/main/java/org/springframework/test/web/client/match/ContentRequestMatchers.java +++ b/spring-test/src/main/java/org/springframework/test/web/client/match/ContentRequestMatchers.java @@ -35,6 +35,7 @@ import org.jspecify.annotations.Nullable; import org.w3c.dom.Node; +import org.springframework.core.ResolvableType; import org.springframework.core.io.Resource; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; @@ -175,12 +176,14 @@ public RequestMatcher formDataContains(Map expected) { return formData(multiValueMap, false); } + @SuppressWarnings("unchecked") private RequestMatcher formData(MultiValueMap expectedMap, boolean containsExactly) { return request -> { MockClientHttpRequest mockRequest = (MockClientHttpRequest) request; MockHttpInputMessage message = new MockHttpInputMessage(mockRequest.getBodyAsBytes()); message.getHeaders().putAll(mockRequest.getHeaders()); - MultiValueMap actualMap = new FormHttpMessageConverter().read(null, message); + MultiValueMap actualMap = (MultiValueMap) new FormHttpMessageConverter() + .read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class), message, null); if (containsExactly) { assertEquals("Form data", expectedMap, actualMap); } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcClientHttpRequestFactory.java b/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcClientHttpRequestFactory.java index acb23a324448..218b4ebcb8da 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcClientHttpRequestFactory.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcClientHttpRequestFactory.java @@ -16,28 +16,40 @@ package org.springframework.test.web.servlet.client; +import java.io.IOException; +import java.io.InputStream; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.List; import jakarta.servlet.http.Cookie; +import org.springframework.core.ResolvableType; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.converter.multipart.FilePart; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; +import org.springframework.http.converter.multipart.Part; +import org.springframework.mock.http.MockHttpInputMessage; import org.springframework.mock.http.client.MockClientHttpRequest; import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockPart; import org.springframework.test.web.servlet.MockMvc; -import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; +import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder; +import org.springframework.test.web.servlet.request.MockMultipartHttpServletRequestBuilder; import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.multipart; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request; /** @@ -45,10 +57,16 @@ * * @author Rossen Stoyanchev * @author Rob Worsnop + * @author Brian Clozel * @since 7.0 */ public class MockMvcClientHttpRequestFactory implements ClientHttpRequestFactory { + private static final ResolvableType MULTIVALUEMAP_TYPE = ResolvableType.forClassWithGenerics(MultiValueMap.class, + String.class, Part.class); + + private static final MultipartHttpMessageConverter MULTIPART_CONVERTER = new MultipartHttpMessageConverter(); + private final MockMvc mockMvc; @@ -79,11 +97,8 @@ private class MockMvcClientHttpRequest extends MockClientHttpRequest { @Override public ClientHttpResponse executeInternal() { try { - MockHttpServletRequestBuilder servletRequestBuilder = request(getMethod(), getURI()) - .headers(getHeaders()) - .content(getBodyAsBytes()); - - addCookies(servletRequestBuilder); + AbstractMockHttpServletRequestBuilder servletRequestBuilder = + adaptRequest(getMethod(), getURI(), getHeaders(), getBodyAsBytes()); MockHttpServletResponse servletResponse = MockMvcClientHttpRequestFactory.this.mockMvc .perform(servletRequestBuilder) @@ -104,15 +119,61 @@ public ClientHttpResponse executeInternal() { } } - private void addCookies(MockHttpServletRequestBuilder requestBuilder) { - List values = getHeaders().get(HttpHeaders.COOKIE); + private AbstractMockHttpServletRequestBuilder adaptRequest( + HttpMethod httpMethod, URI uri, HttpHeaders headers, byte[] bytes) throws IOException { + + String contentType = headers.getFirst(HttpHeaders.CONTENT_TYPE); + AbstractMockHttpServletRequestBuilder requestBuilder; + + if (StringUtils.hasLength(contentType) && + MediaType.MULTIPART_FORM_DATA.includes(MediaType.parseMediaType(contentType))) { + + MockMultipartHttpServletRequestBuilder multipartRequestBuilder = multipart(httpMethod, uri); + Assert.notNull(bytes, "No multipart content"); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(bytes); + inputMessage.getHeaders().putAll(headers); + + MultiValueMap parts = MULTIPART_CONVERTER.read(MULTIVALUEMAP_TYPE, inputMessage, null); + for (List partValues : parts.values()) { + for (Part part : partValues) { + parsePart(part, multipartRequestBuilder); + } + } + requestBuilder = multipartRequestBuilder; + } + else { + requestBuilder = request(httpMethod, uri); + if (!ObjectUtils.isEmpty(bytes)) { + requestBuilder.content(bytes); + } + } + + requestBuilder.headers(headers); + addCookies(headers, requestBuilder); + + return requestBuilder; + } + + private void parsePart(Part part, MockMultipartHttpServletRequestBuilder multipartRequestBuilder) throws IOException { + try (InputStream content = part.content()) { + byte[] partBytes = content.readAllBytes(); + MockPart mockPart = (part instanceof FilePart filePart ? + new MockPart(part.name(), filePart.filename(), partBytes) : + new MockPart(part.name(), partBytes)); + mockPart.getHeaders().putAll(part.headers()); + multipartRequestBuilder.part(mockPart); + } + } + + private void addCookies(HttpHeaders headers, AbstractMockHttpServletRequestBuilder requestBuilder) { + List values = headers.get(HttpHeaders.COOKIE); if (!ObjectUtils.isEmpty(values)) { values.stream() .flatMap(header -> StringUtils.commaDelimitedListToSet(header).stream()) .map(value -> { - String[] parts = StringUtils.split(value, "="); - Assert.isTrue(parts != null && parts.length == 2, "Invalid cookie: '" + value + "'"); - return new Cookie(parts[0], parts[1]); + String[] cookieParts = StringUtils.split(value, "="); + Assert.isTrue(cookieParts != null && cookieParts.length == 2, "Invalid cookie: '" + value + "'"); + return new Cookie(cookieParts[0], cookieParts[1]); }) .forEach(requestBuilder::cookie); } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java index ac46d9b660c8..fe74c01c312a 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java @@ -46,6 +46,7 @@ import org.springframework.beans.Mergeable; import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.core.ResolvableType; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpMethod; @@ -1038,7 +1039,8 @@ public HttpHeaders getHeaders() { }; try { - return new FormHttpMessageConverter().read(null, message); + return (MultiValueMap) new FormHttpMessageConverter() + .read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class), message, null); } catch (IOException ex) { throw new IllegalStateException("Failed to parse form data in request body", ex); diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/setup/StubWebApplicationContext.java b/spring-test/src/main/java/org/springframework/test/web/servlet/setup/StubWebApplicationContext.java index 678db294465e..6b1495d7830b 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/setup/StubWebApplicationContext.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/setup/StubWebApplicationContext.java @@ -65,6 +65,7 @@ * * @author Rossen Stoyanchev * @author Juergen Hoeller + * @author Yanming Zhou * @since 3.2 */ class StubWebApplicationContext implements WebApplicationContext { @@ -168,6 +169,11 @@ public T getBean(String name, Class requiredType) throws BeansException { return this.beanFactory.getBean(name, requiredType); } + @Override + public T getBean(String name, ParameterizedTypeReference typeReference) throws BeansException { + return this.beanFactory.getBean(name, typeReference); + } + @Override public Object getBean(String name, @Nullable Object @Nullable ... args) throws BeansException { return this.beanFactory.getBean(name, args); diff --git a/spring-test/src/test/java/org/springframework/test/context/BootstrapUtilsTests.java b/spring-test/src/test/java/org/springframework/test/context/BootstrapUtilsTests.java index 6d9077ec6aeb..fff6c5f3a4b4 100644 --- a/spring-test/src/test/java/org/springframework/test/context/BootstrapUtilsTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/BootstrapUtilsTests.java @@ -146,13 +146,7 @@ void resolveTestContextBootstrapperWithLocalDeclarationThatOverridesMetaBootstra */ @Test // gh-35938 void resolveTestContextBootstrapperWithMetaBootstrapWithAnnotationThatOverridesMetaMetaBootstrapWithAnnotation() { - BootstrapContext bootstrapContext = BootstrapTestUtils.buildBootstrapContext( - MetaAndMetaMetaBootstrapWithAnnotationsClass.class, delegate); - assertThatIllegalStateException() - .isThrownBy(() -> resolveTestContextBootstrapper(bootstrapContext)) - .withMessageContaining("Configuration error: found multiple declarations of @BootstrapWith") - .withMessageContaining(FooBootstrapper.class.getSimpleName()) - .withMessageContaining(BarBootstrapper.class.getSimpleName()); + assertBootstrapper(MetaAndMetaMetaBootstrapWithAnnotationsClass.class, BarBootstrapper.class); } /** diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessorTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessorTests.java index 4a1dd92774fe..7eda38045ba1 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessorTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessorTests.java @@ -401,7 +401,7 @@ private void qualifiedElementIsField(RootBeanDefinition def) { private static AnnotationConfigApplicationContext createContext(Class testClass) { AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); - Set handlers = new LinkedHashSet<>(BeanOverrideTestUtils.findHandlers(testClass)); + Set handlers = new LinkedHashSet<>(BeanOverrideUtils.findHandlersForFields(testClass)); new BeanOverrideContextCustomizer(handlers).customizeContext(context, mock(MergedContextConfiguration.class)); return context; } diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideHandlerTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideHandlerTests.java index 12656bbe149f..b343f6193335 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideHandlerTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideHandlerTests.java @@ -51,14 +51,14 @@ class BeanOverrideHandlerTests { @Test void forTestClassWithSingleField() { - List handlers = BeanOverrideTestUtils.findHandlers(SingleAnnotation.class); + List handlers = BeanOverrideUtils.findHandlersForFields(SingleAnnotation.class); assertThat(handlers).singleElement().satisfies(hasBeanOverrideHandler( field(SingleAnnotation.class, "message"), String.class, null)); } @Test void forTestClassWithMultipleFields() { - List handlers = BeanOverrideTestUtils.findHandlers(MultipleAnnotations.class); + List handlers = BeanOverrideUtils.findHandlersForFields(MultipleAnnotations.class); assertThat(handlers).hasSize(2) .anySatisfy(hasBeanOverrideHandler( field(MultipleAnnotations.class, "message"), String.class, null)) @@ -68,7 +68,7 @@ void forTestClassWithMultipleFields() { @Test void forTestClassWithMultipleFieldsWithIdenticalMetadata() { - List handlers = BeanOverrideTestUtils.findHandlers(MultipleAnnotationsDuplicate.class); + List handlers = BeanOverrideUtils.findHandlersForFields(MultipleAnnotationsDuplicate.class); assertThat(handlers).hasSize(2) .anySatisfy(hasBeanOverrideHandler( field(MultipleAnnotationsDuplicate.class, "message1"), String.class, "messageBean")) @@ -81,7 +81,7 @@ void forTestClassWithMultipleFieldsWithIdenticalMetadata() { void forTestClassWithCompetingBeanOverrideAnnotationsOnSameField() { Field faultyField = field(MultipleAnnotationsOnSameField.class, "message"); assertThatIllegalStateException() - .isThrownBy(() -> BeanOverrideTestUtils.findHandlers(MultipleAnnotationsOnSameField.class)) + .isThrownBy(() -> BeanOverrideUtils.findHandlersForFields(MultipleAnnotationsOnSameField.class)) .withMessageStartingWith("Multiple @BeanOverride annotations found") .withMessageContaining(faultyField.toString()); } @@ -90,7 +90,7 @@ void forTestClassWithCompetingBeanOverrideAnnotationsOnSameField() { void forTestClassWithStaticBeanOverrideField() { Field staticField = field(StaticBeanOverrideField.class, "message"); assertThatIllegalStateException() - .isThrownBy(() -> BeanOverrideTestUtils.findHandlers(StaticBeanOverrideField.class)) + .isThrownBy(() -> BeanOverrideUtils.findHandlersForFields(StaticBeanOverrideField.class)) .withMessage("@BeanOverride field must not be static: " + staticField); } diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/convention/TestBeanOverrideHandlerTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/convention/TestBeanOverrideHandlerTests.java index 778adb052463..3fd8ff6fdbb0 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/convention/TestBeanOverrideHandlerTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/convention/TestBeanOverrideHandlerTests.java @@ -26,7 +26,7 @@ import org.springframework.core.ResolvableType; import org.springframework.test.context.bean.override.BeanOverrideHandler; import org.springframework.test.context.bean.override.BeanOverrideStrategy; -import org.springframework.test.context.bean.override.BeanOverrideTestUtils; +import org.springframework.test.context.bean.override.BeanOverrideUtils; import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; @@ -44,20 +44,20 @@ class TestBeanOverrideHandlerTests { @Test void beanNameIsSetToNullIfAnnotationNameIsEmpty() { - List handlers = BeanOverrideTestUtils.findHandlers(SampleOneOverride.class); + List handlers = BeanOverrideUtils.findHandlersForFields(SampleOneOverride.class); assertThat(handlers).singleElement().extracting(BeanOverrideHandler::getBeanName).isNull(); } @Test void beanNameIsSetToAnnotationName() { - List handlers = BeanOverrideTestUtils.findHandlers(SampleOneOverrideWithName.class); + List handlers = BeanOverrideUtils.findHandlersForFields(SampleOneOverrideWithName.class); assertThat(handlers).singleElement().extracting(BeanOverrideHandler::getBeanName).isEqualTo("anotherBean"); } @Test void failsWithMissingMethod() { assertThatIllegalStateException() - .isThrownBy(() -> BeanOverrideTestUtils.findHandlers(SampleMissingMethod.class)) + .isThrownBy(() -> BeanOverrideUtils.findHandlersForFields(SampleMissingMethod.class)) .withMessage("No static method found named message() in %s with return type %s", SampleMissingMethod.class.getName(), String.class.getName()); } diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/example/CustomQualifier.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/example/CustomQualifier.java index 6f1e5a73c2e0..47ed8c0c7926 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/example/CustomQualifier.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/example/CustomQualifier.java @@ -24,7 +24,7 @@ import org.springframework.beans.factory.annotation.Qualifier; -@Target({ElementType.FIELD, ElementType.METHOD}) +@Target({ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER}) @Retention(RetentionPolicy.RUNTIME) @Inherited @Qualifier diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanConfigurationErrorTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanConfigurationErrorTests.java index 36cc59d8d075..d7aabf7d2da0 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanConfigurationErrorTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanConfigurationErrorTests.java @@ -26,7 +26,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** - * Tests for {@link MockitoBean @MockitoBean}. + * Tests for {@link MockitoBean @MockitoBean} error scenarios. * * @author Stephane Nicoll * @author Sam Brannen @@ -88,19 +88,88 @@ void cannotOverrideBeanByTypeWithTooManyBeansOfThatType() { List.of("bean1", "bean2")); } + @Test // gh-36096 + void cannotOverrideBeanByNameWithNoSuchBeanNameOnConstructorParameter() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBean("anotherBean", String.class, () -> "example"); + BeanOverrideContextCustomizerTestUtils.customizeApplicationContext(FailureByNameLookupOnConstructorParameter.class, context); + assertThatIllegalStateException() + .isThrownBy(context::refresh) + .withMessage(""" + Unable to replace bean: there is no bean with name 'beanToOverride' and type \ + java.lang.String (as required by parameter 'example' in constructor for %s). \ + If the bean is defined in a @Bean method, make sure the return type is the most \ + specific type possible (for example, the concrete implementation type).""", + FailureByNameLookupOnConstructorParameter.class.getName()); + } + + @Test // gh-36096 + void cannotOverrideBeanByNameWithBeanOfWrongTypeOnConstructorParameter() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBean("beanToOverride", Integer.class, () -> 42); + BeanOverrideContextCustomizerTestUtils.customizeApplicationContext(FailureByNameLookupOnConstructorParameter.class, context); + assertThatIllegalStateException() + .isThrownBy(context::refresh) + .withMessage(""" + Unable to replace bean: there is no bean with name 'beanToOverride' and type \ + java.lang.String (as required by parameter 'example' in constructor for %s). \ + If the bean is defined in a @Bean method, make sure the return type is the most \ + specific type possible (for example, the concrete implementation type).""", + FailureByNameLookupOnConstructorParameter.class.getName()); + } + + @Test // gh-36096 + void cannotOverrideBeanByTypeWithNoSuchBeanTypeOnConstructorParameter() { + GenericApplicationContext context = new GenericApplicationContext(); + BeanOverrideContextCustomizerTestUtils.customizeApplicationContext(FailureByTypeLookupOnConstructorParameter.class, context); + assertThatIllegalStateException() + .isThrownBy(context::refresh) + .withMessage(""" + Unable to override bean: there are no beans of type java.lang.String \ + (as required by parameter 'example' in constructor for %s). \ + If the bean is defined in a @Bean method, make sure the return type is the most \ + specific type possible (for example, the concrete implementation type).""", + FailureByTypeLookupOnConstructorParameter.class.getName()); + } + + @Test // gh-36096 + void cannotOverrideBeanByTypeWithTooManyBeansOfThatTypeOnConstructorParameter() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBean("bean1", String.class, () -> "example1"); + context.registerBean("bean2", String.class, () -> "example2"); + BeanOverrideContextCustomizerTestUtils.customizeApplicationContext(FailureByTypeLookupOnConstructorParameter.class, context); + assertThatIllegalStateException() + .isThrownBy(context::refresh) + .withMessage(""" + Unable to select a bean to override: found 2 beans of type java.lang.String \ + (as required by parameter 'example' in constructor for %s): %s""", + FailureByTypeLookupOnConstructorParameter.class.getName(), + List.of("bean1", "bean2")); + } + static class FailureByTypeLookup { @MockitoBean(enforceOverride = true) String example; - } static class FailureByNameLookup { @MockitoBean(name = "beanToOverride", enforceOverride = true) String example; + } + + static class FailureByTypeLookupOnConstructorParameter { + + FailureByTypeLookupOnConstructorParameter(@MockitoBean(enforceOverride = true) String example) { + } + } + + static class FailureByNameLookupOnConstructorParameter { + FailureByNameLookupOnConstructorParameter(@MockitoBean(name = "beanToOverride", enforceOverride = true) String example) { + } } } diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanNestedTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanNestedTests.java index 37933ae49c32..0b57a4de0264 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanNestedTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanNestedTests.java @@ -16,6 +16,7 @@ package org.springframework.test.context.bean.override.mockito; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -24,8 +25,10 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.TestContextAnnotationUtils; import org.springframework.test.context.junit.jupiter.SpringExtension; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.times; @@ -38,10 +41,6 @@ * @since 6.2 */ @ExtendWith(SpringExtension.class) -// TODO Remove @ContextConfiguration declaration. -// @ContextConfiguration is currently required due to a bug in the TestContext framework. -// See https://github.com/spring-projects/spring-framework/issues/31456 -@ContextConfiguration class MockitoBeanNestedTests { @MockitoBean @@ -50,6 +49,13 @@ class MockitoBeanNestedTests { @Autowired Task task; + @BeforeAll + static void ensureNotAnnotatedWithContextConfiguration() { + boolean hasAnnotation = + TestContextAnnotationUtils.hasAnnotation(MockitoBeanNestedTests.class, ContextConfiguration.class); + assertThat(hasAnnotation).isFalse(); + } + @Test void mockWasInvokedOnce() { task.execute(); diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideHandlerTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideHandlerTests.java index 876e2fbacaad..3b9234fb5bf7 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideHandlerTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideHandlerTests.java @@ -26,7 +26,7 @@ import org.springframework.core.ResolvableType; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.test.context.bean.override.BeanOverrideHandler; -import org.springframework.test.context.bean.override.BeanOverrideTestUtils; +import org.springframework.test.context.bean.override.BeanOverrideUtils; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -42,13 +42,13 @@ class MockitoBeanOverrideHandlerTests { @Test void beanNameIsSetToNullIfAnnotationNameIsEmpty() { - List list = BeanOverrideTestUtils.findHandlers(SampleOneMock.class); + List list = BeanOverrideUtils.findHandlersForFields(SampleOneMock.class); assertThat(list).singleElement().satisfies(handler -> assertThat(handler.getBeanName()).isNull()); } @Test void beanNameIsSetToAnnotationName() { - List list = BeanOverrideTestUtils.findHandlers(SampleOneMockWithName.class); + List list = BeanOverrideUtils.findHandlersForFields(SampleOneMockWithName.class); assertThat(list).singleElement().satisfies(handler -> assertThat(handler.getBeanName()).isEqualTo("anotherService")); } @@ -194,7 +194,7 @@ private static MockitoBeanOverrideHandler createHandler(Field field) { private MockitoBeanOverrideHandler createHandler(Class clazz) { MockitoBean annotation = AnnotatedElementUtils.getMergedAnnotation(clazz, MockitoBean.class); - return new MockitoBeanOverrideHandler(null, ResolvableType.forClass(annotation.types()[0]), annotation); + return new MockitoBeanOverrideHandler((Field) null, ResolvableType.forClass(annotation.types()[0]), annotation); } diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideProcessorTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideProcessorTests.java index dce2ebe2cb2f..ff0259440952 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideProcessorTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideProcessorTests.java @@ -18,6 +18,7 @@ import java.lang.annotation.Annotation; import java.lang.reflect.Field; +import java.lang.reflect.Parameter; import java.util.List; import org.jspecify.annotations.Nullable; @@ -52,17 +53,17 @@ class CreateHandlerTests { @Test void mockAnnotationCreatesMockitoBeanOverrideHandler() { MockitoBean annotation = AnnotationUtils.synthesizeAnnotation(MockitoBean.class); - BeanOverrideHandler object = processor.createHandler(annotation, TestCase.class, field); + BeanOverrideHandler handler = processor.createHandler(annotation, TestCase.class, field); - assertThat(object).isExactlyInstanceOf(MockitoBeanOverrideHandler.class); + assertThat(handler).isExactlyInstanceOf(MockitoBeanOverrideHandler.class); } @Test void spyAnnotationCreatesMockitoSpyBeanOverrideHandler() { MockitoSpyBean annotation = AnnotationUtils.synthesizeAnnotation(MockitoSpyBean.class); - BeanOverrideHandler object = processor.createHandler(annotation, TestCase.class, field); + BeanOverrideHandler handler = processor.createHandler(annotation, TestCase.class, field); - assertThat(object).isExactlyInstanceOf(MockitoSpyBeanOverrideHandler.class); + assertThat(handler).isExactlyInstanceOf(MockitoSpyBeanOverrideHandler.class); } @Test @@ -102,6 +103,81 @@ static class NameNotSupportedTestCase { } } + @Nested // gh-36096 + class CreateHandlerForParameterTests { + + private final Parameter parameter = TestCase.class.getDeclaredConstructors()[0].getParameters()[0]; + + + @Test + void mockAnnotationCreatesMockitoBeanOverrideHandler() { + MockitoBean annotation = AnnotationUtils.synthesizeAnnotation(MockitoBean.class); + BeanOverrideHandler handler = processor.createHandler(annotation, TestCase.class, parameter); + + assertThat(handler).isExactlyInstanceOf(MockitoBeanOverrideHandler.class); + } + + @Test + void spyAnnotationCreatesMockitoSpyBeanOverrideHandler() { + MockitoSpyBean annotation = AnnotationUtils.synthesizeAnnotation(MockitoSpyBean.class); + BeanOverrideHandler handler = processor.createHandler(annotation, TestCase.class, parameter); + + assertThat(handler).isExactlyInstanceOf(MockitoSpyBeanOverrideHandler.class); + } + + @Test + void otherAnnotationThrows() { + Annotation annotation = parameter.getAnnotation(Nullable.class); + + assertThatIllegalStateException() + .isThrownBy(() -> processor.createHandler(annotation, TestCase.class, parameter)) + .withMessage("Invalid annotation passed to MockitoBeanOverrideProcessor: expected either " + + "@MockitoBean or @MockitoSpyBean on parameter '%s' in constructor %s", + parameter.getName(), parameter.getDeclaringExecutable().getName()); + } + + @Test + void typesAttributeNotSupportedForMockitoBean() { + Parameter parameter = TypesNotSupportedForMockitoBeanTestCase.class + .getDeclaredConstructors()[0].getParameters()[0]; + MockitoBean annotation = parameter.getAnnotation(MockitoBean.class); + + assertThatIllegalStateException() + .isThrownBy(() -> processor.createHandler(annotation, TypesNotSupportedForMockitoBeanTestCase.class, parameter)) + .withMessage("The @MockitoBean 'types' attribute must be omitted when declared on a parameter"); + } + + @Test + void typesAttributeNotSupportedForMockitoSpyBean() { + Parameter parameter = TypesNotSupportedForMockitoSpyBeanTestCase.class + .getDeclaredConstructors()[0].getParameters()[0]; + MockitoSpyBean annotation = parameter.getAnnotation(MockitoSpyBean.class); + + assertThatIllegalStateException() + .isThrownBy(() -> processor.createHandler(annotation, TypesNotSupportedForMockitoSpyBeanTestCase.class, parameter)) + .withMessage("The @MockitoSpyBean 'types' attribute must be omitted when declared on a parameter"); + } + + + static class TestCase { + + TestCase(@MockitoBean @MockitoSpyBean @Nullable Integer number) { + } + } + + static class TypesNotSupportedForMockitoBeanTestCase { + + TypesNotSupportedForMockitoBeanTestCase(@MockitoBean(types = Integer.class) String param) { + } + } + + static class TypesNotSupportedForMockitoSpyBeanTestCase { + + TypesNotSupportedForMockitoSpyBeanTestCase(@MockitoSpyBean(types = Integer.class) String param) { + } + } + } + @Nested class CreateHandlersTests { diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanConfigurationErrorTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanConfigurationErrorTests.java index 38ee10d9615d..71f63337af4b 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanConfigurationErrorTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanConfigurationErrorTests.java @@ -32,7 +32,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** - * Tests for {@link MockitoSpyBean @MockitoSpyBean}. + * Tests for {@link MockitoSpyBean @MockitoSpyBean} error scenarios. * * @author Stephane Nicoll * @author Sam Brannen @@ -113,33 +113,85 @@ void mockitoSpyBeanCannotSpyOnSelfInjectionScopedProxy() { to spy on a scoped proxy, which is not supported."""); } + @Test // gh-36096 + void contextCustomizerCannotBeCreatedWithNoSuchBeanNameOnConstructorParameter() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBean("present", String.class, () -> "example"); + BeanOverrideContextCustomizerTestUtils.customizeApplicationContext(ByNameSingleLookupOnConstructorParameter.class, context); + assertThatIllegalStateException() + .isThrownBy(context::refresh) + .withMessage(""" + Unable to wrap bean: there is no bean with name 'beanToSpy' and type \ + java.lang.String (as required by parameter 'example' in constructor for %s). \ + If the bean is defined in a @Bean method, make sure the return type is the most \ + specific type possible (for example, the concrete implementation type).""", + ByNameSingleLookupOnConstructorParameter.class.getName()); + } + + @Test // gh-36096 + void contextCustomizerCannotBeCreatedWithNoSuchBeanTypeOnConstructorParameter() { + GenericApplicationContext context = new GenericApplicationContext(); + BeanOverrideContextCustomizerTestUtils.customizeApplicationContext(ByTypeSingleLookupOnConstructorParameter.class, context); + assertThatIllegalStateException() + .isThrownBy(context::refresh) + .withMessage(""" + Unable to select a bean to wrap: there are no beans of type java.lang.String \ + (as required by parameter 'example' in constructor for %s). \ + If the bean is defined in a @Bean method, make sure the return type is the most \ + specific type possible (for example, the concrete implementation type).""", + ByTypeSingleLookupOnConstructorParameter.class.getName()); + } + + @Test // gh-36096 + void contextCustomizerCannotBeCreatedWithTooManyBeansOfThatTypeOnConstructorParameter() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBean("bean1", String.class, () -> "example1"); + context.registerBean("bean2", String.class, () -> "example2"); + BeanOverrideContextCustomizerTestUtils.customizeApplicationContext(ByTypeSingleLookupOnConstructorParameter.class, context); + assertThatIllegalStateException() + .isThrownBy(context::refresh) + .withMessage(""" + Unable to select a bean to wrap: found 2 beans of type java.lang.String \ + (as required by parameter 'example' in constructor for %s): %s""", + ByTypeSingleLookupOnConstructorParameter.class.getName(), + List.of("bean1", "bean2")); + } + static class ByTypeSingleLookup { @MockitoSpyBean String example; - } static class ByNameSingleLookup { @MockitoSpyBean("beanToSpy") String example; + } + + static class ByTypeSingleLookupOnConstructorParameter { + ByTypeSingleLookupOnConstructorParameter(@MockitoSpyBean String example) { + } + } + + static class ByNameSingleLookupOnConstructorParameter { + + ByNameSingleLookupOnConstructorParameter(@MockitoSpyBean("beanToSpy") String example) { + } } static class ScopedProxyTestCase { @MockitoSpyBean MyScopedProxy myScopedProxy; - } static class SelfInjectionScopedProxyTestCase { @MockitoSpyBean MySelfInjectionScopedProxy mySelfInjectionScopedProxy; - } @Component("myScopedProxy") diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanOverrideHandlerTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanOverrideHandlerTests.java index 9867a276453b..98f4619fd543 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanOverrideHandlerTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanOverrideHandlerTests.java @@ -24,7 +24,7 @@ import org.springframework.core.ResolvableType; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.test.context.bean.override.BeanOverrideHandler; -import org.springframework.test.context.bean.override.BeanOverrideTestUtils; +import org.springframework.test.context.bean.override.BeanOverrideUtils; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -38,13 +38,13 @@ class MockitoSpyBeanOverrideHandlerTests { @Test void beanNameIsSetToNullIfAnnotationNameIsEmpty() { - List list = BeanOverrideTestUtils.findHandlers(SampleOneSpy.class); + List list = BeanOverrideUtils.findHandlersForFields(SampleOneSpy.class); assertThat(list).singleElement().satisfies(handler -> assertThat(handler.getBeanName()).isNull()); } @Test void beanNameIsSetToAnnotationName() { - List list = BeanOverrideTestUtils.findHandlers(SampleOneSpyWithName.class); + List list = BeanOverrideUtils.findHandlersForFields(SampleOneSpyWithName.class); assertThat(list).singleElement().satisfies(handler -> assertThat(handler.getBeanName()).isEqualTo("anotherService")); } diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoBeanByNameLookupForConstructorParametersIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoBeanByNameLookupForConstructorParametersIntegrationTests.java new file mode 100644 index 000000000000..e312efb3c80e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoBeanByNameLookupForConstructorParametersIntegrationTests.java @@ -0,0 +1,175 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.bean.override.mockito.constructor; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.bean.override.example.ExampleService; +import org.springframework.test.context.bean.override.example.RealExampleService; +import org.springframework.test.context.bean.override.mockito.MockitoBean; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import org.springframework.test.mockito.MockitoAssertions; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link MockitoBean @MockitoBean} that use by-name lookup + * on constructor parameters. + * + * @author Sam Brannen + * @since 7.1 + * @see gh-36096 + * @see org.springframework.test.context.bean.override.mockito.MockitoBeanByNameLookupTestMethodScopedExtensionContextIntegrationTests + */ +@SpringJUnitConfig +class MockitoBeanByNameLookupForConstructorParametersIntegrationTests { + + final ExampleService service0A; + + final ExampleService service0B; + + final ExampleService service0C; + + final ExampleService nonExisting; + + + MockitoBeanByNameLookupForConstructorParametersIntegrationTests( + @MockitoBean ExampleService s0A, + @MockitoBean(name = "s0B") ExampleService service0B, + @MockitoBean @Qualifier("s0C") ExampleService service0C, + @MockitoBean("nonExistingBean") ExampleService nonExisting) { + + this.service0A = s0A; + this.service0B = service0B; + this.service0C = service0C; + this.nonExisting = nonExisting; + } + + + @Test + void parameterNameIsUsedAsBeanName(ApplicationContext ctx) { + assertThat(this.service0A) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(ctx.getBean("s0A")); + + assertThat(this.service0A.greeting()).as("mocked greeting").isNull(); + } + + @Test + void explicitBeanNameOverridesParameterName(ApplicationContext ctx) { + assertThat(this.service0B) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(ctx.getBean("s0B")); + + assertThat(this.service0B.greeting()).as("mocked greeting").isNull(); + } + + @Test + void qualifierIsUsedToResolveByName(ApplicationContext ctx) { + assertThat(this.service0C) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(ctx.getBean("s0C")); + + assertThat(this.service0C.greeting()).as("mocked greeting").isNull(); + } + + @Test + void mockIsCreatedWhenNoBeanExistsWithProvidedName(ApplicationContext ctx) { + assertThat(this.nonExisting) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(ctx.getBean("nonExistingBean")); + + assertThat(this.nonExisting.greeting()).as("mocked greeting").isNull(); + } + + + @Nested + class NestedTests { + + @Autowired + @Qualifier("s0A") + ExampleService localService0A; + + @Autowired + @Qualifier("nonExistingBean") + ExampleService localNonExisting; + + final ExampleService nestedNonExisting; + + + NestedTests(@MockitoBean("nestedNonExistingBean") ExampleService nestedNonExisting) { + this.nestedNonExisting = nestedNonExisting; + } + + + @Test + void mockFromEnclosingClassIsAccessibleViaAutowiring(ApplicationContext ctx) { + assertThat(this.localService0A) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(service0A) + .isSameAs(ctx.getBean("s0A")); + + assertThat(this.localService0A.greeting()).as("mocked greeting").isNull(); + } + + @Test + void mockForNonExistingBeanFromEnclosingClassIsAccessibleViaAutowiring(ApplicationContext ctx) { + assertThat(this.localNonExisting) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(nonExisting) + .isSameAs(ctx.getBean("nonExistingBean")); + + assertThat(this.localNonExisting.greeting()).as("mocked greeting").isNull(); + } + + @Test + void nestedConstructorParameterIsMockedWhenNoBeanExistsWithProvidedName(ApplicationContext ctx) { + assertThat(this.nestedNonExisting) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(ctx.getBean("nestedNonExistingBean")); + + assertThat(this.nestedNonExisting.greeting()).as("mocked greeting").isNull(); + } + } + + + @Configuration(proxyBeanMethods = false) + static class Config { + + @Bean + ExampleService s0A() { + return new RealExampleService("prod s0A"); + } + + @Bean + ExampleService s0B() { + return new RealExampleService("prod s0B"); + } + + @Bean + ExampleService s0C() { + return new RealExampleService("prod s0C"); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoBeanByTypeLookupForConstructorParametersIntegrationRecordTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoBeanByTypeLookupForConstructorParametersIntegrationRecordTests.java new file mode 100644 index 000000000000..958b3c4ccca4 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoBeanByTypeLookupForConstructorParametersIntegrationRecordTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.bean.override.mockito.constructor; + +import org.junit.jupiter.api.Test; + +import org.springframework.test.context.bean.override.example.ExampleService; +import org.springframework.test.context.bean.override.mockito.MockitoBean; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.test.mockito.MockitoAssertions.assertIsMock; + +/** + * Integration tests for {@link MockitoBean @MockitoBean} that use by-type lookup + * on constructor parameters in a Java record. + * + * @author Sam Brannen + * @since 7.1 + * @see gh-36096 + */ +@SpringJUnitConfig +record MockitoBeanByTypeLookupForConstructorParametersIntegrationRecordTests( + @MockitoBean ExampleService exampleService) { + + @Test + void test() { + assertIsMock(this.exampleService); + + when(this.exampleService.greeting()).thenReturn("Mocked greeting"); + + assertThat(this.exampleService.greeting()).isEqualTo("Mocked greeting"); + verify(this.exampleService, times(1)).greeting(); + verifyNoMoreInteractions(this.exampleService); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoBeanByTypeLookupForConstructorParametersIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoBeanByTypeLookupForConstructorParametersIntegrationTests.java new file mode 100644 index 000000000000..c88587bd6aa3 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoBeanByTypeLookupForConstructorParametersIntegrationTests.java @@ -0,0 +1,207 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.bean.override.mockito.constructor; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.NoUniqueBeanDefinitionException; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.annotation.Order; +import org.springframework.test.context.bean.override.example.CustomQualifier; +import org.springframework.test.context.bean.override.example.ExampleService; +import org.springframework.test.context.bean.override.example.RealExampleService; +import org.springframework.test.context.bean.override.mockito.MockitoBean; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import org.springframework.test.mockito.MockitoAssertions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.test.mockito.MockitoAssertions.assertIsMock; + +/** + * Integration tests for {@link MockitoBean @MockitoBean} that use by-type lookup + * on constructor parameters. + * + * @author Sam Brannen + * @since 7.1 + * @see gh-36096 + * @see org.springframework.test.context.bean.override.mockito.MockitoBeanByTypeLookupIntegrationTests + */ +@SpringJUnitConfig +class MockitoBeanByTypeLookupForConstructorParametersIntegrationTests { + + final AnotherService serviceIsNotABean; + + final ExampleService anyNameForService; + + final StringBuilder ambiguous; + + final StringBuilder ambiguousMeta; + + + MockitoBeanByTypeLookupForConstructorParametersIntegrationTests( + @MockitoBean AnotherService serviceIsNotABean, + @MockitoBean ExampleService anyNameForService, + @MockitoBean @Qualifier("prefer") StringBuilder ambiguous, + @MockitoBean @CustomQualifier StringBuilder ambiguousMeta) { + + this.serviceIsNotABean = serviceIsNotABean; + this.anyNameForService = anyNameForService; + this.ambiguous = ambiguous; + this.ambiguousMeta = ambiguousMeta; + } + + + @Test + void mockIsCreatedWhenNoCandidateIsFound() { + assertIsMock(this.serviceIsNotABean); + + when(this.serviceIsNotABean.hello()).thenReturn("Mocked hello"); + + assertThat(this.serviceIsNotABean.hello()).isEqualTo("Mocked hello"); + verify(this.serviceIsNotABean, times(1)).hello(); + verifyNoMoreInteractions(this.serviceIsNotABean); + } + + @Test + void overrideIsFoundByType(ApplicationContext ctx) { + assertThat(this.anyNameForService) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(ctx.getBean("example")) + .isSameAs(ctx.getBean(ExampleService.class)); + + when(this.anyNameForService.greeting()).thenReturn("Mocked greeting"); + + assertThat(this.anyNameForService.greeting()).isEqualTo("Mocked greeting"); + verify(this.anyNameForService, times(1)).greeting(); + verifyNoMoreInteractions(this.anyNameForService); + } + + @Test + void overrideIsFoundByTypeAndDisambiguatedByQualifier(ApplicationContext ctx) { + assertThat(this.ambiguous) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(ctx.getBean("ambiguous2")); + + assertThatExceptionOfType(NoUniqueBeanDefinitionException.class) + .isThrownBy(() -> ctx.getBean(StringBuilder.class)) + .satisfies(ex -> assertThat(ex.getBeanNamesFound()).containsOnly("ambiguous1", "ambiguous2")); + + assertThat(this.ambiguous).isEmpty(); + assertThat(this.ambiguous.substring(0)).isNull(); + verify(this.ambiguous, times(1)).length(); + verify(this.ambiguous, times(1)).substring(anyInt()); + verifyNoMoreInteractions(this.ambiguous); + } + + @Test + void overrideIsFoundByTypeAndDisambiguatedByMetaQualifier(ApplicationContext ctx) { + assertThat(this.ambiguousMeta) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(ctx.getBean("ambiguous1")); + + assertThatExceptionOfType(NoUniqueBeanDefinitionException.class) + .isThrownBy(() -> ctx.getBean(StringBuilder.class)) + .satisfies(ex -> assertThat(ex.getBeanNamesFound()).containsOnly("ambiguous1", "ambiguous2")); + + assertThat(this.ambiguousMeta).isEmpty(); + assertThat(this.ambiguousMeta.substring(0)).isNull(); + verify(this.ambiguousMeta, times(1)).length(); + verify(this.ambiguousMeta, times(1)).substring(anyInt()); + verifyNoMoreInteractions(this.ambiguousMeta); + } + + + @Nested + class NestedTests { + + @Autowired + ExampleService localAnyNameForService; + + final NestedService nestedService; + + + NestedTests(@MockitoBean NestedService nestedService) { + this.nestedService = nestedService; + } + + + @Test + void mockFromEnclosingClassConstructorParameterIsAccessibleViaAutowiring(ApplicationContext ctx) { + assertThat(this.localAnyNameForService) + .satisfies(MockitoAssertions::assertIsMock) + .isSameAs(anyNameForService) + .isSameAs(ctx.getBean("example")) + .isSameAs(ctx.getBean(ExampleService.class)); + } + + @Test + void nestedConstructorParameterIsAMock() { + assertIsMock(this.nestedService); + + when(this.nestedService.hello()).thenReturn("Nested hello"); + assertThat(this.nestedService.hello()).isEqualTo("Nested hello"); + verify(this.nestedService).hello(); + verifyNoMoreInteractions(this.nestedService); + } + } + + + public interface AnotherService { + + String hello(); + } + + public interface NestedService { + + String hello(); + } + + @Configuration(proxyBeanMethods = false) + static class Config { + + @Bean("example") + ExampleService bean1() { + return new RealExampleService("Production hello"); + } + + @Bean("ambiguous1") + @Order(1) + @CustomQualifier + StringBuilder bean2() { + return new StringBuilder("bean2"); + } + + @Bean("ambiguous2") + @Order(2) + @Qualifier("prefer") + StringBuilder bean3() { + return new StringBuilder("bean3"); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoSpyBeanByNameLookupForConstructorParametersIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoSpyBeanByNameLookupForConstructorParametersIntegrationTests.java new file mode 100644 index 000000000000..0852a2d41cd1 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoSpyBeanByNameLookupForConstructorParametersIntegrationTests.java @@ -0,0 +1,165 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.bean.override.mockito.constructor; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.bean.override.example.ExampleService; +import org.springframework.test.context.bean.override.example.RealExampleService; +import org.springframework.test.context.bean.override.mockito.MockitoSpyBean; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import org.springframework.test.mockito.MockitoAssertions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Integration tests for {@link MockitoSpyBean @MockitoSpyBean} that use by-name + * lookup on constructor parameters. + * + * @author Sam Brannen + * @since 7.1 + * @see gh-36096 + * @see org.springframework.test.context.bean.override.mockito.MockitoSpyBeanByNameLookupTestMethodScopedExtensionContextIntegrationTests + */ +@SpringJUnitConfig +class MockitoSpyBeanByNameLookupForConstructorParametersIntegrationTests { + + final ExampleService service1; + + final ExampleService service2; + + final ExampleService service3; + + + MockitoSpyBeanByNameLookupForConstructorParametersIntegrationTests( + @MockitoSpyBean ExampleService s1, + @MockitoSpyBean("s2") ExampleService service2, + @MockitoSpyBean @Qualifier("s3") ExampleService service3) { + + this.service1 = s1; + this.service2 = service2; + this.service3 = service3; + } + + + @Test + void parameterNameIsUsedAsBeanName(ApplicationContext ctx) { + assertThat(this.service1) + .satisfies(MockitoAssertions::assertIsSpy) + .isSameAs(ctx.getBean("s1")); + + assertThat(this.service1.greeting()).isEqualTo("prod 1"); + verify(this.service1).greeting(); + verifyNoMoreInteractions(this.service1); + } + + @Test + void explicitBeanNameOverridesParameterName(ApplicationContext ctx) { + assertThat(this.service2) + .satisfies(MockitoAssertions::assertIsSpy) + .isSameAs(ctx.getBean("s2")); + + assertThat(this.service2.greeting()).isEqualTo("prod 2"); + verify(this.service2).greeting(); + verifyNoMoreInteractions(this.service2); + } + + @Test + void qualifierIsUsedToResolveByName(ApplicationContext ctx) { + assertThat(this.service3) + .satisfies(MockitoAssertions::assertIsSpy) + .isSameAs(ctx.getBean("s3")); + + assertThat(this.service3.greeting()).isEqualTo("prod 3"); + verify(this.service3).greeting(); + verifyNoMoreInteractions(this.service3); + } + + + @Nested + class NestedTests { + + @Autowired + @Qualifier("s1") + ExampleService localService1; + + final ExampleService nestedSpy; + + + NestedTests(@MockitoSpyBean("s4") ExampleService nestedSpy) { + this.nestedSpy = nestedSpy; + } + + + @Test + void spyFromEnclosingClassIsAccessibleViaAutowiring(ApplicationContext ctx) { + assertThat(this.localService1) + .satisfies(MockitoAssertions::assertIsSpy) + .isSameAs(service1) + .isSameAs(ctx.getBean("s1")); + + assertThat(this.localService1.greeting()).isEqualTo("prod 1"); + verify(this.localService1).greeting(); + verifyNoMoreInteractions(this.localService1); + } + + @Test + void nestedConstructorParameterIsASpy(ApplicationContext ctx) { + assertThat(this.nestedSpy) + .satisfies(MockitoAssertions::assertIsSpy) + .isSameAs(ctx.getBean("s4")); + + assertThat(this.nestedSpy.greeting()).isEqualTo("prod 4"); + verify(this.nestedSpy).greeting(); + verifyNoMoreInteractions(this.nestedSpy); + } + } + + + @Configuration(proxyBeanMethods = false) + static class Config { + + @Bean + ExampleService s1() { + return new RealExampleService("prod 1"); + } + + @Bean + ExampleService s2() { + return new RealExampleService("prod 2"); + } + + @Bean + ExampleService s3() { + return new RealExampleService("prod 3"); + } + + @Bean + ExampleService s4() { + return new RealExampleService("prod 4"); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoSpyBeanByTypeLookupForConstructorParametersIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoSpyBeanByTypeLookupForConstructorParametersIntegrationTests.java new file mode 100644 index 000000000000..ca6b3bb6c877 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/constructor/MockitoSpyBeanByTypeLookupForConstructorParametersIntegrationTests.java @@ -0,0 +1,213 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.bean.override.mockito.constructor; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.annotation.Order; +import org.springframework.test.context.bean.override.example.CustomQualifier; +import org.springframework.test.context.bean.override.example.ExampleService; +import org.springframework.test.context.bean.override.example.RealExampleService; +import org.springframework.test.context.bean.override.mockito.MockitoSpyBean; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import org.springframework.test.mockito.MockitoAssertions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatException; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Integration tests for {@link MockitoSpyBean @MockitoSpyBean} that use by-type + * lookup on constructor parameters. + * + * @author Sam Brannen + * @since 7.1 + * @see gh-36096 + * @see org.springframework.test.context.bean.override.mockito.MockitoSpyBeanByTypeLookupIntegrationTests + */ +@SpringJUnitConfig +class MockitoSpyBeanByTypeLookupForConstructorParametersIntegrationTests { + + final ExampleService anyNameForService; + + final StringHolder ambiguous; + + final StringHolder ambiguousMeta; + + + MockitoSpyBeanByTypeLookupForConstructorParametersIntegrationTests( + @MockitoSpyBean ExampleService anyNameForService, + @MockitoSpyBean @Qualifier("prefer") StringHolder ambiguous, + @MockitoSpyBean @CustomQualifier StringHolder ambiguousMeta) { + + this.anyNameForService = anyNameForService; + this.ambiguous = ambiguous; + this.ambiguousMeta = ambiguousMeta; + } + + + @Test + void overrideIsFoundByType(ApplicationContext ctx) { + assertThat(this.anyNameForService) + .satisfies(MockitoAssertions::assertIsSpy) + .isSameAs(ctx.getBean("example")) + .isSameAs(ctx.getBean(ExampleService.class)); + + assertThat(this.anyNameForService.greeting()).isEqualTo("Production hello"); + verify(this.anyNameForService).greeting(); + verifyNoMoreInteractions(this.anyNameForService); + } + + @Test + void overrideIsFoundByTypeAndDisambiguatedByQualifier(ApplicationContext ctx) { + assertThat(this.ambiguous) + .satisfies(MockitoAssertions::assertIsSpy) + .isSameAs(ctx.getBean("ambiguous2")); + + assertThatException() + .isThrownBy(() -> ctx.getBean(StringHolder.class)) + .withMessageEndingWith("but found 2: ambiguous1,ambiguous2"); + + assertThat(this.ambiguous.getValue()).isEqualTo("bean3"); + assertThat(this.ambiguous.size()).isEqualTo(5); + verify(this.ambiguous).getValue(); + verify(this.ambiguous).size(); + verifyNoMoreInteractions(this.ambiguous); + } + + @Test + void overrideIsFoundByTypeAndDisambiguatedByMetaQualifier(ApplicationContext ctx) { + assertThat(this.ambiguousMeta) + .satisfies(MockitoAssertions::assertIsSpy) + .isSameAs(ctx.getBean("ambiguous1")); + + assertThatException() + .isThrownBy(() -> ctx.getBean(StringHolder.class)) + .withMessageEndingWith("but found 2: ambiguous1,ambiguous2"); + + assertThat(this.ambiguousMeta.getValue()).isEqualTo("bean2"); + assertThat(this.ambiguousMeta.size()).isEqualTo(5); + verify(this.ambiguousMeta).getValue(); + verify(this.ambiguousMeta).size(); + verifyNoMoreInteractions(this.ambiguousMeta); + } + + + @Nested + class NestedTests { + + @Autowired + ExampleService localAnyNameForService; + + final AnotherService nestedSpy; + + + NestedTests(@MockitoSpyBean AnotherService nestedSpy) { + this.nestedSpy = nestedSpy; + } + + + @Test + void spyFromEnclosingClassConstructorParameterIsAccessibleViaAutowiring(ApplicationContext ctx) { + assertThat(this.localAnyNameForService) + .satisfies(MockitoAssertions::assertIsSpy) + .isSameAs(anyNameForService) + .isSameAs(ctx.getBean("example")) + .isSameAs(ctx.getBean(ExampleService.class)); + + assertThat(this.localAnyNameForService.greeting()).isEqualTo("Production hello"); + verify(this.localAnyNameForService).greeting(); + verifyNoMoreInteractions(this.localAnyNameForService); + } + + @Test + void nestedConstructorParameterIsASpy(ApplicationContext ctx) { + assertThat(this.nestedSpy) + .satisfies(MockitoAssertions::assertIsSpy) + .isSameAs(ctx.getBean("anotherService")) + .isSameAs(ctx.getBean(AnotherService.class)); + + assertThat(this.nestedSpy.hello()).isEqualTo("Another hello"); + verify(this.nestedSpy).hello(); + verifyNoMoreInteractions(this.nestedSpy); + } + } + + + interface AnotherService { + + String hello(); + } + + static class StringHolder { + + private final String value; + + StringHolder(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + public int size() { + return this.value.length(); + } + } + + @Configuration(proxyBeanMethods = false) + static class Config { + + @Bean("example") + ExampleService bean1() { + return new RealExampleService("Production hello"); + } + + @Bean("ambiguous1") + @Order(1) + @CustomQualifier + StringHolder bean2() { + return new StringHolder("bean2"); + } + + @Bean("ambiguous2") + @Order(2) + @Qualifier("prefer") + StringHolder bean3() { + return new StringHolder("bean3"); + } + + @Bean + AnotherService anotherService() { + return new AnotherService() { + @Override + public String hello() { + return "Another hello"; + } + }; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/ConstructorService01.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/ConstructorService01.java new file mode 100644 index 000000000000..adc28db41fac --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/ConstructorService01.java @@ -0,0 +1,20 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.bean.override.mockito.typelevel; + +interface ConstructorService01 extends Service { +} diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansByNameIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansByNameIntegrationTests.java index 87424d767700..52f576a6a96c 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansByNameIntegrationTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansByNameIntegrationTests.java @@ -48,6 +48,10 @@ @MockitoBean(name = "s2", types = ExampleService.class) class MockitoBeansByNameIntegrationTests { + final ExampleService service0A; + final ExampleService service0B; + final ExampleService service0C; + @Autowired ExampleService s1; @@ -62,6 +66,16 @@ class MockitoBeansByNameIntegrationTests { ExampleService service4; + MockitoBeansByNameIntegrationTests(@MockitoBean ExampleService s0A, + @MockitoBean(name = "s0B") ExampleService service0B, + @MockitoBean @Qualifier("s0C") ExampleService service0C) { + + this.service0A = s0A; + this.service0B = service0B; + this.service0C = service0C; + } + + @BeforeEach void configureMocks() { assertIsMock(s1, "s1"); @@ -86,6 +100,21 @@ void checkMocksAndStandardBean() { @Configuration static class Config { + @Bean + ExampleService s0A() { + return () -> "prod 0A"; + } + + @Bean + ExampleService s0B() { + return () -> "prod 0B"; + } + + @Bean + ExampleService s0C() { + return () -> "prod 0C"; + } + @Bean ExampleService s1() { return () -> "prod 1"; diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansByTypeIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansByTypeIntegrationTests.java index e3e5207b4de7..f6f529220cc2 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansByTypeIntegrationTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansByTypeIntegrationTests.java @@ -63,10 +63,17 @@ class MockitoBeansByTypeIntegrationTests implements MockTestInterface01 { @Autowired Service06 service06; + final ConstructorService01 constructorService01; + @MockitoBean Service07 service07; + MockitoBeansByTypeIntegrationTests(@MockitoBean ConstructorService01 constructorService01) { + this.constructorService01 = constructorService01; + } + + @BeforeEach void configureMocks() { assertIsMock(service01, "service01"); @@ -75,6 +82,7 @@ void configureMocks() { assertIsMock(service04, "service04"); assertIsMock(service05, "service05"); assertIsMock(service06, "service06"); + assertIsMock(constructorService01, "constructorService01"); assertIsMock(service07, "service07"); given(service01.greeting()).willReturn("mock 01"); @@ -83,6 +91,7 @@ void configureMocks() { given(service04.greeting()).willReturn("mock 04"); given(service05.greeting()).willReturn("mock 05"); given(service06.greeting()).willReturn("mock 06"); + given(constructorService01.greeting()).willReturn("mock constructor 01"); given(service07.greeting()).willReturn("mock 07"); } @@ -94,6 +103,7 @@ void checkMocks() { assertThat(service04.greeting()).isEqualTo("mock 04"); assertThat(service05.greeting()).isEqualTo("mock 05"); assertThat(service06.greeting()).isEqualTo("mock 06"); + assertThat(constructorService01.greeting()).isEqualTo("mock constructor 01"); assertThat(service07.greeting()).isEqualTo("mock 07"); } @@ -133,6 +143,7 @@ void configureMocks() { assertIsMock(service04, "service04"); assertIsMock(service05, "service05"); assertIsMock(service06, "service06"); + assertIsMock(constructorService01, "constructorService01"); assertIsMock(service07, "service07"); assertIsMock(service08, "service08"); assertIsMock(service09, "service09"); @@ -157,6 +168,7 @@ void checkMocks() { assertThat(service04.greeting()).isEqualTo("mock 04"); assertThat(service05.greeting()).isEqualTo("mock 05"); assertThat(service06.greeting()).isEqualTo("mock 06"); + assertThat(constructorService01.greeting()).isEqualTo("mock constructor 01"); assertThat(service07.greeting()).isEqualTo("mock 07"); assertThat(service08.greeting()).isEqualTo("mock 08"); assertThat(service09.greeting()).isEqualTo("mock 09"); diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansTests.java index 2f30d8a75151..6eaa8e0123c5 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoBeansTests.java @@ -22,7 +22,7 @@ import org.springframework.core.ResolvableType; import org.springframework.test.context.bean.override.BeanOverrideHandler; -import org.springframework.test.context.bean.override.BeanOverrideTestUtils; +import org.springframework.test.context.bean.override.BeanOverrideUtils; import org.springframework.test.context.bean.override.mockito.MockitoBean; import org.springframework.test.context.bean.override.mockito.MockitoBeans; @@ -44,7 +44,7 @@ void registrationOrderForTopLevelClass() { Stream> mockedServices = getRegisteredMockTypes(MockitoBeansByTypeIntegrationTests.class); assertThat(mockedServices).containsExactly( Service01.class, Service02.class, Service03.class, Service04.class, - Service05.class, Service06.class, Service07.class); + Service05.class, Service06.class, ConstructorService01.class, Service07.class); } @Test @@ -52,14 +52,14 @@ void registrationOrderForNestedClass() { Stream> mockedServices = getRegisteredMockTypes(MockitoBeansByTypeIntegrationTests.NestedTests.class); assertThat(mockedServices).containsExactly( Service01.class, Service02.class, Service03.class, Service04.class, - Service05.class, Service06.class, Service07.class, Service08.class, - Service09.class, Service10.class, Service11.class, Service12.class, + Service05.class, Service06.class, ConstructorService01.class, Service07.class, + Service08.class, Service09.class, Service10.class, Service11.class, Service12.class, Service13.class); } private static Stream> getRegisteredMockTypes(Class testClass) { - return BeanOverrideTestUtils.findAllHandlers(testClass) + return BeanOverrideUtils.findAllHandlers(testClass) .stream() .map(BeanOverrideHandler::getBeanType) .map(ResolvableType::getRawClass); diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoSpyBeansTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoSpyBeansTests.java index 181cc082e4ee..738211e48b92 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoSpyBeansTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/typelevel/MockitoSpyBeansTests.java @@ -22,7 +22,7 @@ import org.springframework.core.ResolvableType; import org.springframework.test.context.bean.override.BeanOverrideHandler; -import org.springframework.test.context.bean.override.BeanOverrideTestUtils; +import org.springframework.test.context.bean.override.BeanOverrideUtils; import org.springframework.test.context.bean.override.mockito.MockitoSpyBean; import org.springframework.test.context.bean.override.mockito.MockitoSpyBeans; @@ -59,7 +59,7 @@ void registrationOrderForNestedClass() { private static Stream> getRegisteredMockTypes(Class testClass) { - return BeanOverrideTestUtils.findAllHandlers(testClass) + return BeanOverrideUtils.findAllHandlers(testClass) .stream() .map(BeanOverrideHandler::getBeanType) .map(ResolvableType::getRawClass); diff --git a/spring-test/src/test/java/org/springframework/test/context/config/ImplicitDefaultConfigClassesBaseTests.java b/spring-test/src/test/java/org/springframework/test/context/config/ImplicitDefaultConfigClassesBaseTests.java index 36c2718f47a7..1828e0df7c9b 100644 --- a/spring-test/src/test/java/org/springframework/test/context/config/ImplicitDefaultConfigClassesBaseTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/config/ImplicitDefaultConfigClassesBaseTests.java @@ -46,7 +46,7 @@ class ImplicitDefaultConfigClassesBaseTests { @Test - void greeting1AndPuzzle1() { + final void greeting1AndPuzzle1() { // This class must NOT be annotated with @SpringJUnitConfig or @ContextConfiguration. assertThat(AnnotatedElementUtils.hasAnnotation(getClass(), ContextConfiguration.class)).isFalse(); diff --git a/spring-test/src/test/java/org/springframework/test/context/config/ImplicitDefaultConfigClassesInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/config/ImplicitDefaultConfigClassesInheritedTests.java index 93699c072cec..7a33a39c8827 100644 --- a/spring-test/src/test/java/org/springframework/test/context/config/ImplicitDefaultConfigClassesInheritedTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/config/ImplicitDefaultConfigClassesInheritedTests.java @@ -41,14 +41,6 @@ class ImplicitDefaultConfigClassesInheritedTests extends ImplicitDefaultConfigCl String greeting2; - // To be removed in favor of base class method in 7.1 - @Test - @Override - void greeting1AndPuzzle1() { - assertThat(greeting1).isEqualTo("TEST 2"); - assertThat(puzzle1).isEqualTo(222); - } - @Test void greeting2() { // This class must NOT be annotated with @SpringJUnitConfig or @ContextConfiguration. @@ -59,8 +51,7 @@ void greeting2() { @Test void greetings(@Autowired List greetings) { - assertThat(greetings).containsExactly("TEST 2"); - // for 7.1: assertThat(greetings).containsExactly("TEST 1", "TEST 2"); + assertThat(greetings).containsExactly("TEST 1", "TEST 2"); } diff --git a/spring-test/src/test/java/org/springframework/test/context/junit/jupiter/nested/DefaultContextConfigurationDetectionWithNestedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit/jupiter/nested/DefaultContextConfigurationDetectionWithNestedTests.java index 615c23558376..8e1d361f3420 100644 --- a/spring-test/src/test/java/org/springframework/test/context/junit/jupiter/nested/DefaultContextConfigurationDetectionWithNestedTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/junit/jupiter/nested/DefaultContextConfigurationDetectionWithNestedTests.java @@ -16,6 +16,7 @@ package org.springframework.test.context.junit.jupiter.nested; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -53,7 +54,6 @@ void test(@Autowired String localGreeting) { } - /** for 7.1: @Nested class NestedTests { @@ -63,7 +63,6 @@ void test(@Autowired String localGreeting) { assertThat(localGreeting).isEqualTo("TEST"); } } - */ @Configuration diff --git a/spring-test/src/test/java/org/springframework/test/web/client/match/MultipartRequestMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/client/match/MultipartRequestMatchersTests.java index c6bc3863728d..834aa0690b78 100644 --- a/spring-test/src/test/java/org/springframework/test/web/client/match/MultipartRequestMatchersTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/client/match/MultipartRequestMatchersTests.java @@ -25,9 +25,10 @@ import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.http.HttpOutputMessage; import org.springframework.http.MediaType; -import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; import org.springframework.mock.http.client.MockClientHttpRequest; import org.springframework.mock.web.MockMultipartFile; import org.springframework.util.LinkedMultiValueMap; @@ -55,6 +56,7 @@ class MultipartRequestMatchersTests { @BeforeEach void setup() { + this.request.setMethod(HttpMethod.POST); this.request.getHeaders().setContentType(MediaType.MULTIPART_FORM_DATA); } @@ -188,7 +190,7 @@ private void writeAndAssertContains() throws IOException { } private void writeForm() throws IOException { - new FormHttpMessageConverter().write(this.input, MediaType.MULTIPART_FORM_DATA, + new MultipartHttpMessageConverter().write(this.input, MediaType.MULTIPART_FORM_DATA, new HttpOutputMessage() { @Override public OutputStream getBody() throws IOException { diff --git a/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java b/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java index e75962a6f832..026fe1dcd307 100644 --- a/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java @@ -47,7 +47,7 @@ * @author Juergen Hoeller */ @SpringJUnitWebConfig -@SuppressWarnings("deprecation") +@SuppressWarnings({"deprecation", "removal"}) class MockMvcClientHttpRequestFactoryTests { private final RestTemplate template; diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/client/MockMvcRestTestClientTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/client/MockMvcRestTestClientTests.java index a3f01f5cef0c..5a88f1b147e5 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/client/MockMvcRestTestClientTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/client/MockMvcRestTestClientTests.java @@ -17,24 +17,39 @@ package org.springframework.test.web.servlet.client; import java.io.IOException; +import java.nio.charset.StandardCharsets; import jakarta.servlet.http.Cookie; import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.core.io.Resource; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.http.converter.multipart.FilePart; +import org.springframework.http.converter.multipart.FormFieldPart; +import org.springframework.http.converter.multipart.Part; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.bind.annotation.CookieValue; import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests that use a {@link RestTestClient} configured with a {@link MockMvc} instance * that uses a standalone controller. * * @author Rob Worsnop * @author Sam Brannen - * @since 7.0 + * @author Brian Clozel */ class MockMvcRestTestClientTests { @@ -75,6 +90,45 @@ void withErrorAndBody() { .isEqualTo("some really bad request"); } + @Test + void retrieveMultipart() { + client.get() + .uri("/multipart") + .accept(MediaType.MULTIPART_FORM_DATA) + .exchange() + .expectStatus().isOk() + .expectBody(new ParameterizedTypeReference>() {}) + .value(result -> { + assertThat(result).hasSize(3); + assertThat(result).containsKeys("text1", "text2", "file1"); + assertThat(result.getFirst("text1")).isInstanceOfSatisfying(FormFieldPart.class, + part -> assertThat(part.value()).isEqualTo("a")); + assertThat(result.getFirst("text2")).isInstanceOfSatisfying(FormFieldPart.class, + part -> assertThat(part.value()).isEqualTo("b")); + assertThat(result.getFirst("file1")).isInstanceOfSatisfying(FilePart.class, + part -> assertThat(part.filename()).isEqualTo("file1.txt")); + }); + } + + @Test + void writeMultipart() { + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("text1", "value1"); + parts.add("file1", new ByteArrayResource("filecontent1".getBytes()) { + @Override + public String getFilename() { + return "spring.txt"; + } + }); + + client.post() + .uri("/multipart") + .contentType(MediaType.MULTIPART_FORM_DATA) + .body(parts) + .exchange() + .expectStatus().isOk(); + } + @RestController static class TestController { @@ -94,6 +148,34 @@ void handleErrorWithBody(HttpServletResponse response) throws Exception { response.sendError(400); response.getWriter().write("some really bad request"); } + + @GetMapping(path = "/multipart", produces = MediaType.MULTIPART_FORM_DATA_VALUE) + MultiValueMap writeMultipart() { + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("text1", "a"); + parts.add("text2", "b"); + Resource resource = new ByteArrayResource("Lorem ipsum dolor sit amet".getBytes()) { + @Override + public String getFilename() { + return "file1.txt"; + } + }; + parts.add("file1", resource); + return parts; + } + + @PostMapping(path = "/multipart", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) + ResponseEntity readMultipart(@RequestParam MultiValueMap parts) throws Exception { + assertThat(parts.keySet()).containsOnly("text1", "file1"); + jakarta.servlet.http.Part text1 = parts.get("text1").get(0); + assertThat(text1.getName()).isEqualTo("text1"); + assertThat(text1.getInputStream()).asString(StandardCharsets.UTF_8).isEqualTo("value1"); + jakarta.servlet.http.Part file1 = parts.get("file1").get(0); + assertThat(file1.getName()).isEqualTo("file1"); + assertThat(file1.getSubmittedFileName()).isEqualTo("spring.txt"); + assertThat(file1.getInputStream()).asString(StandardCharsets.UTF_8).isEqualTo("filecontent1"); + return ResponseEntity.ok().build(); + } } } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/client/samples/XmlContentTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/client/samples/XmlContentTests.java index ee12416dc71a..c8b817f6f2a1 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/client/samples/XmlContentTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/client/samples/XmlContentTests.java @@ -31,6 +31,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; import org.springframework.test.web.servlet.client.RestTestClient; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; @@ -59,7 +60,9 @@ class XmlContentTests { """; - private final RestTestClient client = RestTestClient.bindToController(new PersonController()).build(); + private final RestTestClient client = RestTestClient.bindToController(new PersonController()) + .configureServer(server -> server.setMessageConverters(new Jaxb2RootElementHttpMessageConverter())) + .build(); @Test diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/client/standalone/resultmatches/XmlContentAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/client/standalone/resultmatches/XmlContentAssertionTests.java index ba3e9aafecad..f67c455a737f 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/client/standalone/resultmatches/XmlContentAssertionTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/client/standalone/resultmatches/XmlContentAssertionTests.java @@ -28,6 +28,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; import org.springframework.stereotype.Controller; import org.springframework.test.web.Person; import org.springframework.test.web.reactive.server.WebTestClient; @@ -58,6 +59,7 @@ class XmlContentAssertionTests { private final WebTestClient testClient = MockMvcWebTestClient.bindToController(new MusicController()) + .messageConverters(new Jaxb2RootElementHttpMessageConverter()) .alwaysExpect(status().isOk()) .alwaysExpect(content().contentType(MediaType.parseMediaType("application/xml;charset=UTF-8"))) .configureClient() diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/client/standalone/resultmatches/XpathAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/client/standalone/resultmatches/XpathAssertionTests.java index ca17a980995d..073ef9329bd5 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/client/standalone/resultmatches/XpathAssertionTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/client/standalone/resultmatches/XpathAssertionTests.java @@ -31,6 +31,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; import org.springframework.stereotype.Controller; import org.springframework.test.web.Person; import org.springframework.test.web.reactive.server.WebTestClient; @@ -60,6 +61,7 @@ class XpathAssertionTests { private final WebTestClient testClient = MockMvcWebTestClient.bindToController(new MusicController()) + .messageConverters(new Jaxb2RootElementHttpMessageConverter()) .alwaysExpect(status().isOk()) .alwaysExpect(content().contentType(MediaType.parseMediaType("application/xml;charset=UTF-8"))) .configureClient() diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XmlContentAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XmlContentAssertionTests.java index d4bb5c7be9db..677e66bb890a 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XmlContentAssertionTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XmlContentAssertionTests.java @@ -28,6 +28,7 @@ import org.junit.jupiter.api.Test; import org.springframework.http.MediaType; +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; import org.springframework.stereotype.Controller; import org.springframework.test.web.Person; import org.springframework.test.web.servlet.MockMvc; @@ -65,6 +66,7 @@ class XmlContentAssertionTests { @BeforeEach void setup() { this.mockMvc = standaloneSetup(new MusicController()) + .setMessageConverters(new Jaxb2RootElementHttpMessageConverter()) .defaultRequest(get("/").accept(MediaType.APPLICATION_XML, MediaType.parseMediaType("application/xml;charset=UTF-8"))) .alwaysExpect(status().isOk()) .alwaysExpect(content().contentType(MediaType.parseMediaType("application/xml;charset=UTF-8"))) diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XpathAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XpathAssertionTests.java index 0bcd2f6a456c..58e0d336e4ca 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XpathAssertionTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XpathAssertionTests.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test; import org.springframework.http.MediaType; +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; import org.springframework.stereotype.Controller; import org.springframework.test.web.Person; import org.springframework.test.web.servlet.MockMvc; @@ -67,6 +68,7 @@ class XpathAssertionTests { @BeforeEach void setup() throws Exception { this.mockMvc = standaloneSetup(new MusicController()) + .setMessageConverters(new Jaxb2RootElementHttpMessageConverter()) .defaultRequest(get("/").accept(MediaType.APPLICATION_XML, MediaType.parseMediaType("application/xml;charset=UTF-8"))) .alwaysExpect(status().isOk()) .alwaysExpect(content().contentType(MediaType.parseMediaType("application/xml;charset=UTF-8"))) diff --git a/spring-test/src/test/kotlin/org/springframework/test/context/bean/override/mockito/MockitoBeanByTypeLookupForConstructorParametersIntegrationKotlinDataClassTests.kt b/spring-test/src/test/kotlin/org/springframework/test/context/bean/override/mockito/MockitoBeanByTypeLookupForConstructorParametersIntegrationKotlinDataClassTests.kt new file mode 100644 index 000000000000..537c0c11d2b2 --- /dev/null +++ b/spring-test/src/test/kotlin/org/springframework/test/context/bean/override/mockito/MockitoBeanByTypeLookupForConstructorParametersIntegrationKotlinDataClassTests.kt @@ -0,0 +1,44 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.bean.override.mockito + +import org.junit.jupiter.api.Test + +import org.springframework.test.context.bean.override.example.ExampleService +import org.springframework.test.context.bean.override.mockito.MockitoBean +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig +import org.springframework.test.mockito.MockitoAssertions.assertIsMock + +/** + * Integration tests for [@MockitoBean][MockitoBean] that use by-type lookup + * on constructor parameters in a Kotlin data class. + * + * @author Sam Brannen + * @since 7.1 + * @see gh-36096 + * @see MockitoBeanByTypeLookupForConstructorParametersIntegrationKotlinTests + */ +@SpringJUnitConfig +data class MockitoBeanByTypeLookupForConstructorParametersIntegrationKotlinDataClassTests( + @MockitoBean val exampleService: ExampleService) { + + @Test + fun test() { + assertIsMock(exampleService) + } + +} diff --git a/spring-test/src/test/kotlin/org/springframework/test/context/bean/override/mockito/MockitoBeanByTypeLookupForConstructorParametersIntegrationKotlinTests.kt b/spring-test/src/test/kotlin/org/springframework/test/context/bean/override/mockito/MockitoBeanByTypeLookupForConstructorParametersIntegrationKotlinTests.kt new file mode 100644 index 000000000000..b1b9c5f9391a --- /dev/null +++ b/spring-test/src/test/kotlin/org/springframework/test/context/bean/override/mockito/MockitoBeanByTypeLookupForConstructorParametersIntegrationKotlinTests.kt @@ -0,0 +1,44 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.bean.override.mockito + +import org.junit.jupiter.api.Test + +import org.springframework.test.context.bean.override.example.ExampleService +import org.springframework.test.context.bean.override.mockito.MockitoBean +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig +import org.springframework.test.mockito.MockitoAssertions.assertIsMock + +/** + * Integration tests for [@MockitoBean][MockitoBean] that use by-type lookup + * on constructor parameters in Kotlin. + * + * @author Sam Brannen + * @since 7.1 + * @see gh-36096 + * @see org.springframework.test.context.bean.override.mockito.MockitoBeanByNameLookupTestMethodScopedExtensionContextIntegrationTests + */ +@SpringJUnitConfig +class MockitoBeanByTypeLookupForConstructorParametersIntegrationKotlinTests( + @MockitoBean val exampleService: ExampleService) { + + @Test + fun test() { + assertIsMock(exampleService) + } + +} diff --git a/spring-test/src/test/kotlin/org/springframework/test/web/servlet/MockMvcExtensionsTests.kt b/spring-test/src/test/kotlin/org/springframework/test/web/servlet/MockMvcExtensionsTests.kt index f2b6fcb19157..567529466250 100644 --- a/spring-test/src/test/kotlin/org/springframework/test/web/servlet/MockMvcExtensionsTests.kt +++ b/spring-test/src/test/kotlin/org/springframework/test/web/servlet/MockMvcExtensionsTests.kt @@ -28,6 +28,8 @@ import org.springframework.http.MediaType.APPLICATION_ATOM_XML import org.springframework.http.MediaType.APPLICATION_JSON import org.springframework.http.MediaType.APPLICATION_XML import org.springframework.http.MediaType.TEXT_PLAIN +import org.springframework.http.converter.json.JacksonJsonHttpMessageConverter +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter import org.springframework.test.json.JsonCompareMode import org.springframework.test.web.Person import org.springframework.test.web.servlet.setup.MockMvcBuilders @@ -52,7 +54,9 @@ import java.util.Locale */ class MockMvcExtensionsTests { - private val mockMvc = MockMvcBuilders.standaloneSetup(PersonController()).build() + private val mockMvc = MockMvcBuilders.standaloneSetup(PersonController()) + .setMessageConverters(JacksonJsonHttpMessageConverter(), Jaxb2RootElementHttpMessageConverter()) + .build() @Test fun request() { diff --git a/spring-web/src/main/java/org/springframework/http/HttpMethod.java b/spring-web/src/main/java/org/springframework/http/HttpMethod.java index 856031d98e64..422abf282031 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpMethod.java +++ b/spring-web/src/main/java/org/springframework/http/HttpMethod.java @@ -17,6 +17,7 @@ package org.springframework.http; import java.io.Serializable; +import java.util.Locale; import org.jspecify.annotations.Nullable; @@ -29,6 +30,7 @@ * * @author Arjen Poutsma * @author Juergen Hoeller + * @author Sam Brannen * @since 3.0 */ public final class HttpMethod implements Comparable, Serializable { @@ -110,20 +112,23 @@ public static HttpMethod[] values() { /** * Return an {@code HttpMethod} object for the given value. - *

    Note that this lookup is case-sensitive. For predefined constants, - * the method value must be provided in uppercase (e.g., {@code "GET"}, - * {@code "POST"}). If no predefined constant matches, a new {@code HttpMethod} - * instance is returned for the given value as-is. For example, - * {@code HttpMethod.valueOf("GET")} resolves to {@link HttpMethod#GET}, while - * {@code HttpMethod.valueOf("get")} resolves to {@code new HttpMethod("get")}, - * and the two resulting {@code HttpMethod} instances are not + *

    As of Spring Framework 7.1, lookups for predefined constants such as + * {@link HttpMethod#GET GET} are case-insensitive. + *

    If no predefined constant matches, a new {@code HttpMethod} instance is + * returned for the given value as-is. + *

    For example, {@code HttpMethod.valueOf("GET")} and + * {@code HttpMethod.valueOf("get")} both resolve to {@link HttpMethod#GET}. + * Whereas, {@code HttpMethod.valueOf("FOO")} and + * {@code HttpMethod.valueOf("foo")} resolve to {@code new HttpMethod("FOO")} + * and {@code new HttpMethod("foo")}, respectively. In the latter case, the + * two resulting {@code HttpMethod} instances are not * {@linkplain #equals(Object) equal} and do not {@linkplain #matches(String) match}. * @param method the method value as a String * @return the corresponding {@code HttpMethod} */ public static HttpMethod valueOf(String method) { Assert.notNull(method, "Method must not be null"); - return switch (method) { + return switch (method.toUpperCase(Locale.ROOT)) { case "GET" -> GET; case "HEAD" -> HEAD; case "POST" -> POST; diff --git a/spring-web/src/main/java/org/springframework/http/MediaType.java b/spring-web/src/main/java/org/springframework/http/MediaType.java index 1d7ecf7fff9a..578b7882241f 100644 --- a/spring-web/src/main/java/org/springframework/http/MediaType.java +++ b/spring-web/src/main/java/org/springframework/http/MediaType.java @@ -209,6 +209,19 @@ public class MediaType extends MimeType implements Serializable { */ public static final String APPLICATION_NDJSON_VALUE = "application/x-ndjson"; + /** + * Media type for {@code application/jsonl} (JSON Lines). + * @since 7.1 + * @see JSON Lines + */ + public static final MediaType APPLICATION_JSONL; + + /** + * A String equivalent of {@link MediaType#APPLICATION_JSONL}. + * @since 7.1 + */ + public static final String APPLICATION_JSONL_VALUE = "application/jsonl"; + /** * Media type for {@code application/xhtml+xml}. */ @@ -372,6 +385,7 @@ public class MediaType extends MimeType implements Serializable { APPLICATION_GRAPHQL_RESPONSE = new MediaType("application", "graphql-response+json"); APPLICATION_JSON = new MediaType("application", "json"); APPLICATION_NDJSON = new MediaType("application", "x-ndjson"); + APPLICATION_JSONL = new MediaType("application", "jsonl"); APPLICATION_OCTET_STREAM = new MediaType("application", "octet-stream"); APPLICATION_PDF = new MediaType("application", "pdf"); APPLICATION_PROBLEM_JSON = new MediaType("application", "problem+json"); diff --git a/spring-web/src/main/java/org/springframework/http/client/support/HttpAccessor.java b/spring-web/src/main/java/org/springframework/http/client/support/HttpAccessor.java index 7cf850c52dd1..c42b584916ef 100644 --- a/spring-web/src/main/java/org/springframework/http/client/support/HttpAccessor.java +++ b/spring-web/src/main/java/org/springframework/http/client/support/HttpAccessor.java @@ -50,7 +50,10 @@ * @since 3.0 * @see ClientHttpRequestFactory * @see org.springframework.web.client.RestTemplate + * @deprecated since 7.1 with no replacement. */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public abstract class HttpAccessor { /** Logger available to subclasses. */ diff --git a/spring-web/src/main/java/org/springframework/http/client/support/InterceptingHttpAccessor.java b/spring-web/src/main/java/org/springframework/http/client/support/InterceptingHttpAccessor.java index a4436fb391eb..75000affea3a 100644 --- a/spring-web/src/main/java/org/springframework/http/client/support/InterceptingHttpAccessor.java +++ b/spring-web/src/main/java/org/springframework/http/client/support/InterceptingHttpAccessor.java @@ -42,7 +42,10 @@ * @see ClientHttpRequestInterceptor * @see InterceptingClientHttpRequestFactory * @see org.springframework.web.client.RestTemplate + * @deprecated since 7.1 with no replacement. */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public abstract class InterceptingHttpAccessor extends HttpAccessor { private final List interceptors = new ArrayList<>(); diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/GsonEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/GsonEncoder.java index 7ce4130ee31b..1c80af93bc70 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/json/GsonEncoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/json/GsonEncoder.java @@ -58,7 +58,8 @@ public class GsonEncoder extends AbstractEncoder implements HttpMessageE private static final MimeType[] DEFAULT_JSON_MIME_TYPES = new MimeType[] { MediaType.APPLICATION_JSON, new MediaType("application", "*+json"), - MediaType.APPLICATION_NDJSON + MediaType.APPLICATION_NDJSON, + MediaType.APPLICATION_JSONL }; private final Gson gson; @@ -68,11 +69,12 @@ public class GsonEncoder extends AbstractEncoder implements HttpMessageE /** * Construct a new encoder using a default {@link Gson} instance * and the {@code "application/json"} and {@code "application/*+json"} - * MIME types. The {@code "application/x-ndjson"} is configured for streaming. + * MIME types. The {@code "application/jsonl"} and {@code "application/x-ndjson"} + * are configured for streaming. */ public GsonEncoder() { this(new Gson(), DEFAULT_JSON_MIME_TYPES); - setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON)); + setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON, MediaType.APPLICATION_JSONL)); } /** diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2CodecSupport.java b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2CodecSupport.java index b9f0355cbbb1..fbcebe25f236 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2CodecSupport.java +++ b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2CodecSupport.java @@ -81,7 +81,8 @@ public abstract class Jackson2CodecSupport { private static final List defaultMimeTypes = List.of( MediaType.APPLICATION_JSON, new MediaType("application", "*+json"), - MediaType.APPLICATION_NDJSON); + MediaType.APPLICATION_NDJSON, + MediaType.APPLICATION_JSONL); protected final Log logger = HttpLogging.forLogName(getClass()); diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2JsonEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2JsonEncoder.java index b5db02e692ec..274780fda318 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2JsonEncoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2JsonEncoder.java @@ -63,7 +63,7 @@ public Jackson2JsonEncoder() { public Jackson2JsonEncoder(ObjectMapper mapper, MimeType... mimeTypes) { super(mapper, mimeTypes); - setStreamingMediaTypes(Arrays.asList(MediaType.APPLICATION_NDJSON)); + setStreamingMediaTypes(Arrays.asList(MediaType.APPLICATION_NDJSON, MediaType.APPLICATION_JSONL)); this.ssePrettyPrinter = initSsePrettyPrinter(); } diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/JacksonJsonDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/JacksonJsonDecoder.java index 73cb908d6fe1..5d0dd45e4f70 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/json/JacksonJsonDecoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/json/JacksonJsonDecoder.java @@ -55,7 +55,8 @@ public class JacksonJsonDecoder extends AbstractJacksonDecoder { private static final MimeType[] DEFAULT_JSON_MIME_TYPES = new MimeType[] { MediaType.APPLICATION_JSON, new MediaType("application", "*+json"), - MediaType.APPLICATION_NDJSON + MediaType.APPLICATION_NDJSON, + MediaType.APPLICATION_JSONL }; diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/JacksonJsonEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/JacksonJsonEncoder.java index e8a14d01b030..57b55106834e 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/json/JacksonJsonEncoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/json/JacksonJsonEncoder.java @@ -55,7 +55,8 @@ public class JacksonJsonEncoder extends AbstractJacksonEncoder { private static final MimeType[] DEFAULT_JSON_MIME_TYPES = new MimeType[] { MediaType.APPLICATION_JSON, new MediaType("application", "*+json"), - MediaType.APPLICATION_NDJSON + MediaType.APPLICATION_NDJSON, + MediaType.APPLICATION_JSONL }; @@ -99,7 +100,7 @@ public JacksonJsonEncoder(JsonMapper mapper) { */ public JacksonJsonEncoder(JsonMapper.Builder builder, MimeType... mimeTypes) { super(builder.addMixIn(ProblemDetail.class, ProblemDetailJacksonMixin.class), mimeTypes); - setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON)); + setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON, MediaType.APPLICATION_JSONL)); this.ssePrettyPrinter = initSsePrettyPrinter(); } @@ -110,7 +111,7 @@ public JacksonJsonEncoder(JsonMapper.Builder builder, MimeType... mimeTypes) { */ public JacksonJsonEncoder(JsonMapper mapper, MimeType... mimeTypes) { super(mapper, mimeTypes); - setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON)); + setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON, MediaType.APPLICATION_JSONL)); this.ssePrettyPrinter = initSsePrettyPrinter(); } diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonDecoder.java index 7198e1992b55..c0b1c1239947 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonDecoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonDecoder.java @@ -67,7 +67,8 @@ public class KotlinSerializationJsonDecoder extends KotlinSerializationStringDec private static final MimeType[] DEFAULT_JSON_MIME_TYPES = new MimeType[] { MediaType.APPLICATION_JSON, new MediaType("application", "*+json"), - MediaType.APPLICATION_NDJSON + MediaType.APPLICATION_NDJSON, + MediaType.APPLICATION_JSONL }; /** diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonEncoder.java index 4c3e22ec054d..0701615a6ea8 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonEncoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonEncoder.java @@ -36,7 +36,7 @@ /** * Encode from an {@code Object} stream to a byte stream of JSON objects using * kotlinx.serialization. - * It supports {@code application/json}, {@code application/x-ndjson} and {@code application/*+json} with + * It supports {@code application/json}, {@code application/x-ndjson}, {@code application/jsonl} and {@code application/*+json} with * various character sets, {@code UTF-8} being the default. * *

    As of Spring Framework 7.0, by default it only encodes types annotated with @@ -59,7 +59,8 @@ public class KotlinSerializationJsonEncoder extends KotlinSerializationStringEnc private static final MimeType[] DEFAULT_JSON_MIME_TYPES = new MimeType[] { MediaType.APPLICATION_JSON, new MediaType("application", "*+json"), - MediaType.APPLICATION_NDJSON + MediaType.APPLICATION_NDJSON, + MediaType.APPLICATION_JSONL }; /** @@ -87,7 +88,7 @@ public KotlinSerializationJsonEncoder(Predicate typePredicate) { */ public KotlinSerializationJsonEncoder(Json json) { super(json, DEFAULT_JSON_MIME_TYPES); - setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON)); + setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON, MediaType.APPLICATION_JSONL)); } /** @@ -97,7 +98,7 @@ public KotlinSerializationJsonEncoder(Json json) { */ public KotlinSerializationJsonEncoder(Json json, Predicate typePredicate) { super(json, typePredicate, DEFAULT_JSON_MIME_TYPES); - setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON)); + setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON, MediaType.APPLICATION_JSONL)); } @Override diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufJsonEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufJsonEncoder.java index 73ef356bd1a5..ef42819cd3c7 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufJsonEncoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufJsonEncoder.java @@ -80,7 +80,7 @@ public ProtobufJsonEncoder(JsonFormat.Printer printer) { @Override public List getStreamingMediaTypes() { - return List.of(MediaType.APPLICATION_NDJSON); + return List.of(MediaType.APPLICATION_NDJSON, MediaType.APPLICATION_JSONL); } @Override diff --git a/spring-web/src/main/java/org/springframework/http/converter/AbstractGenericHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/AbstractGenericHttpMessageConverter.java index a72f1ab1873e..e80ddc4a330c 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/AbstractGenericHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/AbstractGenericHttpMessageConverter.java @@ -117,7 +117,7 @@ public HttpHeaders getHeaders() { } @Override public boolean repeatable() { - return supportsRepeatableWrites(t); + return canWriteRepeatedly(t, contentType); } }); } diff --git a/spring-web/src/main/java/org/springframework/http/converter/AbstractHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/AbstractHttpMessageConverter.java index c1f587eecc81..1c6b5fee0a3b 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/AbstractHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/AbstractHttpMessageConverter.java @@ -304,9 +304,11 @@ else if (MediaType.APPLICATION_OCTET_STREAM.equals(contentType)) { * @return {@code true} if {@code t} can be written repeatedly; * {@code false} otherwise * @since 6.1 + * @deprecated since 7.1 in favor of {@link #canWriteRepeatedly(Object, MediaType)}. */ + @Deprecated(since = "7.1", forRemoval = true) protected boolean supportsRepeatableWrites(T t) { - return false; + return canWriteRepeatedly(t, null); } diff --git a/spring-web/src/main/java/org/springframework/http/converter/AbstractJacksonHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/AbstractJacksonHttpMessageConverter.java index 40d6847b484b..b9f730f0dcf1 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/AbstractJacksonHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/AbstractJacksonHttpMessageConverter.java @@ -285,6 +285,11 @@ public boolean canWrite(ResolvableType type, Class valueClass, @Nullable Medi return this.mapperRegistrations == null || selectMapper(valueClass, mediaType) != null; } + @Override + public boolean canWriteRepeatedly(Object o, @Nullable MediaType contentType) { + return true; + } + /** * Select an ObjectMapper to use, either the main ObjectMapper or another * if the handling for the given Class has been customized through @@ -503,6 +508,7 @@ protected JsonEncoding getJsonEncoding(@Nullable MediaType contentType) { } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Object o) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/AbstractKotlinSerializationHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/AbstractKotlinSerializationHttpMessageConverter.java index c49057a7ace5..c0c18525b702 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/AbstractKotlinSerializationHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/AbstractKotlinSerializationHttpMessageConverter.java @@ -120,6 +120,11 @@ public boolean canWrite(ResolvableType type, Class valueClass, @Nullable Medi return serializer(resolvableType, null) != null && canWrite(mediaType); } + @Override + public boolean canWriteRepeatedly(Object o, @Nullable MediaType contentType) { + return true; + } + @Override public final Object read(ResolvableType type, HttpInputMessage inputMessage, @Nullable Map hints) throws IOException, HttpMessageNotReadableException { @@ -194,6 +199,7 @@ protected abstract void writeInternal(Object object, KSerializer seriali } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Object object) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/AbstractSmartHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/AbstractSmartHttpMessageConverter.java index 798507fa2030..84e1d02d4b16 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/AbstractSmartHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/AbstractSmartHttpMessageConverter.java @@ -108,7 +108,7 @@ public HttpHeaders getHeaders() { } @Override public boolean repeatable() { - return supportsRepeatableWrites(t); + return canWriteRepeatedly(t, contentType); } }); } diff --git a/spring-web/src/main/java/org/springframework/http/converter/ByteArrayHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/ByteArrayHttpMessageConverter.java index b924827739ea..b893b6a4972c 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/ByteArrayHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/ByteArrayHttpMessageConverter.java @@ -50,6 +50,11 @@ public boolean supports(Class clazz) { return byte[].class == clazz; } + @Override + public boolean canWriteRepeatedly(byte[] bytes, @Nullable MediaType contentType) { + return true; + } + @Override public byte[] readInternal(Class clazz, HttpInputMessage message) throws IOException { long length = message.getHeaders().getContentLength(); @@ -68,6 +73,7 @@ protected void writeInternal(byte[] bytes, HttpOutputMessage outputMessage) thro } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(byte[] bytes) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/DefaultHttpMessageConverters.java b/spring-web/src/main/java/org/springframework/http/converter/DefaultHttpMessageConverters.java index f4fdd6c87341..97129ce56282 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/DefaultHttpMessageConverters.java +++ b/spring-web/src/main/java/org/springframework/http/converter/DefaultHttpMessageConverters.java @@ -34,10 +34,10 @@ import org.springframework.http.converter.json.JsonbHttpMessageConverter; import org.springframework.http.converter.json.KotlinSerializationJsonHttpMessageConverter; import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; import org.springframework.http.converter.protobuf.KotlinSerializationProtobufHttpMessageConverter; import org.springframework.http.converter.smile.JacksonSmileHttpMessageConverter; import org.springframework.http.converter.smile.MappingJackson2SmileHttpMessageConverter; -import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; import org.springframework.http.converter.xml.JacksonXmlHttpMessageConverter; import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; import org.springframework.http.converter.xml.MappingJackson2XmlHttpMessageConverter; @@ -76,39 +76,39 @@ public Iterator> iterator() { abstract static class DefaultBuilder { - private static final boolean JACKSON_PRESENT; + static final boolean JACKSON_PRESENT; - private static final boolean JACKSON_2_PRESENT; + static final boolean JACKSON_2_PRESENT; - private static final boolean GSON_PRESENT; + static final boolean GSON_PRESENT; - private static final boolean JSONB_PRESENT; + static final boolean JSONB_PRESENT; - private static final boolean KOTLIN_SERIALIZATION_JSON_PRESENT; + static final boolean KOTLIN_SERIALIZATION_JSON_PRESENT; - private static final boolean JACKSON_XML_PRESENT; + static final boolean JACKSON_XML_PRESENT; - private static final boolean JACKSON_2_XML_PRESENT; + static final boolean JACKSON_2_XML_PRESENT; - private static final boolean JAXB_2_PRESENT; + static final boolean JAXB_2_PRESENT; - private static final boolean JACKSON_SMILE_PRESENT; + static final boolean JACKSON_SMILE_PRESENT; - private static final boolean JACKSON_2_SMILE_PRESENT; + static final boolean JACKSON_2_SMILE_PRESENT; - private static final boolean JACKSON_CBOR_PRESENT; + static final boolean JACKSON_CBOR_PRESENT; - private static final boolean JACKSON_2_CBOR_PRESENT; + static final boolean JACKSON_2_CBOR_PRESENT; - private static final boolean KOTLIN_SERIALIZATION_CBOR_PRESENT; + static final boolean KOTLIN_SERIALIZATION_CBOR_PRESENT; - private static final boolean JACKSON_YAML_PRESENT; + static final boolean JACKSON_YAML_PRESENT; - private static final boolean JACKSON_2_YAML_PRESENT; + static final boolean JACKSON_2_YAML_PRESENT; - private static final boolean KOTLIN_SERIALIZATION_PROTOBUF_PRESENT; + static final boolean KOTLIN_SERIALIZATION_PROTOBUF_PRESENT; - private static final boolean ROME_PRESENT; + static final boolean ROME_PRESENT; boolean registerDefaults; @@ -120,6 +120,8 @@ abstract static class DefaultBuilder { @Nullable HttpMessageConverter resourceRegionConverter; + @Nullable HttpMessageConverter formConverter; + @Nullable Consumer> configurer; @Nullable Consumer>> convertersListConfigurer; @@ -175,6 +177,11 @@ void setStringConverter(HttpMessageConverter stringConverter) { this.stringConverter = stringConverter; } + public void setFormConverter(HttpMessageConverter formConverter) { + checkConverterSupports(formConverter, MediaType.APPLICATION_FORM_URLENCODED); + this.formConverter = formConverter; + } + void setKotlinSerializationJsonConverter(HttpMessageConverter kotlinJsonConverter) { Assert.notNull(kotlinJsonConverter, "kotlinJsonConverter must not be null"); this.kotlinJsonConverter = kotlinJsonConverter; @@ -241,6 +248,9 @@ List> getBaseConverters() { if (this.stringConverter != null) { converters.add(this.stringConverter); } + if (this.formConverter != null) { + converters.add(this.formConverter); + } return converters; } @@ -289,6 +299,9 @@ void detectMessageConverters() { if (this.stringConverter == null) { this.stringConverter = new StringHttpMessageConverter(); } + if (this.formConverter == null) { + this.formConverter = new FormHttpMessageConverter(); + } if (this.kotlinJsonConverter == null) { if (KOTLIN_SERIALIZATION_JSON_PRESENT) { if (this.jsonConverter != null || JACKSON_PRESENT || JACKSON_2_PRESENT || GSON_PRESENT || JSONB_PRESENT) { @@ -401,6 +414,12 @@ public ClientBuilder withStringConverter(HttpMessageConverter stringConverter return this; } + @Override + public ClientBuilder withFormConverter(HttpMessageConverter formMessageConverter) { + setFormConverter(formMessageConverter); + return this; + } + @Override public ClientBuilder withKotlinSerializationJsonConverter(HttpMessageConverter kotlinSerializationJsonConverter) { setKotlinSerializationJsonConverter(kotlinSerializationJsonConverter); @@ -461,6 +480,23 @@ public ClientBuilder configureMessageConvertersList(Consumer> partConverters = new ArrayList<>(this.getCustomConverters()); List> allConverters = new ArrayList<>(this.getCustomConverters()); - partConverters.addAll(this.getCoreConverters()); allConverters.addAll(this.getBaseConverters()); if (this.resourceConverter != null) { allConverters.add(this.resourceConverter); } + if (!this.getCoreConverters().isEmpty()) { + // use separate instances of base converters for multipart + partConverters.addAll(List.of(new ByteArrayHttpMessageConverter(), + new StringHttpMessageConverter(), new ResourceHttpMessageConverter())); + partConverters.addAll(this.getCoreConverters()); + } if (!partConverters.isEmpty() || !allConverters.isEmpty()) { - allConverters.add(new AllEncompassingFormHttpMessageConverter(partConverters)); + allConverters.add(new MultipartHttpMessageConverter(partConverters)); } allConverters.addAll(this.getCoreConverters()); if (this.convertersListConfigurer != null) { @@ -509,6 +550,12 @@ public ServerBuilder withStringConverter(HttpMessageConverter stringConverter return this; } + @Override + public ServerBuilder withFormConverter(HttpMessageConverter formMessageConverter) { + setFormConverter(formMessageConverter); + return this; + } + @Override public ServerBuilder withKotlinSerializationJsonConverter(HttpMessageConverter kotlinSerializationJsonConverter) { setKotlinSerializationJsonConverter(kotlinSerializationJsonConverter); @@ -569,6 +616,20 @@ public ServerBuilder configureMessageConvertersList(Consumer> partConverters = new ArrayList<>(this.getCustomConverters()); List> allConverters = new ArrayList<>(this.getCustomConverters()); - partConverters.addAll(this.getCoreConverters()); allConverters.addAll(this.getBaseConverters()); if (this.resourceConverter != null) { allConverters.add(this.resourceConverter); @@ -586,8 +646,14 @@ public HttpMessageConverters build() { if (this.resourceRegionConverter != null) { allConverters.add(this.resourceRegionConverter); } + if (!this.getCoreConverters().isEmpty()) { + // use separate instances of base converters for multipart + partConverters.addAll(List.of(new ByteArrayHttpMessageConverter(), + new StringHttpMessageConverter(), new ResourceHttpMessageConverter())); + partConverters.addAll(this.getCoreConverters()); + } if (!partConverters.isEmpty() || !allConverters.isEmpty()) { - allConverters.add(new AllEncompassingFormHttpMessageConverter(partConverters)); + allConverters.add(new MultipartHttpMessageConverter(partConverters)); } allConverters.addAll(this.getCoreConverters()); if (this.convertersListConfigurer != null) { diff --git a/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java index fd3ad3012f16..a24e1aae8830 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java @@ -16,26 +16,17 @@ package org.springframework.http.converter; -import java.io.FilterOutputStream; import java.io.IOException; -import java.io.OutputStream; import java.net.URLDecoder; import java.net.URLEncoder; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collections; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import org.jspecify.annotations.Nullable; -import org.springframework.core.io.Resource; -import org.springframework.http.ContentDisposition; -import org.springframework.http.HttpEntity; -import org.springframework.http.HttpHeaders; +import org.springframework.core.ResolvableType; import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpOutputMessage; import org.springframework.http.MediaType; @@ -43,297 +34,99 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MimeTypeUtils; import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; import org.springframework.util.StreamUtils; import org.springframework.util.StringUtils; /** - * Implementation of {@link HttpMessageConverter} to read and write 'normal' HTML - * forms and also to write (but not read) multipart data (for example, file uploads). + * Implementation of {@link HttpMessageConverter} to read and + * write URL encoded forms. For multipart support, see the + * {@link org.springframework.http.converter.multipart.MultipartHttpMessageConverter}. * - *

    In other words, this converter can read and write the + *

    This converter can read and write the * {@code "application/x-www-form-urlencoded"} media type as - * {@link MultiValueMap MultiValueMap<String, String>}, and it can also - * write (but not read) the {@code "multipart/form-data"} and - * {@code "multipart/mixed"} media types as - * {@link MultiValueMap MultiValueMap<String, Object>}. - * - *

    Multipart Data

    - * - *

    By default, {@code "multipart/form-data"} is used as the content type when - * {@linkplain #write writing} multipart data. It is also possible to write - * multipart data using other multipart subtypes such as {@code "multipart/mixed"} - * and {@code "multipart/related"}, as long as the multipart subtype is registered - * as a {@linkplain #getSupportedMediaTypes supported media type} and the - * desired multipart subtype is specified as the content type when - * {@linkplain #write writing} the multipart data. Note that {@code "multipart/mixed"} - * is registered as a supported media type by default. - * - *

    When writing multipart data, this converter uses other - * {@link HttpMessageConverter HttpMessageConverters} to write the respective - * MIME parts. By default, basic converters are registered for byte array, - * {@code String}, and {@code Resource}. These can be overridden via - * {@link #setPartConverters} or augmented via {@link #addPartConverter}. + * {@link MultiValueMap MultiValueMap<String, String>}. * *

    Examples

    * *

    The following snippet shows how to submit an HTML form using the - * {@code "multipart/form-data"} content type. + * {@code "application/x-www-form-urlencoded"} content type. * *

      * RestClient restClient = RestClient.create();
    - * // AllEncompassingFormHttpMessageConverter is configured by default
      *
    - * MultiValueMap<String, Object> form = new LinkedMultiValueMap<>();
    + * MultiValueMap<String, String> form = new LinkedMultiValueMap<>();
      * form.add("field 1", "value 1");
      * form.add("field 2", "value 2");
      * form.add("field 2", "value 3");
    - * form.add("field 3", 4);  // non-String form values supported as of 5.1.4
    - *
    - * ResponseEntity<Void> response = restClient.post()
    - *   .uri("https://example.com/myForm")
    - *   .contentType(MULTIPART_FORM_DATA)
    - *   .body(form)
    - *   .retrieve()
    - *   .toBodilessEntity();
    - * - *

    The following snippet shows how to do a file upload using the - * {@code "multipart/form-data"} content type. - * - *

    - * MultiValueMap<String, Object> parts = new LinkedMultiValueMap<>();
    - * parts.add("field 1", "value 1");
    - * parts.add("file", new ClassPathResource("myFile.jpg"));
    - *
    - * ResponseEntity<Void> response = restClient.post()
    - *   .uri("https://example.com/myForm")
    - *   .contentType(MULTIPART_FORM_DATA)
    - *   .body(parts)
    - *   .retrieve()
    - *   .toBodilessEntity();
    - * - *

    The following snippet shows how to do a file upload using the - * {@code "multipart/mixed"} content type. - * - *

    - * MultiValueMap<String, Object> parts = new LinkedMultiValueMap<>();
    - * parts.add("field 1", "value 1");
    - * parts.add("file", new ClassPathResource("myFile.jpg"));
    - *
    - * ResponseEntity<Void> response = restClient.post()
    - *   .uri("https://example.com/myForm")
    - *   .contentType(MULTIPART_MIXED)
    - *   .body(form)
    - *   .retrieve()
    - *   .toBodilessEntity();
    - * - *

    The following snippet shows how to do a file upload using the - * {@code "multipart/related"} content type. - * - *

    - * restClient = restClient.mutate()
    - *   .messageConverters(l -> l.stream()
    -  *    .filter(FormHttpMessageConverter.class::isInstance)
    - *     .map(FormHttpMessageConverter.class::cast)
    - *     .findFirst()
    - *     .orElseThrow(() -> new IllegalStateException("Failed to find FormHttpMessageConverter"))
    - *     .addSupportedMediaTypes(MULTIPART_RELATED);
    - *
    - * MultiValueMap<String, Object> parts = new LinkedMultiValueMap<>();
    - * parts.add("field 1", "value 1");
    - * parts.add("file", new ClassPathResource("myFile.jpg"));
    + * form.add("field 3", 4);
      *
      * ResponseEntity<Void> response = restClient.post()
      *   .uri("https://example.com/myForm")
    - *   .contentType(MULTIPART_RELATED)
    + *   .contentType(MediaType.APPLICATION_FORM_URLENCODED)
      *   .body(form)
      *   .retrieve()
      *   .toBodilessEntity();
    * - *

    Miscellaneous

    - * - *

    Some methods in this class were inspired by - * {@code org.apache.commons.httpclient.methods.multipart.MultipartRequestEntity}. - * * @author Arjen Poutsma * @author Rossen Stoyanchev * @author Juergen Hoeller * @author Sam Brannen + * @author Brian Clozel * @since 3.0 - * @see org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter * @see org.springframework.util.MultiValueMap */ -public class FormHttpMessageConverter implements HttpMessageConverter> { +public class FormHttpMessageConverter implements SmartHttpMessageConverter { /** The default charset used by the converter. */ public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + private static final ResolvableType MULTIVALUE_TYPE = + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); - private List supportedMediaTypes = new ArrayList<>(); - - private List> partConverters = new ArrayList<>(); + private static final ResolvableType MAP_TYPE = + ResolvableType.forClassWithGenerics(Map.class, String.class, String.class); private Charset charset = DEFAULT_CHARSET; - private @Nullable Charset multipartCharset; - - - public FormHttpMessageConverter() { - this.supportedMediaTypes.add(MediaType.APPLICATION_FORM_URLENCODED); - this.supportedMediaTypes.add(MediaType.MULTIPART_FORM_DATA); - this.supportedMediaTypes.add(MediaType.MULTIPART_MIXED); - this.supportedMediaTypes.add(MediaType.MULTIPART_RELATED); - - this.partConverters.add(new ByteArrayHttpMessageConverter()); - this.partConverters.add(new StringHttpMessageConverter()); - this.partConverters.add(new ResourceHttpMessageConverter()); - - applyDefaultCharset(); - } - - - /** - * Set the list of {@link MediaType} objects supported by this converter. - * @see #addSupportedMediaTypes(MediaType...) - * @see #getSupportedMediaTypes() - */ - public void setSupportedMediaTypes(List supportedMediaTypes) { - Assert.notNull(supportedMediaTypes, "'supportedMediaTypes' must not be null"); - // Ensure internal list is mutable. - this.supportedMediaTypes = new ArrayList<>(supportedMediaTypes); - } - - /** - * Add {@link MediaType} objects to be supported by this converter. - *

    The supplied {@code MediaType} objects will be appended to the list - * of {@linkplain #getSupportedMediaTypes() supported MediaType objects}. - * @param supportedMediaTypes a var-args list of {@code MediaType} objects to add - * @since 5.2 - * @see #setSupportedMediaTypes(List) - */ - public void addSupportedMediaTypes(MediaType... supportedMediaTypes) { - Assert.notNull(supportedMediaTypes, "'supportedMediaTypes' must not be null"); - Assert.noNullElements(supportedMediaTypes, "'supportedMediaTypes' must not contain null elements"); - Collections.addAll(this.supportedMediaTypes, supportedMediaTypes); - } /** * {@inheritDoc} - * @see #setSupportedMediaTypes(List) - * @see #addSupportedMediaTypes(MediaType...) */ @Override public List getSupportedMediaTypes() { - return Collections.unmodifiableList(this.supportedMediaTypes); - } - - /** - * Set the message body converters to use. These converters are used to - * convert objects to MIME parts. - */ - public void setPartConverters(List> partConverters) { - Assert.notEmpty(partConverters, "'partConverters' must not be empty"); - this.partConverters = partConverters; - } - - /** - * Return the {@linkplain #setPartConverters configured converters} for MIME - * parts. - * @since 5.3 - */ - public List> getPartConverters() { - return Collections.unmodifiableList(this.partConverters); - } - - /** - * Add a message body converter. Such a converter is used to convert objects - * to MIME parts. - */ - public void addPartConverter(HttpMessageConverter partConverter) { - Assert.notNull(partConverter, "'partConverter' must not be null"); - this.partConverters.add(partConverter); + return List.of(MediaType.APPLICATION_FORM_URLENCODED); } /** * Set the default character set to use for reading and writing form data when * the request or response {@code Content-Type} header does not explicitly * specify it. - *

    As of 4.3, this is also used as the default charset for the conversion - * of text bodies in a multipart request. - *

    As of 5.0, this is also used for part headers including - * {@code Content-Disposition} (and its filename parameter) unless (the mutually - * exclusive) {@link #setMultipartCharset multipartCharset} is also set, in - * which case part headers are encoded as ASCII and filename is encoded - * with the {@code encoded-word} syntax from RFC 2047. - *

    By default this is set to "UTF-8". + *

    By default, this is set to "UTF-8". */ public void setCharset(@Nullable Charset charset) { if (charset != this.charset) { this.charset = (charset != null ? charset : DEFAULT_CHARSET); - applyDefaultCharset(); } } - /** - * Apply the configured charset as a default to registered part converters. - */ - private void applyDefaultCharset() { - for (HttpMessageConverter candidate : this.partConverters) { - if (candidate instanceof AbstractHttpMessageConverter converter) { - // Only override default charset if the converter operates with a charset to begin with... - if (converter.getDefaultCharset() != null) { - converter.setDefaultCharset(this.charset); - } - } - } - } - - /** - * Set the character set to use when writing multipart data to encode file - * names. Encoding is based on the {@code encoded-word} syntax defined in - * RFC 2047 and relies on {@code MimeUtility} from {@code jakarta.mail}. - *

    As of 5.0 by default part headers, including {@code Content-Disposition} - * (and its filename parameter) will be encoded based on the setting of - * {@link #setCharset(Charset)} or {@code UTF-8} by default. - * @since 4.1.1 - * @see Encoded-Word - */ - public void setMultipartCharset(Charset charset) { - this.multipartCharset = charset; + @Override + public boolean canRead(ResolvableType type, @Nullable MediaType mediaType) { + return canConvert(type, mediaType); } - @Override - public boolean canRead(Class clazz, @Nullable MediaType mediaType) { - if (!MultiValueMap.class.isAssignableFrom(clazz)) { - return false; - } - if (mediaType == null) { - return true; - } - for (MediaType supportedMediaType : getSupportedMediaTypes()) { - if (supportedMediaType.getType().equalsIgnoreCase("multipart")) { - // We can't read multipart, so skip this supported media type. - continue; - } - if (supportedMediaType.includes(mediaType)) { - return true; - } - } - return false; + public boolean canWrite(ResolvableType targetType, Class valueClass, @Nullable MediaType mediaType) { + return canConvert(targetType, mediaType); } - @Override - public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { - if (!MultiValueMap.class.isAssignableFrom(clazz)) { + private boolean canConvert(ResolvableType targetType, @Nullable MediaType mediaType) { + if (!Map.class.isAssignableFrom(targetType.toClass())) { return false; } - if (mediaType == null || MediaType.ALL.equals(mediaType)) { - return true; - } for (MediaType supportedMediaType : getSupportedMediaTypes()) { - if (supportedMediaType.isCompatibleWith(mediaType)) { + if (supportedMediaType.includes(mediaType)) { return true; } } @@ -341,8 +134,8 @@ public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { } @Override - public MultiValueMap read(@Nullable Class> clazz, - HttpInputMessage inputMessage) throws IOException, HttpMessageNotReadableException { + public Object read(ResolvableType type, HttpInputMessage inputMessage, @Nullable Map hints) + throws IOException, HttpMessageNotReadableException { MediaType contentType = inputMessage.getHeaders().getContentType(); Charset charset = (contentType != null && contentType.getCharset() != null ? @@ -367,46 +160,31 @@ public MultiValueMap read(@Nullable Class map, @Nullable MediaType contentType, HttpOutputMessage outputMessage) - throws IOException, HttpMessageNotWritableException { + public void write(Object data, ResolvableType type, @Nullable MediaType contentType, + HttpOutputMessage outputMessage, @Nullable Map hints) throws IOException, HttpMessageNotWritableException { - if (isMultipart(map, contentType)) { - writeMultipart((MultiValueMap) map, contentType, outputMessage); - } - else { - writeForm((MultiValueMap) map, contentType, outputMessage); - } - } + Assert.isInstanceOf(Map.class, data, "data must be of type Map or MultiValueMap"); + contentType = getFormContentType(contentType); + outputMessage.getHeaders().setContentType(contentType); + Charset charset = (contentType.getCharset() != null ? contentType.getCharset() : this.charset); - private boolean isMultipart(MultiValueMap map, @Nullable MediaType contentType) { - if (contentType != null) { - return contentType.getType().equalsIgnoreCase("multipart"); + String serializedForm = ""; + if (data instanceof MultiValueMap formData) { + serializedForm = serializeForm((MultiValueMap) formData, charset); } - for (List values : map.values()) { - for (Object value : values) { - if (value != null && !(value instanceof String)) { - return true; - } - } + else if (data instanceof Map formData) { + serializedForm = serializeForm((Map) formData, charset); } - return false; - } - - private void writeForm(MultiValueMap formData, @Nullable MediaType mediaType, - HttpOutputMessage outputMessage) throws IOException { - - mediaType = getFormContentType(mediaType); - outputMessage.getHeaders().setContentType(mediaType); - - Charset charset = (mediaType.getCharset() != null ? mediaType.getCharset() : this.charset); - - byte[] bytes = serializeForm(formData, charset).getBytes(charset); + byte[] bytes = serializedForm.getBytes(charset); outputMessage.getHeaders().setContentLength(bytes.length); if (outputMessage instanceof StreamingHttpOutputMessage streamingOutputMessage) { @@ -436,288 +214,39 @@ protected MediaType getFormContentType(@Nullable MediaType contentType) { return contentType; } - protected String serializeForm(MultiValueMap formData, Charset charset) { + protected String serializeForm(Map formData, Charset charset) { + StringBuilder builder = new StringBuilder(); + formData.forEach((name, value) -> { + if (name == null) { + Assert.isTrue(ObjectUtils.isEmpty(value), () -> "Null name in form data: " + formData); + return; + } + serializeValue(builder, name, value, charset); + }); + return builder.toString(); + } + + protected String serializeForm(MultiValueMap formData, Charset charset) { StringBuilder builder = new StringBuilder(); formData.forEach((name, values) -> { if (name == null) { Assert.isTrue(CollectionUtils.isEmpty(values), () -> "Null name in form data: " + formData); return; } - values.forEach(value -> { - if (builder.length() != 0) { - builder.append('&'); - } - builder.append(URLEncoder.encode(name, charset)); - if (value != null) { - builder.append('='); - builder.append(URLEncoder.encode(String.valueOf(value), charset)); - } - }); + values.forEach(value -> serializeValue(builder, name, value, charset)); }); - return builder.toString(); } - private void writeMultipart( - MultiValueMap parts, @Nullable MediaType contentType, HttpOutputMessage outputMessage) - throws IOException { - - // If the supplied content type is null, fall back to multipart/form-data. - // Otherwise rely on the fact that isMultipart() already verified the - // supplied content type is multipart. - if (contentType == null) { - contentType = MediaType.MULTIPART_FORM_DATA; - } - - Map parameters = new LinkedHashMap<>(contentType.getParameters().size() + 2); - parameters.putAll(contentType.getParameters()); - - byte[] boundary = generateMultipartBoundary(); - if (!isFilenameCharsetSet()) { - if (!this.charset.equals(StandardCharsets.UTF_8) && - !this.charset.equals(StandardCharsets.US_ASCII)) { - parameters.put("charset", this.charset.name()); - } - } - parameters.put("boundary", new String(boundary, StandardCharsets.US_ASCII)); - - // Add parameters to output content type - contentType = new MediaType(contentType, parameters); - outputMessage.getHeaders().setContentType(contentType); - - if (outputMessage instanceof StreamingHttpOutputMessage streamingOutputMessage) { - boolean repeatable = checkPartsRepeatable(parts); - streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() { - @Override - public void writeTo(OutputStream outputStream) throws IOException { - FormHttpMessageConverter.this.writeParts(outputStream, parts, boundary); - writeEnd(outputStream, boundary); - } - @Override - public boolean repeatable() { - return repeatable; - } - }); - } - else { - writeParts(outputMessage.getBody(), parts, boundary); - writeEnd(outputMessage.getBody(), boundary); - } - } - - @SuppressWarnings({"unchecked", "ConstantValue"}) - private boolean checkPartsRepeatable(MultiValueMap map) { - return map.entrySet().stream().allMatch(e -> e.getValue().stream().filter(Objects::nonNull).allMatch(part -> { - HttpHeaders headers = null; - Object body = part; - if (part instanceof HttpEntity entity) { - headers = entity.getHeaders(); - body = entity.getBody(); - Assert.state(body != null, "Empty body for part '" + e.getKey() + "': " + part); - } - HttpMessageConverter converter = findConverterFor(e.getKey(), headers, body); - return (converter instanceof AbstractHttpMessageConverter ahmc && - ((AbstractHttpMessageConverter) ahmc).supportsRepeatableWrites((T) body)); - })); - } - - private @Nullable HttpMessageConverter findConverterFor( - String name, @Nullable HttpHeaders headers, Object body) { - - Class partType = body.getClass(); - MediaType contentType = (headers != null ? headers.getContentType() : null); - for (HttpMessageConverter converter : this.partConverters) { - if (converter.canWrite(partType, contentType)) { - return converter; - } + private void serializeValue(StringBuilder builder, String name, String value, Charset charset) { + if (builder.length() != 0) { + builder.append('&'); } - return null; - } - - /** - * When {@link #setMultipartCharset(Charset)} is configured (i.e. RFC 2047, - * {@code encoded-word} syntax) we need to use ASCII for part headers, or - * otherwise we encode directly using the configured {@link #setCharset(Charset)}. - */ - private boolean isFilenameCharsetSet() { - return (this.multipartCharset != null); - } - - private void writeParts(OutputStream os, MultiValueMap parts, byte[] boundary) throws IOException { - for (Map.Entry> entry : parts.entrySet()) { - String name = entry.getKey(); - for (Object part : entry.getValue()) { - if (part != null) { - writeBoundary(os, boundary); - writePart(name, getHttpEntity(part), os); - writeNewLine(os); - } - } + builder.append(URLEncoder.encode(name, charset)); + if (value != null) { + builder.append('='); + builder.append(URLEncoder.encode(value, charset)); } } - @SuppressWarnings("unchecked") - private void writePart(String name, HttpEntity partEntity, OutputStream os) throws IOException { - Object partBody = partEntity.getBody(); - Assert.state(partBody != null, "Empty body for part '" + name + "': " + partEntity); - HttpHeaders partHeaders = partEntity.getHeaders(); - MediaType partContentType = partHeaders.getContentType(); - HttpMessageConverter converter = findConverterFor(name, partHeaders, partBody); - if (converter != null) { - Charset charset = isFilenameCharsetSet() ? StandardCharsets.US_ASCII : this.charset; - HttpOutputMessage multipartMessage = new MultipartHttpOutputMessage(os, charset); - String filename = getFilename(partBody); - ContentDisposition.Builder cd = ContentDisposition.formData().name(name); - if (filename != null) { - cd.filename(filename, this.multipartCharset); - } - multipartMessage.getHeaders().setContentDisposition(cd.build()); - if (!partHeaders.isEmpty()) { - multipartMessage.getHeaders().putAll(partHeaders); - } - ((HttpMessageConverter) converter).write(partBody, partContentType, multipartMessage); - return; - } - throw new HttpMessageNotWritableException("Could not write request: " + - "no suitable HttpMessageConverter found for request type [" + partBody.getClass().getName() + "]"); - } - - /** - * Generate a multipart boundary. - *

    This implementation delegates to - * {@link MimeTypeUtils#generateMultipartBoundary()}. - */ - protected byte[] generateMultipartBoundary() { - return MimeTypeUtils.generateMultipartBoundary(); - } - - /** - * Return an {@link HttpEntity} for the given part Object. - * @param part the part to return an {@link HttpEntity} for - * @return the part Object itself it is an {@link HttpEntity}, - * or a newly built {@link HttpEntity} wrapper for that part - */ - protected HttpEntity getHttpEntity(Object part) { - return (part instanceof HttpEntity httpEntity ? httpEntity : new HttpEntity<>(part)); - } - - /** - * Return the filename of the given multipart part. This value will be used for the - * {@code Content-Disposition} header. - *

    The default implementation returns {@link Resource#getFilename()} if the part is a - * {@code Resource}, and {@code null} in other cases. Can be overridden in subclasses. - * @param part the part to determine the file name for - * @return the filename, or {@code null} if not known - */ - protected @Nullable String getFilename(Object part) { - if (part instanceof Resource resource) { - return resource.getFilename(); - } - else { - return null; - } - } - - - private void writeBoundary(OutputStream os, byte[] boundary) throws IOException { - os.write('-'); - os.write('-'); - os.write(boundary); - writeNewLine(os); - } - - private static void writeEnd(OutputStream os, byte[] boundary) throws IOException { - os.write('-'); - os.write('-'); - os.write(boundary); - os.write('-'); - os.write('-'); - writeNewLine(os); - } - - private static void writeNewLine(OutputStream os) throws IOException { - os.write('\r'); - os.write('\n'); - } - - - /** - * Implementation of {@link org.springframework.http.HttpOutputMessage} used - * to write a MIME multipart. - */ - private static class MultipartHttpOutputMessage implements HttpOutputMessage { - - private final OutputStream outputStream; - - private final Charset charset; - - private final HttpHeaders headers = new HttpHeaders(); - - private boolean headersWritten = false; - - public MultipartHttpOutputMessage(OutputStream outputStream, Charset charset) { - this.outputStream = new MultipartOutputStream(outputStream); - this.charset = charset; - } - - @Override - public HttpHeaders getHeaders() { - return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); - } - - @Override - public OutputStream getBody() throws IOException { - writeHeaders(); - return this.outputStream; - } - - private void writeHeaders() throws IOException { - if (!this.headersWritten) { - for (Map.Entry> entry : this.headers.headerSet()) { - byte[] headerName = getBytes(entry.getKey()); - for (String headerValueString : entry.getValue()) { - byte[] headerValue = getBytes(headerValueString); - this.outputStream.write(headerName); - this.outputStream.write(':'); - this.outputStream.write(' '); - this.outputStream.write(headerValue); - writeNewLine(this.outputStream); - } - } - writeNewLine(this.outputStream); - this.headersWritten = true; - } - } - - private byte[] getBytes(String name) { - return name.getBytes(this.charset); - } - - } - - - /** - * OutputStream that neither flushes nor closes. - */ - private static class MultipartOutputStream extends FilterOutputStream { - - public MultipartOutputStream(OutputStream out) { - super(out); - } - - @Override - public void write(byte[] b, int off, int let) throws IOException { - this.out.write(b, off, let); - } - - @Override - public void flush() { - } - - @Override - public void close() { - } - } - - } diff --git a/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java index f2419d004a69..b7adaeb34929 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java @@ -108,4 +108,22 @@ T read(Class clazz, HttpInputMessage inputMessage) void write(T t, @Nullable MediaType contentType, HttpOutputMessage outputMessage) throws IOException, HttpMessageNotWritableException; + /** + * Indicates whether this message converter can + * {@linkplain #write(Object, MediaType, HttpOutputMessage) write} the + * given payload multiple times. + *

    This can be used by HTTP client libraries to know whether a message can be + * sent again, for example after an HTTP redirect. The default implementation + * returns {@code false}. This typically returns false if the payload can be read + * only once. + * @param t the object t + * @param contentType the content type to use when writing. + * @return {@code true} if {@code t} can be written repeatedly; + * {@code false} otherwise + * @since 7.1 + */ + default boolean canWriteRepeatedly(T t, @Nullable MediaType contentType) { + return false; + } + } diff --git a/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverters.java b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverters.java index c7b96340f0e2..7b6104f35a22 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverters.java +++ b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverters.java @@ -114,6 +114,14 @@ interface Builder> { */ T withStringConverter(HttpMessageConverter stringMessageConverter); + /** + * Override the default {@code HttpMessageConverter} for URL encoded forms. + * @param formMessageConverter the converter instance to use + * @since 7.1 + * @see FormHttpMessageConverter + */ + T withFormConverter(HttpMessageConverter formMessageConverter); + /** * Override the default String {@code HttpMessageConverter} * with any converter supporting the Kotlin Serialization conversion for JSON. diff --git a/spring-web/src/main/java/org/springframework/http/converter/ObjectToStringHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/ObjectToStringHttpMessageConverter.java index b305180831e7..14a1cfc5bb47 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/ObjectToStringHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/ObjectToStringHttpMessageConverter.java @@ -99,6 +99,11 @@ public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { return canWrite(mediaType) && this.conversionService.canConvert(clazz, String.class); } + @Override + public boolean canWriteRepeatedly(Object o, @Nullable MediaType contentType) { + return true; + } + @Override protected boolean supports(Class clazz) { // should not be called, since we override canRead/Write @@ -137,6 +142,7 @@ protected Long getContentLength(Object obj, @Nullable MediaType contentType) { } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Object o) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/ResourceHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/ResourceHttpMessageConverter.java index 6365f10c3f05..0faea6dc8a5c 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/ResourceHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/ResourceHttpMessageConverter.java @@ -71,6 +71,11 @@ public ResourceHttpMessageConverter(boolean supportsReadStreaming) { } + @Override + public boolean canWriteRepeatedly(Resource resource, @Nullable MediaType contentType) { + return !(resource instanceof InputStreamResource); + } + @Override protected boolean supports(Class clazz) { return Resource.class.isAssignableFrom(clazz); @@ -141,8 +146,9 @@ protected MediaType getDefaultContentType(Resource resource) { } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Resource resource) { - return !(resource instanceof InputStreamResource); + return canWriteRepeatedly(resource, null); } diff --git a/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java index ff64c9f1a73e..31b74eb3e512 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java @@ -104,6 +104,23 @@ public boolean canWrite(@Nullable Type type, @Nullable Class clazz, @Nullable return ResourceRegion.class.isAssignableFrom(typeArgumentClass); } + @Override + public boolean canWriteRepeatedly(Object object, @Nullable MediaType contentType) { + if (object instanceof ResourceRegion resourceRegion) { + return supportsRepeatableWrites(resourceRegion); + } + else { + @SuppressWarnings("unchecked") + Collection regions = (Collection) object; + for (ResourceRegion region : regions) { + if (!supportsRepeatableWrites(region)) { + return false; + } + } + return true; + } + } + @Override protected void writeInternal(Object object, @Nullable Type type, HttpOutputMessage outputMessage) throws IOException, HttpMessageNotWritableException { @@ -140,20 +157,9 @@ protected MediaType getDefaultContentType(Object object) { } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Object object) { - if (object instanceof ResourceRegion resourceRegion) { - return supportsRepeatableWrites(resourceRegion); - } - else { - @SuppressWarnings("unchecked") - Collection regions = (Collection) object; - for (ResourceRegion region : regions) { - if (!supportsRepeatableWrites(region)) { - return false; - } - } - return true; - } + return canWriteRepeatedly(object, null); } private boolean supportsRepeatableWrites(ResourceRegion region) { diff --git a/spring-web/src/main/java/org/springframework/http/converter/StringHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/StringHttpMessageConverter.java index 5a9e81e90da7..e280c0b4f102 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/StringHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/StringHttpMessageConverter.java @@ -89,6 +89,11 @@ public boolean supports(Class clazz) { return String.class == clazz; } + @Override + public boolean canWriteRepeatedly(String s, @Nullable MediaType contentType) { + return true; + } + @Override protected String readInternal(Class clazz, HttpInputMessage inputMessage) throws IOException { Charset charset = getContentTypeCharset(inputMessage.getHeaders().getContentType()); @@ -161,6 +166,7 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON) || } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(String s) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/feed/AbstractWireFeedHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/feed/AbstractWireFeedHttpMessageConverter.java index 8b0a5998b073..45b571572b64 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/feed/AbstractWireFeedHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/feed/AbstractWireFeedHttpMessageConverter.java @@ -29,6 +29,7 @@ import com.rometools.rome.io.FeedException; import com.rometools.rome.io.WireFeedInput; import com.rometools.rome.io.WireFeedOutput; +import org.jspecify.annotations.Nullable; import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpOutputMessage; @@ -66,6 +67,11 @@ protected AbstractWireFeedHttpMessageConverter(MediaType supportedMediaType) { } + @Override + public boolean canWriteRepeatedly(T t, @Nullable MediaType contentType) { + return true; + } + @Override @SuppressWarnings("unchecked") protected T readInternal(Class clazz, HttpInputMessage inputMessage) @@ -108,6 +114,7 @@ protected void writeInternal(T wireFeed, HttpOutputMessage outputMessage) } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(T t) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/AbstractJackson2HttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/json/AbstractJackson2HttpMessageConverter.java index 62a8219bb456..22b07ac1928d 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/json/AbstractJackson2HttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/json/AbstractJackson2HttpMessageConverter.java @@ -292,6 +292,11 @@ public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { return false; } + @Override + public boolean canWriteRepeatedly(Object o, @Nullable MediaType contentType) { + return true; + } + /** * Select an ObjectMapper to use, either the main ObjectMapper or another * if the handling for the given Class has been customized through @@ -570,6 +575,7 @@ protected JsonEncoding getJsonEncoding(@Nullable MediaType contentType) { } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Object o) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/GsonHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/json/GsonHttpMessageConverter.java index 62efb5eee51b..dd617cc1e2e0 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/json/GsonHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/json/GsonHttpMessageConverter.java @@ -24,6 +24,7 @@ import com.google.gson.Gson; import org.jspecify.annotations.Nullable; +import org.springframework.http.MediaType; import org.springframework.util.Assert; /** @@ -86,6 +87,10 @@ public Gson getGson() { return this.gson; } + @Override + public boolean canWriteRepeatedly(Object o, @Nullable MediaType contentType) { + return true; + } @Override protected Object readInternal(Type resolvedType, Reader reader) throws Exception { @@ -108,6 +113,7 @@ protected void writeInternal(Object object, @Nullable Type type, Writer writer) } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Object o) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/JsonbHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/json/JsonbHttpMessageConverter.java index 53300311cf24..5cc7e5ae98f4 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/json/JsonbHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/json/JsonbHttpMessageConverter.java @@ -26,6 +26,7 @@ import jakarta.json.bind.JsonbConfig; import org.jspecify.annotations.Nullable; +import org.springframework.http.MediaType; import org.springframework.util.Assert; /** @@ -94,6 +95,10 @@ public Jsonb getJsonb() { return this.jsonb; } + @Override + public boolean canWriteRepeatedly(Object o, @Nullable MediaType contentType) { + return true; + } @Override protected Object readInternal(Type resolvedType, Reader reader) throws Exception { @@ -111,6 +116,7 @@ protected void writeInternal(Object object, @Nullable Type type, Writer writer) } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Object o) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/multipart/DefaultParts.java b/spring-web/src/main/java/org/springframework/http/converter/multipart/DefaultParts.java new file mode 100644 index 000000000000..293ef04f1784 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/multipart/DefaultParts.java @@ -0,0 +1,297 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.multipart; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.nio.file.StandardOpenOption; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.ContentDisposition; +import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; + +/** + * Default implementations of {@link Part} and subtypes. + * + * @author Arjen Poutsma + * @author Brian Clozel + */ +abstract class DefaultParts { + + /** + * Create a new {@link FormFieldPart} with the given parameters. + * @param headers the part headers + * @param value the form field value + * @return the created part + */ + public static FormFieldPart formFieldPart(HttpHeaders headers, String value) { + Assert.notNull(headers, "Headers must not be null"); + Assert.notNull(value, "Value must not be null"); + + return new DefaultFormFieldPart(headers, value); + } + + /** + * Create a new {@link Part} or {@link FilePart} based on a flux of data + * buffers. Returns {@link FilePart} if the {@code Content-Disposition} of + * the given headers contains a filename, or a "normal" {@link Part} + * otherwise. + * @param headers the part headers + * @param dataBuffer the content of the part + * @return {@link Part} or {@link FilePart}, depending on {@link HttpHeaders#getContentDisposition()} + */ + public static Part part(HttpHeaders headers, DataBuffer dataBuffer) { + Assert.notNull(headers, "Headers must not be null"); + Assert.notNull(dataBuffer, "DataBuffer must not be null"); + + return partInternal(headers, new DataBufferContent(dataBuffer)); + } + + /** + * Create a new {@link Part} or {@link FilePart} based on the given file. + * Returns {@link FilePart} if the {@code Content-Disposition} of the given + * headers contains a filename, or a "normal" {@link Part} otherwise + * @param headers the part headers + * @param file the file + * @return {@link Part} or {@link FilePart}, depending on {@link HttpHeaders#getContentDisposition()} + */ + public static Part part(HttpHeaders headers, Path file) { + Assert.notNull(headers, "Headers must not be null"); + Assert.notNull(file, "File must not be null"); + + return partInternal(headers, new FileContent(file)); + } + + + private static Part partInternal(HttpHeaders headers, Content content) { + String filename = headers.getContentDisposition().getFilename(); + if (filename != null) { + return new DefaultFilePart(headers, content); + } + else { + return new DefaultPart(headers, content); + } + } + + + /** + * Abstract base class for {@link Part} implementations. + */ + private abstract static class AbstractPart implements Part { + + private final HttpHeaders headers; + + protected AbstractPart(HttpHeaders headers) { + Assert.notNull(headers, "HttpHeaders is required"); + this.headers = headers; + } + + @Override + public String name() { + String name = headers().getContentDisposition().getName(); + Assert.state(name != null, "No part name available"); + return name; + } + + @Override + public HttpHeaders headers() { + return this.headers; + } + } + + + /** + * Default implementation of {@link FormFieldPart}. + */ + private static class DefaultFormFieldPart extends AbstractPart implements FormFieldPart { + + private final String value; + + public DefaultFormFieldPart(HttpHeaders headers, String value) { + super(headers); + this.value = value; + } + + @Override + public InputStream content() { + byte[] bytes = this.value.getBytes(MultipartUtils.charset(headers())); + return new ByteArrayInputStream(bytes); + } + + @Override + public String value() { + return this.value; + } + + @Override + public String toString() { + String name = headers().getContentDisposition().getName(); + if (name != null) { + return "DefaultFormFieldPart{" + name() + "}"; + } + else { + return "DefaultFormFieldPart"; + } + } + } + + + /** + * Default implementation of {@link Part}. + */ + private static class DefaultPart extends AbstractPart { + + protected final Content content; + + public DefaultPart(HttpHeaders headers, Content content) { + super(headers); + this.content = content; + } + + @Override + public InputStream content() throws IOException { + return this.content.content(); + } + + @Override + public void delete() throws IOException { + this.content.delete(); + } + + @Override + public String toString() { + String name = headers().getContentDisposition().getName(); + if (name != null) { + return "DefaultPart{" + name + "}"; + } + else { + return "DefaultPart"; + } + } + } + + + /** + * Default implementation of {@link FilePart}. + */ + private static final class DefaultFilePart extends DefaultPart implements FilePart { + + public DefaultFilePart(HttpHeaders headers, Content content) { + super(headers, content); + } + + @Override + public String filename() { + String filename = headers().getContentDisposition().getFilename(); + Assert.state(filename != null, "No filename found"); + return filename; + } + + @Override + public void transferTo(Path dest) throws IOException { + this.content.transferTo(dest); + } + + @Override + public String toString() { + ContentDisposition contentDisposition = headers().getContentDisposition(); + String name = contentDisposition.getName(); + String filename = contentDisposition.getFilename(); + if (name != null) { + return "DefaultFilePart{" + name + " (" + filename + ")}"; + } + else { + return "DefaultFilePart{(" + filename + ")}"; + } + } + } + + + /** + * Part content abstraction. + */ + private interface Content { + + InputStream content() throws IOException; + + void transferTo(Path dest) throws IOException; + + void delete() throws IOException; + } + + + /** + * {@code Content} implementation based on an in-memory {@code InputStream}. + */ + private static final class DataBufferContent implements Content { + + private final DataBuffer content; + + public DataBufferContent(DataBuffer content) { + this.content = content; + } + + @Override + public InputStream content() { + return this.content.asInputStream(); + } + + @Override + public void transferTo(Path dest) throws IOException { + Files.copy(this.content.asInputStream(), dest, StandardCopyOption.REPLACE_EXISTING); + } + + @Override + public void delete() throws IOException { + } + } + + + /** + * {@code Content} implementation based on a file. + */ + private static final class FileContent implements Content { + + private final Path file; + + public FileContent(Path file) { + this.file = file; + } + + @Override + public InputStream content() throws IOException { + return Files.newInputStream(this.file.toAbsolutePath(), StandardOpenOption.READ); + } + + @Override + public void transferTo(Path dest) throws IOException { + Files.copy(this.file, dest, StandardCopyOption.REPLACE_EXISTING); + } + + @Override + public void delete() throws IOException { + Files.delete(this.file); + } + + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/multipart/FilePart.java b/spring-web/src/main/java/org/springframework/http/converter/multipart/FilePart.java new file mode 100644 index 000000000000..9d5594a3d0e0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/multipart/FilePart.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.multipart; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; + +/** + * Specialization of {@link Part} that represents an uploaded file received in + * a multipart request. + * + * @author Brian Clozel + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 7.1 + */ +public interface FilePart extends Part { + + /** + * Return the original filename in the client's filesystem. + *

    Note: Please keep in mind this filename is supplied + * by the client and should not be used blindly. In addition to not using + * the directory portion, the file name could also contain characters such + * as ".." and others that can be used maliciously. It is recommended to not + * use this filename directly. Preferably generate a unique one and save + * this one somewhere for reference, if necessary. + * @return the original filename, or the empty String if no file has been chosen + * in the multipart form, or {@code null} if not defined or not available + * @see RFC 7578, Section 4.2 + * @see Unrestricted File Upload + */ + String filename(); + + /** + * Convenience method to copy the content of the file in this part to the + * given destination file. If the destination file already exists, it will + * be truncated first. + *

    The default implementation delegates to {@link #transferTo(Path)}. + * @param dest the target file + * @throws IllegalStateException if the part isn't a file + * @see #transferTo(Path) + */ + default void transferTo(File dest) throws IOException { + transferTo(dest.toPath()); + } + + /** + * Convenience method to copy the content of the file in this part to the + * given destination file. If the destination file already exists, it will + * be truncated first. + * @param dest the target file + * @throws IllegalStateException if the part isn't a file + * @see #transferTo(File) + */ + void transferTo(Path dest) throws IOException; + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideTestUtils.java b/spring-web/src/main/java/org/springframework/http/converter/multipart/FormFieldPart.java similarity index 57% rename from spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideTestUtils.java rename to spring-web/src/main/java/org/springframework/http/converter/multipart/FormFieldPart.java index added7402c50..8a920fb7da78 100644 --- a/spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideTestUtils.java +++ b/spring-web/src/main/java/org/springframework/http/converter/multipart/FormFieldPart.java @@ -14,24 +14,21 @@ * limitations under the License. */ -package org.springframework.test.context.bean.override; +package org.springframework.http.converter.multipart; -import java.util.List; /** - * Test utilities for Bean Overrides. + * Specialization of {@link Part} for a form field. * - * @author Sam Brannen - * @since 6.2.2 + * @author Brian Clozel + * @author Rossen Stoyanchev + * @since 7.1 */ -public abstract class BeanOverrideTestUtils { +public interface FormFieldPart extends Part { - public static List findHandlers(Class testClass) { - return BeanOverrideHandler.forTestClass(testClass); - } - - public static List findAllHandlers(Class testClass) { - return BeanOverrideHandler.findAllHandlers(testClass); - } + /** + * Return the form field value. + */ + String value(); } diff --git a/spring-web/src/main/java/org/springframework/http/converter/multipart/MultipartHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/multipart/MultipartHttpMessageConverter.java new file mode 100644 index 000000000000..daf70768dd57 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/multipart/MultipartHttpMessageConverter.java @@ -0,0 +1,652 @@ +/* + * Copyright 2026-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.multipart; + +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.jspecify.annotations.Nullable; + +import org.springframework.core.ResolvableType; +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBufferLimitException; +import org.springframework.http.ContentDisposition; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.http.converter.AbstractHttpMessageConverter; +import org.springframework.http.converter.ByteArrayHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.http.converter.ResourceHttpMessageConverter; +import org.springframework.http.converter.SmartHttpMessageConverter; +import org.springframework.http.converter.StringHttpMessageConverter; +import org.springframework.util.Assert; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.MultiValueMap; + +/** + * Implementation of {@link HttpMessageConverter} to read and write + * multipart data (for example, file uploads). + * + *

    This converter can read {@code "multipart/form-data"} + * and {@code "multipart/mixed"} messages as + * {@link MultiValueMap MultiValueMap<String, Part>}, and + * write {@link MultiValueMap MultiValueMap<String, Object>} as + * multipart messages. + * + *

    On Servlet containers, the reading of multipart messages should be + * delegated to the {@link org.springframework.web.multipart.MultipartResolver}. + * + *

    Multipart Data

    + * + *

    By default, {@code "multipart/form-data"} is used as the content type when + * {@linkplain #write writing} multipart data. It is also possible to write + * multipart data using other multipart subtypes such as {@code "multipart/mixed"} + * and {@code "multipart/related"}, as long as the multipart subtype is registered + * as a {@linkplain #getSupportedMediaTypes supported media type} and the + * desired multipart subtype is specified as the content type when + * {@linkplain #write writing} the multipart data. Note that {@code "multipart/mixed"} + * is registered as a supported media type by default. + * + *

    When writing multipart data, this converter uses other + * {@link HttpMessageConverter HttpMessageConverters} to write the respective + * MIME parts. By default, basic converters are registered for byte array, + * {@code String}, and {@code Resource}. This can be set with the main + * {@link #MultipartHttpMessageConverter(Iterable) constructor}. + * + *

    Examples

    + * + *

    The following snippet shows how to submit an HTML form using the + * {@code "multipart/form-data"} content type. + * + *

    + * RestClient restClient = RestClient.create();
    + * // MultipartHttpMessageConverter is configured by default
    + *
    + * MultiValueMap<String, Object> form = new LinkedMultiValueMap<>();
    + * form.add("field 1", "value 1");
    + * form.add("field 2", "value 2");
    + * form.add("field 2", "value 3");
    + * form.add("field 3", 4);
    + *
    + * ResponseEntity<Void> response = restClient.post()
    + *   .uri("https://example.com/myForm")
    + *   .contentType(MULTIPART_FORM_DATA)
    + *   .body(form)
    + *   .retrieve()
    + *   .toBodilessEntity();
    + * + *

    The following snippet shows how to do a file upload using the + * {@code "multipart/form-data"} content type. + * + *

    + * MultiValueMap<String, Object> parts = new LinkedMultiValueMap<>();
    + * parts.add("field 1", "value 1");
    + * parts.add("file", new ClassPathResource("myFile.jpg"));
    + *
    + * ResponseEntity<Void> response = restClient.post()
    + *   .uri("https://example.com/myForm")
    + *   .contentType(MULTIPART_FORM_DATA)
    + *   .body(parts)
    + *   .retrieve()
    + *   .toBodilessEntity();
    + * + *

    The following snippet shows how to decode a multipart response. + * + *

    + * MultiValueMap<String, Part> body = this.restClient.get()
    + * 				.uri("https://example.com/parts/42")
    + * 				.accept(MediaType.MULTIPART_FORM_DATA)
    + * 				.retrieve()
    + * 				.body(new ParameterizedTypeReference<>() {});
    + * + * @author Brian Clozel + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Sam Brannen + * @since 7.1 + * @see org.springframework.util.MultiValueMap + */ +public class MultipartHttpMessageConverter implements SmartHttpMessageConverter> { + + private final List> partConverters; + + private @Nullable Path tempDirectory; + + private List supportedMediaTypes = new ArrayList<>(); + + private Charset charset = StandardCharsets.UTF_8; + + private @Nullable Charset multipartCharset; + + private int maxInMemorySize = 256 * 1024; + + private int maxHeadersSize = 10 * 1024; + + private long maxDiskUsagePerPart = -1; + + private int maxParts = -1; + + /** + * Create a new converter instance with the given converter instances for reading and + * writing parts. + * @param converters the converters to use for reading and writing parts + */ + public MultipartHttpMessageConverter(Iterable> converters) { + this.supportedMediaTypes.add(MediaType.MULTIPART_FORM_DATA); + this.supportedMediaTypes.add(MediaType.MULTIPART_MIXED); + this.supportedMediaTypes.add(MediaType.MULTIPART_RELATED); + + this.partConverters = new ArrayList<>(); + converters.forEach(this.partConverters::add); + applyDefaultCharset(); + } + + /** + * Create a new converter instance with default converter instances for reading and + * writing parts. + * @see ByteArrayHttpMessageConverter + * @see StringHttpMessageConverter + * @see ResourceHttpMessageConverter + */ + public MultipartHttpMessageConverter() { + this(List.of( new ByteArrayHttpMessageConverter(), new StringHttpMessageConverter(), + new ResourceHttpMessageConverter())); + } + + /** + * Set the list of {@link MediaType} objects supported by this converter. + * @see #addSupportedMediaTypes(MediaType...) + * @see #getSupportedMediaTypes() + */ + public void setSupportedMediaTypes(List supportedMediaTypes) { + Assert.notNull(supportedMediaTypes, "'supportedMediaTypes' must not be null"); + // Ensure internal list is mutable. + this.supportedMediaTypes = new ArrayList<>(supportedMediaTypes); + } + + /** + * Add {@link MediaType} objects to be supported by this converter. + *

    The supplied {@code MediaType} objects will be appended to the list + * of {@linkplain #getSupportedMediaTypes() supported MediaType objects}. + * @param supportedMediaTypes a var-args list of {@code MediaType} objects to add + * @see #setSupportedMediaTypes(List) + */ + public void addSupportedMediaTypes(MediaType... supportedMediaTypes) { + Assert.notNull(supportedMediaTypes, "'supportedMediaTypes' must not be null"); + Assert.noNullElements(supportedMediaTypes, "'supportedMediaTypes' must not contain null elements"); + Collections.addAll(this.supportedMediaTypes, supportedMediaTypes); + } + + /** + * {@inheritDoc} + * @see #setSupportedMediaTypes(List) + * @see #addSupportedMediaTypes(MediaType...) + */ + @Override + public List getSupportedMediaTypes() { + return Collections.unmodifiableList(this.supportedMediaTypes); + } + + + /** + * Return the configured converters for MIME parts. + */ + public List> getPartConverters() { + return Collections.unmodifiableList(this.partConverters); + } + + /** + * Set the default character set to use for reading and writing form data when + * the request or response {@code Content-Type} header does not explicitly + * specify it. + *

    As of 4.3, this is also used as the default charset for the conversion + * of text bodies in a multipart request. + *

    As of 5.0, this is also used for part headers including + * {@code Content-Disposition} (and its filename parameter) unless (the mutually + * exclusive) {@link #setMultipartCharset multipartCharset} is also set, in + * which case part headers are encoded as ASCII and filename is encoded + * with the {@code encoded-word} syntax from RFC 2047. + *

    By default, this is set to "UTF-8". + */ + public void setCharset(@Nullable Charset charset) { + if (charset != this.charset) { + this.charset = (charset != null ? charset : StandardCharsets.UTF_8); + applyDefaultCharset(); + } + } + + /** + * Apply the configured charset as a default to registered part converters. + */ + private void applyDefaultCharset() { + for (HttpMessageConverter candidate : this.partConverters) { + if (candidate instanceof AbstractHttpMessageConverter converter) { + // Only override default charset if the converter operates with a charset to begin with... + if (converter.getDefaultCharset() != null) { + converter.setDefaultCharset(this.charset); + } + } + } + } + + /** + * Set the character set to use when writing multipart data to encode file + * names. Encoding is based on the {@code encoded-word} syntax defined in + * RFC 2047 and relies on {@code MimeUtility} from {@code jakarta.mail}. + *

    As of 5.0 by default part headers, including {@code Content-Disposition} + * (and its filename parameter) will be encoded based on the setting of + * {@link #setCharset(Charset)} or {@code UTF-8} by default. + * @see Encoded-Word + */ + public void setMultipartCharset(Charset charset) { + this.multipartCharset = charset; + } + + + /** + * Configure the maximum amount of memory that is allowed per headers section of each part. + *

    By default, this is set to 10K. + * @param byteCount the maximum amount of memory for headers + */ + public void setMaxHeadersSize(int byteCount) { + this.maxHeadersSize = byteCount; + } + + /** + * Configure the maximum amount of memory allowed per part. + * When the limit is exceeded: + *

      + *
    • File parts are written to a temporary file. + *
    • Non-file parts are rejected with {@link DataBufferLimitException}. + *
    + *

    By default, this is set to 256K. + * @param maxInMemorySize the in-memory limit in bytes; if set to -1 the entire + * contents will be stored in memory + */ + public void setMaxInMemorySize(int maxInMemorySize) { + this.maxInMemorySize = maxInMemorySize; + } + + /** + * Configure the maximum amount of disk space allowed for file parts. + *

    By default, this is set to -1, meaning that there is no maximum. + *

    Note that this property is ignored when + * {@link #setMaxInMemorySize(int) maxInMemorySize} is set to -1. + */ + public void setMaxDiskUsagePerPart(long maxDiskUsagePerPart) { + this.maxDiskUsagePerPart = maxDiskUsagePerPart; + } + + /** + * Specify the maximum number of parts allowed in a given multipart request. + *

    By default, this is set to -1, meaning that there is no maximum. + */ + public void setMaxParts(int maxParts) { + this.maxParts = maxParts; + } + + @Override + public boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType) { + if (!supportsMediaType(mediaType)) { + return false; + } + if (!MultiValueMap.class.isAssignableFrom(elementType.toClass()) || + (!elementType.hasUnresolvableGenerics() && + !Part.class.isAssignableFrom(elementType.getGeneric(1).toClass()))) { + return false; + } + return true; + } + + private boolean supportsMediaType(@Nullable MediaType mediaType) { + if (mediaType == null) { + return true; + } + for (MediaType supportedMediaType : getSupportedMediaTypes()) { + if (supportedMediaType.includes(mediaType)) { + return true; + } + } + return false; + } + + @Override + public MultiValueMap read(ResolvableType type, HttpInputMessage message, @Nullable Map hints) throws IOException, HttpMessageNotReadableException { + + Charset headersCharset = MultipartUtils.charset(message.getHeaders()); + byte[] boundary = boundary(message, headersCharset); + if (boundary == null) { + throw new HttpMessageNotReadableException("No multipart boundary found in Content-Type: \"" + + message.getHeaders().getContentType() + "\"", message); + } + PartGenerator partListener = new PartGenerator(this.maxInMemorySize, this.maxDiskUsagePerPart, this.maxParts, getTempDirectory()); + new MultipartParser(this.maxHeadersSize, 2 * 1024).parse(message.getBody(), boundary, + headersCharset, partListener); + return partListener.getParts(); + } + + + private static byte @Nullable [] boundary(HttpInputMessage message, Charset headersCharset) { + MediaType contentType = message.getHeaders().getContentType(); + if (contentType != null) { + String boundary = contentType.getParameter("boundary"); + if (boundary != null) { + int len = boundary.length(); + if (len > 2 && boundary.charAt(0) == '"' && boundary.charAt(len - 1) == '"') { + boundary = boundary.substring(1, len - 1); + } + return boundary.getBytes(headersCharset); + } + } + return null; + } + + private Path getTempDirectory() throws IOException { + if (this.tempDirectory == null || !this.tempDirectory.toFile().exists()) { + this.tempDirectory = Files.createTempDirectory("spring-multipart-"); + } + return this.tempDirectory; + } + + @Override + public boolean canWrite(ResolvableType targetType, Class valueClass, @Nullable MediaType mediaType) { + if (!MultiValueMap.class.isAssignableFrom(targetType.toClass())) { + return false; + } + if (mediaType == null || MediaType.ALL.equals(mediaType)) { + return true; + } + for (MediaType supportedMediaType : getSupportedMediaTypes()) { + if (supportedMediaType.isCompatibleWith(mediaType)) { + return true; + } + } + return false; + } + + @Override + @SuppressWarnings("unchecked") + public void write(MultiValueMap map, ResolvableType type, @Nullable MediaType contentType, HttpOutputMessage outputMessage, @Nullable Map hints) throws IOException, HttpMessageNotWritableException { + MultiValueMap parts = (MultiValueMap) map; + + // If the supplied content type is null, fall back to multipart/form-data. + // Otherwise, rely on the fact that isMultipart() already verified the + // supplied content type is multipart. + if (contentType == null) { + contentType = MediaType.MULTIPART_FORM_DATA; + } + + Map parameters = new LinkedHashMap<>(contentType.getParameters().size() + 2); + parameters.putAll(contentType.getParameters()); + + byte[] boundary = MimeTypeUtils.generateMultipartBoundary(); + if (!isFilenameCharsetSet()) { + if (!this.charset.equals(StandardCharsets.UTF_8) && + !this.charset.equals(StandardCharsets.US_ASCII)) { + parameters.put("charset", this.charset.name()); + } + } + parameters.put("boundary", new String(boundary, StandardCharsets.US_ASCII)); + + // Add parameters to output content type + contentType = new MediaType(contentType, parameters); + outputMessage.getHeaders().setContentType(contentType); + + if (outputMessage instanceof StreamingHttpOutputMessage streamingOutputMessage) { + boolean repeatable = checkPartsRepeatable(parts, contentType); + streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() { + @Override + public void writeTo(OutputStream outputStream) throws IOException { + MultipartHttpMessageConverter.this.writeParts(outputStream, parts, boundary); + writeEnd(outputStream, boundary); + } + + @Override + public boolean repeatable() { + return repeatable; + } + }); + } + else { + writeParts(outputMessage.getBody(), parts, boundary); + writeEnd(outputMessage.getBody(), boundary); + } + } + + + @SuppressWarnings({"unchecked", "ConstantValue"}) + private boolean checkPartsRepeatable(MultiValueMap map, MediaType contentType) { + return map.entrySet().stream().allMatch(e -> e.getValue().stream().filter(Objects::nonNull).allMatch(part -> { + HttpHeaders headers = null; + Object body = part; + if (part instanceof HttpEntity entity) { + headers = entity.getHeaders(); + body = entity.getBody(); + Assert.state(body != null, "Empty body for part '" + e.getKey() + "': " + part); + } + HttpMessageConverter converter = (HttpMessageConverter) findConverterFor(e.getKey(), headers, body); + return converter != null && converter.canWriteRepeatedly((T) body, contentType); + })); + } + + private @Nullable HttpMessageConverter findConverterFor( + String name, @Nullable HttpHeaders headers, Object body) { + + Class partType = body.getClass(); + MediaType contentType = (headers != null ? headers.getContentType() : null); + for (HttpMessageConverter converter : this.partConverters) { + if (converter.canWrite(partType, contentType)) { + return converter; + } + } + return null; + } + + /** + * When {@link #setMultipartCharset(Charset)} is configured (i.e. RFC 2047, + * {@code encoded-word} syntax) we need to use ASCII for part headers, or + * otherwise we encode directly using the configured {@link #setCharset(Charset)}. + */ + private boolean isFilenameCharsetSet() { + return (this.multipartCharset != null); + } + + private void writeParts(OutputStream os, MultiValueMap parts, byte[] boundary) throws IOException { + for (Map.Entry> entry : parts.entrySet()) { + String name = entry.getKey(); + for (Object part : entry.getValue()) { + if (part != null) { + writeBoundary(os, boundary); + writePart(name, getHttpEntity(part), os); + writeNewLine(os); + } + } + } + } + + @SuppressWarnings("unchecked") + private void writePart(String name, HttpEntity partEntity, OutputStream os) throws IOException { + Object partBody = partEntity.getBody(); + Assert.state(partBody != null, "Empty body for part '" + name + "': " + partEntity); + HttpHeaders partHeaders = partEntity.getHeaders(); + MediaType partContentType = partHeaders.getContentType(); + HttpMessageConverter converter = findConverterFor(name, partHeaders, partBody); + if (converter != null) { + Charset charset = isFilenameCharsetSet() ? StandardCharsets.US_ASCII : this.charset; + HttpOutputMessage multipartMessage = new MultipartHttpOutputMessage(os, charset); + String filename = getFilename(partBody); + ContentDisposition.Builder cd = ContentDisposition.formData().name(name); + if (filename != null) { + cd.filename(filename, this.multipartCharset); + } + multipartMessage.getHeaders().setContentDisposition(cd.build()); + if (!partHeaders.isEmpty()) { + multipartMessage.getHeaders().putAll(partHeaders); + } + ((HttpMessageConverter) converter).write(partBody, partContentType, multipartMessage); + return; + } + throw new HttpMessageNotWritableException("Could not write request: " + + "no suitable HttpMessageConverter found for request type [" + partBody.getClass().getName() + "]"); + } + + /** + * Return an {@link HttpEntity} for the given part Object. + * @param part the part to return an {@link HttpEntity} for + * @return the part Object itself it is an {@link HttpEntity}, + * or a newly built {@link HttpEntity} wrapper for that part + */ + protected HttpEntity getHttpEntity(Object part) { + return (part instanceof HttpEntity httpEntity ? httpEntity : new HttpEntity<>(part)); + } + + /** + * Return the filename of the given multipart part. This value will be used for the + * {@code Content-Disposition} header. + *

    The default implementation returns {@link Resource#getFilename()} if the part is a + * {@code Resource}, and {@code null} in other cases. Can be overridden in subclasses. + * @param part the part to determine the file name for + * @return the filename, or {@code null} if not known + */ + protected @Nullable String getFilename(Object part) { + if (part instanceof Resource resource) { + return resource.getFilename(); + } + else { + return null; + } + } + + + private void writeBoundary(OutputStream os, byte[] boundary) throws IOException { + os.write('-'); + os.write('-'); + os.write(boundary); + writeNewLine(os); + } + + private static void writeEnd(OutputStream os, byte[] boundary) throws IOException { + os.write('-'); + os.write('-'); + os.write(boundary); + os.write('-'); + os.write('-'); + writeNewLine(os); + } + + private static void writeNewLine(OutputStream os) throws IOException { + os.write('\r'); + os.write('\n'); + } + + + /** + * Implementation of {@link org.springframework.http.HttpOutputMessage} used + * to write a MIME multipart. + */ + private static class MultipartHttpOutputMessage implements HttpOutputMessage { + + private final OutputStream outputStream; + + private final Charset charset; + + private final HttpHeaders headers = new HttpHeaders(); + + private boolean headersWritten = false; + + public MultipartHttpOutputMessage(OutputStream outputStream, Charset charset) { + this.outputStream = new MultipartOutputStream(outputStream); + this.charset = charset; + } + + @Override + public HttpHeaders getHeaders() { + return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); + } + + @Override + public OutputStream getBody() throws IOException { + writeHeaders(); + return this.outputStream; + } + + private void writeHeaders() throws IOException { + if (!this.headersWritten) { + for (Map.Entry> entry : this.headers.headerSet()) { + byte[] headerName = getBytes(entry.getKey()); + for (String headerValueString : entry.getValue()) { + byte[] headerValue = getBytes(headerValueString); + this.outputStream.write(headerName); + this.outputStream.write(':'); + this.outputStream.write(' '); + this.outputStream.write(headerValue); + writeNewLine(this.outputStream); + } + } + writeNewLine(this.outputStream); + this.headersWritten = true; + } + } + + private byte[] getBytes(String name) { + return name.getBytes(this.charset); + } + + } + + + /** + * OutputStream that neither flushes nor closes. + */ + private static class MultipartOutputStream extends FilterOutputStream { + + public MultipartOutputStream(OutputStream out) { + super(out); + } + + @Override + public void write(byte[] b, int off, int let) throws IOException { + this.out.write(b, off, let); + } + + @Override + public void flush() { + } + + @Override + public void close() { + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/multipart/MultipartParser.java b/spring-web/src/main/java/org/springframework/http/converter/multipart/MultipartParser.java new file mode 100644 index 000000000000..194ea1631c46 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/multipart/MultipartParser.java @@ -0,0 +1,553 @@ +/* + * Copyright 2026-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.multipart; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.jspecify.annotations.Nullable; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.converter.HttpMessageConversionException; + +/** + * Read a Multipart message as a byte stream and parse its content + * and signals them to the {@link PartListener}. + * + * @author Brian Clozel + * @author Arjen Poutsma + */ +final class MultipartParser { + + private static final Log logger = LogFactory.getLog(MultipartParser.class); + + private final int maxHeadersSize; + + private final int bufferSize; + + /** + * Create a new multipart parser instance. + * + * @param maxHeadersSize the maximum buffered header size + * @param bufferSize the size of the reading buffer + */ + MultipartParser(int maxHeadersSize, int bufferSize) { + this.maxHeadersSize = maxHeadersSize; + this.bufferSize = bufferSize; + } + + /** + * Parses the given stream of bytes into events published to the {@link PartListener}. + * @param input the input stream + * @param boundary the multipart boundary, as found in the {@code Content-Type} header + * @param headersCharset the charset to use for decoding headers + * @param listener a listener for parsed tokens + */ + public void parse(InputStream input, byte[] boundary, Charset headersCharset, PartListener listener) { + + InternalParser internalParser = new InternalParser(boundary, headersCharset, listener); + try { + while (true) { + byte[] read = input.readNBytes(this.bufferSize); + if (read.length == 0) { + break; + } + internalParser.state.data(DefaultDataBufferFactory.sharedInstance.wrap(read)); + } + internalParser.state.complete(); + } + catch (IOException ex) { + internalParser.state.dispose(); + listener.onError(new HttpMessageConversionException("Could not decode multipart message", ex)); + } + } + + private final class InternalParser { + + private final byte[] boundary; + + private final Charset headersCharset; + + private final PartListener listener; + + private State state; + + InternalParser(byte[] boundary, Charset headersCharset, PartListener listener) { + this.boundary = boundary; + this.headersCharset = headersCharset; + this.listener = listener; + this.state = new PreambleState(); + } + + void changeState(State newState, @Nullable DataBuffer remainder) { + if (logger.isTraceEnabled()) { + logger.trace("Changed state: " + this.state + " -> " + newState); + } + this.state.dispose(); + this.state = newState; + if (remainder != null) { + if (remainder.readableByteCount() > 0) { + newState.data(remainder); + } + else { + DataBufferUtils.release(remainder); + } + } + } + + /** + * Concatenates the given array of byte arrays. + */ + private static byte[] concat(byte[]... byteArrays) { + int len = 0; + for (byte[] byteArray : byteArrays) { + len += byteArray.length; + } + byte[] result = new byte[len]; + len = 0; + for (byte[] byteArray : byteArrays) { + System.arraycopy(byteArray, 0, result, len, byteArray.length); + len += byteArray.length; + } + return result; + } + + /** + * Represents the internal state of the {@link MultipartParser}. + * The flow for well-formed multipart messages is shown below: + *

    +		 *     PREAMBLE
    +		 *         |
    +		 *         v
    +		 *  +-->HEADERS--->DISPOSED
    +		 *  |      |
    +		 *  |      v
    +		 *  +----BODY
    +		 *  
    + * For malformed messages the flow ends in DISPOSED. + */ + private interface State { + + byte[] CR_LF = {'\r', '\n'}; + + byte HYPHEN = '-'; + + byte[] TWO_HYPHENS = {HYPHEN, HYPHEN}; + + String HEADER_ENTRY_SEPARATOR = "\\r\\n"; + + void data(DataBuffer buf); + + void complete(); + + default void dispose() { + } + } + + /** + * The initial state of the parser. Looks for the first boundary of the + * multipart message. Note that the first boundary is not necessarily + * prefixed with {@code CR LF}; only the prefix {@code --} is required. + */ + private final class PreambleState implements State { + + private final DataBufferUtils.Matcher firstBoundary; + + + PreambleState() { + this.firstBoundary = DataBufferUtils.matcher(concat(TWO_HYPHENS, InternalParser.this.boundary)); + } + + /** + * Looks for the first boundary in the given buffer. If found, changes + * state to {@link HeadersState}, and passes on the remainder of the + * buffer. + */ + @Override + public void data(DataBuffer buf) { + int endIdx = this.firstBoundary.match(buf); + if (endIdx != -1) { + if (logger.isTraceEnabled()) { + logger.trace("First boundary found @" + endIdx + " in " + buf); + } + DataBuffer preambleBuffer = buf.split(endIdx + 1); + DataBufferUtils.release(preambleBuffer); + changeState(new HeadersState(), buf); + } + else { + DataBufferUtils.release(buf); + } + } + + @Override + public void complete() { + changeState(DisposedState.INSTANCE, null); + InternalParser.this.listener.onError(new HttpMessageConversionException("Could not find first boundary")); + } + + @Override + public String toString() { + return "PREAMBLE"; + } + + } + + /** + * The state of the parser dealing with part headers. Parses header + * buffers into a {@link HttpHeaders} instance, making sure that + * the amount does not exceed {@link #maxHeadersSize}. + */ + private final class HeadersState implements State { + + private final DataBufferUtils.Matcher endHeaders = DataBufferUtils.matcher(concat(CR_LF, CR_LF)); + + private final List buffers = new ArrayList<>(); + + private int byteCount; + + + /** + * First checks whether the multipart boundary leading to this state + * was the final boundary. Then looks for the header-body boundary + * ({@code CR LF CR LF}) in the given buffer. If found, checks whether + * the size of all header buffers does not exceed {@link #maxHeadersSize}, + * converts all buffers collected so far into a {@link HttpHeaders} object + * and changes to {@link BodyState}, passing the remainder of the + * buffer. If the boundary is not found, the buffer is collected if + * its size does not exceed {@link #maxHeadersSize}. + */ + @Override + public void data(DataBuffer buf) { + if (isLastBoundary(buf)) { + if (logger.isTraceEnabled()) { + logger.trace("Last boundary found in " + buf); + } + changeState(DisposedState.INSTANCE, buf); + InternalParser.this.listener.onComplete(); + return; + } + int endIdx = this.endHeaders.match(buf); + if (endIdx != -1) { + if (logger.isTraceEnabled()) { + logger.trace("End of headers found @" + endIdx + " in " + buf); + } + this.byteCount += endIdx; + if (belowMaxHeaderSize(this.byteCount)) { + DataBuffer headerBuf = buf.split(endIdx + 1); + this.buffers.add(headerBuf); + emitHeaders(); + changeState(new BodyState(), buf); + } + } + else { + this.byteCount += buf.readableByteCount(); + if (belowMaxHeaderSize(this.byteCount)) { + this.buffers.add(buf); + } + } + } + + private void emitHeaders() { + HttpHeaders headers = parseHeaders(); + if (logger.isTraceEnabled()) { + logger.trace("Emitting headers: " + headers); + } + InternalParser.this.listener.onHeaders(headers); + } + + /** + * If the given buffer is the first buffer, check whether it starts with {@code --}. + * If it is the second buffer, check whether it makes up {@code --} together with the first buffer. + */ + private boolean isLastBoundary(DataBuffer buf) { + return (this.buffers.isEmpty() && + buf.readableByteCount() >= 2 && + buf.getByte(0) == HYPHEN && buf.getByte(1) == HYPHEN) || + (this.buffers.size() == 1 && + this.buffers.get(0).readableByteCount() == 1 && + this.buffers.get(0).getByte(0) == HYPHEN && + buf.readableByteCount() >= 1 && + buf.getByte(0) == HYPHEN); + } + + /** + * Checks whether the given {@code count} is below or equal to {@link #maxHeadersSize} + * and throws a {@link DataBufferLimitException} if not. + */ + private boolean belowMaxHeaderSize(long count) { + if (count <= MultipartParser.this.maxHeadersSize) { + return true; + } + else { + InternalParser.this.listener.onError( + new HttpMessageConversionException("Part headers exceeded the memory usage limit of " + + MultipartParser.this.maxHeadersSize + " bytes")); + return false; + } + } + + /** + * Parses the list of buffers into a {@link HttpHeaders} instance. + * Converts the joined buffers into a string using ISO=8859-1, and parses + * that string into key and values. + */ + private HttpHeaders parseHeaders() { + if (this.buffers.isEmpty()) { + return HttpHeaders.EMPTY; + } + DataBuffer joined = this.buffers.get(0).factory().join(this.buffers); + this.buffers.clear(); + String string = joined.toString(InternalParser.this.headersCharset); + DataBufferUtils.release(joined); + String[] lines = string.split(HEADER_ENTRY_SEPARATOR); + HttpHeaders result = new HttpHeaders(); + for (String line : lines) { + int idx = line.indexOf(':'); + if (idx != -1) { + String name = line.substring(0, idx); + String value = line.substring(idx + 1); + while (value.startsWith(" ")) { + value = value.substring(1); + } + result.add(name, value); + } + } + return result; + } + + @Override + public void complete() { + changeState(DisposedState.INSTANCE, null); + InternalParser.this.listener.onError(new HttpMessageConversionException("Could not find end of headers")); + } + + @Override + public void dispose() { + this.buffers.forEach(DataBufferUtils::release); + } + + @Override + public String toString() { + return "HEADERS"; + } + + } + + /** + * The state of the parser dealing with multipart bodies. Relays + * data buffers as {@link PartListener#onBody(DataBuffer, boolean)} + * until the boundary is found (or rather: {@code CR LF - - boundary}). + */ + private final class BodyState implements State { + + private final DataBufferUtils.Matcher boundaryMatcher; + + private final int boundaryLength; + + private final Deque queue = new ArrayDeque<>(); + + public BodyState() { + byte[] delimiter = concat(CR_LF, TWO_HYPHENS, InternalParser.this.boundary); + this.boundaryMatcher = DataBufferUtils.matcher(delimiter); + this.boundaryLength = delimiter.length; + } + + /** + * Checks whether the (end of the) needle {@code CR LF - - boundary} + * can be found in {@code buffer}. If found, the needle can overflow into the + * previous buffer, so we calculate the length and slice the current + * and previous buffers accordingly. We then change to {@link HeadersState} + * and pass on the remainder of {@code buffer}. If the needle is not found, we + * enqueue {@code buffer}. + */ + @Override + public void data(DataBuffer buffer) { + int endIdx = this.boundaryMatcher.match(buffer); + if (endIdx != -1) { + DataBuffer boundaryBuffer = buffer.split(endIdx + 1); + if (logger.isTraceEnabled()) { + logger.trace("Boundary found @" + endIdx + " in " + buffer); + } + int len = endIdx - this.boundaryLength + 1 - boundaryBuffer.readPosition(); + if (len > 0) { + // whole boundary in buffer. + // slice off the body part, and flush + DataBuffer body = boundaryBuffer.split(len); + DataBufferUtils.release(boundaryBuffer); + enqueue(body); + flush(); + } + else if (len < 0) { + // boundary spans multiple buffers, and we've just found the end + // iterate over buffers in reverse order + DataBufferUtils.release(boundaryBuffer); + DataBuffer prev; + while ((prev = this.queue.pollLast()) != null) { + int prevByteCount = prev.readableByteCount(); + int prevLen = prevByteCount + len; + if (prevLen >= 0) { + // slice body part of previous buffer, and flush it + DataBuffer body = prev.split(prevLen + prev.readPosition()); + DataBufferUtils.release(prev); + enqueue(body); + flush(); + break; + } + else { + // previous buffer only contains boundary bytes + DataBufferUtils.release(prev); + len += prevByteCount; + } + } + } + else /* if (len == 0) */ { + // buffer starts with complete delimiter, flush out the previous buffers + DataBufferUtils.release(boundaryBuffer); + flush(); + } + + changeState(new HeadersState(), buffer); + } + else { + enqueue(buffer); + } + } + + /** + * Store the given buffer. Emit buffers that cannot contain boundary bytes, + * by iterating over the queue in reverse order, and summing buffer sizes. + * The first buffer that passes the boundary length and subsequent buffers + * are emitted (in the correct, non-reverse order). + */ + private void enqueue(DataBuffer buf) { + this.queue.add(buf); + + int len = 0; + Deque emit = new ArrayDeque<>(); + for (Iterator iterator = this.queue.descendingIterator(); iterator.hasNext(); ) { + DataBuffer previous = iterator.next(); + if (len > this.boundaryLength) { + // addFirst to negate iterating in reverse order + emit.addFirst(previous); + iterator.remove(); + } + len += previous.readableByteCount(); + } + emit.forEach(buffer -> InternalParser.this.listener.onBody(buffer, false)); + } + + private void flush() { + for (Iterator iterator = this.queue.iterator(); iterator.hasNext(); ) { + DataBuffer buffer = iterator.next(); + boolean last = !iterator.hasNext(); + InternalParser.this.listener.onBody(buffer, last); + } + this.queue.clear(); + } + + @Override + public void complete() { + changeState(DisposedState.INSTANCE, null); + String msg = "Could not find end of body (␍␊--" + + new String(InternalParser.this.boundary, StandardCharsets.UTF_8) + + ")"; + InternalParser.this.listener.onError(new HttpMessageConversionException(msg)); + } + + @Override + public void dispose() { + this.queue.forEach(DataBufferUtils::release); + this.queue.clear(); + } + + @Override + public String toString() { + return "BODY"; + } + } + + /** + * The state of the parser when finished, either due to seeing the final + * boundary or to a malformed message. Releases all incoming buffers. + */ + private static final class DisposedState implements State { + + public static final DisposedState INSTANCE = new DisposedState(); + + private DisposedState() { + } + + @Override + public void data(DataBuffer buf) { + DataBufferUtils.release(buf); + } + + @Override + public void complete() { + } + + @Override + public String toString() { + return "DISPOSED"; + } + } + + } + + + /** + * Listen for part events while parsing the inbound stream of data. + */ + interface PartListener { + + /** + * Handle {@link HttpHeaders} for a part. + */ + void onHeaders(HttpHeaders headers); + + /** + * Handle a piece of data for a body part. + * @param buffer a chunk of body + * @param last whether this is the last chunk for the part + */ + void onBody(DataBuffer buffer, boolean last); + + /** + * Handle the completion event for the Multipart message. + */ + void onComplete(); + + /** + * Handle any error thrown during the parsing phase. + */ + void onError(Throwable error); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/multipart/MultipartUtils.java b/spring-web/src/main/java/org/springframework/http/converter/multipart/MultipartUtils.java new file mode 100644 index 000000000000..37a45a778621 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/multipart/MultipartUtils.java @@ -0,0 +1,47 @@ +/* + * Copyright 2026-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.multipart; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; + +/** + * Various static utility methods for dealing with multipart parsing. + * @author Arjen Poutsma + * @author Brian Clozel + */ +abstract class MultipartUtils { + + /** + * Return the character set of the given headers, as defined in the + * {@link HttpHeaders#getContentType()} header. + */ + static Charset charset(HttpHeaders headers) { + MediaType contentType = headers.getContentType(); + if (contentType != null) { + Charset charset = contentType.getCharset(); + if (charset != null) { + return charset; + } + } + return StandardCharsets.UTF_8; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/multipart/Part.java b/spring-web/src/main/java/org/springframework/http/converter/multipart/Part.java new file mode 100644 index 000000000000..a0a8f7fe83cc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/multipart/Part.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.multipart; + +import java.io.IOException; +import java.io.InputStream; + +import org.springframework.http.HttpHeaders; + + +/** + * Representation for a part in a "multipart/form-data" request. + * + *

    The origin of a multipart request may be a browser form in which case each + * part is either a {@link FormFieldPart} or a {@link FilePart}. + * + *

    Multipart requests may also be used outside a browser for data of any + * content type (for example, JSON, PDF, etc). + * + * @author Brian Clozel + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 7.1 + * @see RFC 7578 (multipart/form-data) + * @see RFC 2183 (Content-Disposition) + * @see HTML5 (multipart forms) + */ +public interface Part { + + /** + * Return the name of the part in the multipart form. + * @return the name of the part, never {@code null} or empty + */ + String name(); + + /** + * Return the headers associated with the part. + */ + HttpHeaders headers(); + + /** + * Return the content for this part. + *

    Note that for a {@link FormFieldPart} the content may be accessed + * more easily via {@link FormFieldPart#value()}. + */ + InputStream content() throws IOException; + + /** + * Delete the underlying storage for this part. + */ + default void delete() throws IOException { + + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/multipart/PartGenerator.java b/spring-web/src/main/java/org/springframework/http/converter/multipart/PartGenerator.java new file mode 100644 index 000000000000..9eb43d05ca56 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/multipart/PartGenerator.java @@ -0,0 +1,394 @@ +/* + * Copyright 2026-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.multipart; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayDeque; +import java.util.List; +import java.util.Queue; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.converter.HttpMessageConversionException; +import org.springframework.util.FastByteArrayOutputStream; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * {@link MultipartParser.PartListener Listen} to a stream of part tokens + * and return a {@code MultiValueMap} as a result. + * + * @author Brian Clozel + * @author Arjen Poutsma + */ +final class PartGenerator implements MultipartParser.PartListener { + + private static final Log logger = LogFactory.getLog(PartGenerator.class); + + private final MultiValueMap parts = new LinkedMultiValueMap<>(); + + private final int maxInMemorySize; + + private final long maxDiskUsagePerPart; + + private final int maxParts; + + private final Path fileStorageDirectory; + + private int partCount; + + private State state; + + + PartGenerator(int maxInMemorySize, long maxDiskUsagePerPart, int maxParts, Path fileStorageDirectory) { + this.maxInMemorySize = maxInMemorySize; + this.maxDiskUsagePerPart = maxDiskUsagePerPart; + this.maxParts = maxParts; + this.fileStorageDirectory = fileStorageDirectory; + this.state = new InitialState(); + } + + /** + * Return the collected parts. + */ + public MultiValueMap getParts() { + return this.parts; + } + + @Override + public void onHeaders(HttpHeaders headers) { + if (isFormField(headers)) { + this.state = new FormFieldState(headers); + } + else { + this.state = new InMemoryState(headers); + } + } + + private static boolean isFormField(HttpHeaders headers) { + MediaType contentType = headers.getContentType(); + return (contentType == null || MediaType.TEXT_PLAIN.equalsTypeAndSubtype(contentType)) && + headers.getContentDisposition().getFilename() == null; + } + + @Override + public void onBody(DataBuffer buffer, boolean last) { + try { + this.state.onBody(buffer, last); + } + catch (Throwable ex) { + deleteParts(); + throw ex; + } + } + + void deleteParts() { + try { + for (List partList : this.parts.values()) { + for (Part part : partList) { + part.delete(); + } + } + } + catch (IOException ex) { + // ignored + } + } + + @Override + public void onComplete() { + if (logger.isTraceEnabled()) { + logger.trace("Finished reading " + this.partCount + " part(s)"); + } + } + + @Override + public void onError(Throwable error) { + deleteParts(); + throw new HttpMessageConversionException("Cannot decode multipart body", error); + } + + void addPart(Part part) { + if (this.maxParts != -1 && this.partCount == this.maxParts) { + throw new HttpMessageConversionException("Maximum number of parts exceeded: " + this.maxParts); + } + try { + this.partCount++; + this.parts.add(part.name(), part); + } + catch (Exception exc) { + throw new HttpMessageConversionException("Part #" + this.partCount + " is unnamed", exc); + } + } + + /** + * Represents the internal state of the {@link PartGenerator} for creating a single {@link Part}. + * {@link State} instances are stateful, and created when a new + * {@link MultipartParser.PartListener#onHeaders(HttpHeaders) headers instance} is accepted. + * The following rules determine which state the creator will have: + *

      + *
    1. If the part is a {@linkplain #isFormField(HttpHeaders) form field}, + * the creator will be in the {@link FormFieldState}.
    2. + *
    3. Otherwise, the creator will initially be in the + * {@link InMemoryState}, but will switch over to {@link FileState} + * when the part byte count exceeds {@link #maxInMemorySize}
    4. + *
    + */ + private interface State { + + /** + * Invoked when a {@link MultipartParser.PartListener#onBody(DataBuffer, boolean)} is received. + */ + void onBody(DataBuffer dataBuffer, boolean last); + + } + + /** + * The initial state of the creator. Throws an exception for {@link #onBody(DataBuffer, boolean)}. + */ + private static final class InitialState implements State { + + private InitialState() { + } + + @Override + public void onBody(DataBuffer dataBuffer, boolean last) { + DataBufferUtils.release(dataBuffer); + throw new HttpMessageConversionException("Body token not expected"); + } + + @Override + public String toString() { + return "INITIAL"; + } + } + + /** + * The creator state when a form field is received. + * Stores all body buffers in memory (up until {@link #maxInMemorySize}). + */ + private final class FormFieldState implements State { + + private final FastByteArrayOutputStream value = new FastByteArrayOutputStream(); + + private final HttpHeaders headers; + + public FormFieldState(HttpHeaders headers) { + this.headers = headers; + } + + @Override + public void onBody(DataBuffer dataBuffer, boolean last) { + int size = this.value.size() + dataBuffer.readableByteCount(); + if (PartGenerator.this.maxInMemorySize == -1 || + size < PartGenerator.this.maxInMemorySize) { + store(dataBuffer); + } + else { + DataBufferUtils.release(dataBuffer); + throw new HttpMessageConversionException("Form field value exceeded the memory usage limit of " + + PartGenerator.this.maxInMemorySize + " bytes"); + } + if (last) { + byte[] bytes = this.value.toByteArrayUnsafe(); + String value = new String(bytes, MultipartUtils.charset(this.headers)); + FormFieldPart formFieldPart = DefaultParts.formFieldPart(this.headers, value); + PartGenerator.this.addPart(formFieldPart); + } + } + + private void store(DataBuffer dataBuffer) { + try { + byte[] bytes = new byte[dataBuffer.readableByteCount()]; + dataBuffer.read(bytes); + this.value.write(bytes); + } + catch (IOException ex) { + throw new HttpMessageConversionException("Cannot store multipart body", ex); + } + finally { + DataBufferUtils.release(dataBuffer); + } + } + + @Override + public String toString() { + return "FORM-FIELD"; + } + } + + /** + * The creator state when not handling a form field. + * Stores all received buffers in a queue. + * If the byte count exceeds {@link #maxInMemorySize}, the creator state + * is changed to {@link FileState}. + */ + private final class InMemoryState implements State { + + private final Queue content = new ArrayDeque<>(); + + private long byteCount; + + private final HttpHeaders headers; + + + public InMemoryState(HttpHeaders headers) { + this.headers = headers; + } + + @Override + public void onBody(DataBuffer dataBuffer, boolean last) { + this.byteCount += dataBuffer.readableByteCount(); + if (PartGenerator.this.maxInMemorySize == -1 || + this.byteCount <= PartGenerator.this.maxInMemorySize) { + this.content.add(dataBuffer); + if (last) { + emitMemoryPart(); + } + } + else { + switchToFile(dataBuffer, last); + } + } + + private void switchToFile(DataBuffer current, boolean last) { + FileState newState = new FileState(this.headers, PartGenerator.this.fileStorageDirectory); + this.content.forEach(newState::writeBuffer); + newState.onBody(current, last); + PartGenerator.this.state = newState; + } + + private void emitMemoryPart() { + byte[] bytes = new byte[(int) this.byteCount]; + int idx = 0; + for (DataBuffer buffer : this.content) { + int len = buffer.readableByteCount(); + buffer.read(bytes, idx, len); + idx += len; + DataBufferUtils.release(buffer); + } + this.content.clear(); + DefaultDataBuffer content = DefaultDataBufferFactory.sharedInstance.wrap(bytes); + Part part = DefaultParts.part(this.headers, content); + PartGenerator.this.addPart(part); + } + + @Override + public String toString() { + return "IN-MEMORY"; + } + } + + /** + * The creator state when writing for a temporary file. + * {@link InMemoryState} initially switches to this state when the byte + * count exceeds {@link #maxInMemorySize}. + */ + private final class FileState implements State { + + private final HttpHeaders headers; + + private final Path file; + + private final OutputStream outputStream; + + private long byteCount; + + + public FileState(HttpHeaders headers, Path folder) { + this.headers = headers; + this.file = createFile(folder); + this.outputStream = createOutputStream(this.file); + } + + @Override + public void onBody(DataBuffer dataBuffer, boolean last) { + this.byteCount += dataBuffer.readableByteCount(); + if (PartGenerator.this.maxDiskUsagePerPart == -1 || this.byteCount <= PartGenerator.this.maxDiskUsagePerPart) { + writeBuffer(dataBuffer); + if (last) { + Part part = DefaultParts.part(this.headers, this.file); + PartGenerator.this.addPart(part); + } + } + else { + try { + this.outputStream.close(); + } + catch (IOException exc) { + // ignored + } + throw new HttpMessageConversionException("Part exceeded the disk usage limit of " + + PartGenerator.this.maxDiskUsagePerPart + " bytes"); + } + } + + private Path createFile(Path directory) { + try { + Path tempFile = Files.createTempFile(directory, null, ".multipart"); + if (logger.isTraceEnabled()) { + logger.trace("Storing multipart data in file " + tempFile); + } + return tempFile; + } + catch (IOException ex) { + throw new UncheckedIOException("Could not create temp file in " + directory, ex); + } + } + + private OutputStream createOutputStream(Path file) { + try { + return Files.newOutputStream(file, StandardOpenOption.WRITE); + } + catch (IOException ex) { + throw new UncheckedIOException("Could not write to temp file " + file, ex); + } + } + + private void writeBuffer(DataBuffer dataBuffer) { + try (InputStream in = dataBuffer.asInputStream()) { + in.transferTo(this.outputStream); + this.outputStream.flush(); + } + catch (IOException exc) { + throw new UncheckedIOException("Could not write to temp file ", exc); + } + finally { + DataBufferUtils.release(dataBuffer); + } + } + + @Override + public String toString() { + return "WRITE-FILE"; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/multipart/package-info.java b/spring-web/src/main/java/org/springframework/http/converter/multipart/package-info.java new file mode 100644 index 000000000000..158445f2c3ec --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/multipart/package-info.java @@ -0,0 +1,7 @@ +/** + * Provides an HttpMessageConverter for Multipart support. + */ +@NullMarked +package org.springframework.http.converter.multipart; + +import org.jspecify.annotations.NullMarked; diff --git a/spring-web/src/main/java/org/springframework/http/converter/protobuf/ProtobufHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/protobuf/ProtobufHttpMessageConverter.java index 8f5eaacbe7c0..8e4b12767e56 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/protobuf/ProtobufHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/protobuf/ProtobufHttpMessageConverter.java @@ -122,6 +122,12 @@ public ProtobufHttpMessageConverter(ExtensionRegistry extensionRegistry) { this(null, extensionRegistry); } + + @Override + public boolean canWriteRepeatedly(Message message, @Nullable MediaType contentType) { + return true; + } + /** * Constructor for a subclass that supports additional formats. * @param formatDelegate delegate to read and write additional formats @@ -255,6 +261,7 @@ private void setProtoHeader(HttpOutputMessage response, Message message) { } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Message message) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/support/AllEncompassingFormHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/support/AllEncompassingFormHttpMessageConverter.java index 24a6ad19044d..7ab5491afaf4 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/support/AllEncompassingFormHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/support/AllEncompassingFormHttpMessageConverter.java @@ -16,12 +16,12 @@ package org.springframework.http.converter.support; -import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverters; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; /** - * Extension of {@link org.springframework.http.converter.FormHttpMessageConverter}, + * Extension of {@link MultipartHttpMessageConverter}, * adding support for XML, JSON, Smile, CBOR, Protobuf and Yaml based parts when * related libraries are present in the classpath. * @@ -29,17 +29,18 @@ * @author Juergen Hoeller * @author Sebastien Deleuze * @since 3.2 + * @deprecated since 7.1 in favor of {@link MultipartHttpMessageConverter}. */ -public class AllEncompassingFormHttpMessageConverter extends FormHttpMessageConverter { +@Deprecated(since = "7.1", forRemoval = true) +public class AllEncompassingFormHttpMessageConverter extends MultipartHttpMessageConverter { /** * Create a new {@link AllEncompassingFormHttpMessageConverter} instance * that will auto-detect part converters. */ - @SuppressWarnings("removal") public AllEncompassingFormHttpMessageConverter() { - HttpMessageConverters.forClient().registerDefaults().build().forEach(this::addPartConverter); + super(HttpMessageConverters.forClient().registerDefaults().build()); } /** @@ -49,7 +50,7 @@ public AllEncompassingFormHttpMessageConverter() { * @since 7.0 */ public AllEncompassingFormHttpMessageConverter(Iterable> converters) { - converters.forEach(this::addPartConverter); + super(converters); } } diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java index fee6d426eead..989d30a658d6 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java @@ -126,6 +126,11 @@ public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { return (supportedType && canWrite(mediaType)); } + @Override + public boolean canWriteRepeatedly(Object o, @Nullable MediaType contentType) { + return true; + } + @Override protected boolean supports(Class clazz) { // should not be called, since we override canRead/Write @@ -235,6 +240,7 @@ private void setCharset(@Nullable MediaType contentType, Marshaller marshaller) } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Object o) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java index 459436a4fcfb..0c383e5b9589 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java @@ -114,6 +114,11 @@ public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { return (canWrite(mediaType) && this.marshaller != null && this.marshaller.supports(clazz)); } + @Override + public boolean canWriteRepeatedly(Object o, @Nullable MediaType contentType) { + return true; + } + @Override protected boolean supports(Class clazz) { // should not be called, since we override canRead()/canWrite() @@ -137,6 +142,7 @@ protected void writeToResult(Object o, HttpHeaders headers, Result result) throw } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(Object o) { return true; } diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/SourceHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/SourceHttpMessageConverter.java index 3cc4cb533e4f..80ea8de6acc5 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/xml/SourceHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/SourceHttpMessageConverter.java @@ -148,6 +148,11 @@ public boolean supports(Class clazz) { return SUPPORTED_CLASSES.contains(clazz); } + @Override + public boolean canWriteRepeatedly(T t, @Nullable MediaType contentType) { + return t instanceof DOMSource; + } + @Override @SuppressWarnings("unchecked") protected T readInternal(Class clazz, HttpInputMessage inputMessage) @@ -293,8 +298,9 @@ private void transform(Source source, Result result) throws TransformerException } @Override + @SuppressWarnings("removal") protected boolean supportsRepeatableWrites(T t) { - return t instanceof DOMSource; + return canWriteRepeatedly(t, null); } diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java b/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java index 6e83480dd216..2ede5e8fcf1c 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java @@ -58,7 +58,10 @@ * @author Juergen Hoeller * @since 3.0 * @see RestTemplate#setErrorHandler + * @deprecated as of 7.1 in favor of {@link RestClient.ResponseSpec.ErrorHandler} */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public class DefaultResponseErrorHandler implements ResponseErrorHandler { private @Nullable List> messageConverters; diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index af7ea9133e0e..0b937410d14b 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -814,6 +814,7 @@ public ResponseSpec onStatus(Predicate statusPredicate, ErrorHan } @Override + @SuppressWarnings("removal") public ResponseSpec onStatus(ResponseErrorHandler errorHandler) { Assert.notNull(errorHandler, "ResponseErrorHandler must not be null"); diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java index 9df03025e8bd..0b93147e7460 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java @@ -148,6 +148,7 @@ public DefaultRestClientBuilder(DefaultRestClientBuilder other) { this.observationConvention = other.observationConvention; } + @SuppressWarnings("removal") public DefaultRestClientBuilder(RestTemplate restTemplate) { Assert.notNull(restTemplate, "RestTemplate must not be null"); @@ -167,6 +168,7 @@ public DefaultRestClientBuilder(RestTemplate restTemplate) { this.observationConvention = restTemplate.getObservationConvention(); } + @SuppressWarnings("removal") private static @Nullable UriBuilderFactory getUriBuilderFactory(RestTemplate restTemplate) { UriTemplateHandler uriTemplateHandler = restTemplate.getUriTemplateHandler(); if (uriTemplateHandler instanceof DefaultUriBuilderFactory builderFactory) { @@ -199,6 +201,7 @@ private static boolean hasRestTemplateDefaults(DefaultUriBuilderFactory factory) factory.shouldParsePath()); } + @SuppressWarnings("removal") private static ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) { ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); if (requestFactory instanceof InterceptingClientHttpRequestFactory interceptingClientHttpRequestFactory) { @@ -297,6 +300,7 @@ public RestClient.Builder defaultStatusHandler(Predicate statusP } @Override + @SuppressWarnings("removal") public RestClient.Builder defaultStatusHandler(ResponseErrorHandler errorHandler) { return defaultStatusHandlerInternal(StatusHandler.fromErrorHandler(errorHandler)); } diff --git a/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java b/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java index 0a3d69581f2f..58e3de3366c2 100644 --- a/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java +++ b/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java @@ -60,7 +60,10 @@ * @author Arjen Poutsma * @since 5.0 * @see RestTemplate#setErrorHandler(ResponseErrorHandler) + * @deprecated as of 7.1 in favor of {@link RestClient.ResponseSpec.ErrorHandler} */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public class ExtractingResponseErrorHandler extends DefaultResponseErrorHandler { private List> messageConverters = Collections.emptyList(); diff --git a/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java b/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java index f386798ad4a8..c6a65aee3831 100644 --- a/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java +++ b/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java @@ -43,7 +43,10 @@ * @since 3.0 * @param the data type * @see RestTemplate + * @deprecated as of 7.1 with no replacement. */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public class HttpMessageConverterExtractor implements ResponseExtractor { private final Type responseType; diff --git a/spring-web/src/main/java/org/springframework/web/client/NoOpResponseErrorHandler.java b/spring-web/src/main/java/org/springframework/web/client/NoOpResponseErrorHandler.java index 3864923d33e3..83c71248ba01 100644 --- a/spring-web/src/main/java/org/springframework/web/client/NoOpResponseErrorHandler.java +++ b/spring-web/src/main/java/org/springframework/web/client/NoOpResponseErrorHandler.java @@ -35,7 +35,10 @@ * * @author Stephane Nicoll * @since 6.1.7 + * @deprecated as of 7.1 in favor of {@link RestClient.ResponseSpec.ErrorHandler} */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public final class NoOpResponseErrorHandler implements ResponseErrorHandler { @Override diff --git a/spring-web/src/main/java/org/springframework/web/client/RequestCallback.java b/spring-web/src/main/java/org/springframework/web/client/RequestCallback.java index 22b2b4c189f7..9b671c845092 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RequestCallback.java +++ b/spring-web/src/main/java/org/springframework/web/client/RequestCallback.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.lang.reflect.Type; +import java.util.function.Consumer; import org.springframework.http.HttpOutputMessage; import org.springframework.http.client.ClientHttpRequest; @@ -37,8 +38,11 @@ * @author Arjen Poutsma * @since 3.0 * @see RestTemplate#execute + * @deprecated as of 7.1 in favor of {@link RestClient.RequestBodySpec#httpRequest(Consumer)}. */ @FunctionalInterface +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public interface RequestCallback { /** diff --git a/spring-web/src/main/java/org/springframework/web/client/ResponseErrorHandler.java b/spring-web/src/main/java/org/springframework/web/client/ResponseErrorHandler.java index 8f9619073084..b39bb52c51c5 100644 --- a/spring-web/src/main/java/org/springframework/web/client/ResponseErrorHandler.java +++ b/spring-web/src/main/java/org/springframework/web/client/ResponseErrorHandler.java @@ -31,7 +31,9 @@ * * @author Arjen Poutsma * @since 3.0 + * @deprecated as of 7.1 in favor of {@link RestClient.ResponseSpec.ErrorHandler} */ +@Deprecated(since = "7.1", forRemoval = true) public interface ResponseErrorHandler { /** diff --git a/spring-web/src/main/java/org/springframework/web/client/ResponseExtractor.java b/spring-web/src/main/java/org/springframework/web/client/ResponseExtractor.java index 5bedc9040046..318c4347f2b1 100644 --- a/spring-web/src/main/java/org/springframework/web/client/ResponseExtractor.java +++ b/spring-web/src/main/java/org/springframework/web/client/ResponseExtractor.java @@ -37,8 +37,11 @@ * @since 3.0 * @param the data type * @see RestTemplate#execute + * @deprecated as of 7.1 in favor of {@link RestClient.RequestBodySpec#exchange(RestClient.RequestHeadersSpec.ExchangeFunction)}. */ @FunctionalInterface +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public interface ResponseExtractor { /** diff --git a/spring-web/src/main/java/org/springframework/web/client/RestClient.java b/spring-web/src/main/java/org/springframework/web/client/RestClient.java index dc72cc5beb5e..22850d80070e 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestClient.java @@ -189,6 +189,7 @@ static RestClient create(URI baseUrl) { * @return a {@code RestClient} initialized with the {@code restTemplate}'s * configuration */ + @SuppressWarnings("removal") static RestClient create(RestTemplate restTemplate) { return new DefaultRestClientBuilder(restTemplate).build(); } @@ -218,6 +219,7 @@ static RestClient.Builder builder() { * @return a {@code RestClient} builder initialized with {@code restTemplate}'s * configuration */ + @SuppressWarnings("removal") static RestClient.Builder builder(RestTemplate restTemplate) { return new DefaultRestClientBuilder(restTemplate); } @@ -386,7 +388,10 @@ Builder defaultStatusHandler(Predicate statusPredicate, * @param errorHandler the error handler to configure, internally adapted * and integrated into the {@link ResponseSpec.ErrorHandler} chain. * @return this builder + * @deprecated as of 7.1 in favor of {@link #defaultStatusHandler(Predicate, ResponseSpec.ErrorHandler)} */ + @Deprecated(since = "7.1", forRemoval = true) + @SuppressWarnings("removal") Builder defaultStatusHandler(ResponseErrorHandler errorHandler); /** @@ -1012,7 +1017,10 @@ ResponseSpec onStatus(Predicate statusPredicate, * {@link RestClientException}. * @param errorHandler the error handler * @return this builder + * @deprecated as of 7.1 in favor of {@link #onStatus(Predicate, ErrorHandler)} */ + @Deprecated(since = "7.1", forRemoval = true) + @SuppressWarnings("removal") ResponseSpec onStatus(ResponseErrorHandler errorHandler); /** diff --git a/spring-web/src/main/java/org/springframework/web/client/RestOperations.java b/spring-web/src/main/java/org/springframework/web/client/RestOperations.java index 502ea82466a5..0a87791cfb13 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestOperations.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestOperations.java @@ -39,7 +39,10 @@ * @author Juergen Hoeller * @since 3.0 * @see RestTemplate + * @deprecated as of 7.1, in favor of {@link RestClient}. */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public interface RestOperations { // GET diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index 7b8558a51e30..50e2ef1e06b3 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -100,7 +100,10 @@ * @see RequestCallback * @see ResponseExtractor * @see ResponseErrorHandler + * @deprecated as of 7.1, in favor of {@link RestClient}. For removal in 8.0. */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public class RestTemplate extends InterceptingHttpAccessor implements RestOperations { private static final ClientRequestObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultClientRequestObservationConvention(); diff --git a/spring-web/src/main/java/org/springframework/web/client/StatusHandler.java b/spring-web/src/main/java/org/springframework/web/client/StatusHandler.java index 7b822b86df2b..1769051c88da 100644 --- a/spring-web/src/main/java/org/springframework/web/client/StatusHandler.java +++ b/spring-web/src/main/java/org/springframework/web/client/StatusHandler.java @@ -90,7 +90,10 @@ public static StatusHandler of( /** * Create a StatusHandler from a {@link ResponseErrorHandler}. + * @deprecated as of 7.1 in favor of {@link #of(Predicate, RestClient.ResponseSpec.ErrorHandler)} */ + @Deprecated(since = "7.1", forRemoval = true) + @SuppressWarnings("removal") public static StatusHandler fromErrorHandler(ResponseErrorHandler errorHandler) { Assert.notNull(errorHandler, "ResponseErrorHandler must not be null"); @@ -160,7 +163,7 @@ private static String getErrorMessage( return preface + bodyText; } - @SuppressWarnings("NullAway") + @SuppressWarnings({"NullAway", "removal"}) private static Function initBodyConvertFunction( ClientHttpResponse response, byte[] body, List> messageConverters) { diff --git a/spring-web/src/main/java/org/springframework/web/client/support/RestGatewaySupport.java b/spring-web/src/main/java/org/springframework/web/client/support/RestGatewaySupport.java index 3a0efc727811..ac163551d849 100644 --- a/spring-web/src/main/java/org/springframework/web/client/support/RestGatewaySupport.java +++ b/spring-web/src/main/java/org/springframework/web/client/support/RestGatewaySupport.java @@ -32,7 +32,10 @@ * @since 3.0 * @see #setRestTemplate * @see org.springframework.web.client.RestTemplate + * @deprecated as of 7.1, in favor of {@link org.springframework.web.client.RestClient}. */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public class RestGatewaySupport { /** Logger available to subclasses. */ diff --git a/spring-web/src/main/java/org/springframework/web/client/support/RestTemplateAdapter.java b/spring-web/src/main/java/org/springframework/web/client/support/RestTemplateAdapter.java index e56d005888f6..a44bdcfcd019 100644 --- a/spring-web/src/main/java/org/springframework/web/client/support/RestTemplateAdapter.java +++ b/spring-web/src/main/java/org/springframework/web/client/support/RestTemplateAdapter.java @@ -45,7 +45,10 @@ * @author Olga Maciaszek-Sharma * @author Brian Clozel * @since 6.1 + * @deprecated as of 7.1 in favor of {@link RestClientAdapter}. */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public final class RestTemplateAdapter implements HttpExchangeAdapter { private final RestTemplate restTemplate; diff --git a/spring-web/src/main/java/org/springframework/web/filter/FormContentFilter.java b/spring-web/src/main/java/org/springframework/web/filter/FormContentFilter.java index ea27c055563e..a9a1d2024278 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/FormContentFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/FormContentFilter.java @@ -36,10 +36,10 @@ import jakarta.servlet.http.HttpServletResponse; import org.jspecify.annotations.Nullable; +import org.springframework.core.ResolvableType; import org.springframework.http.HttpInputMessage; import org.springframework.http.MediaType; import org.springframework.http.converter.FormHttpMessageConverter; -import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -48,7 +48,7 @@ /** * {@code Filter} that parses form data for HTTP PUT, PATCH, and DELETE requests - * and exposes it as Servlet request parameters. By default the Servlet spec + * and exposes it as Servlet request parameters. By default, the Servlet spec * only requires this for HTTP POST. * * @author Rossen Stoyanchev @@ -58,12 +58,12 @@ public class FormContentFilter extends OncePerRequestFilter { private static final List HTTP_METHODS = Arrays.asList("PUT", "PATCH", "DELETE"); - private FormHttpMessageConverter formConverter = new AllEncompassingFormHttpMessageConverter(); + private FormHttpMessageConverter formConverter = new FormHttpMessageConverter(); /** * Set the converter to use for parsing form content. - *

    By default this is an instance of {@link AllEncompassingFormHttpMessageConverter}. + *

    By default, this is an instance of {@link FormHttpMessageConverter}. */ public void setFormConverter(FormHttpMessageConverter converter) { Assert.notNull(converter, "FormHttpMessageConverter is required"); @@ -94,6 +94,7 @@ protected void doFilterInternal( } } + @SuppressWarnings("unchecked") private @Nullable MultiValueMap parseIfNecessary(HttpServletRequest request) throws IOException { if (!shouldParse(request)) { return null; @@ -105,7 +106,7 @@ public InputStream getBody() throws IOException { return request.getInputStream(); } }; - return this.formConverter.read(null, inputMessage); + return (MultiValueMap) this.formConverter.read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class), inputMessage, null); } private boolean shouldParse(HttpServletRequest request) { diff --git a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java index 6dca729698eb..1a1ef3a011ad 100644 --- a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java +++ b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java @@ -316,9 +316,10 @@ private static class KotlinDelegate { Object arg = args[index]; if (!(parameter.isOptional() && arg == null)) { KType type = parameter.getType(); - if (!type.isMarkedNullable() && + if (!(type.isMarkedNullable() && arg == null) && type.getClassifier() instanceof KClass kClass && - KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass))) { + KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass)) && + !JvmClassMappingKt.getJavaClass(kClass).isInstance(arg)) { arg = box(kClass, arg); } argMap.put(parameter, arg); @@ -337,9 +338,10 @@ private static class KotlinDelegate { private static Object box(KClass kClass, @Nullable Object arg) { KFunction constructor = Objects.requireNonNull(KClasses.getPrimaryConstructor(kClass)); KType type = constructor.getParameters().get(0).getType(); - if (!type.isMarkedNullable() && + if (!(type.isMarkedNullable() && arg == null) && type.getClassifier() instanceof KClass parameterClass && - KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(parameterClass))) { + KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(parameterClass)) && + !JvmClassMappingKt.getJavaClass(parameterClass).isInstance(arg)) { arg = box(parameterClass, arg); } if (!KCallablesJvm.isAccessible(constructor)) { diff --git a/spring-web/src/main/java/org/springframework/web/util/DisconnectedClientHelper.java b/spring-web/src/main/java/org/springframework/web/util/DisconnectedClientHelper.java index e4658f914c68..4895e0064eb9 100644 --- a/spring-web/src/main/java/org/springframework/web/util/DisconnectedClientHelper.java +++ b/spring-web/src/main/java/org/springframework/web/util/DisconnectedClientHelper.java @@ -16,7 +16,6 @@ package org.springframework.web.util; -import java.util.LinkedHashSet; import java.util.Locale; import java.util.Set; @@ -27,6 +26,7 @@ import org.springframework.core.NestedExceptionUtils; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; /** * Utility methods to assist with identifying and logging exceptions that @@ -48,7 +48,7 @@ public class DisconnectedClientHelper { Set.of("AbortedException", "ClientAbortException", "EOFException", "EofException", "AsyncRequestNotUsableException"); - private static final Set> EXCLUDED_EXCEPTION_TYPES = new LinkedHashSet<>(4); + private static final Set> EXCLUDED_EXCEPTION_TYPES = CollectionUtils.newLinkedHashSet(4); static { addExcludedExceptionType("org.springframework.web.client.RestClientException"); diff --git a/spring-web/src/main/java/org/springframework/web/util/UriTemplateHandler.java b/spring-web/src/main/java/org/springframework/web/util/UriTemplateHandler.java index 952ae09a5a26..b039ce9bba0d 100644 --- a/spring-web/src/main/java/org/springframework/web/util/UriTemplateHandler.java +++ b/spring-web/src/main/java/org/springframework/web/util/UriTemplateHandler.java @@ -26,7 +26,6 @@ * * @author Rossen Stoyanchev * @since 4.2 - * @see org.springframework.web.client.RestTemplate#setUriTemplateHandler(UriTemplateHandler) */ public interface UriTemplateHandler { diff --git a/spring-web/src/main/kotlin/org/springframework/web/client/RestOperationsExtensions.kt b/spring-web/src/main/kotlin/org/springframework/web/client/RestOperationsExtensions.kt index 8fd277712f2f..2ed06fc368f3 100644 --- a/spring-web/src/main/kotlin/org/springframework/web/client/RestOperationsExtensions.kt +++ b/spring-web/src/main/kotlin/org/springframework/web/client/RestOperationsExtensions.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:Suppress("DEPRECATION", "REMOVAL") + package org.springframework.web.client import org.springframework.core.ParameterizedTypeReference diff --git a/spring-web/src/test/java/org/springframework/http/HttpMethodTests.java b/spring-web/src/test/java/org/springframework/http/HttpMethodTests.java index 1fb6205aad70..ac3a3f0ae4e6 100644 --- a/spring-web/src/test/java/org/springframework/http/HttpMethodTests.java +++ b/spring-web/src/test/java/org/springframework/http/HttpMethodTests.java @@ -57,8 +57,15 @@ void valueOf() { HttpMethod get = HttpMethod.valueOf("GET"); assertThat(get).isSameAs(HttpMethod.GET); - HttpMethod foo = HttpMethod.valueOf("FOO"); - HttpMethod other = HttpMethod.valueOf("FOO"); + get = HttpMethod.valueOf("Get"); + assertThat(get).isSameAs(HttpMethod.GET); + + get = HttpMethod.valueOf("get"); + assertThat(get).isSameAs(HttpMethod.GET); + + HttpMethod foo = HttpMethod.valueOf("foo"); + HttpMethod other = HttpMethod.valueOf("foo"); + assertThat(foo).isNotSameAs(other); assertThat(foo).isEqualTo(other); } @@ -73,4 +80,5 @@ void matches() { assertThat(HttpMethod.GET.matches("GET")).isTrue(); assertThat(HttpMethod.GET.matches("FOO")).isFalse(); } + } diff --git a/spring-web/src/test/java/org/springframework/http/client/support/InterceptingHttpAccessorTests.java b/spring-web/src/test/java/org/springframework/http/client/support/InterceptingHttpAccessorTests.java index bbc41ae24490..ea134551d815 100644 --- a/spring-web/src/test/java/org/springframework/http/client/support/InterceptingHttpAccessorTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/support/InterceptingHttpAccessorTests.java @@ -35,6 +35,7 @@ * * @author Brian Clozel */ +@SuppressWarnings("removal") class InterceptingHttpAccessorTests { @Test diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/DefaultPartHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/DefaultPartHttpMessageReaderTests.java index c77ba81b33fb..54975b30c27a 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/DefaultPartHttpMessageReaderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/DefaultPartHttpMessageReaderTests.java @@ -80,7 +80,7 @@ void canRead(DefaultPartHttpMessageReader reader) { @ParameterizedDefaultPartHttpMessageReaderTest void simple(DefaultPartHttpMessageReader reader) throws InterruptedException { MockServerHttpRequest request = createRequest( - new ClassPathResource("simple.multipart", getClass()), "simple-boundary"); + "simple.multipart", "simple-boundary"); Flux result = reader.read(forClass(Part.class), request, emptyMap()); @@ -98,7 +98,7 @@ void simple(DefaultPartHttpMessageReader reader) throws InterruptedException { @ParameterizedDefaultPartHttpMessageReaderTest void noHeaders(DefaultPartHttpMessageReader reader) { MockServerHttpRequest request = createRequest( - new ClassPathResource("no-header.multipart", getClass()), "boundary"); + "no-header.multipart", "boundary"); Flux result = reader.read(forClass(Part.class), request, emptyMap()); StepVerifier.create(result) @@ -112,7 +112,7 @@ void noHeaders(DefaultPartHttpMessageReader reader) { @ParameterizedDefaultPartHttpMessageReaderTest void noEndBoundary(DefaultPartHttpMessageReader reader) { MockServerHttpRequest request = createRequest( - new ClassPathResource("no-end-boundary.multipart", getClass()), "boundary"); + "no-end-boundary.multipart", "boundary"); Flux result = reader.read(forClass(Part.class), request, emptyMap()); @@ -128,7 +128,7 @@ void noEndBoundary(DefaultPartHttpMessageReader reader) { @ParameterizedDefaultPartHttpMessageReaderTest void garbage(DefaultPartHttpMessageReader reader) { MockServerHttpRequest request = createRequest( - new ClassPathResource("garbage-1.multipart", getClass()), "boundary"); + "garbage-1.multipart", "boundary"); Flux result = reader.read(forClass(Part.class), request, emptyMap()); @@ -140,7 +140,7 @@ void garbage(DefaultPartHttpMessageReader reader) { @ParameterizedDefaultPartHttpMessageReaderTest void noEndHeader(DefaultPartHttpMessageReader reader) { MockServerHttpRequest request = createRequest( - new ClassPathResource("no-end-header.multipart", getClass()), "boundary"); + "no-end-header.multipart", "boundary"); Flux result = reader.read(forClass(Part.class), request, emptyMap()); StepVerifier.create(result) @@ -151,7 +151,7 @@ void noEndHeader(DefaultPartHttpMessageReader reader) { @ParameterizedDefaultPartHttpMessageReaderTest void noEndBody(DefaultPartHttpMessageReader reader) { MockServerHttpRequest request = createRequest( - new ClassPathResource("no-end-body.multipart", getClass()), "boundary"); + "no-end-body.multipart", "boundary"); Flux result = reader.read(forClass(Part.class), request, emptyMap()); StepVerifier.create(result) @@ -162,7 +162,7 @@ void noEndBody(DefaultPartHttpMessageReader reader) { @ParameterizedDefaultPartHttpMessageReaderTest void cancelPart(DefaultPartHttpMessageReader reader) { MockServerHttpRequest request = createRequest( - new ClassPathResource("simple.multipart", getClass()), "simple-boundary"); + "simple.multipart", "simple-boundary"); Flux result = reader.read(forClass(Part.class), request, emptyMap()); StepVerifier.create(result, 1) @@ -174,7 +174,7 @@ void cancelPart(DefaultPartHttpMessageReader reader) { @ParameterizedDefaultPartHttpMessageReaderTest void cancelBody(DefaultPartHttpMessageReader reader) throws Exception { MockServerHttpRequest request = createRequest( - new ClassPathResource("simple.multipart", getClass()), "simple-boundary"); + "simple.multipart", "simple-boundary"); Flux result = reader.read(forClass(Part.class), request, emptyMap()); CountDownLatch latch = new CountDownLatch(1); @@ -191,7 +191,7 @@ void cancelBody(DefaultPartHttpMessageReader reader) throws Exception { @ParameterizedDefaultPartHttpMessageReaderTest void cancelBodyThenPart(DefaultPartHttpMessageReader reader) { MockServerHttpRequest request = createRequest( - new ClassPathResource("simple.multipart", getClass()), "simple-boundary"); + "simple.multipart", "simple-boundary"); Flux result = reader.read(forClass(Part.class), request, emptyMap()); StepVerifier.create(result, 1) @@ -202,26 +202,26 @@ void cancelBodyThenPart(DefaultPartHttpMessageReader reader) { @ParameterizedDefaultPartHttpMessageReaderTest void firefox(DefaultPartHttpMessageReader reader) throws InterruptedException { - testBrowser(reader, new ClassPathResource("firefox.multipart", getClass()), + testBrowser(reader, "firefox.multipart", "---------------------------18399284482060392383840973206"); } @ParameterizedDefaultPartHttpMessageReaderTest void chrome(DefaultPartHttpMessageReader reader) throws InterruptedException { - testBrowser(reader, new ClassPathResource("chrome.multipart", getClass()), + testBrowser(reader, "chrome.multipart", "----WebKitFormBoundaryEveBLvRT65n21fwU"); } @ParameterizedDefaultPartHttpMessageReaderTest void safari(DefaultPartHttpMessageReader reader) throws InterruptedException { - testBrowser(reader, new ClassPathResource("safari.multipart", getClass()), + testBrowser(reader, "safari.multipart", "----WebKitFormBoundaryG8fJ50opQOML0oGD"); } @Test void tooManyParts() throws InterruptedException { MockServerHttpRequest request = createRequest( - new ClassPathResource("simple.multipart", getClass()), "simple-boundary"); + "simple.multipart", "simple-boundary"); DefaultPartHttpMessageReader reader = new DefaultPartHttpMessageReader(); reader.setMaxParts(1); @@ -241,7 +241,7 @@ void tooManyParts() throws InterruptedException { @ParameterizedDefaultPartHttpMessageReaderTest void quotedBoundary(DefaultPartHttpMessageReader reader) throws InterruptedException { MockServerHttpRequest request = createRequest( - new ClassPathResource("simple.multipart", getClass()), "\"simple-boundary\""); + "simple.multipart", "\"simple-boundary\""); Flux result = reader.read(forClass(Part.class), request, emptyMap()); @@ -259,7 +259,7 @@ void quotedBoundary(DefaultPartHttpMessageReader reader) throws InterruptedExcep @ParameterizedDefaultPartHttpMessageReaderTest void utf8Headers(DefaultPartHttpMessageReader reader) throws InterruptedException { MockServerHttpRequest request = createRequest( - new ClassPathResource("utf8.multipart", getClass()), "\"simple-boundary\""); + "utf8.multipart", "\"simple-boundary\""); Flux result = reader.read(forClass(Part.class), request, emptyMap()); @@ -278,7 +278,7 @@ void utf8Headers(DefaultPartHttpMessageReader reader) throws InterruptedExceptio @Test void exceedHeaderLimit() throws InterruptedException { Flux body = DataBufferUtils - .readByteChannel((new ClassPathResource("files.multipart", getClass()))::readableChannel, bufferFactory, 282); + .readByteChannel(new ClassPathResource("/org/springframework/http/multipart/files.multipart")::readableChannel, bufferFactory, 282); MediaType contentType = new MediaType("multipart", "form-data", singletonMap("boundary", "----WebKitFormBoundaryG8fJ50opQOML0oGD")); MockServerHttpRequest request = MockServerHttpRequest.post("/") @@ -303,7 +303,7 @@ void exceedHeaderLimit() throws InterruptedException { @ParameterizedDefaultPartHttpMessageReaderTest void emptyLastPart(DefaultPartHttpMessageReader reader) throws InterruptedException { MockServerHttpRequest request = createRequest( - new ClassPathResource("empty-part.multipart", getClass()), "LiG0chJ0k7YtLt-FzTklYFgz50i88xJCW5jD"); + "empty-part.multipart", "LiG0chJ0k7YtLt-FzTklYFgz50i88xJCW5jD"); Flux result = reader.read(forClass(Part.class), request, emptyMap()); @@ -317,10 +317,10 @@ void emptyLastPart(DefaultPartHttpMessageReader reader) throws InterruptedExcept } - private void testBrowser(DefaultPartHttpMessageReader reader, Resource resource, String boundary) + private void testBrowser(DefaultPartHttpMessageReader reader, String fileName, String boundary) throws InterruptedException { - MockServerHttpRequest request = createRequest(resource, boundary); + MockServerHttpRequest request = createRequest(fileName, boundary); Flux result = reader.read(forClass(Part.class), request, emptyMap()); CountDownLatch latch = new CountDownLatch(3); @@ -334,7 +334,8 @@ private void testBrowser(DefaultPartHttpMessageReader reader, Resource resource, latch.await(); } - private MockServerHttpRequest createRequest(Resource resource, String boundary) { + private MockServerHttpRequest createRequest(String fileName, String boundary) { + Resource resource = new ClassPathResource("/org/springframework/http/multipart/" + fileName); Flux body = DataBufferUtils .readByteChannel(resource::readableChannel, bufferFactory, BUFFER_SIZE); diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/PartEventHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/PartEventHttpMessageReaderTests.java index 332b8c1d6be9..aa9cab7d61a9 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/PartEventHttpMessageReaderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/PartEventHttpMessageReaderTests.java @@ -66,7 +66,7 @@ void canRead() { @Test void simple() { MockServerHttpRequest request = createRequest( - new ClassPathResource("simple.multipart", getClass()), "simple-boundary"); + "simple.multipart", "simple-boundary"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); @@ -80,7 +80,7 @@ void simple() { @Test void noHeaders() { MockServerHttpRequest request = createRequest( - new ClassPathResource("no-header.multipart", getClass()), "boundary"); + "no-header.multipart", "boundary"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); StepVerifier.create(result) @@ -91,7 +91,7 @@ void noHeaders() { @Test void noEndBoundary() { MockServerHttpRequest request = createRequest( - new ClassPathResource("no-end-boundary.multipart", getClass()), "boundary"); + "no-end-boundary.multipart", "boundary"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); @@ -103,7 +103,7 @@ void noEndBoundary() { @Test void garbage() { MockServerHttpRequest request = createRequest( - new ClassPathResource("garbage-1.multipart", getClass()), "boundary"); + "garbage-1.multipart", "boundary"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); @@ -116,7 +116,7 @@ void garbage() { @Test void noEndHeader() { MockServerHttpRequest request = createRequest( - new ClassPathResource("no-end-header.multipart", getClass()), "boundary"); + "no-end-header.multipart", "boundary"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); StepVerifier.create(result) @@ -127,7 +127,7 @@ void noEndHeader() { @Test void noEndBody() { MockServerHttpRequest request = createRequest( - new ClassPathResource("no-end-body.multipart", getClass()), "boundary"); + "no-end-body.multipart", "boundary"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); StepVerifier.create(result) @@ -138,7 +138,7 @@ void noEndBody() { @Test void noBody() { MockServerHttpRequest request = createRequest( - new ClassPathResource("no-body.multipart", getClass()), "boundary"); + "no-body.multipart", "boundary"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); StepVerifier.create(result) @@ -151,7 +151,7 @@ void noBody() { @Test void cancel() { MockServerHttpRequest request = createRequest( - new ClassPathResource("simple.multipart", getClass()), "simple-boundary"); + "simple.multipart", "simple-boundary"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); StepVerifier.create(result, 3) @@ -165,7 +165,7 @@ void cancel() { @Test void firefox() { - MockServerHttpRequest request = createRequest(new ClassPathResource("firefox.multipart", getClass()), + MockServerHttpRequest request = createRequest("firefox.multipart", "---------------------------18399284482060392383840973206"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); @@ -187,7 +187,7 @@ void firefox() { @Test void chrome() { - MockServerHttpRequest request = createRequest(new ClassPathResource("chrome.multipart", getClass()), + MockServerHttpRequest request = createRequest("chrome.multipart", "----WebKitFormBoundaryEveBLvRT65n21fwU"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); @@ -208,7 +208,7 @@ void chrome() { @Test void safari() { - MockServerHttpRequest request = createRequest(new ClassPathResource("safari.multipart", getClass()), + MockServerHttpRequest request = createRequest("safari.multipart", "----WebKitFormBoundaryG8fJ50opQOML0oGD"); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); @@ -229,7 +229,7 @@ void safari() { @Test void tooManyParts() { MockServerHttpRequest request = createRequest( - new ClassPathResource("simple.multipart", getClass()), "simple-boundary"); + "simple.multipart", "simple-boundary"); PartEventHttpMessageReader reader = new PartEventHttpMessageReader(); reader.setMaxParts(1); @@ -244,7 +244,7 @@ void tooManyParts() { @Test void partSizeTooLarge() { - MockServerHttpRequest request = createRequest(new ClassPathResource("safari.multipart", getClass()), + MockServerHttpRequest request = createRequest("safari.multipart", "----WebKitFormBoundaryG8fJ50opQOML0oGD"); PartEventHttpMessageReader reader = new PartEventHttpMessageReader(); @@ -262,7 +262,7 @@ void partSizeTooLarge() { @Test void formPartTooLarge() { MockServerHttpRequest request = createRequest( - new ClassPathResource("simple.multipart", getClass()), "simple-boundary"); + "simple.multipart", "simple-boundary"); PartEventHttpMessageReader reader = new PartEventHttpMessageReader(); reader.setMaxInMemorySize(40); @@ -277,7 +277,7 @@ void formPartTooLarge() { @Test void utf8Headers() { MockServerHttpRequest request = createRequest( - new ClassPathResource("utf8.multipart", getClass()), "\"simple-boundary\""); + "utf8.multipart", "\"simple-boundary\""); Flux result = this.reader.read(forClass(PartEvent.class), request, emptyMap()); @@ -290,7 +290,7 @@ void utf8Headers() { @Test void exceedHeaderLimit() { Flux body = DataBufferUtils - .readByteChannel((new ClassPathResource("files.multipart", getClass()))::readableChannel, bufferFactory, + .readByteChannel((new ClassPathResource("/org/springframework/http/multipart/files.multipart"))::readableChannel, bufferFactory, 282); MediaType contentType = new MediaType("multipart", "form-data", @@ -309,7 +309,8 @@ void exceedHeaderLimit() { .verifyComplete(); } - private MockServerHttpRequest createRequest(Resource resource, String boundary) { + private MockServerHttpRequest createRequest(String fileName, String boundary) { + Resource resource = new ClassPathResource("/org/springframework/http/multipart/" + fileName); Flux body = DataBufferUtils .readByteChannel(resource::readableChannel, bufferFactory, BUFFER_SIZE); diff --git a/spring-web/src/test/java/org/springframework/http/converter/DefaultHttpMessageConvertersTests.java b/spring-web/src/test/java/org/springframework/http/converter/DefaultHttpMessageConvertersTests.java index 0c0cb5af427d..7c4ccaf86973 100644 --- a/spring-web/src/test/java/org/springframework/http/converter/DefaultHttpMessageConvertersTests.java +++ b/spring-web/src/test/java/org/springframework/http/converter/DefaultHttpMessageConvertersTests.java @@ -33,9 +33,9 @@ import org.springframework.http.converter.feed.RssChannelHttpMessageConverter; import org.springframework.http.converter.json.JacksonJsonHttpMessageConverter; import org.springframework.http.converter.json.KotlinSerializationJsonHttpMessageConverter; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; import org.springframework.http.converter.protobuf.KotlinSerializationProtobufHttpMessageConverter; import org.springframework.http.converter.smile.JacksonSmileHttpMessageConverter; -import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; import org.springframework.http.converter.xml.JacksonXmlHttpMessageConverter; import org.springframework.http.converter.yaml.JacksonYamlHttpMessageConverter; @@ -68,6 +68,13 @@ void failsWhenStringConverterDoesNotSupportMediaType() { .withMessage("converter should support 'text/plain'"); } + @Test + void failsWhenFormConverterDoesNotSupportMediaType() { + assertThatIllegalArgumentException() + .isThrownBy(() -> HttpMessageConverters.forClient().withFormConverter(new CustomHttpMessageConverter()).build()) + .withMessage("converter should support 'application/x-www-form-urlencoded'"); + } + @Test void failsWhenJsonConverterDoesNotSupportMediaType() { assertThatIllegalArgumentException() @@ -116,8 +123,9 @@ class ClientConvertersTests { void defaultConverters() { var converters = HttpMessageConverters.forClient().registerDefaults().build(); assertThat(converters).hasExactlyElementsOfTypes(ByteArrayHttpMessageConverter.class, - StringHttpMessageConverter.class, ResourceHttpMessageConverter.class, - AllEncompassingFormHttpMessageConverter.class, KotlinSerializationJsonHttpMessageConverter.class, + StringHttpMessageConverter.class, FormHttpMessageConverter.class, + ResourceHttpMessageConverter.class, MultipartHttpMessageConverter.class, + KotlinSerializationJsonHttpMessageConverter.class, JacksonJsonHttpMessageConverter.class, JacksonSmileHttpMessageConverter.class, KotlinSerializationCborHttpMessageConverter.class, JacksonCborHttpMessageConverter.class, JacksonYamlHttpMessageConverter.class, JacksonXmlHttpMessageConverter.class, @@ -134,7 +142,7 @@ void disableDefaults() { @Test void multipartConverterContainsOtherConverters() { var converters = HttpMessageConverters.forClient().registerDefaults().build(); - var multipartConverter = findMessageConverter(AllEncompassingFormHttpMessageConverter.class, converters); + var multipartConverter = findMessageConverter(MultipartHttpMessageConverter.class, converters); assertThat(multipartConverter.getPartConverters()).hasExactlyElementsOfTypes( ByteArrayHttpMessageConverter.class, StringHttpMessageConverter.class, @@ -150,7 +158,7 @@ void multipartConverterContainsOtherConverters() { void registerCustomMessageConverter() { var converters = HttpMessageConverters.forClient() .addCustomConverter(new CustomHttpMessageConverter()).build(); - assertThat(converters).hasExactlyElementsOfTypes(CustomHttpMessageConverter.class, AllEncompassingFormHttpMessageConverter.class); + assertThat(converters).hasExactlyElementsOfTypes(CustomHttpMessageConverter.class, MultipartHttpMessageConverter.class); } @Test @@ -159,8 +167,9 @@ void registerCustomMessageConverterAheadOfDefaults() { .addCustomConverter(new CustomHttpMessageConverter()).build(); assertThat(converters).hasExactlyElementsOfTypes( CustomHttpMessageConverter.class, ByteArrayHttpMessageConverter.class, - StringHttpMessageConverter.class, ResourceHttpMessageConverter.class, - AllEncompassingFormHttpMessageConverter.class, KotlinSerializationJsonHttpMessageConverter.class, + StringHttpMessageConverter.class, FormHttpMessageConverter.class, + ResourceHttpMessageConverter.class, MultipartHttpMessageConverter.class, + KotlinSerializationJsonHttpMessageConverter.class, JacksonJsonHttpMessageConverter.class, JacksonSmileHttpMessageConverter.class, KotlinSerializationCborHttpMessageConverter.class, JacksonCborHttpMessageConverter.class, JacksonYamlHttpMessageConverter.class, JacksonXmlHttpMessageConverter.class, @@ -172,7 +181,7 @@ void registerCustomMessageConverterAheadOfDefaults() { void registerCustomConverterInMultipartConverter() { var converters = HttpMessageConverters.forClient().registerDefaults() .addCustomConverter(new CustomHttpMessageConverter()).build(); - var multipartConverter = findMessageConverter(AllEncompassingFormHttpMessageConverter.class, converters); + var multipartConverter = findMessageConverter(MultipartHttpMessageConverter.class, converters); assertThat(multipartConverter.getPartConverters()).hasAtLeastOneElementOfType(CustomHttpMessageConverter.class); } @@ -181,7 +190,7 @@ void shouldConfigureOverridesWhenDefaultOff() { var stringConverter = new StringHttpMessageConverter(); var converters = HttpMessageConverters.forClient().withStringConverter(stringConverter).build(); assertThat(converters).hasExactlyElementsOfTypes( - StringHttpMessageConverter.class, AllEncompassingFormHttpMessageConverter.class); + StringHttpMessageConverter.class, MultipartHttpMessageConverter.class); var configured = findMessageConverter(StringHttpMessageConverter.class, converters); assertThat(configured).isEqualTo(stringConverter); } @@ -251,8 +260,9 @@ void defaultConverters() { var converters = HttpMessageConverters.forServer().registerDefaults().build(); assertThat(converters).hasExactlyElementsOfTypes( ByteArrayHttpMessageConverter.class, StringHttpMessageConverter.class, - ResourceHttpMessageConverter.class, ResourceRegionHttpMessageConverter.class, - AllEncompassingFormHttpMessageConverter.class, KotlinSerializationJsonHttpMessageConverter.class, + FormHttpMessageConverter.class, ResourceHttpMessageConverter.class, + ResourceRegionHttpMessageConverter.class, MultipartHttpMessageConverter.class, + KotlinSerializationJsonHttpMessageConverter.class, JacksonJsonHttpMessageConverter.class, JacksonSmileHttpMessageConverter.class, KotlinSerializationCborHttpMessageConverter.class, JacksonCborHttpMessageConverter.class, JacksonYamlHttpMessageConverter.class, JacksonXmlHttpMessageConverter.class, @@ -269,7 +279,7 @@ void disableDefaults() { @Test void multipartConverterContainsOtherConverters() { var converters = HttpMessageConverters.forServer().registerDefaults().build(); - var multipartConverter = findMessageConverter(AllEncompassingFormHttpMessageConverter.class, converters); + var multipartConverter = findMessageConverter(MultipartHttpMessageConverter.class, converters); assertThat(multipartConverter.getPartConverters()).hasExactlyElementsOfTypes( ByteArrayHttpMessageConverter.class, StringHttpMessageConverter.class, @@ -285,7 +295,7 @@ void multipartConverterContainsOtherConverters() { void registerCustomMessageConverter() { var converters = HttpMessageConverters.forServer() .addCustomConverter(new CustomHttpMessageConverter()).build(); - assertThat(converters).hasExactlyElementsOfTypes(CustomHttpMessageConverter.class, AllEncompassingFormHttpMessageConverter.class); + assertThat(converters).hasExactlyElementsOfTypes(CustomHttpMessageConverter.class, MultipartHttpMessageConverter.class); } @Test @@ -295,8 +305,9 @@ void registerCustomMessageConverterAheadOfDefaults() { assertThat(converters).hasExactlyElementsOfTypes( CustomHttpMessageConverter.class, ByteArrayHttpMessageConverter.class, StringHttpMessageConverter.class, - ResourceHttpMessageConverter.class, ResourceRegionHttpMessageConverter.class, - AllEncompassingFormHttpMessageConverter.class, KotlinSerializationJsonHttpMessageConverter.class, + FormHttpMessageConverter.class, ResourceHttpMessageConverter.class, + ResourceRegionHttpMessageConverter.class, MultipartHttpMessageConverter.class, + KotlinSerializationJsonHttpMessageConverter.class, JacksonJsonHttpMessageConverter.class, JacksonSmileHttpMessageConverter.class, KotlinSerializationCborHttpMessageConverter.class, JacksonCborHttpMessageConverter.class, JacksonYamlHttpMessageConverter.class, JacksonXmlHttpMessageConverter.class, @@ -308,7 +319,7 @@ void registerCustomMessageConverterAheadOfDefaults() { void registerCustomConverterInMultipartConverter() { var converters = HttpMessageConverters.forServer().registerDefaults() .addCustomConverter(new CustomHttpMessageConverter()).build(); - var multipartConverter = findMessageConverter(AllEncompassingFormHttpMessageConverter.class, converters); + var multipartConverter = findMessageConverter(MultipartHttpMessageConverter.class, converters); assertThat(multipartConverter.getPartConverters()).hasAtLeastOneElementOfType(CustomHttpMessageConverter.class); } @@ -317,7 +328,7 @@ void shouldConfigureOverridesWhenDefaultOff() { var stringConverter = new StringHttpMessageConverter(); var converters = HttpMessageConverters.forServer().withStringConverter(stringConverter).build(); assertThat(converters).hasExactlyElementsOfTypes( - StringHttpMessageConverter.class, AllEncompassingFormHttpMessageConverter.class); + StringHttpMessageConverter.class, MultipartHttpMessageConverter.class); var configured = findMessageConverter(StringHttpMessageConverter.class, converters); assertThat(configured).isEqualTo(stringConverter); } diff --git a/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java index 049340968a68..2e6b89350a02 100644 --- a/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java +++ b/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java @@ -16,33 +16,16 @@ package org.springframework.http.converter; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.InputStream; -import java.io.StringReader; import java.nio.charset.StandardCharsets; -import java.util.LinkedHashMap; +import java.util.HashMap; import java.util.List; import java.util.Map; -import javax.xml.transform.Source; -import javax.xml.transform.stream.StreamSource; - -import org.apache.tomcat.util.http.fileupload.FileItem; -import org.apache.tomcat.util.http.fileupload.FileUpload; -import org.apache.tomcat.util.http.fileupload.RequestContext; -import org.apache.tomcat.util.http.fileupload.UploadContext; -import org.apache.tomcat.util.http.fileupload.disk.DiskFileItemFactory; import org.junit.jupiter.api.Test; -import org.springframework.core.io.ClassPathResource; -import org.springframework.core.io.Resource; -import org.springframework.http.HttpEntity; -import org.springframework.http.HttpHeaders; +import org.springframework.core.ResolvableType; import org.springframework.http.MediaType; -import org.springframework.http.StreamingHttpOutputMessage; -import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; -import org.springframework.http.converter.xml.SourceHttpMessageConverter; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.testfixture.http.MockHttpInputMessage; @@ -52,86 +35,123 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED; -import static org.springframework.http.MediaType.APPLICATION_JSON; import static org.springframework.http.MediaType.MULTIPART_FORM_DATA; import static org.springframework.http.MediaType.MULTIPART_MIXED; import static org.springframework.http.MediaType.MULTIPART_RELATED; -import static org.springframework.http.MediaType.TEXT_XML; /** - * Tests for {@link FormHttpMessageConverter} and - * {@link AllEncompassingFormHttpMessageConverter}. + * Tests for {@link FormHttpMessageConverter}. * * @author Arjen Poutsma * @author Rossen Stoyanchev * @author Sam Brannen * @author Sebastien Deleuze + * @author Brian Clozel */ class FormHttpMessageConverterTests { - private final FormHttpMessageConverter converter = new AllEncompassingFormHttpMessageConverter(); + private static final ResolvableType LINKED_MULTI_VALUE_MAP = + ResolvableType.forClassWithGenerics(LinkedMultiValueMap.class, String.class, String.class); + + private static final ResolvableType MULTI_VALUE_MAP = + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); + private static final ResolvableType MAP = + ResolvableType.forClassWithGenerics(Map.class, String.class, String.class); + + private final FormHttpMessageConverter converter = new FormHttpMessageConverter(); @Test - void canRead() { - assertCanRead(MultiValueMap.class, null); - assertCanRead(APPLICATION_FORM_URLENCODED); + void cannotReadToMapsWhenMediaTypeMissing() { + assertCannotRead(MAP, null); + assertCannotRead(MULTI_VALUE_MAP, null); + assertCannotRead(LINKED_MULTI_VALUE_MAP, null); + // without generics + assertCannotRead(ResolvableType.forClass(Map.class), null); + assertCannotRead(ResolvableType.forClass(MultiValueMap.class), null); + assertCannotRead(ResolvableType.forClass(LinkedMultiValueMap.class), null); + } - assertCannotRead(String.class, null); - assertCannotRead(String.class, APPLICATION_FORM_URLENCODED); + @Test + void canReadToMapTypes() { + assertCanRead(MAP, APPLICATION_FORM_URLENCODED); + assertCanRead(MULTI_VALUE_MAP, APPLICATION_FORM_URLENCODED); + assertCanRead(LINKED_MULTI_VALUE_MAP, APPLICATION_FORM_URLENCODED); + // without generics + assertCanRead(ResolvableType.forClass(Map.class), APPLICATION_FORM_URLENCODED); + assertCanRead(ResolvableType.forClass(MultiValueMap.class), APPLICATION_FORM_URLENCODED); + assertCanRead(ResolvableType.forClass(LinkedMultiValueMap.class), APPLICATION_FORM_URLENCODED); } @Test void cannotReadMultipart() { // Without custom multipart types supported - asssertCannotReadMultipart(); + assertCannotReadMultipart(); // Should still be the case with custom multipart types supported - asssertCannotReadMultipart(); + assertCannotReadMultipart(); } @Test void canWrite() { assertCanWrite(APPLICATION_FORM_URLENCODED); - assertCanWrite(MULTIPART_FORM_DATA); - assertCanWrite(MULTIPART_MIXED); - assertCanWrite(MULTIPART_RELATED); - assertCanWrite(new MediaType("multipart", "form-data", UTF_8)); - assertCanWrite(MediaType.ALL); - assertCanWrite(null); + assertCannotWrite(MediaType.ALL); } - @Test - void setSupportedMediaTypes() { - this.converter.setSupportedMediaTypes(List.of(MULTIPART_FORM_DATA)); - assertCannotWrite(MULTIPART_MIXED); - this.converter.setSupportedMediaTypes(List.of(MULTIPART_MIXED)); - assertCanWrite(MULTIPART_MIXED); + @Test + void canWriteMapTypes() { + assertCanWrite(MAP, APPLICATION_FORM_URLENCODED); + assertCanWrite(MULTI_VALUE_MAP, APPLICATION_FORM_URLENCODED); + assertCanWrite(LINKED_MULTI_VALUE_MAP, APPLICATION_FORM_URLENCODED); + // without generics + assertCanWrite(ResolvableType.forClass(Map.class), APPLICATION_FORM_URLENCODED); + assertCanWrite(ResolvableType.forClass(MultiValueMap.class), APPLICATION_FORM_URLENCODED); + assertCanWrite(ResolvableType.forClass(LinkedMultiValueMap.class), APPLICATION_FORM_URLENCODED); } @Test - void addSupportedMediaTypes() { - this.converter.setSupportedMediaTypes(List.of(MULTIPART_FORM_DATA)); + void cannotWriteMultipart() { + assertCannotWrite(MULTIPART_FORM_DATA); assertCannotWrite(MULTIPART_MIXED); - - this.converter.addSupportedMediaTypes(MULTIPART_RELATED); - assertCanWrite(MULTIPART_RELATED); + assertCannotWrite(MULTIPART_RELATED); + assertCannotWrite(new MediaType("multipart", "form-data", UTF_8)); + assertCannotWrite(null); } @Test - void readForm() throws Exception { + @SuppressWarnings("unchecked") + void readFormAsMultiValueMap() throws Exception { String body = "name+1=value+1&name+2=value+2%2B1&name+2=value+2%2B2&name+3"; MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.ISO_8859_1)); inputMessage.getHeaders().setContentType( new MediaType("application", "x-www-form-urlencoded", StandardCharsets.ISO_8859_1)); - MultiValueMap result = this.converter.read(null, inputMessage); + Object result = this.converter.read(ResolvableType.forClass(MultiValueMap.class), inputMessage, null); - assertThat(result).as("Invalid result").hasSize(3); - assertThat(result.getFirst("name 1")).as("Invalid result").isEqualTo("value 1"); - List values = result.get("name 2"); + assertThat(result).isInstanceOf(MultiValueMap.class); + MultiValueMap form = (MultiValueMap) result; + assertThat(form).as("Invalid result").hasSize(3); + assertThat(form.getFirst("name 1")).as("Invalid result").isEqualTo("value 1"); + List values = form.get("name 2"); assertThat(values).as("Invalid result").containsExactly("value 2+1", "value 2+2"); - assertThat(result.getFirst("name 3")).as("Invalid result").isNull(); + assertThat(form.getFirst("name 3")).as("Invalid result").isNull(); + } + + @Test + @SuppressWarnings("unchecked") + void readFormAsMap() throws Exception { + String body = "name+1=value+1&name+2=value+2&name+3"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.ISO_8859_1)); + inputMessage.getHeaders().setContentType( + new MediaType("application", "x-www-form-urlencoded", StandardCharsets.ISO_8859_1)); + Object result = this.converter.read(ResolvableType.forClass(Map.class), inputMessage, null); + + assertThat(result).isInstanceOf(Map.class); + Map form = (Map) result; + assertThat(form).as("Invalid result").hasSize(3); + assertThat(form.get("name 1")).as("Invalid result").isEqualTo("value 1"); + assertThat(form.get("name 2")).as("Invalid result").isEqualTo("value 2"); + assertThat(form.get("name 3")).as("Invalid result").isNull(); } @Test @@ -156,7 +176,7 @@ void readInvalidFormWithNameWithNoValueThatWontUrlDecode() { } @Test - void writeForm() throws IOException { + void writeFormFromMultiValueMap() throws IOException { MultiValueMap body = new LinkedMultiValueMap<>(); body.set("name 1", "value 1"); body.add("name 2", "value 2+1"); @@ -174,245 +194,30 @@ void writeForm() throws IOException { } @Test - void writeMultipart() throws Exception { - - MultiValueMap parts = new LinkedMultiValueMap<>(); - parts.add("name 1", "value 1"); - parts.add("name 2", "value 2+1"); - parts.add("name 2", "value 2+2"); - parts.add("name 3", null); - - Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); - parts.add("logo", logo); - - // SPR-12108 - Resource utf8 = new ClassPathResource("/org/springframework/http/converter/logo.jpg") { - @Override - public String getFilename() { - return "Hall\u00F6le.jpg"; - } - }; - parts.add("utf8", utf8); - - MyBean myBean = new MyBean(); - myBean.setString("foo"); - HttpHeaders entityHeaders = new HttpHeaders(); - entityHeaders.setContentType(APPLICATION_JSON); - HttpEntity entity = new HttpEntity<>(myBean, entityHeaders); - parts.add("json", entity); - - Map parameters = new LinkedHashMap<>(2); - parameters.put("charset", UTF_8.name()); - parameters.put("foo", "bar"); - - StreamingMockHttpOutputMessage outputMessage = new StreamingMockHttpOutputMessage(); - this.converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage); - - final MediaType contentType = outputMessage.getHeaders().getContentType(); - assertThat(contentType.getParameters()).containsKeys("charset", "boundary", "foo"); // gh-21568, gh-25839 - - // see if Commons FileUpload can read what we wrote - FileUpload fileUpload = new FileUpload(); - fileUpload.setFileItemFactory(new DiskFileItemFactory()); - RequestContext requestContext = new MockHttpOutputMessageRequestContext(outputMessage); - List items = fileUpload.parseRequest(requestContext); - assertThat(items).hasSize(6); - FileItem item = items.get(0); - assertThat(item.isFormField()).isTrue(); - assertThat(item.getFieldName()).isEqualTo("name 1"); - assertThat(item.getString()).isEqualTo("value 1"); - - item = items.get(1); - assertThat(item.isFormField()).isTrue(); - assertThat(item.getFieldName()).isEqualTo("name 2"); - assertThat(item.getString()).isEqualTo("value 2+1"); - - item = items.get(2); - assertThat(item.isFormField()).isTrue(); - assertThat(item.getFieldName()).isEqualTo("name 2"); - assertThat(item.getString()).isEqualTo("value 2+2"); - - item = items.get(3); - assertThat(item.isFormField()).isFalse(); - assertThat(item.getFieldName()).isEqualTo("logo"); - assertThat(item.getName()).isEqualTo("logo.jpg"); - assertThat(item.getContentType()).isEqualTo("image/jpeg"); - assertThat(item.getSize()).isEqualTo(logo.getFile().length()); - - item = items.get(4); - assertThat(item.isFormField()).isFalse(); - assertThat(item.getFieldName()).isEqualTo("utf8"); - assertThat(item.getName()).isEqualTo("Hall\u00F6le.jpg"); - assertThat(item.getContentType()).isEqualTo("image/jpeg"); - assertThat(item.getSize()).isEqualTo(logo.getFile().length()); - - item = items.get(5); - assertThat(item.getFieldName()).isEqualTo("json"); - assertThat(item.getContentType()).isEqualTo("application/json"); - - assertThat(outputMessage.wasRepeatable()).isTrue(); - } - - @Test - void writeMultipartWithSourceHttpMessageConverter() throws Exception { - - converter.setPartConverters(List.of( - new StringHttpMessageConverter(), - new ResourceHttpMessageConverter(), - new SourceHttpMessageConverter<>())); - - MultiValueMap parts = new LinkedMultiValueMap<>(); - parts.add("name 1", "value 1"); - parts.add("name 2", "value 2+1"); - parts.add("name 2", "value 2+2"); - parts.add("name 3", null); - - Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); - parts.add("logo", logo); - - // SPR-12108 - Resource utf8 = new ClassPathResource("/org/springframework/http/converter/logo.jpg") { - @Override - public String getFilename() { - return "Hall\u00F6le.jpg"; - } - }; - parts.add("utf8", utf8); - - Source xml = new StreamSource(new StringReader("")); - HttpHeaders entityHeaders = new HttpHeaders(); - entityHeaders.setContentType(TEXT_XML); - HttpEntity entity = new HttpEntity<>(xml, entityHeaders); - parts.add("xml", entity); - - Map parameters = new LinkedHashMap<>(2); - parameters.put("charset", UTF_8.name()); - parameters.put("foo", "bar"); - - StreamingMockHttpOutputMessage outputMessage = new StreamingMockHttpOutputMessage(); - this.converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage); - - final MediaType contentType = outputMessage.getHeaders().getContentType(); - assertThat(contentType.getParameters()).containsKeys("charset", "boundary", "foo"); // gh-21568, gh-25839 - - // see if Commons FileUpload can read what we wrote - FileUpload fileUpload = new FileUpload(); - fileUpload.setFileItemFactory(new DiskFileItemFactory()); - RequestContext requestContext = new MockHttpOutputMessageRequestContext(outputMessage); - List items = fileUpload.parseRequest(requestContext); - assertThat(items).hasSize(6); - FileItem item = items.get(0); - assertThat(item.isFormField()).isTrue(); - assertThat(item.getFieldName()).isEqualTo("name 1"); - assertThat(item.getString()).isEqualTo("value 1"); - - item = items.get(1); - assertThat(item.isFormField()).isTrue(); - assertThat(item.getFieldName()).isEqualTo("name 2"); - assertThat(item.getString()).isEqualTo("value 2+1"); - - item = items.get(2); - assertThat(item.isFormField()).isTrue(); - assertThat(item.getFieldName()).isEqualTo("name 2"); - assertThat(item.getString()).isEqualTo("value 2+2"); - - item = items.get(3); - assertThat(item.isFormField()).isFalse(); - assertThat(item.getFieldName()).isEqualTo("logo"); - assertThat(item.getName()).isEqualTo("logo.jpg"); - assertThat(item.getContentType()).isEqualTo("image/jpeg"); - assertThat(item.getSize()).isEqualTo(logo.getFile().length()); - - item = items.get(4); - assertThat(item.isFormField()).isFalse(); - assertThat(item.getFieldName()).isEqualTo("utf8"); - assertThat(item.getName()).isEqualTo("Hall\u00F6le.jpg"); - assertThat(item.getContentType()).isEqualTo("image/jpeg"); - assertThat(item.getSize()).isEqualTo(logo.getFile().length()); - - item = items.get(5); - assertThat(item.getFieldName()).isEqualTo("xml"); - assertThat(item.getContentType()).isEqualTo("text/xml"); - - assertThat(outputMessage.wasRepeatable()).isFalse(); - } - - @Test // SPR-13309 - void writeMultipartOrder() throws Exception { - MyBean myBean = new MyBean(); - myBean.setString("foo"); - - MultiValueMap parts = new LinkedMultiValueMap<>(); - parts.add("part1", myBean); - - HttpHeaders entityHeaders = new HttpHeaders(); - entityHeaders.setContentType(TEXT_XML); - HttpEntity entity = new HttpEntity<>(myBean, entityHeaders); - parts.add("part2", entity); - + void writeFormFromMap() throws IOException { + Map body = new HashMap<>(); + body.put("name 1", "value 1"); + body.put("name 2", "value 2"); MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); - this.converter.setMultipartCharset(UTF_8); - this.converter.write(parts, new MediaType("multipart", "form-data", UTF_8), outputMessage); - - final MediaType contentType = outputMessage.getHeaders().getContentType(); - assertThat(contentType.getParameter("boundary")).as("No boundary found").isNotNull(); - - // see if Commons FileUpload can read what we wrote - FileUpload fileUpload = new FileUpload(); - fileUpload.setFileItemFactory(new DiskFileItemFactory()); - RequestContext requestContext = new MockHttpOutputMessageRequestContext(outputMessage); - List items = fileUpload.parseRequest(requestContext); - assertThat(items).hasSize(2); - - FileItem item = items.get(0); - assertThat(item.isFormField()).isTrue(); - assertThat(item.getFieldName()).isEqualTo("part1"); - assertThat(item.getString()).isEqualTo("{\"string\":\"foo\"}"); - - item = items.get(1); - assertThat(item.isFormField()).isTrue(); - assertThat(item.getFieldName()).isEqualTo("part2"); - - // With developer builds we get: foo - // But on CI server we get: foo - // So... we make a compromise: - assertThat(item.getString()) - .startsWith("foo"); - } - - @Test - void writeMultipartCharset() throws Exception { - MultiValueMap parts = new LinkedMultiValueMap<>(); - Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); - parts.add("logo", logo); - - MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); - this.converter.write(parts, MULTIPART_FORM_DATA, outputMessage); - - MediaType contentType = outputMessage.getHeaders().getContentType(); - Map parameters = contentType.getParameters(); - assertThat(parameters).containsOnlyKeys("boundary"); - - this.converter.setCharset(StandardCharsets.ISO_8859_1); - - outputMessage = new MockHttpOutputMessage(); - this.converter.write(parts, MULTIPART_FORM_DATA, outputMessage); + this.converter.write(body, APPLICATION_FORM_URLENCODED, outputMessage); - parameters = outputMessage.getHeaders().getContentType().getParameters(); - assertThat(parameters).containsOnlyKeys("boundary", "charset"); - assertThat(parameters).containsEntry("charset", "ISO-8859-1"); + assertThat(outputMessage.getBodyAsString(UTF_8)) + .as("Invalid result").isEqualTo("name+2=value+2&name+1=value+1"); + assertThat(outputMessage.getHeaders().getContentType()) + .as("Invalid content-type").isEqualTo(APPLICATION_FORM_URLENCODED); + assertThat(outputMessage.getHeaders().getContentLength()) + .as("Invalid content-length").isEqualTo(outputMessage.getBodyAsBytes().length); } private void assertCanRead(MediaType mediaType) { - assertCanRead(MultiValueMap.class, mediaType); + assertCanRead(MULTI_VALUE_MAP, mediaType); } - private void assertCanRead(Class clazz, MediaType mediaType) { - assertThat(this.converter.canRead(clazz, mediaType)).as(clazz.getSimpleName() + " : " + mediaType).isTrue(); + private void assertCanRead(ResolvableType type, MediaType mediaType) { + assertThat(this.converter.canRead(type, mediaType)).as(type.toClass().getSimpleName() + " : " + mediaType).isTrue(); } - private void asssertCannotReadMultipart() { + private void assertCannotReadMultipart() { assertCannotRead(new MediaType("multipart", "*")); assertCannotRead(MULTIPART_FORM_DATA); assertCannotRead(MULTIPART_MIXED); @@ -420,21 +225,25 @@ private void asssertCannotReadMultipart() { } private void assertCannotRead(MediaType mediaType) { - assertCannotRead(MultiValueMap.class, mediaType); + assertCannotRead(MULTI_VALUE_MAP, mediaType); + } + + private void assertCannotRead(ResolvableType type, MediaType mediaType) { + assertThat(this.converter.canRead(type, mediaType)).as(type + " : " + mediaType).isFalse(); } - private void assertCannotRead(Class clazz, MediaType mediaType) { - assertThat(this.converter.canRead(clazz, mediaType)).as(clazz.getSimpleName() + " : " + mediaType).isFalse(); + private void assertCanWrite(ResolvableType type, MediaType mediaType) { + assertThat(this.converter.canWrite(type, LinkedMultiValueMap.class, mediaType)) + .as(type + " : " + mediaType).isTrue(); } private void assertCanWrite(MediaType mediaType) { - Class clazz = MultiValueMap.class; - assertThat(this.converter.canWrite(clazz, mediaType)).as(clazz.getSimpleName() + " : " + mediaType).isTrue(); + assertCanWrite(MULTI_VALUE_MAP, mediaType); } private void assertCannotWrite(MediaType mediaType) { - Class clazz = MultiValueMap.class; - assertThat(this.converter.canWrite(clazz, mediaType)).as(clazz.getSimpleName() + " : " + mediaType).isFalse(); + assertThat(this.converter.canWrite(MULTI_VALUE_MAP, MultiValueMap.class, mediaType)) + .as(MultiValueMap.class.getSimpleName() + " : " + mediaType).isFalse(); } private void assertInvalidFormIsRejectedWithSpecificException(String body) { @@ -442,80 +251,10 @@ private void assertInvalidFormIsRejectedWithSpecificException(String body) { inputMessage.getHeaders().setContentType( new MediaType("application", "x-www-form-urlencoded", StandardCharsets.ISO_8859_1)); - assertThatThrownBy(() -> this.converter.read(null, inputMessage)) + assertThatThrownBy(() -> this.converter.read(MULTI_VALUE_MAP, inputMessage, null)) .isInstanceOf(HttpMessageNotReadableException.class) .hasCauseInstanceOf(IllegalArgumentException.class) .hasMessage("Could not decode HTTP form payload"); } - - private static class StreamingMockHttpOutputMessage extends MockHttpOutputMessage implements StreamingHttpOutputMessage { - - private boolean repeatable; - - public boolean wasRepeatable() { - return this.repeatable; - } - - @Override - public void setBody(Body body) { - try { - this.repeatable = body.repeatable(); - body.writeTo(getBody()); - } - catch (IOException ex) { - throw new RuntimeException(ex); - } - } - } - - - private static class MockHttpOutputMessageRequestContext implements UploadContext { - - private final MockHttpOutputMessage outputMessage; - - private final byte[] body; - - private MockHttpOutputMessageRequestContext(MockHttpOutputMessage outputMessage) { - this.outputMessage = outputMessage; - this.body = this.outputMessage.getBodyAsBytes(); - } - - @Override - public String getCharacterEncoding() { - MediaType type = this.outputMessage.getHeaders().getContentType(); - return (type != null && type.getCharset() != null ? type.getCharset().name() : null); - } - - @Override - public String getContentType() { - MediaType type = this.outputMessage.getHeaders().getContentType(); - return (type != null ? type.toString() : null); - } - - @Override - public InputStream getInputStream() { - return new ByteArrayInputStream(body); - } - - @Override - public long contentLength() { - return body.length; - } - } - - - public static class MyBean { - - private String string; - - public String getString() { - return this.string; - } - - public void setString(String string) { - this.string = string; - } - } - } diff --git a/spring-web/src/test/java/org/springframework/http/converter/multipart/MultipartHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/multipart/MultipartHttpMessageConverterTests.java new file mode 100644 index 000000000000..8cbc5988ad77 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/multipart/MultipartHttpMessageConverterTests.java @@ -0,0 +1,574 @@ +/* + * Copyright 2026-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.multipart; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.StringReader; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +import javax.xml.transform.Source; +import javax.xml.transform.stream.StreamSource; + +import org.apache.tomcat.util.http.fileupload.FileItem; +import org.apache.tomcat.util.http.fileupload.FileUpload; +import org.apache.tomcat.util.http.fileupload.RequestContext; +import org.apache.tomcat.util.http.fileupload.UploadContext; +import org.apache.tomcat.util.http.fileupload.disk.DiskFileItemFactory; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import org.springframework.core.ResolvableType; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.http.converter.AbstractHttpMessageConverter; +import org.springframework.http.converter.ByteArrayHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConversionException; +import org.springframework.http.converter.ResourceHttpMessageConverter; +import org.springframework.http.converter.StringHttpMessageConverter; +import org.springframework.http.converter.json.JacksonJsonHttpMessageConverter; +import org.springframework.http.converter.xml.SourceHttpMessageConverter; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.testfixture.http.MockHttpInputMessage; +import org.springframework.web.testfixture.http.MockHttpOutputMessage; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Collections.singletonMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.springframework.http.MediaType.APPLICATION_JSON; +import static org.springframework.http.MediaType.MULTIPART_FORM_DATA; +import static org.springframework.http.MediaType.MULTIPART_MIXED; +import static org.springframework.http.MediaType.MULTIPART_RELATED; +import static org.springframework.http.MediaType.TEXT_XML; + + +/** + * Tests for {@link MultipartHttpMessageConverter}. + * + * @author Brian Clozel + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Sam Brannen + * @author Sebastien Deleuze + */ +class MultipartHttpMessageConverterTests { + + private MultipartHttpMessageConverter converter = new MultipartHttpMessageConverter( + List.of(new StringHttpMessageConverter(), new ByteArrayHttpMessageConverter(), + new ResourceHttpMessageConverter(), new JacksonJsonHttpMessageConverter()) + ); + + + @Test + void canRead() { + assertCanRead(MULTIPART_FORM_DATA); + assertCanRead(MULTIPART_MIXED); + assertCanRead(MULTIPART_RELATED); + assertCanRead(ResolvableType.forClass(LinkedMultiValueMap.class), MULTIPART_FORM_DATA); + assertCanRead(ResolvableType.forClassWithGenerics(LinkedMultiValueMap.class, String.class, Part.class), MULTIPART_FORM_DATA); + + assertCannotRead(ResolvableType.forClassWithGenerics(LinkedMultiValueMap.class, String.class, Object.class), MULTIPART_FORM_DATA); + } + + @Test + void canWrite() { + assertCanWrite(MULTIPART_FORM_DATA); + assertCanWrite(MULTIPART_MIXED); + assertCanWrite(MULTIPART_RELATED); + assertCanWrite(new MediaType("multipart", "form-data", UTF_8)); + assertCanWrite(MediaType.ALL); + assertCanWrite(null); + assertCanWrite(ResolvableType.forClassWithGenerics(LinkedMultiValueMap.class, String.class, Object.class), MULTIPART_FORM_DATA); + } + + @Test + void setSupportedMediaTypes() { + this.converter.setSupportedMediaTypes(List.of(MULTIPART_FORM_DATA)); + assertCannotWrite(MULTIPART_MIXED); + + this.converter.setSupportedMediaTypes(List.of(MULTIPART_MIXED)); + assertCanWrite(MULTIPART_MIXED); + } + + @Test + void addSupportedMediaTypes() { + this.converter.setSupportedMediaTypes(List.of(MULTIPART_FORM_DATA)); + assertCannotWrite(MULTIPART_MIXED); + + this.converter.addSupportedMediaTypes(MULTIPART_RELATED); + assertCanWrite(MULTIPART_RELATED); + } + + @Test + void applyDefaultCharsetToPartConverters() { + this.converter.getPartConverters().forEach(converter -> { + if (converter instanceof AbstractHttpMessageConverter abstractConverter) { + assertThat(abstractConverter.getDefaultCharset()).isIn(null, StandardCharsets.UTF_8); + } + }); + } + + @Test + void customCharsetAppliedToPartConverters() { + this.converter.setCharset(StandardCharsets.UTF_16); + this.converter.getPartConverters().forEach(converter -> { + if (converter instanceof AbstractHttpMessageConverter abstractConverter) { + assertThat(abstractConverter.getDefaultCharset()).isIn(null, StandardCharsets.UTF_16); + } + }); + } + + + private void assertCanRead(MediaType mediaType) { + assertCanRead(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class), mediaType); + } + + private void assertCanRead(ResolvableType type, MediaType mediaType) { + assertThat(this.converter.canRead(type, mediaType)).as(type + " : " + mediaType).isTrue(); + } + + private void assertCannotRead(ResolvableType type, MediaType mediaType) { + assertThat(this.converter.canRead(type, mediaType)).as(type + " : " + mediaType).isFalse(); + } + + private void assertCanWrite(ResolvableType type, MediaType mediaType) { + assertThat(this.converter.canWrite(type, MultiValueMap.class, mediaType)).as(type + " : " + mediaType).isTrue(); + } + + private void assertCanWrite(MediaType mediaType) { + assertCanWrite(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Object.class), mediaType); + } + + private void assertCannotWrite(MediaType mediaType) { + Class clazz = MultiValueMap.class; + assertThat(this.converter.canWrite(clazz, mediaType)).as(clazz.getSimpleName() + " : " + mediaType).isFalse(); + } + + + @Nested + class ReadingTests { + + @Test + void readMultipartFiles() throws Exception { + MockHttpInputMessage response = createMultipartResponse("files.multipart", "----WebKitFormBoundaryG8fJ50opQOML0oGD"); + MultiValueMap result = converter.read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class), response, null); + + assertThat(result).containsOnlyKeys("file2"); + assertThat(result.get("file2")).anyMatch(isFilePart("a.txt")) + .anyMatch(isFilePart("b.txt")); + } + + @Test + void readMultipartBrowser() throws Exception { + MockHttpInputMessage response = createMultipartResponse("firefox.multipart", "---------------------------18399284482060392383840973206"); + MultiValueMap result = converter.read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class), response, null); + + assertThat(result).containsOnlyKeys("file1", "file2", "text1", "text2"); + assertThat(result.get("file1")).anyMatch(isFilePart("a.txt")); + assertThat(result.get("file2")).anyMatch(isFilePart("a.txt")) + .anyMatch(isFilePart("b.txt")); + assertThat(result.get("text1")).anyMatch(isFormData("text1", "a")); + assertThat(result.get("text2")).anyMatch(isFormData("text2", "b")); + } + + @Test + void readMultipartInvalid() throws Exception { + MockHttpInputMessage response = createMultipartResponse("garbage-1.multipart", "boundary"); + assertThatThrownBy(() -> converter.read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class), response, null)) + .isInstanceOf(HttpMessageConversionException.class).hasMessage("Cannot decode multipart body"); + } + + @Test + void readMultipartMaxPartsExceeded() throws Exception { + MockHttpInputMessage response = createMultipartResponse("files.multipart", "----WebKitFormBoundaryG8fJ50opQOML0oGD"); + converter.setMaxParts(1); + assertThatThrownBy(() -> converter.read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class), response, null)) + .isInstanceOf(HttpMessageConversionException.class).hasMessage("Maximum number of parts exceeded: 1"); + } + + @Test + void readMultipartToFiles() throws Exception { + MockHttpInputMessage response = createMultipartResponse("files.multipart", "----WebKitFormBoundaryG8fJ50opQOML0oGD"); + converter.setMaxInMemorySize(1); + MultiValueMap result = converter.read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class), response, null); + assertThat(result).containsOnlyKeys("file2"); + } + + @Test + void readMultipartMaxInMemoryExceeded() throws Exception { + MockHttpInputMessage response = createMultipartResponse("firefox.multipart", "---------------------------18399284482060392383840973206"); + converter.setMaxInMemorySize(1); + assertThatThrownBy(() -> converter.read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class), response, null)) + .isInstanceOf(HttpMessageConversionException.class).hasMessage("Form field value exceeded the memory usage limit of 1 bytes"); + } + + @Test + void readMultipartMaxDiskUsageExceeded() throws Exception { + MockHttpInputMessage response = createMultipartResponse("firefox.multipart", "---------------------------18399284482060392383840973206"); + converter.setMaxInMemorySize(30); + converter.setMaxDiskUsagePerPart(35); + assertThatThrownBy(() -> converter.read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class), response, null)) + .isInstanceOf(HttpMessageConversionException.class).hasMessage("Part exceeded the disk usage limit of 35 bytes"); + } + + @Test + void readMultipartUnnamedPart() throws Exception { + MockHttpInputMessage response = createMultipartResponse("simple.multipart", "simple-boundary"); + assertThatThrownBy(() -> converter.read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class), response, null)) + .isInstanceOf(HttpMessageConversionException.class).hasMessage("Part #1 is unnamed"); + } + + + private MockHttpInputMessage createMultipartResponse(String fileName, String boundary) throws Exception { + InputStream stream = createStream(fileName); + MockHttpInputMessage response = new MockHttpInputMessage(stream); + response.getHeaders().setContentType( + new MediaType("multipart", "form-data", singletonMap("boundary", boundary))); + return response; + } + + private InputStream createStream(String fileName) throws IOException { + Resource resource = new ClassPathResource("/org/springframework/http/multipart/" + fileName); + return resource.getInputStream(); + } + + private Predicate isFilePart(String fileName) { + return part -> part instanceof FilePart filePart && + filePart.filename().equals(fileName); + } + + private Predicate isFormData(String name, String value) { + return part -> part instanceof FormFieldPart formFieldPart && + formFieldPart.name().equals(name) && + formFieldPart.value().equals(value); + } + + } + + @Nested + class WritingTests { + + @Test + void writeMultipart() throws Exception { + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("name 1", "value 1"); + parts.add("name 2", "value 2+1"); + parts.add("name 2", "value 2+2"); + parts.add("name 3", null); + + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + parts.add("logo", logo); + + // SPR-12108 + Resource utf8 = new ClassPathResource("/org/springframework/http/converter/logo.jpg") { + @Override + public String getFilename() { + return "Hall\u00F6le.jpg"; + } + }; + parts.add("utf8", utf8); + + MyBean myBean = new MyBean(); + myBean.setString("foo"); + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(APPLICATION_JSON); + HttpEntity entity = new HttpEntity<>(myBean, entityHeaders); + parts.add("json", entity); + + Map parameters = new LinkedHashMap<>(2); + parameters.put("charset", UTF_8.name()); + parameters.put("foo", "bar"); + + StreamingMockHttpOutputMessage outputMessage = new StreamingMockHttpOutputMessage(); + converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage); + + final MediaType contentType = outputMessage.getHeaders().getContentType(); + assertThat(contentType.getParameters()).containsKeys("charset", "boundary", "foo"); // gh-21568, gh-25839 + + // see if Commons FileUpload can read what we wrote + FileUpload fileUpload = new FileUpload(); + fileUpload.setFileItemFactory(new DiskFileItemFactory()); + RequestContext requestContext = new MockHttpOutputMessageRequestContext(outputMessage); + List items = fileUpload.parseRequest(requestContext); + assertThat(items).hasSize(6); + FileItem item = items.get(0); + assertThat(item.isFormField()).isTrue(); + assertThat(item.getFieldName()).isEqualTo("name 1"); + assertThat(item.getString()).isEqualTo("value 1"); + + item = items.get(1); + assertThat(item.isFormField()).isTrue(); + assertThat(item.getFieldName()).isEqualTo("name 2"); + assertThat(item.getString()).isEqualTo("value 2+1"); + + item = items.get(2); + assertThat(item.isFormField()).isTrue(); + assertThat(item.getFieldName()).isEqualTo("name 2"); + assertThat(item.getString()).isEqualTo("value 2+2"); + + item = items.get(3); + assertThat(item.isFormField()).isFalse(); + assertThat(item.getFieldName()).isEqualTo("logo"); + assertThat(item.getName()).isEqualTo("logo.jpg"); + assertThat(item.getContentType()).isEqualTo("image/jpeg"); + assertThat(item.getSize()).isEqualTo(logo.getFile().length()); + + item = items.get(4); + assertThat(item.isFormField()).isFalse(); + assertThat(item.getFieldName()).isEqualTo("utf8"); + assertThat(item.getName()).isEqualTo("Hall\u00F6le.jpg"); + assertThat(item.getContentType()).isEqualTo("image/jpeg"); + assertThat(item.getSize()).isEqualTo(logo.getFile().length()); + + item = items.get(5); + assertThat(item.getFieldName()).isEqualTo("json"); + assertThat(item.getContentType()).isEqualTo("application/json"); + + assertThat(outputMessage.wasRepeatable()).isTrue(); + } + + @Test + void writeMultipartWithSourceHttpMessageConverter() throws Exception { + + converter = new MultipartHttpMessageConverter(List.of( + new StringHttpMessageConverter(), + new ResourceHttpMessageConverter(), + new SourceHttpMessageConverter<>())); + + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("name 1", "value 1"); + parts.add("name 2", "value 2+1"); + parts.add("name 2", "value 2+2"); + parts.add("name 3", null); + + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + parts.add("logo", logo); + + // SPR-12108 + Resource utf8 = new ClassPathResource("/org/springframework/http/converter/logo.jpg") { + @Override + public String getFilename() { + return "Hall\u00F6le.jpg"; + } + }; + parts.add("utf8", utf8); + + Source xml = new StreamSource(new StringReader("")); + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(TEXT_XML); + HttpEntity entity = new HttpEntity<>(xml, entityHeaders); + parts.add("xml", entity); + + Map parameters = new LinkedHashMap<>(2); + parameters.put("charset", UTF_8.name()); + parameters.put("foo", "bar"); + + StreamingMockHttpOutputMessage outputMessage = new StreamingMockHttpOutputMessage(); + converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage); + + final MediaType contentType = outputMessage.getHeaders().getContentType(); + assertThat(contentType.getParameters()).containsKeys("charset", "boundary", "foo"); // gh-21568, gh-25839 + + // see if Commons FileUpload can read what we wrote + FileUpload fileUpload = new FileUpload(); + fileUpload.setFileItemFactory(new DiskFileItemFactory()); + RequestContext requestContext = new MockHttpOutputMessageRequestContext(outputMessage); + List items = fileUpload.parseRequest(requestContext); + assertThat(items).hasSize(6); + FileItem item = items.get(0); + assertThat(item.isFormField()).isTrue(); + assertThat(item.getFieldName()).isEqualTo("name 1"); + assertThat(item.getString()).isEqualTo("value 1"); + + item = items.get(1); + assertThat(item.isFormField()).isTrue(); + assertThat(item.getFieldName()).isEqualTo("name 2"); + assertThat(item.getString()).isEqualTo("value 2+1"); + + item = items.get(2); + assertThat(item.isFormField()).isTrue(); + assertThat(item.getFieldName()).isEqualTo("name 2"); + assertThat(item.getString()).isEqualTo("value 2+2"); + + item = items.get(3); + assertThat(item.isFormField()).isFalse(); + assertThat(item.getFieldName()).isEqualTo("logo"); + assertThat(item.getName()).isEqualTo("logo.jpg"); + assertThat(item.getContentType()).isEqualTo("image/jpeg"); + assertThat(item.getSize()).isEqualTo(logo.getFile().length()); + + item = items.get(4); + assertThat(item.isFormField()).isFalse(); + assertThat(item.getFieldName()).isEqualTo("utf8"); + assertThat(item.getName()).isEqualTo("Hall\u00F6le.jpg"); + assertThat(item.getContentType()).isEqualTo("image/jpeg"); + assertThat(item.getSize()).isEqualTo(logo.getFile().length()); + + item = items.get(5); + assertThat(item.getFieldName()).isEqualTo("xml"); + assertThat(item.getContentType()).isEqualTo("text/xml"); + + assertThat(outputMessage.wasRepeatable()).isFalse(); + } + + @Test // SPR-13309 + void writeMultipartOrder() throws Exception { + MyBean myBean = new MyBean(); + myBean.setString("foo"); + + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("part1", myBean); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(APPLICATION_JSON); + HttpEntity entity = new HttpEntity<>(myBean, entityHeaders); + parts.add("part2", entity); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.setMultipartCharset(UTF_8); + converter.write(parts, new MediaType("multipart", "form-data", UTF_8), outputMessage); + + final MediaType contentType = outputMessage.getHeaders().getContentType(); + assertThat(contentType.getParameter("boundary")).as("No boundary found").isNotNull(); + + // see if Commons FileUpload can read what we wrote + FileUpload fileUpload = new FileUpload(); + fileUpload.setFileItemFactory(new DiskFileItemFactory()); + RequestContext requestContext = new MockHttpOutputMessageRequestContext(outputMessage); + List items = fileUpload.parseRequest(requestContext); + assertThat(items).hasSize(2); + + FileItem item = items.get(0); + assertThat(item.isFormField()).isTrue(); + assertThat(item.getFieldName()).isEqualTo("part1"); + assertThat(item.getString()).isEqualTo("{\"string\":\"foo\"}"); + + item = items.get(1); + assertThat(item.isFormField()).isTrue(); + assertThat(item.getFieldName()).isEqualTo("part2"); + + assertThat(item.getString()) + .contains("{\"string\":\"foo\"}"); + } + + @Test + void writeMultipartCharset() throws Exception { + MultiValueMap parts = new LinkedMultiValueMap<>(); + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + parts.add("logo", logo); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(parts, MULTIPART_FORM_DATA, outputMessage); + + MediaType contentType = outputMessage.getHeaders().getContentType(); + Map parameters = contentType.getParameters(); + assertThat(parameters).containsOnlyKeys("boundary"); + + converter.setCharset(StandardCharsets.ISO_8859_1); + + outputMessage = new MockHttpOutputMessage(); + converter.write(parts, MULTIPART_FORM_DATA, outputMessage); + + parameters = outputMessage.getHeaders().getContentType().getParameters(); + assertThat(parameters).containsOnlyKeys("boundary", "charset"); + assertThat(parameters).containsEntry("charset", "ISO-8859-1"); + } + + } + + + private static class StreamingMockHttpOutputMessage extends MockHttpOutputMessage implements StreamingHttpOutputMessage { + + private boolean repeatable; + + public boolean wasRepeatable() { + return this.repeatable; + } + + @Override + public void setBody(Body body) { + try { + this.repeatable = body.repeatable(); + body.writeTo(getBody()); + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + } + } + + + private static class MockHttpOutputMessageRequestContext implements UploadContext { + + private final MockHttpOutputMessage outputMessage; + + private final byte[] body; + + private MockHttpOutputMessageRequestContext(MockHttpOutputMessage outputMessage) { + this.outputMessage = outputMessage; + this.body = this.outputMessage.getBodyAsBytes(); + } + + @Override + public String getCharacterEncoding() { + MediaType type = this.outputMessage.getHeaders().getContentType(); + return (type != null && type.getCharset() != null ? type.getCharset().name() : null); + } + + @Override + public String getContentType() { + MediaType type = this.outputMessage.getHeaders().getContentType(); + return (type != null ? type.toString() : null); + } + + @Override + public InputStream getInputStream() { + return new ByteArrayInputStream(body); + } + + @Override + public long contentLength() { + return body.length; + } + } + + + public static class MyBean { + + private String string; + + public String getString() { + return this.string; + } + + public void setString(String string) { + this.string = string; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/multipart/MultipartParserTests.java b/spring-web/src/test/java/org/springframework/http/converter/multipart/MultipartParserTests.java new file mode 100644 index 000000000000..2110e9ba39b5 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/multipart/MultipartParserTests.java @@ -0,0 +1,282 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.multipart; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; +import java.util.function.Consumer; + +import org.jspecify.annotations.NonNull; +import org.junit.jupiter.api.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.ContentDisposition; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.converter.HttpMessageConversionException; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link MultipartParser}. + * + * @author Brian Clozel + * @author Arjen Poutsma + */ +class MultipartParserTests { + + private static final MediaType TEXT_PLAIN_ASCII = new MediaType("text", "plain", StandardCharsets.US_ASCII); + + @Test + void simple() throws Exception { + TestListener listener = new TestListener(); + parse("simple.multipart", "simple-boundary", listener); + + listener.assertHeader(headers -> assertThat(headers.isEmpty()).isTrue()) + .assertBodyChunk("This is implicitly typed plain ASCII text.\r\nIt does NOT end with a linebreak.") + .assertHeader(headers -> assertThat(headers.getContentType()).isEqualTo(TEXT_PLAIN_ASCII)) + .assertBodyChunk("This is explicitly typed plain ASCII text.\r\nIt DOES end with a linebreak.\r\n") + .assertComplete(); + } + + @Test + void noHeaders() throws Exception { + TestListener listener = new TestListener(); + parse("no-header.multipart", "boundary", listener); + + listener.assertHeader(headers -> assertThat(headers.isEmpty()).isTrue()) + .assertBodyChunk("a") + .assertComplete(); + } + + @Test + void noEndBoundary() throws Exception { + TestListener listener = new TestListener(); + parse("no-end-boundary.multipart", "boundary", listener); + + assertThat(listener.error).isInstanceOf(HttpMessageConversionException.class); + } + + @Test + void garbage() throws Exception { + TestListener listener = new TestListener(); + parse("garbage-1.multipart", "boundary", listener); + + assertThat(listener.error).isInstanceOf(HttpMessageConversionException.class); + } + + @Test + void noEndHeader() throws Exception { + TestListener listener = new TestListener(); + parse("no-end-header.multipart", "boundary", listener); + + assertThat(listener.error).isInstanceOf(HttpMessageConversionException.class); + } + + @Test + void noEndBody() throws Exception { + TestListener listener = new TestListener(); + parse("no-end-body.multipart", "boundary", listener); + + assertThat(listener.error).isInstanceOf(HttpMessageConversionException.class); + } + + @Test + void noBody() throws Exception { + TestListener listener = new TestListener(); + parse("no-body.multipart", "boundary", listener); + + listener.assertHeader(headers -> assertThat(headers.hasHeaderValues("Part", List.of("1"))).isTrue()) + .assertHeader(headers -> assertThat(headers.hasHeaderValues("Part", List.of("2"))).isTrue()) + .assertBodyChunk("a") + .assertComplete(); + } + + @Test + void firefox() throws Exception { + TestListener listener = new TestListener(); + parse("firefox.multipart", + "---------------------------18399284482060392383840973206", listener); + + listener.assertHeadersFormField("text1") + .assertBodyChunk("a") + .assertHeadersFormField("text2") + .assertBodyChunk("b") + .assertHeadersFile("file1", "a.txt") + .assertBodyChunk() + .assertHeadersFile("file2", "a.txt") + .assertBodyChunk() + .assertHeadersFile("file2", "b.txt") + .assertBodyChunk() + .assertComplete(); + } + + @Test + void chrome() throws Exception { + TestListener listener = new TestListener(); + parse("chrome.multipart", + "----WebKitFormBoundaryEveBLvRT65n21fwU", listener); + + listener.assertHeadersFormField("text1") + .assertBodyChunk("a") + .assertHeadersFormField("text2") + .assertBodyChunk("b") + .assertHeadersFile("file1", "a.txt") + .assertBodyChunk() + .assertHeadersFile("file2", "a.txt") + .assertBodyChunk() + .assertHeadersFile("file2", "b.txt") + .assertBodyChunk() + .assertComplete(); + } + + @Test + void safari() throws Exception { + TestListener listener = new TestListener(); + parse("safari.multipart", + "----WebKitFormBoundaryG8fJ50opQOML0oGD", listener); + + listener.assertHeadersFormField("text1") + .assertBodyChunk("a") + .assertHeadersFormField("text2") + .assertBodyChunk("b") + .assertHeadersFile("file1", "a.txt") + .assertBodyChunk() + .assertHeadersFile("file2", "a.txt") + .assertBodyChunk() + .assertHeadersFile("file2", "b.txt") + .assertBodyChunk() + .assertComplete(); + } + + @Test + void utf8Headers() throws Exception { + TestListener listener = new TestListener(); + parse("utf8.multipart", "simple-boundary", listener); + + listener.assertHeader(headers -> + assertThat(headers.hasHeaderValues("Føø", List.of("Bår"))).isTrue()) + .assertBodyChunk("This is plain ASCII text.") + .assertComplete(); + } + + private InputStream createStream(String fileName) throws IOException { + Resource resource = new ClassPathResource("/org/springframework/http/multipart/" + fileName); + return resource.getInputStream(); + } + + private void parse(String fileName, String boundary, MultipartParser.PartListener listener) throws Exception { + try (InputStream input = createStream(fileName)) { + MultipartParser multipartParser = new MultipartParser(10 * 1024, 4 * 1024); + multipartParser.parse(input, boundary.getBytes(UTF_8), StandardCharsets.UTF_8, listener); + } + } + + + static class TestListener implements MultipartParser.PartListener { + + Deque received = new ArrayDeque<>(); + + boolean complete; + + Throwable error; + + @Override + public void onHeaders(@NonNull HttpHeaders headers) { + this.received.add(headers); + } + + @Override + public void onBody(@NonNull DataBuffer buffer, boolean last) { + this.received.add(buffer); + } + + @Override + public void onComplete() { + this.complete = true; + } + + @Override + public void onError(@NonNull Throwable error) { + this.error = error; + } + + TestListener assertHeader(Consumer headersConsumer) { + Object value = received.pollFirst(); + assertThat(value).isInstanceOf(HttpHeaders.class); + headersConsumer.accept((HttpHeaders) value); + return this; + } + + TestListener assertHeadersFormField(String expectedName) { + return assertHeader(headers -> { + ContentDisposition cd = headers.getContentDisposition(); + assertThat(cd.isFormData()).isTrue(); + assertThat(cd.getName()).isEqualTo(expectedName); + }); + } + + TestListener assertHeadersFile(String expectedName, String expectedFilename) { + return assertHeader(headers -> { + ContentDisposition cd = headers.getContentDisposition(); + assertThat(cd.isFormData()).isTrue(); + assertThat(cd.getName()).isEqualTo(expectedName); + assertThat(cd.getFilename()).isEqualTo(expectedFilename); + }); + } + + TestListener assertBodyChunk(Consumer bodyConsumer) { + Object value = received.pollFirst(); + assertThat(value).isInstanceOf(DataBuffer.class); + bodyConsumer.accept((DataBuffer) value); + DataBufferUtils.release((DataBuffer) value); + return this; + } + + TestListener assertBodyChunk(String bodyContent) { + return assertBodyChunk(buffer -> { + String actual = buffer.toString(UTF_8); + assertThat(actual).isEqualTo(bodyContent); + }); + } + + TestListener assertBodyChunk() { + return assertBodyChunk(buffer -> { + }); + } + + TestListener assertLastBodyChunk() { + if (!received.isEmpty()) { + assertThat(received.peek()).isNotInstanceOf(DataBuffer.class); + } + return this; + } + + void assertComplete() { + assertThat(this.complete).isTrue(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerHttpStatusTests.java b/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerHttpStatusTests.java index 085158845e68..fe38c2256171 100644 --- a/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerHttpStatusTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerHttpStatusTests.java @@ -57,6 +57,7 @@ * Tests for {@link DefaultResponseErrorHandler} handling of specific * HTTP status codes. */ +@SuppressWarnings("removal") class DefaultResponseErrorHandlerHttpStatusTests { private final DefaultResponseErrorHandler handler = new DefaultResponseErrorHandler(); diff --git a/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java b/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java index 5fbfacd397c6..e0b313a11899 100644 --- a/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java @@ -44,6 +44,7 @@ * @author Juergen Hoeller * @author Denys Ivano */ +@SuppressWarnings("removal") class DefaultResponseErrorHandlerTests { private final DefaultResponseErrorHandler handler = new DefaultResponseErrorHandler(); diff --git a/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java b/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java index 51f6b04a63ca..50121cb9049b 100644 --- a/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java @@ -44,7 +44,7 @@ * * @author Arjen Poutsma */ -@SuppressWarnings("ALL") +@SuppressWarnings("removal") class ExtractingResponseErrorHandlerTests { private ExtractingResponseErrorHandler errorHandler; diff --git a/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java b/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java index 3ee14938577a..fadf846063cf 100644 --- a/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java @@ -52,6 +52,7 @@ * @author Brian Clozel * @author Sam Brannen */ +@SuppressWarnings("removal") class HttpMessageConverterExtractorTests { private final HttpMessageConverter converter = mock(); diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java index 6f5550ce6a87..c83371ca4f1f 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java @@ -33,7 +33,7 @@ import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; -import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; import org.springframework.web.util.DefaultUriBuilderFactory; import static org.assertj.core.api.Assertions.assertThat; @@ -49,7 +49,7 @@ */ class RestClientBuilderTests { - @SuppressWarnings("unchecked") + @SuppressWarnings({"unchecked", "removal"}) @Test void createFromRestTemplate() { JettyClientHttpRequestFactory requestFactory = new JettyClientHttpRequestFactory(); @@ -90,6 +90,7 @@ void createFromRestTemplate() { } @Test + @SuppressWarnings("removal") void defaultUriBuilderFactory() { RestTemplate restTemplate = new RestTemplate(); @@ -154,7 +155,7 @@ void configureMessageConverters() { assertThat(fieldValue("messageConverters", restClient)) .asInstanceOf(InstanceOfAssertFactories.LIST) - .hasExactlyElementsOfTypes(StringHttpMessageConverter.class, AllEncompassingFormHttpMessageConverter.class); + .hasExactlyElementsOfTypes(StringHttpMessageConverter.class, MultipartHttpMessageConverter.class); } @Test diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java index 5d8de466c91c..684b34d30c94 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java @@ -24,6 +24,7 @@ import java.lang.annotation.Target; import java.net.URI; import java.net.URISyntaxException; +import java.nio.file.Files; import java.util.List; import java.util.function.Consumer; import java.util.function.Function; @@ -40,6 +41,8 @@ import org.junit.jupiter.params.provider.MethodSource; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatusCode; @@ -54,6 +57,9 @@ import org.springframework.http.client.JettyClientHttpRequestFactory; import org.springframework.http.client.ReactorClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.http.converter.multipart.FilePart; +import org.springframework.http.converter.multipart.FormFieldPart; +import org.springframework.http.converter.multipart.Part; import org.springframework.util.CollectionUtils; import org.springframework.util.FastByteArrayOutputStream; import org.springframework.util.FileCopyUtils; @@ -354,6 +360,35 @@ void retrieveJsonEmpty(ClientHttpRequestFactory requestFactory) throws IOExcepti assertThat(result).isNull(); } + @ParameterizedRestClientTest + void retrieveMultipart(ClientHttpRequestFactory requestFactory) throws IOException { + startServer(requestFactory); + Resource resource = new ClassPathResource("simple.multipart", getClass()); + String multipartBody = Files.readString(resource.getFile().toPath()); + + prepareResponse(builder -> builder + .setHeader("Content-Type", "multipart/form-data; boundary=---------------------------testboundary") + .body(multipartBody)); + + MultiValueMap result = this.restClient.get() + .uri("/multipart") + .accept(MediaType.MULTIPART_FORM_DATA) + .retrieve() + .body(new ParameterizedTypeReference<>() {}); + + assertThat(result).hasSize(3); + assertThat(result).containsKeys("text1", "text2", "file1"); + assertThat(result.get("text1").get(0)).isInstanceOfSatisfying(FormFieldPart.class, part -> assertThat(part.value()).isEqualTo("a")); + assertThat(result.get("text2").get(0)).isInstanceOfSatisfying(FormFieldPart.class, part -> assertThat(part.value()).isEqualTo("b")); + assertThat(result.get("file1").get(0)).isInstanceOfSatisfying(FilePart.class, part -> assertThat(part.filename()).isEqualTo("a.txt")); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getTarget()).isEqualTo("/multipart"); + assertThat(request.getHeaders().get(HttpHeaders.ACCEPT)).isEqualTo("multipart/form-data"); + }); + } + @ParameterizedRestClientTest void retrieve404(ClientHttpRequestFactory requestFactory) throws IOException { startServer(requestFactory); diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java index 42903d2d5826..6a99e795b751 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java @@ -53,8 +53,8 @@ import org.springframework.http.client.JettyClientHttpRequestFactory; import org.springframework.http.client.ReactorClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; -import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.json.MappingJacksonValue; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -82,6 +82,7 @@ * @author Brian Clozel * @author Sam Brannen */ +@SuppressWarnings("removal") class RestTemplateIntegrationTests extends AbstractMockWebServerTests { @Retention(RetentionPolicy.RUNTIME) @@ -350,23 +351,25 @@ private MultiValueMap createMultipartParts() { private void addSupportedMediaTypeToFormHttpMessageConverter(MediaType mediaType) { this.template.getMessageConverters().stream() - .filter(FormHttpMessageConverter.class::isInstance) - .map(FormHttpMessageConverter.class::cast) + .filter(MultipartHttpMessageConverter.class::isInstance) + .map(MultipartHttpMessageConverter.class::cast) .findFirst() - .orElseThrow(() -> new IllegalStateException("Failed to find FormHttpMessageConverter")) + .orElseThrow(() -> new IllegalStateException("Failed to find MultipartHttpMessageConverter")) .addSupportedMediaTypes(mediaType); } @ParameterizedRestTemplateTest void form(ClientHttpRequestFactory clientHttpRequestFactory) { setUpClient(clientHttpRequestFactory); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); MultiValueMap form = new LinkedMultiValueMap<>(); form.add("name 1", "value 1"); form.add("name 2", "value 2+1"); form.add("name 2", "value 2+2"); - template.postForLocation(baseUrl + "/form", form); + template.exchange(baseUrl + "/form", POST, new HttpEntity<>(form, headers), Void.class); } @ParameterizedRestTemplateTest diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateObservationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateObservationTests.java index 442fe5593ce2..7b9fc0817d67 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateObservationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateObservationTests.java @@ -55,6 +55,7 @@ * Tests for the client HTTP observations with {@link RestTemplate}. * @author Brian Clozel */ +@SuppressWarnings("removal") class RestTemplateObservationTests { diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java index 241b81464976..cfeb32f892fe 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -83,7 +83,7 @@ * @author Brian Clozel * @author Sam Brannen */ -@SuppressWarnings("unchecked") +@SuppressWarnings({"unchecked", "removal"}) class RestTemplateTests { private final ClientHttpRequestFactory requestFactory = mock(); diff --git a/spring-web/src/test/java/org/springframework/web/client/support/RestClientAdapterTests.java b/spring-web/src/test/java/org/springframework/web/client/support/RestClientAdapterTests.java index f04ec8662739..b097c958d480 100644 --- a/spring-web/src/test/java/org/springframework/web/client/support/RestClientAdapterTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/support/RestClientAdapterTests.java @@ -80,7 +80,7 @@ * @author Rossen Stoyanchev * @author Brian Clozel */ -@SuppressWarnings("JUnitMalformedDeclaration") +@SuppressWarnings({"JUnitMalformedDeclaration", "removal"}) class RestClientAdapterTests { private final MockWebServer anotherServer = new MockWebServer(); diff --git a/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java b/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java index 7f5eeee513b0..e84789af192c 100644 --- a/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java +++ b/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java @@ -24,7 +24,8 @@ import jakarta.servlet.http.Part; import org.junit.jupiter.api.Test; -import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.MediaType; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.multipart.MaxUploadSizeExceededException; @@ -87,7 +88,7 @@ void multipartFileResource() throws IOException { map.add(name, multipartFile.getResource()); MockHttpOutputMessage output = new MockHttpOutputMessage(); - new FormHttpMessageConverter().write(map, null, output); + new MultipartHttpMessageConverter().write(map, null, output); assertThat(output.getBodyAsString(StandardCharsets.UTF_8)).contains(""" Content-Disposition: form-data; name="file"; filename="myFile.txt" @@ -166,6 +167,7 @@ void commonsFileUploadFileCountLimitException() { private static StandardMultipartHttpServletRequest requestWithPart(String name, String disposition, String content) { MockHttpServletRequest request = new MockHttpServletRequest(); + request.setContentType(MediaType.MULTIPART_FORM_DATA_VALUE); MockPart part = new MockPart(name, null, content.getBytes(StandardCharsets.UTF_8)); part.getHeaders().set("Content-Disposition", disposition); request.addPart(part); diff --git a/spring-web/src/test/kotlin/org/springframework/web/client/RestOperationsExtensionsTests.kt b/spring-web/src/test/kotlin/org/springframework/web/client/RestOperationsExtensionsTests.kt index 0771be69dfe7..4a98fce0dad3 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/client/RestOperationsExtensionsTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/client/RestOperationsExtensionsTests.kt @@ -34,6 +34,7 @@ import kotlin.reflect.jvm.kotlinFunction * * @author Sebastien Deleuze */ +@Suppress("REMOVAL", "DEPRECATION") class RestOperationsExtensionsTests { val template = mockk() diff --git a/spring-web/src/test/kotlin/org/springframework/web/client/support/KotlinRestTemplateHttpServiceProxyTests.kt b/spring-web/src/test/kotlin/org/springframework/web/client/support/KotlinRestTemplateHttpServiceProxyTests.kt index 4b11f2def0a3..0da04bb11b1b 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/client/support/KotlinRestTemplateHttpServiceProxyTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/client/support/KotlinRestTemplateHttpServiceProxyTests.kt @@ -28,7 +28,6 @@ import org.springframework.http.ResponseEntity import org.springframework.util.LinkedMultiValueMap import org.springframework.util.MultiValueMap import org.springframework.web.bind.annotation.* -import org.springframework.web.client.RestTemplate import org.springframework.web.multipart.MultipartFile import org.springframework.web.service.annotation.GetExchange import org.springframework.web.service.annotation.PostExchange @@ -46,6 +45,7 @@ import java.util.* * * @author Olga Maciaszek-Sharma */ +@Suppress("REMOVAL", "DEPRECATION") class KotlinRestTemplateHttpServiceProxyTests { private lateinit var server: MockWebServer @@ -65,7 +65,7 @@ class KotlinRestTemplateHttpServiceProxyTests { } private fun initTestService(): TestService { - val restTemplate = RestTemplate() + val restTemplate = org.springframework.web.client.RestTemplate() restTemplate.uriTemplateHandler = DefaultUriBuilderFactory(server.url("/").toString()) return HttpServiceProxyFactory.builder() .exchangeAdapter(RestTemplateAdapter.create(restTemplate)) diff --git a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt index b7da2b669f2e..a3dbd8f91a43 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt @@ -155,6 +155,13 @@ class InvocableHandlerMethodKotlinTests { Assertions.assertThat(value).isEqualTo(1L) } + @Test + fun valueClassWithNullableAndUnderlyingValue() { + composite.addResolver(StubArgumentResolver(LongValueClass::class.java, 1L)) + val value = getInvocable(ValueClassHandler::valueClassWithNullable.javaMethod!!).invokeForRequest(request, null) + Assertions.assertThat(value).isEqualTo(1L) + } + @Test fun valueClassWithNullable() { composite.addResolver(StubArgumentResolver(LongValueClass::class.java, null)) @@ -215,6 +222,14 @@ class InvocableHandlerMethodKotlinTests { StepVerifier.create(value as Mono).verifyComplete() } + @Test + fun suspendingValueClassWithNullableAndUnderlyingValue() { + composite.addResolver(ContinuationHandlerMethodArgumentResolver()) + composite.addResolver(StubArgumentResolver(LongValueClass::class.java, 1L)) + val value = getInvocable(SuspendingValueClassHandler::valueClassWithNullable.javaMethod!!).invokeForRequest(request, null) + StepVerifier.create(value as Mono).expectNext(1L).verifyComplete() + } + @Test fun suspendingValueClassWithPrivateConstructor() { composite.addResolver(ContinuationHandlerMethodArgumentResolver()) diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/chrome.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/chrome.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/chrome.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/chrome.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/empty-part.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/empty-part.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/empty-part.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/empty-part.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/files.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/files.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/files.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/files.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/firefox.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/firefox.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/firefox.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/firefox.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/garbage-1.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/garbage-1.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/garbage-1.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/garbage-1.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/invalid.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/invalid.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/invalid.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/invalid.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/no-body.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/no-body.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/no-body.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/no-body.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/no-end-body.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/no-end-body.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/no-end-body.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/no-end-body.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/no-end-boundary.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/no-end-boundary.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/no-end-boundary.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/no-end-boundary.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/no-end-header.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/no-end-header.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/no-end-header.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/no-end-header.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/no-header.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/no-header.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/no-header.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/no-header.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/part-no-end-boundary.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/part-no-end-boundary.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/part-no-end-boundary.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/part-no-end-boundary.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/part-no-header.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/part-no-header.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/part-no-header.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/part-no-header.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/safari.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/safari.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/safari.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/safari.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/simple.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/simple.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/simple.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/simple.multipart diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/utf8.multipart b/spring-web/src/test/resources/org/springframework/http/multipart/utf8.multipart similarity index 100% rename from spring-web/src/test/resources/org/springframework/http/codec/multipart/utf8.multipart rename to spring-web/src/test/resources/org/springframework/http/multipart/utf8.multipart diff --git a/spring-web/src/test/resources/org/springframework/web/client/simple.multipart b/spring-web/src/test/resources/org/springframework/web/client/simple.multipart new file mode 100644 index 000000000000..de58b2ea79ca --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/web/client/simple.multipart @@ -0,0 +1,15 @@ +-----------------------------testboundary +Content-Disposition: form-data; name="text1" + +a +-----------------------------testboundary +Content-Disposition: form-data; name="text2" + +b +-----------------------------testboundary +Content-Disposition: form-data; name="file1"; filename="a.txt" +Content-Type: text/plain + +Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer iaculis metus id vestibulum nullam. + +-----------------------------testboundary-- diff --git a/spring-webflux/spring-webflux.gradle b/spring-webflux/spring-webflux.gradle index dad853c54300..16c9be3a271f 100644 --- a/spring-webflux/spring-webflux.gradle +++ b/spring-webflux/spring-webflux.gradle @@ -32,6 +32,7 @@ dependencies { optional("org.jetbrains.kotlin:kotlin-reflect") optional("org.jetbrains.kotlin:kotlin-stdlib") optional("org.jetbrains.kotlinx:kotlinx-coroutines-reactor") + optional("io.micrometer:context-propagation") optional("org.webjars:webjars-locator-lite") optional("tools.jackson.core:jackson-databind") optional("tools.jackson.dataformat:jackson-dataformat-smile") diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/ApiVersionConfigurer.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/ApiVersionConfigurer.java index 2cc00b359043..2be74419aca3 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/config/ApiVersionConfigurer.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/ApiVersionConfigurer.java @@ -205,8 +205,9 @@ public ApiVersionConfigurer detectSupportedVersions(boolean detect) { * {@link #addSupportedVersions} and {@link #detectSupportedVersions}. * @param predicate the predicate to use */ - public void setSupportedVersionPredicate(@Nullable Predicate> predicate) { + public ApiVersionConfigurer setSupportedVersionPredicate(@Nullable Predicate> predicate) { this.supportedVersionPredicate = predicate; + return this; } /** diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index c06ea62a2aa3..11e312bbe8b9 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -356,9 +356,10 @@ private static class KotlinDelegate { Object arg = args[index]; if (!(parameter.isOptional() && arg == null)) { KType type = parameter.getType(); - if (!type.isMarkedNullable() && + if (!(type.isMarkedNullable() && arg == null) && type.getClassifier() instanceof KClass kClass && - KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass))) { + KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass)) && + !JvmClassMappingKt.getJavaClass(kClass).isInstance(arg)) { arg = box(kClass, arg); } argMap.put(parameter, arg); @@ -378,9 +379,10 @@ private static class KotlinDelegate { private static Object box(KClass kClass, @Nullable Object arg) { KFunction constructor = Objects.requireNonNull(KClasses.getPrimaryConstructor(kClass)); KType type = constructor.getParameters().get(0).getType(); - if (!type.isMarkedNullable() && + if (!(type.isMarkedNullable() && arg == null) && type.getClassifier() instanceof KClass parameterClass && - KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(parameterClass))) { + KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(parameterClass)) && + !JvmClassMappingKt.getJavaClass(parameterClass).isInstance(arg)) { arg = box(parameterClass, arg); } if (!KCallablesJvm.isAccessible(constructor)) { diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt index d993ea1eacc0..c426db2aabcb 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt @@ -16,6 +16,8 @@ package org.springframework.web.reactive.function.client +import io.micrometer.context.ContextRegistry +import io.micrometer.context.ContextSnapshotFactory import kotlinx.coroutines.Job import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.flow.Flow @@ -25,6 +27,7 @@ import kotlinx.coroutines.withContext import org.reactivestreams.Publisher import org.springframework.core.ParameterizedTypeReference import org.springframework.http.ResponseEntity +import org.springframework.util.ClassUtils import org.springframework.web.reactive.function.client.CoExchangeFilterFunction.Companion.COROUTINE_CONTEXT_ATTRIBUTE import org.springframework.web.reactive.function.client.WebClient.RequestBodySpec import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec @@ -237,8 +240,25 @@ suspend inline fun WebClient.ResponseSpec.awaitEntity(): Respo } } +private val contextPropagationPresent = ClassUtils.isPresent("io.micrometer.context.ContextSnapshotFactory", + WebClient::class.java.classLoader) + @PublishedApi internal fun CoroutineContext.toReactorContext(): ReactorContext { - val context = Context.of(COROUTINE_CONTEXT_ATTRIBUTE, this).readOnly() - return (this[ReactorContext.Key]?.context?.putAll(context) ?: context).asCoroutineContext() + var context = Context.of(COROUTINE_CONTEXT_ATTRIBUTE, this) + if (contextPropagationPresent) { + context = ContextPropagationDelegate.captureThreadLocalsInto(context) + } + val readOnlyContext = context.readOnly() + return (this[ReactorContext.Key]?.context?.putAll(readOnlyContext) ?: readOnlyContext).asCoroutineContext() +} + +private object ContextPropagationDelegate { + + private val contextSnapshotFactory = ContextSnapshotFactory.builder() + .contextRegistry(ContextRegistry.getInstance()).build() + + fun captureThreadLocalsInto(context: Context): Context { + return contextSnapshotFactory.captureAll().updateContext(context) + } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/HttpMessageWriterViewTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/HttpMessageWriterViewTests.java index c935e60c7761..71a04a610ebe 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/HttpMessageWriterViewTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/HttpMessageWriterViewTests.java @@ -55,7 +55,8 @@ void supportedMediaTypes() { assertThat(this.view.getSupportedMediaTypes()).containsExactly( MediaType.APPLICATION_JSON, MediaType.parseMediaType("application/*+json"), - MediaType.APPLICATION_NDJSON); + MediaType.APPLICATION_NDJSON, + MediaType.APPLICATION_JSONL); } @Test diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt index 514deffc602b..78943f1013c2 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt @@ -16,6 +16,9 @@ package org.springframework.web.reactive.function.client +import io.micrometer.observation.Observation +import io.micrometer.observation.ObservationHandler +import io.micrometer.observation.ObservationRegistry import io.mockk.every import io.mockk.mockk import io.mockk.slot @@ -30,11 +33,13 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.reactivestreams.Publisher import org.springframework.core.ParameterizedTypeReference +import org.springframework.core.PropagationContextElement import org.springframework.http.HttpHeaders import org.springframework.http.HttpStatus import org.springframework.http.ResponseEntity import org.springframework.web.reactive.function.client.CoExchangeFilterFunction.Companion.COROUTINE_CONTEXT_ATTRIBUTE import reactor.core.publisher.Flux +import reactor.core.publisher.Hooks import reactor.core.publisher.Mono import java.time.Duration import java.util.concurrent.CompletableFuture @@ -433,9 +438,88 @@ class WebClientExtensionsTests { } } + @Test + fun `awaitExchange preserves parent observation with automatic context propagation`() { + Hooks.enableAutomaticContextPropagation() + try { + val observationRegistry = ObservationRegistry.create() + val contextObservationHandler = ContextObservationHandler() + observationRegistry.observationConfig().observationHandler(contextObservationHandler) + val exchangeFunction = mockk() + val mockResponse = mockk() + every { exchangeFunction.exchange(any()) } returns Mono.just(mockResponse) + every { mockResponse.statusCode() } returns HttpStatus.OK + every { mockResponse.releaseBody() } returns Mono.empty() + + val parent = Observation.start("parent", observationRegistry) + val scope = parent.openScope() + try { + runBlocking(PropagationContextElement()) { + val webClient = WebClient.builder() + .exchangeFunction(exchangeFunction) + .observationRegistry(observationRegistry) + .build() + webClient.get().uri("/path1").awaitExchange { it.statusCode() } + webClient.get().uri("/path2").awaitExchange { it.statusCode() } + } + } finally { + scope.close() + parent.stop() + } + assertThat(contextObservationHandler.parentObservation).containsExactly(true, true) + } finally { + Hooks.disableAutomaticContextPropagation() + } + } + + @Test + fun `awaitBody preserves parent observation with automatic context propagation`() { + Hooks.enableAutomaticContextPropagation() + try { + val observationRegistry = ObservationRegistry.create() + val contextObservationHandler = ContextObservationHandler() + observationRegistry.observationConfig().observationHandler(contextObservationHandler) + val exchangeFunction = mockk() + val mockResponse = mockk() + every { exchangeFunction.exchange(any()) } returns Mono.just(mockResponse) + every { mockResponse.statusCode() } returns HttpStatus.OK + every { mockResponse.bodyToMono(object : ParameterizedTypeReference() {}) } returns Mono.just("body") + + val parent = Observation.start("parent", observationRegistry) + val scope = parent.openScope() + try { + runBlocking(PropagationContextElement()) { + val webClient = WebClient.builder() + .exchangeFunction(exchangeFunction) + .observationRegistry(observationRegistry) + .build() + webClient.get().uri("/path1").retrieve().awaitBody() + webClient.get().uri("/path2").retrieve().awaitBody() + } + } finally { + scope.close() + parent.stop() + } + assertThat(contextObservationHandler.parentObservation).containsExactly(true, true) + } finally { + Hooks.disableAutomaticContextPropagation() + } + } + class Foo private data class FooContextElement(val foo: Foo) : AbstractCoroutineContextElement(FooContextElement) { companion object Key : CoroutineContext.Key } + + private class ContextObservationHandler : ObservationHandler { + + val parentObservation = mutableListOf() + + override fun onStart(context: ClientRequestObservationContext) { + parentObservation.add(context.parentObservation != null) + } + + override fun supportsContext(context: Observation.Context) = context is ClientRequestObservationContext + } } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt index 9e41a2223996..f8041148619c 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt @@ -258,6 +258,14 @@ class InvocableHandlerMethodKotlinTests { assertHandlerResultValue(result, "1") } + @Test + fun valueClassWithNullableAndUnderlyingValue() { + this.resolvers.add(stubResolver(1L, LongValueClass::class.java)) + val method = ValueClassController::valueClassWithNullable.javaMethod!! + val result = invoke(ValueClassController(), method) + assertHandlerResultValue(result, "1") + } + @Test fun valueClassWithNullable() { this.resolvers.add(stubResolver(null, LongValueClass::class.java)) @@ -320,6 +328,14 @@ class InvocableHandlerMethodKotlinTests { assertHandlerResultValue(result, "null") } + @Test + fun suspendingValueClassWithNullableAndUnderlyingValue() { + this.resolvers.add(stubResolver(1L, LongValueClass::class.java)) + val method = SuspendingValueClassController::valueClassWithNullable.javaMethod!! + val result = invoke(SuspendingValueClassController(), method) + assertHandlerResultValue(result, "1") + } + @Test fun suspendingValueClassWithPrivateConstructor() { this.resolvers.add(stubResolver(1L, Long::class.java)) @@ -590,4 +606,4 @@ class InvocableHandlerMethodKotlinTests { } class CustomException(message: String) : Throwable(message) -} \ No newline at end of file +} diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt index b932e5204b42..6bc6aa7e34cf 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt @@ -131,7 +131,6 @@ class CoroutinesIntegrationTests : AbstractRequestMappingIntegrationTests() { assertThat(entity.body).isEqualTo("foobar") } - @Configuration @EnableWebFlux @ComponentScan(resourcePattern = "**/CoroutinesIntegrationTests*") @@ -207,7 +206,6 @@ class CoroutinesIntegrationTests : AbstractRequestMappingIntegrationTests() { } return ResponseEntity.ok().body(strings) } - } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesValueClassIntegrationTest.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesValueClassIntegrationTest.kt new file mode 100644 index 000000000000..b468dd077f7c --- /dev/null +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesValueClassIntegrationTest.kt @@ -0,0 +1,125 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.result.method.annotation + +import kotlinx.coroutines.delay +import org.assertj.core.api.Assertions.assertThat +import org.springframework.context.ApplicationContext +import org.springframework.context.annotation.AnnotationConfigApplicationContext +import org.springframework.context.annotation.ComponentScan +import org.springframework.context.annotation.Configuration +import org.springframework.http.HttpHeaders +import org.springframework.http.HttpStatus +import org.springframework.web.bind.annotation.GetMapping +import org.springframework.web.bind.annotation.RequestParam +import org.springframework.web.bind.annotation.RestController +import org.springframework.web.reactive.config.EnableWebFlux +import org.springframework.web.testfixture.http.server.reactive.bootstrap.HttpServer +import java.util.UUID + +class CoroutinesValueClassIntegrationTest : AbstractRequestMappingIntegrationTests() { + + override fun initApplicationContext(): ApplicationContext { + val context = AnnotationConfigApplicationContext() + context.register(WebConfig::class.java) + context.refresh() + return context + } + + + @ParameterizedHttpServerTest + fun `Suspending handler method with nullable value class request param`(httpServer: HttpServer) { + startServer(httpServer) + + val entity = performGet("/suspend-value-class?value=550e8400-e29b-41d4-a716-446655440000", HttpHeaders.EMPTY, String::class.java) + assertThat(entity.statusCode).isEqualTo(HttpStatus.OK) + assertThat(entity.body).isEqualTo("550e8400-e29b-41d4-a716-446655440000") + } + + @ParameterizedHttpServerTest + fun `Suspending handler method with nullable value class request param omitted`(httpServer: HttpServer) { + startServer(httpServer) + + val entity = performGet("/suspend-value-class", HttpHeaders.EMPTY, String::class.java) + assertThat(entity.statusCode).isEqualTo(HttpStatus.OK) + assertThat(entity.body).isEqualTo("outer-null") + } + + @ParameterizedHttpServerTest + fun `Suspending handler method with non-optional nullable inner value class request param`(httpServer: HttpServer) { + startServer(httpServer) + + val entity = performGet("/suspend-nullable-inner-value-class", HttpHeaders.EMPTY, String::class.java) + assertThat(entity.statusCode).isEqualTo(HttpStatus.OK) + assertThat(entity.body).isEqualTo("inner-null") + } + + @ParameterizedHttpServerTest + fun `Suspending handler method with optional nullable inner value class request param`(httpServer: HttpServer) { + startServer(httpServer) + + val entity = performGet("/suspend-nullable-inner-value-class-optional", HttpHeaders.EMPTY, String::class.java) + assertThat(entity.statusCode).isEqualTo(HttpStatus.OK) + assertThat(entity.body).isEqualTo("outer-null") + } + + + @Configuration + @EnableWebFlux + @ComponentScan(resourcePattern = "**/CoroutinesValueClassIntegrationTest*") + open class WebConfig + + @RestController + class CoroutinesController { + + @GetMapping("/suspend-value-class") + suspend fun suspendingValueClassEndpoint(@RequestParam value: ValueClass?): String { + delay(1) + return when (value) { + null -> "outer-null" + else -> value.value.toString() + } + } + + @GetMapping("/suspend-nullable-inner-value-class") + suspend fun suspendingNullableInnerValueClassEndpoint( + @RequestParam(required = false) value: NullableInnerValueClass + ): String { + delay(1) + return if (value.value == null) "inner-null" else value.value.toString() + } + + @GetMapping("/suspend-nullable-inner-value-class-optional") + suspend fun suspendingOptionalNullableInnerValueClassEndpoint( + @RequestParam(required = false) value: NullableInnerValueClass? + ): String { + delay(1) + return when { + value == null -> "outer-null" + value.value == null -> "inner-null" + else -> value.value.toString() + } + } + } + + @JvmInline + value class ValueClass(val value: UUID) + + @JvmInline + value class NullableInnerValueClass(val value: UUID?) + +} diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java index 3d6356f1e2e0..5d10ec940011 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java @@ -30,7 +30,6 @@ import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; -import jakarta.servlet.http.HttpServletResponseWrapper; import org.jspecify.annotations.Nullable; import org.springframework.beans.BeanUtils; @@ -49,7 +48,6 @@ import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.core.env.ConfigurableEnvironment; import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatusCode; import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; @@ -172,7 +170,7 @@ public abstract class FrameworkServlet extends HttpServletBean implements Applic * HTTP methods supported by {@link jakarta.servlet.http.HttpServlet}. */ private static final Set HTTP_SERVLET_METHODS = - Set.of("DELETE", "HEAD", "GET", "OPTIONS", "POST", "PUT", "TRACE"); + Set.of("DELETE", "HEAD", "GET", "OPTIONS", "PATCH", "POST", "PUT", "TRACE"); /** ServletContext attribute to find the WebApplicationContext in. */ @@ -864,7 +862,7 @@ public void destroy() { /** * Override the parent class implementation in order to intercept requests - * using PATCH or non-standard HTTP methods (WebDAV). + * using non-standard HTTP methods (such as WebDAV). */ @Override protected void service(HttpServletRequest request, HttpServletResponse response) @@ -892,6 +890,18 @@ protected final void doGet(HttpServletRequest request, HttpServletResponse respo processRequest(request, response); } + /** + * Delegate {@code PATCH} requests to {@link #processRequest}. + * @since 7.1 + * @see #doService + */ + @Override + protected final void doPatch(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + processRequest(request, response); + } + /** * Delegate {@code POST} requests to {@link #processRequest}. * @see #doService @@ -943,16 +953,7 @@ protected void doOptions(HttpServletRequest request, HttpServletResponse respons } } - // Use response wrapper in order to always add PATCH to the allowed methods - super.doOptions(request, new HttpServletResponseWrapper(response) { - @Override - public void setHeader(String name, String value) { - if (HttpHeaders.ALLOW.equals(name)) { - value = (StringUtils.hasLength(value) ? value + ", " : "") + HttpMethod.PATCH.name(); - } - super.setHeader(name, value); - } - }); + super.doOptions(request, response); } /** @@ -1154,9 +1155,9 @@ private void publishRequestHandledEvent(HttpServletRequest request, HttpServletR /** * Subclasses must implement this method to do the work of request handling, - * receiving a centralized callback for {@code GET}, {@code POST}, {@code PUT}, - * {@code DELETE}, {@code OPTIONS}, and {@code TRACE} requests as well as for - * requests using non-standard HTTP methods (such as WebDAV). + * receiving a centralized callback for {@code GET}, {@code PATCH}, {@code POST}, + * {@code PUT}, {@code DELETE}, {@code OPTIONS}, and {@code TRACE} requests + * as well as for requests using non-standard HTTP methods (such as WebDAV). *

    The contract is essentially the same as that for the commonly overridden * {@code doGet} or {@code doPost} methods of HttpServlet. *

    This class intercepts calls to ensure that exception handling and diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/ApiVersionConfigurer.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/ApiVersionConfigurer.java index b7bff5573c50..0c13497285d4 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/ApiVersionConfigurer.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/ApiVersionConfigurer.java @@ -205,8 +205,9 @@ public ApiVersionConfigurer detectSupportedVersions(boolean detect) { * {@link #addSupportedVersions} and {@link #detectSupportedVersions}. * @param predicate the predicate to use */ - public void setSupportedVersionPredicate(@Nullable Predicate> predicate) { + public ApiVersionConfigurer setSupportedVersionPredicate(@Nullable Predicate> predicate) { this.supportedVersionPredicate = predicate; + return this; } /** diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExceptionHandlerExceptionResolver.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExceptionHandlerExceptionResolver.java index 879cc78b6ac6..2dcb87750664 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExceptionHandlerExceptionResolver.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExceptionHandlerExceptionResolver.java @@ -36,9 +36,11 @@ import org.springframework.http.HttpStatusCode; import org.springframework.http.MediaType; import org.springframework.http.converter.ByteArrayHttpMessageConverter; +import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverters; import org.springframework.http.converter.StringHttpMessageConverter; -import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; import org.springframework.ui.ModelMap; import org.springframework.web.ErrorResponse; import org.springframework.web.HttpMediaTypeNotAcceptableException; @@ -293,7 +295,9 @@ private void initMessageConverters() { } this.messageConverters.add(new ByteArrayHttpMessageConverter()); this.messageConverters.add(new StringHttpMessageConverter()); - this.messageConverters.add(new AllEncompassingFormHttpMessageConverter()); + this.messageConverters.add(new FormHttpMessageConverter()); + this.messageConverters.add(new MultipartHttpMessageConverter(HttpMessageConverters.forServer() + .registerDefaults().build())); } private void initExceptionHandlerAdviceCache() { diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java index 1894fb0fe46d..1833702d3711 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java @@ -194,11 +194,11 @@ public boolean isReactiveType(Class type) { } /** - * Attempts to find a concrete {@code MediaType} that can be streamed (as json separated - * by newlines in the response body). This method considers two concrete types - * {@code APPLICATION_NDJSON} and {@code APPLICATION_STREAM_JSON}) as well as any - * subtype of application that has the {@code +x-ndjson} suffix. In the later case, - * the media type MUST be concrete for it to be considered. + * Attempts to find a concrete {@code MediaType} that can be streamed (as JSON payloads + * separated by newlines in the response body). This method considers {@code APPLICATION_JSONL}, + * {@code APPLICATION_NDJSON}, and any subtype of application + * that has the {@code +x-ndjson} suffix. In the latter case, the media type MUST be + * concrete for it to be considered. * *

    For example {@code application/vnd.myapp+x-ndjson} is considered a streaming type * while {@code application/*+x-ndjson} isn't. @@ -223,6 +223,9 @@ public boolean isReactiveType(Class type) { else if (MediaType.APPLICATION_NDJSON.includes(acceptedType)) { return MediaType.APPLICATION_NDJSON; } + else if (MediaType.APPLICATION_JSONL.includes(acceptedType)) { + return MediaType.APPLICATION_JSONL; + } } return null; // not a concrete streaming type } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java index a48695012f0f..7de750f3ed56 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java @@ -51,9 +51,11 @@ import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.http.converter.ByteArrayHttpMessageConverter; +import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverters; import org.springframework.http.converter.StringHttpMessageConverter; -import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; import org.springframework.ui.ModelMap; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -579,7 +581,9 @@ private void initMessageConverters() { this.messageConverters.add(new ByteArrayHttpMessageConverter()); this.messageConverters.add(new StringHttpMessageConverter()); - this.messageConverters.add(new AllEncompassingFormHttpMessageConverter()); + this.messageConverters.add(new FormHttpMessageConverter()); + this.messageConverters.add(new MultipartHttpMessageConverter(HttpMessageConverters.forServer() + .registerDefaults().build())); } private void initControllerAdviceCache() { diff --git a/spring-webmvc/src/main/resources/org/springframework/web/servlet/config/spring-mvc.xsd b/spring-webmvc/src/main/resources/org/springframework/web/servlet/config/spring-mvc.xsd index 7b8da7ba4fcc..596218fd4e6b 100644 --- a/spring-webmvc/src/main/resources/org/springframework/web/servlet/config/spring-mvc.xsd +++ b/spring-webmvc/src/main/resources/org/springframework/web/servlet/config/spring-mvc.xsd @@ -213,7 +213,7 @@ By default, a SimpleAsyncTaskExecutor is used which does not re-use threads and is not recommended for production. As of 5.0 this executor is also used when a controller returns a reactive type that does streaming - (for example, "text/event-stream" or "application/x-ndjson") for the blocking writes to the + (for example, "text/event-stream", "application/x-ndjson" or "application/jsonl") for the blocking writes to the "jakarta.servlet.ServletOutputStream". ]]> diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java index 26b6cceb9732..1bb17460bacb 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java @@ -801,7 +801,7 @@ protected ConfigurableWebEnvironment createEnvironment() { assertThat(custom.getEnvironment()).isInstanceOf(CustomServletEnvironment.class); } - @Test + @Test // gh-36247 void allowedOptionsIncludesPatchMethod() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "OPTIONS", "/foo"); MockHttpServletResponse response = spy(new MockHttpServletResponse()); @@ -809,7 +809,7 @@ void allowedOptionsIncludesPatchMethod() throws Exception { servlet.setDispatchOptionsRequest(false); servlet.service(request, response); verify(response, never()).getHeader(anyString()); // SPR-10341 - assertThat(response.getHeader("Allow")).isEqualTo("GET, HEAD, POST, PUT, DELETE, TRACE, OPTIONS, PATCH"); + assertThat(response.getHeader("Allow")).isEqualTo("GET, HEAD, PATCH, POST, PUT, DELETE, TRACE, OPTIONS"); } @Test diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java index 24f5f5c99baf..4388a0a08a73 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java @@ -39,7 +39,7 @@ import org.springframework.http.converter.HttpMessageConverters; import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.http.converter.json.JacksonJsonHttpMessageConverter; -import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; import org.springframework.scheduling.concurrent.ConcurrentTaskExecutor; import org.springframework.stereotype.Controller; import org.springframework.util.AntPathMatcher; @@ -213,7 +213,7 @@ void requestMappingHandlerAdapter() { assertThat(converters).hasSize(3); assertThat(converters.get(0).getClass()).isEqualTo(StringHttpMessageConverter.class); assertThat(converters.get(1).getClass()).isEqualTo(JacksonJsonHttpMessageConverter.class); - assertThat(converters.get(2).getClass()).isEqualTo(AllEncompassingFormHttpMessageConverter.class); + assertThat(converters.get(2).getClass()).isEqualTo(MultipartHttpMessageConverter.class); JsonMapper jsonMapper = ((JacksonJsonHttpMessageConverter) converters.get(1)).getMapper(); assertThat(jsonMapper.deserializationConfig().isEnabled(MapperFeature.DEFAULT_VIEW_INCLUSION)).isFalse(); assertThat(jsonMapper.deserializationConfig().isEnabled(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)).isFalse(); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java index f101e1c66d19..65e4fd2c1f39 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java @@ -35,6 +35,8 @@ import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; @@ -287,10 +289,11 @@ void writeServerSentEventsWithBuilder() throws Exception { assertThat(emitterHandler.getValuesAsText()).isEqualTo("id:1\ndata:foo\n\nid:2\ndata:bar\n\nid:3\ndata:baz\n\n"); } - @Test - void writeStreamJson() throws Exception { + @ParameterizedTest + @ValueSource(strings = {"application/jsonl", "application/x-ndjson"}) + void writeStreamJson(String mediaType) throws Exception { - this.servletRequest.addHeader("Accept", "application/x-ndjson"); + this.servletRequest.addHeader("Accept", mediaType); Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer(); ResponseBodyEmitter emitter = handleValue(sink.asFlux(), Flux.class, forClass(Bar.class)); @@ -308,7 +311,7 @@ void writeStreamJson() throws Exception { sink.tryEmitNext(bar2); sink.tryEmitComplete(); - assertThat(message.getHeaders().getContentType()).hasToString("application/x-ndjson"); + assertThat(message.getHeaders().getContentType()).hasToString(mediaType); assertThat(emitterHandler.getValues()).isEqualTo(Arrays.asList(bar1, "\n", bar2, "\n")); } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java index fa7eb4bebecd..91e018465dd4 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java @@ -50,7 +50,7 @@ import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.ResourceHttpMessageConverter; import org.springframework.http.converter.json.JacksonJsonHttpMessageConverter; -import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.http.converter.multipart.MultipartHttpMessageConverter; import org.springframework.stereotype.Controller; import org.springframework.util.FileSystemUtils; import org.springframework.util.LinkedMultiValueMap; @@ -133,12 +133,11 @@ void setup() { converters.add(new ResourceHttpMessageConverter()); converters.add(new JacksonJsonHttpMessageConverter()); - AllEncompassingFormHttpMessageConverter formConverter = new AllEncompassingFormHttpMessageConverter(); - formConverter.setPartConverters(converters); + MultipartHttpMessageConverter converter = new MultipartHttpMessageConverter(converters); this.restClient = RestClient.builder().baseUrl(baseUrl) .requestFactory(new HttpComponentsClientHttpRequestFactory()) - .configureMessageConverters(clientBuilder -> clientBuilder.addCustomConverter(formConverter)) + .configureMessageConverters(clientBuilder -> clientBuilder.addCustomConverter(converter)) .build(); } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestResponseBodyMethodProcessorTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestResponseBodyMethodProcessorTests.java index 0b321d5a6aa8..fc0c968df0c0 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestResponseBodyMethodProcessorTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestResponseBodyMethodProcessorTests.java @@ -49,13 +49,13 @@ import org.springframework.http.ProblemDetail; import org.springframework.http.ResponseEntity; import org.springframework.http.converter.ByteArrayHttpMessageConverter; +import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageNotReadableException; import org.springframework.http.converter.ResourceHttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.http.converter.json.JacksonJsonHttpMessageConverter; import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; -import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; import org.springframework.http.converter.xml.JacksonXmlHttpMessageConverter; import org.springframework.http.converter.xml.MappingJackson2XmlHttpMessageConverter; import org.springframework.util.MultiValueMap; @@ -135,7 +135,7 @@ void resolveArgumentRawTypeFromParameterizedType() throws Exception { this.servletRequest.setContent(content.getBytes(StandardCharsets.UTF_8)); this.servletRequest.setContentType(MediaType.APPLICATION_FORM_URLENCODED_VALUE); - List> converters = List.of(new AllEncompassingFormHttpMessageConverter()); + List> converters = List.of(new FormHttpMessageConverter()); RequestResponseBodyMethodProcessor processor = new RequestResponseBodyMethodProcessor(converters); @SuppressWarnings("unchecked") diff --git a/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodValueClassKotlinTests.kt b/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodValueClassKotlinTests.kt new file mode 100644 index 000000000000..611d541eca4c --- /dev/null +++ b/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodValueClassKotlinTests.kt @@ -0,0 +1,119 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.servlet.mvc.method.annotation + +import org.assertj.core.api.Assertions.assertThat +import org.springframework.web.bind.annotation.RequestMapping +import org.springframework.web.bind.annotation.RequestParam +import org.springframework.web.bind.annotation.RestController +import org.springframework.web.context.request.async.WebAsyncUtils +import org.springframework.web.servlet.handler.PathPatternsParameterizedTest +import org.springframework.web.testfixture.servlet.MockHttpServletRequest +import org.springframework.web.testfixture.servlet.MockHttpServletResponse +import java.util.UUID +import java.util.stream.Stream + +class ServletAnnotationControllerHandlerMethodValueClassKotlinTests : AbstractServletHandlerMethodTests() { + + companion object { + @JvmStatic + fun pathPatternsArguments(): Stream { + return Stream.of(true, false) + } + } + + @PathPatternsParameterizedTest + fun suspendingValueClass(usePathPatterns: Boolean) { + initDispatcherServlet(CoroutinesController::class.java, usePathPatterns) + + val request = MockHttpServletRequest("GET", "/suspending-value-class") + request.isAsyncSupported = true + request.addParameter("value", "550e8400-e29b-41d4-a716-446655440000") + val response = MockHttpServletResponse() + servlet.service(request, response) + assertThat(WebAsyncUtils.getAsyncManager(request).concurrentResult).isEqualTo("550e8400-e29b-41d4-a716-446655440000") + } + + @PathPatternsParameterizedTest + fun suspendingValueClassOmitted(usePathPatterns: Boolean) { + initDispatcherServlet(CoroutinesController::class.java, usePathPatterns) + + val request = MockHttpServletRequest("GET", "/suspending-value-class") + request.isAsyncSupported = true + val response = MockHttpServletResponse() + servlet.service(request, response) + assertThat(WebAsyncUtils.getAsyncManager(request).concurrentResult).isEqualTo("outer-null") + } + + @PathPatternsParameterizedTest + fun suspendingNullableInnerValueClass(usePathPatterns: Boolean) { + initDispatcherServlet(CoroutinesController::class.java, usePathPatterns) + + val request = MockHttpServletRequest("GET", "/suspending-nullable-inner-value-class") + request.isAsyncSupported = true + val response = MockHttpServletResponse() + servlet.service(request, response) + assertThat(WebAsyncUtils.getAsyncManager(request).concurrentResult).isEqualTo("inner-null") + } + + @PathPatternsParameterizedTest + fun suspendingOptionalNullableInnerValueClass(usePathPatterns: Boolean) { + initDispatcherServlet(CoroutinesController::class.java, usePathPatterns) + + val request = MockHttpServletRequest("GET", "/suspending-nullable-inner-value-class-optional") + request.isAsyncSupported = true + val response = MockHttpServletResponse() + servlet.service(request, response) + assertThat(WebAsyncUtils.getAsyncManager(request).concurrentResult).isEqualTo("outer-null") + } + + @RestController + class CoroutinesController { + + @Suppress("RedundantSuspendModifier") + @RequestMapping("/suspending-value-class") + suspend fun handleValueClass(@RequestParam value: ValueClass?): String { + return when (value) { + null -> "outer-null" + else -> value.value.toString() + } + } + + @Suppress("RedundantSuspendModifier") + @RequestMapping("/suspending-nullable-inner-value-class") + suspend fun handleNullableInnerValueClass(@RequestParam(required = false) value: NullableInnerValueClass): String { + return value.value?.toString() ?: "inner-null" + } + + @Suppress("RedundantSuspendModifier") + @RequestMapping("/suspending-nullable-inner-value-class-optional") + suspend fun handleOptionalNullableInnerValueClass(@RequestParam(required = false) value: NullableInnerValueClass?): String { + return when { + value == null -> "outer-null" + value.value == null -> "inner-null" + else -> value.value.toString() + } + } + } + + @JvmInline + value class ValueClass(val value: UUID) + + @JvmInline + value class NullableInnerValueClass(val value: UUID?) + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java index 089f7b9c8c28..14e9134b2851 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java @@ -53,7 +53,10 @@ * * @author Rossen Stoyanchev * @since 4.1 + * @deprecated as of 7.1 in favor of {@link RestClientXhrTransport}. */ +@Deprecated(since = "7.1", forRemoval = true) +@SuppressWarnings("removal") public class RestTemplateXhrTransport extends AbstractXhrTransport { private final RestOperations restTemplate; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java index b8180ee44100..59f2d673f83c 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java @@ -67,6 +67,7 @@ * * @author Rossen Stoyanchev */ +@SuppressWarnings("removal") class RestTemplateXhrTransportTests { private static final JacksonJsonSockJsMessageCodec CODEC = new JacksonJsonSockJsMessageCodec(); @@ -201,6 +202,7 @@ private InputStream getInputStream(String content) { } + @SuppressWarnings("removal") private static class TestRestTemplate extends RestTemplate { private Queue responses = new LinkedBlockingDeque<>();