diff --git a/README.md b/README.md index 2a01f1f..62d68cc 100644 --- a/README.md +++ b/README.md @@ -53,24 +53,27 @@ To create the executable jar with your custom transformer, you should shadow the Note that this tool is not intended to be run by users directly. Rather it is integrated into the [NeoGradle](https://github.com/neoforged/NeoGradle) build process. -It can be invoked as a standalone executable Jar-File. Java 17 is required. +It can be invoked as a standalone executable Jar-File. Java 21 is required. ``` -Usage: jst [-hV] [--in-format=] [--libraries-list=] +Usage: jst [-hV] [--debug] [--in-format=] [--libraries-list=] [--max-queue-depth=] [--out-format=] - [--classpath=]... [--ignore-prefix=]... - [--enable-parchment --parchment-mappings= [--[no-]parchment-javadoc] + [--problems-report=] [--classpath=]... + [--ignore-prefix=]... [--enable-parchment + --parchment-mappings= [--[no-]parchment-javadoc] [--parchment-conflict-prefix=]] [--enable-accesstransformers --access-transformer= [--access-transformer=]... [--access-transformer-validation=]] [--enable-interface-injection [--interface-injection-stubs=] [--interface-injection-marker=] - [--interface-injection-data=]...] INPUT OUTPUT + [--interface-injection-data=]...] [--enable-unpick [--unpick-data=]...] + INPUT OUTPUT INPUT Path to a single Java-file, a source-archive or a folder containing the source to transform. OUTPUT Path to where the resulting source should be placed. --classpath= Additional classpath entries to use. Is combined with --libraries-list. + --debug Print additional debugging information -h, --help Show this help message and exit. --ignore-prefix= Do not apply transformations to paths that start with any of these @@ -89,6 +92,8 @@ Usage: jst [-hV] [--in-format=] [--libraries-list=] --out-format= Specify the format of OUTPUT explicitly. Allows the same options as --in-format. + --problems-report= + Write problems to this report file. -V, --version Print version information and exit. Plugin - parchment --enable-parchment Enable parchment @@ -116,6 +121,10 @@ Plugin - interface-injection injected interfaces --interface-injection-stubs= The path to a zip to save interface stubs in +Plugin - unpick + --enable-unpick Enable unpick + --unpick-data= + The paths to read unpick definition files from ``` ## Licenses diff --git a/api/src/main/java/net/neoforged/jst/api/PsiHelper.java b/api/src/main/java/net/neoforged/jst/api/PsiHelper.java index 3e864c3..35bdfc9 100644 --- a/api/src/main/java/net/neoforged/jst/api/PsiHelper.java +++ b/api/src/main/java/net/neoforged/jst/api/PsiHelper.java @@ -2,11 +2,14 @@ import com.intellij.lang.jvm.types.JvmPrimitiveTypeKind; import com.intellij.psi.PsiClass; +import com.intellij.psi.PsiElement; +import com.intellij.psi.PsiExpression; import com.intellij.psi.PsiMethod; import com.intellij.psi.PsiModifier; import com.intellij.psi.PsiParameter; import com.intellij.psi.PsiParameterListOwner; import com.intellij.psi.PsiPrimitiveType; +import com.intellij.psi.PsiReferenceExpression; import com.intellij.psi.PsiTypeParameter; import com.intellij.psi.PsiTypes; import com.intellij.psi.PsiWhiteSpace; @@ -17,6 +20,7 @@ import com.intellij.psi.util.PsiTreeUtil; import com.intellij.util.containers.ObjectIntHashMap; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; import java.util.ArrayList; import java.util.Iterator; @@ -242,4 +246,13 @@ public static int getLastLineLength(PsiWhiteSpace psiWhiteSpace) { return psiWhiteSpace.getTextLength(); } } + + @Nullable + public static PsiElement resolve(PsiReferenceExpression expression) { + try { + return expression.resolve(); + } catch (Exception ignored) { + return null; + } + } } diff --git a/build.gradle b/build.gradle index 1d0102a..8f75a38 100644 --- a/build.gradle +++ b/build.gradle @@ -23,7 +23,7 @@ subprojects { java { toolchain { - languageVersion = JavaLanguageVersion.of(17) + languageVersion = JavaLanguageVersion.of(21) } } diff --git a/cli/build.gradle b/cli/build.gradle index a59db0d..efb6d92 100644 --- a/cli/build.gradle +++ b/cli/build.gradle @@ -34,6 +34,7 @@ dependencies { include project(":parchment") include project(":accesstransformers") include project(':interfaceinjection') + include project(':unpick') testImplementation platform("org.junit:junit-bom:$junit_version") testImplementation 'org.junit.jupiter:junit-jupiter' diff --git a/cli/src/main/java/net/neoforged/jst/cli/io/FolderFileSink.java b/cli/src/main/java/net/neoforged/jst/cli/io/FolderFileSink.java index 054b93a..50e5d49 100644 --- a/cli/src/main/java/net/neoforged/jst/cli/io/FolderFileSink.java +++ b/cli/src/main/java/net/neoforged/jst/cli/io/FolderFileSink.java @@ -18,8 +18,9 @@ public void putDirectory(String relativePath) throws IOException { public void putFile(String relativePath, FileTime lastModified, byte[] content) throws IOException { var targetPath = path.resolve(relativePath); - if (targetPath.getParent() != null && !Files.isDirectory(targetPath.getParent())) + if (targetPath.getParent() != null && !Files.isDirectory(targetPath.getParent())) { Files.createDirectories(targetPath.getParent()); + } Files.write(targetPath, content); Files.setLastModifiedTime(targetPath, lastModified); diff --git a/settings.gradle b/settings.gradle index 2bc592c..8c5fa61 100644 --- a/settings.gradle +++ b/settings.gradle @@ -41,3 +41,4 @@ include 'parchment' include 'tests' include 'accesstransformers' include 'interfaceinjection' +include 'unpick' diff --git a/tests/data/unpick/const/def.unpick b/tests/data/unpick/const/def.unpick new file mode 100644 index 0000000..57ec0c9 --- /dev/null +++ b/tests/data/unpick/const/def.unpick @@ -0,0 +1,16 @@ +unpick v3 + +group String + com.example.Constants.VERSION + +group float + @strict + java.lang.Math.PI + java.lang.Math.PI / 3 + +group float + com.example.Constants.FLOAT_CT + +group long + com.example.Constants.LONG_VAL + (com.example.Constants.LONG_VAL + 1) * 2 diff --git a/tests/data/unpick/const/expected/com/example/Constants.java b/tests/data/unpick/const/expected/com/example/Constants.java new file mode 100644 index 0000000..4110cac --- /dev/null +++ b/tests/data/unpick/const/expected/com/example/Constants.java @@ -0,0 +1,9 @@ +package com.example; + +public class Constants { + public static final String VERSION = "1.21.4"; + + public static final float FLOAT_CT = 2.5; + + public static final long LONG_VAL = 34L; +} diff --git a/tests/data/unpick/const/expected/com/stuff/Uses.java b/tests/data/unpick/const/expected/com/stuff/Uses.java new file mode 100644 index 0000000..5614a9b --- /dev/null +++ b/tests/data/unpick/const/expected/com/stuff/Uses.java @@ -0,0 +1,21 @@ +package com.stuff; + +import com.example.Constants; + +public class Uses { + public String fld = Constants.VERSION; + + void run() { + String s = Constants.VERSION + "2"; + + float f = Math.PI; + + f = Math.PI / 3; + + double d = 3.141592653589793d; // PI unpick is strict float so this should not be replaced + + d = Constants.FLOAT_CT; // but the other float unpick isn't so this double literal should be replaced + + System.out.println(Long.toHexString((Constants.LONG_VAL + 1) * 2)); + } +} diff --git a/tests/data/unpick/const/source/com/example/Constants.java b/tests/data/unpick/const/source/com/example/Constants.java new file mode 100644 index 0000000..4110cac --- /dev/null +++ b/tests/data/unpick/const/source/com/example/Constants.java @@ -0,0 +1,9 @@ +package com.example; + +public class Constants { + public static final String VERSION = "1.21.4"; + + public static final float FLOAT_CT = 2.5; + + public static final long LONG_VAL = 34L; +} diff --git a/tests/data/unpick/const/source/com/stuff/Uses.java b/tests/data/unpick/const/source/com/stuff/Uses.java new file mode 100644 index 0000000..6a61a1f --- /dev/null +++ b/tests/data/unpick/const/source/com/stuff/Uses.java @@ -0,0 +1,19 @@ +package com.stuff; + +public class Uses { + public String fld = "1.21.4"; + + void run() { + String s = "1.21.4" + "2"; + + float f = 3.141592653589793f; + + f = 1.0471975511965976f; + + double d = 3.141592653589793d; // PI unpick is strict float so this should not be replaced + + d = 2.5d; // but the other float unpick isn't so this double literal should be replaced + + System.out.println(Long.toHexString(70L)); + } +} diff --git a/tests/data/unpick/flags/def.unpick b/tests/data/unpick/flags/def.unpick new file mode 100644 index 0000000..febf0ad --- /dev/null +++ b/tests/data/unpick/flags/def.unpick @@ -0,0 +1,11 @@ +unpick v3 + +group int Flags + @flags + com.example.Example.FLAG_1 + com.example.Example.FLAG_2 + com.example.Example.FLAG_3 + com.example.Example.FLAG_4 + +target_method com.example.Example applyFlags(I)V + param 0 Flags diff --git a/tests/data/unpick/flags/expected/com/example/Example.java b/tests/data/unpick/flags/expected/com/example/Example.java new file mode 100644 index 0000000..c403eaa --- /dev/null +++ b/tests/data/unpick/flags/expected/com/example/Example.java @@ -0,0 +1,18 @@ +package com.example; + +public class Example { + public static final int + FLAG_1 = 2, + FLAG_2 = 4, + FLAG_3 = 8, + FLAG_4 = 16; + + public static void main(String[] args) { + applyFlags(Example.FLAG_1); + applyFlags(Example.FLAG_1 | Example.FLAG_2); + applyFlags(Example.FLAG_3 | Example.FLAG_4 | Example.FLAG_1 | 129); + applyFlags(-1); + } + + public static void applyFlags(int flags) {} +} diff --git a/tests/data/unpick/flags/source/com/example/Example.java b/tests/data/unpick/flags/source/com/example/Example.java new file mode 100644 index 0000000..9222826 --- /dev/null +++ b/tests/data/unpick/flags/source/com/example/Example.java @@ -0,0 +1,18 @@ +package com.example; + +public class Example { + public static final int + FLAG_1 = 2, + FLAG_2 = 4, + FLAG_3 = 8, + FLAG_4 = 16; + + public static void main(String[] args) { + applyFlags(2); + applyFlags(6); + applyFlags(155); + applyFlags(-1); + } + + public static void applyFlags(int flags) {} +} diff --git a/tests/data/unpick/formats/def.unpick b/tests/data/unpick/formats/def.unpick new file mode 100644 index 0000000..d385a4d --- /dev/null +++ b/tests/data/unpick/formats/def.unpick @@ -0,0 +1,21 @@ +unpick v3 + +group int HEXInt + @format hex +target_method com.example.Example acceptHex(I)V + param 0 HEXInt + +group int BINInt + @format binary +target_method com.example.Example acceptBin(I)V + param 0 BINInt + +group int OCTInt + @format octal +target_method com.example.Example acceptOct(I)V + param 0 OCTInt + +group int CharInt + @format char +target_method com.example.Example acceptChar(C)V + param 0 CharInt diff --git a/tests/data/unpick/formats/expected/com/example/Example.java b/tests/data/unpick/formats/expected/com/example/Example.java new file mode 100644 index 0000000..6f935ce --- /dev/null +++ b/tests/data/unpick/formats/expected/com/example/Example.java @@ -0,0 +1,17 @@ +package com.example; + +public class Example { + + void execute() { + acceptHex(0xA505); + acceptBin(0b1010100111010110000); + acceptOct(017350); + acceptChar('d'); + } + + void acceptHex(int hex) {} + void acceptBin(int b) {} + void acceptOct(int oct) {} + + void acceptChar(char c) {} +} diff --git a/tests/data/unpick/formats/source/com/example/Example.java b/tests/data/unpick/formats/source/com/example/Example.java new file mode 100644 index 0000000..fd6484e --- /dev/null +++ b/tests/data/unpick/formats/source/com/example/Example.java @@ -0,0 +1,17 @@ +package com.example; + +public class Example { + + void execute() { + acceptHex(42245); + acceptBin(347824); + acceptOct(7912); + acceptChar(100); + } + + void acceptHex(int hex) {} + void acceptBin(int b) {} + void acceptOct(int oct) {} + + void acceptChar(char c) {} +} diff --git a/tests/data/unpick/local_variables/def.unpick b/tests/data/unpick/local_variables/def.unpick new file mode 100644 index 0000000..7b17ee5 --- /dev/null +++ b/tests/data/unpick/local_variables/def.unpick @@ -0,0 +1,10 @@ +unpick v3 + +group int Color + @format hex + com.example.Example.RED + com.example.Example.PURPLE + com.example.Example.PINK + +target_method com.example.Example setColor(I)V + param 0 Color diff --git a/tests/data/unpick/local_variables/expected/com/example/Example.java b/tests/data/unpick/local_variables/expected/com/example/Example.java new file mode 100644 index 0000000..4cb8f37 --- /dev/null +++ b/tests/data/unpick/local_variables/expected/com/example/Example.java @@ -0,0 +1,21 @@ +package com.example; + +public class Example { + public static final int + RED = 0xFF0000, + PURPLE = 0x800080, + PINK = 0xFFC0CB; + + public static void acceptColor(int in) { + int color = 0xD7837F; + if (in < 0) { + color = Example.PURPLE; + } else { + color = in == 0x0 ? Example.RED : Example.PINK; + } + + setColor(color); + } + + public static void setColor(int color) {} +} diff --git a/tests/data/unpick/local_variables/source/com/example/Example.java b/tests/data/unpick/local_variables/source/com/example/Example.java new file mode 100644 index 0000000..8413564 --- /dev/null +++ b/tests/data/unpick/local_variables/source/com/example/Example.java @@ -0,0 +1,21 @@ +package com.example; + +public class Example { + public static final int + RED = 0xFF0000, + PURPLE = 0x800080, + PINK = 0xFFC0CB; + + public static void acceptColor(int in) { + int color = 14123903; + if (in < 0) { + color = 8388736; + } else { + color = in == 0 ? 16711680 : 16761035; + } + + setColor(color); + } + + public static void setColor(int color) {} +} diff --git a/tests/data/unpick/returns/def.unpick b/tests/data/unpick/returns/def.unpick new file mode 100644 index 0000000..fcc4cff --- /dev/null +++ b/tests/data/unpick/returns/def.unpick @@ -0,0 +1,9 @@ +unpick v3 + +group int Constant + com.example.Example.ONE + com.example.Example.TWO + com.example.Example.FOUR + +target_method com.example.Example getNumber(ZZ)I + return Constant diff --git a/tests/data/unpick/returns/expected/com/example/Example.java b/tests/data/unpick/returns/expected/com/example/Example.java new file mode 100644 index 0000000..23d63ca --- /dev/null +++ b/tests/data/unpick/returns/expected/com/example/Example.java @@ -0,0 +1,21 @@ +package com.example; + +public class Example { + public static final int + ONE = 1, + TWO = 2, + FOUR = 4; + + public static int getNumber(boolean odd, boolean b) { + int value = 0; + if (odd) { + value = Example.ONE; + } else if (b) { + return Example.FOUR; + } else { + value = Example.TWO; + } + + return value; + } +} diff --git a/tests/data/unpick/returns/source/com/example/Example.java b/tests/data/unpick/returns/source/com/example/Example.java new file mode 100644 index 0000000..8ecd9a4 --- /dev/null +++ b/tests/data/unpick/returns/source/com/example/Example.java @@ -0,0 +1,21 @@ +package com.example; + +public class Example { + public static final int + ONE = 1, + TWO = 2, + FOUR = 4; + + public static int getNumber(boolean odd, boolean b) { + int value = 0; + if (odd) { + value = 1; + } else if (b) { + return 4; + } else { + value = 2; + } + + return value; + } +} diff --git a/tests/data/unpick/scoped/def.unpick b/tests/data/unpick/scoped/def.unpick new file mode 100644 index 0000000..b2ee490 --- /dev/null +++ b/tests/data/unpick/scoped/def.unpick @@ -0,0 +1,14 @@ +unpick v3 + +group int + @scope class com.example.Example + com.example.Example.V1 + com.example.Example.V2 + +group int + @scope method com.Outsider anotherExecute ()V + com.Outsider.DIFFERENT_CONST + +group int + @scope package com.example + com.example.Example.FOUR diff --git a/tests/data/unpick/scoped/expected/com/Outsider.java b/tests/data/unpick/scoped/expected/com/Outsider.java new file mode 100644 index 0000000..847befd --- /dev/null +++ b/tests/data/unpick/scoped/expected/com/Outsider.java @@ -0,0 +1,17 @@ +package com; + +public class Outsider { + private static final int DIFFERENT_CONST = 12; + + public void execute() { + int i = 472; // This should NOT be unpicked to Example.V1 + + int j = 12; // This should NOT be unpicked to DIFFERENT_CONST + + int k = 4; // This should NOT be unpicked to Example.FOUR since it's outside the package + } + + public void anotherExecute() { + int i = Outsider.DIFFERENT_CONST; // This should be replaced with DIFFERENT_CONST + } +} diff --git a/tests/data/unpick/scoped/expected/com/example/Example.java b/tests/data/unpick/scoped/expected/com/example/Example.java new file mode 100644 index 0000000..a94273b --- /dev/null +++ b/tests/data/unpick/scoped/expected/com/example/Example.java @@ -0,0 +1,12 @@ +package com.example; + +public class Example { + private static final int V1 = 472, V2 = 84; + static final int FOUR = 4; + + void execute() { + System.out.println(Example.V1); + + System.out.println(Example.V2); + } +} diff --git a/tests/data/unpick/scoped/expected/com/example/Example2.java b/tests/data/unpick/scoped/expected/com/example/Example2.java new file mode 100644 index 0000000..d091542 --- /dev/null +++ b/tests/data/unpick/scoped/expected/com/example/Example2.java @@ -0,0 +1,5 @@ +package com.example; + +public class Example2 { + private final int fld = Example.FOUR; +} diff --git a/tests/data/unpick/scoped/source/com/Outsider.java b/tests/data/unpick/scoped/source/com/Outsider.java new file mode 100644 index 0000000..c6fa2d7 --- /dev/null +++ b/tests/data/unpick/scoped/source/com/Outsider.java @@ -0,0 +1,17 @@ +package com; + +public class Outsider { + private static final int DIFFERENT_CONST = 12; + + public void execute() { + int i = 472; // This should NOT be unpicked to Example.V1 + + int j = 12; // This should NOT be unpicked to DIFFERENT_CONST + + int k = 4; // This should NOT be unpicked to Example.FOUR since it's outside the package + } + + public void anotherExecute() { + int i = 12; // This should be replaced with DIFFERENT_CONST + } +} diff --git a/tests/data/unpick/scoped/source/com/example/Example.java b/tests/data/unpick/scoped/source/com/example/Example.java new file mode 100644 index 0000000..19f68fe --- /dev/null +++ b/tests/data/unpick/scoped/source/com/example/Example.java @@ -0,0 +1,12 @@ +package com.example; + +public class Example { + private static final int V1 = 472, V2 = 84; + static final int FOUR = 4; + + void execute() { + System.out.println(472); + + System.out.println(84); + } +} diff --git a/tests/data/unpick/scoped/source/com/example/Example2.java b/tests/data/unpick/scoped/source/com/example/Example2.java new file mode 100644 index 0000000..5a63ce2 --- /dev/null +++ b/tests/data/unpick/scoped/source/com/example/Example2.java @@ -0,0 +1,5 @@ +package com.example; + +public class Example2 { + private final int fld = 4; +} diff --git a/tests/data/unpick/statements/def.unpick b/tests/data/unpick/statements/def.unpick new file mode 100644 index 0000000..cc15b23 --- /dev/null +++ b/tests/data/unpick/statements/def.unpick @@ -0,0 +1,11 @@ +unpick v3 + +group int Flags + @flags + com.example.Example.ONE + com.example.Example.TWO + com.example.Example.THREE + com.example.Example.FOUR + +target_method com.example.Example accept(I)V + param 0 Flags diff --git a/tests/data/unpick/statements/expected/com/example/Example.java b/tests/data/unpick/statements/expected/com/example/Example.java new file mode 100644 index 0000000..6c199f1 --- /dev/null +++ b/tests/data/unpick/statements/expected/com/example/Example.java @@ -0,0 +1,19 @@ +package com.example; + +public class Example { + public static final int + ONE = 1 << 0, + TWO = 1 << 1, + THREE = 1 << 2, + FOUR = 1 << 3; + + public static void run(int val) { + accept(((val & Example.FOUR) != 0) ? val | Example.THREE : val | Example.ONE); + + val = Example.TWO; + + accept(val); + } + + public static void accept(int value) {} +} diff --git a/tests/data/unpick/statements/source/com/example/Example.java b/tests/data/unpick/statements/source/com/example/Example.java new file mode 100644 index 0000000..b1bf000 --- /dev/null +++ b/tests/data/unpick/statements/source/com/example/Example.java @@ -0,0 +1,19 @@ +package com.example; + +public class Example { + public static final int + ONE = 1 << 0, + TWO = 1 << 1, + THREE = 1 << 2, + FOUR = 1 << 3; + + public static void run(int val) { + accept(((val & 8) != 0) ? val | 4 : val | 1); + + val = 2; + + accept(val); + } + + public static void accept(int value) {} +} diff --git a/tests/src/test/java/net/neoforged/jst/tests/EmbeddedTest.java b/tests/src/test/java/net/neoforged/jst/tests/EmbeddedTest.java index 8c1a4ab..ca47ff6 100644 --- a/tests/src/test/java/net/neoforged/jst/tests/EmbeddedTest.java +++ b/tests/src/test/java/net/neoforged/jst/tests/EmbeddedTest.java @@ -357,6 +357,44 @@ void testNestedGenericStubs() throws Exception { } } + @Nested + class Unpick { + @Test + void testConst() throws Exception { + runUnpickTest("const"); + } + + @Test + void testFormats() throws Exception { + runUnpickTest("formats"); + } + + @Test + void testScoped() throws Exception { + runUnpickTest("scoped"); + } + + @Test + void testLocalVariables() throws Exception { + runUnpickTest("local_variables"); + } + + @Test + void testReturns() throws Exception { + runUnpickTest("returns"); + } + + @Test + void testFlags() throws Exception { + runUnpickTest("flags"); + } + + @Test + void testStatements() throws Exception { + runUnpickTest("statements"); + } + } + protected final void runInterfaceInjectionTest(String testDirName, Path tempDir, String... additionalArgs) throws Exception { var stub = tempDir.resolve("jst-" + testDirName + "-stub.jar"); testDirName = "interfaceinjection/" + testDirName; @@ -373,6 +411,17 @@ protected final void runInterfaceInjectionTest(String testDirName, Path tempDir, } } + protected final void runUnpickTest(String testDirName, String... additionalArgs) throws Exception { + testDirName = "unpick/" + testDirName; + var testDir = testDataRoot.resolve(testDirName); + var inputPath = testDir.resolve("def.unpick"); + + var args = new ArrayList<>(Arrays.asList("--enable-unpick", "--unpick-data", inputPath.toString())); + args.addAll(Arrays.asList(additionalArgs)); + + runTest(testDirName, UnaryOperator.identity(), args.toArray(String[]::new)); + } + protected final void runATTest(String testDirName, final String... extraArgs) throws Exception { testDirName = "accesstransformer/" + testDirName; var atPath = testDataRoot.resolve(testDirName).resolve("accesstransformer.cfg"); diff --git a/unpick/build.gradle b/unpick/build.gradle new file mode 100644 index 0000000..9871314 --- /dev/null +++ b/unpick/build.gradle @@ -0,0 +1,8 @@ +plugins { + id 'java-library' +} + +dependencies { + implementation project(':api') + implementation 'net.fabricmc.unpick:unpick-format-utils:3.0.0-beta.8' +} diff --git a/unpick/src/main/java/net/neoforged/jst/unpick/NumberType.java b/unpick/src/main/java/net/neoforged/jst/unpick/NumberType.java new file mode 100644 index 0000000..a521629 --- /dev/null +++ b/unpick/src/main/java/net/neoforged/jst/unpick/NumberType.java @@ -0,0 +1,364 @@ +package net.neoforged.jst.unpick; + +import daomephsta.unpick.constantmappers.datadriven.tree.DataType; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Class that is used to compute mathematical operations between number types while operating on {@link Number}s. + *

+ * Each operation has a method that is overridden as needed by each number type to ensure that + * special behaviour is preserved (e.g. we cannot add two bytes as longs and then attempt to cast back + * to byte because we need to make sure that the addition overflows as as byte; the same + * applies to floats and doubles). + */ +public enum NumberType { + BYTE(DataType.BYTE, Byte.class, false) { + @Override + public Number cast(Number in) { + return in.byteValue(); + } + + @Override + public Number divide(Number a, Number b) { + return a.byteValue() / b.byteValue(); + } + + @Override + public Number multiply(Number a, Number b) { + return a.byteValue() * b.byteValue(); + } + + @Override + public Number add(Number a, Number b) { + return a.byteValue() + b.byteValue(); + } + + @Override + public Number subtract(Number a, Number b) { + return a.byteValue() - b.byteValue(); + } + + @Override + public Number modulo(Number a, Number b) { + return a.byteValue() % b.byteValue(); + } + + @Override + public Number or(Number a, Number b) { + return a.byteValue() | b.byteValue(); + } + + @Override + public Number xor(Number a, Number b) { + return a.byteValue() ^ b.byteValue(); + } + + @Override + public Number and(Number a, Number b) { + return a.byteValue() & b.byteValue(); + } + + @Override + public Number lshift(Number a, Number b) { + return a.byteValue() << b.byteValue(); + } + + @Override + public Number rshift(Number a, Number b) { + return a.byteValue() >> b.byteValue(); + } + + @Override + public Number rshiftUnsigned(Number a, Number b) { + return a.byteValue() >>> b.byteValue(); + } + }, + SHORT(DataType.SHORT, Short.class, false, BYTE) { + @Override + public Number cast(Number in) { + return in.shortValue(); + } + + @Override + public Number divide(Number a, Number b) { + return a.shortValue() / b.shortValue(); + } + + @Override + public Number multiply(Number a, Number b) { + return a.shortValue() * b.shortValue(); + } + + @Override + public Number add(Number a, Number b) { + return a.shortValue() + b.shortValue(); + } + + @Override + public Number subtract(Number a, Number b) { + return a.shortValue() - b.shortValue(); + } + + @Override + public Number modulo(Number a, Number b) { + return a.shortValue() % b.shortValue(); + } + + @Override + public Number or(Number a, Number b) { + return a.shortValue() | b.shortValue(); + } + + @Override + public Number xor(Number a, Number b) { + return a.shortValue() ^ b.shortValue(); + } + + @Override + public Number and(Number a, Number b) { + return a.shortValue() & b.shortValue(); + } + + @Override + public Number lshift(Number a, Number b) { + return a.shortValue() << b.shortValue(); + } + + @Override + public Number rshift(Number a, Number b) { + return a.shortValue() >> b.shortValue(); + } + + @Override + public Number rshiftUnsigned(Number a, Number b) { + return a.shortValue() >>> b.shortValue(); + } + }, + INT(DataType.INT, Integer.class, true, SHORT) { + @Override + public long toUnsignedLong(Number number) { + return Integer.toUnsignedLong(number.intValue()); + } + + @Override + public Number cast(Number in) { + return in.intValue(); + } + }, + LONG(DataType.LONG, Long.class, true, INT) { + @Override + public Number negate(Number number) { + return ~number.longValue(); + } + + @Override + public Number cast(Number in) { + return in.longValue(); + } + + @Override + public Number divide(Number a, Number b) { + return a.longValue() / b.longValue(); + } + + @Override + public Number multiply(Number a, Number b) { + return a.longValue() * b.longValue(); + } + + @Override + public Number add(Number a, Number b) { + return a.longValue() + b.longValue(); + } + + @Override + public Number subtract(Number a, Number b) { + return a.longValue() - b.longValue(); + } + + @Override + public Number modulo(Number a, Number b) { + return a.longValue() % b.longValue(); + } + + @Override + public Number or(Number a, Number b) { + return a.longValue() | b.longValue(); + } + + @Override + public Number xor(Number a, Number b) { + return a.longValue() ^ b.longValue(); + } + + @Override + public Number and(Number a, Number b) { + return a.longValue() & b.longValue(); + } + + @Override + public Number lshift(Number a, Number b) { + return a.longValue() << b.longValue(); + } + + @Override + public Number rshift(Number a, Number b) { + return a.longValue() >> b.longValue(); + } + + @Override + public Number rshiftUnsigned(Number a, Number b) { + return a.longValue() >>> b.longValue(); + } + }, + FLOAT(DataType.FLOAT, Float.class, false, INT) { + @Override + public Number cast(Number in) { + return in.floatValue(); + } + + @Override + public Number divide(Number a, Number b) { + return a.floatValue() / b.floatValue(); + } + + @Override + public Number multiply(Number a, Number b) { + return a.floatValue() * b.floatValue(); + } + + @Override + public Number add(Number a, Number b) { + return a.floatValue() + b.floatValue(); + } + + @Override + public Number subtract(Number a, Number b) { + return a.floatValue() - b.floatValue(); + } + + @Override + public Number modulo(Number a, Number b) { + return a.floatValue() % b.floatValue(); + } + }, + DOUBLE(DataType.DOUBLE, Double.class, false, FLOAT) { + @Override + public Number cast(Number in) { + return in.doubleValue(); + } + + @Override + public Number divide(Number a, Number b) { + return a.doubleValue() / b.doubleValue(); + } + + @Override + public Number multiply(Number a, Number b) { + return a.doubleValue() * b.doubleValue(); + } + + @Override + public Number add(Number a, Number b) { + return a.doubleValue() + b.doubleValue(); + } + + @Override + public Number subtract(Number a, Number b) { + return a.doubleValue() - b.doubleValue(); + } + + @Override + public Number modulo(Number a, Number b) { + return a.doubleValue() % b.doubleValue(); + } + }; + + public static final Map, NumberType> TYPES; + static { + var types = new HashMap, NumberType>(); + for (NumberType value : values()) { + types.put(value.classType, value); + } + TYPES = Collections.unmodifiableMap(types); + } + + public final DataType dataType; + public final Class classType; + /** + * Whether this number type can be treated as a bit flag - only {@code true} for {@link #INT} and {@link #LONG}. + */ + public final boolean supportsFlag; + /** + * Number types that can be converted to this type without needing an explicit cast: + *

+ * byte -> short -> int -> long + *

+ * int -> float -> double + */ + public final NumberType[] widenFrom; + + NumberType(DataType dataType, Class classType, boolean supportsFlag, NumberType... widenFrom) { + this.dataType = dataType; + this.classType = classType; + this.supportsFlag = supportsFlag; + this.widenFrom = widenFrom; + } + + public abstract Number cast(Number in); + + public long toUnsignedLong(Number number) { + return number.longValue(); + } + + public Number negate(Number number) { + return ~number.intValue(); + } + + public Number divide(Number a, Number b) { + return a.intValue() / b.intValue(); + } + + public Number multiply(Number a, Number b) { + return a.intValue() * b.intValue(); + } + + public Number add(Number a, Number b) { + return a.intValue() + b.intValue(); + } + + public Number subtract(Number a, Number b) { + return a.intValue() - b.intValue(); + } + + public Number modulo(Number a, Number b) { + return a.intValue() % b.intValue(); + } + + public Number or(Number a, Number b) { + return a.intValue() | b.intValue(); + } + + public Number xor(Number a, Number b) { + return a.intValue() ^ b.intValue(); + } + + public Number and(Number a, Number b) { + return a.intValue() & b.intValue(); + } + + public Number lshift(Number a, Number b) { + return a.intValue() << b.intValue(); + } + + public Number rshift(Number a, Number b) { + return a.intValue() >> b.intValue(); + } + + public Number rshiftUnsigned(Number a, Number b) { + return a.intValue() >>> b.intValue(); + } +} diff --git a/unpick/src/main/java/net/neoforged/jst/unpick/UnpickCollection.java b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickCollection.java new file mode 100644 index 0000000..b68ba64 --- /dev/null +++ b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickCollection.java @@ -0,0 +1,317 @@ +package net.neoforged.jst.unpick; + +import com.intellij.lang.jvm.JvmModifier; +import com.intellij.openapi.util.Key; +import com.intellij.psi.JavaPsiFacade; +import com.intellij.psi.PsiClass; +import com.intellij.psi.PsiClassType; +import com.intellij.psi.PsiElement; +import com.intellij.psi.PsiField; +import com.intellij.psi.PsiJavaFile; +import com.intellij.psi.PsiMethod; +import com.intellij.psi.PsiType; +import com.intellij.psi.PsiTypes; +import com.intellij.psi.search.GlobalSearchScope; +import com.intellij.util.containers.MultiMap; +import daomephsta.unpick.constantmappers.datadriven.tree.DataType; +import daomephsta.unpick.constantmappers.datadriven.tree.GroupDefinition; +import daomephsta.unpick.constantmappers.datadriven.tree.GroupFormat; +import daomephsta.unpick.constantmappers.datadriven.tree.GroupScope; +import daomephsta.unpick.constantmappers.datadriven.tree.Literal; +import daomephsta.unpick.constantmappers.datadriven.tree.TargetField; +import daomephsta.unpick.constantmappers.datadriven.tree.TargetMethod; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.BinaryExpression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.CastExpression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.Expression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.FieldExpression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.LiteralExpression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.ParenExpression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.UnaryExpression; +import net.neoforged.jst.api.PsiHelper; +import net.neoforged.jst.api.TransformContext; +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Predicate; + +public class UnpickCollection { + private static final Key> UNPICK_DEFINITION = Key.create("unpick.method_definition"); + private static final Key UNPICK_FIELD_TARGET = Key.create("unpick.field_target"); + + private final Set possibleTargetNames = new HashSet<>(); + private final MultiMap groups; + + private final List global; + + private final MultiMap byPackage; + private final MultiMap byClass; + + private final MultiMap methodScopes; + + // This list only exists to keep the base elements in memory and prevent them from being GC'd and therefore losing their user data + // JavaPsiFacade#findClass uses a soft key and soft value map + @SuppressWarnings({"FieldCanBeLocal", "MismatchedQueryAndUpdateOfCollection"}) + private final List baseElements; + + public UnpickCollection(TransformContext context, Map> groups, List fields, List methods) { + this.groups = new MultiMap<>(new HashMap<>(groups.size())); + + var facade = context.environment().getPsiFacade(); + var project = context.environment().getPsiManager().getProject(); + + var projectScope = GlobalSearchScope.projectScope(project); + + global = new ArrayList<>(); + + byPackage = new MultiMap<>(); + byClass = new MultiMap<>(); + + methodScopes = new MultiMap<>(new IdentityHashMap<>()); + baseElements = new ArrayList<>(); + + groups.forEach((key, defs) -> { + for (GroupDefinition def : defs) { + var gr = Group.create(def, facade, projectScope); + if (key.isGlobal()) { + global.add(gr); + } else { + this.groups.putValue(key.name(), gr); + + for (var scope : def.scopes()) { + switch (scope) { + case GroupScope.Package(var packageName) -> byPackage.putValue(packageName, gr); + case GroupScope.Class(var cls) -> byClass.putValue(cls, gr); + case GroupScope.Method(var className, var method, var desc) -> { + var cls = facade.findClass(className, projectScope); + if (cls == null) return; + + for (PsiMethod clsMethod : cls.getMethods()) { + if (clsMethod.getName().equals(method) && PsiHelper.getBinaryMethodSignature(clsMethod).equals(desc)) { + methodScopes.putValue(clsMethod, gr); + } + } + } + } + } + } + } + }); + + for (var field : fields) { + var cls = facade.findClass(field.className(), projectScope); + if (cls == null) continue; + + var fld = cls.findFieldByName(field.fieldName(), true); + if (fld != null) { + fld.putUserData(UNPICK_FIELD_TARGET, field); + baseElements.add(fld); + } + } + + for (var method : methods) { + var cls = facade.findClass(method.className(), projectScope); + if (cls == null) continue; + + possibleTargetNames.add(method.methodName()); + + for (PsiMethod clsMethod : cls.getMethods()) { + if (clsMethod.getName().equals(method.methodName()) && PsiHelper.getBinaryMethodSignature(clsMethod).equals(method.methodDesc())) { + clsMethod.putUserData(UNPICK_DEFINITION, Optional.of(method)); + baseElements.add(clsMethod); + } + } + } + } + + public Collection getClassContext(PsiClass cls) { + var clsName = cls.getQualifiedName(); + if (clsName != null) { + return byClass.get(clsName); + } + return List.of(); + } + + + public Collection getPackageContext(PsiJavaFile file) { + return byPackage.get(file.getPackageName()); + } + + public Collection getMethodContext(PsiMethod method) { + return methodScopes.get(method); + } + + public Collection getGlobalContext() { + return global; + } + + @SuppressWarnings("OptionalAssignedToNull") + @Nullable + public TargetMethod getDefinitionsFor(PsiMethod method) { + if (!possibleTargetNames.contains(method.getName())) return null; + var data = method.getUserData(UNPICK_DEFINITION); + if (data == null) { + synchronized (this) { + if (method.getParent() instanceof PsiClass cls) { + for (PsiClass iface : cls.getInterfaces()) { + var met = iface.findMethodBySignature(method, true); + if (met != null) { + var parent = getDefinitionsFor(met); + if (parent != null) { + data = Optional.of(parent); + break; + } + } + } + + if (data == null && cls.getSuperClass() != null) { + var met = cls.getSuperClass().findMethodBySignature(method, true); + if (met != null) { + var parent = getDefinitionsFor(met); + if (parent != null) { + data = Optional.of(parent); + } + } + } + } + + if (data == null) data = Optional.empty(); + method.putUserData(UNPICK_DEFINITION, data); + } + } + return data.orElse(null); + } + + public Collection getGroups(String id) { + return groups.get(id); + } + + public record Group( + DataType data, + boolean strict, + boolean flag, + @Nullable GroupFormat format, + Map constants + ) { + public static Group create(GroupDefinition def, JavaPsiFacade facade, GlobalSearchScope scope) { + var constants = HashMap.newHashMap(def.constants().size()); + for (Expression constant : def.constants()) { + var value = resolveConstant(constant, facade, scope); + constants.put(cast(value, def.dataType()), constant); + } + return new Group( + def.dataType(), + def.strict(), + def.flags(), + def.format(), + constants + ); + } + + private static Object resolveConstant(Expression expression, JavaPsiFacade facade, GlobalSearchScope scope) { + if (expression instanceof FieldExpression fieldEx) { + var clazz = facade.findClass(fieldEx.className, scope); + if (clazz != null) { + for (PsiField field : clazz.getAllFields()) { + if (fieldEx.isStatic != field.hasModifier(JvmModifier.STATIC)) continue; + if (fieldEx.fieldType != null && !sameType(fieldEx.fieldType, field.getType())) continue; + if (fieldEx.fieldName.equals(field.getName())) { + return field.computeConstantValue(); + } + } + } + throw new IllegalArgumentException("Cannot find field named " + fieldEx.className + " of type " + fieldEx.fieldType + " in class " + fieldEx.className); + } else if (expression instanceof LiteralExpression literalExpression) { + return switch (literalExpression.literal) { + case Literal.Character(var ch) -> ch; + case Literal.Integer i -> i.value(); + case Literal.Long l -> l.value(); + case Literal.Float(var f) -> f; + case Literal.Double(var d) -> d; + case Literal.String(var s) -> s; + }; + } else if (expression instanceof ParenExpression parenExpression) { + return resolveConstant(parenExpression.expression, facade, scope); + } else if (expression instanceof CastExpression castExpression) { + return cast(resolveConstant(castExpression.operand, facade, scope), castExpression.castType); + } else if (expression instanceof UnaryExpression unaryExpression) { + var value = (Number) resolveConstant(unaryExpression.operand, facade, scope); + return switch (unaryExpression.operator) { + case NEGATE -> NumberType.TYPES.get(value.getClass()).negate(value); + case BIT_NOT -> value instanceof Long ? ~value.longValue() : ~value.intValue(); + }; + } else if (expression instanceof BinaryExpression binaryExpression) { + var lhs = resolveConstant(binaryExpression.lhs, facade, scope); + var rhs = resolveConstant(binaryExpression.rhs, facade, scope); + + if (lhs instanceof Number l && rhs instanceof Number r) { + var type = NumberType.TYPES.get(l.getClass()); + return switch (binaryExpression.operator) { + case ADD -> type.add(l, r); + case DIVIDE -> type.divide(l, r); + case MODULO -> type.modulo(l, r); + case MULTIPLY -> type.multiply(l, r); + case SUBTRACT -> type.subtract(l, r); + + case BIT_AND -> type.and(l, r); + case BIT_OR -> type.or(l, r); + case BIT_XOR -> type.xor(l, r); + + case BIT_SHIFT_LEFT -> type.lshift(l, r); + case BIT_SHIFT_RIGHT -> type.rshift(l, r); + case BIT_SHIFT_RIGHT_UNSIGNED -> type.rshiftUnsigned(l, r); + }; + } + + if (lhs instanceof String lS && rhs instanceof String rS && binaryExpression.operator == BinaryExpression.Operator.ADD) { + return lS + rS; + } + + throw new IllegalArgumentException("Cannot resolve expression: " + binaryExpression + ". Operands of type " + lhs.getClass() + " and " + rhs.getClass() + " do not support operator " + binaryExpression.operator); + } + + throw new IllegalArgumentException("Unknown Expression of type " + expression.getClass() + ": " + expression); + } + + private static Object cast(Object in, DataType type) { + return switch (type) { + case BYTE -> ((Number) in).byteValue(); + case CHAR -> Character.valueOf((char)((Number) in).byteValue()); + case SHORT -> ((Number) in).shortValue(); + case INT -> ((Number) in).intValue(); + case LONG -> ((Number) in).longValue(); + case FLOAT -> ((Number) in).floatValue(); + case DOUBLE -> ((Number) in).doubleValue(); + case CLASS -> (Class) in; + case STRING -> in.toString(); + }; + } + + private static boolean sameType(DataType type, PsiType fieldType) { + return switch (type) { + case BYTE -> fieldType.equals(PsiTypes.byteType()); + case CHAR -> fieldType.equals(PsiTypes.charType()); + case SHORT -> fieldType.equals(PsiTypes.shortType()); + case INT -> fieldType.equals(PsiTypes.intType()); + case LONG -> fieldType.equals(PsiTypes.longType()); + case FLOAT -> fieldType.equals(PsiTypes.floatType()); + case DOUBLE -> fieldType.equals(PsiTypes.doubleType()); + case CLASS -> ((PsiClassType) fieldType).resolve().getQualifiedName().equals("java.lang.Class"); + case STRING -> ((PsiClassType) fieldType).resolve().getQualifiedName().equals("java.lang.String"); + }; + } + } + + public record TypedKey(DataType type, List scopes, @Nullable String name) { + public boolean isGlobal() { + return name == null && scopes.isEmpty(); + } + } +} diff --git a/unpick/src/main/java/net/neoforged/jst/unpick/UnpickPlugin.java b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickPlugin.java new file mode 100644 index 0000000..0a0dde6 --- /dev/null +++ b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickPlugin.java @@ -0,0 +1,18 @@ +package net.neoforged.jst.unpick; + +import net.neoforged.jst.api.SourceTransformer; +import net.neoforged.jst.api.SourceTransformerPlugin; +import org.jetbrains.annotations.ApiStatus; + +@ApiStatus.Experimental +public class UnpickPlugin implements SourceTransformerPlugin { + @Override + public String getName() { + return "unpick"; + } + + @Override + public SourceTransformer createTransformer() { + return new UnpickTransformer(); + } +} diff --git a/unpick/src/main/java/net/neoforged/jst/unpick/UnpickTransformer.java b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickTransformer.java new file mode 100644 index 0000000..102cbe3 --- /dev/null +++ b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickTransformer.java @@ -0,0 +1,66 @@ +package net.neoforged.jst.unpick; + +import com.intellij.psi.PsiFile; +import daomephsta.unpick.constantmappers.datadriven.parser.v3.UnpickV3Reader; +import daomephsta.unpick.constantmappers.datadriven.tree.GroupDefinition; +import daomephsta.unpick.constantmappers.datadriven.tree.TargetField; +import daomephsta.unpick.constantmappers.datadriven.tree.TargetMethod; +import daomephsta.unpick.constantmappers.datadriven.tree.UnpickV3Visitor; +import net.neoforged.jst.api.Replacements; +import net.neoforged.jst.api.SourceTransformer; +import net.neoforged.jst.api.TransformContext; +import picocli.CommandLine; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +public class UnpickTransformer implements SourceTransformer { + @CommandLine.Option(names = "--unpick-data", description = "The paths to read unpick definition files from") + public List paths = new ArrayList<>(); + + private UnpickCollection collection; + + @Override + public void beforeRun(TransformContext context) { + var groups = new HashMap>(); + var fields = new ArrayList(); + var methods = new ArrayList(); + + for (Path path : paths) { + try (var reader = Files.newBufferedReader(path)) { + new UnpickV3Reader(reader).accept(new UnpickV3Visitor() { + @Override + public void visitGroupDefinition(GroupDefinition groupDefinition) { + groups.computeIfAbsent(new UnpickCollection.TypedKey(groupDefinition.dataType(), groupDefinition.scopes(), groupDefinition.name()), k -> new ArrayList<>()) + .add(groupDefinition); + } + + @Override + public void visitTargetField(TargetField targetField) { + fields.add(targetField); + } + + @Override + public void visitTargetMethod(TargetMethod targetMethod) { + methods.add(targetMethod); + } + }); + } catch (IOException exception) { + context.logger().error("Failed to read unpick definition file: %s", exception.getMessage()); + throw new UncheckedIOException(exception); + } + } + + this.collection = new UnpickCollection(context, groups, fields, methods); + } + + @Override + public void visitFile(PsiFile psiFile, Replacements replacements) { + new UnpickVisitor(psiFile, collection, replacements).visitFile(psiFile); + } +} diff --git a/unpick/src/main/java/net/neoforged/jst/unpick/UnpickVisitor.java b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickVisitor.java new file mode 100644 index 0000000..1225cd5 --- /dev/null +++ b/unpick/src/main/java/net/neoforged/jst/unpick/UnpickVisitor.java @@ -0,0 +1,571 @@ +package net.neoforged.jst.unpick; + +import com.intellij.openapi.util.Key; +import com.intellij.psi.JavaTokenType; +import com.intellij.psi.PsiAssignmentExpression; +import com.intellij.psi.PsiCallExpression; +import com.intellij.psi.PsiClass; +import com.intellij.psi.PsiElement; +import com.intellij.psi.PsiExpression; +import com.intellij.psi.PsiField; +import com.intellij.psi.PsiFile; +import com.intellij.psi.PsiJavaFile; +import com.intellij.psi.PsiJavaToken; +import com.intellij.psi.PsiLiteralExpression; +import com.intellij.psi.PsiLocalVariable; +import com.intellij.psi.PsiMethod; +import com.intellij.psi.PsiMethodCallExpression; +import com.intellij.psi.PsiParameter; +import com.intellij.psi.PsiPrefixExpression; +import com.intellij.psi.PsiRecursiveElementVisitor; +import com.intellij.psi.PsiReferenceExpression; +import com.intellij.psi.PsiReturnStatement; +import com.intellij.psi.PsiVariable; +import daomephsta.unpick.constantmappers.datadriven.tree.GroupFormat; +import daomephsta.unpick.constantmappers.datadriven.tree.Literal; +import daomephsta.unpick.constantmappers.datadriven.tree.TargetMethod; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.BinaryExpression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.CastExpression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.Expression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.ExpressionVisitor; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.FieldExpression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.LiteralExpression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.ParenExpression; +import daomephsta.unpick.constantmappers.datadriven.tree.expr.UnaryExpression; +import net.neoforged.jst.api.ImportHelper; +import net.neoforged.jst.api.PsiHelper; +import net.neoforged.jst.api.Replacements; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Stack; +import java.util.function.Predicate; + +public class UnpickVisitor extends PsiRecursiveElementVisitor { + private static final Key UNPICK_WAS_REPLACED = Key.create("unpick.was_replaced"); + + private final PsiFile file; + private final UnpickCollection collection; + private final Replacements replacements; + + public UnpickVisitor(PsiFile file, UnpickCollection collection, Replacements replacements) { + this.file = file; + this.collection = collection; + this.replacements = replacements; + } + + @Nullable + private PsiMethod methodContext; + @Nullable + private PsiField fieldContext; + + @Nullable + private TargetMethod calledMethodContext; + private int currentParameterIndex; + + private final Stack> contextStack = new Stack<>(); + + @Override + public void visitElement(@NotNull PsiElement element) { + Collection additionalContext = List.of(); + + switch (element) { + case PsiJavaFile javaFile -> additionalContext = collection.getPackageContext(javaFile); + case PsiClass cls -> additionalContext = collection.getClassContext(cls); + case PsiMethod met when met.getBody() != null -> { + var oldCtx = this.methodContext; + contextStack.push(collection.getMethodContext(met)); + this.methodContext = met; + met.getBody().acceptChildren(this); + this.methodContext = oldCtx; + contextStack.pop(); + return; + } + case PsiField fld -> { + var oldCtx = this.fieldContext; + this.fieldContext = fld; + fld.acceptChildren(this); + this.fieldContext = oldCtx; + return; + } + + case PsiJavaToken tok -> { + visitToken(tok); + return; + } + + case PsiMethodCallExpression call -> { + PsiElement ref = PsiHelper.resolve(call.getMethodExpression()); + if (ref instanceof PsiMethod met) { + var oldCtx = this.calledMethodContext; + var oldIdx = this.currentParameterIndex; + + // We replace the old context to avoid nesting as it can produce weird artifacts + // when parameter expression are themselves other method calls since "invasive" + // (i.e. number formatting) rules of this group would apply to its parameters + this.calledMethodContext = collection.getDefinitionsFor(met); + for (int i = 0; i < call.getArgumentList().getExpressions().length; i++) { + this.currentParameterIndex = i; + // If any parameter of the method call is directly referencing local var we re-walk the entire method body + // and apply unpick with the context of the method being called to all of its assignments (including the initialiser) + acceptPossibleLocalVarReference(call.getArgumentList().getExpressions()[i]); + } + + this.calledMethodContext = oldCtx; + this.currentParameterIndex = oldIdx; + return; + } + } + case PsiReturnStatement returnStatement when methodContext != null -> { + var contextDefinitions = collection.getDefinitionsFor(methodContext); + if (contextDefinitions != null && contextDefinitions.returnGroup() != null) { + var groups = collection.getGroups(contextDefinitions.returnGroup()); + if (!groups.isEmpty()) { + contextStack.push(groups); + acceptPossibleLocalVarReference(returnStatement.getReturnValue()); + contextStack.pop(); + return; + } + } + } + + case PsiLocalVariable localVar + when localVar.getInitializer() != null && localVar.getInitializer() instanceof PsiMethodCallExpression methodCall -> { + acceptReturnFlow(localVar, methodCall); + return; + } + + case PsiAssignmentExpression assignment -> { + if (assignment.getOperationSign().getTokenType() == JavaTokenType.EQ && assignment.getLExpression() instanceof PsiReferenceExpression ref) { + var referencedVariable = PsiHelper.resolve(ref); + if (referencedVariable instanceof PsiLocalVariable || referencedVariable instanceof PsiParameter) { + if (assignment.getLExpression() instanceof PsiMethodCallExpression methodCall) { + acceptReturnFlow((PsiVariable) referencedVariable, methodCall); + return; + } + } + } + } + + default -> {} + } + + if (additionalContext.isEmpty()) { + element.acceptChildren(this); + } else { + contextStack.push(additionalContext); + element.acceptChildren(this); + contextStack.pop(); + } + } + + private void acceptReturnFlow(PsiVariable variable, PsiMethodCallExpression expression) { + var flowingFrom = expression.resolveMethod(); + if (flowingFrom != null) { + var target = collection.getDefinitionsFor(flowingFrom); + if (target != null && target.returnGroup() != null) { + var groups = collection.getGroups(target.returnGroup()); + if (!groups.isEmpty()) { + contextStack.push(groups); + visitVariableAssignments(variable); + contextStack.pop(); + return; + } + } + } + expression.acceptChildren(this); + } + + private void acceptPossibleLocalVarReference(PsiExpression expression) { + if (expression instanceof PsiReferenceExpression refEx) { + PsiElement resolved = PsiHelper.resolve(refEx); + + if (resolved instanceof PsiLocalVariable || resolved instanceof PsiParameter) { + visitVariableAssignments((PsiVariable) resolved); + return; + } + } + + if (expression != null) { + visitElement(expression); + } + } + + private void visitVariableAssignments(PsiVariable var) { + if (var.getInitializer() != null) { + var.getInitializer().accept(limitedDirectVisitor()); + } + + var body = methodContext == null ? null : methodContext.getBody(); + if (body != null) { + new PsiRecursiveElementVisitor() { + @Override + public void visitElement(@NotNull PsiElement element) { + if (element instanceof PsiAssignmentExpression as) { + if (as.getOperationSign().getTokenType() == JavaTokenType.EQ && as.getLExpression() instanceof PsiReferenceExpression ref && PsiHelper.resolve(ref) == var && as.getRExpression() != null) { + as.getRExpression().accept(limitedDirectVisitor()); + } + return; + } + super.visitElement(element); + } + }.visitElement(body); + } + } + + /** + * {@return an element visitor that visits only tokens outside of call expressions} + * This can be used when there is no need to handle call expressions as they would have already + * been handled or will be handled - for instance, when re-applying unpick for local variable initialisers, + * but with more context. + */ + private PsiRecursiveElementVisitor limitedDirectVisitor() { + return new PsiRecursiveElementVisitor() { + @Override + public void visitElement(@NotNull PsiElement element) { + if (element instanceof PsiCallExpression) { + return; // We do not want to try to further replace constants inside method calls, that's why we're limited to direct elements + } + if (element instanceof PsiJavaToken tok) { + visitToken(tok); + return; + } + super.visitElement(element); + } + }; + } + + private void visitToken(PsiJavaToken tok) { + if (Boolean.TRUE.equals(tok.getUserData(UNPICK_WAS_REPLACED))) return; + + if (tok.getTokenType() == JavaTokenType.STRING_LITERAL) { + var val = tok.getText().substring(1); // Remove starting " + final var finalVal = val.substring(0, val.length() - 1); // Remove leading " + forInScope(group -> { + var ct = group.constants().get(finalVal); + if (ct != null && checkNotRecursive(ct)) { + replacements.replace(tok, write(ct)); + tok.putUserData(UNPICK_WAS_REPLACED, true); + return true; + } + return false; + }); + } else if (tok.getTokenType() == JavaTokenType.INTEGER_LITERAL) { + var text = tok.getText().toLowerCase(Locale.ROOT); + + int val; + if (text.startsWith("0x")) { + val = Integer.parseUnsignedInt(text.substring(2), 16); + } else if (text.startsWith("0b")) { + val = Integer.parseUnsignedInt(text.substring(2), 2); + } else { + val = Integer.parseUnsignedInt(text); + } + + if (isUnaryMinus(tok)) val = -val; + replaceLiteral(tok, val, NumberType.INT); + } else if (tok.getTokenType() == JavaTokenType.LONG_LITERAL) { + var text = removeSuffix(tok.getText(), 'l').toLowerCase(Locale.ROOT); + + long val; + if (text.startsWith("0x")) { + val = Long.parseUnsignedLong(text.substring(2), 16); + } else if (text.startsWith("0b")) { + val = Long.parseUnsignedLong(text.substring(2), 2); + } else { + val = Long.parseUnsignedLong(text); + } + + if (isUnaryMinus(tok)) val = -val; + replaceLiteral(tok, val, NumberType.LONG); + } else if (tok.getTokenType() == JavaTokenType.DOUBLE_LITERAL) { + var val = Double.parseDouble(removeSuffix(tok.getText(), 'd')); + if (isUnaryMinus(tok)) val = -val; + replaceLiteral(tok, val, NumberType.DOUBLE); + } else if (tok.getTokenType() == JavaTokenType.FLOAT_LITERAL) { + var val = Float.parseFloat(removeSuffix(tok.getText(), 'f')); + if (isUnaryMinus(tok)) val = -val; + replaceLiteral(tok, val, NumberType.FLOAT); + } + } + + private void replaceLiteral(PsiJavaToken element, Number number, NumberType type) { + replaceLiteral(element, number, type, false); + } + + private boolean replaceLiteral(PsiJavaToken element, Number number, NumberType type, boolean denyStrict) { + return forInScope(group -> { + // If we need to deny strict conversion (so if this is a conversion) we shall do so + if (group.strict() && denyStrict) return false; + + // We'll try a direct constant first, even if it's a flag + var ct = group.constants().get(number); + if (ct != null && checkNotRecursive(ct)) { + replacements.replace(element, write(ct)); + element.putUserData(UNPICK_WAS_REPLACED, true); + replaceMinus(element); + return true; + } + + // Next, try if this group is a flag and the number type supports flags (ints and longs) we try to generate the flag combination + if (group.flag() && type.supportsFlag) { + // We generate flags for ints based on their long value as + // longs are a superset of ints and as such we can reduce code duplication + var flag = generateFlag(group, number.longValue(), type); + if (flag != null) { + replacements.replace(element, flag); + element.putUserData(UNPICK_WAS_REPLACED, true); + replaceMinus(element); + return true; + } + } + + // As a fallback, if the group has a specific format but the + // value of the token does not have a constant we format the token + if (group.format() != null) { + replacements.replace(element, formatAs(number, group.format())); + replaceMinus(element); + element.putUserData(UNPICK_WAS_REPLACED, true); + return true; + } + + // Finally we try to apply non-strict widening from lower number types + for (NumberType from : type.widenFrom) { + var lower = from.cast(number); + if (lower.doubleValue() == number.doubleValue()) { + if (replaceLiteral(element, lower, from, true)) { + return true; + } + } + } + + return false; + }); + } + + private boolean isUnaryMinus(PsiJavaToken tok) { + return tok.getParent() instanceof PsiLiteralExpression lit && lit.getParent() instanceof PsiPrefixExpression ex && ex.getOperationTokenType() == JavaTokenType.MINUS; + } + + private void replaceMinus(PsiJavaToken tok) { + if (tok.getParent() instanceof PsiLiteralExpression lit && lit.getParent() instanceof PsiPrefixExpression ex && ex.getOperationTokenType() == JavaTokenType.MINUS) { + replacements.remove(ex.getOperationSign()); + } + } + + private boolean checkNotRecursive(Expression expression) { + if (fieldContext != null && fieldContext.getContainingClass() != null && expression instanceof FieldExpression fld) { + return !(fld.className.equals(fieldContext.getContainingClass().getQualifiedName()) && Objects.equals(fld.fieldName, fieldContext.getName())); + } + return true; + } + + private boolean forInScope(Predicate apply) { + if (calledMethodContext != null) { + var paramGroupId = this.calledMethodContext.paramGroups().get(currentParameterIndex); + if (paramGroupId != null) { + var paramGroups = collection.getGroups(paramGroupId); + if (!paramGroups.isEmpty()) { + for (var group : paramGroups) { + if (apply.test(group)) { + return true; + } + } + } + } + } + + if (!contextStack.isEmpty()) { + // Walk and apply the context stack in reverse (e.g. we first apply the method scope, then the class scope and finally the package scope) + for (int i = contextStack.size() - 1; i >= 0; i--) { + for (var group : contextStack.get(i)) { + if (apply.test(group)) { + return true; + } + } + } + } + + for (var group : collection.getGlobalContext()) { + if (apply.test(group)) { + return true; + } + } + + return false; + } + + private String write(Expression expression) { + StringBuilder s = new StringBuilder(); + expression.accept(new ExpressionVisitor() { + @Override + public void visitFieldExpression(FieldExpression fieldExpression) { + var cls = imports().importClass(fieldExpression.className); + s.append(cls).append('.').append(fieldExpression.fieldName); + } + + @Override + public void visitParenExpression(ParenExpression parenExpression) { + s.append('(') + .append(write(parenExpression.expression)) + .append(')'); + } + + @Override + public void visitLiteralExpression(LiteralExpression literalExpression) { + if (literalExpression.literal instanceof Literal.String(String value)) { + s.append('\"').append(value.replace("\"", "\\\"")).append('\"'); + } else if (literalExpression.literal instanceof Literal.Integer i) { + s.append(i.value()); + } else if (literalExpression.literal instanceof Literal.Long l) { + s.append(l.value()).append('l'); + } else if (literalExpression.literal instanceof Literal.Double d) { + s.append(d).append('d'); + } else if (literalExpression.literal instanceof Literal.Float f) { + s.append(f).append('f'); + } + } + + @Override + public void visitCastExpression(CastExpression castExpression) { + s.append('('); + s.append(switch (castExpression.castType) { + case INT -> "int"; + case CHAR -> "char"; + case DOUBLE -> "double"; + case BYTE -> "byte"; + case LONG -> "long"; + case FLOAT -> "float"; + case SHORT -> "short"; + case STRING -> "String"; + case CLASS -> "Class"; + }); + s.append(')'); + s.append(write(castExpression.operand)); + } + + @Override + public void visitBinaryExpression(BinaryExpression binaryExpression) { + s.append(write(binaryExpression.lhs)); + switch (binaryExpression.operator) { + case ADD -> s.append(" + "); + case DIVIDE -> s.append(" / "); + case MODULO -> s.append(" % "); + case MULTIPLY -> s.append(" * "); + case SUBTRACT -> s.append(" - "); + + case BIT_AND -> s.append(" & "); + case BIT_OR -> s.append(" | "); + case BIT_XOR -> s.append(" ^ "); + + case BIT_SHIFT_LEFT -> s.append(" << "); + case BIT_SHIFT_RIGHT -> s.append(" >> "); + case BIT_SHIFT_RIGHT_UNSIGNED -> s.append(" >>> "); + } + s.append(write(binaryExpression.rhs)); + } + + @Override + public void visitUnaryExpression(UnaryExpression unaryExpression) { + switch (unaryExpression.operator) { + case NEGATE -> s.append("!"); + case BIT_NOT -> s.append("~"); + } + s.append(write(unaryExpression.operand)); + } + }); + return s.toString(); + } + + private String formatAs(Number value, GroupFormat format) { + return switch (format) { + case HEX -> { + if (value instanceof Integer) yield "0x" + Integer.toHexString(value.intValue()).toUpperCase(Locale.ROOT); + else if (value instanceof Long) yield "0x" + Long.toHexString(value.longValue()).toUpperCase(Locale.ROOT) + "l"; + else if (value instanceof Double) yield Double.toHexString(value.doubleValue()) + "d"; + else if (value instanceof Float) yield Float.toHexString(value.floatValue()) + "f"; + yield value.toString(); + } + case OCTAL -> { + if (value instanceof Integer) yield "0" + Integer.toOctalString(value.intValue()); + else if (value instanceof Long) yield "0" + Long.toOctalString(value.longValue()) + "l"; + yield value.toString(); + } + case BINARY -> { + if (value instanceof Integer) yield "0b" + Integer.toBinaryString(value.intValue()); + else if (value instanceof Long) yield "0b" + Long.toBinaryString(value.longValue()) + "l"; + yield value.toString(); + } + case CHAR -> "'" + ((char) value.intValue()) + "'"; + + default -> value.toString(); + }; + } + + private ImportHelper imports() { + return ImportHelper.get(file); + } + + @Nullable + private String generateFlag(UnpickCollection.Group group, long val, NumberType type) { + List orConstants = new ArrayList<>(); + long orResidual = getConstantsEncompassing(val, type, group, orConstants); + long negatedLiteral = type.toUnsignedLong(type.negate(val)); + List negatedConstants = new ArrayList<>(); + long negatedResidual = getConstantsEncompassing(negatedLiteral, type, group, negatedConstants); + + boolean negated = negatedResidual == 0 && (orResidual != 0 || negatedConstants.size() < orConstants.size()); + List constants = negated ? negatedConstants : orConstants; + if (constants.isEmpty()) + return null; + + long residual = negated ? negatedResidual : orResidual; + + StringBuilder replacement = new StringBuilder(write(constants.getFirst())); + for (int i = 1; i < constants.size(); i++) { + replacement.append(" | "); + replacement.append(write(constants.get(i))); + } + + if (residual != 0) { + replacement.append(" | ").append(residual); + } + + if (negated) { + return "~" + replacement; + } + + return replacement.toString(); + } + + /** + * Adds the constants that encompass {@code literal} to {@code constantsOut}. + * Returns the residual (bits set in the literal not covered by the returned constants). + */ + private static long getConstantsEncompassing(long literal, NumberType unsign, UnpickCollection.Group group, List constantsOut) { + long residual = literal; + for (var constant : group.constants().entrySet()) { + long val = unsign.toUnsignedLong((Number) constant.getKey()); + if ((val & residual) != 0 && (val & literal) == val) { + residual &= ~val; + constantsOut.add(constant.getValue()); + if (residual == 0) + break; + } + } + return residual; + } + + private static String removeSuffix(String in, char suffix) { + var lastChar = in.charAt(in.length() - 1); + if (lastChar == suffix || lastChar == Character.toUpperCase(suffix)) { + return in.substring(0, in.length() - 1); + } + return in; + } +} diff --git a/unpick/src/main/resources/META-INF/services/net.neoforged.jst.api.SourceTransformerPlugin b/unpick/src/main/resources/META-INF/services/net.neoforged.jst.api.SourceTransformerPlugin new file mode 100644 index 0000000..da41d00 --- /dev/null +++ b/unpick/src/main/resources/META-INF/services/net.neoforged.jst.api.SourceTransformerPlugin @@ -0,0 +1 @@ +net.neoforged.jst.unpick.UnpickPlugin