diff --git a/api/src/main/java/net/neoforged/jst/api/PsiHelper.java b/api/src/main/java/net/neoforged/jst/api/PsiHelper.java index 35bdfc9..d60a5eb 100644 --- a/api/src/main/java/net/neoforged/jst/api/PsiHelper.java +++ b/api/src/main/java/net/neoforged/jst/api/PsiHelper.java @@ -1,6 +1,7 @@ package net.neoforged.jst.api; import com.intellij.lang.jvm.types.JvmPrimitiveTypeKind; +import com.intellij.psi.JavaPsiFacade; import com.intellij.psi.PsiClass; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiExpression; @@ -14,6 +15,7 @@ import com.intellij.psi.PsiTypes; import com.intellij.psi.PsiWhiteSpace; import com.intellij.psi.SyntaxTraverser; +import com.intellij.psi.search.GlobalSearchScope; import com.intellij.psi.util.CachedValueProvider; import com.intellij.psi.util.CachedValuesManager; import com.intellij.psi.util.ClassUtil; @@ -255,4 +257,17 @@ public static PsiElement resolve(PsiReferenceExpression expression) { return null; } } + + @Nullable + public static PsiClass findClass(JavaPsiFacade facade, String name, GlobalSearchScope scope) { + var inner = name.split("\\$"); + var cls = facade.findClass(inner[0], scope); + if (cls == null) return null; + + for (int i = 1; i < inner.length; i++) { + cls = cls.findInnerClassByName(inner[i], true); + if (cls == null) return null; + } + return cls; + } } diff --git a/tests/data/unpick/const/def.unpick b/tests/data/unpick/const/def.unpick index 57ec0c9..3a4150a 100644 --- a/tests/data/unpick/const/def.unpick +++ b/tests/data/unpick/const/def.unpick @@ -2,6 +2,7 @@ unpick v3 group String com.example.Constants.VERSION + com.example.Constants$Inner1$Inner2.VALUE group float @strict diff --git a/tests/data/unpick/const/expected/com/example/Constants.java b/tests/data/unpick/const/expected/com/example/Constants.java index 4110cac..1180a02 100644 --- a/tests/data/unpick/const/expected/com/example/Constants.java +++ b/tests/data/unpick/const/expected/com/example/Constants.java @@ -6,4 +6,10 @@ public class Constants { public static final float FLOAT_CT = 2.5; public static final long LONG_VAL = 34L; + + public static class Inner1 { + public static class Inner2 { + public static final String VALUE = "value"; + } + } } diff --git a/tests/data/unpick/const/expected/com/stuff/Uses.java b/tests/data/unpick/const/expected/com/stuff/Uses.java index 5614a9b..729ee20 100644 --- a/tests/data/unpick/const/expected/com/stuff/Uses.java +++ b/tests/data/unpick/const/expected/com/stuff/Uses.java @@ -1,21 +1,22 @@ package com.stuff; import com.example.Constants; +import com.example.Constants.Inner1.Inner2; public class Uses { public String fld = Constants.VERSION; void run() { - String s = Constants.VERSION + "2"; + String s = Constants.VERSION + "2" + Inner2.VALUE; float f = Math.PI; - f = Math.PI / 3; + f = (Math.PI / 3); double d = 3.141592653589793d; // PI unpick is strict float so this should not be replaced d = Constants.FLOAT_CT; // but the other float unpick isn't so this double literal should be replaced - System.out.println(Long.toHexString((Constants.LONG_VAL + 1) * 2)); + System.out.println(Long.toHexString(((Constants.LONG_VAL + 1) * 2))); } } diff --git a/tests/data/unpick/const/source/com/example/Constants.java b/tests/data/unpick/const/source/com/example/Constants.java index 4110cac..1180a02 100644 --- a/tests/data/unpick/const/source/com/example/Constants.java +++ b/tests/data/unpick/const/source/com/example/Constants.java @@ -6,4 +6,10 @@ public class Constants { public static final float FLOAT_CT = 2.5; public static final long LONG_VAL = 34L; + + public static class Inner1 { + public static class Inner2 { + public static final String VALUE = "value"; + } + } } diff --git a/tests/data/unpick/const/source/com/stuff/Uses.java b/tests/data/unpick/const/source/com/stuff/Uses.java index 6a61a1f..9d8f5bf 100644 --- a/tests/data/unpick/const/source/com/stuff/Uses.java +++ b/tests/data/unpick/const/source/com/stuff/Uses.java @@ -4,7 +4,7 @@ public class Uses { public String fld = "1.21.4"; void run() { - String s = "1.21.4" + "2"; + String s = "1.21.4" + "2" + "value"; float f = 3.141592653589793f; diff --git a/tests/data/unpick/flags/def.unpick b/tests/data/unpick/flags/def.unpick index febf0ad..742f618 100644 --- a/tests/data/unpick/flags/def.unpick +++ b/tests/data/unpick/flags/def.unpick @@ -2,10 +2,21 @@ unpick v3 group int Flags @flags + com.example.Example.FLAG_4 * 21 com.example.Example.FLAG_1 com.example.Example.FLAG_2 com.example.Example.FLAG_3 com.example.Example.FLAG_4 +group int FlagsHex + @flags + @format hex + com.example.Example.FLAG_1 + com.example.Example.FLAG_3 + com.example.Example.FLAG_4 + target_method com.example.Example applyFlags(I)V param 0 Flags + +target_method com.example.Example applyHexFlags(I)V + param 0 FlagsHex diff --git a/tests/data/unpick/flags/expected/com/example/Example.java b/tests/data/unpick/flags/expected/com/example/Example.java index c403eaa..910b9f6 100644 --- a/tests/data/unpick/flags/expected/com/example/Example.java +++ b/tests/data/unpick/flags/expected/com/example/Example.java @@ -12,7 +12,14 @@ public static void main(String[] args) { applyFlags(Example.FLAG_1 | Example.FLAG_2); applyFlags(Example.FLAG_3 | Example.FLAG_4 | Example.FLAG_1 | 129); applyFlags(-1); + + applyFlags(~(Example.FLAG_1 | Example.FLAG_2)); + applyFlags(~Example.FLAG_4); + applyFlags((Example.FLAG_4 * 21) | Example.FLAG_3); + + applyHexFlags(Example.FLAG_3 | Example.FLAG_4 | Example.FLAG_1 | 0x81); } public static void applyFlags(int flags) {} + public static void applyHexFlags(int flags) {} } diff --git a/tests/data/unpick/flags/source/com/example/Example.java b/tests/data/unpick/flags/source/com/example/Example.java index 9222826..01c5db7 100644 --- a/tests/data/unpick/flags/source/com/example/Example.java +++ b/tests/data/unpick/flags/source/com/example/Example.java @@ -12,7 +12,14 @@ public static void main(String[] args) { applyFlags(6); applyFlags(155); applyFlags(-1); + + applyFlags(-7); + applyFlags(-17); + applyFlags(344); + + applyHexFlags(155); } public static void applyFlags(int flags) {} + public static void applyHexFlags(int flags) {} } diff --git a/tests/data/unpick/string_concat/def.unpick b/tests/data/unpick/string_concat/def.unpick new file mode 100644 index 0000000..d8b6a89 --- /dev/null +++ b/tests/data/unpick/string_concat/def.unpick @@ -0,0 +1,6 @@ +unpick v3 + +group String + com.example.Example.I1 + com.example.Example.S + (com.example.Example._S + com.example.Example.D1) + com.example.Example.I1 + com.example.Example.S + com.example.Example._S diff --git a/tests/data/unpick/string_concat/expected/com/example/Example.java b/tests/data/unpick/string_concat/expected/com/example/Example.java new file mode 100644 index 0000000..50d7733 --- /dev/null +++ b/tests/data/unpick/string_concat/expected/com/example/Example.java @@ -0,0 +1,13 @@ +package com.example; + +public class Example { + public static final int I1 = 12; + public static final double D1 = 4.53; + + public static final String S = "ab"; + public static final String _S = "ba"; + + public static void main(String[] args) { + String a = Example.I1 + Example.S, b = (Example._S + Example.D1) + Example.I1, c = Example.S + Example._S; + } +} diff --git a/tests/data/unpick/string_concat/source/com/example/Example.java b/tests/data/unpick/string_concat/source/com/example/Example.java new file mode 100644 index 0000000..7385f7d --- /dev/null +++ b/tests/data/unpick/string_concat/source/com/example/Example.java @@ -0,0 +1,13 @@ +package com.example; + +public class Example { + public static final int I1 = 12; + public static final double D1 = 4.53; + + public static final String S = "ab"; + public static final String _S = "ba"; + + public static void main(String[] args) { + String a = "12ab", b = "ba4.5312", c = "abba"; + } +} diff --git a/tests/data/unpick/strings/def.unpick b/tests/data/unpick/strings/def.unpick new file mode 100644 index 0000000..cf87924 --- /dev/null +++ b/tests/data/unpick/strings/def.unpick @@ -0,0 +1,6 @@ +unpick v3 + +group String + com.example.Example.S + com.example.Example.S2 + com.example.Example.S + com.example.Example.S + "\" adf" diff --git a/tests/data/unpick/strings/expected/com/example/Example.java b/tests/data/unpick/strings/expected/com/example/Example.java new file mode 100644 index 0000000..843ac95 --- /dev/null +++ b/tests/data/unpick/strings/expected/com/example/Example.java @@ -0,0 +1,10 @@ +package com.example; + +public class Example { + public static final String S = "a\nb\"", S2 = "ab\\"; + + public static void main(String[] args) { + String s = Example.S, s1 = Example.S2 + Example.S, s2 = Example.S + "\" adf"; + String noUnpick = "a\\nb\\\""; + } +} diff --git a/tests/data/unpick/strings/source/com/example/Example.java b/tests/data/unpick/strings/source/com/example/Example.java new file mode 100644 index 0000000..9ed40c1 --- /dev/null +++ b/tests/data/unpick/strings/source/com/example/Example.java @@ -0,0 +1,10 @@ +package com.example; + +public class Example { + public static final String S = "a\nb\"", S2 = "ab\\"; + + public static void main(String[] args) { + String s = "a\nb\"", s1 = "ab\\a\nb\"", s2 = "a\nb\"\" adf"; + String noUnpick = "a\\nb\\\""; + } +} diff --git a/tests/src/test/java/net/neoforged/jst/tests/EmbeddedTest.java b/tests/src/test/java/net/neoforged/jst/tests/EmbeddedTest.java index ca47ff6..34372f8 100644 --- a/tests/src/test/java/net/neoforged/jst/tests/EmbeddedTest.java +++ b/tests/src/test/java/net/neoforged/jst/tests/EmbeddedTest.java @@ -393,6 +393,16 @@ void testFlags() throws Exception { void testStatements() throws Exception { runUnpickTest("statements"); } + + @Test + void testStringConcat() throws Exception { + runUnpickTest("string_concat"); + } + + @Test + void testStrings() throws Exception { + runUnpickTest("strings"); + } } protected final void runInterfaceInjectionTest(String testDirName, Path tempDir, String... additionalArgs) throws Exception { diff --git a/unpick/src/main/java/net/neoforged/jst/unpick/NumberType.java b/unpick/src/main/java/net/neoforged/jst/unpick/NumberType.java index a521629..a9bae8d 100644 --- a/unpick/src/main/java/net/neoforged/jst/unpick/NumberType.java +++ b/unpick/src/main/java/net/neoforged/jst/unpick/NumberType.java @@ -361,4 +361,18 @@ public Number rshift(Number a, Number b) { public Number rshiftUnsigned(Number a, Number b) { return a.intValue() >>> b.intValue(); } + + public boolean canWidenFrom(NumberType other) { + if (other == this) return true; + for (NumberType from : widenFrom) { + if (other == from || from.canWidenFrom(other)) { + return true; + } + } + return false; + } + + public static NumberType widest(NumberType a, NumberType b) { + return a.canWidenFrom(b) ? a : b; + } } diff --git a/unpick/src/main/java/net/neoforged/jst/unpick/UnpickCollection.java b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickCollection.java index b68ba64..a445bd1 100644 --- a/unpick/src/main/java/net/neoforged/jst/unpick/UnpickCollection.java +++ b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickCollection.java @@ -12,6 +12,7 @@ import com.intellij.psi.PsiType; import com.intellij.psi.PsiTypes; import com.intellij.psi.search.GlobalSearchScope; +import com.intellij.psi.util.ClassUtil; import com.intellij.util.containers.MultiMap; import daomephsta.unpick.constantmappers.datadriven.tree.DataType; import daomephsta.unpick.constantmappers.datadriven.tree.GroupDefinition; @@ -40,7 +41,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.function.Predicate; public class UnpickCollection { private static final Key> UNPICK_DEFINITION = Key.create("unpick.method_definition"); @@ -90,7 +90,7 @@ public UnpickCollection(TransformContext context, Map byPackage.putValue(packageName, gr); case GroupScope.Class(var cls) -> byClass.putValue(cls, gr); case GroupScope.Method(var className, var method, var desc) -> { - var cls = facade.findClass(className, projectScope); + var cls = PsiHelper.findClass(facade, className, projectScope); if (cls == null) return; for (PsiMethod clsMethod : cls.getMethods()) { @@ -106,7 +106,7 @@ public UnpickCollection(TransformContext context, Map getClassContext(PsiClass cls) { - var clsName = cls.getQualifiedName(); + var clsName = ClassUtil.getJVMClassName(cls); if (clsName != null) { return byClass.get(clsName); } @@ -217,7 +217,7 @@ public static Group create(GroupDefinition def, JavaPsiFacade facade, GlobalSear private static Object resolveConstant(Expression expression, JavaPsiFacade facade, GlobalSearchScope scope) { if (expression instanceof FieldExpression fieldEx) { - var clazz = facade.findClass(fieldEx.className, scope); + var clazz = PsiHelper.findClass(facade, fieldEx.className, scope); if (clazz != null) { for (PsiField field : clazz.getAllFields()) { if (fieldEx.isStatic != field.hasModifier(JvmModifier.STATIC)) continue; @@ -252,7 +252,7 @@ private static Object resolveConstant(Expression expression, JavaPsiFacade facad var rhs = resolveConstant(binaryExpression.rhs, facade, scope); if (lhs instanceof Number l && rhs instanceof Number r) { - var type = NumberType.TYPES.get(l.getClass()); + var type = NumberType.widest(NumberType.TYPES.get(l.getClass()), NumberType.TYPES.get(r.getClass())); return switch (binaryExpression.operator) { case ADD -> type.add(l, r); case DIVIDE -> type.divide(l, r); @@ -270,8 +270,8 @@ private static Object resolveConstant(Expression expression, JavaPsiFacade facad }; } - if (lhs instanceof String lS && rhs instanceof String rS && binaryExpression.operator == BinaryExpression.Operator.ADD) { - return lS + rS; + if ((lhs instanceof String || rhs instanceof String) && binaryExpression.operator == BinaryExpression.Operator.ADD) { + return lhs.toString() + rhs.toString(); } throw new IllegalArgumentException("Cannot resolve expression: " + binaryExpression + ". Operands of type " + lhs.getClass() + " and " + rhs.getClass() + " do not support operator " + binaryExpression.operator); @@ -283,7 +283,7 @@ private static Object resolveConstant(Expression expression, JavaPsiFacade facad private static Object cast(Object in, DataType type) { return switch (type) { case BYTE -> ((Number) in).byteValue(); - case CHAR -> Character.valueOf((char)((Number) in).byteValue()); + case CHAR -> Character.valueOf((char)((Number) in).shortValue()); case SHORT -> ((Number) in).shortValue(); case INT -> ((Number) in).intValue(); case LONG -> ((Number) in).longValue(); diff --git a/unpick/src/main/java/net/neoforged/jst/unpick/UnpickVisitor.java b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickVisitor.java index 1225cd5..12aec89 100644 --- a/unpick/src/main/java/net/neoforged/jst/unpick/UnpickVisitor.java +++ b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickVisitor.java @@ -1,6 +1,7 @@ package net.neoforged.jst.unpick; import com.intellij.openapi.util.Key; +import com.intellij.openapi.util.text.StringUtil; import com.intellij.psi.JavaTokenType; import com.intellij.psi.PsiAssignmentExpression; import com.intellij.psi.PsiCallExpression; @@ -21,6 +22,7 @@ import com.intellij.psi.PsiReferenceExpression; import com.intellij.psi.PsiReturnStatement; import com.intellij.psi.PsiVariable; +import com.intellij.psi.util.ClassUtil; import daomephsta.unpick.constantmappers.datadriven.tree.GroupFormat; import daomephsta.unpick.constantmappers.datadriven.tree.Literal; import daomephsta.unpick.constantmappers.datadriven.tree.TargetMethod; @@ -244,10 +246,10 @@ private void visitToken(PsiJavaToken tok) { if (Boolean.TRUE.equals(tok.getUserData(UNPICK_WAS_REPLACED))) return; if (tok.getTokenType() == JavaTokenType.STRING_LITERAL) { - var val = tok.getText().substring(1); // Remove starting " - final var finalVal = val.substring(0, val.length() - 1); // Remove leading " + // Remove starting and leading " and unescape characters to turn from the source representation to the runtime one + final String value = StringUtil.unescapeStringCharacters(StringUtil.unquoteString(tok.getText())); forInScope(group -> { - var ct = group.constants().get(finalVal); + var ct = group.constants().get(value); if (ct != null && checkNotRecursive(ct)) { replacements.replace(tok, write(ct)); tok.putUserData(UNPICK_WAS_REPLACED, true); @@ -306,7 +308,7 @@ private boolean replaceLiteral(PsiJavaToken element, Number number, NumberType t // We'll try a direct constant first, even if it's a flag var ct = group.constants().get(number); if (ct != null && checkNotRecursive(ct)) { - replacements.replace(element, write(ct)); + replacements.replace(element, writeOptionallySurround(ct)); element.putUserData(UNPICK_WAS_REPLACED, true); replaceMinus(element); return true; @@ -337,7 +339,7 @@ private boolean replaceLiteral(PsiJavaToken element, Number number, NumberType t // Finally we try to apply non-strict widening from lower number types for (NumberType from : type.widenFrom) { var lower = from.cast(number); - if (lower.doubleValue() == number.doubleValue()) { + if (lower.doubleValue() == number.doubleValue() && lower.longValue() == number.longValue()) { if (replaceLiteral(element, lower, from, true)) { return true; } @@ -360,7 +362,7 @@ private void replaceMinus(PsiJavaToken tok) { private boolean checkNotRecursive(Expression expression) { if (fieldContext != null && fieldContext.getContainingClass() != null && expression instanceof FieldExpression fld) { - return !(fld.className.equals(fieldContext.getContainingClass().getQualifiedName()) && Objects.equals(fld.fieldName, fieldContext.getName())); + return !(fld.className.equals(ClassUtil.getJVMClassName(fieldContext.getContainingClass())) && Objects.equals(fld.fieldName, fieldContext.getName())); } return true; } @@ -400,12 +402,22 @@ private boolean forInScope(Predicate apply) { return false; } + /** + * Write the given expression, optionally surrounding it with parenthesis if + * it cannot be safely assumed not to need them. + *

+ * Only literals or top-level fields will not be surrounded. + */ + private String writeOptionallySurround(Expression expression) { + return (expression instanceof LiteralExpression || expression instanceof FieldExpression) ? write(expression) : "(" + write(expression) + ")"; + } + private String write(Expression expression) { StringBuilder s = new StringBuilder(); expression.accept(new ExpressionVisitor() { @Override public void visitFieldExpression(FieldExpression fieldExpression) { - var cls = imports().importClass(fieldExpression.className); + var cls = imports().importClass(fieldExpression.className.replace("$", ".")); s.append(cls).append('.').append(fieldExpression.fieldName); } @@ -419,7 +431,7 @@ public void visitParenExpression(ParenExpression parenExpression) { @Override public void visitLiteralExpression(LiteralExpression literalExpression) { if (literalExpression.literal instanceof Literal.String(String value)) { - s.append('\"').append(value.replace("\"", "\\\"")).append('\"'); + s.append('\"').append(StringUtil.escapeStringCharacters(value)).append('\"'); } else if (literalExpression.literal instanceof Literal.Integer i) { s.append(i.value()); } else if (literalExpression.literal instanceof Literal.Long l) { @@ -503,7 +515,11 @@ private String formatAs(Number value, GroupFormat format) { } case CHAR -> "'" + ((char) value.intValue()) + "'"; - default -> value.toString(); + default -> switch (value) { + case Long l -> l + "l"; + case Float f -> f + "f"; + default -> value.toString(); + }; }; } @@ -526,18 +542,29 @@ private String generateFlag(UnpickCollection.Group group, long val, NumberType t long residual = negated ? negatedResidual : orResidual; - StringBuilder replacement = new StringBuilder(write(constants.getFirst())); + boolean paren = false; + + StringBuilder replacement = new StringBuilder(writeOptionallySurround(constants.getFirst())); for (int i = 1; i < constants.size(); i++) { replacement.append(" | "); - replacement.append(write(constants.get(i))); + replacement.append(writeOptionallySurround(constants.get(i))); + paren = true; } if (residual != 0) { - replacement.append(" | ").append(residual); + boolean isLong = residual < Integer.MIN_VALUE || residual > Integer.MAX_VALUE; + + replacement.append(" | "); + + // The formatAs method appends l automatically to any long value + // so if it's an int we downcast it to avoid it + replacement.append(formatAs(isLong ? (Number) residual : (Number) (int) residual, Objects.requireNonNullElse(group.format(), GroupFormat.DECIMAL))); + + paren = true; } if (negated) { - return "~" + replacement; + return "~" + (paren ? ("(" + replacement + ")") : replacement); } return replacement.toString();