Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/java/org/apache/cassandra/db/memtable/ShardBoundaries.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
package org.apache.cassandra.db.memtable;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.google.common.annotations.VisibleForTesting;

import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.PartitionPosition;
import org.apache.cassandra.dht.AbstractBounds;
import org.apache.cassandra.dht.Range;
import org.apache.cassandra.dht.Token;
import org.apache.cassandra.tcm.Epoch;

Expand All @@ -45,15 +51,42 @@ public class ShardBoundaries
// - the default partitioner doesn't support splitting
// - the keyspace is local system keyspace
public static final ShardBoundaries NONE = new ShardBoundaries(EMPTY_TOKEN_ARRAY, Epoch.EMPTY);
private static final Range<PartitionPosition>[] EMPTY_RANGE_ARRAY = new Range[0];
private static final List<Integer> EMPTY_BOUNDARIES_SHARDS = Collections.singletonList(0);

private final Token[] boundaries;
private final Range<PartitionPosition>[] ranges;
private final List<Integer> allShards;
public final Epoch epoch;

@VisibleForTesting
public ShardBoundaries(Token[] boundaries, Epoch epoch)
{
this.boundaries = boundaries;
this.epoch = epoch;
this.ranges = precomputeRanges();
this.allShards = IntStream.range(0, boundaries.length + 1).boxed().collect(Collectors.toList());
}

private Range<PartitionPosition>[] precomputeRanges()
{
if (boundaries.length == 0)
return EMPTY_RANGE_ARRAY;

Range<PartitionPosition>[] ranges = new Range[boundaries.length + 1];
int rangeIndex = 0;
PartitionPosition minimum = DatabaseDescriptor.getPartitioner().getMinimumToken().minKeyBound();

for (Token boundary : boundaries)
{
PartitionPosition boundaryPosition = boundary.maxKeyBound();
ranges[rangeIndex++] = new Range<>(minimum, boundaryPosition);
minimum = boundaryPosition;
}

ranges[rangeIndex] = new Range<>(minimum, DatabaseDescriptor.getPartitioner().getMaximumTokenForSplitting().maxKeyBound());

return ranges;
}

public ShardBoundaries(List<Token> boundaries, Epoch epoch)
Expand Down Expand Up @@ -87,6 +120,20 @@ public int getShardForKey(PartitionPosition key)
return getShardForToken(key.getToken());
}

public List<Integer> getShardsForRange(AbstractBounds<PartitionPosition> keyRange)
{
if (boundaries.length == 0)
return EMPTY_BOUNDARIES_SHARDS;

// If the keyRange tokens match and are minimum then it represents the entire token ring
// then we need to return all the shards.
if (keyRange.right.isMinimum() && keyRange.left.compareTo(keyRange.right) == 0)
return allShards;

// Otherwise we need to return all the shards whose range intersects the keyrange
return allShards.stream().filter(s -> ranges[s].intersects(keyRange)).collect(Collectors.toList());
}

public Token getShardStartBoundary(int shardId)
{
if (shardId <= 0 || shardId >= boundaries.length)
Expand Down
26 changes: 26 additions & 0 deletions src/java/org/apache/cassandra/dht/Range.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,35 @@ public boolean intersects(AbstractBounds<T> that)
return intersects((Range<T>) that);
if (that instanceof Bounds)
return intersects((Bounds<T>) that);
if (that instanceof ExcludingBounds)
return intersects((ExcludingBounds<T>) that);
if (that instanceof IncludingExcludingBounds)
return intersects((IncludingExcludingBounds<T>) that);
throw new UnsupportedOperationException("Intersection is only supported for Bounds and Range objects; found " + that.getClass());
}

public boolean intersects(IncludingExcludingBounds<T> that)
{
if (!isWrapAround() && !that.right.isMinimum() && (this.left.compareTo(that.right) == 0))
return false;
else if (isWrapAround() && !that.right.isMinimum() && (this.right.compareTo(that.right) == 0))
return false;
return contains(that.left) || (!that.left.equals(that.right) && intersects(new Range<T>(that.left, that.right)));
}

public boolean intersects(ExcludingBounds<T> that)
{
if (!isWrapAround() &&
((!that.right.isMinimum() && (this.left.compareTo(that.right) == 0)) ||
(this.right.compareTo(that.left) == 0)))
return false;
else if (isWrapAround() &&
((this.left.compareTo(that.left) == 0) ||
(!that.right.isMinimum() && (this.right.compareTo(that.right) == 0))))
return false;
return contains(that.left) || (!that.left.equals(that.right) && intersects(new Range<T>(that.left, that.right)));
}

/**
* @param that range to check for intersection
* @return true if the given range intersects with this range.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
import org.apache.cassandra.index.sai.disk.format.Version;
import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig;
import org.apache.cassandra.index.sai.memory.MemtableIndexManager;
import org.apache.cassandra.index.sai.memory.ShardedMemtableIndex;
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
import org.apache.cassandra.index.sai.metrics.IndexMetrics;
import org.apache.cassandra.index.sai.utils.IndexIdentifier;
Expand Down Expand Up @@ -159,7 +160,8 @@ public class StorageAttachedIndex implements Index
IndexWriterConfig.OPTIMIZE_FOR,
NonTokenizingOptions.CASE_SENSITIVE,
NonTokenizingOptions.NORMALIZE,
NonTokenizingOptions.ASCII);
NonTokenizingOptions.ASCII,
ShardedMemtableIndex.SHARDS_OPTION);

public static final Set<CQL3Type> SUPPORTED_TYPES = ImmutableSet.of(CQL3Type.Native.ASCII, CQL3Type.Native.BIGINT, CQL3Type.Native.DATE,
CQL3Type.Native.DOUBLE, CQL3Type.Native.FLOAT, CQL3Type.Native.INT,
Expand Down Expand Up @@ -276,6 +278,10 @@ public static Map<String, String> validateOptions(Map<String, String> options, T
}

IndexTermType indexTermType = IndexTermType.create(target.left, metadata.partitionKeyColumns(), target.right);
String shardsOption = options.get(ShardedMemtableIndex.SHARDS_OPTION);
if (shardsOption != null && indexTermType.isVector())
throw new InvalidRequestException("A storage-attached index on a vector column does not support sharding");

AbstractAnalyzer.fromOptions(indexTermType, analysisOptions);
IndexWriterConfig config = IndexWriterConfig.fromOptions(null, indexTermType, options);

Expand Down
12 changes: 7 additions & 5 deletions src/java/org/apache/cassandra/index/sai/disk/RowMapping.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class RowMapping
public static final RowMapping DUMMY = new RowMapping()
{
@Override
public Iterator<Pair<ByteComparable, LongArrayList>> merge(MemtableIndex index) { return Collections.emptyIterator(); }
public Iterator<Pair<ByteComparable, LongArrayList>> merge(MemtableIndex index, PrimaryKey minKey, PrimaryKey maxKey) { return Collections.emptyIterator(); }

@Override
public void complete() {}
Expand Down Expand Up @@ -99,22 +99,24 @@ public static RowMapping create(OperationType opType)
*
* @return an iterator of term -> postings list {@link Pair}s
*/
public Iterator<Pair<ByteComparable, LongArrayList>> merge(MemtableIndex index)
public Iterator<Pair<ByteComparable, LongArrayList>> merge(MemtableIndex index,
PrimaryKey minKey,
PrimaryKey maxKey)
{
assert complete : "RowMapping is not built.";

Iterator<Pair<ByteComparable, PrimaryKeys>> iterator = index.iterator();
Iterator<Pair<ByteComparable, Iterator<PrimaryKey>>> iterator = index.iterator(minKey.partitionKey(), maxKey.partitionKey());
return new AbstractGuavaIterator<>()
{
@Override
protected Pair<ByteComparable, LongArrayList> computeNext()
{
while (iterator.hasNext())
{
Pair<ByteComparable, PrimaryKeys> pair = iterator.next();
Pair<ByteComparable, Iterator<PrimaryKey>> pair = iterator.next();

LongArrayList postings = null;
Iterator<PrimaryKey> primaryKeys = pair.right.iterator();
Iterator<PrimaryKey> primaryKeys = pair.right;

while (primaryKeys.hasNext())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void complete(Stopwatch stopwatch) throws IOException
}
else
{
final Iterator<Pair<ByteComparable, LongArrayList>> iterator = rowMapping.merge(memtable);
final Iterator<Pair<ByteComparable, LongArrayList>> iterator = rowMapping.merge(memtable, minKey, maxKey);

long cellCount = 0;
if (iterator.hasNext())
Expand Down
94 changes: 27 additions & 67 deletions src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Function;

import org.apache.cassandra.db.Clustering;
Expand All @@ -31,104 +30,65 @@
import org.apache.cassandra.db.memtable.Memtable;
import org.apache.cassandra.dht.AbstractBounds;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.StorageAttachedIndex;
import org.apache.cassandra.index.sai.disk.format.IndexDescriptor;
import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata;
import org.apache.cassandra.index.sai.disk.v1.vector.PrimaryKeyWithScore;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.utils.IndexIdentifier;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.PrimaryKeys;
import org.apache.cassandra.utils.CloseableIterator;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;

public class MemtableIndex implements MemtableOrdering
public interface MemtableIndex extends MemtableOrdering
{
private final MemoryIndex memoryIndex;
private final LongAdder writeCount = new LongAdder();
private final LongAdder estimatedMemoryUsed = new LongAdder();
private final Memtable memtable;
long writeCount();

public MemtableIndex(StorageAttachedIndex index, Memtable memtable)
{
this.memoryIndex = index.termType().isVector() ? new VectorMemoryIndex(index, memtable) : new TrieMemoryIndex(index);
this.memtable = memtable;
}
long estimatedMemoryUsed();

public long writeCount()
{
return writeCount.sum();
}
boolean isEmpty();

public long estimatedMemoryUsed()
{
return estimatedMemoryUsed.sum();
}
Memtable getMemtable();

public boolean isEmpty()
{
return memoryIndex.isEmpty();
}
ByteBuffer getMinTerm();

public Memtable getMemtable()
{
return memtable;
}
ByteBuffer getMaxTerm();

public ByteBuffer getMinTerm()
{
return memoryIndex.getMinTerm();
}
long index(DecoratedKey key, Clustering<?> clustering, ByteBuffer value);

public ByteBuffer getMaxTerm()
/** Implementation only for Vector Indexes */
default long update(DecoratedKey key, Clustering<?> clustering, ByteBuffer oldValue, ByteBuffer newValue)
{
return memoryIndex.getMaxTerm();
throw new UnsupportedOperationException();
}

public long index(DecoratedKey key, Clustering<?> clustering, ByteBuffer value)
{
if (value == null || (value.remaining() == 0 && memoryIndex.index.termType().skipsEmptyValue()))
return 0;

long ram = memoryIndex.add(key, clustering, value);
writeCount.increment();
estimatedMemoryUsed.add(ram);
return ram;
}

public long update(DecoratedKey key, Clustering<?> clustering, ByteBuffer oldValue, ByteBuffer newValue)
{
return memoryIndex.update(key, clustering, oldValue, newValue);
}
KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds<PartitionPosition> keyRange);

public KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds<PartitionPosition> keyRange)
{
return memoryIndex.search(queryContext, expression, keyRange);
}

public Iterator<Pair<ByteComparable, PrimaryKeys>> iterator()
{
return memoryIndex.iterator();
}
/**
* NOTE: Returned data may contain keys outside [minKey, maxKey].
* Bounds are used for shard selection only; implementations
* without sharding may ignore them.
*/
Iterator<Pair<ByteComparable, Iterator<PrimaryKey>>> iterator(DecoratedKey min, DecoratedKey max);

public SegmentMetadata.ComponentMetadataMap writeDirect(IndexDescriptor indexDescriptor,
IndexIdentifier indexIdentifier,
Function<PrimaryKey, Integer> postingTransformer) throws IOException
/** Implementation only for Vector Indexes */
default SegmentMetadata.ComponentMetadataMap writeDirect(IndexDescriptor indexDescriptor,
IndexIdentifier indexIdentifier,
Function<PrimaryKey, Integer> postingTransformer) throws IOException
{
return memoryIndex.writeDirect(indexDescriptor, indexIdentifier, postingTransformer);
throw new UnsupportedOperationException();
}

@Override
public CloseableIterator<PrimaryKeyWithScore> orderBy(QueryContext queryContext, Expression orderer, AbstractBounds<PartitionPosition> keyRange)
default CloseableIterator<PrimaryKeyWithScore> orderBy(QueryContext queryContext, Expression orderer, AbstractBounds<PartitionPosition> keyRange)
{
return memoryIndex.orderBy(queryContext, orderer, keyRange);
throw new UnsupportedOperationException();
}

@Override
public CloseableIterator<PrimaryKeyWithScore> orderResultsBy(QueryContext queryContext, List<PrimaryKey> results, Expression orderer)
default CloseableIterator<PrimaryKeyWithScore> orderResultsBy(QueryContext queryContext, List<PrimaryKey> results, Expression orderer)
{
return memoryIndex.orderResultsBy(queryContext, results, orderer);
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,16 @@ private MemtableIndex initializeMemtableIndex(Memtable mt)
// We expect the relevant IndexMemtable to be present most of the time, so only make the
// call to computeIfAbsent() if it's not. (see https://bugs.openjdk.java.net/browse/JDK-8161372)
return current != null ? current
: liveMemtableIndexMap.computeIfAbsent(mt, memtable -> new MemtableIndex(index, memtable));
: liveMemtableIndexMap.computeIfAbsent(mt, memtable -> {
String shardsOption = index.getIndexMetadata().options.get(ShardedMemtableIndex.SHARDS_OPTION);
if (shardsOption != null)
{
Integer shardCount = Integer.parseInt(shardsOption);
if (shardCount > 1)
return new ShardedMemtableIndex(index, index.baseCfs(), shardCount, memtable);
}
return new UnshardedMemtableIndex(index, memtable);
});
}

public long index(DecoratedKey key, Row row, Memtable mt)
Expand Down
Loading