diff --git a/src/antlr/Parser.g b/src/antlr/Parser.g index 8794d00d02a9..5cffb86e7476 100644 --- a/src/antlr/Parser.g +++ b/src/antlr/Parser.g @@ -845,6 +845,7 @@ txnColumnCondition[List conditions] | K_NULL { conditions.add(new ConditionStatement.Raw(lhs, ConditionStatement.Kind.IS_NULL, null)); } ) | (txnConditionKind term)=> op=txnConditionKind t=term { conditions.add(new ConditionStatement.Raw(lhs, op, t)); } + | (txnConditionKind rowDataReference)=> op=txnConditionKind rhs=rowDataReference { conditions.add(new ConditionStatement.Raw(lhs, op, rhs)); } ) | lhs=term op=txnConditionKind rhs=rowDataReference { conditions.add(new ConditionStatement.Raw(lhs, op, rhs)); } ; diff --git a/src/java/org/apache/cassandra/cql3/transactions/ConditionStatement.java b/src/java/org/apache/cassandra/cql3/transactions/ConditionStatement.java index 7f19d834d934..4f02d5dda85a 100644 --- a/src/java/org/apache/cassandra/cql3/transactions/ConditionStatement.java +++ b/src/java/org/apache/cassandra/cql3/transactions/ConditionStatement.java @@ -24,6 +24,7 @@ import org.apache.cassandra.cql3.QueryOptions; import org.apache.cassandra.cql3.VariableSpecifications; import org.apache.cassandra.cql3.terms.Term; +import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.service.accord.txn.TxnCondition; import org.apache.cassandra.service.accord.txn.TxnReference; @@ -41,7 +42,7 @@ public enum Kind GTE(TxnCondition.Kind.GREATER_THAN_OR_EQUAL, TxnCondition.Kind.LESS_THAN_OR_EQUAL), LT(TxnCondition.Kind.LESS_THAN, TxnCondition.Kind.GREATER_THAN), LTE(TxnCondition.Kind.LESS_THAN_OR_EQUAL, TxnCondition.Kind.GREATER_THAN_OR_EQUAL); - + // TODO: Support for IN, CONTAINS, CONTAINS KEY private final TxnCondition.Kind kind; @@ -100,8 +101,19 @@ public ConditionStatement prepare(String keyspace, VariableSpecifications bindVa RowDataReference reference; Term value; boolean reversed = false; - - if (lhs instanceof RowDataReference.Raw) + + if (lhs instanceof RowDataReference.Raw && rhs instanceof RowDataReference.Raw) + { + if (((RowDataReference.Raw) lhs).column() == null) + throw new IllegalStateException(String.format("Row reference (%s) can only be used with IS NULL/IS NOT NULL conditions", lhs.getText())); + if (((RowDataReference.Raw) rhs).column() == null) + throw new IllegalStateException(String.format("Row reference (%s) can only be used with IS NULL/IS NOT NULL conditions", rhs.getText())); + reference = ((RowDataReference.Raw) lhs).prepareAsReceiver(); + value = ((RowDataReference.Raw) rhs).prepareAsReceiver(); + if (!reference.toResultMetadata().type.equals(((RowDataReference) value).toResultMetadata().type)) + throw new InvalidRequestException(String.format("Row reference (%s) must have the same type as row reference (%s)", lhs.getText(), rhs.getText())); + } + else if (lhs instanceof RowDataReference.Raw) { if (((RowDataReference.Raw) lhs).column() == null) throw new IllegalStateException(String.format("Row reference (%s) can only be used with IS NULL/IS NOT NULL conditions", lhs.getText())); @@ -143,10 +155,19 @@ public TxnCondition createCondition(QueryOptions options) case GTE: case LT: case LTE: - // TODO: Support for references on LHS and RHS - TxnReference ref = reference.toTxnReference(options); - checkTrue(ref.kind == TxnReference.Kind.COLUMN, "Condition %s requires COLUMN reference but given %s", kind, ref.kind); - return new TxnCondition.Value(ref.asColumn(), + TxnReference refLHS = reference.toTxnReference(options); + checkTrue(refLHS.kind == TxnReference.Kind.COLUMN, "Condition %s requires COLUMN reference but given %s", kind, refLHS.kind); + if (value instanceof RowDataReference) + { + TxnReference refRHS = ((RowDataReference) value).toTxnReference(options); + checkTrue(refRHS.kind == TxnReference.Kind.COLUMN, "Condition %s requires COLUMN reference but given %s", kind, refRHS.kind); + return new TxnCondition.Reference(refLHS.asColumn(), + kind.toTxnKind(reversed), + refRHS.asColumn(), + options.getProtocolVersion()); + } + + return new TxnCondition.Value(refLHS.asColumn(), kind.toTxnKind(reversed), value.bindAndGet(options), options.getProtocolVersion()); diff --git a/src/java/org/apache/cassandra/service/accord/txn/TxnCondition.java b/src/java/org/apache/cassandra/service/accord/txn/TxnCondition.java index 380566155ff7..ad40789443d2 100644 --- a/src/java/org/apache/cassandra/service/accord/txn/TxnCondition.java +++ b/src/java/org/apache/cassandra/service/accord/txn/TxnCondition.java @@ -42,6 +42,9 @@ import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.CollectionType; +import org.apache.cassandra.db.marshal.ListType; +import org.apache.cassandra.db.marshal.MapType; +import org.apache.cassandra.db.marshal.SetType; import org.apache.cassandra.db.marshal.UserType; import org.apache.cassandra.db.partitions.FilteredPartition; import org.apache.cassandra.db.rows.Cell; @@ -618,6 +621,125 @@ public long serializedSize(Value condition, TableMetadatas tables) }; } + public static class Reference extends TxnCondition + { + private static final EnumSet KINDS = EnumSet.of(Kind.EQUAL, Kind.NOT_EQUAL, + Kind.GREATER_THAN, Kind.GREATER_THAN_OR_EQUAL, + Kind.LESS_THAN, Kind.LESS_THAN_OR_EQUAL); + + private final TxnReference.ColumnReference referenceLHS; + private final TxnReference.ColumnReference referenceRHS; + private final ProtocolVersion version; + + public Reference(TxnReference.ColumnReference referenceLHS, Kind kind, TxnReference.ColumnReference referenceRHS, ProtocolVersion version) + { + super(kind); + Invariants.requireArgument(KINDS.contains(kind), "Kind " + kind + " cannot be used with a value condition"); + this.referenceLHS = referenceLHS; + this.referenceRHS = referenceRHS; + this.version = version; + } + + public static EnumSet supported() + { + return EnumSet.copyOf(KINDS); + } + + @Override + public boolean equals(Object o) + { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + Reference reference1 = (Reference) o; + return referenceLHS.equals(reference1.referenceLHS) && referenceRHS.equals(reference1.referenceRHS); + } + + @Override + public void collect(TableMetadatas.Collector collector) + { + referenceLHS.collect(collector); + referenceRHS.collect(collector); + } + + @Override + public int hashCode() + { + return Objects.hash(super.hashCode(), referenceLHS, referenceRHS); + } + + @Override + public String toString() + { + return referenceLHS.toString() + ' ' + kind.symbol + ' ' + referenceRHS.toString(); + } + + public AbstractType getColumnType(TxnReference.ColumnReference reference) + { + ColumnMetadata column = reference.column(); + if (reference.isElementSelection()) + { + if (column.type instanceof ListType) + return ((ListType) column.type).valueComparator(); + else if (column.type instanceof SetType) + return ((SetType) column.type).nameComparator(); + else if (column.type instanceof MapType) + return ((MapType) column.type).valueComparator(); + } + else if (reference.isFieldSelection()) + { + return reference.getFieldSelectionType(); + } + + return column.type; + } + + @Override + public boolean applies(TxnData data) + { + AbstractType typeLHS = getColumnType(referenceLHS); + AbstractType typeRHS = getColumnType(referenceRHS); + + ByteBuffer lhs = referenceLHS.toByteBuffer(data, typeLHS); + ByteBuffer rhs = referenceRHS.toByteBuffer(data, typeRHS); + + if (lhs == null || rhs == null) + return false; + + return kind.operator.isSatisfiedBy(typeLHS, lhs, rhs); + } + + private static final ConditionSerializer serializer = new ConditionSerializer<>() + { + @Override + public void serialize(Reference condition, TableMetadatas tables, DataOutputPlus out) throws IOException + { + TxnReference.serializer.serialize(condition.referenceLHS, tables, out); + TxnReference.serializer.serialize(condition.referenceRHS, tables, out); + out.writeUTF(condition.version.name()); + } + + @Override + public Reference deserialize(TableMetadatas tables, DataInputPlus in, Kind kind) throws IOException + { + TxnReference.ColumnReference referenceLHS = TxnReference.serializer.deserialize(tables, in).asColumn(); + TxnReference.ColumnReference referenceRHS = TxnReference.serializer.deserialize(tables, in).asColumn(); + ProtocolVersion protocolVersion = ProtocolVersion.valueOf(in.readUTF()); + return new Reference(referenceLHS, kind, referenceRHS, protocolVersion); + } + + @Override + public long serializedSize(Reference condition, TableMetadatas tables) + { + long size = 0; + size += TxnReference.serializer.serializedSize(condition.referenceLHS, tables); + size += TxnReference.serializer.serializedSize(condition.referenceRHS, tables); + size += TypeSizes.sizeof(condition.version.name()); + return size; + } + }; + } + public static class BooleanGroup extends TxnCondition { private static final Set KINDS = ImmutableSet.of(Kind.AND, Kind.OR); @@ -698,27 +820,64 @@ public long serializedSize(BooleanGroup condition, TableMetadatas tables) public static final ParameterisedUnversionedSerializer serializer = new ParameterisedUnversionedSerializer<>() { + // TOP_BIT is used to differentiate between Value.Serializer and Reference.Serialzer, + // in order to implement comparison between LET variables. + // The reason we use TOP_BIT is to support users who have been deploying off trunk + // to upgrade nodes without breaking them. Upgrading is safe under the following assumptions: + // 1) `ref op ref` feature is only used after all nodes have been upgraded + // 2) cluster can be mixed mode as long as `ref op ref` is not used + // If a user tries to use `ref op ref` in a mixed mode this will lead to undefined errors, + // where the only recovery process is to force older nodes to upgrade + // See CASSANDRA-21458 + private static final int TOP_BIT = 0x40000000; + @SuppressWarnings("unchecked") @Override public void serialize(TxnCondition condition, TableMetadatas tables, DataOutputPlus out) throws IOException { - out.writeUnsignedVInt32(condition.kind.ordinal()); - condition.kind.serializer().serialize(condition, tables, out); + if (condition instanceof Reference) + { + out.writeUnsignedVInt32(condition.kind.ordinal() | TOP_BIT); + Reference.serializer.serialize((Reference) condition, tables, out); + } + else + { + out.writeUnsignedVInt32(condition.kind.ordinal()); + condition.kind.serializer().serialize(condition, tables, out); + } } @Override public TxnCondition deserialize(TableMetadatas tables, DataInputPlus in) throws IOException { - Kind kind = Kind.values()[in.readUnsignedVInt32()]; - return kind.serializer().deserialize(tables, in, kind); + int flag = in.readUnsignedVInt32(); + if ((flag & TOP_BIT) != 0) + { + Kind kind = Kind.values()[flag ^ TOP_BIT]; + return Reference.serializer.deserialize(tables, in, kind); + } + else + { + Kind kind = Kind.values()[flag]; + return kind.serializer().deserialize(tables, in, kind); + } } @SuppressWarnings("unchecked") @Override public long serializedSize(TxnCondition condition, TableMetadatas tables) { - long size = TypeSizes.sizeofUnsignedVInt(condition.kind.ordinal()); - size += condition.kind.serializer().serializedSize(condition, tables); + long size; + if (condition instanceof Reference) + { + size = TypeSizes.sizeofUnsignedVInt(condition.kind.ordinal() | TOP_BIT); + size += Reference.serializer.serializedSize((Reference) condition, tables); + } + else + { + size = TypeSizes.sizeofUnsignedVInt(condition.kind.ordinal()); + size += condition.kind.serializer().serializedSize(condition, tables); + } return size; } }; diff --git a/src/java/org/apache/cassandra/utils/NullableSerializer.java b/src/java/org/apache/cassandra/utils/NullableSerializer.java index 8392bf19dc85..bddabf407c2e 100644 --- a/src/java/org/apache/cassandra/utils/NullableSerializer.java +++ b/src/java/org/apache/cassandra/utils/NullableSerializer.java @@ -35,7 +35,6 @@ public static void serializeNullable(T value, DataOutputPlus out, Unversione if (value != null) serializer.serialize(value, out); } - public static void serializeNullable(T value, DataOutputPlus out, int version, IVersionedSerializer serializer) throws IOException { out.writeBoolean(value != null); diff --git a/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java b/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java index f38b84ede287..1cbde6b5f24b 100644 --- a/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java +++ b/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java @@ -3604,4 +3604,161 @@ public void userSeesInvalidRejection() throws Exception .hasMessage("Attempted to set an element on a list which is null"); }); } + + @Test + public void testLetComparisonWithDifferentTypesFails() throws Throwable + { + test("CREATE TABLE " + qualifiedAccordTableName + " (k text PRIMARY KEY, customer person) WITH " + transactionalMode.asCqlParam(), cluster -> { + try + { + Object personValue = CQLTester.userType("height", 74, "age", 37); + ByteBuffer personBuffer = CQLTester.makeByteBuffer(personValue, null); + String query = "BEGIN TRANSACTION\n" + + "LET k1 = (SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 'first');\n" + + "LET k2 = (SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 'second');\n" + + "IF k1.k < k2.customer.height THEN \n" + + " UPDATE " + qualifiedAccordTableName + " SET customer = ? WHERE k = 'first';\n" + + "END IF\n" + + "COMMIT TRANSACTION"; + + + cluster.coordinator(1).executeWithResult(query, ConsistencyLevel.SERIAL, personBuffer); + fail("Expected exception"); + } + catch (Throwable t) + { + assertEquals(InvalidRequestException.class.getName(), t.getClass().getName()); + assertEquals("Row reference (k1.k) must have the same type as row reference (k2.customer.height)", t.getMessage()); + } + }); + } + + @Test + public void testLetComparisonTransactionStatement() throws Throwable + { + test("CREATE TABLE " + qualifiedAccordTableName + " (k int PRIMARY KEY, v int) WITH " + transactionalMode.asCqlParam(), cluster -> { + String insert = "BEGIN TRANSACTION\n" + + "INSERT INTO " + qualifiedAccordTableName + " (k, v) VALUES (1, 2);\n" + + "INSERT INTO " + qualifiedAccordTableName + " (k, v) VALUES (2, 3);\n" + + "COMMIT TRANSACTION"; + + cluster.coordinator(1).executeWithResult(insert, ConsistencyLevel.SERIAL); + + String query = "BEGIN TRANSACTION\n" + + "LET k1 = (SELECT v FROM " + qualifiedAccordTableName + " WHERE k = 1);\n" + + "LET k2 = (SELECT v FROM " + qualifiedAccordTableName + " WHERE k = 2);\n" + + "IF k1.v IS NOT NULL AND k1.v < k2.v THEN \n" + + " UPDATE " + qualifiedAccordTableName + " SET v = 10 WHERE k = 1;\n" + + "END IF\n" + + "COMMIT TRANSACTION"; + + cluster.coordinator(1).executeWithResult(query, ConsistencyLevel.SERIAL); + + String read = "BEGIN TRANSACTION\n" + + "SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 1;\n" + + "COMMIT TRANSACTION"; + + SimpleQueryResult result = cluster.coordinator(1).executeWithResult(read, ConsistencyLevel.SERIAL); + assertThat(result).hasSize(1).contains(1, 10); + }); + } + + @Test + public void testLetComparisonWithDifferentTypes() throws Throwable + { + test("CREATE TABLE " + qualifiedAccordTableName + " (k int PRIMARY KEY, customer person) WITH " + transactionalMode.asCqlParam(), cluster -> { + Object personValue1 = CQLTester.userType("height", 74, "age", 37); + ByteBuffer personBuffer1 = CQLTester.makeByteBuffer(personValue1, null); + + Object personValue2 = CQLTester.userType("height", 74, "age", 38); + ByteBuffer personBuffer2 = CQLTester.makeByteBuffer(personValue2, null); + + String insert = "BEGIN TRANSACTION\n" + + " INSERT INTO " + qualifiedAccordTableName + " (k, customer) VALUES (?, ?);\n" + + " INSERT INTO " + qualifiedAccordTableName + " (k, customer) VALUES (?, ?);\n" + + "COMMIT TRANSACTION"; + cluster.coordinator(1).executeWithResult(insert, ConsistencyLevel.ANY, 0, personBuffer1, 32, personBuffer2); + + String update = "BEGIN TRANSACTION\n" + + "LET k1 = (SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 0);\n" + + "LET k2 = (SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 32);\n" + + "IF k1.customer.height > k2.k THEN \n" + + " UPDATE " + qualifiedAccordTableName + " SET customer = ? WHERE k = 32;\n" + + "END IF\n" + + "COMMIT TRANSACTION"; + + cluster.coordinator(1).executeWithResult(update, ConsistencyLevel.SERIAL, personBuffer1); + + String read = "BEGIN TRANSACTION\n" + + "SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 32;\n" + + "COMMIT TRANSACTION"; + + SimpleQueryResult result = cluster.coordinator(1).executeWithResult(read, ConsistencyLevel.SERIAL); + assertThat(result).hasSize(1).contains(32, personBuffer1); + }); + } + + @Test + public void testLetComparisonWithUDT() throws Throwable + { + test("CREATE TABLE " + qualifiedAccordTableName + " (k int PRIMARY KEY, customer person) WITH " + transactionalMode.asCqlParam(), cluster -> { + Object personValue1 = CQLTester.userType("height", 74, "age", 37); + ByteBuffer personBuffer1 = CQLTester.makeByteBuffer(personValue1, null); + + Object personValue2 = CQLTester.userType("height", 74, "age", 38); + ByteBuffer personBuffer2 = CQLTester.makeByteBuffer(personValue2, null); + + String insert = "BEGIN TRANSACTION\n" + + " INSERT INTO " + qualifiedAccordTableName + " (k, customer) VALUES (?, ?);\n" + + " INSERT INTO " + qualifiedAccordTableName + " (k, customer) VALUES (?, ?);\n" + + "COMMIT TRANSACTION"; + cluster.coordinator(1).executeWithResult(insert, ConsistencyLevel.ANY, 0, personBuffer1, 1, personBuffer2); + + String update = "BEGIN TRANSACTION\n" + + "LET k1 = (SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 0);\n" + + "LET k2 = (SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 1);\n" + + "IF k1.customer.height = k2.customer.height THEN \n" + + " UPDATE " + qualifiedAccordTableName + " SET customer = ? WHERE k = 1;\n" + + "END IF\n" + + "COMMIT TRANSACTION"; + + cluster.coordinator(1).executeWithResult(update, ConsistencyLevel.SERIAL, personBuffer1); + + String read = "BEGIN TRANSACTION\n" + + "SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 1;\n" + + "COMMIT TRANSACTION"; + + SimpleQueryResult result = cluster.coordinator(1).executeWithResult(read, ConsistencyLevel.SERIAL); + assertThat(result).hasSize(1).contains(1, personBuffer1); + }); + } + + @Test + public void testLetComparisonWithMap() throws Throwable + { + test("CREATE TABLE " + qualifiedAccordTableName + " (k int PRIMARY KEY, v map) WITH " + transactionalMode.asCqlParam(), cluster -> { + String insert = "BEGIN TRANSACTION\n" + + " INSERT INTO " + qualifiedAccordTableName + " (k, v) VALUES (0, {1:3, 2:6});\n" + + " INSERT INTO " + qualifiedAccordTableName + " (k, v) VALUES (1, {0:5});\n" + + "COMMIT TRANSACTION"; + cluster.coordinator(1).executeWithResult(insert, ConsistencyLevel.ANY); + + String update = "BEGIN TRANSACTION\n" + + "LET k1 = (SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 0);\n" + + "LET k2 = (SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 1);\n" + + "IF k1.v[1] < k2.v[0] THEN \n" + + " UPDATE " + qualifiedAccordTableName + " SET v = {1:10} WHERE k = 1;\n" + + "END IF\n" + + "COMMIT TRANSACTION"; + + cluster.coordinator(1).executeWithResult(update, ConsistencyLevel.SERIAL); + + String read = "BEGIN TRANSACTION\n" + + "SELECT v[1] FROM " + qualifiedAccordTableName + " WHERE k = 1;\n" + + "COMMIT TRANSACTION"; + + SimpleQueryResult result = cluster.coordinator(1).executeWithResult(read, ConsistencyLevel.SERIAL); + assertThat(result).hasSize(1).contains(10); + }); + } } diff --git a/test/unit/org/apache/cassandra/service/accord/txn/TxnConditionTest.java b/test/unit/org/apache/cassandra/service/accord/txn/TxnConditionTest.java index 287e921a0228..d64d3813ce07 100644 --- a/test/unit/org/apache/cassandra/service/accord/txn/TxnConditionTest.java +++ b/test/unit/org/apache/cassandra/service/accord/txn/TxnConditionTest.java @@ -471,6 +471,94 @@ public void value() }); } + @Test + public void reference() + { + Gen> typeGen = toGen(new AbstractTypeGenerators.TypeGenBuilder() + .withoutUnsafeEquality() + .build()); + qt().check(rs -> { + AbstractType type = typeGen.next(rs); + TableMetadata metadata = TableMetadata.builder("ks", "tbl") + .addPartitionKeyColumn("pk", type.freeze()) + .addClusteringColumn("ck", type.freeze()) + .addRegularColumn("r", type) + .addStaticColumn("s", type) + .partitioner(Murmur3Partitioner.instance) + .build(); + + ByteBuffer valueLHS = toGen(AbstractTypeGenerators.getTypeSupport(type).bytesGen()).next(rs); + List complexValueLHS = type.isMultiCell() ? split(type, valueLHS) : null; + Clustering clusteringLHS = BufferClustering.make(valueLHS); + SimplePartition partitionLHS = new SimplePartition(metadata, metadata.partitioner.decorateKey(valueLHS)); + + ByteBuffer valueRHS = toGen(AbstractTypeGenerators.getTypeSupport(type).bytesGen()).next(rs); + List complexValueRHS = type.isMultiCell() ? split(type, valueRHS) : null; + Clustering clusteringRHS = BufferClustering.make(valueRHS); + SimplePartition partitionRHS = new SimplePartition(metadata, metadata.partitioner.decorateKey(valueRHS)); + + for (TxnCondition.Kind kind : TxnCondition.Value.supported()) + { + for (ProtocolVersion version : ProtocolVersion.SUPPORTED) + { + for (ColumnMetadata column : metadata.columns()) + { + TxnReference refLHS = TxnReference.column(0, metadata, column); + TxnReference refRHS = TxnReference.column(1, metadata, column); + + TxnCondition.Value value = new TxnCondition.Value(refLHS.asColumn(), kind, valueRHS, version); + TxnCondition.Reference condition = new TxnCondition.Reference(refLHS.asColumn(), kind, refRHS.asColumn(), version); + + partitionLHS.clear().addEmptyAndLive(clusteringLHS); + partitionRHS.clear().addEmptyAndLive(clusteringRHS); + + TxnData dataLHS = TxnData.of(0, new TxnDataKeyValue(partitionLHS.filtered())); + TxnData data = dataLHS.merge(TxnData.of(1, new TxnDataKeyValue(partitionRHS.filtered()))); + + Assertions.assertThat(condition.applies(data)) + .describedAs("column=%s, type=%s, kind=%s", column.name, type.asCQL3Type(), kind.name()) + .isEqualTo(value.applies(dataLHS)); + + if (column.isPrimaryKeyColumn()) continue; + + // with value + if (type.isMultiCell()) + { + partitionLHS.clear() + .add(column.isStatic() ? Clustering.STATIC_CLUSTERING : clusteringLHS) + .addComplex(column, complexValueLHS) + .build(); + + partitionRHS.clear() + .add(column.isStatic() ? Clustering.STATIC_CLUSTERING : clusteringRHS) + .addComplex(column, complexValueRHS) + .build(); + } + else + { + partitionLHS.clear() + .add(column.isStatic() ? Clustering.STATIC_CLUSTERING : clusteringLHS) + .add(column, valueLHS) + .build(); + + partitionRHS.clear() + .add(column.isStatic() ? Clustering.STATIC_CLUSTERING : clusteringRHS) + .add(column, valueRHS) + .build(); + } + + dataLHS = TxnData.of(0, new TxnDataKeyValue(partitionLHS.filtered())); + data = dataLHS.merge(TxnData.of(1, new TxnDataKeyValue(partitionRHS.filtered()))); + + Assertions.assertThat(condition.applies(data)) + .describedAs("column=%s, type=%s, kind=%s", column.name, type.asCQL3Type(), kind.name()) + .isEqualTo(value.applies(dataLHS)); + } + } + } + }); + } + private static List split(AbstractType type, ByteBuffer value) { type = type.unwrap(); @@ -494,13 +582,14 @@ private static void assertExists(TxnData data, TxnReference ref, boolean exists) private Gen txnConditionGen() { return rs -> { - switch (rs.nextInt(1, 5)) + switch (rs.nextInt(1, 6)) { case 0: return TxnCondition.none(); case 1: return new TxnCondition.Exists(TXN_REF_GEN.next(rs), EXISTS_KIND_GEN.next(rs)); case 2: return new TxnCondition.Value(TXN_REF_GEN.next(rs).asColumn(), VALUE_KIND_GEN.next(rs), BYTES_GEN.next(rs), PROTOCOL_VERSION_GEN.next(rs)); - case 3: return new TxnCondition.ColumnConditionsAdapter(CLUSTERING_GEN.next(rs), Gens.lists(BOUND_GEN).ofSizeBetween(0, 3).next(rs)); - case 4: return new TxnCondition.BooleanGroup(BOOLEAN_KIND_GEN.next(rs), Gens.lists(txnConditionGen()).ofSizeBetween(0, 3).next(rs)); + case 3: return new TxnCondition.Reference(TXN_REF_GEN.next(rs).asColumn(), VALUE_KIND_GEN.next(rs), TXN_REF_GEN.next(rs).asColumn(), PROTOCOL_VERSION_GEN.next(rs)); + case 4: return new TxnCondition.ColumnConditionsAdapter(CLUSTERING_GEN.next(rs), Gens.lists(BOUND_GEN).ofSizeBetween(0, 3).next(rs)); + case 5: return new TxnCondition.BooleanGroup(BOOLEAN_KIND_GEN.next(rs), Gens.lists(txnConditionGen()).ofSizeBetween(0, 3).next(rs)); default: throw new AssertionError(); } }; diff --git a/test/unit/org/apache/cassandra/utils/ASTGenerators.java b/test/unit/org/apache/cassandra/utils/ASTGenerators.java index 31df22f11ad2..68244c27dfbd 100644 --- a/test/unit/org/apache/cassandra/utils/ASTGenerators.java +++ b/test/unit/org/apache/cassandra/utils/ASTGenerators.java @@ -1366,27 +1366,32 @@ public Gen build() { Gen boolGen = SourceDSL.booleans().all(); return rnd -> { - var pk = partitionKeyValuesGen.generate(rnd); - var mutation = mutationGen(rs, pk).generate(rnd); + var pk1 = partitionKeyValuesGen.generate(rnd); + var pk2 = partitionKeyValuesGen.generate(rnd); + var mutation = mutationGen(rs, pk1).generate(rnd); - Select select = select(metadata, pk).withLimit(1); - var columns = model.columns(select); + Select select1 = select(metadata, pk1).withLimit(1); + Select select2 = select(metadata, pk2).withLimit(1); + var columns = model.columns(select1); Txn.Builder builder = Txn.builder(); - builder.addLet("r1", select); - Reference ref = Reference.of(Symbol.unknownType("r1")); + builder.addLet("r1", select1); + Reference ref1 = Reference.of(Symbol.unknownType("r1")); + builder.addLet("r2", select1); + Reference ref2 = Reference.of(Symbol.unknownType("r2")); - builder.addReturn(select(metadata, pk)); + builder.addReturn(select(metadata, pk1)); Conditional.Builder condition = Conditional.builder(); for (var col : columns) { if (boolGen.generate(rnd)) continue; - Reference colRef = ref.add(col); + Reference colRef1 = ref1.add(col); + Reference colRef2 = ref2.add(col); if (boolGen.generate(rnd)) - condition.is(colRef, SourceDSL.arbitrary().enumValues(Conditional.Is.Kind.class).generate(rnd)); + condition.is(colRef1, SourceDSL.arbitrary().enumValues(Conditional.Is.Kind.class).generate(rnd)); if (boolGen.generate(rnd)) { - Expression lhs = colRef; + Expression lhs = colRef1; Expression rhs = value(rnd, getTypeSupport(lhs.type()).bytesGen().generate(rnd), lhs.type()); if (boolGen.generate(rnd)) { @@ -1397,6 +1402,11 @@ public Gen build() Conditional.Where.Inequality inequality = SourceDSL.arbitrary().enumValues(Conditional.Where.Inequality.class).generate(rnd); condition.where(lhs, inequality, rhs); } + if (boolGen.generate(rnd)) + { + Conditional.Where.Inequality inequality = SourceDSL.arbitrary().enumValues(Conditional.Where.Inequality.class).generate(rnd); + condition.where(colRef1, inequality, colRef2); + } } if (condition.isEmpty()) condition.is("r1", Conditional.Is.Kind.NotNull);