diff --git a/fastfilter/src/main/java/org/fastfilter/Filter.java b/fastfilter/src/main/java/org/fastfilter/Filter.java index 5eedbcb..05bfe43 100644 --- a/fastfilter/src/main/java/org/fastfilter/Filter.java +++ b/fastfilter/src/main/java/org/fastfilter/Filter.java @@ -1,5 +1,7 @@ package org.fastfilter; +import java.nio.ByteBuffer; + /** * An approximate membership filter. */ @@ -14,7 +16,7 @@ public interface Filter { boolean mayContain(long key); /** - * Get the number of bits in thhe filter. + * Get the number of bits in the filter. * * @return the number of bits */ @@ -65,4 +67,22 @@ default long cardinality() { return -1; } + /** + * Get the serialized size of the filter. + * + * @return the size in bytes + */ + default int getSerializedSize() { + return -1; + } + + /** + * Serializes the filter state into the provided {@code ByteBuffer}. + * + * @param buffer the byte buffer where the serialized state of the filter will be written + * @throws UnsupportedOperationException if the operation is not supported by the filter implementation + */ + default void serialize(ByteBuffer buffer) { + throw new UnsupportedOperationException(); + } } diff --git a/fastfilter/src/main/java/org/fastfilter/utils/Hash.java b/fastfilter/src/main/java/org/fastfilter/utils/Hash.java index 6e6b02f..57fc3fe 100644 --- a/fastfilter/src/main/java/org/fastfilter/utils/Hash.java +++ b/fastfilter/src/main/java/org/fastfilter/utils/Hash.java @@ -3,7 +3,6 @@ import java.util.Random; public class Hash { - private static Random random = new Random(); public static void setSeed(long seed) { diff --git a/fastfilter/src/main/java/org/fastfilter/xor/Xor16.java b/fastfilter/src/main/java/org/fastfilter/xor/Xor16.java index bca40a6..8cdc4d8 100644 --- a/fastfilter/src/main/java/org/fastfilter/xor/Xor16.java +++ b/fastfilter/src/main/java/org/fastfilter/xor/Xor16.java @@ -1,5 +1,7 @@ package org.fastfilter.xor; +import java.nio.ByteBuffer; + import org.fastfilter.Filter; import org.fastfilter.utils.Hash; @@ -143,4 +145,55 @@ private int fingerprint(long hash) { return (int) (hash & ((1 << BITS_PER_FINGERPRINT) - 1)); } + private Xor16(int blockLength, int bitCount, long seed, short[] fingerprints) { + this.blockLength = blockLength; + this.bitCount = bitCount; + this.seed = seed; + this.fingerprints = fingerprints; + } + + @Override + public int getSerializedSize() { + return Integer.BYTES + Long.BYTES + Integer.BYTES + fingerprints.length * Short.BYTES; + } + + @Override + public void serialize(ByteBuffer buffer) { + if (buffer.remaining() < getSerializedSize()) { + throw new IllegalArgumentException("Buffer too small"); + } + + buffer.putInt(blockLength); + buffer.putLong(seed); + buffer.putInt(fingerprints.length); + for (final short fp : fingerprints) { + buffer.putShort(fp); + } + } + + public static Xor16 deserialize(ByteBuffer buffer) { + // Check minimum size for header (1 int + 1 long + 1 int for length) + if (buffer.remaining() < Integer.BYTES + Long.BYTES + Integer.BYTES) { + throw new IllegalArgumentException("Buffer too small"); + } + + final int blockLength = buffer.getInt(); + final long seed = buffer.getLong(); + + final int len = buffer.getInt(); + + // Check if buffer has enough bytes for all fingerprints + if (buffer.remaining() < len * Short.BYTES) { + throw new IllegalArgumentException("Buffer too small"); + } + + final short[] fingerprints = new short[len]; + for (int i = 0; i < len; i++) { + fingerprints[i] = buffer.getShort(); + } + + final int bitCount = len * BITS_PER_FINGERPRINT; + + return new Xor16(blockLength, bitCount, seed, fingerprints); + } } diff --git a/fastfilter/src/main/java/org/fastfilter/xor/Xor8.java b/fastfilter/src/main/java/org/fastfilter/xor/Xor8.java index 86ac870..bb3b5ff 100644 --- a/fastfilter/src/main/java/org/fastfilter/xor/Xor8.java +++ b/fastfilter/src/main/java/org/fastfilter/xor/Xor8.java @@ -1,6 +1,7 @@ package org.fastfilter.xor; import java.io.*; +import java.nio.ByteBuffer; import org.fastfilter.Filter; import org.fastfilter.utils.Hash; @@ -187,4 +188,51 @@ public Xor8(InputStream in) { } } + private Xor8(int size, long seed, byte[] fingerprints) { + this.size = size; + this.arrayLength = getArrayLength(size); + this.bitCount = arrayLength * BITS_PER_FINGERPRINT; + this.blockLength = arrayLength / HASHES; + this.seed = seed; + this.fingerprints = fingerprints; + } + + @Override + public int getSerializedSize() { + return Integer.BYTES + Long.BYTES + Integer.BYTES + fingerprints.length * Byte.BYTES; + } + + @Override + public void serialize(ByteBuffer buffer) { + if (buffer.remaining() < getSerializedSize()) { + throw new IllegalArgumentException("Buffer too small"); + } + + buffer.putInt(size); + buffer.putLong(seed); + buffer.putInt(fingerprints.length); + buffer.put(fingerprints); + } + + public static Xor8 deserialize(ByteBuffer buffer) { + // Check minimum size for header (1 int + 1 long + 1 int for length) + if (buffer.remaining() < Integer.BYTES + Long.BYTES + Integer.BYTES) { + throw new IllegalArgumentException("Buffer too small"); + } + + final int size = buffer.getInt(); + final long seed = buffer.getLong(); + + final int len = buffer.getInt(); + + // Check if buffer has enough bytes for all fingerprints + if (buffer.remaining() < len * Byte.BYTES) { + throw new IllegalArgumentException("Buffer too small"); + } + + final byte[] fingerprints = new byte[len]; + buffer.get(fingerprints); + + return new Xor8(size, seed, fingerprints); + } } diff --git a/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse16.java b/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse16.java index aad3c77..db49863 100644 --- a/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse16.java +++ b/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse16.java @@ -1,7 +1,7 @@ package org.fastfilter.xor; +import java.nio.ByteBuffer; import java.util.Arrays; - import org.fastfilter.Filter; import org.fastfilter.utils.Hash; @@ -20,19 +20,25 @@ public class XorBinaryFuse16 implements Filter { private final short[] fingerprints; private long seed; - public XorBinaryFuse16(int segmentCount, int segmentLength) { + private XorBinaryFuse16(int segmentCount, int segmentLength, long seed, short[] fingerprints) { if (segmentLength < 0 || Integer.bitCount(segmentLength) != 1) { throw new IllegalArgumentException("Segment length needs to be a power of 2, is " + segmentLength); } if (segmentCount <= 0) { throw new IllegalArgumentException("Illegal segment count: " + segmentCount); } - this.segmentLength = segmentLength; + this.segmentCount = segmentCount; - this.segmentLengthMask = segmentLength - 1; this.segmentCountLength = segmentCount * segmentLength; - this.arrayLength = (segmentCount + ARITY - 1) * segmentLength; - this.fingerprints = new short[arrayLength]; + this.segmentLength = segmentLength; + this.segmentLengthMask = segmentLength - 1; + this.arrayLength = fingerprints.length; + this.fingerprints = fingerprints; + this.seed = seed; + } + + public XorBinaryFuse16(int segmentCount, int segmentLength) { + this(segmentCount, segmentLength, 0L, new short[(segmentCount + ARITY - 1) * segmentLength]); } public long getBitCount() { @@ -202,14 +208,13 @@ private void addAll(long[] keys) { // if construction doesn't succeed eventually, // then there is likely a problem with the hash function // let us not crash the system: - for(int i = 0; i < fingerprints.length; i++) { - fingerprints[i] = (short)0xFFFF; - } + Arrays.fill(fingerprints, (short) 0xFFFF); return; } - // use a new random numbers + // use a new random number seed = Hash.randomSeed(); } + alone = null; t2count = null; t2hash = null; @@ -261,4 +266,51 @@ private short fingerprint(long hash) { return (short) hash; } -} \ No newline at end of file + @Override + public int getSerializedSize() { + return 2 * Integer.BYTES + Long.BYTES + Integer.BYTES + fingerprints.length * Short.BYTES; + } + + @Override + public void serialize(ByteBuffer buffer) { + if (buffer.remaining() < getSerializedSize()) { + throw new IllegalArgumentException("Buffer too small"); + } + + buffer.putInt(segmentLength); + buffer.putInt(segmentCountLength); + buffer.putLong(seed); + buffer.putInt(fingerprints.length); + for (final short fp : fingerprints) { + buffer.putShort(fp); + } + } + + public static XorBinaryFuse16 deserialize(ByteBuffer buffer) { + // Check minimum size for header (2 ints + 1 long + 1 int for length) + if (buffer.remaining() < 2 * Integer.BYTES + Long.BYTES + Integer.BYTES) { + throw new IllegalArgumentException("Buffer too small"); + } + + final int segmentLength = buffer.getInt(); + final int segmentCountLength = buffer.getInt(); + final long seed = buffer.getLong(); + + final int len = buffer.getInt(); + + // Check if buffer has enough bytes for all fingerprints + if (buffer.remaining() < len * Short.BYTES) { + throw new IllegalArgumentException("Buffer too small"); + } + + final short[] fingerprints = new short[len]; + for (int i = 0; i < len; i++) { + fingerprints[i] = buffer.getShort(); + } + + // Calculate segmentCount from segmentCountLength and segmentLength + final int segmentCount = segmentCountLength / segmentLength; + + return new XorBinaryFuse16(segmentCount, segmentLength, seed, fingerprints); + } +} diff --git a/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse32.java b/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse32.java index d3f6125..760fc3f 100644 --- a/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse32.java +++ b/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse32.java @@ -1,5 +1,6 @@ package org.fastfilter.xor; +import java.nio.ByteBuffer; import java.util.Arrays; import org.fastfilter.Filter; @@ -20,19 +21,25 @@ public class XorBinaryFuse32 implements Filter { private final int[] fingerprints; private long seed; - public XorBinaryFuse32(int segmentCount, int segmentLength) { + private XorBinaryFuse32(int segmentCount, int segmentLength, long seed, int[] fingerprints) { if (segmentLength < 0 || Integer.bitCount(segmentLength) != 1) { throw new IllegalArgumentException("Segment length needs to be a power of 2, is " + segmentLength); } if (segmentCount <= 0) { throw new IllegalArgumentException("Illegal segment count: " + segmentCount); } - this.segmentLength = segmentLength; + this.segmentCount = segmentCount; - this.segmentLengthMask = segmentLength - 1; this.segmentCountLength = segmentCount * segmentLength; - this.arrayLength = (segmentCount + ARITY - 1) * segmentLength; - this.fingerprints = new int[arrayLength]; + this.segmentLength = segmentLength; + this.segmentLengthMask = segmentLength - 1; + this.arrayLength = fingerprints.length; + this.fingerprints = fingerprints; + this.seed = seed; + } + + public XorBinaryFuse32(int segmentCount, int segmentLength) { + this(segmentCount, segmentLength, 0L, new int[(segmentCount + ARITY - 1) * segmentLength]); } public long getBitCount() { @@ -261,4 +268,51 @@ private int fingerprint(long hash) { return (int) (hash ^ (hash >>> 32)); } + @Override + public int getSerializedSize() { + return 2 * Integer.BYTES + Long.BYTES + Integer.BYTES + fingerprints.length * Integer.BYTES; + } + + @Override + public void serialize(ByteBuffer buffer) { + if (buffer.remaining() < getSerializedSize()) { + throw new IllegalArgumentException("Buffer too small"); + } + + buffer.putInt(segmentLength); + buffer.putInt(segmentCountLength); + buffer.putLong(seed); + buffer.putInt(fingerprints.length); + for (final int fp : fingerprints) { + buffer.putInt(fp); + } + } + + public static XorBinaryFuse32 deserialize(ByteBuffer buffer) { + // Check minimum size for header (2 ints + 1 long + 1 int for length) + if (buffer.remaining() < 2 * Integer.BYTES + Long.BYTES + Integer.BYTES) { + throw new IllegalArgumentException("Buffer too small"); + } + + final int segmentLength = buffer.getInt(); + final int segmentCountLength = buffer.getInt(); + final long seed = buffer.getLong(); + + final int len = buffer.getInt(); + + // Check if buffer has enough bytes for all fingerprints + if (buffer.remaining() < len * Integer.BYTES) { + throw new IllegalArgumentException("Buffer too small"); + } + + final int[] fingerprints = new int[len]; + for (int i = 0; i < len; i++) { + fingerprints[i] = buffer.getInt(); + } + + // Calculate segmentCount from segmentCountLength and segmentLength + final int segmentCount = segmentCountLength / segmentLength; + + return new XorBinaryFuse32(segmentCount, segmentLength, seed, fingerprints); + } } diff --git a/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse8.java b/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse8.java index dfd5f45..ea16611 100644 --- a/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse8.java +++ b/fastfilter/src/main/java/org/fastfilter/xor/XorBinaryFuse8.java @@ -1,5 +1,6 @@ package org.fastfilter.xor; +import java.nio.ByteBuffer; import java.util.Arrays; import org.fastfilter.Filter; @@ -20,19 +21,25 @@ public class XorBinaryFuse8 implements Filter { private final byte[] fingerprints; private long seed; - public XorBinaryFuse8(int segmentCount, int segmentLength) { + private XorBinaryFuse8(int segmentCount, int segmentLength, long seed, byte[] fingerprints) { if (segmentLength < 0 || Integer.bitCount(segmentLength) != 1) { throw new IllegalArgumentException("Segment length needs to be a power of 2, is " + segmentLength); } if (segmentCount <= 0) { throw new IllegalArgumentException("Illegal segment count: " + segmentCount); } - this.segmentLength = segmentLength; + this.segmentCount = segmentCount; - this.segmentLengthMask = segmentLength - 1; this.segmentCountLength = segmentCount * segmentLength; - this.arrayLength = (segmentCount + ARITY - 1) * segmentLength; - this.fingerprints = new byte[arrayLength]; + this.segmentLength = segmentLength; + this.segmentLengthMask = segmentLength - 1; + this.arrayLength = fingerprints.length; + this.fingerprints = fingerprints; + this.seed = seed; + } + + public XorBinaryFuse8(int segmentCount, int segmentLength) { + this(segmentCount, segmentLength, 0L, new byte[(segmentCount + ARITY - 1) * segmentLength]); } public long getBitCount() { @@ -261,4 +268,47 @@ private byte fingerprint(long hash) { return (byte) hash; } + @Override + public int getSerializedSize() { + return 2 * Integer.BYTES + Long.BYTES + Integer.BYTES + fingerprints.length * Byte.BYTES; + } + + @Override + public void serialize(ByteBuffer buffer) { + if (buffer.remaining() < getSerializedSize()) { + throw new IllegalArgumentException("Buffer too small"); + } + + buffer.putInt(segmentLength); + buffer.putInt(segmentCountLength); + buffer.putLong(seed); + buffer.putInt(fingerprints.length); + buffer.put(fingerprints); + } + + public static XorBinaryFuse8 deserialize(ByteBuffer buffer) { + // Check minimum size for header (2 ints + 1 long + 1 int for length) + if (buffer.remaining() < 2 * Integer.BYTES + Long.BYTES + Integer.BYTES) { + throw new IllegalArgumentException("Buffer too small"); + } + + final int segmentLength = buffer.getInt(); + final int segmentCountLength = buffer.getInt(); + final long seed = buffer.getLong(); + + final int len = buffer.getInt(); + + // Check if buffer has enough bytes for all fingerprints + if (buffer.remaining() < len * Byte.BYTES) { + throw new IllegalArgumentException("Buffer too small"); + } + + final byte[] fingerprints = new byte[len]; + buffer.get(fingerprints); + + // Calculate segmentCount from segmentCountLength and segmentLength + final int segmentCount = segmentCountLength / segmentLength; + + return new XorBinaryFuse8(segmentCount, segmentLength, seed, fingerprints); + } } diff --git a/fastfilter/src/test/java/org/fastfilter/xor/SerializationTest.java b/fastfilter/src/test/java/org/fastfilter/xor/SerializationTest.java new file mode 100644 index 0000000..df6a204 --- /dev/null +++ b/fastfilter/src/test/java/org/fastfilter/xor/SerializationTest.java @@ -0,0 +1,298 @@ +package org.fastfilter.xor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.function.Function; +import org.fastfilter.Filter; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class SerializationTest { + + private final String filterName; + private final Function constructor; + private final Function deserializer; + + public SerializationTest(String filterName, + Function constructor, + Function deserializer) { + this.filterName = filterName; + this.constructor = constructor; + this.deserializer = deserializer; + } + + @Parameters(name = "{0}") + public static List filters() { + return List.of( + new Object[] {"Xor8", (Function) Xor8::construct, + (Function) Xor8::deserialize}, + new Object[] {"Xor16", (Function) Xor16::construct, + (Function) Xor16::deserialize}, + new Object[] {"XorBinaryFuse8", (Function) XorBinaryFuse8::construct, + (Function) XorBinaryFuse8::deserialize}, + new Object[] {"XorBinaryFuse16", (Function) XorBinaryFuse16::construct, + (Function) XorBinaryFuse16::deserialize}, + new Object[] {"XorBinaryFuse32", (Function) XorBinaryFuse32::construct, + (Function) XorBinaryFuse32::deserialize} + ); + } + + @Test + public void shouldSerializeAndDeserializeSmallFilter() { + // Arrange + final var keys = new long[]{1L, 2L, 3L, 4L, 5L}; + final var originalFilter = constructor.apply(keys); + final var buffer = ByteBuffer.allocate(originalFilter.getSerializedSize()); + + // Act + originalFilter.serialize(buffer); + buffer.flip(); + final var deserializedFilter = deserializer.apply(buffer); + + // Assert + for (final long key : keys) { + assertTrue("Key " + key + " should be present in deserialized " + filterName + " filter", + deserializedFilter.mayContain(key)); + } + } + + @Test + public void shouldSerializeAndDeserializeMediumFilter() { + // Arrange + final var keys = new long[]{100L, 200L, 300L, 400L, 500L, 600L, 700L, 800L, 900L, 1000L}; + final var originalFilter = constructor.apply(keys); + final var buffer = ByteBuffer.allocate(originalFilter.getSerializedSize()); + + // Act + originalFilter.serialize(buffer); + buffer.flip(); + final var deserializedFilter = deserializer.apply(buffer); + + // Assert + for (final long key : keys) { + assertTrue("Key " + key + " should be present in deserialized " + filterName + " filter", + deserializedFilter.mayContain(key)); + } + assertFalse("Key 50L should not be in " + filterName + " filter", deserializedFilter.mayContain(50L)); + assertFalse("Key 1500L should not be in " + filterName + " filter", deserializedFilter.mayContain(1500L)); + } + + @Test + public void shouldSerializeAndDeserializeLargeFilter() { + // Arrange + final int size = 10000; + final var keys = new long[size]; + for (int i = 0; i < size; i++) { + keys[i] = i * 100L; + } + final var originalFilter = constructor.apply(keys); + final var buffer = ByteBuffer.allocate(originalFilter.getSerializedSize()); + + // Act + originalFilter.serialize(buffer); + buffer.flip(); + final var deserializedFilter = deserializer.apply(buffer); + + // Assert + for (int i = 0; i < size; i++) { + final long key = i * 100L; + assertTrue("Key " + key + " should be present in deserialized " + filterName + " filter", + deserializedFilter.mayContain(key)); + } + // Test some keys that should not be in the filter + assertFalse("Key 1L should not be in filter", deserializedFilter.mayContain(1L)); + assertFalse("Key 50L should not be in filter", deserializedFilter.mayContain(50L)); + assertFalse("Key 99L should not be in filter", deserializedFilter.mayContain(99L)); + } + + @Test + public void shouldPreserveFilterCharacteristicsAfterDeserialization() { + // Arrange + final var keys = new long[]{1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L}; + final var originalFilter = constructor.apply(keys); + final var buffer = ByteBuffer.allocate(originalFilter.getSerializedSize()); + + // Act + originalFilter.serialize(buffer); + buffer.flip(); + final var deserializedFilter = deserializer.apply(buffer); + + // Assert + assertEquals("Bit count should be preserved for " + filterName, + originalFilter.getBitCount(), deserializedFilter.getBitCount()); + assertEquals("Serialized size should be preserved for " + filterName, + originalFilter.getSerializedSize(), deserializedFilter.getSerializedSize()); + } + + @Test + public void shouldHandleMultipleSerializationRounds() { + // Arrange + final var keys = new long[]{10L, 20L, 30L, 40L, 50L}; + final var originalFilter = constructor.apply(keys); + final var buffer1 = ByteBuffer.allocate(originalFilter.getSerializedSize()); + + // Act - First round + originalFilter.serialize(buffer1); + buffer1.flip(); + final var filter1 = deserializer.apply(buffer1); + + // Act - Second round + final var buffer2 = ByteBuffer.allocate(filter1.getSerializedSize()); + filter1.serialize(buffer2); + buffer2.flip(); + final var filter2 = deserializer.apply(buffer2); + + // Assert + for (final long key : keys) { + assertTrue("Key " + key + " should be present after first deserialization of " + filterName, + filter1.mayContain(key)); + assertTrue("Key " + key + " should be present after second deserialization of " + filterName, + filter2.mayContain(key)); + } + } + + @Test + public void shouldThrowExceptionWhenSerializeBufferTooSmall() { + // Arrange + final var keys = new long[]{1L, 2L, 3L, 4L, 5L}; + final var filter = constructor.apply(keys); + final var smallBuffer = ByteBuffer.allocate(filter.getSerializedSize() - 1); + + // Act & Assert + try { + filter.serialize(smallBuffer); + fail("Should have thrown IllegalArgumentException for buffer too small"); + } catch (IllegalArgumentException e) { + assertEquals("Buffer too small", e.getMessage()); + } + } + + @Test + public void shouldThrowExceptionWhenDeserializeBufferTooSmall() { + // Arrange + final var tooSmallBuffer = ByteBuffer.allocate(10); + + // Act & Assert + try { + deserializer.apply(tooSmallBuffer); + fail("Should have thrown IllegalArgumentException for buffer too small"); + } catch (IllegalArgumentException e) { + assertEquals("Buffer too small", e.getMessage()); + } + } + + @Test + public void shouldHandleFilterWithSequentialKeys() { + // Arrange + final int size = 1000; + final var keys = new long[size]; + for (int i = 0; i < size; i++) { + keys[i] = i; + } + final var originalFilter = constructor.apply(keys); + final var buffer = ByteBuffer.allocate(originalFilter.getSerializedSize()); + + // Act + originalFilter.serialize(buffer); + buffer.flip(); + final var deserializedFilter = deserializer.apply(buffer); + + // Assert + for (int i = 0; i < size; i++) { + assertTrue("Sequential key " + i + " should be present in " + filterName, + deserializedFilter.mayContain(i)); + } + assertFalse("Key outside range should not be in " + filterName + " filter", + deserializedFilter.mayContain(size + 1000)); + } + + @Test + public void shouldHandleFilterWithRandomLargeKeys() { + // Arrange + final var keys = new long[]{ + Long.MAX_VALUE - 1, + Long.MAX_VALUE - 100, + Long.MAX_VALUE - 1000, + Long.MAX_VALUE / 2, + Long.MAX_VALUE / 3 + }; + final var originalFilter = constructor.apply(keys); + final var buffer = ByteBuffer.allocate(originalFilter.getSerializedSize()); + + // Act + originalFilter.serialize(buffer); + buffer.flip(); + final var deserializedFilter = deserializer.apply(buffer); + + // Assert + for (final long key : keys) { + assertTrue("Large key " + key + " should be present in " + filterName, + deserializedFilter.mayContain(key)); + } + } + + @Test + public void shouldCorrectlyCalculateSerializedSize() { + // Arrange + final var keys = new long[]{1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L}; + final var filter = constructor.apply(keys); + final int expectedSizeInBytes = filter.getSerializedSize(); + final var buffer = ByteBuffer.allocate(expectedSizeInBytes); + + // Act + filter.serialize(buffer); + + // Assert + assertEquals("Buffer position should equal serialized size for " + filterName, + expectedSizeInBytes, buffer.position()); + assertEquals("Buffer should have no remaining space for " + filterName, + 0, buffer.remaining()); + } + + @Test + public void shouldHandleExactBufferSize() { + // Arrange + final var keys = new long[]{100L, 200L, 300L}; + final var filter = constructor.apply(keys); + final var exactBuffer = ByteBuffer.allocate(filter.getSerializedSize()); + + // Act + filter.serialize(exactBuffer); + exactBuffer.flip(); + final var deserializedFilter = deserializer.apply(exactBuffer); + + // Assert + for (final long key : keys) { + assertTrue("Key " + key + " should be present with exact buffer in " + filterName, + deserializedFilter.mayContain(key)); + } + assertEquals("No bytes should remain in buffer for " + filterName, 0, exactBuffer.remaining()); + } + + @Test + public void shouldHandleLargerBufferThanNeeded() { + // Arrange + final var keys = new long[]{1L, 2L, 3L}; + final var filter = constructor.apply(keys); + final var largeBuffer = ByteBuffer.allocate(filter.getSerializedSize() + 1000); + + // Act + filter.serialize(largeBuffer); + largeBuffer.flip(); + final var deserializedFilter = deserializer.apply(largeBuffer); + + // Assert + for (final long key : keys) { + assertTrue("Key " + key + " should be present with larger buffer in " + filterName, + deserializedFilter.mayContain(key)); + } + } +}