Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,8 @@
Class<?> retConversion() default void.class;

Class<?>[] argConversions() default {};

boolean critical() default false;

boolean captureCallState() default false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -1411,8 +1411,8 @@ private void generateExternalFunctionInvoker(List<CApiExternalFunctionSignatureD
lines.add("package " + EXFUNC_INVOKER_PACKAGE + ";");
lines.add("");
lines.add("import java.lang.invoke.MethodHandle;");
lines.add("import java.lang.invoke.MethodType;");
lines.add("");
lines.add("import " + NativeSimpleType.class.getCanonicalName() + ";");
lines.add("import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.EnsurePythonObjectNode;");
lines.add("import com.oracle.graal.python.builtins.objects.cext.capi.transitions.CApiTiming;");
lines.add("import com.oracle.graal.python.builtins.objects.function.PArguments;");
Expand Down Expand Up @@ -1600,13 +1600,24 @@ private void generateNativeAccessSupport(Element[] origins) throws IOException {
lines.add("import java.lang.invoke.MethodHandle;");
lines.add("import java.lang.invoke.MethodHandles;");
lines.add("import java.lang.invoke.MethodType;");
lines.add("import java.lang.invoke.VarHandle;");
lines.add("import java.util.ArrayList;");
lines.add("import java.util.List;");
lines.add("import java.util.OptionalLong;");
lines.add("");
lines.add("import static com.oracle.truffle.api.CompilerDirectives.shouldNotReachHere;");
lines.add("import com.oracle.graal.python.PythonLanguage;");
lines.add("import " + NativeSimpleType.class.getCanonicalName() + ";");
lines.add("import com.oracle.graal.python.annotations.PythonOS;");
lines.add("import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;");
lines.add("");
lines.add("public final class " + NATIVE_ACCESS_SUPPORT_IMPL_CLASS_NAME + " extends " + NATIVE_ACCESS_SUPPORT_CLASS_NAME + " {");
lines.add(" private static final MethodHandle OF_ADDRESS;");
lines.add(" private static final boolean WINDOWS = PythonLanguage.getPythonOS() == PythonOS.PLATFORM_WIN32;");
lines.add(" private static final MemoryLayout CAPTURE_STATE_LAYOUT = Linker.Option.captureStateLayout();");
lines.add(" private static final VarHandle ERRNO = captureStateVarHandle(\"errno\");");
lines.add(" private static final VarHandle GET_LAST_ERROR = WINDOWS ? captureStateVarHandle(\"GetLastError\") : null;");
lines.add(" private static final Linker.Option CAPTURE_CALL_STATE_OPTION = WINDOWS ? Linker.Option.captureCallState(\"errno\", \"GetLastError\", \"WSAGetLastError\") : Linker.Option.captureCallState(\"errno\");");
lines.add("");
lines.add(" static {");
lines.add(" try {");
Expand All @@ -1616,6 +1627,15 @@ private void generateNativeAccessSupport(Element[] origins) throws IOException {
lines.add(" }");
lines.add(" }");
lines.add("");
lines.add(" private static VarHandle captureStateVarHandle(String stateName) {");
lines.add(" return CAPTURE_STATE_LAYOUT.varHandle(MemoryLayout.PathElement.groupElement(stateName));");
lines.add(" }");
lines.add("");
lines.add(" private static MethodType createMethodType(boolean captureCallState, FunctionDescriptor functionDescriptor) {");
lines.add(" MethodType methodType = functionDescriptor.toMethodType().insertParameterTypes(0, long.class);");
lines.add(" return captureCallState ? methodType.insertParameterTypes(1, MemorySegment.class) : methodType;");
lines.add(" }");
lines.add("");
lines.add(" @Override");
lines.add(" protected Object createArenaImpl() {");
lines.add(" return Arena.ofShared();");
Expand All @@ -1640,15 +1660,42 @@ private void generateNativeAccessSupport(Element[] origins) throws IOException {
lines.add("");
lines.add(" @Override");
lines.add(" @SuppressWarnings(\"restricted\")");
lines.add(" protected MethodHandle createDowncallHandleImpl(MethodType methodType, boolean critical) {");
lines.add(" FunctionDescriptor functionDescriptor = createFunctionDescriptor(methodType);");
lines.add(" Linker.Option[] options = critical ? new Linker.Option[] { Linker.Option.critical(false) } : new Linker.Option[0];");
lines.add(" MethodHandle methodHandle = Linker.nativeLinker().downcallHandle(functionDescriptor, options);");
lines.add(" protected MethodHandle createDowncallHandleImpl(boolean critical, boolean captureCallState, NativeSimpleType resType, NativeSimpleType[] argTypes) {");
lines.add(" FunctionDescriptor functionDescriptor = createFunctionDescriptor(resType, argTypes);");
lines.add(" MethodType methodType = createMethodType(captureCallState, functionDescriptor);");
lines.add(" List<Linker.Option> options = new ArrayList<>();");
lines.add(" if (critical) {");
lines.add(" options.add(Linker.Option.critical(false));");
lines.add(" }");
lines.add(" if (captureCallState) {");
lines.add(" options.add(CAPTURE_CALL_STATE_OPTION);");
lines.add(" }");
lines.add(" MethodHandle methodHandle = Linker.nativeLinker().downcallHandle(functionDescriptor, options.toArray(Linker.Option[]::new));");
lines.add(" methodHandle = MethodHandles.filterArguments(methodHandle, 0, OF_ADDRESS);");
lines.add(" return methodHandle.asType(methodType);");
lines.add(" }");
lines.add("");
lines.add(" @Override");
lines.add(" protected Object createCapturedCallStateImpl(Object arena) {");
lines.add(" return ((Arena) arena).allocate(CAPTURE_STATE_LAYOUT);");
lines.add(" }");
lines.add("");
lines.add(" @TruffleBoundary");
lines.add(" @Override");
lines.add(" protected int readCapturedErrnoImpl(Object capturedCallState) {");
lines.add(" return (int) ERRNO.get((MemorySegment) capturedCallState, 0L);");
lines.add(" }");
lines.add("");
lines.add(" @TruffleBoundary");
lines.add(" @Override");
lines.add(" protected int readCapturedGetLastErrorImpl(Object capturedCallState) {");
lines.add(" if (GET_LAST_ERROR == null) {");
lines.add(" throw shouldNotReachHere(\"GetLastError is only captured on Windows\");");
lines.add(" }");
lines.add(" return (int) GET_LAST_ERROR.get((MemorySegment) capturedCallState, 0L);");
lines.add(" }");
lines.add("");
lines.add(" @Override");
lines.add(" @SuppressWarnings(\"restricted\")");
lines.add(" protected long createClosureImpl(MethodHandle staticMethodHandle, NativeSimpleType resType, NativeSimpleType[] argTypes, Object arena) {");
lines.add(" FunctionDescriptor functionDescriptor = createFunctionDescriptor(resType, argTypes);");
Expand All @@ -1668,33 +1715,6 @@ private void generateNativeAccessSupport(Element[] origins) throws IOException {
lines.add(" return resType == NativeSimpleType.VOID ? FunctionDescriptor.ofVoid(argLayouts) : FunctionDescriptor.of(asLayout(resType), argLayouts);");
lines.add(" }");
lines.add("");
lines.add(" private static FunctionDescriptor createFunctionDescriptor(MethodType methodType) {");
lines.add(" Class<?>[] parameterTypes = methodType.parameterArray();");
lines.add(" MemoryLayout[] argLayouts = new MemoryLayout[parameterTypes.length - 1];");
lines.add(" for (int i = 1; i < parameterTypes.length; i++) {");
lines.add(" argLayouts[i - 1] = asLayout(parameterTypes[i]);");
lines.add(" }");
lines.add(" Class<?> returnType = methodType.returnType();");
lines.add(" return returnType == void.class ? FunctionDescriptor.ofVoid(argLayouts) : FunctionDescriptor.of(asLayout(returnType), argLayouts);");
lines.add(" }");
lines.add("");
lines.add(" private static MemoryLayout asLayout(Class<?> type) {");
lines.add(" if (type == byte.class) {");
lines.add(" return ValueLayout.JAVA_BYTE;");
lines.add(" } else if (type == short.class) {");
lines.add(" return ValueLayout.JAVA_SHORT;");
lines.add(" } else if (type == int.class) {");
lines.add(" return ValueLayout.JAVA_INT;");
lines.add(" } else if (type == long.class) {");
lines.add(" return ValueLayout.JAVA_LONG;");
lines.add(" } else if (type == float.class) {");
lines.add(" return ValueLayout.JAVA_FLOAT;");
lines.add(" } else if (type == double.class) {");
lines.add(" return ValueLayout.JAVA_DOUBLE;");
lines.add(" }");
lines.add(" throw shouldNotReachHere(\"Unsupported layout carrier: \" + type);");
lines.add(" }");
lines.add("");
lines.add(" private static MemoryLayout asLayout(NativeSimpleType type) {");
lines.add(" return switch (type) {");
lines.add(" case VOID -> throw shouldNotReachHere(\"VOID has no layout\");");
Expand Down Expand Up @@ -1724,7 +1744,6 @@ private void generateDummyNativeAccessSupport(Element[] origins) throws IOExcept
lines.add("package " + NATIVE_ACCESS_PACKAGE + ";");
lines.add("");
lines.add("import java.lang.invoke.MethodHandle;");
lines.add("import java.lang.invoke.MethodType;");
lines.add("");
lines.add("import " + NativeSimpleType.class.getCanonicalName() + ";");
lines.add("");
Expand All @@ -1750,7 +1769,22 @@ private void generateDummyNativeAccessSupport(Element[] origins) throws IOExcept
lines.add(" }");
lines.add("");
lines.add(" @Override");
lines.add(" protected MethodHandle createDowncallHandleImpl(MethodType methodType, boolean critical) {");
lines.add(" protected MethodHandle createDowncallHandleImpl(boolean critical, boolean captureCallState, NativeSimpleType resType, NativeSimpleType[] argTypes) {");
lines.add(" throw unsupported();");
lines.add(" }");
lines.add("");
lines.add(" @Override");
lines.add(" protected Object createCapturedCallStateImpl(Object arena) {");
lines.add(" throw unsupported();");
lines.add(" }");
lines.add("");
lines.add(" @Override");
lines.add(" protected int readCapturedErrnoImpl(Object state) {");
lines.add(" throw unsupported();");
lines.add(" }");
lines.add("");
lines.add(" @Override");
lines.add(" protected int readCapturedGetLastErrorImpl(Object state) {");
lines.add(" throw unsupported();");
lines.add(" }");
lines.add("");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@
import com.oracle.graal.python.annotations.NativeSimpleType;

public class GenerateNativeDowncallsProcessor extends AbstractProcessor {
private record NativeDowncallDesc(String name, String symbolName, NativeSimpleType returnType, List<NativeSimpleType> argumentTypes, List<String> argumentNames) {
private record NativeDowncallDesc(String name, NativeSimpleType returnType, List<NativeSimpleType> argumentTypes, List<String> argumentNames, boolean critical, boolean captureCallState) {
}

private static final String CAPTURE_CALL_STATE_FIELD = "callState";
private static final String NATIVE_MEMORY_SEGMENT_CLASS_FIELD = "NATIVE_MEMORY_SEGMENT_CLASS";

@Override
public Set<String> getSupportedAnnotationTypes() {
return Set.of(DowncallSignature.class.getName());
Expand Down Expand Up @@ -124,6 +127,9 @@ private void generateInvoker(TypeElement invokerElement) throws IOException, Pro
throw error(invokerElement, "Annotated class does not declare any downcalls");
}

// Indicates if any of the annotations wants to capture the state. In this case, we need to generate appropriate code.
boolean hasCapturedCallStates = hasCapturedCallStates(downcalls);

String packageName = processingEnv.getElementUtils().getPackageOf(invokerElement).getQualifiedName().toString();
String invokerQualifiedName = invokerElement.getQualifiedName().toString();
String invokerTypeRef = invokerQualifiedName.startsWith(packageName + ".") ? invokerQualifiedName.substring(packageName.length() + 1) : invokerQualifiedName;
Expand All @@ -135,28 +141,38 @@ private void generateInvoker(TypeElement invokerElement) throws IOException, Pro
lines.add("// Generated by annotation processor: " + getClass().getName());
lines.add("package " + packageName + ";");
lines.add("");
lines.add("import java.lang.foreign.MemorySegment;");
lines.add("import java.lang.invoke.MethodHandle;");
lines.add("import java.lang.invoke.MethodType;");
lines.add("import java.util.concurrent.atomic.AtomicLongArray;");
lines.add("import java.util.List;");
lines.add("");
lines.add("import com.oracle.graal.python.annotations.NativeSimpleType;");
lines.add("import com.oracle.graal.python.runtime.nativeaccess.NativeAccessSupport;");
lines.add("import com.oracle.graal.python.runtime.nativeaccess.NativeLibrary;");
lines.add("import com.oracle.truffle.api.CompilerDirectives;");
lines.add("import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;");
lines.add("");
lines.add("final class " + className + " extends " + invokerTypeRef + " {");
lines.add(" private static final Class<? extends MemorySegment> " + NATIVE_MEMORY_SEGMENT_CLASS_FIELD + " = MemorySegment.NULL.getClass();");
lines.add(" private final PythonContext context;");
lines.add(" private final AtomicLongArray cachedFunctions = new AtomicLongArray(" + downcalls.size() + ");");
lines.add(" private volatile NativeLibrary nativeLibrary;");
if (hasCapturedCallStates) {
lines.add(" private final ThreadLocal<Object> " + CAPTURE_CALL_STATE_FIELD + ";");
}
lines.add("");

for (NativeDowncallDesc downcall : downcalls) {
NativeDowncallMethodHandleGenerator.emitMethodHandleField(lines, methodHandleName(downcall.name), downcall.returnType, downcall.argumentTypes);
NativeDowncallMethodHandleGenerator.emitMethodHandleField(lines, methodHandleName(downcall.name), downcall.returnType, downcall.argumentTypes, downcall.critical,
downcall.captureCallState);
}

lines.add("");
lines.add(" " + className + "(PythonContext context) {");
lines.add(" this.context = context;");
if (hasCapturedCallStates) {
lines.add(" this." + CAPTURE_CALL_STATE_FIELD + " = ThreadLocal.withInitial(() -> context.ensureNativeContext().getCapturedCallStateTL().get());");
}
lines.add(" }");

for (int i = 0; i < downcalls.size(); i++) {
Expand Down Expand Up @@ -229,13 +245,13 @@ private static NativeDowncallDesc extractDowncall(ExecutableElement method) thro

List<NativeSimpleType> argumentTypes = List.of(argTypes);
List<String> argumentNames = extractArgumentNames(method);
String symbolName = method.getSimpleName().toString();
return new NativeDowncallDesc(
symbolName,
symbolName,
method.getSimpleName().toString(),
annotation.returnType(),
argumentTypes,
argumentNames);
argumentNames,
annotation.critical(),
annotation.captureCallState());
}

private static List<String> extractArgumentNames(ExecutableElement method) throws ProcessingError {
Expand Down Expand Up @@ -289,12 +305,12 @@ private static void emitDowncallMethod(List<String> lines, NativeDowncallDesc do
lines.add(" @TruffleBoundary(allowInlining = true, transferToInterpreterOnException = false)");
lines.add(" @Override");
lines.add(" " + nativeSimpleTypeToJavaType(downcall.returnType) + " " + downcall.name + "(" + typedArgs(downcall.argumentTypes, downcall.argumentNames) + ") {");
lines.add(" long functionPointer = lookup(" + functionIndex + ", " + stringLiteral(downcall.symbolName) + ");");
lines.add(" long functionPointer = lookup(" + functionIndex + ", " + stringLiteral(downcall.name) + ");");
lines.add(" try {");
if (NativeSimpleType.VOID == downcall.returnType) {
lines.add(" " + methodHandleName(downcall.name) + ".invokeExact(" + invokeArgs(downcall.argumentNames) + ");");
lines.add(" " + methodHandleName(downcall.name) + ".invokeExact(" + invokeArgs(downcall) + ");");
} else {
lines.add(" return (" + nativeSimpleTypeToJavaType(downcall.returnType) + ") " + methodHandleName(downcall.name) + ".invokeExact(" + invokeArgs(downcall.argumentNames) + ");");
lines.add(" return (" + nativeSimpleTypeToJavaType(downcall.returnType) + ") " + methodHandleName(downcall.name) + ".invokeExact(" + invokeArgs(downcall) + ");");
}
lines.add(" } catch (Throwable t) {");
lines.add(" throw CompilerDirectives.shouldNotReachHere(t);");
Expand All @@ -314,12 +330,27 @@ private static String typedArgs(List<NativeSimpleType> argTypes, List<String> ar
return String.join(", ", args);
}

private static String invokeArgs(List<String> argNames) {
String args = String.join(", ", argNames);
return args.isEmpty() ? "functionPointer" : "functionPointer, " + args;
private static String invokeArgs(NativeDowncallDesc downcall) {
List<String> allArgs = new ArrayList<>();
allArgs.add("functionPointer");
if (downcall.captureCallState) {
String capturedCallState = String.format("%s.cast(%s.get())", NATIVE_MEMORY_SEGMENT_CLASS_FIELD, CAPTURE_CALL_STATE_FIELD);
allArgs.add(capturedCallState);
}
allArgs.addAll(downcall.argumentNames);
return String.join(", ", allArgs);
}

private static String stringLiteral(String value) {
return "\"" + value.replace("\\", "\\\\").replace("\"", "\\\"") + "\"";
}

private static boolean hasCapturedCallStates(List<NativeDowncallDesc> downcalls) {
for (NativeDowncallDesc downcall : downcalls) {
if (downcall.captureCallState) {
return true;
}
}
return false;
}
}
Loading
Loading