diff --git a/graalpython/com.oracle.graal.python.test.integration/src/com/oracle/graal/python/test/integration/builtin/HashingTest.java b/graalpython/com.oracle.graal.python.test.integration/src/com/oracle/graal/python/test/integration/builtin/HashingTest.java index 17dc9bdc5a..749b357add 100644 --- a/graalpython/com.oracle.graal.python.test.integration/src/com/oracle/graal/python/test/integration/builtin/HashingTest.java +++ b/graalpython/com.oracle.graal.python.test.integration/src/com/oracle/graal/python/test/integration/builtin/HashingTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2026, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * The Universal Permissive License (UPL), Version 1.0 @@ -197,8 +197,9 @@ public void dictEqualTest3() { @Test public void setAndTest() { - assertPrints("{2}\n", "print({2, 3} ^ {3})\n"); - assertPrints("{'c', 'b'}\n", "print({'a', 'c'} ^ frozenset({'a', 'b'}))\n"); - assertPrints("frozenset({'b'})\n", "print(frozenset({'a', 'c'}) ^ {'a', 'b', 'c'})\n"); + String source = "assert {2, 3} ^ {3} == {2}\n" + + "assert {'a', 'c'} ^ frozenset({'a', 'b'}) == {'b', 'c'}\n" + + "assert frozenset({'a', 'c'}) ^ {'a', 'b', 'c'} == frozenset({'b'})\n"; + assertPrints("", source); } } diff --git a/graalpython/com.oracle.graal.python.test/src/tests/test_dict.py b/graalpython/com.oracle.graal.python.test/src/tests/test_dict.py index b87a4082f6..3f6a93465a 100644 --- a/graalpython/com.oracle.graal.python.test/src/tests/test_dict.py +++ b/graalpython/com.oracle.graal.python.test/src/tests/test_dict.py @@ -1549,6 +1549,46 @@ def test_eq_side_effects(): assert key.eq_calls == 1 +def test_keys_xor_side_effects(): + key1 = TrackingKey('foo') + key2 = TrackingKey('foo') + d1 = {key1: 1} + d2 = {key2: 2} + key1.clear_observations() + key2.clear_observations() + + assert d1.keys() ^ d2.keys() == set() + assert key1.eq_calls == 1 + assert key1.hash_calls == 0 + assert key2.eq_calls == 0 + assert key2.hash_calls == 1 + + +def test_items_xor_side_effects(): + log = [] + + class Key: + def __init__(self, name): + self.name = name + + def __hash__(self): + log.append(("hash", self.name)) + return 42 + + def __eq__(self, other): + log.append(("eq", self.name, other.name)) + return True + + key1 = Key('left') + key2 = Key('right') + d1 = {key1: 1} + d2 = {key2: 1} + log.clear() + + assert d1.items() ^ d2.items() == set() + assert log == [("eq", "left", "right"), ("eq", "left", "right")] + + # TODO: GR-40680 # def test_iteration_and_del(): # def test_iter(get_iterable): diff --git a/graalpython/com.oracle.graal.python.test/src/tests/test_set.py b/graalpython/com.oracle.graal.python.test/src/tests/test_set.py index 2f587532ff..af1fceac19 100644 --- a/graalpython/com.oracle.graal.python.test/src/tests/test_set.py +++ b/graalpython/com.oracle.graal.python.test/src/tests/test_set.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved. +# Copyright (c) 2018, 2026, Oracle and/or its affiliates. All rights reserved. # DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. # # The Universal Permissive License (UPL), Version 1.0 @@ -596,29 +596,68 @@ def key1_eq_call(key1, key2): test_op(operator.__and__, key1_eq_call) test_op(operator.__iand__, key1_eq_call) - # TODO: GR-42240 - # - # def symmetric_difference_check(key1, key2): - # assert key1.eq_calls == 0 - # assert key1.hash_calls == 0 - # assert key2.eq_calls == 2 - # assert key2.hash_calls == 0 - # - # test_op(set.symmetric_difference, symmetric_difference_check) - # test_op(operator.__xor__, symmetric_difference_check) - # test_op(operator.__ixor__, symmetric_difference_check) - # - # def symmetric_difference_update_check(key1, key2): - # assert key1.eq_calls == 2 - # assert key1.hash_calls == 0 - # assert key2.eq_calls == 0 - # assert key2.hash_calls == 0 - # - # test_op(set.symmetric_difference_update, symmetric_difference_update_check) - # + def symmetric_difference_check(key1, key2): + assert key1.eq_calls == 0 + assert key1.hash_calls == 0 + assert key2.eq_calls == 2 + assert key2.hash_calls == 0 + + test_op(set.symmetric_difference, symmetric_difference_check) + test_op(operator.__xor__, symmetric_difference_check) + + def symmetric_difference_update_check(key1, key2): + assert key1.eq_calls == 2 + assert key1.hash_calls == 0 + assert key2.eq_calls == 0 + assert key2.hash_calls == 0 + + test_op(operator.__ixor__, symmetric_difference_update_check) + test_op(set.symmetric_difference_update, symmetric_difference_update_check) + # TODO: intersection, intersection_update +def test_symmetric_difference_dict_keys_side_effects(): + def test_op(op, check): + key1 = TrackingKey('foo', hash=42) + key2 = TrackingKey('bar', hash=42) + s = {key1} + d = {key2: 1} + key1.clear_observations() + key2.clear_observations() + op(s, d.keys()) + check(key1, key2) + + def symmetric_difference_check(key1, key2): + assert key1.eq_calls == 0 + assert key1.hash_calls == 0 + assert key2.eq_calls == 2 + assert key2.hash_calls == 1 + + def symmetric_difference_update_check(key1, key2): + assert key1.eq_calls == 2 + assert key1.hash_calls == 0 + assert key2.eq_calls == 0 + assert key2.hash_calls == 1 + + test_op(set.symmetric_difference, symmetric_difference_check) + test_op(set.symmetric_difference_update, symmetric_difference_update_check) + + +def test_symmetric_difference_update_empty_side_effects(): + key1 = TrackingKey('foo', hash=42) + key2 = TrackingKey('bar', hash=42) + s = {key1, key2} + key1.clear_observations() + key2.clear_observations() + + s.symmetric_difference_update(set()) + assert key1.eq_calls == 0 + assert key1.hash_calls == 0 + assert key2.eq_calls == 0 + assert key2.hash_calls == 0 + + def test_pop_side_effects(): class TrackingKey: def __init__(self): diff --git a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingCollectionNodes.java b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingCollectionNodes.java index 0102c81b7b..636edde508 100644 --- a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingCollectionNodes.java +++ b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingCollectionNodes.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2026, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * The Universal Permissive License (UPL), Version 1.0 @@ -319,4 +319,39 @@ static HashingStorage doGeneric(VirtualFrame frame, Node inliningTarget, Object return getHashingStorageNode.getForSets(frame, inliningTarget, other); } } + + /** + * CPython's set symmetric_difference_update uses stored hashes for exact sets and dicts, but + * first materializes other iterables, including dict views, as a temporary set. + */ + @GenerateInline(inlineByDefault = true) + public abstract static class GetSetStorageForXorNode extends PNodeWithContext { + + public abstract HashingStorage execute(VirtualFrame frame, Node inliningTarget, Object iterator); + + @Specialization + static HashingStorage doHashingCollection(PHashingCollection other) { + return other.getDictStorage(); + } + + @Specialization(guards = "!isPHashingCollection(other)") + @InliningCutoff + static HashingStorage doGeneric(VirtualFrame frame, Node inliningTarget, Object other, + @Cached PyObjectGetIter getIter, + @Cached PyIterNextNode nextNode, + @Exclusive @Cached HashingStorageSetItem setStorageItem) { + HashingStorage curStorage = EmptyStorage.INSTANCE; + Object iterator = getIter.execute(frame, inliningTarget, other); + while (true) { + Object key; + try { + key = nextNode.execute(frame, inliningTarget, iterator); + } catch (IteratorExhausted e) { + return curStorage; + } + curStorage = setStorageItem.execute(frame, inliningTarget, curStorage, key, PNone.NONE); + } + } + } + } diff --git a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingStorageNodes.java b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingStorageNodes.java index be5db59202..564d8a7b89 100644 --- a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingStorageNodes.java +++ b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingStorageNodes.java @@ -1370,23 +1370,23 @@ public ResultAndOther(ObjectHashMap result, HashingStorage other) { @GenerateInline @GenerateCached(false) @ImportStatic({PGuards.class}) - public abstract static class HashingStorageXorCallback extends HashingStorageForEachCallback { + public abstract static class HashingStorageXorCallback extends HashingStorageForEachCallback { @Override - public abstract ResultAndOther execute(Frame frame, Node inliningTarget, HashingStorage storage, HashingStorageIterator it, ResultAndOther accumulator); + public abstract EconomicMapStorage execute(Frame frame, Node inliningTarget, HashingStorage storage, HashingStorageIterator it, EconomicMapStorage accumulator); @Specialization - static ResultAndOther doGeneric(Frame frame, Node inliningTarget, HashingStorage storage, HashingStorageIterator it, ResultAndOther acc, + static EconomicMapStorage doGeneric(Frame frame, Node inliningTarget, HashingStorage storage, HashingStorageIterator it, EconomicMapStorage acc, @Cached PutNode putResultNode, - @Cached HashingStorageGetItemWithHash getFromOther, + @Cached ObjectHashMap.RemoveNode removeResultNode, @Cached HashingStorageIteratorKey iterKey, @Cached HashingStorageIteratorValue iterValue, @Cached HashingStorageIteratorKeyHash iterHash) { Object key = iterKey.execute(inliningTarget, storage, it); long hash = iterHash.execute(frame, inliningTarget, storage, it); - Object otherValue = getFromOther.execute(frame, inliningTarget, acc.other, key, hash); - if (otherValue == null) { - putResultNode.put(frame, inliningTarget, acc.result, key, hash, iterValue.execute(inliningTarget, storage, it)); + Object removedValue = removeResultNode.execute(frame, inliningTarget, acc, key, hash); + if (removedValue == null) { + putResultNode.put(frame, inliningTarget, acc, key, hash, iterValue.execute(inliningTarget, storage, it)); } return acc; } @@ -1397,22 +1397,41 @@ static ResultAndOther doGeneric(Frame frame, Node inliningTarget, HashingStorage @GenerateCached(false) @ImportStatic({PGuards.class}) public abstract static class HashingStorageXor extends Node { - public abstract HashingStorage execute(Frame frame, Node inliningTarget, HashingStorage a, HashingStorage b); + abstract HashingStorage execute(Frame frame, Node inliningTarget, HashingStorage left, HashingStorage right, boolean leftMayBeMutated); + + public final HashingStorage executePreservingLeft(Frame frame, Node inliningTarget, HashingStorage left, HashingStorage right) { + return execute(frame, inliningTarget, left, right, false); + } + + public final HashingStorage executeMutatingLeft(Frame frame, Node inliningTarget, HashingStorage left, HashingStorage right) { + return execute(frame, inliningTarget, left, right, true); + } @Specialization - static HashingStorage doIt(Frame frame, Node inliningTarget, HashingStorage aStorage, HashingStorage bStorage, - @Cached HashingStorageForEach forEachA, - @Cached HashingStorageForEach forEachB, - @Cached HashingStorageXorCallback callbackA, - @Cached HashingStorageXorCallback callbackB) { - final EconomicMapStorage result = EconomicMapStorage.createWithSideEffects(); - ObjectHashMap resultMap = result; + static HashingStorage doEconomicLeft(Frame frame, Node inliningTarget, EconomicMapStorage leftStorage, HashingStorage rightStorage, + boolean leftMayBeMutated, + @Exclusive @Cached HashingStorageForEach forEachRight, + @Exclusive @Cached HashingStorageXorCallback callback) { + if (leftStorage == rightStorage) { + return EmptyStorage.INSTANCE; + } + EconomicMapStorage result = leftMayBeMutated ? leftStorage : (EconomicMapStorage) leftStorage.copy(); + forEachRight.execute(frame, inliningTarget, rightStorage, callback, result); + return result; + } - ResultAndOther accA = new ResultAndOther(resultMap, bStorage); - forEachA.execute(frame, inliningTarget, aStorage, callbackA, accA); + @Specialization(replaces = "doEconomicLeft") + static HashingStorage doIt(Frame frame, Node inliningTarget, HashingStorage leftStorage, HashingStorage rightStorage, + @SuppressWarnings("unused") boolean leftMayBeMutated, + @Exclusive @Cached HashingStorageForEach forEachLeft, + @Exclusive @Cached HashingStorageForEach forEachRight, + @Exclusive @Cached HashingStorageTransferItem transferItem, + @Exclusive @Cached HashingStorageXorCallback callback) { + final EconomicMapStorage result = EconomicMapStorage.createWithSideEffects(); - ResultAndOther accB = new ResultAndOther(resultMap, aStorage); - forEachB.execute(frame, inliningTarget, bStorage, callbackB, accB); + HashingStorage copiedLeft = forEachLeft.execute(frame, inliningTarget, leftStorage, transferItem, result); + assert copiedLeft == result; + forEachRight.execute(frame, inliningTarget, rightStorage, callback, result); return result; } diff --git a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DictViewBuiltins.java b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DictViewBuiltins.java index a82ffa0a3a..387c1796e2 100644 --- a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DictViewBuiltins.java +++ b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DictViewBuiltins.java @@ -52,17 +52,31 @@ import com.oracle.graal.python.builtins.CoreFunctions; import com.oracle.graal.python.builtins.PythonBuiltinClassType; import com.oracle.graal.python.builtins.PythonBuiltins; +import com.oracle.graal.python.builtins.objects.PNone; import com.oracle.graal.python.builtins.objects.PNotImplemented; +import com.oracle.graal.python.builtins.objects.common.EconomicMapStorage; +import com.oracle.graal.python.builtins.objects.common.EmptyStorage; +import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes.GetSetStorageForXorNode; import com.oracle.graal.python.builtins.objects.common.HashingStorage; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageAddAllToOther; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageCopy; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageDiff; +import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageForEach; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageGetItem; +import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageGetItemWithHash; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageGetIterator; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageGetReverseIterator; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageIntersect; +import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageIterator; +import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageIteratorKey; +import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageIteratorKeyHash; +import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageIteratorValue; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageLen; +import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageSetItem; +import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageTransferItem; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageXor; +import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageForEachCallback; +import com.oracle.graal.python.builtins.objects.common.ObjectHashMap; import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes; import com.oracle.graal.python.builtins.objects.dict.DictViewBuiltinsFactory.ContainedInNodeGen; import com.oracle.graal.python.builtins.objects.dict.PDictView.PDictItemsView; @@ -84,6 +98,7 @@ import com.oracle.graal.python.lib.PyObjectSizeNode; import com.oracle.graal.python.lib.PySequenceContainsNode; import com.oracle.graal.python.lib.RichCmpOp; +import com.oracle.graal.python.nodes.PGuards; import com.oracle.graal.python.nodes.PNodeWithContext; import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode; import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode; @@ -100,6 +115,7 @@ import com.oracle.truffle.api.dsl.GenerateInline; import com.oracle.truffle.api.dsl.GenerateNodeFactory; import com.oracle.truffle.api.dsl.GenerateUncached; +import com.oracle.truffle.api.dsl.ImportStatic; import com.oracle.truffle.api.dsl.NeverDefault; import com.oracle.truffle.api.dsl.NodeFactory; import com.oracle.truffle.api.dsl.Specialization; @@ -533,19 +549,137 @@ static PBaseSet doGeneric(VirtualFrame frame, Object self, Object other, } } + static final class DictItemsXorState { + final PythonLanguage language; + final EconomicMapStorage remaining; + HashingStorage resultStorage = EmptyStorage.INSTANCE; + + DictItemsXorState(PythonLanguage language, EconomicMapStorage remaining) { + this.language = language; + this.remaining = remaining; + } + } + + @GenerateInline + @GenerateCached(false) + abstract static class CopyToEconomicMapNode extends Node { + abstract EconomicMapStorage execute(Frame frame, Node inliningTarget, HashingStorage storage); + + @Specialization + static EconomicMapStorage doEconomic(Node inliningTarget, EconomicMapStorage storage, + @Cached HashingStorageCopy copyNode) { + return (EconomicMapStorage) copyNode.execute(inliningTarget, storage); + } + + @Specialization(replaces = "doEconomic") + static EconomicMapStorage doGeneric(Frame frame, Node inliningTarget, HashingStorage storage, + @Cached HashingStorageForEach forEach, + @Cached HashingStorageTransferItem transferItem) { + EconomicMapStorage result = EconomicMapStorage.createWithSideEffects(); + HashingStorage copied = forEach.execute(frame, inliningTarget, storage, transferItem, result); + assert copied == result; + return result; + } + } + + @GenerateInline + @GenerateCached(false) + public abstract static class DictItemsXorCallback extends HashingStorageForEachCallback { + @Override + public abstract DictItemsXorState execute(Frame frame, Node inliningTarget, HashingStorage storage, HashingStorageIterator it, DictItemsXorState state); + + @Specialization + static DictItemsXorState doGeneric(Frame frame, Node inliningTarget, HashingStorage storage, HashingStorageIterator it, DictItemsXorState state, + @Cached HashingStorageGetItemWithHash getItem, + @Cached ObjectHashMap.RemoveNode removeNode, + @Cached PyObjectRichCompareBool eqNode, + @Cached HashingStorageSetItem setItem, + @Cached HashingStorageIteratorKey iterKey, + @Cached HashingStorageIteratorValue iterValue, + @Cached HashingStorageIteratorKeyHash iterHash) { + Object key = iterKey.execute(inliningTarget, storage, it); + long hash = iterHash.execute(frame, inliningTarget, storage, it); + Object rightValue = iterValue.execute(inliningTarget, storage, it); + Object leftValue = getItem.execute(frame, inliningTarget, state.remaining, key, hash); + boolean delete = leftValue != null && eqNode.execute(frame, inliningTarget, leftValue, rightValue, RichCmpOp.Py_EQ); + if (delete) { + removeNode.execute(frame, inliningTarget, state.remaining, key, hash); + } else { + PTuple pair = PFactory.createTuple(state.language, new Object[]{key, rightValue}); + state.resultStorage = setItem.execute(frame, inliningTarget, state.resultStorage, pair, PNone.NONE); + } + return state; + } + } + + @GenerateInline + @GenerateCached(false) + public abstract static class AddDictItemsToSetCallback extends HashingStorageForEachCallback { + @Override + public abstract HashingStorage execute(Frame frame, Node inliningTarget, HashingStorage storage, HashingStorageIterator it, HashingStorage resultStorage); + + @Specialization + static HashingStorage doGeneric(Frame frame, Node inliningTarget, HashingStorage storage, HashingStorageIterator it, HashingStorage resultStorage, + @Bind PythonLanguage language, + @Cached HashingStorageSetItem setItem, + @Cached HashingStorageIteratorKey iterKey, + @Cached HashingStorageIteratorValue iterValue) { + PTuple pair = PFactory.createTuple(language, new Object[]{iterKey.execute(inliningTarget, storage, it), iterValue.execute(inliningTarget, storage, it)}); + return setItem.execute(frame, inliningTarget, resultStorage, pair, PNone.NONE); + } + } + @Slot(value = SlotKind.nb_xor, isComplex = true) @GenerateNodeFactory + @ImportStatic(PGuards.class) public abstract static class XorNode extends BinaryOpBuiltinNode { @Specialization + static PBaseSet doItemsViews(VirtualFrame frame, PDictItemsView self, PDictItemsView other, + @Bind Node inliningTarget, + @Cached CopyToEconomicMapNode copyToEconomicMapNode, + @Cached HashingStorageForEach forEachRight, + @Cached HashingStorageForEach forEachRemaining, + @Cached DictItemsXorCallback xorCallback, + @Cached AddDictItemsToSetCallback addRemainingCallback, + @Bind PythonLanguage language) { + // CPython has a dedicated dictitems_xor helper: copy the left dict, iterate the right + // dict by key/hash, compare values for matching keys, and emit item tuples. + EconomicMapStorage remaining = copyToEconomicMapNode.execute(frame, inliningTarget, self.getWrappedStorage()); + DictItemsXorState state = new DictItemsXorState(language, remaining); + forEachRight.execute(frame, inliningTarget, other.getWrappedStorage(), xorCallback, state); + HashingStorage result = forEachRemaining.execute(frame, inliningTarget, remaining, addRemainingCallback, state.resultStorage); + return PFactory.createSet(language, result); + } + + @Specialization(guards = "isDictKeysView(self) || isAnySet(self)") + static PBaseSet doKeysViewOrSet(VirtualFrame frame, Object self, Object other, + @Bind Node inliningTarget, + @Shared @Cached GetSetStorageForXorNode getRightStorage, + @Shared @Cached HashingStorageXor xor, + @Bind PythonLanguage language) { + HashingStorage left = getKeysViewOrSetStorage(self); + HashingStorage right = getRightStorage.execute(frame, inliningTarget, other); + return PFactory.createSet(language, xor.executePreservingLeft(frame, inliningTarget, left, right)); + } + + private static HashingStorage getKeysViewOrSetStorage(Object self) { + if (self instanceof PDictKeysView keysView) { + return keysView.getWrappedStorage(); + } + return ((PBaseSet) self).getDictStorage(); + } + + @Specialization(guards = {"!isDictKeysView(self)", "!isAnySet(self)"}) static PBaseSet doGeneric(VirtualFrame frame, Object self, Object other, @Bind Node inliningTarget, - @Cached GetStorageForBinopNode getStorage, - @Cached HashingStorageXor xor, + @Cached SetNodes.ConstructSetNode constructSetNode, + @Shared @Cached GetSetStorageForXorNode getRightStorage, + @Shared @Cached HashingStorageXor xor, @Bind PythonLanguage language) { - HashingStorage left = getStorage.execute(frame, inliningTarget, self); - HashingStorage right = getStorage.execute(frame, inliningTarget, other); - return PFactory.createSet(language, xor.execute(frame, inliningTarget, left, right)); + HashingStorage left = constructSetNode.execute(frame, self).getDictStorage(); + HashingStorage right = getRightStorage.execute(frame, inliningTarget, other); + return PFactory.createSet(language, xor.executeMutatingLeft(frame, inliningTarget, left, right)); } } } diff --git a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/BaseSetBuiltins.java b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/BaseSetBuiltins.java index 4542c92bde..b183fda862 100644 --- a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/BaseSetBuiltins.java +++ b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/BaseSetBuiltins.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2026, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * The Universal Permissive License (UPL), Version 1.0 @@ -405,8 +405,7 @@ static Object doSet(VirtualFrame frame, PBaseSet self, PBaseSet other, @Bind Node inliningTarget, @Cached HashingStorageNodes.HashingStorageXor xorNode, @Cached CreateSetNode createSetNode) { - // TODO: calls __eq__ wrong number of times compared to CPython (GR-42240) - HashingStorage storage = xorNode.execute(frame, inliningTarget, self.getDictStorage(), other.getDictStorage()); + HashingStorage storage = xorNode.executePreservingLeft(frame, inliningTarget, other.getDictStorage(), self.getDictStorage()); return createSetNode.execute(inliningTarget, storage, self); } diff --git a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/FrozenSetBuiltins.java b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/FrozenSetBuiltins.java index e06ca36f1d..d1b18929bf 100644 --- a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/FrozenSetBuiltins.java +++ b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/FrozenSetBuiltins.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2025, Oracle and/or its affiliates. + * Copyright (c) 2017, 2026, Oracle and/or its affiliates. * Copyright (c) 2014, Regents of the University of California * * All rights reserved. @@ -40,6 +40,7 @@ import com.oracle.graal.python.builtins.objects.PNone; import com.oracle.graal.python.builtins.objects.common.EmptyStorage; import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes; +import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes.GetSetStorageForXorNode; import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes.GetSetStorageNode; import com.oracle.graal.python.builtins.objects.common.HashingStorage; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes; @@ -52,6 +53,7 @@ import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageIteratorKey; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageIteratorNext; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageXor; +import com.oracle.graal.python.builtins.objects.common.PHashingCollection; import com.oracle.graal.python.builtins.objects.type.TpSlots; import com.oracle.graal.python.builtins.objects.type.TypeNodes; import com.oracle.graal.python.builtins.objects.type.slots.TpSlotHashFun.HashBuiltinNode; @@ -65,8 +67,10 @@ import com.oracle.graal.python.runtime.object.PFactory; import com.oracle.truffle.api.dsl.Bind; import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Cached.Exclusive; import com.oracle.truffle.api.dsl.Cached.Shared; import com.oracle.truffle.api.dsl.GenerateNodeFactory; +import com.oracle.truffle.api.dsl.ImportStatic; import com.oracle.truffle.api.dsl.NodeFactory; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; @@ -195,15 +199,25 @@ static PFrozenSet doSet(@SuppressWarnings("unused") VirtualFrame frame, PFrozenS @Builtin(name = "symmetric_difference", minNumOfPositionalArgs = 2) @GenerateNodeFactory + @ImportStatic(PGuards.class) public abstract static class SymmetricDifferenceNode extends PythonBuiltinNode { @Specialization - static PFrozenSet doSet(@SuppressWarnings("unused") VirtualFrame frame, PFrozenSet self, Object other, + static PFrozenSet doHashingCollection(VirtualFrame frame, PFrozenSet self, PHashingCollection other, + @Bind Node inliningTarget, + @Exclusive @Cached HashingStorageXor xorNode, + @Bind PythonLanguage language) { + HashingStorage result = xorNode.executePreservingLeft(frame, inliningTarget, other.getDictStorage(), self.getDictStorage()); + return PFactory.createFrozenSet(language, result); + } + + @Specialization(guards = "!isPHashingCollection(other)") + static PFrozenSet doGeneric(VirtualFrame frame, PFrozenSet self, Object other, @Bind Node inliningTarget, - @Cached HashingCollectionNodes.GetSetStorageNode getHashingStorage, - @Cached HashingStorageXor xorNode, + @Exclusive @Cached GetSetStorageForXorNode getHashingStorage, + @Exclusive @Cached HashingStorageXor xorNode, @Bind PythonLanguage language) { - HashingStorage result = xorNode.execute(frame, inliningTarget, self.getDictStorage(), getHashingStorage.execute(frame, inliningTarget, other)); + HashingStorage result = xorNode.executeMutatingLeft(frame, inliningTarget, getHashingStorage.execute(frame, inliningTarget, other), self.getDictStorage()); return PFactory.createFrozenSet(language, result); } } diff --git a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/SetBuiltins.java b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/SetBuiltins.java index 6ff2efef46..0f1e5a2384 100644 --- a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/SetBuiltins.java +++ b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/SetBuiltins.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2025, Oracle and/or its affiliates. + * Copyright (c) 2017, 2026, Oracle and/or its affiliates. * Copyright (c) 2014, Regents of the University of California * * All rights reserved. @@ -44,6 +44,7 @@ import com.oracle.graal.python.builtins.objects.PNone; import com.oracle.graal.python.builtins.objects.PNotImplemented; import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes; +import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes.GetSetStorageForXorNode; import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes.GetSetStorageNode; import com.oracle.graal.python.builtins.objects.common.HashingStorage; import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageAddAllToOther; @@ -500,7 +501,7 @@ public abstract static class IXorNode extends PythonBinaryBuiltinNode { static Object doSet(VirtualFrame frame, PSet self, PBaseSet other, @Bind Node inliningTarget, @Cached HashingStorageXor xorNode) { - self.setDictStorage(xorNode.execute(frame, inliningTarget, self.getDictStorage(), other.getDictStorage())); + self.setDictStorage(xorNode.executeMutatingLeft(frame, inliningTarget, self.getDictStorage(), other.getDictStorage())); return self; } @@ -513,15 +514,25 @@ Object doOr(Object self, Object other) { @Builtin(name = "symmetric_difference", minNumOfPositionalArgs = 2) @GenerateNodeFactory + @ImportStatic(PGuards.class) public abstract static class SymmetricDifferenceNode extends PythonBuiltinNode { @Specialization - static PSet doSet(VirtualFrame frame, PSet self, Object other, + static PSet doHashingCollection(VirtualFrame frame, PSet self, PHashingCollection other, + @Bind Node inliningTarget, + @Exclusive @Cached HashingStorageXor xorNode, + @Bind PythonLanguage language) { + HashingStorage result = xorNode.executePreservingLeft(frame, inliningTarget, other.getDictStorage(), self.getDictStorage()); + return PFactory.createSet(language, result); + } + + @Specialization(guards = "!isPHashingCollection(other)") + static PSet doGeneric(VirtualFrame frame, PSet self, Object other, @Bind Node inliningTarget, - @Cached GetSetStorageNode getHashingStorage, - @Cached HashingStorageXor xorNode, + @Exclusive @Cached GetSetStorageForXorNode getHashingStorage, + @Exclusive @Cached HashingStorageXor xorNode, @Bind PythonLanguage language) { - HashingStorage result = xorNode.execute(frame, inliningTarget, self.getDictStorage(), getHashingStorage.execute(frame, inliningTarget, other)); + HashingStorage result = xorNode.executeMutatingLeft(frame, inliningTarget, getHashingStorage.execute(frame, inliningTarget, other), self.getDictStorage()); return PFactory.createSet(language, result); } } @@ -540,11 +551,11 @@ static PNone doSet(VirtualFrame frame, PSet self, PNone other) { static PNone doCached(VirtualFrame frame, PSet self, Object[] args, @Bind Node inliningTarget, @Cached("args.length") int len, - @Shared @Cached HashingCollectionNodes.GetSetStorageNode getHashingStorage, + @Shared @Cached GetSetStorageForXorNode getHashingStorage, @Shared @Cached HashingStorageXor xorNode) { HashingStorage result = self.getDictStorage(); for (int i = 0; i < len; i++) { - result = xorNode.execute(frame, inliningTarget, result, getHashingStorage.execute(frame, inliningTarget, args[i])); + result = xorNode.executeMutatingLeft(frame, inliningTarget, result, getHashingStorage.execute(frame, inliningTarget, args[i])); } self.setDictStorage(result); return PNone.NONE; @@ -553,11 +564,11 @@ static PNone doCached(VirtualFrame frame, PSet self, Object[] args, @Specialization(replaces = "doCached") static PNone doSetArgs(VirtualFrame frame, PSet self, Object[] args, @Bind Node inliningTarget, - @Shared @Cached GetSetStorageNode getHashingStorage, + @Shared @Cached GetSetStorageForXorNode getHashingStorage, @Shared @Cached HashingStorageXor xorNode) { HashingStorage result = self.getDictStorage(); for (Object o : args) { - result = xorNode.execute(frame, inliningTarget, result, getHashingStorage.execute(frame, inliningTarget, o)); + result = xorNode.executeMutatingLeft(frame, inliningTarget, result, getHashingStorage.execute(frame, inliningTarget, o)); } self.setDictStorage(result); return PNone.NONE; @@ -570,9 +581,9 @@ static boolean isOther(Object arg) { @Specialization(guards = "isOther(other)") static PNone doSetOther(VirtualFrame frame, PSet self, Object other, @Bind Node inliningTarget, - @Shared @Cached HashingCollectionNodes.GetSetStorageNode getHashingStorage, + @Shared @Cached GetSetStorageForXorNode getHashingStorage, @Shared @Cached HashingStorageXor xorNode) { - HashingStorage result = xorNode.execute(frame, inliningTarget, self.getDictStorage(), getHashingStorage.execute(frame, inliningTarget, other)); + HashingStorage result = xorNode.executeMutatingLeft(frame, inliningTarget, self.getDictStorage(), getHashingStorage.execute(frame, inliningTarget, other)); self.setDictStorage(result); return PNone.NONE; }