Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
.apply(getCursor(), mi.getCoordinates().replace(), actual, message, expected);
}
if (args.size() == 3) {
// When actual is integral but delta is floating-point, use isEqualTo instead of isCloseTo
// to avoid type mismatch (e.g. AbstractLongAssert.isCloseTo requires Offset<Long>, not Offset<Double>)
if (isIntegralType(actual)) {
return JavaTemplate.builder("assertThat(#{any()}).isEqualTo(#{any()});")
.staticImports(ASSERTJ + ".assertThat")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3"))
.build()
.apply(getCursor(), mi.getCoordinates().replace(), actual, expected);
}
maybeAddImport(ASSERTJ, "within", false);
return JavaTemplate.builder("assertThat(#{any()}).isCloseTo(#{any()}, within(#{any()}));")
.staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within")
Expand All @@ -85,10 +94,17 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
.apply(getCursor(), mi.getCoordinates().replace(), actual, expected, args.get(2));
}

maybeAddImport(ASSERTJ, "within", false);

// The assertEquals is using a floating point with a delta argument and a message.
Expression message = args.get(3);
if (isIntegralType(actual)) {
return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isEqualTo(#{any()});")
.staticImports(ASSERTJ + ".assertThat")
.imports("java.util.function.Supplier")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3"))
.build()
.apply(getCursor(), mi.getCoordinates().replace(), actual, message, expected);
}
maybeAddImport(ASSERTJ, "within", false);
return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isCloseTo(#{any()}, within(#{any()}));")
.staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within")
.imports("java.util.function.Supplier")
Expand All @@ -107,6 +123,16 @@ private boolean isFloatingPointType(Expression expression) {
JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType());
return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float;
}

private boolean isIntegralType(Expression expression) {
JavaType.FullyQualified fq = TypeUtils.asFullyQualified(expression.getType());
if (fq != null) {
String typeName = fq.getFullyQualifiedName();
return "java.lang.Long".equals(typeName) || "java.lang.Integer".equals(typeName);
}
JavaType.Primitive p = TypeUtils.asPrimitive(expression.getType());
return p == JavaType.Primitive.Long || p == JavaType.Primitive.Int;
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
public class JUnitAssertThrowsToAssertExceptionType extends Recipe {

private static final String JUNIT_ASSERTIONS = "org.junit.jupiter.api.Assertions";
private static final String ASSERTIONS_FOR_CLASS_TYPES = "org.assertj.core.api.AssertionsForClassTypes";
private static final String ASSERTJ_ASSERTIONS = "org.assertj.core.api.Assertions";
private static final MethodMatcher ASSERT_THROWS_MATCHER = new MethodMatcher(JUNIT_ASSERTIONS + " assertThrows(..)");

@Getter
Expand Down Expand Up @@ -64,7 +64,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu

maybeRemoveImport(JUNIT_ASSERTIONS);
maybeRemoveImport(JUNIT_ASSERTIONS + ".assertThrows");
maybeAddImport(ASSERTIONS_FOR_CLASS_TYPES, "assertThatExceptionOfType");
maybeAddImport(ASSERTJ_ASSERTIONS, "assertThatExceptionOfType");

List<Expression> args = mi.getArguments();

Expand All @@ -74,7 +74,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
code += ".actual()";
}
return JavaTemplate.builder(code)
.staticImports(ASSERTIONS_FOR_CLASS_TYPES + ".assertThatExceptionOfType")
.staticImports(ASSERTJ_ASSERTIONS + ".assertThatExceptionOfType")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3"))
.build()
.apply(getCursor(), mi.getCoordinates().replace(), args.get(0), args.get(1));
Expand All @@ -85,7 +85,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
code += ".actual()";
}
return JavaTemplate.builder(code)
.staticImports(ASSERTIONS_FOR_CLASS_TYPES + ".assertThatExceptionOfType")
.staticImports(ASSERTJ_ASSERTIONS + ".assertThatExceptionOfType")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3"))
.build()
.apply(getCursor(), mi.getCoordinates().replace(), args.get(0), args.get(2), args.get(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocat
}
}

// Skip transformation when the assertion argument type is wider than the chained method's
// return type (e.g. isEqualTo(1L) with size() returning int -> hasSize(1L) won't compile)
if (!methodToReplaceArgumentIsEmpty &&
TypeUtils.asPrimitive(mi.getArguments().get(0).getType()) == JavaType.Primitive.Long &&
TypeUtils.asPrimitive(assertThatArg.getType()) == JavaType.Primitive.Int) {
return mi;
}

List<Expression> arguments = new ArrayList<>();
arguments.add(actual);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocat
if (binary.getOperator() != J.Binary.Type.Equal && binary.getOperator() != J.Binary.Type.NotEqual) {
return mi;
}
// Skip null comparisons — defer to assertNull/assertNotNull recipes
if (binary.getLeft().getType() == JavaType.Primitive.Null ||
binary.getRight().getType() == JavaType.Primitive.Null) {
return mi;
}
List<Expression> newArguments = new ArrayList<>();
newArguments.add(binary.getLeft());
newArguments.add(binary.getRight());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,4 +401,82 @@ public record OwnClass(String a) {
)
);
}

@Test
void assertEqualsWithObjectReturnType() {
//language=java
rewriteRun(
java(
"""
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class MyTest {
@Test
public void test() {
assertEquals("expected", getObject());
}
private Object getObject() {
return "expected";
}
}
""",
"""
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;

public class MyTest {
@Test
public void test() {
assertThat(getObject()).isEqualTo("expected");
}
private Object getObject() {
return "expected";
}
}
"""
)
);
}

@Test
void longActualWithDoubleDelta() {
//language=java
rewriteRun(
java(
"""
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class MyTest {
@Test
public void test() {
assertEquals(50L, getSummation(), 0.0);
}
private long getSummation() {
return 50L;
}
}
""",
"""
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;

public class MyTest {
@Test
public void test() {
assertThat(getSummation()).isEqualTo(50L);
}
private long getSummation() {
return 50L;
}
}
"""
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void foo() {
}
""",
"""
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class SimpleExpectedExceptionTest {
public void throwsExceptionWithSpecificType() {
Expand Down Expand Up @@ -95,7 +95,7 @@ public void throwsWithMemberReference() {
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class MemberReferenceTest {

Expand Down Expand Up @@ -126,7 +126,7 @@ public void throwsExceptionWithSpecificType() {
}
""",
"""
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class SimpleExpectedExceptionTest {
public void throwsExceptionWithSpecificType() {
Expand Down Expand Up @@ -157,7 +157,7 @@ public void throwsExceptionWithSpecificType() {
}
""",
"""
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class SimpleExpectedExceptionTest {
public void throwsExceptionWithSpecificType() {
Expand Down Expand Up @@ -190,7 +190,7 @@ public void throwsExceptionWithSpecificType() {
}
""",
"""
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class SimpleExpectedExceptionTest {
public void throwsExceptionWithSpecificType() {
Expand Down Expand Up @@ -226,7 +226,7 @@ NullPointerException exception() {
}
""",
"""
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class SimpleExpectedExceptionTest {
public void throwsExceptionWithSpecificType() {
Expand Down Expand Up @@ -315,7 +315,7 @@ void foo() {
}
""".formatted(message),
"""
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class SimpleExpectedExceptionTest {
public void throwsExceptionWithSpecificType() {
Expand Down Expand Up @@ -353,7 +353,7 @@ public void throwsExceptionWithSpecificType() {
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class SimpleExpectedExceptionTest {
public void throwsExceptionWithSpecificType() {
Expand Down Expand Up @@ -385,7 +385,7 @@ void foo() {
}
""".formatted(message),
"""
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class SimpleExpectedExceptionTest {
public void throwsExceptionWithSpecificType() {
Expand Down Expand Up @@ -422,7 +422,7 @@ void foo() {
}
""".formatted(message),
"""
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class SimpleExpectedExceptionTest {
public void throwsExceptionWithSpecificType() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,4 +689,46 @@ void simpleTest(Optional<String> o) {
);
}
}

@Test
void mapGetIsEqualToWithPartialWildcardTypeIsNotConverted() {
rewriteRun(
spec -> spec.recipe(new SimplifyChainedAssertJAssertion("get", "isEqualTo", "containsEntry", "java.util.Map")),
//language=java
java(
"""
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

class MyTest {
void testMethod(Map<String, ?> map) {
assertThat(map.get("key")).isEqualTo("value");
}
}
"""
)
);
}

@Test
void sizeIsEqualToLongLiteralIsNotConverted() {
rewriteRun(
spec -> spec.recipe(new SimplifyChainedAssertJAssertion("size", "isEqualTo", "hasSize", "java.util.Collection")),
//language=java
java(
"""
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

class MyTest {
void testMethod(List<String> list) {
assertThat(list.size()).isEqualTo(1L);
}
}
"""
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,48 @@ public void test() {
)
);
}

@Test
void assertTrueNotNullShouldNotConvert() {
//language=java
rewriteRun(
java(
"""
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertTrue;

class MyTest {

@Test
public void test(Object obj) {
assertTrue(obj != null);
}
}
"""
)
);
}

@Test
void assertFalseNullShouldNotConvert() {
//language=java
rewriteRun(
java(
"""
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertFalse;

class MyTest {

@Test
public void test(Object obj) {
assertFalse(obj == null);
}
}
"""
)
);
}
}
Loading