diff --git a/src/java/org/apache/cassandra/db/memtable/ShardBoundaries.java b/src/java/org/apache/cassandra/db/memtable/ShardBoundaries.java index 102d087bb247..6440d5655128 100644 --- a/src/java/org/apache/cassandra/db/memtable/ShardBoundaries.java +++ b/src/java/org/apache/cassandra/db/memtable/ShardBoundaries.java @@ -18,12 +18,18 @@ package org.apache.cassandra.db.memtable; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import com.google.common.annotations.VisibleForTesting; +import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; import org.apache.cassandra.tcm.Epoch; @@ -45,8 +51,12 @@ public class ShardBoundaries // - the default partitioner doesn't support splitting // - the keyspace is local system keyspace public static final ShardBoundaries NONE = new ShardBoundaries(EMPTY_TOKEN_ARRAY, Epoch.EMPTY); + private static final Range[] EMPTY_RANGE_ARRAY = new Range[0]; + private static final List EMPTY_BOUNDARIES_SHARDS = Collections.singletonList(0); private final Token[] boundaries; + private final Range[] ranges; + private final List allShards; public final Epoch epoch; @VisibleForTesting @@ -54,6 +64,29 @@ public ShardBoundaries(Token[] boundaries, Epoch epoch) { this.boundaries = boundaries; this.epoch = epoch; + this.ranges = precomputeRanges(); + this.allShards = IntStream.range(0, boundaries.length + 1).boxed().collect(Collectors.toList()); + } + + private Range[] precomputeRanges() + { + if (boundaries.length == 0) + return EMPTY_RANGE_ARRAY; + + Range[] ranges = new Range[boundaries.length + 1]; + int rangeIndex = 0; + PartitionPosition minimum = DatabaseDescriptor.getPartitioner().getMinimumToken().minKeyBound(); + + for (Token boundary : boundaries) + { + PartitionPosition boundaryPosition = boundary.maxKeyBound(); + ranges[rangeIndex++] = new Range<>(minimum, boundaryPosition); + minimum = boundaryPosition; + } + + ranges[rangeIndex] = new Range<>(minimum, DatabaseDescriptor.getPartitioner().getMaximumTokenForSplitting().maxKeyBound()); + + return ranges; } public ShardBoundaries(List boundaries, Epoch epoch) @@ -87,6 +120,20 @@ public int getShardForKey(PartitionPosition key) return getShardForToken(key.getToken()); } + public List getShardsForRange(AbstractBounds keyRange) + { + if (boundaries.length == 0) + return EMPTY_BOUNDARIES_SHARDS; + + // If the keyRange tokens match and are minimum then it represents the entire token ring + // then we need to return all the shards. + if (keyRange.right.isMinimum() && keyRange.left.compareTo(keyRange.right) == 0) + return allShards; + + // Otherwise we need to return all the shards whose range intersects the keyrange + return allShards.stream().filter(s -> ranges[s].intersects(keyRange)).collect(Collectors.toList()); + } + public Token getShardStartBoundary(int shardId) { if (shardId <= 0 || shardId >= boundaries.length) diff --git a/src/java/org/apache/cassandra/dht/Range.java b/src/java/org/apache/cassandra/dht/Range.java index a95249d426b3..7c21127e7c31 100644 --- a/src/java/org/apache/cassandra/dht/Range.java +++ b/src/java/org/apache/cassandra/dht/Range.java @@ -178,9 +178,35 @@ public boolean intersects(AbstractBounds that) return intersects((Range) that); if (that instanceof Bounds) return intersects((Bounds) that); + if (that instanceof ExcludingBounds) + return intersects((ExcludingBounds) that); + if (that instanceof IncludingExcludingBounds) + return intersects((IncludingExcludingBounds) that); throw new UnsupportedOperationException("Intersection is only supported for Bounds and Range objects; found " + that.getClass()); } + public boolean intersects(IncludingExcludingBounds that) + { + if (!isWrapAround() && !that.right.isMinimum() && (this.left.compareTo(that.right) == 0)) + return false; + else if (isWrapAround() && !that.right.isMinimum() && (this.right.compareTo(that.right) == 0)) + return false; + return contains(that.left) || (!that.left.equals(that.right) && intersects(new Range(that.left, that.right))); + } + + public boolean intersects(ExcludingBounds that) + { + if (!isWrapAround() && + ((!that.right.isMinimum() && (this.left.compareTo(that.right) == 0)) || + (this.right.compareTo(that.left) == 0))) + return false; + else if (isWrapAround() && + ((this.left.compareTo(that.left) == 0) || + (!that.right.isMinimum() && (this.right.compareTo(that.right) == 0)))) + return false; + return contains(that.left) || (!that.left.equals(that.right) && intersects(new Range(that.left, that.right))); + } + /** * @param that range to check for intersection * @return true if the given range intersects with this range. diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java index 1dba97d3ea5c..3423e55ce603 100644 --- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java +++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java @@ -88,6 +88,7 @@ import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig; import org.apache.cassandra.index.sai.memory.MemtableIndexManager; +import org.apache.cassandra.index.sai.memory.ShardedMemtableIndex; import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics; import org.apache.cassandra.index.sai.metrics.IndexMetrics; import org.apache.cassandra.index.sai.utils.IndexIdentifier; @@ -159,7 +160,8 @@ public class StorageAttachedIndex implements Index IndexWriterConfig.OPTIMIZE_FOR, NonTokenizingOptions.CASE_SENSITIVE, NonTokenizingOptions.NORMALIZE, - NonTokenizingOptions.ASCII); + NonTokenizingOptions.ASCII, + ShardedMemtableIndex.SHARDS_OPTION); public static final Set SUPPORTED_TYPES = ImmutableSet.of(CQL3Type.Native.ASCII, CQL3Type.Native.BIGINT, CQL3Type.Native.DATE, CQL3Type.Native.DOUBLE, CQL3Type.Native.FLOAT, CQL3Type.Native.INT, @@ -276,6 +278,10 @@ public static Map validateOptions(Map options, T } IndexTermType indexTermType = IndexTermType.create(target.left, metadata.partitionKeyColumns(), target.right); + String shardsOption = options.get(ShardedMemtableIndex.SHARDS_OPTION); + if (shardsOption != null && indexTermType.isVector()) + throw new InvalidRequestException("A storage-attached index on a vector column does not support sharding"); + AbstractAnalyzer.fromOptions(indexTermType, analysisOptions); IndexWriterConfig config = IndexWriterConfig.fromOptions(null, indexTermType, options); diff --git a/src/java/org/apache/cassandra/index/sai/disk/RowMapping.java b/src/java/org/apache/cassandra/index/sai/disk/RowMapping.java index 5bf64360afc3..19cedd0b937c 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/RowMapping.java +++ b/src/java/org/apache/cassandra/index/sai/disk/RowMapping.java @@ -52,7 +52,7 @@ public class RowMapping public static final RowMapping DUMMY = new RowMapping() { @Override - public Iterator> merge(MemtableIndex index) { return Collections.emptyIterator(); } + public Iterator> merge(MemtableIndex index, PrimaryKey minKey, PrimaryKey maxKey) { return Collections.emptyIterator(); } @Override public void complete() {} @@ -99,11 +99,13 @@ public static RowMapping create(OperationType opType) * * @return an iterator of term -> postings list {@link Pair}s */ - public Iterator> merge(MemtableIndex index) + public Iterator> merge(MemtableIndex index, + PrimaryKey minKey, + PrimaryKey maxKey) { assert complete : "RowMapping is not built."; - Iterator> iterator = index.iterator(); + Iterator>> iterator = index.iterator(minKey.partitionKey(), maxKey.partitionKey()); return new AbstractGuavaIterator<>() { @Override @@ -111,10 +113,10 @@ protected Pair computeNext() { while (iterator.hasNext()) { - Pair pair = iterator.next(); + Pair> pair = iterator.next(); LongArrayList postings = null; - Iterator primaryKeys = pair.right.iterator(); + Iterator primaryKeys = pair.right; while (primaryKeys.hasNext()) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java index 795d5c1a0d6b..e8201c800f10 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java @@ -145,7 +145,7 @@ public void complete(Stopwatch stopwatch) throws IOException } else { - final Iterator> iterator = rowMapping.merge(memtable); + final Iterator> iterator = rowMapping.merge(memtable, minKey, maxKey); long cellCount = 0; if (iterator.hasNext()) diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java index ab24fff087f9..395bfdd7f2db 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java @@ -22,7 +22,6 @@ import java.nio.ByteBuffer; import java.util.Iterator; import java.util.List; -import java.util.concurrent.atomic.LongAdder; import java.util.function.Function; import org.apache.cassandra.db.Clustering; @@ -31,7 +30,6 @@ import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.QueryContext; -import org.apache.cassandra.index.sai.StorageAttachedIndex; import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; @@ -39,96 +37,58 @@ import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.utils.IndexIdentifier; import org.apache.cassandra.index.sai.utils.PrimaryKey; -import org.apache.cassandra.index.sai.utils.PrimaryKeys; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; -public class MemtableIndex implements MemtableOrdering +public interface MemtableIndex extends MemtableOrdering { - private final MemoryIndex memoryIndex; - private final LongAdder writeCount = new LongAdder(); - private final LongAdder estimatedMemoryUsed = new LongAdder(); - private final Memtable memtable; + long writeCount(); - public MemtableIndex(StorageAttachedIndex index, Memtable memtable) - { - this.memoryIndex = index.termType().isVector() ? new VectorMemoryIndex(index, memtable) : new TrieMemoryIndex(index); - this.memtable = memtable; - } + long estimatedMemoryUsed(); - public long writeCount() - { - return writeCount.sum(); - } + boolean isEmpty(); - public long estimatedMemoryUsed() - { - return estimatedMemoryUsed.sum(); - } + Memtable getMemtable(); - public boolean isEmpty() - { - return memoryIndex.isEmpty(); - } + ByteBuffer getMinTerm(); - public Memtable getMemtable() - { - return memtable; - } + ByteBuffer getMaxTerm(); - public ByteBuffer getMinTerm() - { - return memoryIndex.getMinTerm(); - } + long index(DecoratedKey key, Clustering clustering, ByteBuffer value); - public ByteBuffer getMaxTerm() + /** Implementation only for Vector Indexes */ + default long update(DecoratedKey key, Clustering clustering, ByteBuffer oldValue, ByteBuffer newValue) { - return memoryIndex.getMaxTerm(); + throw new UnsupportedOperationException(); } - public long index(DecoratedKey key, Clustering clustering, ByteBuffer value) - { - if (value == null || (value.remaining() == 0 && memoryIndex.index.termType().skipsEmptyValue())) - return 0; - - long ram = memoryIndex.add(key, clustering, value); - writeCount.increment(); - estimatedMemoryUsed.add(ram); - return ram; - } - - public long update(DecoratedKey key, Clustering clustering, ByteBuffer oldValue, ByteBuffer newValue) - { - return memoryIndex.update(key, clustering, oldValue, newValue); - } + KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds keyRange); - public KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds keyRange) - { - return memoryIndex.search(queryContext, expression, keyRange); - } - - public Iterator> iterator() - { - return memoryIndex.iterator(); - } + /** + * NOTE: Returned data may contain keys outside [minKey, maxKey]. + * Bounds are used for shard selection only; implementations + * without sharding may ignore them. + */ + Iterator>> iterator(DecoratedKey min, DecoratedKey max); - public SegmentMetadata.ComponentMetadataMap writeDirect(IndexDescriptor indexDescriptor, - IndexIdentifier indexIdentifier, - Function postingTransformer) throws IOException + /** Implementation only for Vector Indexes */ + default SegmentMetadata.ComponentMetadataMap writeDirect(IndexDescriptor indexDescriptor, + IndexIdentifier indexIdentifier, + Function postingTransformer) throws IOException { - return memoryIndex.writeDirect(indexDescriptor, indexIdentifier, postingTransformer); + throw new UnsupportedOperationException(); } @Override - public CloseableIterator orderBy(QueryContext queryContext, Expression orderer, AbstractBounds keyRange) + default CloseableIterator orderBy(QueryContext queryContext, Expression orderer, AbstractBounds keyRange) { - return memoryIndex.orderBy(queryContext, orderer, keyRange); + throw new UnsupportedOperationException(); } @Override - public CloseableIterator orderResultsBy(QueryContext queryContext, List results, Expression orderer) + default CloseableIterator orderResultsBy(QueryContext queryContext, List results, Expression orderer) { - return memoryIndex.orderResultsBy(queryContext, results, orderer); + throw new UnsupportedOperationException(); } } diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndexManager.java b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndexManager.java index c2f7087fc963..983ba3255244 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndexManager.java +++ b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndexManager.java @@ -64,7 +64,16 @@ private MemtableIndex initializeMemtableIndex(Memtable mt) // We expect the relevant IndexMemtable to be present most of the time, so only make the // call to computeIfAbsent() if it's not. (see https://bugs.openjdk.java.net/browse/JDK-8161372) return current != null ? current - : liveMemtableIndexMap.computeIfAbsent(mt, memtable -> new MemtableIndex(index, memtable)); + : liveMemtableIndexMap.computeIfAbsent(mt, memtable -> { + String shardsOption = index.getIndexMetadata().options.get(ShardedMemtableIndex.SHARDS_OPTION); + if (shardsOption != null) + { + Integer shardCount = Integer.parseInt(shardsOption); + if (shardCount > 1) + return new ShardedMemtableIndex(index, index.baseCfs(), shardCount, memtable); + } + return new UnshardedMemtableIndex(index, memtable); + }); } public long index(DecoratedKey key, Row row, Memtable mt) diff --git a/src/java/org/apache/cassandra/index/sai/memory/ShardedMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/ShardedMemtableIndex.java new file mode 100644 index 000000000000..2b6f8743b204 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/memory/ShardedMemtableIndex.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.memory; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.atomic.LongAdder; + +import javax.annotation.Nullable; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + +import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.db.memtable.ShardBoundaries; +import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.index.sai.QueryContext; +import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.index.sai.iterators.KeyRangeConcatIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.plan.Expression; +import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.index.sai.utils.PrimaryKeys; +import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.MergeIterator; +import org.apache.cassandra.utils.Pair; +import org.apache.cassandra.utils.bytecomparable.ByteComparable; + +import static org.apache.cassandra.config.CassandraRelevantProperties.MEMTABLE_SHARD_COUNT; + +public class ShardedMemtableIndex implements MemtableIndex +{ + private final ShardBoundaries boundaries; + private final MemoryIndex[] shards; + private final AbstractType comparator; + private final StorageAttachedIndex index; + private final LongAdder writeCount = new LongAdder(); + private final LongAdder estimatedMemoryUsed = new LongAdder(); + private final Memtable memtable; + + private static volatile int defaultShardCount = MEMTABLE_SHARD_COUNT.getInt(FBUtilities.getAvailableProcessors()); + public static final String SHARDS_OPTION = "shards"; + + public ShardedMemtableIndex(StorageAttachedIndex index, + Memtable.Owner owner, + Integer shardCountOption, + Memtable memtable) + { + this.index = index; + this.comparator = index.termType().indexType(); + int shardCount = (null == shardCountOption) ? defaultShardCount: shardCountOption; + this.boundaries = owner.localRangeSplits(shardCount); + this.shards = generateShards(boundaries.shardCount(), index); + this.memtable = memtable; + } + + private MemoryIndex[] generateShards(int splits, + StorageAttachedIndex index) + { + MemoryIndex[] generatedShards = new MemoryIndex[splits]; + + for (int shard = 0; shard < boundaries.shardCount(); shard++) + { + generatedShards[shard] = new TrieMemoryIndex(index); + } + + return generatedShards; + } + + @VisibleForTesting + public int shardCount() + { + return shards.length; + } + + public long writeCount() + { + return writeCount.sum(); + } + + public boolean isEmpty() + { + return getMinTerm() == null; + } + + public Memtable getMemtable() + { + return memtable; + } + + public long estimatedMemoryUsed() + { + return estimatedMemoryUsed.sum(); + } + + // Returns the minimum indexed term in the combined memory indexes. + // This can be null if the indexed memtable was empty. Users of the + // {@code MemtableIndex} requiring a non-null minimum term should + // use the {@link MemtableIndex#isEmpty} method. + // Note: Individual index shards can return null here if the index + // didn't receive any terms within the token range of the shard + @Nullable + public ByteBuffer getMinTerm() { + return Arrays.stream(shards) + .map(MemoryIndex::getMinTerm) + .filter(Objects::nonNull) + .min(comparator) + .orElse(null); + } + + // Returns the maximum indexed term in the combined memory indexes. + // This can be null if the indexed memtable was empty. Users of the + // {@code MemtableIndex} requiring a non-null maximum term should + // use the {@link MemtableIndex#isEmpty} method. + // Note: Individual index shards can return null here if the index + // didn't receive any terms within the token range of the shard + @Nullable + public ByteBuffer getMaxTerm() { + return Arrays.stream(shards) + .map(MemoryIndex::getMaxTerm) + .filter(Objects::nonNull) + .max(comparator) + .orElse(null); + } + + public long index(DecoratedKey key, Clustering clustering, ByteBuffer value) + { + if (value == null || (value.remaining() == 0 && index.termType().skipsEmptyValue())) + return 0; + + long ram = shards[boundaries.getShardForKey(key)].add(key, clustering, value); + writeCount.increment(); + estimatedMemoryUsed.add(ram); + return ram; + } + + public KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds keyRange) + { + List shardsForRange = boundaries.getShardsForRange(keyRange); + KeyRangeConcatIterator.Builder builder = KeyRangeConcatIterator.builder(shardsForRange.size()); + + for (int shard: shardsForRange) + { + assert shards[shard] != null; + builder.add(shards[shard].search(queryContext, expression, keyRange)); + } + + return builder.build(); + } + + public Iterator>> iterator(DecoratedKey min, DecoratedKey max) + { + int minSubrange = min == null ? 0 : boundaries.getShardForKey(min); + int maxSubrange = max == null ? shards.length - 1 : boundaries.getShardForKey(max); + + List>> rangeIterators = new ArrayList<>(maxSubrange - minSubrange + 1); + + for (int i = minSubrange; i <= maxSubrange; i++) + rangeIterators.add(shards[i].iterator()); + + return MergeIterator.get(rangeIterators, + (o1, o2) -> ByteComparable.compare(o1.left, o2.left, + ByteComparable.Version.OSS50), + new PrimaryKeysMergeReducer(rangeIterators.size())); + } + + // The PrimaryKeysMergeReducer receives the range iterators from each of the shards selected based on the + // min and max keys passed to the iterator method. It doesn't strictly do any reduction because the terms in each + // shard are unique. It will receive at most one shard entry per selected shard before getReduced + // is called. + private static class PrimaryKeysMergeReducer extends MergeIterator.Reducer, Pair>> + { + private final Pair[] shardEntriesToMerge; + private final Comparator comparator; + + private ByteComparable term; + + @SuppressWarnings("unchecked") + // The size represents the number of shards that have been selected for the merger + PrimaryKeysMergeReducer(int size) + { + this.shardEntriesToMerge = new Pair[size]; + this.comparator = PrimaryKey::compareTo; + } + + @Override + // Receive the term entry for a shard. This should only be called once for each + // shard before reduction. + public void reduce(int idx, Pair current) + { + Preconditions.checkArgument(shardEntriesToMerge[idx] == null, "Terms should be unique in the memory index"); + + shardEntriesToMerge[idx] = current; + if (current != null && term == null) + term = current.left; + } + + @Override + protected Pair> getReduced() + { + Preconditions.checkArgument(term != null, "The term must exist in memory index"); + + List> keyIterators = new ArrayList<>(shardEntriesToMerge.length); + for (Pair p : shardEntriesToMerge) + if (p != null && p.right != null && !p.right.isEmpty()) + keyIterators.add(p.right.iterator()); + + Iterator primaryKeys = MergeIterator.get(keyIterators, comparator, new MergeIterator.Reducer.Trivial<>()); + return Pair.create(term, primaryKeys); + } + + @Override + protected void onKeyChange() + { + Arrays.fill(shardEntriesToMerge, null); + term = null; + } + } +} diff --git a/src/java/org/apache/cassandra/index/sai/memory/UnshardedMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/UnshardedMemtableIndex.java new file mode 100644 index 000000000000..d0d14ab56b09 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/memory/UnshardedMemtableIndex.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.memory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.atomic.LongAdder; +import java.util.function.Function; + +import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.index.sai.QueryContext; +import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; +import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; +import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore; +import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.plan.Expression; +import org.apache.cassandra.index.sai.utils.IndexIdentifier; +import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.index.sai.utils.PrimaryKeys; +import org.apache.cassandra.utils.CloseableIterator; +import org.apache.cassandra.utils.Pair; +import org.apache.cassandra.utils.bytecomparable.ByteComparable; + +public class UnshardedMemtableIndex implements MemtableIndex +{ + private final MemoryIndex memoryIndex; + private final LongAdder writeCount = new LongAdder(); + private final LongAdder estimatedMemoryUsed = new LongAdder(); + private final Memtable memtable; + + public UnshardedMemtableIndex(StorageAttachedIndex index, Memtable memtable) + { + this.memoryIndex = index.termType().isVector() ? new VectorMemoryIndex(index, memtable) : new TrieMemoryIndex(index); + this.memtable = memtable; + } + + public long writeCount() + { + return writeCount.sum(); + } + + public long estimatedMemoryUsed() + { + return estimatedMemoryUsed.sum(); + } + + public boolean isEmpty() + { + return memoryIndex.isEmpty(); + } + + public Memtable getMemtable() + { + return memtable; + } + + public ByteBuffer getMinTerm() + { + return memoryIndex.getMinTerm(); + } + + public ByteBuffer getMaxTerm() + { + return memoryIndex.getMaxTerm(); + } + + public long index(DecoratedKey key, Clustering clustering, ByteBuffer value) + { + if (value == null || (value.remaining() == 0 && memoryIndex.index.termType().skipsEmptyValue())) + return 0; + + long ram = memoryIndex.add(key, clustering, value); + writeCount.increment(); + estimatedMemoryUsed.add(ram); + return ram; + } + + /** Used only for Vector Indexes */ + public long update(DecoratedKey key, Clustering clustering, ByteBuffer oldValue, ByteBuffer newValue) + { + return memoryIndex.update(key, clustering, oldValue, newValue); + } + + public KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds keyRange) + { + return memoryIndex.search(queryContext, expression, keyRange); + } + + public Iterator>> iterator(DecoratedKey min, DecoratedKey max) + { + Iterator> memoryIndexIterator = memoryIndex.iterator(); + + return new Iterator<>() + { + @Override + public boolean hasNext() + { + return memoryIndexIterator.hasNext(); + } + + @Override + public Pair> next() + { + Pair p = memoryIndexIterator.next(); + return Pair.create(p.left, p.right.iterator()); + } + }; + } + + /** Used only for Vector Indexes */ + public SegmentMetadata.ComponentMetadataMap writeDirect(IndexDescriptor indexDescriptor, + IndexIdentifier indexIdentifier, + Function postingTransformer) throws IOException + { + return memoryIndex.writeDirect(indexDescriptor, indexIdentifier, postingTransformer); + } + + /** Used only for Vector Indexes */ + @Override + public CloseableIterator orderBy(QueryContext queryContext, Expression orderer, AbstractBounds keyRange) + { + return memoryIndex.orderBy(queryContext, orderer, keyRange); + } + + @Override + public CloseableIterator orderResultsBy(QueryContext queryContext, List results, Expression orderer) + { + return memoryIndex.orderResultsBy(queryContext, results, orderer); + } +} diff --git a/test/microbench/org/apache/cassandra/test/microbench/sai/AbstractMemtableIndexBench.java b/test/microbench/org/apache/cassandra/test/microbench/sai/AbstractMemtableIndexBench.java new file mode 100644 index 000000000000..4632ea6eabea --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/sai/AbstractMemtableIndexBench.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench.sai; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.cql3.statements.schema.IndexTarget; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.marshal.UTF8Type; +import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.index.sai.memory.MemtableIndex; +import org.apache.cassandra.schema.IndexMetadata; +import org.apache.cassandra.schema.TableMetadata; + +public abstract class AbstractMemtableIndexBench extends CQLTester +{ + private static final int RANDOM_STRING_SIZE = 64 * 1024 * 1024; + private static String keyspace; + + protected MemtableIndex memtableIndex; + protected DecoratedKey[] partitionKeys; + protected ByteBuffer[] terms; + protected StorageAttachedIndex index; + + protected ColumnFamilyStore cfs; + protected String table; + private char[] randomChars = new char[RANDOM_STRING_SIZE]; + + public void setup(int numberOfTerms, int rowsPerPartition) + { + setupServer(); + setupTableAndKeyspace(); + setupCfsAndIndex(); + setupPartitionKeys(numberOfTerms, rowsPerPartition); + setupTerms(numberOfTerms); + } + + public void setupServer() + { + CQLTester.setUpClass(); + DatabaseDescriptor.setAutoSnapshot(false); + } + + public void setupTableAndKeyspace() + { + keyspace = createKeyspace("CREATE KEYSPACE %s with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 } and durable_writes = false"); + table = createTable(keyspace, + "CREATE TABLE %s ( partition_id text, value_text text, PRIMARY KEY(partition_id)) with compression = {'enabled': false}", + "memtable_index"); + execute("use " + keyspace + ";"); + } + + public void setupCfsAndIndex() + { + cfs = Keyspace.open(keyspace).getColumnFamilyStore(table); + cfs.disableAutoCompaction(); + + Map options = new HashMap<>(); + options.put(IndexTarget.CUSTOM_INDEX_OPTION_NAME, + StorageAttachedIndex.class.getCanonicalName()); + options.put("target", "value_text"); + + IndexMetadata indexMetadata = IndexMetadata.fromSchemaMetadata("value_text_idx", IndexMetadata.Kind.CUSTOM, options); + + index = new StorageAttachedIndex(cfs, indexMetadata); + } + + public void setupPartitionKeys(int numberOfTerms, int rowsPerPartition) + { + TableMetadata tableMetadata = cfs.metadata(); + + int numberOfKeys = numberOfTerms / rowsPerPartition; + + partitionKeys = new DecoratedKey[numberOfKeys]; + for (int i = 0; i < numberOfKeys; i++) + partitionKeys[i] = tableMetadata.partitioner.decorateKey(tableMetadata.partitionKeyType.fromString("partition_" + i)); + } + + public void setupTerms(int numberOfTerms) + { + Random random = new Random(); + for (int i = 0; i < RANDOM_STRING_SIZE; i++) + randomChars[i] = (char)('a' + random.nextInt(26)); + + int length = 64; + terms = new ByteBuffer[numberOfTerms]; + for (int i = 0; i < numberOfTerms; i++) + terms[i] = UTF8Type.instance.decompose(generateRandomString(random, length)); + } + + private String generateRandomString(Random random, int length) + { + return new String(randomChars, random.nextInt(RANDOM_STRING_SIZE - length), length); + } +} diff --git a/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexFlushBench.java b/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexFlushBench.java new file mode 100644 index 000000000000..d186e9ec2365 --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexFlushBench.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench.sai; + +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.marshal.UTF8Type; +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.index.sai.memory.ShardedMemtableIndex; +import org.apache.cassandra.index.sai.memory.UnshardedMemtableIndex; +import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.utils.Pair; +import org.apache.cassandra.utils.bytecomparable.ByteComparable; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Warmup(iterations = 3, time = 2) +@Measurement(iterations = 5, time = 5) +@Fork(value = 1, jvmArgsAppend = { "-Xmx4G", "-Xms4G", "-Djmh.executor=CUSTOM", "-Djmh.executor.class=org.apache.cassandra.test.microbench.FastThreadExecutor"}) +@State(Scope.Benchmark) +public class MemtableIndexFlushBench extends AbstractMemtableIndexBench +{ + // 1 Million partitionKeys + private static final int NUM_PARTITION_KEYS = 1000000; + private static final int SMALL_POOL_SIZE = 8 * 1024; + + @Param({ "1", "4", "8"}) + int shardCount; + + @Param({"50", "100"}) + int numberOfTerms; + + private char[] smallPool = new char[SMALL_POOL_SIZE]; + + @Setup(Level.Trial) + public void setup() + { + setupServer(); + setupTableAndKeyspace(); + setupCfsAndIndex(); + setupPartitionKeys(); + setupIndexesExpressionsAndTerms(); + } + + public void setupPartitionKeys() + { + TableMetadata tableMetadata = cfs.metadata(); + + partitionKeys = new DecoratedKey[NUM_PARTITION_KEYS]; + for (int i = 0; i < NUM_PARTITION_KEYS; i++) + partitionKeys[i] = tableMetadata.partitioner.decorateKey(tableMetadata.partitionKeyType.fromString("partition_" + i)); + } + + public void setupIndexesExpressionsAndTerms() { + Memtable memtable = cfs.getCurrentMemtable(); + memtableIndex = (shardCount > 1) + ? new ShardedMemtableIndex(index, cfs, shardCount, memtable) : + new UnshardedMemtableIndex(index, memtable); + + setupTerms(numberOfTerms); + populateIndexData(); + } + + @Override + public void setupTerms(int numberOfTerms) + { + Random random = new Random(); + for (int i = 0; i < SMALL_POOL_SIZE; i++) + smallPool[i] = (char)('a' + random.nextInt(26)); + + int length = 64; + terms = new ByteBuffer[numberOfTerms]; + for (int i = 0; i < numberOfTerms; i++) + terms[i] = UTF8Type.instance.decompose( + new String(smallPool, random.nextInt(SMALL_POOL_SIZE - length), length)); + } + + private void populateIndexData() + { + int termCount = 0; + + for (int i = 0; i < NUM_PARTITION_KEYS; i++) + { + DecoratedKey partitionKey = partitionKeys[i]; + memtableIndex.index(partitionKey, Clustering.EMPTY, terms[termCount]); + + if (++termCount == numberOfTerms) + { + termCount = 0; + } + } + } + + @Benchmark + public long flushBench() + { + Iterator>> it = memtableIndex.iterator(null, null); + long count = 0; + while (it.hasNext()) + { + Iterator primaryKeyIterator = it.next().right; + while (primaryKeyIterator.hasNext()) + { + primaryKeyIterator.next(); + count++; + } + } + return count; + } +} diff --git a/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexPartitionReadBench.java b/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexPartitionReadBench.java new file mode 100644 index 000000000000..604736d9b095 --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexPartitionReadBench.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench.sai; + +import java.nio.ByteBuffer; +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.marshal.UTF8Type; +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.Bounds; +import org.apache.cassandra.index.sai.QueryContext; +import org.apache.cassandra.index.sai.memory.ShardedMemtableIndex; +import org.apache.cassandra.index.sai.memory.UnshardedMemtableIndex; +import org.apache.cassandra.index.sai.plan.Expression; +import org.apache.cassandra.schema.TableMetadata; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Warmup(iterations = 6, time = 3) +@Measurement(iterations = 5, time = 5) +@Fork(value = 1, jvmArgsAppend = { "-Xmx4G", "-Xms4G", "-Djmh.executor=CUSTOM", "-Djmh.executor.class=org.apache.cassandra.test.microbench.FastThreadExecutor"}) +@State(Scope.Benchmark) +public class MemtableIndexPartitionReadBench extends AbstractMemtableIndexBench +{ + // 1 Million partitionKeys + private static final int NUM_PARTITION_KEYS = 1000000; + private static final int NUMBER_OF_SEARCHES = 1000; + private static final int SMALL_POOL_SIZE = 8 * 1024; + + @Param({ "1", "4", "8"}) + int shardCount; + + @Param({"50", "100"}) + int numberOfTerms; + + private char[] smallPool = new char[SMALL_POOL_SIZE]; + private Expression[] stringEqualityExpressions; + private QueryContext queryContext; + private AbstractBounds[] keyRanges; + + @State(Scope.Thread) + public static class ThreadState + { + ThreadLocalRandom random = ThreadLocalRandom.current(); + } + + @Setup(Level.Trial) + public void setup() + { + setupServer(); + setupTableAndKeyspace(); + setupCfsAndIndex(); + setupPartitionKeys(); + setupQueryContext(); + setupIndexesExpressionsAndTerms(); + } + + private void setupQueryContext() + { + // setup a dummy query context, interface signature needs it. + queryContext = new QueryContext(null, Long.MAX_VALUE); + } + + public void setupPartitionKeys() + { + TableMetadata tableMetadata = cfs.metadata(); + + partitionKeys = new DecoratedKey[NUM_PARTITION_KEYS]; + for (int i = 0; i < NUM_PARTITION_KEYS; i++) + partitionKeys[i] = tableMetadata.partitioner.decorateKey(tableMetadata.partitionKeyType.fromString("partition_" + i)); + } + + public void setupIndexesExpressionsAndTerms() { + Memtable memtable = cfs.getCurrentMemtable(); + memtableIndex = (shardCount > 1) + ? new ShardedMemtableIndex(index, cfs, shardCount, memtable) : + new UnshardedMemtableIndex(index, memtable); + + setupTerms(numberOfTerms); + populateIndexDataAndKeyRanges(); + populateExpressions(); + } + + @Override + public void setupTerms(int numberOfTerms) + { + Random random = new Random(); + for (int i = 0; i < SMALL_POOL_SIZE; i++) + smallPool[i] = (char)('a' + random.nextInt(26)); + + int length = 64; + terms = new ByteBuffer[numberOfTerms]; + for (int i = 0; i < numberOfTerms; i++) + terms[i] = UTF8Type.instance.decompose( + new String(smallPool, random.nextInt(SMALL_POOL_SIZE - length), length)); + } + + private void populateIndexDataAndKeyRanges() + { + int termCount = 0; + keyRanges = new AbstractBounds[NUM_PARTITION_KEYS]; + + for (int i = 0; i < NUM_PARTITION_KEYS; i++) + { + DecoratedKey partitionKey = partitionKeys[i]; + memtableIndex.index(partitionKey, Clustering.EMPTY, terms[termCount]); + keyRanges[i] = new Bounds<>(partitionKey, partitionKey); + + if (++termCount == numberOfTerms) + { + termCount = 0; + } + } + } + + private void populateExpressions() + { + Random random = new Random(); + stringEqualityExpressions = new Expression[NUMBER_OF_SEARCHES]; + for (int i = 0; i < NUMBER_OF_SEARCHES; i++) + stringEqualityExpressions[i] = Expression.create(index).add(Operator.EQ, terms[random.nextInt(terms.length)]); + } + + @Benchmark + public long stringEqualityPartitionRestrictedRangeSearch(ThreadState state) + { + long size = 0; + memtableIndex.search(queryContext, + stringEqualityExpressions[state.random.nextInt(stringEqualityExpressions.length)], + keyRanges[state.random.nextInt(keyRanges.length)]); + return size; + } +} diff --git a/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexReadBench.java b/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexReadBench.java new file mode 100644 index 000000000000..46787cee2846 --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexReadBench.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench.sai; + +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DataRange; +import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.Murmur3Partitioner; +import org.apache.cassandra.index.sai.QueryContext; +import org.apache.cassandra.index.sai.memory.ShardedMemtableIndex; +import org.apache.cassandra.index.sai.memory.UnshardedMemtableIndex; +import org.apache.cassandra.index.sai.plan.Expression; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Warmup(iterations = 3, time = 2) +@Measurement(iterations = 5, time = 5) +@Fork(value = 1, jvmArgsAppend = { "-Xmx4G", "-Xms4G", "-Djmh.executor=CUSTOM", "-Djmh.executor.class=org.apache.cassandra.test.microbench.FastThreadExecutor"}) +@State(Scope.Benchmark) +public class MemtableIndexReadBench extends AbstractMemtableIndexBench +{ + private static final int NUMBER_OF_SEARCHES = 1000; + private static final AbstractBounds ALL_DATA_RANGE = DataRange.allData(Murmur3Partitioner.instance).keyRange(); + + @Param({ "1", "4", "8"}) + int shardCount; + + @Param({"1000000" }) + protected int numberOfTerms; + + @Param({ "1", "10", "100"}) + protected int rowsPerPartition; + + private Expression[] stringEqualityExpressions; + private QueryContext queryContext; + + @State(Scope.Thread) + public static class ThreadState + { + ThreadLocalRandom random = ThreadLocalRandom.current(); + } + + @Setup(Level.Trial) + public void setup() + { + super.setup(numberOfTerms, rowsPerPartition); + } + + @Setup(Level.Iteration) + public void setupIndexesAndExpressions() { + Memtable memtable = cfs.getCurrentMemtable(); + memtableIndex = (shardCount > 1) + ? new ShardedMemtableIndex(index, cfs, shardCount, memtable): + new UnshardedMemtableIndex(index, memtable); + + populateIndexData(); + populateExpressions(); + // setup a dummy query context, interface signature needs it. + queryContext = new QueryContext(null, Long.MAX_VALUE); + } + + private void populateIndexData() + { + int rowCount = 0; + int keyCount = 0; + for (int i = 0; i < numberOfTerms; i++) + { + memtableIndex.index(partitionKeys[keyCount], Clustering.EMPTY, terms[i]); + if (++rowCount == rowsPerPartition) + { + rowCount = 0; + keyCount++; + } + } + } + + private void populateExpressions() + { + Random random = new Random(); + stringEqualityExpressions = new Expression[NUMBER_OF_SEARCHES]; + for (int i = 0; i < NUMBER_OF_SEARCHES; i++) + stringEqualityExpressions[i] = Expression.create(index).add(Operator.EQ, terms[random.nextInt(terms.length)]); + } + + @Benchmark + public long stringEqualitySearch(ThreadState state) + { + long size = 0; + memtableIndex.search(queryContext, + stringEqualityExpressions[state.random.nextInt(stringEqualityExpressions.length)], + ALL_DATA_RANGE); + return size; + } +} diff --git a/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexWriteBench.java b/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexWriteBench.java new file mode 100644 index 000000000000..164a11e9a525 --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/sai/MemtableIndexWriteBench.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench.sai; + +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.commitlog.CommitLog; +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.index.sai.memory.ShardedMemtableIndex; +import org.apache.cassandra.index.sai.memory.UnshardedMemtableIndex; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Warmup(iterations = 3, time = 2) +@Measurement(iterations = 5, time = 5) +@Fork(value = 1, jvmArgsAppend = { "-Xmx4G", "-Xms4G", "-Djmh.executor=CUSTOM", "-Djmh.executor.class=org.apache.cassandra.test.microbench.FastThreadExecutor"}) +@Threads(8) +@State(Scope.Benchmark) +public class MemtableIndexWriteBench extends AbstractMemtableIndexBench +{ + @Param({ "1", "4", "8"}) + int shardCount; + + @Param({"1000000" }) + protected int numberOfTerms; + + @Param({ "1", "10", "100"}) + protected int rowsPerPartition; + + @State(Scope.Thread) + public static class ThreadState + { + ThreadLocalRandom random = ThreadLocalRandom.current(); + } + + @Setup(Level.Trial) + public void setup() + { + super.setup(numberOfTerms, rowsPerPartition); + } + + @Setup(Level.Iteration) + public void setupIndexes() + { + Memtable memtable = cfs.getCurrentMemtable(); + memtableIndex = (shardCount > 1) + ? new ShardedMemtableIndex(index, cfs, shardCount, memtable): + new UnshardedMemtableIndex(index, memtable); + } + + @Benchmark + public long write(ThreadState state) + { + return memtableIndex.index(partitionKeys[state.random.nextInt(partitionKeys.length)], + Clustering.EMPTY, + terms[state.random.nextInt(terms.length)]); + } + + @TearDown(Level.Trial) + public void teardown() throws InterruptedException + { + CommitLog.instance.shutdownBlocking(); + CQLTester.tearDownClass(); + CQLTester.cleanup(); + } +} diff --git a/test/unit/org/apache/cassandra/dht/RangeIntersectsBoundsTest.java b/test/unit/org/apache/cassandra/dht/RangeIntersectsBoundsTest.java new file mode 100644 index 000000000000..01112263b559 --- /dev/null +++ b/test/unit/org/apache/cassandra/dht/RangeIntersectsBoundsTest.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.dht; + +import org.junit.Test; + +import org.apache.cassandra.dht.RandomPartitioner.BigIntegerToken; + +import org.apache.cassandra.CassandraTestBase; +import org.apache.cassandra.CassandraTestBase.DDDaemonInitialization; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +@DDDaemonInitialization +public class RangeIntersectsBoundsTest extends CassandraTestBase +{ + @Test + public void rangeIntersectsBounds() throws Exception + { + Range all = new Range<>(new BigIntegerToken("0"), new BigIntegerToken("0")); + Range some = new Range<>(new BigIntegerToken("4"), new BigIntegerToken("8")); + Range someWrapped = new Range<>(some.right, some.left); + + // Coda: + // l - matches left token of some range + // r - matches right token of some range + // b - below left token of some range + // a - above right token of some range + // i - inside some range (above left and below right + Bounds lr = new Bounds<>(new BigIntegerToken("4"), new BigIntegerToken("8")); + Bounds br = new Bounds<>(new BigIntegerToken("3"), new BigIntegerToken("8")); + Bounds bi = new Bounds<>(new BigIntegerToken("3"), new BigIntegerToken("7")); + Bounds ba = new Bounds<>(new BigIntegerToken("3"), new BigIntegerToken("9")); + Bounds la = new Bounds<>(new BigIntegerToken("4"), new BigIntegerToken("9")); + Bounds li = new Bounds<>(new BigIntegerToken("4"), new BigIntegerToken("7")); + Bounds ii = new Bounds<>(new BigIntegerToken("5"), new BigIntegerToken("7")); + Bounds ir = new Bounds<>(new BigIntegerToken("5"), new BigIntegerToken("8")); + Bounds bb = new Bounds<>(new BigIntegerToken("2"), new BigIntegerToken("3")); + Bounds aa = new Bounds<>(new BigIntegerToken("9"), new BigIntegerToken("10")); + Bounds bl = new Bounds<>(new BigIntegerToken("3"), new BigIntegerToken("4")); + Bounds ra = new Bounds<>(new BigIntegerToken("8"), new BigIntegerToken("9")); + + assertTrue(all.intersects(lr)); + assertTrue(all.intersects(br)); + assertTrue(all.intersects(bi)); + assertTrue(all.intersects(ba)); + assertTrue(all.intersects(la)); + assertTrue(all.intersects(li)); + assertTrue(all.intersects(ii)); + assertTrue(all.intersects(ir)); + assertTrue(all.intersects(bb)); + assertTrue(all.intersects(aa)); + assertTrue(all.intersects(bl)); + assertTrue(all.intersects(ra)); + + assertTrue(some.intersects(lr)); + assertTrue(some.intersects(br)); + assertTrue(some.intersects(bi)); + assertTrue(some.intersects(ba)); + assertTrue(some.intersects(la)); + assertTrue(some.intersects(li)); + assertTrue(some.intersects(ii)); + assertTrue(some.intersects(ir)); + assertFalse(some.intersects(bb)); + assertFalse(some.intersects(aa)); + assertFalse(some.intersects(bl)); + assertTrue(some.intersects(ra)); + + assertTrue(someWrapped.intersects(lr)); + assertTrue(someWrapped.intersects(br)); + assertTrue(someWrapped.intersects(bi)); + assertTrue(someWrapped.intersects(ba)); + assertTrue(someWrapped.intersects(la)); + assertTrue(someWrapped.intersects(li)); + assertFalse(someWrapped.intersects(ii)); + assertFalse(someWrapped.intersects(ir)); + assertTrue(someWrapped.intersects(bb)); + assertTrue(someWrapped.intersects(aa)); + assertTrue(someWrapped.intersects(bl)); + assertTrue(someWrapped.intersects(ra)); + } + + @Test + public void rangeIntersectsExcludingBounds() throws Exception + { + Range all = new Range<>(new BigIntegerToken("0"), new BigIntegerToken("0")); + Range some = new Range<>(new BigIntegerToken("4"), new BigIntegerToken("8")); + Range someWrapped = new Range<>(some.right, some.left); + + // Coda: + // l - matches left token of some range + // r - matches right token of some range + // b - below left token of some range + // a - above right token of some range + // i - inside some range (above left and below right + ExcludingBounds lr = new ExcludingBounds<>(new BigIntegerToken("4"), new BigIntegerToken("8")); + ExcludingBounds br = new ExcludingBounds<>(new BigIntegerToken("3"), new BigIntegerToken("8")); + ExcludingBounds bi = new ExcludingBounds<>(new BigIntegerToken("3"), new BigIntegerToken("7")); + ExcludingBounds ba = new ExcludingBounds<>(new BigIntegerToken("3"), new BigIntegerToken("9")); + ExcludingBounds la = new ExcludingBounds<>(new BigIntegerToken("4"), new BigIntegerToken("9")); + ExcludingBounds li = new ExcludingBounds<>(new BigIntegerToken("4"), new BigIntegerToken("7")); + ExcludingBounds ii = new ExcludingBounds<>(new BigIntegerToken("5"), new BigIntegerToken("7")); + ExcludingBounds ir = new ExcludingBounds<>(new BigIntegerToken("5"), new BigIntegerToken("8")); + ExcludingBounds bb = new ExcludingBounds<>(new BigIntegerToken("2"), new BigIntegerToken("3")); + ExcludingBounds aa = new ExcludingBounds<>(new BigIntegerToken("9"), new BigIntegerToken("10")); + ExcludingBounds bl = new ExcludingBounds<>(new BigIntegerToken("3"), new BigIntegerToken("4")); + ExcludingBounds ra = new ExcludingBounds<>(new BigIntegerToken("8"), new BigIntegerToken("9")); + + assertTrue(all.intersects(lr)); + assertTrue(all.intersects(br)); + assertTrue(all.intersects(bi)); + assertTrue(all.intersects(ba)); + assertTrue(all.intersects(la)); + assertTrue(all.intersects(li)); + assertTrue(all.intersects(ii)); + assertTrue(all.intersects(ir)); + assertTrue(all.intersects(bb)); + assertTrue(all.intersects(aa)); + assertTrue(all.intersects(bl)); + assertTrue(all.intersects(ra)); + + assertTrue(some.intersects(lr)); + assertTrue(some.intersects(br)); + assertTrue(some.intersects(bi)); + assertTrue(some.intersects(ba)); + assertTrue(some.intersects(la)); + assertTrue(some.intersects(li)); + assertTrue(some.intersects(ii)); + assertTrue(some.intersects(ir)); + assertFalse(some.intersects(bb)); + assertFalse(some.intersects(aa)); + assertFalse(some.intersects(bl)); + assertFalse(some.intersects(ra)); + + assertTrue(someWrapped.intersects(lr)); + assertTrue(someWrapped.intersects(br)); + assertTrue(someWrapped.intersects(bi)); + assertTrue(someWrapped.intersects(ba)); + assertTrue(someWrapped.intersects(la)); + assertTrue(someWrapped.intersects(li)); + assertFalse(someWrapped.intersects(ii)); + assertFalse(someWrapped.intersects(ir)); + assertTrue(someWrapped.intersects(bb)); + assertTrue(someWrapped.intersects(aa)); + assertFalse(someWrapped.intersects(bl)); + assertFalse(someWrapped.intersects(ra)); + + Range range = new Range<>(Murmur3Partitioner.MINIMUM, new Murmur3Partitioner.LongToken(-1)); + ExcludingBounds bounds = new ExcludingBounds<>(new Murmur3Partitioner.LongToken(-3248873570005575792L), Murmur3Partitioner.MINIMUM); + + assertTrue(range.intersects(bounds)); + + range = new Range<>(new Murmur3Partitioner.LongToken(-1), Murmur3Partitioner.MINIMUM); + + assertTrue(range.intersects(bounds)); + + } + + @Test + public void rangeIntersectsIncludingExcludingBounds() + { + Range all = new Range<>(new BigIntegerToken("0"), new BigIntegerToken("0")); + Range some = new Range<>(new BigIntegerToken("4"), new BigIntegerToken("8")); + Range someWrapped = new Range<>(some.right, some.left); + + // Coda: + // l - matches left token of some range + // r - matches right token of some range + // b - below left token of some range + // a - above right token of some range + // i - inside some range (above left and below right + IncludingExcludingBounds lr = new IncludingExcludingBounds<>(new BigIntegerToken("4"), new BigIntegerToken("8")); + IncludingExcludingBounds br = new IncludingExcludingBounds<>(new BigIntegerToken("3"), new BigIntegerToken("8")); + IncludingExcludingBounds bi = new IncludingExcludingBounds<>(new BigIntegerToken("3"), new BigIntegerToken("7")); + IncludingExcludingBounds ba = new IncludingExcludingBounds<>(new BigIntegerToken("3"), new BigIntegerToken("9")); + IncludingExcludingBounds la = new IncludingExcludingBounds<>(new BigIntegerToken("4"), new BigIntegerToken("9")); + IncludingExcludingBounds li = new IncludingExcludingBounds<>(new BigIntegerToken("4"), new BigIntegerToken("7")); + IncludingExcludingBounds ii = new IncludingExcludingBounds<>(new BigIntegerToken("5"), new BigIntegerToken("7")); + IncludingExcludingBounds ir = new IncludingExcludingBounds<>(new BigIntegerToken("5"), new BigIntegerToken("8")); + IncludingExcludingBounds bb = new IncludingExcludingBounds<>(new BigIntegerToken("2"), new BigIntegerToken("3")); + IncludingExcludingBounds aa = new IncludingExcludingBounds<>(new BigIntegerToken("9"), new BigIntegerToken("10")); + IncludingExcludingBounds bl = new IncludingExcludingBounds<>(new BigIntegerToken("3"), new BigIntegerToken("4")); + IncludingExcludingBounds ra = new IncludingExcludingBounds<>(new BigIntegerToken("8"), new BigIntegerToken("9")); + + assertTrue(all.intersects(lr)); + assertTrue(all.intersects(br)); + assertTrue(all.intersects(bi)); + assertTrue(all.intersects(ba)); + assertTrue(all.intersects(la)); + assertTrue(all.intersects(li)); + assertTrue(all.intersects(ii)); + assertTrue(all.intersects(ir)); + assertTrue(all.intersects(bb)); + assertTrue(all.intersects(aa)); + assertTrue(all.intersects(bl)); + assertTrue(all.intersects(ra)); + + assertTrue(some.intersects(lr)); + assertTrue(some.intersects(br)); + assertTrue(some.intersects(bi)); + assertTrue(some.intersects(ba)); + assertTrue(some.intersects(la)); + assertTrue(some.intersects(li)); + assertTrue(some.intersects(ii)); + assertTrue(some.intersects(ir)); + assertFalse(some.intersects(bb)); + assertFalse(some.intersects(aa)); + assertFalse(some.intersects(bl)); + assertTrue(some.intersects(ra)); + + assertTrue(someWrapped.intersects(lr)); + assertTrue(someWrapped.intersects(br)); + assertTrue(someWrapped.intersects(bi)); + assertTrue(someWrapped.intersects(ba)); + assertTrue(someWrapped.intersects(la)); + assertTrue(someWrapped.intersects(li)); + assertFalse(someWrapped.intersects(ii)); + assertFalse(someWrapped.intersects(ir)); + assertTrue(someWrapped.intersects(bb)); + assertTrue(someWrapped.intersects(aa)); + assertFalse(someWrapped.intersects(bl)); + assertTrue(someWrapped.intersects(ra)); + } + + /** + * Test that we handle partial bounds of the type x > n or x >= n which specifically have + * their right value as minimum. + */ + @Test + public void rangeIntersectsPartialBounds() + { + Range range = new Range<>(Murmur3Partitioner.MINIMUM, new Murmur3Partitioner.LongToken(-1L)); + + Bounds boundsMatch = new Bounds<>(new Murmur3Partitioner.LongToken(-2L), Murmur3Partitioner.MINIMUM); + Bounds boundsNoMatch = new Bounds<>(new Murmur3Partitioner.LongToken(0L), Murmur3Partitioner.MINIMUM); + + assertTrue(range.intersects(boundsMatch)); + assertFalse(range.intersects(boundsNoMatch)); + + ExcludingBounds excBoundsMatch = new ExcludingBounds<>(new Murmur3Partitioner.LongToken(-2L), Murmur3Partitioner.MINIMUM); + ExcludingBounds excBoundsNoMatch = new ExcludingBounds<>(new Murmur3Partitioner.LongToken(-1L), Murmur3Partitioner.MINIMUM); + + assertTrue(range.intersects(excBoundsMatch)); + assertFalse(range.intersects(excBoundsNoMatch)); + + IncludingExcludingBounds incExcBoundsMatch = new IncludingExcludingBounds<>(new Murmur3Partitioner.LongToken(-2L), Murmur3Partitioner.MINIMUM); + IncludingExcludingBounds incExcBoundsNoMatch = new IncludingExcludingBounds<>(new Murmur3Partitioner.LongToken(0L), Murmur3Partitioner.MINIMUM); + + assertTrue(range.intersects(incExcBoundsMatch)); + assertFalse(range.intersects(incExcBoundsNoMatch)); + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/cql/StorageAttachedIndexDDLTest.java b/test/unit/org/apache/cassandra/index/sai/cql/StorageAttachedIndexDDLTest.java index 69f639238a84..f99bd6a671a5 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/StorageAttachedIndexDDLTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/StorageAttachedIndexDDLTest.java @@ -266,6 +266,16 @@ public void shouldFailCreateWithInvalidCharactersInColumnName() .hasMessage(String.format(CreateIndexStatement.INVALID_CHARS_CUSTOM_INDEX_TARGET, invalidColumn)); } + @Test + public void shouldRejectShardsOptionOnVectorIndex() + { + createTable("CREATE TABLE %s (id text PRIMARY KEY, val vector)"); + + assertThatThrownBy(() -> executeNet("CREATE INDEX ON %s(val) USING 'sai' WITH OPTIONS = { 'shards' : '4' }")) + .isInstanceOf(InvalidQueryException.class) + .hasMessageContaining("vector column does not support sharding"); + } + @Test public void shouldCreateIndexIfExists() { diff --git a/test/unit/org/apache/cassandra/index/sai/memory/ShardedMemtableIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/ShardedMemtableIndexTest.java new file mode 100644 index 000000000000..49be3be44509 --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/memory/ShardedMemtableIndexTest.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.memory; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.stream.Collectors; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.cql3.statements.schema.IndexTarget; +import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.marshal.Int32Type; +import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.IPartitioner; +import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.plan.Expression; +import org.apache.cassandra.index.sai.utils.PrimaryKey; +import org.apache.cassandra.index.sai.utils.SAIRandomizedTester; +import org.apache.cassandra.schema.IndexMetadata; +import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.utils.Pair; +import org.apache.cassandra.utils.bytecomparable.ByteComparable; +import org.apache.cassandra.utils.bytecomparable.ByteSource; + +import static org.apache.cassandra.config.CassandraRelevantProperties.MEMTABLE_SHARD_COUNT; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ShardedMemtableIndexTest extends SAIRandomizedTester +{ + private ColumnFamilyStore cfs; + private IPartitioner partitioner; + private StorageAttachedIndex index; + private ShardedMemtableIndex memtableIndex; + private Map keyMap; + private Map rowMap; + + @BeforeClass + public static void setShardCount() { + System.setProperty(MEMTABLE_SHARD_COUNT.getKey(), "8"); + } + + @Before + public void setup() throws Throwable + { + // CQLTester @BeforeClass already sets up server. + // Set up the keyspace and the table. + String keyspace = createKeyspace("CREATE KEYSPACE %s with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 } and durable_writes = false"); + String table = createTable(keyspace, + "CREATE TABLE %s (pk int PRIMARY KEY, val int)", + "memtable_index"); + execute("use " + keyspace + ";"); + + setupCfsAndIndex(keyspace, table); + + partitioner = cfs.getPartitioner(); + keyMap = new TreeMap<>(); + rowMap = new HashMap<>(); + } + + public void setupCfsAndIndex(String keyspace, String table) + { + cfs = Keyspace.open(keyspace).getColumnFamilyStore(table); + cfs.disableAutoCompaction(); + + Map options = new HashMap<>(); + options.put(IndexTarget.CUSTOM_INDEX_OPTION_NAME, + StorageAttachedIndex.class.getCanonicalName()); + options.put("target", "val"); + + IndexMetadata indexMetadata = IndexMetadata.fromSchemaMetadata("val_idx", IndexMetadata.Kind.CUSTOM, options); + + index = new StorageAttachedIndex(cfs, indexMetadata); + } + + @Test + public void onHeapAllocationTest() + { + // Should take the system variable-based shard count here + memtableIndex = new ShardedMemtableIndex(index, cfs, null, cfs.getCurrentMemtable()); + assertEquals(8, memtableIndex.shardCount()); + + assertEquals(0L, memtableIndex.writeCount()); + + for (int row = 0; row < 100; row++) + { + addRow(row, row); + } + + assertTrue(memtableIndex.writeCount() > 0); + } + + @Test + public void randomQueryTest() throws Exception + { + // Should take the system variable-based shard count here + memtableIndex = new ShardedMemtableIndex(index, cfs, null, cfs.getCurrentMemtable()); + assertEquals(8, memtableIndex.shardCount()); + + for (int row = 0; row < getRandom().nextIntBetween(1000, 5000); row++) + { + int pk = getRandom().nextIntBetween(0, 10000); + while (rowMap.containsKey(pk)) + pk = getRandom().nextIntBetween(0, 10000); + int value = getRandom().nextIntBetween(0, 100); + rowMap.put(pk, value); + addRow(pk, value); + } + + List keys = new ArrayList<>(keyMap.keySet()); + + for (int executionCount = 0; executionCount < 1000; executionCount++) + { + Expression expression = generateRandomExpression(index); + + AbstractBounds keyRange = generateRandomBounds(keys, partitioner); + + Set expectedKeys = keyMap.keySet() + .stream() + .filter(keyRange::contains) + .map(keyMap::get) + .filter(pk -> expression.isSatisfiedBy(Int32Type.instance.decompose(rowMap.get(pk)))) + .collect(Collectors.toSet()); + + Set foundKeys = new HashSet<>(); + + try (KeyRangeIterator iterator = memtableIndex.search(null, expression, keyRange)) + { + while (iterator.hasNext()) + { + int key = Int32Type.instance.compose(iterator.next().partitionKey().getKey()); + assertFalse(foundKeys.contains(key)); + foundKeys.add(key); + } + } + + assertEquals(expectedKeys, foundKeys); + } + } + + @Test + public void indexIteratorTest() + { + // Should take the system variable-based shard count here + memtableIndex = new ShardedMemtableIndex(index, cfs, null, cfs.getCurrentMemtable()); + assertEquals(8, memtableIndex.shardCount()); + + Map> terms = buildTermMap(); + + terms.entrySet() + .stream() + .forEach(entry -> entry.getValue() + .forEach(pk -> addRow(Int32Type.instance.compose(pk.getKey()), entry.getKey()))); + + for (int executionCount = 0; executionCount < 1000; executionCount++) + { + // These keys have midrange tokens that select 3 of the 8 shards + DecoratedKey minimum = makeKey(cfs.metadata(), getRandom().nextIntBetween(0, 20000)); + DecoratedKey temp = makeKey(cfs.metadata(), getRandom().nextIntBetween(0, 20000)); + while (temp.compareTo(minimum) <= 0) + temp = makeKey(cfs.metadata(), getRandom().nextIntBetween(0, 20000)); + DecoratedKey maximum = temp; + + Iterator>> iterator = memtableIndex.iterator(minimum, maximum); + + while (iterator.hasNext()) + { + Pair> termPair = iterator.next(); + int term = termFromComparable(termPair.left); + + // The iterator will return keys outside the range of min/max so we need to filter here to + // get the correct keys + List expectedPks = terms.get(term) + .stream() + .filter(pk -> pk.compareTo(minimum) >= 0 && pk.compareTo(maximum) <= 0) + .sorted() + .collect(Collectors.toList()); + + List termPks = new ArrayList<>(); + + while (termPair.right.hasNext()) + { + DecoratedKey pk = termPair.right.next().partitionKey(); + if (pk.compareTo(minimum) >= 0 && pk.compareTo(maximum) <= 0) + termPks.add(pk); + } + + assertEquals(expectedPks, termPks); + } + } + } + + private int termFromComparable(ByteComparable comparable) + { + ByteSource.Peekable peekable = ByteSource.peekable(comparable.asComparableBytes(ByteComparable.Version.OSS50)); + return Int32Type.instance.compose(Int32Type.instance.fromComparableBytes(peekable, ByteComparable.Version.OSS50)); + } + + private Map> buildTermMap() + { + Map> terms = new HashMap<>(); + + for (int count = 0; count < 10000; count++) + { + int term = getRandom().nextIntBetween(0, 100); + Set pks; + if (terms.containsKey(term)) + pks = terms.get(term); + else + { + pks = new HashSet<>(); + terms.put(term, pks); + } + DecoratedKey key = makeKey(cfs.metadata(), getRandom().nextIntBetween(0, 20000)); + while (pks.contains(key)) + key = makeKey(cfs.metadata(), getRandom().nextIntBetween(0, 20000)); + pks.add(key); + } + + return terms; + } + + private void addRow(int pk, int value) + { + DecoratedKey key = makeKey(cfs.metadata(), pk); + memtableIndex.index(key, Clustering.EMPTY, Int32Type.instance.decompose(value)); + keyMap.put(key, pk); + } + + private DecoratedKey makeKey(TableMetadata table, Integer partitionKey) + { + ByteBuffer key = table.partitionKeyType.fromString(partitionKey.toString()); + return table.partitioner.decorateKey(key); + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java index 3e512e9e234c..0c396f8747f2 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java @@ -33,7 +33,6 @@ import org.junit.Ignore; import org.junit.Test; -import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.cql3.statements.schema.IndexTarget; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.ColumnFamilyStore; @@ -43,11 +42,7 @@ import org.apache.cassandra.db.marshal.Int32Type; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.dht.AbstractBounds; -import org.apache.cassandra.dht.Bounds; -import org.apache.cassandra.dht.ExcludingBounds; -import org.apache.cassandra.dht.IncludingExcludingBounds; import org.apache.cassandra.dht.Murmur3Partitioner; -import org.apache.cassandra.dht.Range; import org.apache.cassandra.index.sai.StorageAttachedIndex; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; @@ -110,9 +105,9 @@ public void randomQueryTest() throws Exception for (int executionCount = 0; executionCount < 1000; executionCount++) { - Expression expression = generateRandomExpression(); + Expression expression = generateRandomExpression(this.index); - AbstractBounds keyRange = generateRandomBounds(keys); + AbstractBounds keyRange = generateRandomBounds(keys, Murmur3Partitioner.instance); Set expectedKeys = keyMap.keySet() .stream() @@ -137,62 +132,6 @@ public void randomQueryTest() throws Exception } } - private AbstractBounds generateRandomBounds(List keys) - { - PartitionPosition leftBound = getRandom().nextBoolean() ? Murmur3Partitioner.instance.getMinimumToken().minKeyBound() - : keys.get(getRandom().nextIntBetween(0, keys.size() - 1)).getToken().minKeyBound(); - - PartitionPosition rightBound = getRandom().nextBoolean() ? Murmur3Partitioner.instance.getMinimumToken().minKeyBound() - : keys.get(getRandom().nextIntBetween(0, keys.size() - 1)).getToken().maxKeyBound(); - - AbstractBounds keyRange; - - if (leftBound.isMinimum() && rightBound.isMinimum()) - keyRange = new Range<>(leftBound, rightBound); - else - { - if (AbstractBounds.strictlyWrapsAround(leftBound, rightBound)) - { - PartitionPosition temp = leftBound; - leftBound = rightBound; - rightBound = temp; - } - if (getRandom().nextBoolean()) - keyRange = new Bounds<>(leftBound, rightBound); - else if (getRandom().nextBoolean()) - keyRange = new ExcludingBounds<>(leftBound, rightBound); - else - keyRange = new IncludingExcludingBounds<>(leftBound, rightBound); - } - return keyRange; - } - - private Expression generateRandomExpression() - { - Expression expression = Expression.create(index); - - int equality = getRandom().nextIntBetween(0, 100); - int lower = getRandom().nextIntBetween(0, 75); - int upper = getRandom().nextIntBetween(25, 100); - while (upper <= lower) - upper = getRandom().nextIntBetween(0, 100); - - if (getRandom().nextBoolean()) - expression.add(Operator.EQ, Int32Type.instance.decompose(equality)); - else - { - boolean useLower = getRandom().nextBoolean(); - boolean useUpper = getRandom().nextBoolean(); - if (!useLower && !useUpper) - useLower = useUpper = true; - if (useLower) - expression.add(getRandom().nextBoolean() ? Operator.GT : Operator.GTE, Int32Type.instance.decompose(lower)); - if (useUpper) - expression.add(getRandom().nextBoolean() ? Operator.LT : Operator.LTE, Int32Type.instance.decompose(upper)); - } - return expression; - } - @Test public void shouldAcceptPrefixValuesTest() { diff --git a/test/unit/org/apache/cassandra/index/sai/utils/SAIRandomizedTester.java b/test/unit/org/apache/cassandra/index/sai/utils/SAIRandomizedTester.java index 7cff65acd834..b321280af25c 100644 --- a/test/unit/org/apache/cassandra/index/sai/utils/SAIRandomizedTester.java +++ b/test/unit/org/apache/cassandra/index/sai/utils/SAIRandomizedTester.java @@ -18,6 +18,7 @@ package org.apache.cassandra.index.sai.utils; import java.io.IOException; +import java.util.List; import com.google.common.base.Preconditions; @@ -28,11 +29,22 @@ import org.junit.rules.TestRule; import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.PartitionPosition; import org.apache.cassandra.db.marshal.Int32Type; +import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.Bounds; +import org.apache.cassandra.dht.ExcludingBounds; +import org.apache.cassandra.dht.IPartitioner; +import org.apache.cassandra.dht.IncludingExcludingBounds; import org.apache.cassandra.dht.Murmur3Partitioner; +import org.apache.cassandra.dht.Range; import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.StorageAttachedIndex; import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; import org.apache.cassandra.index.sai.disk.io.IndexFileUtils; +import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.postings.PostingList; import org.apache.cassandra.io.sstable.Descriptor; import org.apache.cassandra.io.sstable.SequenceBasedSSTableId; @@ -199,4 +211,61 @@ public static void shuffle(int[] array) array[randomPosition] = temp; } } + + public static Expression generateRandomExpression(StorageAttachedIndex index) + { + Expression expression = Expression.create(index); + + int equality = getRandom().nextIntBetween(0, 100); + int lower = getRandom().nextIntBetween(0, 75); + int upper = getRandom().nextIntBetween(25, 100); + while (upper <= lower) + upper = getRandom().nextIntBetween(0, 100); + + if (getRandom().nextBoolean()) + expression.add(Operator.EQ, Int32Type.instance.decompose(equality)); + else + { + boolean useLower = getRandom().nextBoolean(); + boolean useUpper = getRandom().nextBoolean(); + if (!useLower && !useUpper) + useLower = useUpper = true; + if (useLower) + expression.add(getRandom().nextBoolean() ? Operator.GT : Operator.GTE, Int32Type.instance.decompose(lower)); + if (useUpper) + expression.add(getRandom().nextBoolean() ? Operator.LT : Operator.LTE, Int32Type.instance.decompose(upper)); + } + return expression; + } + + public static AbstractBounds generateRandomBounds(List keys, + IPartitioner partitioner) + { + PartitionPosition leftBound = getRandom().nextBoolean() ? partitioner.getMinimumToken().minKeyBound() + : keys.get(getRandom().nextIntBetween(0, keys.size() - 1)).getToken().minKeyBound(); + + PartitionPosition rightBound = getRandom().nextBoolean() ? partitioner.getMinimumToken().minKeyBound() + : keys.get(getRandom().nextIntBetween(0, keys.size() - 1)).getToken().maxKeyBound(); + + AbstractBounds keyRange; + + if (leftBound.isMinimum() && rightBound.isMinimum()) + keyRange = new Range<>(leftBound, rightBound); + else + { + if (AbstractBounds.strictlyWrapsAround(leftBound, rightBound)) + { + PartitionPosition temp = leftBound; + leftBound = rightBound; + rightBound = temp; + } + if (getRandom().nextBoolean()) + keyRange = new Bounds<>(leftBound, rightBound); + else if (getRandom().nextBoolean()) + keyRange = new ExcludingBounds<>(leftBound, rightBound); + else + keyRange = new IncludingExcludingBounds<>(leftBound, rightBound); + } + return keyRange; + } }