diff --git a/docs/release notes/4.0.0-RC.9/pr685.feature.md b/docs/release notes/4.0.0-RC.9/pr685.feature.md new file mode 100644 index 000000000..806e19c9a --- /dev/null +++ b/docs/release notes/4.0.0-RC.9/pr685.feature.md @@ -0,0 +1,8 @@ +### Provide a possibility to store and load an empty graph + +**Description** +Provide a possibility to store and load an empty graph from/to disk + +**Purpose / Impact** +- Helps to preserve graph metadata (similarity function, features, dimensions, even if the graph is empty) + diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 8135bba25..4139a14b6 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -951,7 +951,9 @@ private void loadV4(RandomAccessReader in) throws IOException { } graph.setDegrees(layerDegrees); - graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); + if (entryNode != ImmutableGraphIndex.ENTRY_NODE_ABSENT) { + graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); + } } @Deprecated @@ -984,7 +986,9 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { graph.markComplete(new NodeAtLevel(0, nodeId)); } - graph.updateEntryNode(new NodeAtLevel(0, entryNode)); + if (entryNode != ImmutableGraphIndex.ENTRY_NODE_ABSENT) { + graph.updateEntryNode(new NodeAtLevel(0, entryNode)); + } graph.setDegrees(List.of(maxDegree)); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java index a4758c493..6482431cd 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java @@ -24,6 +24,7 @@ package io.github.jbellis.jvector.graph; +import io.github.jbellis.jvector.graph.disk.OrdinalMapper; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.util.Accountable; import io.github.jbellis.jvector.util.Bits; @@ -35,7 +36,6 @@ import java.io.Closeable; import java.io.IOException; -import java.util.function.Function; /** * Represents a graph-based vector index. Nodes are represented as ints, and edges are @@ -48,6 +48,9 @@ * in a View that should be created per accessing thread. */ public interface ImmutableGraphIndex extends AutoCloseable, Accountable { + /** Marks entry node as absent (fe, empty graph) */ + int ENTRY_NODE_ABSENT = -1; + /** Returns the number of nodes in the graph */ @Deprecated default int size() { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 9ed1a92dd..c1b44e991 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -49,7 +49,6 @@ import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.StampedLock; -import java.util.function.Function; import java.util.stream.IntStream; /** @@ -351,7 +350,7 @@ public double getAverageDegree(int level) { public int getMaxLevel() { for (int lvl = 0; lvl < layers.size(); lvl++) { if (layers.get(lvl).size() == 0) { - return lvl - 1; + return (lvl > 0) ? lvl - 1 : 0; } } return layers.size() - 1; @@ -542,8 +541,12 @@ public void save(DataOutput out) throws IOException { } var entryNode = entryPoint.get(); - assert entryNode.level == getMaxLevel(); - out.writeInt(entryNode.node); + if (entryNode != null) { + assert entryNode.level == getMaxLevel(); + out.writeInt(entryNode.node); + } else { + out.writeInt(ENTRY_NODE_ABSENT); + } for (int level = 0; level < layers.size(); level++) { out.writeInt(size(level)); @@ -618,7 +621,9 @@ public static OnHeapGraphIndex load(RandomAccessReader in, int dimension, double } graph.setDegrees(layerDegrees); - graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); + if (entryNode != ENTRY_NODE_ABSENT) { + graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); + } return graph; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java index a5ff739f3..09d0d0ec0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java @@ -175,7 +175,7 @@ void writeFooter(ImmutableGraphIndex.View view, long headerOffset) throws IOExce var layerInfo = CommonHeader.LayerInfo.fromGraph(graph, ordinalMapper); var commonHeader = new CommonHeader(version, dimension, - ordinalMapper.oldToNew(view.entryNode().node), + view.entryNode() == null ? ImmutableGraphIndex.ENTRY_NODE_ABSENT : ordinalMapper.oldToNew(view.entryNode().node), layerInfo, ordinalMapper.maxOrdinal() + 1); var header = new Header(commonHeader, featureMap); @@ -198,7 +198,7 @@ protected synchronized void writeHeader(ImmutableGraphIndex.View view, long star var layerInfo = CommonHeader.LayerInfo.fromGraph(graph, ordinalMapper); var commonHeader = new CommonHeader(version, dimension, - ordinalMapper.oldToNew(view.entryNode().node), + view.entryNode() == null ? ImmutableGraphIndex.ENTRY_NODE_ABSENT : ordinalMapper.oldToNew(view.entryNode().node), layerInfo, ordinalMapper.maxOrdinal() + 1); var header = new Header(commonHeader, featureMap); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index 3fb69d967..b445ceb20 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -50,7 +50,6 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -import java.util.function.Function; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -95,7 +94,11 @@ private OnDiskGraphIndex(ReaderSupplier readerSupplier, Header header, long neig this.version = header.common.version; this.layerInfo = header.common.layerInfo; this.dimension = header.common.dimension; - this.entryNode = new NodeAtLevel(header.common.layerInfo.size() - 1, header.common.entryNode); + if (header.common.entryNode == ENTRY_NODE_ABSENT) { + this.entryNode = null; + } else { + this.entryNode = new NodeAtLevel(header.common.layerInfo.size() - 1, header.common.entryNode); + } this.idUpperBound = header.common.idUpperBound; this.features = header.features; this.neighborsOffset = neighborsOffset; @@ -435,7 +438,7 @@ public String toString() { @Override public int getMaxLevel() { - return entryNode.level; + return entryNode == null ? 0 : entryNode.level; } @Override diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index eed263a09..248be18ce 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -543,4 +543,109 @@ public long ramBytesUsed() { throw new UnsupportedOperationException(); } } + + public static class EmptyGraphIndex implements ImmutableGraphIndex { + private final int dimension; + private final List layerInfo; + + public EmptyGraphIndex(int dimension, Random random) { + this.dimension = dimension; + this.layerInfo = List.of(new CommonHeader.LayerInfo(0, 0)); + } + + @Override + public long ramBytesUsed() { + return 0; + } + + @Override + public NodesIterator getNodes(int level) { + return NodesIterator.EMPTY_NODE_ITERATOR; + } + + @Override + public View getView() { + return new View() { + @Override + public void close() throws IOException { + } + + @Override + public int size() { + return 0; + } + + @Override + public void processNeighbors(int level, int node, ScoreFunction scoreFunction, IntMarker visited, + NeighborProcessor neighborProcessor) { + throw new IllegalStateException("Should not be called, empty graph has no nodes"); + } + + @Override + public Bits liveNodes() { + return Bits.ALL; + } + + @Override + public NodesIterator getNeighborsIterator(int level, int node) { + return NodesIterator.EMPTY_NODE_ITERATOR; + } + + @Override + public NodeAtLevel entryNode() { + return null; + } + + @Override + public boolean contains(int level, int node) { + return false; + } + }; + } + + @Override + public int maxDegree() { + return layerInfo.stream().mapToInt(li -> li.degree).max().orElseThrow(); + } + + @Override + public List maxDegrees() { + throw new NotImplementedException(); + } + + @Override + public int getDimension() { + return dimension; + } + + @Override + public void close() throws IOException { + + } + + @Override + public boolean isHierarchical() { + return false; + } + + @Override + public int getMaxLevel() { + return layerInfo.size() - 1; + } + + @Override + public int getDegree(int level) { + return layerInfo.get(level).degree; + } + + @Override + public double getAverageDegree(int level) { + throw new NotImplementedException(); + } + + @Override + public int size(int level) { + return layerInfo.get(level).size; + } + }; } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java index 59b248584..9cdb693a3 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java @@ -38,6 +38,7 @@ import static io.github.jbellis.jvector.TestUtil.assertGraphEquals; import static io.github.jbellis.jvector.graph.TestVectorGraph.createRandomFloatVectors; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) @@ -157,6 +158,34 @@ public void testSaveAndLoad() throws IOException { assertGraphEquals(graph, builder.graph); } + @Test + public void testSaveAndLoadEmptyGraph() throws IOException { + int dimension = randomIntBetween(2, 32); + var ravv = MockVectorValues.empty(dimension); + + Supplier newBuilder = () -> + new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f, true); + + var indexDataPath = testDirectory.resolve("index_builder_empty.data"); + var builder = newBuilder.get(); + + var graph = TestUtil.buildSequentially(builder, ravv); + + try (var out = TestUtil.openDataOutputStream(indexDataPath)) { + ((OnHeapGraphIndex) graph).setAllMutationsCompleted(); + ((OnHeapGraphIndex) graph).save(out); + } + + builder = newBuilder.get(); + try(var readerSupplier = new SimpleMappedReader.Supplier(indexDataPath)) { + builder.load(readerSupplier.get()); + } + + assertEquals(ravv.size(), builder.graph.size(0)); + assertNull(builder.graph.entryNode()); + assertGraphEquals(graph, builder.graph); + } + // Because RandomAccessVectorValues is exposed in such a way that it allows for subsequent additions to the // vector source, we need to ensure that GraphIndexBuilder can handle this. @Test diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/MockVectorValues.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/MockVectorValues.java index a5a23b245..8bb2dfa01 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/MockVectorValues.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/MockVectorValues.java @@ -39,6 +39,10 @@ public static MockVectorValues fromValues(VectorFloat[] values) { return new MockVectorValues(values[0].length(), values); } + public static MockVectorValues empty(int dimension) { + return new MockVectorValues(dimension, new VectorFloat[0]); + } + MockVectorValues(int dimension, VectorFloat[] denseValues) { this.dimension = dimension; this.denseValues = denseValues; diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java index 29a8dca29..8930b720e 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java @@ -22,6 +22,7 @@ import io.github.jbellis.jvector.disk.SimpleMappedReader; import io.github.jbellis.jvector.graph.GraphIndexBuilder; import io.github.jbellis.jvector.graph.GraphSearcher; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.graph.NodesIterator; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; @@ -60,11 +61,13 @@ public class TestOnDiskGraphIndex extends RandomizedTest { private TestUtil.FullyConnectedGraphIndex fullyConnectedGraph; private TestUtil.RandomlyConnectedGraphIndex randomlyConnectedGraph; + private ImmutableGraphIndex emptyGraph; @Before public void setup() throws IOException { fullyConnectedGraph = new TestUtil.FullyConnectedGraphIndex(0, 6); randomlyConnectedGraph = new TestUtil.RandomlyConnectedGraphIndex(10, 4, getRandom()); + emptyGraph = new TestUtil.EmptyGraphIndex(10, getRandom()); testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); } @@ -75,7 +78,7 @@ public void tearDown() { @Test public void testSimpleGraphs() throws Exception { - for (var graph : List.of(fullyConnectedGraph, randomlyConnectedGraph)) + for (var graph : List.of(fullyConnectedGraph, randomlyConnectedGraph, emptyGraph)) { var outputPath = testDirectory.resolve("test_graph_" + graph.getClass().getSimpleName()); var ravv = new TestVectorGraph.CircularFloatVectorValues(graph.size(0));