diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java index 9ee4848c1..80a69a178 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java @@ -77,6 +77,15 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu .apply(getCursor(), mi.getCoordinates().replace(), actual, message, expected); } if (args.size() == 3) { + // When actual is integral but delta is floating-point, use isEqualTo instead of isCloseTo + // to avoid type mismatch (e.g. AbstractLongAssert.isCloseTo requires Offset, not Offset) + if (isIntegralType(actual)) { + return JavaTemplate.builder("assertThat(#{any()}).isEqualTo(#{any()});") + .staticImports(ASSERTJ + ".assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual, expected); + } maybeAddImport(ASSERTJ, "within", false); return JavaTemplate.builder("assertThat(#{any()}).isCloseTo(#{any()}, within(#{any()}));") .staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within") @@ -85,10 +94,17 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu .apply(getCursor(), mi.getCoordinates().replace(), actual, expected, args.get(2)); } - maybeAddImport(ASSERTJ, "within", false); - // The assertEquals is using a floating point with a delta argument and a message. Expression message = args.get(3); + if (isIntegralType(actual)) { + return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isEqualTo(#{any()});") + .staticImports(ASSERTJ + ".assertThat") + .imports("java.util.function.Supplier") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual, message, expected); + } + maybeAddImport(ASSERTJ, "within", false); return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isCloseTo(#{any()}, within(#{any()}));") .staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within") .imports("java.util.function.Supplier") @@ -107,6 +123,16 @@ private boolean isFloatingPointType(Expression expression) { JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType()); return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float; } + + private boolean isIntegralType(Expression expression) { + JavaType.FullyQualified fq = TypeUtils.asFullyQualified(expression.getType()); + if (fq != null) { + String typeName = fq.getFullyQualifiedName(); + return "java.lang.Long".equals(typeName) || "java.lang.Integer".equals(typeName); + } + JavaType.Primitive p = TypeUtils.asPrimitive(expression.getType()); + return p == JavaType.Primitive.Long || p == JavaType.Primitive.Int; + } }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java index e4ea66967..f41a25af7 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java @@ -34,7 +34,7 @@ public class JUnitAssertThrowsToAssertExceptionType extends Recipe { private static final String JUNIT_ASSERTIONS = "org.junit.jupiter.api.Assertions"; - private static final String ASSERTIONS_FOR_CLASS_TYPES = "org.assertj.core.api.AssertionsForClassTypes"; + private static final String ASSERTJ_ASSERTIONS = "org.assertj.core.api.Assertions"; private static final MethodMatcher ASSERT_THROWS_MATCHER = new MethodMatcher(JUNIT_ASSERTIONS + " assertThrows(..)"); @Getter @@ -64,7 +64,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu maybeRemoveImport(JUNIT_ASSERTIONS); maybeRemoveImport(JUNIT_ASSERTIONS + ".assertThrows"); - maybeAddImport(ASSERTIONS_FOR_CLASS_TYPES, "assertThatExceptionOfType"); + maybeAddImport(ASSERTJ_ASSERTIONS, "assertThatExceptionOfType"); List args = mi.getArguments(); @@ -74,7 +74,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu code += ".actual()"; } return JavaTemplate.builder(code) - .staticImports(ASSERTIONS_FOR_CLASS_TYPES + ".assertThatExceptionOfType") + .staticImports(ASSERTJ_ASSERTIONS + ".assertThatExceptionOfType") .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3")) .build() .apply(getCursor(), mi.getCoordinates().replace(), args.get(0), args.get(1)); @@ -85,7 +85,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu code += ".actual()"; } return JavaTemplate.builder(code) - .staticImports(ASSERTIONS_FOR_CLASS_TYPES + ".assertThatExceptionOfType") + .staticImports(ASSERTJ_ASSERTIONS + ".assertThatExceptionOfType") .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3")) .build() .apply(getCursor(), mi.getCoordinates().replace(), args.get(0), args.get(2), args.get(1)); diff --git a/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java b/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java index c05d6a732..d45d73cfb 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java @@ -131,6 +131,14 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocat } } + // Skip transformation when the assertion argument type is wider than the chained method's + // return type (e.g. isEqualTo(1L) with size() returning int -> hasSize(1L) won't compile) + if (!methodToReplaceArgumentIsEmpty && + TypeUtils.asPrimitive(mi.getArguments().get(0).getType()) == JavaType.Primitive.Long && + TypeUtils.asPrimitive(assertThatArg.getType()) == JavaType.Primitive.Int) { + return mi; + } + List arguments = new ArrayList<>(); arguments.add(actual); diff --git a/src/main/java/org/openrewrite/java/testing/junit5/UseAssertSame.java b/src/main/java/org/openrewrite/java/testing/junit5/UseAssertSame.java index f90f19c79..89c8d1674 100644 --- a/src/main/java/org/openrewrite/java/testing/junit5/UseAssertSame.java +++ b/src/main/java/org/openrewrite/java/testing/junit5/UseAssertSame.java @@ -62,6 +62,11 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocat if (binary.getOperator() != J.Binary.Type.Equal && binary.getOperator() != J.Binary.Type.NotEqual) { return mi; } + // Skip null comparisons — defer to assertNull/assertNotNull recipes + if (binary.getLeft().getType() == JavaType.Primitive.Null || + binary.getRight().getType() == JavaType.Primitive.Null) { + return mi; + } List newArguments = new ArrayList<>(); newArguments.add(binary.getLeft()); newArguments.add(binary.getRight()); diff --git a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java index bf6c321eb..fbb540d09 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java @@ -401,4 +401,82 @@ public record OwnClass(String a) { ) ); } + + @Test + void assertEqualsWithObjectReturnType() { + //language=java + rewriteRun( + java( + """ + import org.junit.jupiter.api.Test; + + import static org.junit.jupiter.api.Assertions.assertEquals; + + public class MyTest { + @Test + public void test() { + assertEquals("expected", getObject()); + } + private Object getObject() { + return "expected"; + } + } + """, + """ + import org.junit.jupiter.api.Test; + + import static org.assertj.core.api.Assertions.assertThat; + + public class MyTest { + @Test + public void test() { + assertThat(getObject()).isEqualTo("expected"); + } + private Object getObject() { + return "expected"; + } + } + """ + ) + ); + } + + @Test + void longActualWithDoubleDelta() { + //language=java + rewriteRun( + java( + """ + import org.junit.jupiter.api.Test; + + import static org.junit.jupiter.api.Assertions.assertEquals; + + public class MyTest { + @Test + public void test() { + assertEquals(50L, getSummation(), 0.0); + } + private long getSummation() { + return 50L; + } + } + """, + """ + import org.junit.jupiter.api.Test; + + import static org.assertj.core.api.Assertions.assertThat; + + public class MyTest { + @Test + public void test() { + assertThat(getSummation()).isEqualTo(50L); + } + private long getSummation() { + return 50L; + } + } + """ + ) + ); + } } diff --git a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionTypeTest.java b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionTypeTest.java index dab28a59b..06cdffc08 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionTypeTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionTypeTest.java @@ -57,7 +57,7 @@ void foo() { } """, """ - import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { @@ -95,7 +95,7 @@ public void throwsWithMemberReference() { import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; - import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class MemberReferenceTest { @@ -126,7 +126,7 @@ public void throwsExceptionWithSpecificType() { } """, """ - import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { @@ -157,7 +157,7 @@ public void throwsExceptionWithSpecificType() { } """, """ - import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { @@ -190,7 +190,7 @@ public void throwsExceptionWithSpecificType() { } """, """ - import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { @@ -226,7 +226,7 @@ NullPointerException exception() { } """, """ - import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { @@ -315,7 +315,7 @@ void foo() { } """.formatted(message), """ - import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { @@ -353,7 +353,7 @@ public void throwsExceptionWithSpecificType() { import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; - import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { @@ -385,7 +385,7 @@ void foo() { } """.formatted(message), """ - import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { @@ -422,7 +422,7 @@ void foo() { } """.formatted(message), """ - import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { diff --git a/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java b/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java index 2ab17e05f..d0d6dca02 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java @@ -689,4 +689,46 @@ void simpleTest(Optional o) { ); } } + + @Test + void mapGetIsEqualToWithPartialWildcardTypeIsNotConverted() { + rewriteRun( + spec -> spec.recipe(new SimplifyChainedAssertJAssertion("get", "isEqualTo", "containsEntry", "java.util.Map")), + //language=java + java( + """ + import java.util.Map; + + import static org.assertj.core.api.Assertions.assertThat; + + class MyTest { + void testMethod(Map map) { + assertThat(map.get("key")).isEqualTo("value"); + } + } + """ + ) + ); + } + + @Test + void sizeIsEqualToLongLiteralIsNotConverted() { + rewriteRun( + spec -> spec.recipe(new SimplifyChainedAssertJAssertion("size", "isEqualTo", "hasSize", "java.util.Collection")), + //language=java + java( + """ + import java.util.List; + + import static org.assertj.core.api.Assertions.assertThat; + + class MyTest { + void testMethod(List list) { + assertThat(list.size()).isEqualTo(1L); + } + } + """ + ) + ); + } } diff --git a/src/test/java/org/openrewrite/java/testing/junit5/UseAssertSameTest.java b/src/test/java/org/openrewrite/java/testing/junit5/UseAssertSameTest.java index 80d7ea11f..162882bd4 100644 --- a/src/test/java/org/openrewrite/java/testing/junit5/UseAssertSameTest.java +++ b/src/test/java/org/openrewrite/java/testing/junit5/UseAssertSameTest.java @@ -189,4 +189,48 @@ public void test() { ) ); } + + @Test + void assertTrueNotNullShouldNotConvert() { + //language=java + rewriteRun( + java( + """ + import org.junit.jupiter.api.Test; + + import static org.junit.jupiter.api.Assertions.assertTrue; + + class MyTest { + + @Test + public void test(Object obj) { + assertTrue(obj != null); + } + } + """ + ) + ); + } + + @Test + void assertFalseNullShouldNotConvert() { + //language=java + rewriteRun( + java( + """ + import org.junit.jupiter.api.Test; + + import static org.junit.jupiter.api.Assertions.assertFalse; + + class MyTest { + + @Test + public void test(Object obj) { + assertFalse(obj == null); + } + } + """ + ) + ); + } }