Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,11 @@ def inner():

code = compile(codestr, "<test>", "exec")
assert "module doc" in code.co_consts
assert 1 in code.co_consts
assert "fn doc" not in code.co_consts
for const in code.co_consts:
if type(const) == types.CodeType:
code = const
assert "fn doc" in code.co_consts
assert "this is fun" not in code.co_consts
for const in code.co_consts:
if type(const) == types.CodeType:
code = const
assert "this is fun" in code.co_consts


def test_generator_code_consts():
Expand Down Expand Up @@ -301,8 +295,7 @@ def bar():
bar = foo()
assert bar.__code__ is foo().__code__
i = foo.__code__.co_consts.index(bar.__code__)
# TODO this is currently broken on the DSL interpreter because the code unit in constants is a separate copy
# assert bar.__code__ is foo.__code__.co_consts[i]
assert bar.__code__ is foo.__code__.co_consts[i]
assert bar.__code__ is bar().f_code

foo_copy = types.FunctionType(marshal.loads(marshal.dumps(foo.__code__)), globals=foo.__globals__, closure=foo.__closure__)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved.
# Copyright (c) 2018, 2026, Oracle and/or its affiliates. All rights reserved.
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
#
# The Universal Permissive License (UPL), Version 1.0
Expand Down Expand Up @@ -779,7 +779,6 @@ def test_annotations_in_function():
exec(code,test_globals)
assert len(test_globals['__annotations__']) == 0
assert len(test_globals['fn'].__annotations__) == 0
assert 1 not in test_globals['fn'].__code__.co_consts # the annotation is ignored in function

source = '''def fn():
a:int =1
Expand All @@ -789,7 +788,6 @@ def test_annotations_in_function():
assert len(test_globals['__annotations__']) == 0
assert hasattr(test_globals['fn'], '__annotations__')
assert len(test_globals['fn'].__annotations__) == 0
assert 1 in test_globals['fn'].__code__.co_consts

def test_annotations_in_class():

Expand Down Expand Up @@ -849,58 +847,6 @@ def test_annotations_in_class():
assert test_globals['Style'].__annotations__['_path'] == str
assert '_path' in dir(test_globals['Style'])

def test_negative_float():

def check_const(fn, expected):
for const in fn.__code__.co_consts:
if repr(const) == repr(expected):
return True
else:
return False

def fn1():
return -0.0

assert check_const(fn1, -0.0)


def find_count_in(collection, what):
count = 0;
for item in collection:
if item == what:
count +=1
return count

def test_same_consts():
def fn1(): a = 1; b = 1; return a + b
assert find_count_in(fn1.__code__.co_consts, 1) == 1

def fn2(): a = 'a'; b = 'a'; return a + b
assert find_count_in(fn2.__code__.co_consts, 'a') == 1

def test_tuple_in_const():
def fn1() : return (0,)
assert (0,) in fn1.__code__.co_consts
assert 0 not in fn1.__code__.co_consts

def fn2() : return (1, 2, 3, 1, 2, 3)
assert (1, 2, 3, 1, 2, 3) in fn2.__code__.co_consts
assert 1 not in fn2.__code__.co_consts
assert 2 not in fn2.__code__.co_consts
assert 3 not in fn2.__code__.co_consts
assert find_count_in(fn2.__code__.co_consts, (1, 2, 3, 1, 2, 3)) == 1

def fn3() : a = 1; return (1, 2, 1)
assert (1, 2, 1) in fn3.__code__.co_consts
assert find_count_in(fn3.__code__.co_consts, 1) == 1
assert 2 not in fn3.__code__.co_consts

def fn4() : a = 1; b = (1,2,3); c = 4; return (1, 2, 3, 1, 2, 3)
assert (1, 2, 3) in fn4.__code__.co_consts
assert (1, 2, 3, 1, 2, 3) in fn4.__code__.co_consts
assert 2 not in fn4.__code__.co_consts
assert find_count_in(fn4.__code__.co_consts, 1) == 1
assert find_count_in(fn4.__code__.co_consts, 4) == 1

def test_ComprehensionGeneratorExpr():
def create_list(gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,11 @@ private void writeBytecodeCodeUnit(BytecodeCodeUnit code) throws IOException {
}

private void writeBytecodeDSLCodeUnit(BytecodeDSLCodeUnit code) throws IOException {
/*
* Nested code units referenced by MakeFunction are stored in co_consts; the
* MakeFunction instruction itself carries only the integer index into this constants
* array.
*/
byte[] serialized = code.getSerialized(context);
writeBytes(serialized);
writeString(code.name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,20 @@
package com.oracle.graal.python.builtins.objects.code;

import static com.oracle.graal.python.nodes.StringLiterals.J_EMPTY_STRING;
import static com.oracle.graal.python.util.PythonUtils.EMPTY_OBJECT_ARRAY;
import static com.oracle.graal.python.util.PythonUtils.EMPTY_TRUFFLESTRING_ARRAY;
import static com.oracle.graal.python.util.PythonUtils.toInternedTruffleStringUncached;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.oracle.graal.python.PythonLanguage;
import com.oracle.graal.python.builtins.objects.PNone;
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
import com.oracle.graal.python.builtins.objects.ellipsis.PEllipsis;
import com.oracle.graal.python.builtins.objects.function.Signature;
import com.oracle.graal.python.builtins.objects.generator.PGenerator;
import com.oracle.graal.python.builtins.objects.object.PythonBuiltinObject;
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
import com.oracle.graal.python.compiler.BytecodeCodeUnit;
import com.oracle.graal.python.compiler.CodeUnit;
import com.oracle.graal.python.compiler.OpCodes;
import com.oracle.graal.python.nodes.PRootNode;
import com.oracle.graal.python.nodes.bytecode.PBytecodeGeneratorFunctionRootNode;
import com.oracle.graal.python.nodes.bytecode.PBytecodeGeneratorRootNode;
Expand Down Expand Up @@ -136,8 +126,6 @@ public final class PCode extends PythonBuiltinObject {
// qualified name with which this code object was defined
private TruffleString qualname;

private Map<CodeUnit, PCode> childCode;

// number of first line in Python source code
private int firstlineno = -1;
// is a string encoding the mapping from bytecode offsets to line numbers
Expand Down Expand Up @@ -293,55 +281,23 @@ private static TruffleString[] extractVarnames(RootNode node) {
return EMPTY_TRUFFLESTRING_ARRAY;
}

private Object[] ensureConstants() {
if (constants == null) {
CodeUnit codeUnit = getCodeUnit(getRootNode());
constants = codeUnit != null ? new Object[codeUnit.constants.length] : PythonUtils.EMPTY_OBJECT_ARRAY;
}
return constants;
}

@TruffleBoundary
private Object[] extractConstants(RootNode node) {
RootNode rootNode = rootNodeForExtraction(node);
if (PythonOptions.ENABLE_BYTECODE_DSL_INTERPRETER) {
if (rootNode instanceof PBytecodeDSLRootNode bytecodeDSLRootNode) {
BytecodeDSLCodeUnit co = bytecodeDSLRootNode.getCodeUnit();
List<Object> constants = new ArrayList<>();
for (int i = 0; i < co.constants.length; i++) {
Object constant = convertConstantToPythonSpace(co.constants[i]);
constants.add(constant);
}
return constants.toArray(new Object[0]);
}
} else if (rootNode instanceof PBytecodeRootNode bytecodeRootNode) {
BytecodeCodeUnit co = bytecodeRootNode.getCodeUnit();
Set<Object> bytecodeConstants = new HashSet<>();
for (int bci = 0; bci < co.code.length;) {
OpCodes op = OpCodes.fromOpCode(co.code[bci]);
if (op.quickens != null) {
op = op.quickens;
}
if (op == OpCodes.LOAD_BYTE) {
bytecodeConstants.add(Byte.toUnsignedInt(co.code[bci + 1]));
} else if (op == OpCodes.LOAD_NONE) {
bytecodeConstants.add(PNone.NONE);
} else if (op == OpCodes.LOAD_TRUE) {
bytecodeConstants.add(true);
} else if (op == OpCodes.LOAD_FALSE) {
bytecodeConstants.add(false);
} else if (op == OpCodes.LOAD_ELLIPSIS) {
bytecodeConstants.add(PEllipsis.INSTANCE);
} else if (op == OpCodes.LOAD_INT || op == OpCodes.LOAD_LONG) {
bytecodeConstants.add(co.primitiveConstants[Byte.toUnsignedInt(co.code[bci + 1])]);
} else if (op == OpCodes.LOAD_DOUBLE) {
bytecodeConstants.add(Double.longBitsToDouble(co.primitiveConstants[Byte.toUnsignedInt(co.code[bci + 1])]));
}
bci += op.length();
}
List<Object> constants = new ArrayList<>();
for (int i = 0; i < co.constants.length; i++) {
Object constant = convertConstantToPythonSpace(co.constants[i]);
if (constant != PNone.NONE || !bytecodeConstants.contains(PNone.NONE)) {
constants.add(constant);
}
}
constants.addAll(bytecodeConstants);
return constants.toArray(new Object[0]);
private Object getOrCreateConstant(int index) {
Object[] cachedConstants = ensureConstants();
Object constant = cachedConstants[index];
if (constant == null) {
constant = convertConstantToPythonSpace(index);
cachedConstants[index] = constant;
}
return EMPTY_OBJECT_ARRAY;
return constant;
}

@TruffleBoundary
Expand Down Expand Up @@ -419,9 +375,11 @@ public void fixCoFilename(TruffleString filename) {
* New code objects inherit the filename from parent, so no need to eagerly construct them
* here
*/
if (childCode != null) {
for (PCode code : childCode.values()) {
code.filename = filename;
if (constants != null) {
for (Object constant : constants) {
if (constant instanceof PCode code) {
code.filename = filename;
}
}
}
}
Expand Down Expand Up @@ -525,65 +483,66 @@ public CodeUnit getCodeUnit() {
}

public Object[] getConstants() {
if (constants == null) {
constants = extractConstants(getRootNode());
Object[] cachedConstants = ensureConstants();
for (int i = 0; i < cachedConstants.length; i++) {
getOrCreateConstant(i);
}
return constants;
return cachedConstants;
}

@TruffleBoundary
public PCode getOrCreateChildCode(BytecodeDSLCodeUnit codeUnit) {
PCode code = null;
if (childCode == null) {
childCode = new HashMap<>();
} else {
code = childCode.get(codeUnit);
}
public PCode getOrCreateChildCode(int index, BytecodeDSLCodeUnit codeUnit) {
Object[] cachedConstants = ensureConstants();
PCode code = (PCode) cachedConstants[index];
if (code == null) {
PBytecodeDSLRootNode outerRootNode = (PBytecodeDSLRootNode) getRootNode();
PythonLanguage language = outerRootNode.getLanguage();
RootCallTarget callTarget = language.createCachedCallTarget(l -> codeUnit.createRootNode(PythonContext.get(null), outerRootNode.getSource()), codeUnit);
PBytecodeDSLRootNode rootNode = (PBytecodeDSLRootNode) callTarget.getRootNode();
code = PFactory.createCode(language, callTarget, rootNode.getSignature(), codeUnit, getFilename());
childCode.put(codeUnit, code);
code = createCode(codeUnit);
cachedConstants[index] = code;
}
return code;
}

@TruffleBoundary
public PCode getOrCreateChildCode(BytecodeCodeUnit codeUnit) {
PCode code = null;
if (childCode == null) {
childCode = new HashMap<>();
} else {
code = childCode.get(codeUnit);
}
private PCode createCode(BytecodeDSLCodeUnit codeUnit) {
PBytecodeDSLRootNode outerRootNode = (PBytecodeDSLRootNode) getRootNode();
PythonLanguage language = outerRootNode.getLanguage();
RootCallTarget callTarget = language.createCachedCallTarget(l -> codeUnit.createRootNode(PythonContext.get(null), outerRootNode.getSource()), codeUnit);
PBytecodeDSLRootNode rootNode = (PBytecodeDSLRootNode) callTarget.getRootNode();
return PFactory.createCode(language, callTarget, rootNode.getSignature(), codeUnit, getFilename());
}

public PCode getOrCreateChildCode(int index, BytecodeCodeUnit codeUnit) {
Object[] cachedConstants = ensureConstants();
PCode code = (PCode) cachedConstants[index];
if (code == null) {
PBytecodeRootNode outerRootNode = (PBytecodeRootNode) getRootNodeForExtraction();
PythonLanguage language = outerRootNode.getLanguage();
RootCallTarget callTarget = language.createCachedCallTarget(
l -> PBytecodeRootNode.createMaybeGenerator(language, codeUnit, outerRootNode.getLazySource(), outerRootNode.isInternal()),
codeUnit);
RootNode rootNode = callTarget.getRootNode();
if (rootNode instanceof PBytecodeGeneratorFunctionRootNode generatorRoot) {
rootNode = generatorRoot.getBytecodeRootNode();
}
code = PFactory.createCode(language, callTarget, ((PBytecodeRootNode) rootNode).getSignature(), codeUnit, getFilename());
childCode.put(codeUnit, code);
code = createCode(codeUnit);
cachedConstants[index] = code;
}
return code;
}

@TruffleBoundary
private Object convertConstantToPythonSpace(Object o) {
private PCode createCode(BytecodeCodeUnit codeUnit) {
PBytecodeRootNode outerRootNode = (PBytecodeRootNode) getRootNodeForExtraction();
PythonLanguage language = outerRootNode.getLanguage();
RootCallTarget callTarget = language.createCachedCallTarget(
l -> PBytecodeRootNode.createMaybeGenerator(language, codeUnit, outerRootNode.getLazySource(), outerRootNode.isInternal()), codeUnit);
RootNode rootNode = callTarget.getRootNode();
if (rootNode instanceof PBytecodeGeneratorFunctionRootNode generatorRoot) {
rootNode = generatorRoot.getBytecodeRootNode();
}
return PFactory.createCode(language, callTarget, ((PBytecodeRootNode) rootNode).getSignature(), codeUnit, getFilename());
}

@TruffleBoundary
private Object convertConstantToPythonSpace(int index) {
Object o = getCodeUnit().constants[index];
PythonLanguage language = PythonLanguage.get(null);
if (o instanceof CodeUnit) {
if (PythonOptions.ENABLE_BYTECODE_DSL_INTERPRETER) {
BytecodeDSLCodeUnit code = (BytecodeDSLCodeUnit) o;
return getOrCreateChildCode(code);
return getOrCreateChildCode(index, code);
} else {
BytecodeCodeUnit code = (BytecodeCodeUnit) o;
return getOrCreateChildCode(code);
return getOrCreateChildCode(index, code);
}
} else if (o instanceof BigInteger) {
return PFactory.createInt(language, (BigInteger) o);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2025, 2026, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* The Universal Permissive License (UPL), Version 1.0
Expand Down Expand Up @@ -41,6 +41,8 @@
package com.oracle.graal.python.compiler.bytecode_dsl;

import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;

import com.oracle.graal.python.PythonLanguage;
import com.oracle.graal.python.compiler.Compiler;
Expand Down Expand Up @@ -94,6 +96,8 @@ public static class BytecodeDSLCompilerContext {
public final int futureLineNumber;
public final ParserCallbacksImpl errorCallback;
public final ScopeEnvironment scopeEnvironment;
// Store code units for possible reparses
public final Map<Object, BytecodeDSLCodeUnit> codeUnits = new HashMap<>();

public BytecodeDSLCompilerContext(PythonLanguage language, ModTy mod, Source source, int optimizationLevel,
EnumSet<FutureFeature> futureFeatures, int futureLineNumber, ParserCallbacksImpl errorCallback, ScopeEnvironment scopeEnvironment) {
Expand Down
Loading
Loading