diff --git a/README.md b/README.md index a11a5faf..85887cf1 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,11 @@ [![Maven Central](https://img.shields.io/maven-central/v/io.github.dfa1.vortex/vortex-reader.svg)](https://central.sonatype.com/artifact/io.github.dfa1.vortex/vortex-reader) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/license/Apache-2.0) +> **Alpha** — not production-ready. APIs will change without notice. + Pure-Java reader/writer for the [Vortex](https://github.com/vortex-data/vortex) columnar file format. 100% Java, no JNI, no `sun.misc.Unsafe`. Uses the FFM API (`MemorySegment`/`Arena`, Java 25+) -for zero-copy memory-mapped reads. Read benchmarks match or beat the Rust JNI on the workloads -tested (Apple M5, JDK 25); see [docs/explanation.md#benchmarks](docs/explanation.md#benchmarks). +for zero-copy memory-mapped reads. | Project | Language | Notes | |---------------------------------------------------------------------|----------|-----------------------------------------| @@ -49,12 +50,15 @@ try (VortexReader vf = VortexReader.open(Path.of("data/example.vortex")); } ``` -> **Lifecycle.** `Chunk` owns a confined `Arena` — close it (try-with-resources -> or `iter.forEachRemaining`) to release the decoded buffers. Full lifecycle -> rules: [docs/explanation.md#memory-model](docs/explanation.md#memory-model). +> **Lifecycle.** `ScanIterator` implements `Iterator` and `Chunk` implements +> `AutoCloseable`. Each chunk owns a confined `Arena`; closing it releases the +> decoded buffers. Calling `iter.next()` while a prior chunk is still open throws +> `IllegalStateException`. Use try-with-resources, or +> `iter.forEachRemaining(c -> ...)` which closes each chunk for you. See +> [docs/explanation.md#memory-model](docs/explanation.md#memory-model). -For more examples (writing, projection, filtering, custom encodings, CLI) see -the documentation below. +For more examples — writing, projection, filtering, custom encodings, and the CLI — +see the documentation below. ## Documentation diff --git a/TODO.md b/TODO.md index e0ec684a..da8cf356 100644 --- a/TODO.md +++ b/TODO.md @@ -248,17 +248,8 @@ See [docs/compatibility.md](docs/compatibility.md) for the full encoding support using the 5-symbol generator from `OhlcEncodingInspectionIntegrationTest#writeOhlcMultiSymbol` and assert the global-dict file is smaller than the per-chunk-dict baseline. -- [ ] **FSST symbol-table builder: port `fsst-rs` Algorithm 3** — - `FsstEncoding.Encoder` is a single-pass, bigram-only top-K table. Rust's - `fsst-rs` (used by `vortex-fsst`) implements **Algorithm 3 from the FSST - paper**: 5 generations of iterative training, symbols up to 8 bytes long, - Lossy Perfect Hash Table for O(1) symbol lookup during compression. On the - high-cardinality random ASCII benchmark - (`FileSizeComparisonIntegrationTest#highCardinalityUtf8_javaVsJni`) the gap - is Java 1.75× raw vs Rust 1.18× raw — purely encoder quality, the wire - format and decoder are unchanged. Estimate: ~1 week of work. - Reference: , - . +- [ ] **FSST in CASCADE_CODECS** — `FsstEncoding` exists but not in the cascade; Rust uses FSST for + `store_and_fwd_flag`. Small gain on taxi (~0.1 MB). ### `vortex.zstd` known limitations diff --git a/cli/src/main/java/io/github/dfa1/vortex/cli/tui/VortexInspectorTui.java b/cli/src/main/java/io/github/dfa1/vortex/cli/tui/VortexInspectorTui.java index a6c58272..9e282b03 100644 --- a/cli/src/main/java/io/github/dfa1/vortex/cli/tui/VortexInspectorTui.java +++ b/cli/src/main/java/io/github/dfa1/vortex/cli/tui/VortexInspectorTui.java @@ -571,11 +571,8 @@ private void runDictLoad(InspectorTree.Node dictNode) { try (java.lang.foreign.Arena arena = java.lang.foreign.Arena.ofConfined()) { int segIdx = values.segments().getFirst(); SegmentSpec spec = tree.segmentSpecs().get(segIdx); - java.lang.foreign.MemorySegment seg = handle.slice(spec.offset(), spec.length()); io.github.dfa1.vortex.core.array.Array arr = - new io.github.dfa1.vortex.encoding.FlatSegmentDecoder(handle.registry()) - .decode(seg, handle.footer().arraySpecs(), - dtype, values.rowCount(), arena); + handle.decodeFlatSegment(spec, dtype, values.rowCount(), arena); int n = (int) Math.min(arr.length(), DATA_PREVIEW_ROWS); List out = new ArrayList<>(n); for (int i = 0; i < n; i++) { diff --git a/core/pom.xml b/core/pom.xml index b4b90110..03d6e9dd 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -47,6 +47,25 @@ + + + + + org.apache.maven.plugins + maven-jar-plugin + + + publish-test-jar + + test-jar + + + + + + + + + io.github.dfa1.vortex + vortex-inspector + 0.6.0 + +``` + +`./mvnw -pl core,reader,inspector verify` builds the read-only artifact set +without the writer module on the classpath. None of the writer-side encoder +implementations are loaded; `ServiceLoader` resolves only the +standalone decoders in `reader`, falling back to the bifunctional `Encoding` +implementations in `core` for encoding families not yet lifted (ADR 0001 +Phases 2–3). + ## Known wire-format gaps | Item | Introduced | Java status | diff --git a/docs/explanation.md b/docs/explanation.md index 19284a08..fc3d4e9d 100644 --- a/docs/explanation.md +++ b/docs/explanation.md @@ -326,334 +326,21 @@ At decode time the registry maps the ID string from the Layout node to the right Custom encodings can be added at build time: `Registry.builder().register(myEncoding).build()`. Files with unrecognised IDs throw `VortexException` unless the builder enabled `allowUnknown()`. -## Why cascading compression - -Vortex stores each column as a tree of encodings. The leaves are raw memory -segments, the inner nodes describe how those bytes turn back into values. -Without cascading, the writer picks one encoding from a static list and stops. -With cascading, the writer samples the data, lets candidate encodings expose -their open child slots, and recursively picks the best inner encoding for -every child. That recursion is what turns a per-encoding sales-pitch into -real compression. - -Six representative columns, both paths. The same scenarios run as -regression tests so the ratios below stay anchored: - -``` -./mvnw verify -pl integration -am \ - -Dit.test=CompressionShowcaseIntegrationTest -``` - -Numbers below: 1 000 000 rows per column, JDK 25, vortex-java HEAD. - -### The encodings in play, and what they're good for - -Before reading the table, you need to know what the writer's pick means. The -labels in the size table aren't marketing terms — each one is a concrete -encoding with a sweet-spot data shape. - -| Encoding | What it does | "Friendly" data shape | -|------------------|--------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------| -| **Primitive** | Raw little-endian bytes. The baseline. | Truly random data — no overhead, no compression. | -| **FoR** | Frame-of-reference. Subtract `min(col)` from every value; store the minimum + the residuals. | Bounded range: timestamps, monotonic IDs, anything where `max - min` is small. | -| **Bitpacked** | Packs N values into the smallest bit-width that fits the maximum value. | Small-range integers (e.g. FoR residuals, dict codes, small counts). | -| **Dict** | Build a values table; store one code per row pointing into it. | Low-cardinality strings or numbers. **Loses** above ~50 % distinct. | -| **ALP** | Adaptive Lossless float compression. Detects a per-column scale + exponent so that values become small integers, then stores the mantissa. | "Physical world" doubles — prices, sensor readings, percentages. **Loses** on truly random F64. | -| **RLE / RunEnd** | Replace each run of identical values with `(run_end_index, value)`. | Long runs of repeated values — status flags, partition IDs, slow-tick counters. | -| **VarBin** | Concatenate all bytes, store offsets per row. Arrow's classic string layout. | Strings of any cardinality — the safe default for Utf8. | -| **FSST** | Builds a per-column symbol table of the most common byte bigrams, then rewrites the strings as 1-byte codes (or escape + literal). | Short, repetitive-ish strings — log lines, identifiers, JSON keys. Truly random strings beat it. | -| **Constant** | Store the scalar once. Done. | All values identical. | - -So "ALP-friendly doubles" = doubles that can be written as `mantissa × 10^-exp` -with a small mantissa. "FoR-friendly" = bounded range. "RLE-friendly" = long -runs. The cascading compressor does the matchmaking; this table tells you -when each match is a win. - -### Headline table - -``` -dataset raw bytes no-cascade cascade(3) ratio --------------------------------------------------------------------------------- -monotonic-timestamps 8 000 000 8 000 664 2 501 832 3.20x -low-card-categorical 4 200 000 1 000 921 1 000 921 4.20x -random-doubles 8 000 000 12 751 764 8 000 672 1.00x -alp-friendly-doubles 8 000 000 8 000 748 1 256 208 6.37x -rle-int 4 000 000 4 000 656 1 604 2 493.77x -highcard-strings 6 000 000 17 977 750 10 551 040 0.57x -``` - -ratio = raw / cascade(3). - -Three patterns to notice before we dig in: - -- **Cascading is usually a strict improvement** — never bigger, often dramatically - smaller (RLE: 4 MB → 1.6 KB; ALP-friendly F64: 8 MB → 1.3 MB). -- **Without cascading, the writer can make the file *bigger* than raw.** Static - fallbacks pick the first encoding that "accepts" the dtype, even when it - fails to compress (random F64 → ALP adds overhead; high-cardinality strings - → Dict blows up the values vector). -- **The cascade isn't magic.** Truly random data ends up close to raw size - because there's no structure to exploit. - -### 1. Monotonic timestamps — FoR + Bitpacked - -Sensor / log streams produce strictly increasing UNIX timestamps. Each value -is only ~1 second above the previous one, so the *delta* needs ~1 bit, but -the *absolute* value is 64 bits. - -``` -dtype: I64 -sample data: [1767225600, 1767225601, 1767225602, 1767225603, …] -``` - -**Without cascading** — `Primitive` (raw 8 bytes/row) → 8 MB. - -``` - Primitive(I64) - │ - 8 000 000 bytes -``` - -**With cascading depth 3** — sample shows all-positive small deltas. FoR -subtracts the minimum so residuals fit in a tiny number of bits; the open -residual child cascades into Bitpacked, which packs `~20` bits per row -instead of 64. - -``` - FoR - ref = 1767225600 - │ - Bitpacked - bit_width = 20 - │ - ~2.5 MB packed -``` - -3.2× smaller than raw. The same shape covers row IDs, monotonic counters, -millisecond timestamps, anything with bounded local deltas. - -### 2. Low-cardinality categorical — Dict - -A column of ticker symbols repeats the same 5 distinct strings across a -million rows. - -``` -dtype: Utf8 -sample data: ["AAPL", "MSFT", "NVDA", "GOOGL", "AMZN", "AAPL", "MSFT", …] -``` - -**Without cascading** — the default codec list picks Dict before VarBin, so -even the no-cascade path catches this one. - -``` - Dict - / \ - Values Codes - (5 strings) (1 M × U8) - ~20 B ~1 MB -``` - -**With cascading depth 3** — same shape. Dict already wins; codes are tight -(1 byte each because cardinality ≤ 256). 4.2× smaller than raw with no -recursive work needed. - -The new cardinality gate (added in 0.6) only kicks in *above* 50 % distinct; -this column is at 0.0005 % distinct, far below the gate. - -### 3. Random doubles — Primitive (no compression possible) - -The worst case. Truly random F64 values have no exploitable structure. - -``` -dtype: F64 -sample data: [0.733, 0.642, 0.218, 0.875, 0.157, 0.488, 0.999, …] -``` - -**Without cascading** — the static list tries ALP first. ALP detects no -common scale factor, falls back to its uncompressed path, and the resulting -file is *bigger* than raw (12.7 MB vs 8 MB) thanks to per-row mantissa / -exponent bookkeeping. - -``` - ALP (degenerate) - / \ \ - Encoded Patch Patch - (no-op) idx values - 8 MB ~4.7 MB -``` - -**With cascading depth 3** — the cost-based selector measures ALP, sees it's -worse than primitive, and picks `Primitive` raw. - -``` - Primitive(F64) - │ - 8 000 000 bytes -``` - -Lesson: cascading is the cheapest insurance against the writer making the -file *larger*. Without it, "first match" can lose to raw bytes. - -### 4. Slowly-varying doubles — ALP + Bitpacked - -Stock prices, sensor readings, and most "physical world" doubles drift -slowly. They're representable as `mantissa × 10^-exp` with a small mantissa, -which is exactly what ALP is for. - -``` -dtype: F64 -sample data: [100.05, 100.04, 100.06, 100.07, 100.05, 100.03, …] -``` - -**Without cascading** — ALP picks the right shape, but its mantissa child -gets emitted as raw `Primitive(I64)` → ~8 MB. - -``` - ALP - e=2 f=1 - │ - Primitive(I64) - │ - 8 000 000 bytes -``` - -**With cascading depth 3** — the same ALP outer, but its mantissa child -cascades into FoR + Bitpacked. The mantissa range fits in ~10 bits. - -``` - ALP - e=2 f=1 - │ - FoR - │ - Bitpacked - bit_width = 10 - │ - ~1.25 MB -``` - -6.4× smaller than raw. This is where Vortex really shines vs Parquet's -fixed page-level codecs: nested arithmetic encodings stack cleanly. - -### 5. Run-encoded ints — RunEnd - -Long runs of the same value: status flags, monotonic counters that tick -slowly, partition IDs. - -``` -dtype: I32 -sample data: [1,1,1,…(10 000)…,1, 2,2,2,…(10 000)…,2, 3,3,…] -``` - -**Without cascading** — `Primitive` → 4 MB. - -``` - Primitive(I32) - │ - 4 000 000 bytes -``` - -**With cascading depth 3** — the run structure is detected, RunEnd encodes -each run as `(end_index, value)` and both children compress. - -``` - RunEnd - / \ - Run ends Values - (100 ints) (100 ints) - ~400 B ~400 B - ↓ ↓ - Bitpacked Bitpacked - (or FoR) (or FoR) -``` - -**2 493×** smaller than raw — the run is so long that the actual payload -collapses to ~1.6 KB. - -### 6. High-cardinality strings — Dict fail / FSST partial win - -A million all-distinct random 6-character strings — the kind of column that -turns into a tar pit for dictionary-style encodings. - -``` -dtype: Utf8 -sample data: ["wkzqof", "tdmgxh", "ablrpe", "yvcjsi", …] -``` - -**Without cascading** — Dict is the first acceptor for Utf8 in the default -codec list. It builds a dictionary nearly as big as the input plus a 4-byte -code per row. - -``` - Dict - / \ - Values Codes - (1 M strings) (1 M × U32) - ~10 MB ~4 MB - = 18 MB (3× raw!) -``` - -**With cascading depth 3** — the new (0.6) cardinality gate in `DictEncoding` -detects > 50 % distinct on the sample, returns `notApplicable`, and the -cascade rotates to FSST. - -``` - FSST - ╭────────┼────────╮ - Symbol Symbol Compressed - table lengths payload - (~2 KB) (~255 B) (~5 MB) - │ - ↓ cascade - uncompressed_lens codes_offsets - (Constant: ~4 B) (FoR+Bitpacked) - ~5 MB -``` - -Result: 10.5 MB — still larger than the 6 MB raw, but **42 % smaller than -no-cascade**. Truly random short strings are hard for any encoder (Rust hits -the same wall on this input). Java's FSST symbol-table builder is also less -aggressive than Rust's for now — see `TODO.md`. - -### Takeaways - -- **Always use `WriteOptions.cascading(...)` unless you have a reason not - to.** The default is `cascading(0)` for legacy compatibility; we'll - probably flip the default in a future major. -- **Cascade depth 3** is the sweet spot in practice. Deeper costs more - encode CPU for tiny diminishing returns; shallower misses key combos - like ALP→FoR→Bitpacked. -- **The encoding tree is the API.** If you `vortex inspect ` you'll - see the exact structure of each column, with sizes per node. No black - box. -- **Codec choice is data-dependent.** No single encoding is "best". The - point of cascading is to let the writer admit when it's wrong and try - again. - ## Benchmarks -JMH throughput (ops/s = full-file scans per second). Higher is better. - -**Apples-to-apples.** The Java and JNI numbers in the OHLC and big-file tables -below come from reading the **same on-disk file**, written once by `vortex-jni` -(Rust-chosen encodings) and opened by both decoders. Differences are pure -decoder cost — same bytes in, same row count out. See -[`RustVsJavaReadBenchmark.@Setup`](../performance/src/main/java/io/github/dfa1/vortex/performance/RustVsJavaReadBenchmark.java) -— `sharedBenchFile` is written by `writeJni(...)` and shared across every -`jniRead*` and `javaRead*` method (`javaReadCascading` is the one exception: -it reads a Java-written file with `WriteOptions.cascading(3)`). +JMH throughput (ops/s = full-file scans per second). Higher is better. Numbers +re-measured 2026-06-08 against commit `051a794`. -**Environment:** Apple M5, OpenJDK 25, 3 warmup × 3 s, 5 measurement × 5 s, fork 1. -Numbers below re-measured 2026-06-11. +**Environment:** Apple M5, OpenJDK 25, 5 warmup × 3 s, 10 measurement × 5 s, fork 1. ### OHLC read — 10 M rows, 58.9 MB (Rust-written file, single-column projection) -| Benchmark | Java (ops/s) | JNI/Rust (ops/s) | Java speedup | -|-----------------------------|----------------|------------------|--------------| -| close (F64/ALP) | 68.8 ± 0.2 | 50.2 ± 1.2 | **1.4×** | -| volume (I64/bitpacked) | 118.9 ± 0.9 | 50.1 ± 2.6 | **2.4×** | -| symbol (varbin) | 104.8 ± 5.1 | 9.7 ± 0.5 | **10.8×** | -| cascading (depth 3, volume) | 86.7 ± 2.5 | n/a | — | +| Benchmark | Java (ops/s) | JNI/Rust (ops/s) | Java speedup | +|---------------------|---------------|------------------|--------------| +| close (F64/ALP) | 61.0 ± 5.8 | 47.9 ± 0.7 | **1.3×** | +| volume (I64/bitpacked) | 104.8 ± 5.1 | 48.4 ± 1.7 | **2.2×** | +| symbol (varbin) | 97.8 ± 1.8 | 9.2 ± 0.4 | **10.6×** | +| cascading (depth 3, volume) | 80.9 ± 1.2 | n/a | — | ### OHLC write — 10 M rows diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/AllowUnknownIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/AllowUnknownIntegrationTest.java index b7fbcd20..70724981 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/AllowUnknownIntegrationTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/AllowUnknownIntegrationTest.java @@ -7,7 +7,7 @@ import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.UnknownArray; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.VortexReader; import org.apache.arrow.c.ArrowArray; import org.apache.arrow.c.ArrowSchema; @@ -82,7 +82,7 @@ void allowUnknown_emptyRegistry_allColumnsReturnUnknownArray(@TempDir Path tmp) var totalRows = new AtomicLong(); var allUnknown = new AtomicBoolean(true); var chunkCount = new AtomicLong(); - try (VortexReader vf = VortexReader.open(file, Registry.builder().allowUnknown().build()); + try (VortexReader vf = VortexReader.open(file, ReadRegistry.builder().allowUnknown().build()); var iter = vf.scan(io.github.dfa1.vortex.reader.ScanOptions.all())) { iter.forEachRemaining(c -> { totalRows.addAndGet(c.rowCount()); @@ -109,7 +109,7 @@ void strictMode_emptyRegistry_throwsVortexException(@TempDir Path tmp) throws IO // When / Then — strict mode throws rather than returning UnknownArray assertThatThrownBy(() -> { - try (VortexReader vf = VortexReader.open(file, Registry.empty()); + try (VortexReader vf = VortexReader.open(file, ReadRegistry.empty()); var iter = vf.scan(io.github.dfa1.vortex.reader.ScanOptions.all())) { iter.forEachRemaining(c -> {}); } @@ -126,7 +126,7 @@ void allowUnknown_loadAllRegistry_noUnknownArrayForSupportedEncodings(@TempDir P var chunkCount = new AtomicLong(); var anyUnknown = new AtomicBoolean(false); try (VortexReader vf = VortexReader.open(file, - Registry.builder().registerServiceLoaded().allowUnknown().build()); + ReadRegistry.builder().registerServiceLoaded().allowUnknown().build()); var iter = vf.scan(io.github.dfa1.vortex.reader.ScanOptions.all())) { iter.forEachRemaining(c -> { chunkCount.incrementAndGet(); diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/CompressionShowcaseIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/CompressionShowcaseIntegrationTest.java deleted file mode 100644 index 2243fff9..00000000 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/CompressionShowcaseIntegrationTest.java +++ /dev/null @@ -1,187 +0,0 @@ -package io.github.dfa1.vortex.integration; - -import io.github.dfa1.vortex.core.DType; -import io.github.dfa1.vortex.core.PType; -import io.github.dfa1.vortex.writer.VortexWriter; -import io.github.dfa1.vortex.writer.WriteOptions; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; - -import java.io.IOException; -import java.nio.channels.FileChannel; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.StandardOpenOption; -import java.time.LocalDate; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Random; - -import static org.assertj.core.api.Assertions.assertThat; - -/// Regression guard for the cascading compressor's encoding choices. -/// -/// Anchors the size ratios documented in -/// `docs/explanation.md#why-cascading-compression`. If a refactor silently -/// swaps the encoding picked for a representative column shape (e.g. ALP -/// gives up on slowly-varying doubles, Dict eats high-cardinality Utf8), -/// the matching assertion here fails and the lookup table in the docs gets -/// re-grounded. -/// -/// Lower bounds are deliberately loose: catastrophic regressions only, -/// not 5 % drift. -class CompressionShowcaseIntegrationTest { - - private static final int N = 1_000_000; - - @Test - void monotonicTimestamps_compressVia_FoR_Bitpacked(@TempDir Path tmp) throws IOException { - // Given — UNIX seconds incrementing by 1 per row. Cascade should pick - // FoR (subtract base) then Bitpacked on the small residuals. - long base = LocalDate.of(2026, 1, 1).toEpochDay() * 86_400L; - long[] data = new long[N]; - for (int i = 0; i < N; i++) { - data[i] = base + i; - } - - long cascadeSize = writeCascade(tmp, "monotonic", new DType.Primitive(PType.I64, false), data); - long rawBytes = (long) N * 8; - - // Then — at least 2.5x smaller than raw I64. Real ratio on M5 is ~3.2x. - assertThat((double) rawBytes / cascadeSize).isGreaterThan(2.5); - } - - @Test - void lowCardCategorical_compressVia_Dict(@TempDir Path tmp) throws IOException { - // Given — 5 distinct ticker symbols cycled across all rows. Even with - // cascade off Dict wins (first acceptor for Utf8); cascade can't make - // it worse. - String[] tickers = {"AAPL", "MSFT", "NVDA", "GOOGL", "AMZN"}; - String[] data = new String[N]; - for (int i = 0; i < N; i++) { - data[i] = tickers[i % tickers.length]; - } - - long cascadeSize = writeCascade(tmp, "lowcard", new DType.Utf8(false), data); - long rawBytes = totalUtf8Bytes(data); - - // Then — at least 3x smaller. Real ratio ~4.2x. - assertThat((double) rawBytes / cascadeSize).isGreaterThan(3.0); - } - - @Test - void randomDoubles_fallbackTo_Primitive(@TempDir Path tmp) throws IOException { - // Given — uniform random F64. No structure; the cascade must measure - // alternatives, find none beats raw, and emit Primitive. - Random rng = new Random(42); - double[] data = new double[N]; - for (int i = 0; i < N; i++) { - data[i] = rng.nextDouble(); - } - - long cascadeSize = writeCascade(tmp, "random-f64", new DType.Primitive(PType.F64, false), data); - long rawBytes = (long) N * 8; - - // Then — within 5 % of raw. No compression possible but cascade must - // not make it bigger. - assertThat(cascadeSize).isLessThan((long) (rawBytes * 1.05)); - } - - @Test - void alpFriendlyDoubles_compressVia_ALP_FoR_Bitpacked(@TempDir Path tmp) throws IOException { - // Given — slowly varying prices around 100, two-decimal precision. - // ALP detects the scale factor; its mantissa child cascades to - // FoR + Bitpacked. - Random rng = new Random(42); - double[] data = new double[N]; - double price = 100.00; - for (int i = 0; i < N; i++) { - price += (rng.nextDouble() - 0.5) * 0.02; - data[i] = Math.round(price * 100.0) / 100.0; - } - - long cascadeSize = writeCascade(tmp, "alp", new DType.Primitive(PType.F64, false), data); - long rawBytes = (long) N * 8; - - // Then — at least 5x smaller. Real ratio ~6.4x. - assertThat((double) rawBytes / cascadeSize).isGreaterThan(5.0); - } - - @Test - void rleInt_compressVia_RunEnd(@TempDir Path tmp) throws IOException { - // Given — 10k-row runs of the same value. RunEnd reduces N rows to - // ~N/run_length pairs; both children compress further via Bitpacked. - int[] data = new int[N]; - int run = 10_000; - int v = 1; - for (int i = 0; i < N; i++) { - if (i % run == 0) { - v += 1; - } - data[i] = v; - } - - long cascadeSize = writeCascade(tmp, "rle", new DType.Primitive(PType.I32, false), data); - long rawBytes = (long) N * 4; - - // Then — at least 1000x smaller. Real ratio ~2493x. - assertThat((double) rawBytes / cascadeSize).isGreaterThan(1000.0); - } - - @Test - void highCardStrings_routeTo_FSST_via_dictGate(@TempDir Path tmp) throws IOException { - // Given — all-distinct random 6-char strings. Dict has > 50 % distinct - // on the sample so its cardinality gate returns notApplicable and the - // cascade rotates to FSST. - Random rng = new Random(42); - String[] data = new String[N]; - byte[] buf = new byte[6]; - for (int i = 0; i < N; i++) { - for (int k = 0; k < 6; k++) { - buf[k] = (byte) ('a' + rng.nextInt(26)); - } - data[i] = new String(buf); - } - - long noCascadeSize = writeNoCascade(tmp, "hicard-flat", new DType.Utf8(false), data); - long cascadeSize = writeCascade(tmp, "hicard-cas", new DType.Utf8(false), data); - - // Then — FSST cascade strictly beats the no-cascade Dict fallback. - // (Both lose to raw bytes on truly random short strings; that's - // fundamental, not a bug.) - assertThat(cascadeSize).isLessThan(noCascadeSize); - } - - // ── Helpers ─────────────────────────────────────────────────────────────── - - private static long writeCascade(Path dir, String name, DType dtype, Object data) throws IOException { - return write(dir.resolve(name + ".cas.vtx"), dtype, data, - WriteOptions.cascading(3).withGlobalDict(false)); - } - - private static long writeNoCascade(Path dir, String name, DType dtype, Object data) throws IOException { - return write(dir.resolve(name + ".flat.vtx"), dtype, data, - WriteOptions.defaults().withGlobalDict(false)); - } - - private static long write(Path file, DType dtype, Object data, WriteOptions opts) throws IOException { - DType.Struct schema = new DType.Struct(List.of("v"), List.of(dtype), false); - try (FileChannel ch = FileChannel.open(file, - StandardOpenOption.CREATE, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING); - VortexWriter writer = VortexWriter.create(ch, schema, opts)) { - Map col = new HashMap<>(); - col.put("v", data); - writer.writeChunk(col); - } - return Files.size(file); - } - - private static long totalUtf8Bytes(String[] data) { - long total = 0; - for (String s : data) { - total += s.length(); - } - return total; - } -} diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/FileSizeComparisonIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/FileSizeComparisonIntegrationTest.java index 67ed8531..3e1a5401 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/FileSizeComparisonIntegrationTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/FileSizeComparisonIntegrationTest.java @@ -6,7 +6,7 @@ import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.array.LongArray; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.VortexReader; import io.github.dfa1.vortex.writer.VortexWriter; import io.github.dfa1.vortex.writer.WriteOptions; @@ -201,7 +201,7 @@ void fileSizeComparison(@TempDir Path tmp) throws IOException { // Then — Java file is readable with correct row count var totalRows = new java.util.concurrent.atomic.AtomicLong(); - try (VortexReader reader = VortexReader.open(javaFile, Registry.loadAll()); + try (VortexReader reader = VortexReader.open(javaFile, ReadRegistry.loadAll()); var iter = reader.scan(io.github.dfa1.vortex.reader.ScanOptions.columns("volume"))) { iter.forEachRemaining(c -> totalRows.addAndGet(c.column("volume").length())); } @@ -231,7 +231,7 @@ void withZstd_smallerFile_and_readable(@TempDir Path tmp) throws IOException { // Then — Zstd file is readable with correct row count var totalRows = new java.util.concurrent.atomic.AtomicLong(); - try (VortexReader reader = VortexReader.open(withZstd, Registry.loadAll()); + try (VortexReader reader = VortexReader.open(withZstd, ReadRegistry.loadAll()); var iter = reader.scan(io.github.dfa1.vortex.reader.ScanOptions.columns("volume"))) { iter.forEachRemaining(c -> totalRows.addAndGet(c.column("volume").length())); } @@ -303,7 +303,7 @@ void highCardinalityUtf8_javaVsJni(@TempDir Path tmp) throws IOException { // Then — Java file is readable and row count matches var totalRows = new java.util.concurrent.atomic.AtomicLong(); - try (VortexReader reader = VortexReader.open(javaFile, Registry.loadAll()); + try (VortexReader reader = VortexReader.open(javaFile, ReadRegistry.loadAll()); var iter = reader.scan(io.github.dfa1.vortex.reader.ScanOptions.columns("s"))) { iter.forEachRemaining(c -> totalRows.addAndGet(c.column("s").length())); } diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/InspectForTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/InspectForTest.java index de952b5c..9df849e4 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/InspectForTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/InspectForTest.java @@ -1,7 +1,7 @@ package io.github.dfa1.vortex.integration; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.inspect.VortexInspector; import io.github.dfa1.vortex.reader.VortexReader; import org.junit.jupiter.api.Disabled; @@ -14,7 +14,7 @@ class InspectForTest { @Test void inspect() throws Exception { for (String f : new String[]{"/tmp/for.vortex", "/tmp/rle.vortex"}) { - try (VortexReader r = VortexReader.open(Path.of(f), Registry.loadAll())) { + try (VortexReader r = VortexReader.open(Path.of(f), ReadRegistry.loadAll())) { System.out.println("=== " + f + " ==="); System.out.println(VortexInspector.inspect(r)); try (var iter = r.scan(io.github.dfa1.vortex.reader.ScanOptions.all())) { diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java index 4e66919f..dc603757 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java @@ -10,22 +10,22 @@ import dev.vortex.jni.NativeLoader; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; -import io.github.dfa1.vortex.encoding.BoolEncoding; -import io.github.dfa1.vortex.encoding.ByteBoolEncoding; -import io.github.dfa1.vortex.encoding.ConstantEncoding; -import io.github.dfa1.vortex.encoding.FsstEncoding; +import io.github.dfa1.vortex.writer.encode.BoolEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.ByteBoolEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.ConstantEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.FsstEncodingEncoder; import io.github.dfa1.vortex.encoding.ListData; -import io.github.dfa1.vortex.encoding.ListEncoding; +import io.github.dfa1.vortex.writer.encode.ListEncodingEncoder; import io.github.dfa1.vortex.encoding.ListViewData; -import io.github.dfa1.vortex.encoding.ListViewEncoding; -import io.github.dfa1.vortex.encoding.NullEncoding; -import io.github.dfa1.vortex.encoding.RleEncoding; -import io.github.dfa1.vortex.encoding.RunEndEncoding; -import io.github.dfa1.vortex.encoding.SparseEncoding; -import io.github.dfa1.vortex.encoding.VarBinEncoding; -import io.github.dfa1.vortex.encoding.VarBinViewEncoding; -import io.github.dfa1.vortex.encoding.ZigZagEncoding; -import io.github.dfa1.vortex.encoding.ZstdEncoding; +import io.github.dfa1.vortex.writer.encode.ListViewEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.NullEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.RleEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.RunEndEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.SparseEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.VarBinEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.VarBinViewEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.ZigZagEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.ZstdEncodingEncoder; import io.github.dfa1.vortex.writer.VortexWriter; import io.github.dfa1.vortex.writer.WriteOptions; import org.apache.arrow.memory.BufferAllocator; @@ -503,7 +503,7 @@ void javaWriter_jniReader_utf8Column(@TempDir Path tmp) throws IOException { String[] data = {"apple", "banana", "cherry", "date", "elderberry"}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, STRING_SCHEMA, WriteOptions.defaults(), - List.of(new VarBinEncoding()))) { + List.of(new VarBinEncodingEncoder()))) { // When sut.writeChunk(Map.of("s", data)); } @@ -520,7 +520,7 @@ void javaWriter_jniReader_fsstUtf8Column(@TempDir Path tmp) throws IOException { String[] data = {"apple", "banana", "cherry", "apricot", "avocado", "almond"}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, STRING_SCHEMA, WriteOptions.defaults(), - List.of(new FsstEncoding()))) { + List.of(new FsstEncodingEncoder()))) { // When sut.writeChunk(Map.of("s", data)); } @@ -537,7 +537,7 @@ void javaWriter_jniReader_varBinViewUtf8Column_inlined(@TempDir Path tmp) throws String[] data = {"hi", "yo", "ok", "abc", "short", "exactly12ok!"}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, STRING_SCHEMA, WriteOptions.defaults(), - List.of(new VarBinViewEncoding()))) { + List.of(new VarBinViewEncodingEncoder()))) { // When sut.writeChunk(Map.of("s", data)); } @@ -554,7 +554,7 @@ void javaWriter_jniReader_varBinViewUtf8Column_referenced(@TempDir Path tmp) thr String[] data = {"this is long text", "another long string", "yet another long one", "thirteenchars!"}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, STRING_SCHEMA, WriteOptions.defaults(), - List.of(new VarBinViewEncoding()))) { + List.of(new VarBinViewEncodingEncoder()))) { // When sut.writeChunk(Map.of("s", data)); } @@ -571,7 +571,7 @@ void javaWriter_jniReader_varBinViewUtf8Column_mixed(@TempDir Path tmp) throws I String[] data = {"short", "this is a longer string", "hi", "medium length ok", "x", "another longer string here"}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, STRING_SCHEMA, WriteOptions.defaults(), - List.of(new VarBinViewEncoding()))) { + List.of(new VarBinViewEncodingEncoder()))) { // When sut.writeChunk(Map.of("s", data)); } @@ -590,7 +590,7 @@ void prop_varBinView_utf8_roundTripsViaRust(String[] data) throws IOException { Path file = tmp.resolve("pbt_varbinview_utf8.vtx"); try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, STRING_SCHEMA, WriteOptions.defaults(), - List.of(new VarBinViewEncoding()))) { + List.of(new VarBinViewEncodingEncoder()))) { sut.writeChunk(Map.of("s", data)); } String[] decoded = readStringColumn(file, "s"); @@ -851,7 +851,7 @@ void javaWriter_rustReader_zigzag_i32(@TempDir Path tmp) throws IOException { int[] data = {-1000, -1, 0, 1, 127, -127, Integer.MIN_VALUE / 2, Integer.MAX_VALUE / 2}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I32_SCHEMA, WriteOptions.defaults(), - List.of(new ZigZagEncoding()))) { + List.of(new ZigZagEncodingEncoder()))) { // When sut.writeChunk(Map.of("v", data)); } @@ -868,7 +868,7 @@ void javaWriter_rustReader_zigzag_i64(@TempDir Path tmp) throws IOException { long[] data = {Long.MIN_VALUE / 2, -1L, 0L, 1L, Long.MAX_VALUE / 2}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, TS_SCHEMA, WriteOptions.defaults(), - List.of(new ZigZagEncoding()))) { + List.of(new ZigZagEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", data)); } @@ -885,7 +885,7 @@ void javaWriter_rustReader_runEnd_i32(@TempDir Path tmp) throws IOException { int[] data = {10, 10, 10, 20, 20, 30, 30, 30, 30}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I32_SCHEMA, WriteOptions.defaults(), - List.of(new RunEndEncoding()))) { + List.of(new RunEndEncodingEncoder()))) { // When sut.writeChunk(Map.of("v", data)); } @@ -902,7 +902,7 @@ void javaWriter_rustReader_runEnd_i64(@TempDir Path tmp) throws IOException { long[] data = {100L, 100L, 200L, 200L, 200L, 300L}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, TS_SCHEMA, WriteOptions.defaults(), - List.of(new RunEndEncoding()))) { + List.of(new RunEndEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", data)); } @@ -919,7 +919,7 @@ void javaWriter_rustReader_rle_i32(@TempDir Path tmp) throws IOException { int[] data = {5, 5, 5, 7, 7, 5, 5, 5, 5, 9}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I32_SCHEMA, WriteOptions.defaults(), - List.of(new RleEncoding()))) { + List.of(new RleEncodingEncoder()))) { // When sut.writeChunk(Map.of("v", data)); } @@ -936,7 +936,7 @@ void javaWriter_rustReader_rle_i64(@TempDir Path tmp) throws IOException { long[] data = {1L, 1L, 1L, 2L, 2L, 2L, 3L, 3L}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, TS_SCHEMA, WriteOptions.defaults(), - List.of(new RleEncoding()))) { + List.of(new RleEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", data)); } @@ -953,7 +953,7 @@ void javaWriter_rustReader_constant_i32(@TempDir Path tmp) throws IOException { int[] data = {42, 42, 42, 42, 42}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I32_SCHEMA, WriteOptions.defaults(), - List.of(new ConstantEncoding()))) { + List.of(new ConstantEncodingEncoder()))) { // When sut.writeChunk(Map.of("v", data)); } @@ -970,7 +970,7 @@ void javaWriter_rustReader_sparse_i32(@TempDir Path tmp) throws IOException { int[] data = {0, 0, 7, 0, 0, 0, 13, 0, 0, 0}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I32_SCHEMA, WriteOptions.defaults(), - List.of(new SparseEncoding()))) { + List.of(new SparseEncodingEncoder()))) { // When sut.writeChunk(Map.of("v", data)); } @@ -987,7 +987,7 @@ void javaWriter_rustReader_bool_boolEncoding(@TempDir Path tmp) throws IOExcepti boolean[] data = {true, false, true, true, false, false, true}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, BOOL_SCHEMA, WriteOptions.defaults(), - List.of(new BoolEncoding()))) { + List.of(new BoolEncodingEncoder()))) { // When sut.writeChunk(Map.of("b", data)); } @@ -1006,7 +1006,7 @@ void javaWriter_rustReader_bool_byteBoolEncoding(@TempDir Path tmp) throws IOExc boolean[] data = {false, true, false, true, true}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, BOOL_SCHEMA, WriteOptions.defaults(), - List.of(new ByteBoolEncoding()))) { + List.of(new ByteBoolEncodingEncoder()))) { // When sut.writeChunk(Map.of("b", data)); } @@ -1023,7 +1023,7 @@ void javaWriter_rustReader_nullColumn(@TempDir Path tmp) throws IOException { long rowCount = 7L; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, NULL_SCHEMA, WriteOptions.defaults(), - List.of(new NullEncoding()))) { + List.of(new NullEncodingEncoder()))) { // When — data is ignored by NullEncoding; pass long[] to satisfy arrayLength sut.writeChunk(Map.of("n", new long[(int) rowCount])); } @@ -1040,7 +1040,7 @@ void javaWriter_rustReader_zstd_i64(@TempDir Path tmp) throws IOException { long[] data = {1L, 2L, 3L, 4L, 1000L, 9999L, Long.MAX_VALUE}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, TS_SCHEMA, WriteOptions.defaults(), - List.of(new ZstdEncoding()))) { + List.of(new ZstdEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", data)); } @@ -1057,7 +1057,7 @@ void javaWriter_rustReader_zstd_utf8(@TempDir Path tmp) throws IOException { String[] data = {"hello", "world", "from", "zstd"}; try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, STRING_SCHEMA, WriteOptions.defaults(), - List.of(new ZstdEncoding()))) { + List.of(new ZstdEncodingEncoder()))) { // When sut.writeChunk(Map.of("s", data)); } @@ -1076,7 +1076,7 @@ void javaWriter_rustReader_list_i64(@TempDir Path tmp) throws IOException { ListData data = new ListData(elements, offsets, 4L); try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, LIST_I64_SCHEMA, WriteOptions.defaults(), - List.of(new ListEncoding()))) { + List.of(new ListEncodingEncoder()))) { // When sut.writeChunk(Map.of("items", data)); } @@ -1096,7 +1096,7 @@ void javaWriter_rustReader_listView_i64(@TempDir Path tmp) throws IOException { ListViewData data = new ListViewData(elements, offsets, sizes, 4L); try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, LIST_I64_SCHEMA, WriteOptions.defaults(), - List.of(new ListViewEncoding()))) { + List.of(new ListViewEncodingEncoder()))) { // When sut.writeChunk(Map.of("items", data)); } diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/OhlcEncodingInspectionIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/OhlcEncodingInspectionIntegrationTest.java index e8ddff1d..75302542 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/OhlcEncodingInspectionIntegrationTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/OhlcEncodingInspectionIntegrationTest.java @@ -4,7 +4,7 @@ import dev.vortex.api.VortexWriter; import dev.vortex.arrow.ArrowAllocation; import dev.vortex.jni.NativeLoader; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.inspect.VortexInspector; import io.github.dfa1.vortex.reader.VortexReader; import org.apache.arrow.c.ArrowArray; @@ -125,7 +125,7 @@ void inspect_ohlcFile_showsColumnEncodings(@TempDir Path tmp) throws IOException // When String report; - try (VortexReader vf = VortexReader.open(file, Registry.loadAll())) { + try (VortexReader vf = VortexReader.open(file, ReadRegistry.loadAll())) { report = VortexInspector.inspect(vf); } diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/PcoFixtureInspectionIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/PcoFixtureInspectionIntegrationTest.java index da3214d5..302585ad 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/PcoFixtureInspectionIntegrationTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/PcoFixtureInspectionIntegrationTest.java @@ -3,7 +3,7 @@ import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.Layout; import io.github.dfa1.vortex.core.SegmentSpec; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.fbs.Array; import io.github.dfa1.vortex.fbs.ArrayNode; import io.github.dfa1.vortex.fbs.Buffer; @@ -45,7 +45,7 @@ class PcoFixtureInspectionIntegrationTest { }; private static void inspect(Path file, StringBuilder out) throws Exception { - try (VortexReader vf = VortexReader.open(file, Registry.empty())) { + try (VortexReader vf = VortexReader.open(file, ReadRegistry.empty())) { out.append("dtype: ").append(formatDType(vf.dtype())).append('\n'); out.append("size: ").append(vf.fileSize()).append(" bytes\n"); diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/RustJavaReaderComparisonIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/RustJavaReaderComparisonIntegrationTest.java index 8f034d2d..66fcca43 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/RustJavaReaderComparisonIntegrationTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/RustJavaReaderComparisonIntegrationTest.java @@ -18,7 +18,7 @@ import io.github.dfa1.vortex.core.array.MaskedArray; import io.github.dfa1.vortex.core.array.ShortArray; import io.github.dfa1.vortex.core.array.VarBinArray; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.inspect.VortexInspector; import io.github.dfa1.vortex.reader.VortexReader; import io.github.dfa1.vortex.reader.Chunk; @@ -213,7 +213,7 @@ private static Stats javaStats(Path file) throws Exception { Map numSums = new LinkedHashMap<>(); Map strLenSums = new LinkedHashMap<>(); long rowCount = 0; - try (VortexReader reader = VortexReader.open(file, Registry.loadAll()); + try (VortexReader reader = VortexReader.open(file, ReadRegistry.loadAll()); var iter = reader.scan(io.github.dfa1.vortex.reader.ScanOptions.all())) { // Skip extension columns: Rust's stats path reports them under their logical // type (timestamp etc.), so summing their storage longs would diverge from diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/RustWritesJavaReadsIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/RustWritesJavaReadsIntegrationTest.java index 883b077b..6cbbc3df 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/RustWritesJavaReadsIntegrationTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/RustWritesJavaReadsIntegrationTest.java @@ -14,7 +14,7 @@ import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.LongArray; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.VortexReader; import org.apache.arrow.c.ArrowArray; import org.apache.arrow.c.ArrowSchema; @@ -153,7 +153,7 @@ private static double[] values(JavaChunk chunk) { // ── Java read helpers ───────────────────────────────────────────────────── private static String firstI64Column(Path file) throws IOException { - try (var vf = VortexReader.open(file, Registry.empty())) { + try (var vf = VortexReader.open(file, ReadRegistry.empty())) { if (vf.dtype() instanceof DType.Struct struct) { for (int i = 0; i < struct.fieldNames().size(); i++) { if (struct.fieldTypes().get(i) instanceof DType.Primitive(PType pt, boolean _) && pt == PType.I64) { @@ -189,7 +189,7 @@ private static long[] readJniLongColumn(Path file, String column) throws IOExcep } private static long[] readJavaLongColumn(Path file, String column) throws IOException { - try (var vf = VortexReader.open(file, Registry.loadAll()); + try (var vf = VortexReader.open(file, ReadRegistry.loadAll()); var iter = vf.scan(io.github.dfa1.vortex.reader.ScanOptions.columns(column))) { var longs = new ArrayList(); iter.forEachRemaining(c -> { @@ -233,7 +233,7 @@ void jniWriter_javaReader_singleChunk(@TempDir Path tmp) throws IOException { writeJni(file, ids, vals); // When / Then - try (var vf = VortexReader.open(file, Registry.loadAll())) { + try (var vf = VortexReader.open(file, ReadRegistry.loadAll())) { List results = scanAll(vf); assertThat(results).hasSize(1); assertThat(results.getFirst().rowCount()).isEqualTo(3L); @@ -253,7 +253,7 @@ void jniWriter_javaReader_multipleChunks(@TempDir Path tmp) throws IOException { } // When / Then — JNI may merge small batches; verify total rows and values - try (var vf = VortexReader.open(file, Registry.loadAll())) { + try (var vf = VortexReader.open(file, ReadRegistry.loadAll())) { List results = scanAll(vf); long totalRows = results.stream().mapToLong(JavaChunk::rowCount).sum(); assertThat(totalRows).isEqualTo(5L); @@ -271,7 +271,7 @@ void jniWriter_javaReader_columnProjection(@TempDir Path tmp) throws IOException writeJni(file, new long[]{10L, 20L}, new double[]{0.1, 0.2}); // When / Then - try (var vf = VortexReader.open(file, Registry.loadAll())) { + try (var vf = VortexReader.open(file, ReadRegistry.loadAll())) { List results = scanAll(vf, io.github.dfa1.vortex.reader.ScanOptions.columns("id")); assertThat(results).hasSize(1); assertThat(results.getFirst().columns()).containsKey("id"); @@ -295,7 +295,7 @@ void jniWriter_javaReader_fewUniqueF64Values(@TempDir Path tmp) throws IOExcepti writeJni(file, ids, vals); // When / Then - try (var vf = VortexReader.open(file, Registry.loadAll())) { + try (var vf = VortexReader.open(file, ReadRegistry.loadAll())) { List results = scanAll(vf, io.github.dfa1.vortex.reader.ScanOptions.columns("value")); long total = results.stream().mapToLong(JavaChunk::rowCount).sum(); assertThat(total).isEqualTo(n); @@ -350,7 +350,7 @@ void jniWriter_nullableColumn_decodesWithoutError(@TempDir Path tmp) throws IOEx } // When / Then — decodes without error, correct row count, correct values - try (var vf = VortexReader.open(file, Registry.loadAll())) { + try (var vf = VortexReader.open(file, ReadRegistry.loadAll())) { List results = scanAll(vf); long totalRows = results.stream().mapToLong(JavaChunk::rowCount).sum(); assertThat(totalRows).isEqualTo(n); @@ -467,7 +467,7 @@ void jniWriter_javaReader_f16_primitiveRoundTrip(@TempDir Path tmp) throws IOExc } // When - try (var vf = VortexReader.open(file, Registry.loadAll())) { + try (var vf = VortexReader.open(file, ReadRegistry.loadAll())) { List results = scanAll(vf); // Then — correct dtype, correct values diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/VortexInspectorIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/VortexInspectorIntegrationTest.java index eab56ba6..8017f890 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/VortexInspectorIntegrationTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/VortexInspectorIntegrationTest.java @@ -4,7 +4,7 @@ import dev.vortex.api.VortexWriter; import dev.vortex.arrow.ArrowAllocation; import dev.vortex.jni.NativeLoader; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.inspect.VortexInspector; import io.github.dfa1.vortex.reader.VortexReader; import org.apache.arrow.c.ArrowArray; @@ -74,7 +74,7 @@ void inspect_showsFileInfoAndEncodings(@TempDir Path tmp) throws IOException { // When String report; - try (VortexReader vf = VortexReader.open(file, Registry.loadAll())) { + try (VortexReader vf = VortexReader.open(file, ReadRegistry.loadAll())) { report = VortexInspector.inspect(vf); } diff --git a/jdbc/src/test/java/io/github/dfa1/vortex/jdbc/JdbcImporterTest.java b/jdbc/src/test/java/io/github/dfa1/vortex/jdbc/JdbcImporterTest.java index 187ad637..fb691b47 100644 --- a/jdbc/src/test/java/io/github/dfa1/vortex/jdbc/JdbcImporterTest.java +++ b/jdbc/src/test/java/io/github/dfa1/vortex/jdbc/JdbcImporterTest.java @@ -116,7 +116,7 @@ void roundTripsTemporalSqlTypesViaExtensions(@TempDir Path tmp) throws Exception // Then — schema declares the three extension dtypes try (VortexReader reader = VortexReader.open(vortex, - io.github.dfa1.vortex.encoding.Registry.loadAll())) { + io.github.dfa1.vortex.reader.ReadRegistry.loadAll())) { DType.Struct schema = (DType.Struct) reader.dtype(); assertThat(schema.fieldTypes().get(1)) .isEqualTo(io.github.dfa1.vortex.extension.DateExtension.INSTANCE.dtype(false)); @@ -175,7 +175,7 @@ void roundTripsNullableExtensionColumns(@TempDir Path tmp) throws Exception { // Then — schema declares nullable=true on every ext column; data round-trips // with row 2 marked invalid in each MaskedArray try (VortexReader reader = VortexReader.open(vortex, - io.github.dfa1.vortex.encoding.Registry.loadAll())) { + io.github.dfa1.vortex.reader.ReadRegistry.loadAll())) { DType.Struct schema = (DType.Struct) reader.dtype(); assertThat(((DType.Extension) schema.fieldTypes().get(1)).nullable()).isTrue(); assertThat(((DType.Extension) schema.fieldTypes().get(2)).nullable()).isTrue(); @@ -258,7 +258,7 @@ void roundTripsUuidColumnViaExtension(@TempDir Path tmp) throws Exception { // Then — column maps to vortex.uuid extension; values round-trip exactly try (VortexReader reader = VortexReader.open(vortex, - io.github.dfa1.vortex.encoding.Registry.loadAll())) { + io.github.dfa1.vortex.reader.ReadRegistry.loadAll())) { DType.Struct schema = (DType.Struct) reader.dtype(); assertThat(schema.fieldTypes().get(1)) .isEqualTo(io.github.dfa1.vortex.extension.UuidExtension.INSTANCE.dtype(false)); diff --git a/performance/src/main/java/io/github/dfa1/vortex/performance/ParquetVsVortexReadBenchmark.java b/performance/src/main/java/io/github/dfa1/vortex/performance/ParquetVsVortexReadBenchmark.java index 0245b0a3..480159ca 100644 --- a/performance/src/main/java/io/github/dfa1/vortex/performance/ParquetVsVortexReadBenchmark.java +++ b/performance/src/main/java/io/github/dfa1/vortex/performance/ParquetVsVortexReadBenchmark.java @@ -8,7 +8,7 @@ import dev.hardwood.schema.ColumnProjection; import io.github.dfa1.vortex.core.array.DoubleArray; import io.github.dfa1.vortex.core.array.IntArray; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.VortexReader; import io.github.dfa1.vortex.parquet.ParquetImporter; import io.github.dfa1.vortex.reader.Chunk; @@ -74,7 +74,7 @@ public class ParquetVsVortexReadBenchmark { private Path parquetFile; private Path vortexFile; - private Registry registry; + ReadRegistry registry; private static Path downloadIfAbsent(Path dest, String url) throws Exception { if (Files.exists(dest)) { @@ -96,7 +96,7 @@ private static Path downloadIfAbsent(Path dest, String url) throws Exception { @Setup(Level.Trial) public void setup() throws Exception { - registry = Registry.loadAll(); + registry = ReadRegistry.loadAll(); synchronized (SETUP_LOCK) { if (sharedParquetFile == null) { String override = System.getProperty("bench.parquet"); diff --git a/performance/src/main/java/io/github/dfa1/vortex/performance/RustVsJavaReadBenchmark.java b/performance/src/main/java/io/github/dfa1/vortex/performance/RustVsJavaReadBenchmark.java index 979cc4a7..ad0bf1a2 100644 --- a/performance/src/main/java/io/github/dfa1/vortex/performance/RustVsJavaReadBenchmark.java +++ b/performance/src/main/java/io/github/dfa1/vortex/performance/RustVsJavaReadBenchmark.java @@ -13,7 +13,7 @@ import io.github.dfa1.vortex.core.array.DoubleArray; import io.github.dfa1.vortex.core.array.LongArray; import io.github.dfa1.vortex.core.array.VarBinArray; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.VortexReader; import io.github.dfa1.vortex.reader.Chunk; import io.github.dfa1.vortex.writer.VortexWriter; @@ -135,7 +135,7 @@ public class RustVsJavaReadBenchmark { private Path benchFile; private Path cascadingFile; - private Registry registry; + ReadRegistry registry; private BufferAllocator allocator; private static double round(double v) { @@ -144,7 +144,7 @@ private static double round(double v) { @Setup(Level.Trial) public void setup() throws IOException { - registry = Registry.loadAll(); + registry = ReadRegistry.loadAll(); allocator = ArrowAllocation.rootAllocator(); synchronized (FILE_LOCK) { diff --git a/performance/src/main/java/io/github/dfa1/vortex/performance/RustWritesJavaReadsBigFileBenchmark.java b/performance/src/main/java/io/github/dfa1/vortex/performance/RustWritesJavaReadsBigFileBenchmark.java index 65c9b1d5..bc5c1435 100644 --- a/performance/src/main/java/io/github/dfa1/vortex/performance/RustWritesJavaReadsBigFileBenchmark.java +++ b/performance/src/main/java/io/github/dfa1/vortex/performance/RustWritesJavaReadsBigFileBenchmark.java @@ -10,7 +10,7 @@ import dev.vortex.jni.NativeLoader; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.VortexReader; import io.github.dfa1.vortex.reader.Chunk; import org.apache.arrow.c.ArrowArray; @@ -95,12 +95,12 @@ public class RustWritesJavaReadsBigFileBenchmark { private Path benchFile; private boolean ownFile; - private Registry registry; + ReadRegistry registry; private BufferAllocator allocator; @Setup(Level.Trial) public void setup() throws IOException { - registry = Registry.loadAll(); + registry = ReadRegistry.loadAll(); allocator = ArrowAllocation.rootAllocator(); String externalFile = System.getProperty("vortex.bench.bigfile"); diff --git a/performance/src/main/java/io/github/dfa1/vortex/performance/TaxiLayoutInspector.java b/performance/src/main/java/io/github/dfa1/vortex/performance/TaxiLayoutInspector.java index f635a79b..6d9331b2 100644 --- a/performance/src/main/java/io/github/dfa1/vortex/performance/TaxiLayoutInspector.java +++ b/performance/src/main/java/io/github/dfa1/vortex/performance/TaxiLayoutInspector.java @@ -8,7 +8,7 @@ import dev.vortex.api.Session; import dev.vortex.arrow.ArrowAllocation; import dev.vortex.jni.NativeLoader; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.inspect.VortexInspector; import io.github.dfa1.vortex.reader.VortexReader; import io.github.dfa1.vortex.parquet.ImportOptions; @@ -140,7 +140,7 @@ public static void main(String[] args) throws Exception { } private static void inspect(Path path) throws IOException { - try (VortexReader r = VortexReader.open(path, Registry.loadAll())) { + try (VortexReader r = VortexReader.open(path, ReadRegistry.loadAll())) { System.out.println(VortexInspector.inspect(r)); } } diff --git a/pom.xml b/pom.xml index ae1c9507..30dd3db2 100644 --- a/pom.xml +++ b/pom.xml @@ -85,10 +85,22 @@ vortex-core ${project.version} + + io.github.dfa1.vortex + vortex-core + ${project.version} + test-jar + + + io.github.dfa1.vortex + vortex-reader + ${project.version} + io.github.dfa1.vortex vortex-reader ${project.version} + test-jar io.github.dfa1.vortex diff --git a/reader/pom.xml b/reader/pom.xml index f2faece3..e8e19450 100644 --- a/reader/pom.xml +++ b/reader/pom.xml @@ -25,6 +25,14 @@ flatbuffers-java + + + io.github.dfa1.vortex + vortex-core + test-jar + test + + org.junit.jupiter junit-jupiter @@ -41,4 +49,23 @@ test + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + publish-test-jar + + test-jar + + + + + + diff --git a/core/src/main/java/io/github/dfa1/vortex/encoding/FlatSegmentDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/FlatSegmentDecoder.java similarity index 84% rename from core/src/main/java/io/github/dfa1/vortex/encoding/FlatSegmentDecoder.java rename to reader/src/main/java/io/github/dfa1/vortex/reader/FlatSegmentDecoder.java index 2ec053cc..9dd47d0e 100644 --- a/core/src/main/java/io/github/dfa1/vortex/encoding/FlatSegmentDecoder.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/FlatSegmentDecoder.java @@ -1,9 +1,14 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.reader; -import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.ArrayStats; +import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.encoding.EncodingId; import io.github.dfa1.vortex.fbs.Buffer; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.reader.decode.KnownArrayNode; +import io.github.dfa1.vortex.reader.decode.UnknownArrayNode; import java.lang.foreign.MemorySegment; import java.lang.foreign.SegmentAllocator; @@ -12,22 +17,21 @@ import java.util.List; /// Parses a flat segment from the memory-mapped file region and dispatches to the -/// appropriate {@link Encoding} via the registry. +/// appropriate decoder via the {@link ReadRegistry}. /// ///

Flat segment wire format: /// {@code buffer_data... | FlatBuffer(Array) | u32 LE = FlatBuffer byte length} /// -///

Registry is pure dispatch ({@code register}/{@code decode}); this class owns -/// all file-format knowledge: FlatBuffer parsing, buffer-offset arithmetic, and -/// encoding-spec lookup. +///

{@link ReadRegistry} is pure dispatch; this class owns all file-format knowledge: +/// FlatBuffer parsing, buffer-offset arithmetic, and encoding-spec lookup. public final class FlatSegmentDecoder { - private final Registry registry; + private final ReadRegistry registry; /// Creates a decoder backed by the given registry. /// - /// @param registry the registry used to dispatch to concrete {@link Encoding} impls - public FlatSegmentDecoder(Registry registry) { + /// @param registry the registry used to dispatch to concrete decoder impls + public FlatSegmentDecoder(ReadRegistry registry) { this.registry = registry; } @@ -80,7 +84,6 @@ private static ArrayNode convertArrayNode( bufferIndices[i] = fbs.buffers(i); } - // metadataAsByteBuffer() returns duplicate with position=vectorStart; slice to normalize to 0 ByteBuffer rawMeta = fbs.metadataAsByteBuffer(); ByteBuffer meta = (rawMeta != null) ? rawMeta.slice() : null; ArrayStats stats = ArrayStats.fromFbs(fbs.stats()); diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/ReadRegistry.java b/reader/src/main/java/io/github/dfa1/vortex/reader/ReadRegistry.java new file mode 100644 index 00000000..e111d661 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/ReadRegistry.java @@ -0,0 +1,189 @@ +package io.github.dfa1.vortex.reader; + +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.UnknownArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.reader.decode.EncodingDecoder; +import io.github.dfa1.vortex.reader.decode.KnownArrayNode; +import io.github.dfa1.vortex.reader.decode.UnknownArrayNode; + +import java.lang.foreign.MemorySegment; +import java.util.HashMap; +import java.util.Map; +import java.util.ServiceLoader; + +/// Read-side registry: maps {@link EncodingId} to {@link EncodingDecoder} implementations. +/// +///

Instances are immutable after construction. Build one via {@link #builder()} or +/// via the {@link #loadAll()} and {@link #empty()} convenience factories. +public final class ReadRegistry { + + private final Map decoders; + private final boolean allowUnknown; + + private ReadRegistry(Map decoders, boolean allowUnknown) { + this.decoders = Map.copyOf(decoders); + this.allowUnknown = allowUnknown; + } + + /// Loads all service-discovered {@link EncodingDecoder} implementations. + /// + /// @return an immutable {@link ReadRegistry} populated with all service-loaded decoders + public static ReadRegistry loadAll() { + return builder().registerServiceLoaded().build(); + } + + /// Creates an empty registry with no decoders registered. + /// + /// @return a new empty immutable {@link ReadRegistry} + public static ReadRegistry empty() { + return builder().build(); + } + + /// Returns a new {@link Builder}. + /// + /// @return a fresh builder + public static Builder builder() { + return new Builder(); + } + + /// Returns whether passthrough decode for unknown encoding ids is enabled. + /// + /// @return {@code true} if unknown encodings are silently wrapped as + /// {@link io.github.dfa1.vortex.core.array.UnknownArray} + public boolean isAllowUnknown() { + return allowUnknown; + } + + /// Returns {@code true} if a decoder is registered for the given id. + /// + /// @param encodingId the encoding id to query + /// @return {@code true} if a decoder is registered + public boolean hasDecoder(EncodingId encodingId) { + return decoders.containsKey(encodingId); + } + + /// Decodes the array described by {@code ctx}. + /// + /// @param ctx the decode context + /// @return the decoded {@link Array} + public Array decode(DecodeContext ctx) { + ArrayNode node = ctx.node(); + EncodingDecoder decoder = switch (node) { + case KnownArrayNode k -> decoders.get(k.encodingId()); + case UnknownArrayNode _ -> null; + }; + if (decoder != null) { + return decoder.decode(ctx); + } + if (allowUnknown) { + return decodeUnknown(ctx, node); + } + String id = switch (node) { + case KnownArrayNode k -> k.encodingId().id(); + case UnknownArrayNode u -> u.rawEncodingId(); + }; + throw new VortexException("no decoder registered for " + id); + } + + /// Decodes the array described by {@code ctx} and returns its primary backing segment. + /// + /// @param ctx the decode context + /// @return the primary {@link MemorySegment} of the decoded array + public MemorySegment decodeAsSegment(DecodeContext ctx) { + ArrayNode node = ctx.node(); + EncodingDecoder decoder = switch (node) { + case KnownArrayNode k -> decoders.get(k.encodingId()); + case UnknownArrayNode _ -> null; + }; + if (decoder != null) { + return ArraySegments.of(decoder.decode(ctx)); + } + String id = switch (node) { + case KnownArrayNode k -> k.encodingId().id(); + case UnknownArrayNode u -> u.rawEncodingId(); + }; + throw new VortexException("no decoder registered for " + id + " (or encoding has no primary segment)"); + } + + private static UnknownArray decodeUnknown(DecodeContext ctx, ArrayNode node) { + String rawId = switch (node) { + case KnownArrayNode k -> k.encodingId().id(); + case UnknownArrayNode u -> u.rawEncodingId(); + }; + MemorySegment[] bufs = new MemorySegment[node.bufferIndices().length]; + for (int i = 0; i < bufs.length; i++) { + bufs[i] = ctx.buffer(i); + } + Array[] children = new Array[node.children().length]; + for (int i = 0; i < children.length; i++) { + ArrayNode childNode = node.children()[i]; + DecodeContext childCtx = new DecodeContext( + childNode, ctx.dtype(), ctx.rowCount(), + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + children[i] = decodeUnknown(childCtx, childNode); + } + return new UnknownArray( + rawId, ctx.dtype(), ctx.rowCount(), + node.metadata(), bufs, children); + } + + /// Builder for {@link ReadRegistry}. + /// + /// Not thread-safe. Build once, use everywhere — the produced {@link ReadRegistry} is immutable. + public static final class Builder { + + private final Map decoders = new HashMap<>(); + private boolean allowUnknown = false; + + private Builder() { + } + + /// Registers a decoder. + /// + /// @param decoder the {@link EncodingDecoder} to register + /// @return this builder, for chaining + /// @throws VortexException if a decoder for the same id is already registered + public Builder register(EncodingDecoder decoder) { + EncodingDecoder old = decoders.put(decoder.encodingId(), decoder); + if (old != null) { + throw new VortexException("decoder %s already registered".formatted(decoder.encodingId())); + } + return this; + } + + /// Registers every {@link EncodingDecoder} discovered via {@link ServiceLoader}. + /// + /// @return this builder, for chaining + /// @throws VortexException if a service-loaded entry collides with one already registered + public Builder registerServiceLoaded() { + for (EncodingDecoder decoder : ServiceLoader.load(EncodingDecoder.class)) { + register(decoder); + } + return this; + } + + /// Enable passthrough decode for unknown encoding ids. + /// + ///

Default is strict: unknown ids throw {@link VortexException}. When enabled, unknown + /// nodes are wrapped as {@link io.github.dfa1.vortex.core.array.UnknownArray}. + /// Mirrors Rust {@code VortexSession::allow_unknown()}. + /// + /// @return this builder, for chaining + public Builder allowUnknown() { + this.allowUnknown = true; + return this; + } + + /// Builds an immutable {@link ReadRegistry}. + /// + /// @return the immutable registry + public ReadRegistry build() { + return new ReadRegistry(decoders, allowUnknown); + } + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/ScanIterator.java b/reader/src/main/java/io/github/dfa1/vortex/reader/ScanIterator.java index 6561bbc0..b4bc82ec 100644 --- a/reader/src/main/java/io/github/dfa1/vortex/reader/ScanIterator.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/ScanIterator.java @@ -22,8 +22,6 @@ import io.github.dfa1.vortex.core.array.StructArray; import io.github.dfa1.vortex.core.array.VarBinArray; import io.github.dfa1.vortex.encoding.EncodingId; -import io.github.dfa1.vortex.encoding.Registry; -import io.github.dfa1.vortex.encoding.FlatSegmentDecoder; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; @@ -71,7 +69,7 @@ public final class ScanIterator implements Iterator, AutoCloseable { private static final ValueLayout.OfLong LE_LONG = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); private final VortexHandle file; - private final Registry registry; + private final ReadRegistry registry; private final ScanOptions options; private List chunks; @@ -83,7 +81,7 @@ public final class ScanIterator implements Iterator, AutoCloseable { private Chunk openChunk; private boolean closed; - public ScanIterator(VortexHandle file, Registry registry, ScanOptions options) { + public ScanIterator(VortexHandle file, ReadRegistry registry, ScanOptions options) { this.file = file; this.registry = registry; this.options = options; @@ -478,8 +476,7 @@ private Array decodeFlat(Layout flat, DType dtype, SegmentAllocator arena) { } int segIdx = flat.segments().getFirst(); SegmentSpec spec = file.footer().segmentSpecs().get(segIdx); - MemorySegment seg = file.slice(spec.offset(), spec.length()); - return new FlatSegmentDecoder(registry).decode(seg, file.footer().arraySpecs(), dtype, flat.rowCount(), arena); + return file.decodeFlatSegment(spec, dtype, flat.rowCount(), arena); } private Array decodeDictLayout(Layout dictLayout, DType dtype, SegmentAllocator arena) { diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHandle.java b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHandle.java index 163ff16d..1a14a317 100644 --- a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHandle.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHandle.java @@ -3,10 +3,12 @@ import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.Footer; import io.github.dfa1.vortex.core.Layout; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.core.SegmentSpec; +import io.github.dfa1.vortex.core.array.Array; import java.io.Closeable; import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; /// Common interface for handles to a Vortex file, regardless of storage backend. /// @@ -23,6 +25,17 @@ public interface VortexHandle extends Closeable { long fileSize(); + /// Typed accessor for the common pattern "slice a flat segment by its {@link SegmentSpec} + /// and decode the encoded array contained therein." Replaces the raw {@link #slice} + /// escape hatch for read-side consumers (scan, inspector, TUI). + /// + /// @param spec the segment spec to read from + /// @param dtype logical type of the decoded array + /// @param rowCount number of logical rows in the segment + /// @param arena allocator for decode output; lifetime matches the caller's chunk epoch + /// @return the decoded array + Array decodeFlatSegment(SegmentSpec spec, DType dtype, long rowCount, SegmentAllocator arena); + /// Returns a read-only view of bytes `[offset, offset+length)` within the file. /// Writes through the returned segment throw `UnsupportedOperationException`. /// @@ -43,7 +56,7 @@ public interface VortexHandle extends Closeable { ScanIterator scan(ScanOptions options); - /// Returns the {@link Registry} this handle was opened with. + /// Returns the {@link ReadRegistry} this handle was opened with. /// ///

Internal escape hatch. Exposed for tooling /// (e.g. the inspector's dictionary preview) that needs to decode an @@ -52,7 +65,7 @@ public interface VortexHandle extends Closeable { /// without deprecation. /// /// @return the registry used to resolve encoding ids during scan - Registry registry(); + ReadRegistry registry(); @Override void close(); diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHttpReader.java b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHttpReader.java index f7040e2b..7446a417 100644 --- a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHttpReader.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHttpReader.java @@ -5,7 +5,6 @@ import io.github.dfa1.vortex.core.Layout; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.VortexFormat; -import io.github.dfa1.vortex.encoding.Registry; import io.github.dfa1.vortex.fbs.Postscript; import java.io.IOException; @@ -44,12 +43,12 @@ public final class VortexHttpReader implements VortexHandle { private final Footer footer; private final DType dtype; private final Layout layout; - private final Registry registry; + private final ReadRegistry registry; private VortexHttpReader( URI uri, long fileSize, int version, Footer footer, DType dtype, Layout layout, - Registry registry + ReadRegistry registry ) { this.uri = uri; this.arena = Arena.ofConfined(); @@ -62,10 +61,10 @@ private VortexHttpReader( } public static VortexHttpReader open(URI uri) throws IOException { - return open(uri, Registry.loadAll()); + return open(uri, ReadRegistry.loadAll()); } - public static VortexHttpReader open(URI uri, Registry registry) throws IOException { + public static VortexHttpReader open(URI uri, ReadRegistry registry) throws IOException { // Single suffix Range request — Content-Range response header gives us fileSize. // Avoids a separate HEAD round trip. TailFetch tf = fetchTail(uri); @@ -217,6 +216,17 @@ public long fileSize() { // ── HTTP helpers ────────────────────────────────────────────────────────── + @Override + public io.github.dfa1.vortex.core.array.Array decodeFlatSegment( + io.github.dfa1.vortex.core.SegmentSpec spec, + DType dtype, long rowCount, + java.lang.foreign.SegmentAllocator arenaOut + ) { + MemorySegment seg = slice(spec.offset(), spec.length()); + return new FlatSegmentDecoder(registry) + .decode(seg, footer.arraySpecs(), dtype, rowCount, arenaOut); + } + /// Fetches bytes `[offset, offset+length)` via HTTP Range and returns them /// as an off-heap [MemorySegment] tied to this reader's [Arena]. @Override @@ -239,7 +249,7 @@ public ScanIterator scan(ScanOptions options) { } @Override - public Registry registry() { + public ReadRegistry registry() { return registry; } diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexReader.java b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexReader.java index f386eeb0..ddbd6a4d 100644 --- a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexReader.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexReader.java @@ -7,7 +7,6 @@ import io.github.dfa1.vortex.core.SegmentSpec; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.VortexFormat; -import io.github.dfa1.vortex.encoding.Registry; import java.io.IOException; import java.lang.foreign.Arena; @@ -38,12 +37,12 @@ public final class VortexReader implements VortexHandle { private final Footer footer; private final DType dtype; private final Layout layout; - private final Registry registry; + private final ReadRegistry registry; private VortexReader( Arena arena, MemorySegment fileSegment, long fileSize, int version, Footer footer, DType dtype, Layout layout, - Registry registry + ReadRegistry registry ) { this.arena = arena; this.fileSegment = fileSegment; @@ -58,10 +57,10 @@ private VortexReader( /// Open a Vortex file. Memory-maps the entire file; all subsequent reads /// are zero-copy slices. Call [#close()] when done. public static VortexReader open(Path path) throws IOException { - return open(path, Registry.loadAll()); + return open(path, ReadRegistry.loadAll()); } - public static VortexReader open(Path path, Registry registry) throws IOException { + public static VortexReader open(Path path, ReadRegistry registry) throws IOException { Arena arena = Arena.ofConfined(); try (var channel = FileChannel.open(path, StandardOpenOption.READ)) { long size = channel.size(); @@ -80,7 +79,7 @@ public static VortexReader open(Path path, Registry registry) throws IOException } private static VortexReader parse( - MemorySegment seg, long size, Arena arena, Registry registry + MemorySegment seg, long size, Arena arena, ReadRegistry registry ) { long bodyBytes = size - VortexFormat.TRAILER_SIZE; var trailerSeg = seg.asSlice(bodyBytes, VortexFormat.TRAILER_SIZE); @@ -162,7 +161,7 @@ public ScanIterator scan(ScanOptions options) { } @Override - public Registry registry() { + public ReadRegistry registry() { return registry; } @@ -233,6 +232,17 @@ private ArrayStats readFlatStats(Layout flat) { return ArrayStats.fromFbs(root.stats()); } + @Override + public io.github.dfa1.vortex.core.array.Array decodeFlatSegment( + io.github.dfa1.vortex.core.SegmentSpec spec, + DType dtype, long rowCount, + java.lang.foreign.SegmentAllocator arena + ) { + MemorySegment seg = fileSegment.asSlice(spec.offset(), spec.length()).asReadOnly(); + return new FlatSegmentDecoder(registry) + .decode(seg, footer.arraySpecs(), dtype, rowCount, arena); + } + /// Zero-copy read-only slice of the memory-mapped file. @Override public MemorySegment slice(long offset, long length) { diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpEncodingDecoder.java new file mode 100644 index 00000000..c684da65 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpEncodingDecoder.java @@ -0,0 +1,170 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.ALPMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.alp}. +public final class AlpEncodingDecoder implements EncodingDecoder { + private static final double[] F10_F64 = {1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, 1e20, 1e21, 1e22, 1e23}; + private static final double[] IF10_F64 = {1e-0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15, 1e-16, 1e-17, 1e-18, 1e-19, 1e-20, 1e-21, 1e-22, 1e-23}; + private static final float[] F10_F32 = {1e0f, 1e1f, 1e2f, 1e3f, 1e4f, 1e5f, 1e6f, 1e7f, 1e8f, 1e9f, 1e10f}; + private static final float[] IF10_F32 = {1e-0f, 1e-1f, 1e-2f, 1e-3f, 1e-4f, 1e-5f, 1e-6f, 1e-7f, 1e-8f, 1e-9f, 1e-10f}; + private static final DType I64_DTYPE = new DType.Primitive(PType.I64, false); + private static final DType I32_DTYPE = new DType.Primitive(PType.I32, false); + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public AlpEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ALP; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return p.ptype() == PType.F64 || p.ptype() == PType.F32; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + ALPMetadata meta; + if (rawMeta == null || !rawMeta.hasRemaining()) { + meta = new ALPMetadata(0, 0, null); + } else { + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = ALPMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_ALP, "invalid metadata", e); + } + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_ALP, "expected primitive dtype, got " + ctx.dtype()); + } + + int expE = meta.exp_e(); + int expF = meta.exp_f(); + PType ptype = p.ptype(); + long n = ctx.rowCount(); + + return switch (ptype) { + case F64 -> decodeF64(ctx, meta, expE, expF, n); + case F32 -> decodeF32(ctx, meta, expE, expF, n); + default -> throw new VortexException(EncodingId.VORTEX_ALP, "unsupported dtype " + ptype); + }; + } + + private static Array decodeF64(DecodeContext ctx, ALPMetadata meta, int expE, int expF, long n) { + double df = F10_F64[expF]; + double de = IF10_F64[expE]; + + MemorySegment src = ctx.decodeChildSegment(0, I64_DTYPE, n); + MemorySegment buf = src.isReadOnly() ? ctx.arena().allocate(n * 8, 8) : src; + if (src.isReadOnly()) { + long srcCap = SegmentBroadcast.capacity(src, 8); + if (srcCap == n) { + for (long i = 0; i < n; i++) { + buf.setAtIndex(PTypeIO.LE_DOUBLE, i, (double) src.getAtIndex(PTypeIO.LE_LONG, i) * df * de); + } + } else { + for (long i = 0; i < n; i++) { + buf.setAtIndex(PTypeIO.LE_DOUBLE, i, (double) src.getAtIndex(PTypeIO.LE_LONG, i % srcCap) * df * de); + } + } + } else { + for (long i = 0; i < n; i++) { + buf.setAtIndex(PTypeIO.LE_DOUBLE, i, (double) buf.getAtIndex(PTypeIO.LE_LONG, i) * df * de); + } + } + + if (meta.patches() != null) { + applyPatches(ctx, meta.patches(), buf, 8); + } + + return new DoubleArray(ctx.dtype(), n, buf.asReadOnly()); + } + + private static Array decodeF32(DecodeContext ctx, ALPMetadata meta, int expE, int expF, long n) { + float df = F10_F32[expF]; + float de = IF10_F32[expE]; + + MemorySegment src32 = ctx.decodeChildSegment(0, I32_DTYPE, n); + MemorySegment buf32 = src32.isReadOnly() ? ctx.arena().allocate(n * 4, 4) : src32; + if (src32.isReadOnly()) { + long srcCap = SegmentBroadcast.capacity(src32, 4); + if (srcCap == n) { + for (long i = 0; i < n; i++) { + buf32.setAtIndex(PTypeIO.LE_FLOAT, i, (float) src32.getAtIndex(PTypeIO.LE_INT, i) * df * de); + } + } else { + for (long i = 0; i < n; i++) { + buf32.setAtIndex(PTypeIO.LE_FLOAT, i, (float) src32.getAtIndex(PTypeIO.LE_INT, i % srcCap) * df * de); + } + } + } else { + for (long i = 0; i < n; i++) { + buf32.setAtIndex(PTypeIO.LE_FLOAT, i, (float) buf32.getAtIndex(PTypeIO.LE_INT, i) * df * de); + } + } + + if (meta.patches() != null) { + applyPatches(ctx, meta.patches(), buf32, 4); + } + + return new FloatArray(ctx.dtype(), n, buf32.asReadOnly()); + } + + private static void applyPatches(DecodeContext ctx, PatchesMetadata pm, MemorySegment out, int elemBytes) { + long numPatches = pm.len(); + long offset = pm.offset(); + PType idxPtype = PType.fromOrdinal(pm.indices_ptype().value()); + int idxBytes = idxPtype.byteSize(); + + MemorySegment idxSeg = ctx.decodeChildSegment(1, new DType.Primitive(idxPtype, false), numPatches); + MemorySegment valSeg = ctx.decodeChildSegment(2, ctx.dtype(), numPatches); + + long idxCap = SegmentBroadcast.capacity(idxSeg, idxBytes); + long valCap = SegmentBroadcast.capacity(valSeg, elemBytes); + if (idxCap >= numPatches && valCap >= numPatches) { + for (long i = 0; i < numPatches; i++) { + long absIdx = readUnsigned(idxSeg, i * idxBytes, idxPtype) - offset; + MemorySegment.copy(valSeg, i * elemBytes, out, absIdx * elemBytes, elemBytes); + } + } else { + for (long i = 0; i < numPatches; i++) { + long absIdx = readUnsigned(idxSeg, (i % idxCap) * idxBytes, idxPtype) - offset; + MemorySegment.copy(valSeg, (i % valCap) * elemBytes, out, absIdx * elemBytes, elemBytes); + } + } + } + + private static long readUnsigned(MemorySegment seg, long off, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case U64 -> seg.get(PTypeIO.LE_LONG, off); + default -> throw new VortexException(EncodingId.VORTEX_ALP, "non-unsigned patch index ptype " + ptype); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpRdEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpRdEncodingDecoder.java new file mode 100644 index 00000000..be318f15 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpRdEncodingDecoder.java @@ -0,0 +1,175 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.ALPRDMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.alprd}. +public final class AlpRdEncodingDecoder implements EncodingDecoder { + private static final DType U16_DTYPE = new DType.Primitive(PType.U16, false); + private static final DType U32_DTYPE = new DType.Primitive(PType.U32, false); + private static final DType U64_DTYPE = new DType.Primitive(PType.U64, false); + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public AlpRdEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ALPRD; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return p.ptype() == PType.F32 || p.ptype() == PType.F64; + } + + @Override + public Array decode(DecodeContext ctx) { + ALPRDMetadata meta = parseMeta(ctx); + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_ALPRD, + "expected primitive dtype, got " + ctx.dtype()); + } + + int rightBitWidth = meta.right_bit_width(); + int dictLen = meta.dict_len(); + short[] dict = new short[dictLen]; + for (int i = 0; i < dictLen; i++) { + dict[i] = (short) (meta.dict().get(i) & 0xFFFF); + } + + long n = ctx.rowCount(); + PType ptype = p.ptype(); + + return switch (ptype) { + case F64 -> decodeF64(ctx, meta, dict, rightBitWidth, n); + case F32 -> decodeF32(ctx, meta, dict, rightBitWidth, n); + default -> throw new VortexException(EncodingId.VORTEX_ALPRD, "unsupported dtype " + ptype); + }; + } + + private static Array decodeF64(DecodeContext ctx, ALPRDMetadata meta, short[] dict, int rightBitWidth, long n) { + MemorySegment leftSeg = ctx.decodeChildSegment(0, U16_DTYPE, n); + MemorySegment rightSeg = ctx.decodeChildSegment(1, U64_DTYPE, n); + long leftCap = SegmentBroadcast.capacity(leftSeg, 2); + long rightCap = SegmentBroadcast.capacity(rightSeg, 8); + MemorySegment out = ctx.arena().allocate(n * Long.BYTES, Long.BYTES); + + for (long i = 0; i < n; i++) { + int code = Short.toUnsignedInt(leftSeg.getAtIndex(PTypeIO.LE_SHORT, i % leftCap)); + long leftBits = (long) (dict[code] & 0xFFFF) << rightBitWidth; + long rightBits = rightSeg.getAtIndex(PTypeIO.LE_LONG, i % rightCap); + out.setAtIndex(PTypeIO.LE_LONG, i, leftBits | rightBits); + } + + if (meta.patches() != null) { + applyPatchesF64(ctx, meta.patches(), out, rightSeg, rightCap, rightBitWidth); + } + + return new DoubleArray(ctx.dtype(), n, out.asReadOnly()); + } + + private static Array decodeF32(DecodeContext ctx, ALPRDMetadata meta, short[] dict, int rightBitWidth, long n) { + MemorySegment leftSeg = ctx.decodeChildSegment(0, U16_DTYPE, n); + MemorySegment rightSeg = ctx.decodeChildSegment(1, U32_DTYPE, n); + long leftCap = SegmentBroadcast.capacity(leftSeg, 2); + long rightCap = SegmentBroadcast.capacity(rightSeg, 4); + MemorySegment out = ctx.arena().allocate(n * Integer.BYTES, Integer.BYTES); + + for (long i = 0; i < n; i++) { + int code = Short.toUnsignedInt(leftSeg.getAtIndex(PTypeIO.LE_SHORT, i % leftCap)); + int leftBits = (dict[code] & 0xFFFF) << rightBitWidth; + int rightBits = rightSeg.getAtIndex(PTypeIO.LE_INT, i % rightCap); + out.setAtIndex(PTypeIO.LE_INT, i, leftBits | rightBits); + } + + if (meta.patches() != null) { + applyPatchesF32(ctx, meta.patches(), out, rightSeg, rightCap, rightBitWidth); + } + + return new FloatArray(ctx.dtype(), n, out.asReadOnly()); + } + + private static void applyPatchesF64(DecodeContext ctx, PatchesMetadata pm, + MemorySegment out, MemorySegment rightSeg, long rightCap, int rightBitWidth) { + long numPatches = pm.len(); + long offset = pm.offset(); + PType idxPtype = PType.fromOrdinal(pm.indices_ptype().value()); + + MemorySegment idxSeg = ctx.decodeChildSegment(2, new DType.Primitive(idxPtype, false), numPatches); + MemorySegment valSeg = ctx.decodeChildSegment(3, U16_DTYPE, numPatches); + int idxBytes = idxPtype.byteSize(); + long valCap = SegmentBroadcast.capacity(valSeg, 2); + + for (long j = 0; j < numPatches; j++) { + long absIdx = readUnsigned(idxSeg, SegmentBroadcast.elementOffset(idxSeg, j, idxBytes), idxPtype) - offset; + short actualLeftU16 = valSeg.getAtIndex(PTypeIO.LE_SHORT, j % valCap); + long leftBits = (long) (actualLeftU16 & 0xFFFF) << rightBitWidth; + long rightBits = rightSeg.getAtIndex(PTypeIO.LE_LONG, absIdx % rightCap); + out.setAtIndex(PTypeIO.LE_LONG, absIdx, leftBits | rightBits); + } + } + + private static void applyPatchesF32(DecodeContext ctx, PatchesMetadata pm, + MemorySegment out, MemorySegment rightSeg, long rightCap, int rightBitWidth) { + long numPatches = pm.len(); + long offset = pm.offset(); + PType idxPtype = PType.fromOrdinal(pm.indices_ptype().value()); + + MemorySegment idxSeg = ctx.decodeChildSegment(2, new DType.Primitive(idxPtype, false), numPatches); + MemorySegment valSeg = ctx.decodeChildSegment(3, U16_DTYPE, numPatches); + int idxBytes = idxPtype.byteSize(); + long valCap = SegmentBroadcast.capacity(valSeg, 2); + + for (long j = 0; j < numPatches; j++) { + long absIdx = readUnsigned(idxSeg, SegmentBroadcast.elementOffset(idxSeg, j, idxBytes), idxPtype) - offset; + short actualLeftU16 = valSeg.getAtIndex(PTypeIO.LE_SHORT, j % valCap); + int leftBits = (actualLeftU16 & 0xFFFF) << rightBitWidth; + int rightBits = rightSeg.getAtIndex(PTypeIO.LE_INT, absIdx % rightCap); + out.setAtIndex(PTypeIO.LE_INT, (int) absIdx, leftBits | rightBits); + } + } + + private static long readUnsigned(MemorySegment seg, long off, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case U64 -> seg.get(PTypeIO.LE_LONG, off); + default -> throw new VortexException(EncodingId.VORTEX_ALPRD, + "non-unsigned patch index ptype " + ptype); + }; + } + + private static ALPRDMetadata parseMeta(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null || !rawMeta.hasRemaining()) { + return new ALPRDMetadata(0, 0, java.util.List.of(), + io.github.dfa1.vortex.proto.PType.fromValue(PType.U16.ordinal()), null); + } + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + return ALPRDMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_ALPRD, "invalid metadata", e); + } + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ArrayNode.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ArrayNode.java new file mode 100644 index 00000000..2f41113e --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ArrayNode.java @@ -0,0 +1,42 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.ArrayStats; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.nio.ByteBuffer; + +/// Encoded array node as stored in a Flat layout segment. +/// In-file representation before decoding; mirrors the Go ArrayNode struct. +/// +/// Sealed: a node is either [KnownArrayNode] (id resolves to an [EncodingId]) or +/// [UnknownArrayNode] (id is an arbitrary string only meaningful for +/// {@link ReadRegistry#isAllowUnknown()} passthrough decode). +public sealed interface ArrayNode permits KnownArrayNode, UnknownArrayNode { + + /// Short factory for the common case: a node whose encoding id is well-known. + /// Mostly used by tests and helper code that converts an {@code EncodeNode} tree back into + /// an {@code ArrayNode} tree. + /// + /// @param encodingId the well-known encoding identifier + /// @param metadata encoding-specific metadata bytes, or {@code null} + /// @param children child nodes + /// @param bufferIndices segment buffer indices for this node + /// @param stats optional zone-map statistics + /// @return a {@link KnownArrayNode} with the given fields + static ArrayNode of(EncodingId encodingId, ByteBuffer metadata, ArrayNode[] children, + int[] bufferIndices, ArrayStats stats) { + return new KnownArrayNode(encodingId, metadata, children, bufferIndices, stats); + } + + /// @return encoding-specific metadata bytes, or {@code null} + ByteBuffer metadata(); + + /// @return child nodes + ArrayNode[] children(); + + /// @return segment buffer indices for this node + int[] bufferIndices(); + + /// @return optional zone-map statistics + ArrayStats stats(); +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BitpackedEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BitpackedEncodingDecoder.java new file mode 100644 index 00000000..dcb1db55 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BitpackedEncodingDecoder.java @@ -0,0 +1,524 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.BitPackedMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code fastlanes.bitpacked}. +public final class BitpackedEncodingDecoder implements EncodingDecoder { + private static final int[] FL_ORDER = {0, 4, 2, 6, 1, 5, 3, 7}; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public BitpackedEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_BITPACKED; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return switch (p.ptype()) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> true; + default -> false; + }; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + // proto3 elides default-valued fields, so BitPackedMetadata(0, 0, null) serialises + // to a 0-byte payload and the writer skips the empty vector. Treat absent metadata + // as all-defaults rather than rejecting — happens when bit_width=0 (constant + // residuals nested under FoR / RLE). + BitPackedMetadata meta; + if (rawMeta == null || !rawMeta.hasRemaining()) { + meta = new BitPackedMetadata(0, 0, null); + } else { + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = BitPackedMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.FASTLANES_BITPACKED, "invalid metadata", e); + } + } + + int bitWidth = meta.bit_width(); + int offset = meta.offset(); + PType ptype = ((DType.Primitive) ctx.dtype()).ptype(); + int typeBits = ptype.byteSize() * 8; + long rowCount = ctx.rowCount(); + + MemorySegment packed = ctx.buffer(0); + MemorySegment output = ctx.arena().allocate(rowCount * ptype.byteSize()); + fastlanesUnpackToSeg(packed, bitWidth, offset, typeBits, rowCount, output); + + if (meta.patches() != null) { + applyPatches(ctx, meta.patches(), output, ptype.byteSize()); + } + + return switch (ptype) { + case I64, U64 -> new LongArray(ctx.dtype(), rowCount, output); + case I32, U32 -> new IntArray(ctx.dtype(), rowCount, output); + case I16, U16 -> new ShortArray(ctx.dtype(), rowCount, output); + case I8, U8 -> new ByteArray(ctx.dtype(), rowCount, output); + default -> throw new VortexException(EncodingId.FASTLANES_BITPACKED, "unsupported ptype " + ptype); + }; + } + + private static void fastlanesUnpackToSeg( + MemorySegment buf, int bitWidth, int offset, int typeBits, long rowCount, + MemorySegment output) { + if (bitWidth == 0) { + return; + } + switch (typeBits) { + case 8 -> unpackLoop8(buf, bitWidth, offset, rowCount, output); + case 16 -> unpackLoop16(buf, bitWidth, offset, rowCount, output); + case 32 -> unpackLoop32(buf, bitWidth, offset, rowCount, output); + case 64 -> unpackLoop64(buf, bitWidth, offset, rowCount, output); + default -> + throw new VortexException(EncodingId.FASTLANES_BITPACKED, "unsupported typeBits: " + typeBits); + } + } + + private static void unpackLoop8(MemorySegment buf, int bitWidth, int offset, long rowCount, MemorySegment out) { + final int lanes = 128; + long totalElems = rowCount + offset; + int blockCount = (int) ((totalElems + 1023) / 1024); + long bitMask = (1L << bitWidth) - 1L; + + int[] shifts = new int[8]; + int[] remainingBits = new int[8]; + int[] currentBits = new int[8]; + long[] loMasks = new long[8]; + long[] hiMasks = new long[8]; + long[] currWordByteBase = new long[8]; + long[] nextWordByteBase = new long[8]; + long[] outRowByteOff = new long[8]; + for (int row = 0; row < 8; row++) { + int currWord = (row * bitWidth) / 8; + int nextWord = ((row + 1) * bitWidth) / 8; + shifts[row] = (row * bitWidth) % 8; + int rem = (nextWord > currWord) ? ((row + 1) * bitWidth) % 8 : 0; + remainingBits[row] = rem; + int curr = bitWidth - rem; + currentBits[row] = curr; + loMasks[row] = rem > 0 ? (1L << curr) - 1L : 0L; + hiMasks[row] = rem > 0 ? (1L << rem) - 1L : 0L; + currWordByteBase[row] = (long) lanes * currWord; + nextWordByteBase[row] = rem > 0 ? (long) lanes * nextWord : 0L; + int o = row / 8; + int s = row % 8; + outRowByteOff[row] = FL_ORDER[o] * 16 + s * 128; + } + + long blockByteOff = 0L; + long blockByteStride = 128L * bitWidth; + for (int block = 0; block < blockCount; block++, blockByteOff += blockByteStride) { + int blockLogicStart = block * 1024 - offset; + boolean fullBlock = blockLogicStart >= 0 && (long) blockLogicStart + 1023L < rowCount; + + if (fullBlock) { + for (int row = 0; row < 8; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + long outBase = blockLogicStart + outRowByteOff[row]; + long wordBase = blockByteOff + currWordByteBase[row]; + if (rem > 0) { + long hiBase = blockByteOff + nextWordByteBase[row]; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + for (int lane = 0; lane < lanes; lane++) { + long lo = (Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, wordBase + lane)) >>> shift) & loMask; + long hi = Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, hiBase + lane)) & hiMask; + out.set(ValueLayout.JAVA_BYTE, outBase + lane, (byte) (lo | (hi << curr))); + } + } else { + for (int lane = 0; lane < lanes; lane++) { + out.set(ValueLayout.JAVA_BYTE, outBase + lane, + (byte) ((Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, wordBase + lane)) >>> shift) & bitMask)); + } + } + } + } else { + for (int row = 0; row < 8; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + int o = row / 8; + int s = row % 8; + int baseIdx = blockLogicStart + FL_ORDER[o] * 16 + s * 128; + long wordBase = blockByteOff + currWordByteBase[row]; + long hiBase = rem > 0 ? blockByteOff + nextWordByteBase[row] : 0L; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + for (int lane = 0; lane < lanes; lane++) { + int logicalIdx = baseIdx + lane; + if (logicalIdx < 0 || logicalIdx >= rowCount) { + continue; + } + long src = Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, wordBase + lane)); + long value; + if (rem > 0) { + long lo = (src >>> shift) & loMask; + long hi = Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, hiBase + lane)) & hiMask; + value = lo | (hi << curr); + } else { + value = (src >>> shift) & bitMask; + } + out.set(ValueLayout.JAVA_BYTE, logicalIdx, (byte) value); + } + } + } + } + } + + private static void unpackLoop16(MemorySegment buf, int bitWidth, int offset, long rowCount, MemorySegment out) { + final int lanes = 64; + long totalElems = rowCount + offset; + int blockCount = (int) ((totalElems + 1023) / 1024); + long bitMask = (1L << bitWidth) - 1L; + + int[] shifts = new int[16]; + int[] remainingBits = new int[16]; + int[] currentBits = new int[16]; + long[] loMasks = new long[16]; + long[] hiMasks = new long[16]; + long[] currWordByteBase = new long[16]; + long[] nextWordByteBase = new long[16]; + long[] outRowByteOff = new long[16]; + for (int row = 0; row < 16; row++) { + int currWord = (row * bitWidth) / 16; + int nextWord = ((row + 1) * bitWidth) / 16; + shifts[row] = (row * bitWidth) % 16; + int rem = (nextWord > currWord) ? ((row + 1) * bitWidth) % 16 : 0; + remainingBits[row] = rem; + int curr = bitWidth - rem; + currentBits[row] = curr; + loMasks[row] = rem > 0 ? (1L << curr) - 1L : 0L; + hiMasks[row] = rem > 0 ? (1L << rem) - 1L : 0L; + currWordByteBase[row] = (long) lanes * currWord * 2L; + nextWordByteBase[row] = rem > 0 ? (long) lanes * nextWord * 2L : 0L; + int o = row / 8; + int s = row % 8; + outRowByteOff[row] = (long) (FL_ORDER[o] * 16 + s * 128) * 2L; + } + + long blockByteOff = 0L; + long blockByteStride = 128L * bitWidth; + for (int block = 0; block < blockCount; block++, blockByteOff += blockByteStride) { + int blockLogicStart = block * 1024 - offset; + boolean fullBlock = blockLogicStart >= 0 && (long) blockLogicStart + 1023L < rowCount; + long blockOutByteBase = (long) blockLogicStart * 2L; + + if (fullBlock) { + for (int row = 0; row < 16; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + long outBase = blockOutByteBase + outRowByteOff[row]; + long wordBase = blockByteOff + currWordByteBase[row]; + if (rem > 0) { + long hiBase = blockByteOff + nextWordByteBase[row]; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 2L) { + long lo = (Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, wordBase + laneOff)) >>> shift) & loMask; + long hi = Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, hiBase + laneOff)) & hiMask; + out.set(PTypeIO.LE_SHORT, outBase + laneOff, (short) (lo | (hi << curr))); + } + } else { + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 2L) { + out.set(PTypeIO.LE_SHORT, outBase + laneOff, + (short) ((Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, wordBase + laneOff)) >>> shift) & bitMask)); + } + } + } + } else { + for (int row = 0; row < 16; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + int o = row / 8; + int s = row % 8; + int baseIdx = blockLogicStart + FL_ORDER[o] * 16 + s * 128; + long wordBase = blockByteOff + currWordByteBase[row]; + long hiBase = rem > 0 ? blockByteOff + nextWordByteBase[row] : 0L; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + for (int lane = 0; lane < lanes; lane++) { + int logicalIdx = baseIdx + lane; + if (logicalIdx < 0 || logicalIdx >= rowCount) { + continue; + } + long src = Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, wordBase + (long) lane * 2)); + long value; + if (rem > 0) { + long lo = (src >>> shift) & loMask; + long hi = Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, hiBase + (long) lane * 2)) & hiMask; + value = lo | (hi << curr); + } else { + value = (src >>> shift) & bitMask; + } + out.set(PTypeIO.LE_SHORT, (long) logicalIdx * 2, (short) value); + } + } + } + } + } + + private static void unpackLoop32(MemorySegment buf, int bitWidth, int offset, long rowCount, MemorySegment out) { + final int lanes = 32; + long totalElems = rowCount + offset; + int blockCount = (int) ((totalElems + 1023) / 1024); + long bitMask = (1L << bitWidth) - 1L; + + int[] shifts = new int[32]; + int[] remainingBits = new int[32]; + int[] currentBits = new int[32]; + long[] loMasks = new long[32]; + long[] hiMasks = new long[32]; + long[] currWordByteBase = new long[32]; + long[] nextWordByteBase = new long[32]; + long[] outRowByteOff = new long[32]; + for (int row = 0; row < 32; row++) { + int currWord = (row * bitWidth) / 32; + int nextWord = ((row + 1) * bitWidth) / 32; + shifts[row] = (row * bitWidth) % 32; + int rem = (nextWord > currWord) ? ((row + 1) * bitWidth) % 32 : 0; + remainingBits[row] = rem; + int curr = bitWidth - rem; + currentBits[row] = curr; + loMasks[row] = rem > 0 ? (1L << curr) - 1L : 0L; + hiMasks[row] = rem > 0 ? (1L << rem) - 1L : 0L; + currWordByteBase[row] = (long) lanes * currWord * 4L; + nextWordByteBase[row] = rem > 0 ? (long) lanes * nextWord * 4L : 0L; + int o = row / 8; + int s = row % 8; + outRowByteOff[row] = (long) (FL_ORDER[o] * 16 + s * 128) * 4L; + } + + long blockByteOff = 0L; + long blockByteStride = 128L * bitWidth; + for (int block = 0; block < blockCount; block++, blockByteOff += blockByteStride) { + int blockLogicStart = block * 1024 - offset; + boolean fullBlock = blockLogicStart >= 0 && (long) blockLogicStart + 1023L < rowCount; + long blockOutByteBase = (long) blockLogicStart * 4L; + + if (fullBlock) { + for (int row = 0; row < 32; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + long outBase = blockOutByteBase + outRowByteOff[row]; + long wordBase = blockByteOff + currWordByteBase[row]; + if (rem > 0) { + long hiBase = blockByteOff + nextWordByteBase[row]; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 4L) { + long lo = (Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, wordBase + laneOff)) >>> shift) & loMask; + long hi = Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, hiBase + laneOff)) & hiMask; + out.set(PTypeIO.LE_INT, outBase + laneOff, (int) (lo | (hi << curr))); + } + } else { + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 4L) { + out.set(PTypeIO.LE_INT, outBase + laneOff, + (int) ((Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, wordBase + laneOff)) >>> shift) & bitMask)); + } + } + } + } else { + for (int row = 0; row < 32; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + int o = row / 8; + int s = row % 8; + int baseIdx = blockLogicStart + FL_ORDER[o] * 16 + s * 128; + long wordBase = blockByteOff + currWordByteBase[row]; + long hiBase = rem > 0 ? blockByteOff + nextWordByteBase[row] : 0L; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + for (int lane = 0; lane < lanes; lane++) { + int logicalIdx = baseIdx + lane; + if (logicalIdx < 0 || logicalIdx >= rowCount) { + continue; + } + long src = Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, wordBase + (long) lane * 4)); + long value; + if (rem > 0) { + long lo = (src >>> shift) & loMask; + long hi = Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, hiBase + (long) lane * 4)) & hiMask; + value = lo | (hi << curr); + } else { + value = (src >>> shift) & bitMask; + } + out.set(PTypeIO.LE_INT, (long) logicalIdx * 4, (int) value); + } + } + } + } + } + + private static void unpackLoop64(MemorySegment buf, int bitWidth, int offset, long rowCount, MemorySegment out) { + final int lanes = 16; + long totalElems = rowCount + offset; + int blockCount = (int) ((totalElems + 1023) / 1024); + long bitMask = bitWidth == 64 ? -1L : (1L << bitWidth) - 1L; + + int[] shifts = new int[64]; + int[] remainingBits = new int[64]; + int[] currentBits = new int[64]; + long[] loMasks = new long[64]; + long[] hiMasks = new long[64]; + long[] currWordByteBase = new long[64]; + long[] nextWordByteBase = new long[64]; + long[] outRowByteOff = new long[64]; + for (int row = 0; row < 64; row++) { + int currWord = (row * bitWidth) / 64; + int nextWord = ((row + 1) * bitWidth) / 64; + shifts[row] = (row * bitWidth) % 64; + int rem = (nextWord > currWord) ? ((row + 1) * bitWidth) % 64 : 0; + remainingBits[row] = rem; + int curr = bitWidth - rem; + currentBits[row] = curr; + loMasks[row] = rem > 0 ? (1L << curr) - 1L : 0L; + hiMasks[row] = rem > 0 ? (1L << rem) - 1L : 0L; + currWordByteBase[row] = (long) lanes * currWord * 8L; + nextWordByteBase[row] = rem > 0 ? (long) lanes * nextWord * 8L : 0L; + int o = row / 8; + int s = row % 8; + outRowByteOff[row] = (long) (FL_ORDER[o] * 16 + s * 128) * 8L; + } + + long blockByteOff = 0L; + long blockByteStride = 128L * bitWidth; + for (int block = 0; block < blockCount; block++, blockByteOff += blockByteStride) { + int blockLogicStart = block * 1024 - offset; + boolean fullBlock = blockLogicStart >= 0 && (long) blockLogicStart + 1023L < rowCount; + long blockOutByteBase = (long) blockLogicStart * 8L; + + if (fullBlock) { + for (int row = 0; row < 64; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + long outBase = blockOutByteBase + outRowByteOff[row]; + long wordBase = blockByteOff + currWordByteBase[row]; + if (rem > 0) { + long hiBase = blockByteOff + nextWordByteBase[row]; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 8L) { + long lo = (buf.get(PTypeIO.LE_LONG, wordBase + laneOff) >>> shift) & loMask; + long hi = buf.get(PTypeIO.LE_LONG, hiBase + laneOff) & hiMask; + out.set(PTypeIO.LE_LONG, outBase + laneOff, lo | (hi << curr)); + } + } else { + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 8L) { + out.set(PTypeIO.LE_LONG, outBase + laneOff, + (buf.get(PTypeIO.LE_LONG, wordBase + laneOff) >>> shift) & bitMask); + } + } + } + } else { + for (int row = 0; row < 64; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + int o = row / 8; + int s = row % 8; + int baseIdx = blockLogicStart + FL_ORDER[o] * 16 + s * 128; + long wordBase = blockByteOff + currWordByteBase[row]; + long hiBase = rem > 0 ? blockByteOff + nextWordByteBase[row] : 0L; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + for (int lane = 0; lane < lanes; lane++) { + int logicalIdx = baseIdx + lane; + if (logicalIdx < 0 || logicalIdx >= rowCount) { + continue; + } + long src = buf.get(PTypeIO.LE_LONG, wordBase + (long) lane * 8); + long value; + if (rem > 0) { + long lo = (src >>> shift) & loMask; + long hi = buf.get(PTypeIO.LE_LONG, hiBase + (long) lane * 8) & hiMask; + value = lo | (hi << curr); + } else { + value = (src >>> shift) & bitMask; + } + out.set(PTypeIO.LE_LONG, (long) logicalIdx * 8, value); + } + } + } + } + } + + private static void applyPatches(DecodeContext ctx, PatchesMetadata pm, + MemorySegment out, int elemBytes) { + long numPatches = pm.len(); + if (numPatches == 0) { + return; + } + long offset = pm.offset(); + PType idxPtype = ptypeFromProto(pm.indices_ptype()); + + MemorySegment idxSeg = ctx.decodeChildSegment(0, new DType.Primitive(idxPtype, false), numPatches); + MemorySegment valSeg = ctx.decodeChildSegment(1, ctx.dtype(), numPatches); + + int idxBytes = idxPtype.byteSize(); + long n = ctx.rowCount(); + for (long i = 0; i < numPatches; i++) { + long absIdx = readUnsignedIdx(idxSeg, SegmentBroadcast.elementOffset(idxSeg, i, idxBytes), idxPtype) - offset; + if (absIdx < 0 || absIdx >= n) { + throw new VortexException(EncodingId.FASTLANES_BITPACKED, + "patch index " + absIdx + " out of range [0," + n + ")"); + } + MemorySegment.copy(valSeg, SegmentBroadcast.elementOffset(valSeg, i, elemBytes), + out, absIdx * elemBytes, elemBytes); + } + } + + private static long readUnsignedIdx(MemorySegment seg, long off, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case U64 -> seg.get(PTypeIO.LE_LONG, off); + default -> throw new VortexException(EncodingId.FASTLANES_BITPACKED, + "non-unsigned patch index ptype " + ptype); + }; + } + + private static PType ptypeFromProto(io.github.dfa1.vortex.proto.PType proto) { + return PType.fromOrdinal(proto.value()); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BoolEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BoolEncodingDecoder.java new file mode 100644 index 00000000..f545de58 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BoolEncodingDecoder.java @@ -0,0 +1,35 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-only decoder for {@code vortex.bool} (bit-packed boolean arrays, LSB first). +/// +///

ADR 0001 Phase 2: first encoding lifted into a standalone {@link EncodingDecoder} +/// implementation in the {@code reader} module. The corresponding write-side encode +/// path continues to live on {@link io.github.dfa1.vortex.encoding.BoolEncoding} in +/// {@code core}; that file is peeled into a {@code BoolEncodingEncoder} in +/// {@code writer} during Phase 3. +public final class BoolEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public BoolEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_BOOL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Bool; + } + + @Override + public Array decode(DecodeContext ctx) { + return new BoolArray(ctx.dtype(), ctx.rowCount(), ctx.buffer(0)); + } +} diff --git a/core/src/main/java/io/github/dfa1/vortex/encoding/ByteBoolEncoding.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoder.java similarity index 52% rename from core/src/main/java/io/github/dfa1/vortex/encoding/ByteBoolEncoding.java rename to reader/src/main/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoder.java index 6fa711a4..79827a95 100644 --- a/core/src/main/java/io/github/dfa1/vortex/encoding/ByteBoolEncoding.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoder.java @@ -1,24 +1,19 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.reader.decode; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.EncodingId; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; -/// Encoding for {@code vortex.bytebool} — one byte per boolean element (0 = false, non-zero = true). -/// -///

Buffer 0: raw byte values, length = rowCount. -/// Metadata: empty. -/// Child slot 0: validity array (optional, only when nullable with explicit validity bitmap). -/// -///

Decode: pack the byte buffer into a bit-packed {@link BoolArray} (LSB-first), -/// matching the layout of {@code vortex.bool}. -public final class ByteBoolEncoding implements Encoding { +/// Read-only decoder for {@code vortex.bytebool} — packs the input byte buffer into the +/// bit-packed {@link BoolArray} layout used by {@code vortex.bool}. +public final class ByteBoolEncodingDecoder implements EncodingDecoder { - /// Creates a new {@code ByteBoolEncoding} instance. - public ByteBoolEncoding() { + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ByteBoolEncodingDecoder() { } @Override @@ -31,16 +26,6 @@ public boolean accepts(DType dtype) { return dtype instanceof DType.Bool; } - @Override - public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { - boolean[] bools = (boolean[]) data; - MemorySegment seg = ctx.arena().allocate(bools.length); - for (int i = 0; i < bools.length; i++) { - seg.set(ValueLayout.JAVA_BYTE, i, bools[i] ? (byte) 1 : (byte) 0); - } - return EncodeResult.simple(encodingId(), seg); - } - @Override public Array decode(DecodeContext ctx) { long n = ctx.rowCount(); diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ChunkedEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ChunkedEncodingDecoder.java new file mode 100644 index 00000000..db5218e1 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ChunkedEncodingDecoder.java @@ -0,0 +1,126 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.StructArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +/// Read-only decoder for {@code vortex.chunked}. +public final class ChunkedEncodingDecoder implements EncodingDecoder { + + private static final ValueLayout.OfLong LE_LONG = + ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ChunkedEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_CHUNKED; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Struct; + } + + @Override + public Array decode(DecodeContext ctx) { + int nchildren = ctx.node().children().length; + if (nchildren < 1) { + throw new VortexException(EncodingId.VORTEX_CHUNKED, + "needs at least one child (chunk offsets)"); + } + int nchunks = nchildren - 1; + long[] offsets = readOffsets(ctx, nchunks); + + DType dtype = ctx.dtype(); + List chunks = new ArrayList<>(nchunks); + for (int i = 0; i < nchunks; i++) { + long chunkLen = offsets[i + 1] - offsets[i]; + chunks.add(ctx.decodeChild(i + 1, dtype, chunkLen)); + } + + return concat(chunks, dtype, ctx.rowCount(), ctx.arena()); + } + + private static long[] readOffsets(DecodeContext ctx, int nchunks) { + DType u64 = new DType.Primitive(PType.U64, false); + MemorySegment offsetsBuf = ctx.decodeChildSegment(0, u64, nchunks + 1L); + long cap = SegmentBroadcast.capacity(offsetsBuf, 8); + long[] offsets = new long[nchunks + 1]; + for (int i = 0; i <= nchunks; i++) { + offsets[i] = offsetsBuf.get(LE_LONG, (i % cap) * 8); + } + return offsets; + } + + private static Array concat(List chunks, DType dtype, long totalRows, SegmentAllocator arena) { + if (dtype instanceof DType.Primitive pt) { + return concatPrimitive(chunks, pt, dtype, totalRows, arena); + } + if (dtype instanceof DType.Struct struct) { + return concatStruct(chunks, struct, totalRows, arena); + } + throw new VortexException(EncodingId.VORTEX_CHUNKED, + "concat not supported for dtype: " + dtype); + } + + private static Array concatPrimitive( + List chunks, DType.Primitive pt, DType dtype, long totalRows, SegmentAllocator arena + ) { + PType ptype = pt.ptype(); + MemorySegment combined = arena.allocate(totalRows * ptype.byteSize()); + long byteOffset = 0; + for (Array chunk : chunks) { + MemorySegment src = ArraySegments.of(chunk); + MemorySegment.copy(src, 0, combined, byteOffset, src.byteSize()); + byteOffset += src.byteSize(); + } + MemorySegment ro = combined.asReadOnly(); + return switch (ptype) { + case I64, U64 -> new LongArray(dtype, totalRows, ro); + case I32, U32 -> new IntArray(dtype, totalRows, ro); + case F64 -> new DoubleArray(dtype, totalRows, ro); + case F32 -> new FloatArray(dtype, totalRows, ro); + case I16, U16 -> new ShortArray(dtype, totalRows, ro); + case I8, U8 -> new ByteArray(dtype, totalRows, ro); + default -> throw new VortexException(EncodingId.VORTEX_CHUNKED, + "unsupported ptype for concat: " + ptype); + }; + } + + private static StructArray concatStruct( + List chunks, DType.Struct struct, long totalRows, SegmentAllocator arena + ) { + int nfields = struct.fieldTypes().size(); + List concatFields = new ArrayList<>(nfields); + for (int f = 0; f < nfields; f++) { + DType fieldDtype = struct.fieldTypes().get(f); + List fieldChunks = new ArrayList<>(chunks.size()); + for (Array chunk : chunks) { + fieldChunks.add(((StructArray) chunk).field(f)); + } + concatFields.add(concat(fieldChunks, fieldDtype, totalRows, arena)); + } + return new StructArray(struct, totalRows, concatFields); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ConstantEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ConstantEncodingDecoder.java new file mode 100644 index 00000000..05aea1bb --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ConstantEncodingDecoder.java @@ -0,0 +1,173 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.GenericArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.NullArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; + +/// Read-only decoder for {@code vortex.constant}. +public final class ConstantEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ConstantEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_CONSTANT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public Array decode(DecodeContext ctx) { + MemorySegment scalarBuf = ctx.buffer(0); + ScalarValue scalar; + try { + scalar = ScalarValue.decode(scalarBuf, 0, scalarBuf.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_CONSTANT, "invalid scalar value", e); + } + + long n = ctx.rowCount(); + + if (ctx.dtype() instanceof DType.Null) { + return new NullArray(ctx.dtype(), n); + } + + if (ctx.dtype() instanceof DType.Utf8 || ctx.dtype() instanceof DType.Binary) { + return decodeString(ctx, scalar, n); + } + + if (ctx.dtype() instanceof DType.Bool) { + return decodeBool(ctx, scalar, n); + } + + if (ctx.dtype() instanceof DType.Decimal) { + return decodeDecimal(ctx, scalar, n); + } + + if (ctx.dtype() instanceof DType.Extension ext) { + var storageCtx = new DecodeContext(ctx.node(), ext.storageDType(), ctx.rowCount(), + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array storage = decode(storageCtx); + return new GenericArray(ctx.dtype(), n, ArraySegments.of(storage)); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported dtype " + ctx.dtype()); + } + + PType ptype = p.ptype(); + int elemBytes = ptype.byteSize(); + long rawBits = scalarToRawBits(scalar, ptype); + + MemorySegment outSeg = ctx.arena().allocate(elemBytes); + ByteBuffer out = outSeg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + writeRaw(out, ptype, rawBits); + + MemorySegment ro = outSeg.asReadOnly(); + return switch (ptype) { + case I64, U64 -> new LongArray(ctx.dtype(), n, ro); + case I32, U32 -> new IntArray(ctx.dtype(), n, ro); + case F64 -> new DoubleArray(ctx.dtype(), n, ro); + case F32 -> new FloatArray(ctx.dtype(), n, ro); + case I16, U16 -> new ShortArray(ctx.dtype(), n, ro); + case I8, U8 -> new ByteArray(ctx.dtype(), n, ro); + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype " + ptype); + }; + } + + private static Array decodeDecimal(DecodeContext ctx, ScalarValue scalar, long n) { + byte[] elemBytes = scalar.bytes_value(); + int elemLen = elemBytes.length; + MemorySegment outSeg = ctx.arena().allocate(n * elemLen); + MemorySegment elemSeg = MemorySegment.ofArray(elemBytes); + for (long i = 0; i < n; i++) { + MemorySegment.copy(elemSeg, 0L, outSeg, i * elemLen, elemLen); + } + return new GenericArray(ctx.dtype(), n, outSeg.asReadOnly()); + } + + private static Array decodeBool(DecodeContext ctx, ScalarValue scalar, long n) { + boolean value = scalar.bool_value() != null && scalar.bool_value(); + long numBytes = (n + 7) >>> 3; + MemorySegment seg = ctx.arena().allocate(numBytes); + if (value) { + for (long i = 0; i < numBytes; i++) { + seg.set(ValueLayout.JAVA_BYTE, i, (byte) 0xFF); + } + } + return new BoolArray(ctx.dtype(), n, seg.asReadOnly()); + } + + private static Array decodeString(DecodeContext ctx, ScalarValue scalar, long n) { + byte[] strBytes = scalar.string_value() != null + ? scalar.string_value().getBytes(StandardCharsets.UTF_8) + : (scalar.bytes_value() != null ? scalar.bytes_value() : new byte[0]); + + int strLen = strBytes.length; + + MemorySegment bytesSeg = ctx.arena().allocate((long) n * strLen); + for (long i = 0; i < n; i++) { + MemorySegment.copy(MemorySegment.ofArray(strBytes), 0L, bytesSeg, i * strLen, strLen); + } + + MemorySegment offsetsSeg = ctx.arena().allocate((n + 1) * 4L, 4); + for (long i = 0; i <= n; i++) { + offsetsSeg.setAtIndex(PTypeIO.LE_INT, i, (int) (i * strLen)); + } + + return new VarBinArray(ctx.dtype(), n, bytesSeg.asReadOnly(), offsetsSeg.asReadOnly(), PType.I32); + } + + private static long scalarToRawBits(ScalarValue scalar, PType ptype) { + if (scalar.int64_value() != null) { + return scalar.int64_value(); + } + if (scalar.uint64_value() != null) { + return scalar.uint64_value(); + } + if (scalar.f32_value() != null) { + return Float.floatToRawIntBits(scalar.f32_value()); + } + if (scalar.f64_value() != null) { + return Double.doubleToRawLongBits(scalar.f64_value()); + } + return 0L; + } + + private static void writeRaw(ByteBuffer buf, PType ptype, long rawBits) { + switch (ptype.byteSize()) { + case 1 -> buf.put((byte) rawBits); + case 2 -> buf.putShort((short) rawBits); + case 4 -> buf.putInt((int) rawBits); + case 8 -> buf.putLong(rawBits); + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype " + ptype); + } + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DateTimePartsEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DateTimePartsEncodingDecoder.java new file mode 100644 index 00000000..b274348c --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DateTimePartsEncodingDecoder.java @@ -0,0 +1,58 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.GenericArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DateTimePartsMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.datetimeparts}. +public final class DateTimePartsEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DateTimePartsEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DATETIMEPARTS; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Extension; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer meta = ctx.metadata(); + if (meta == null || meta.remaining() == 0) { + throw new VortexException(EncodingId.VORTEX_DATETIMEPARTS, "missing metadata"); + } + DateTimePartsMetadata decoded; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(meta.duplicate()); + decoded = DateTimePartsMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_DATETIMEPARTS, "invalid metadata: " + e.getMessage()); + } + + PType daysPtype = PType.fromOrdinal(decoded.days_ptype().value()); + PType secondsPtype = PType.fromOrdinal(decoded.seconds_ptype().value()); + PType subsecondsPtype = PType.fromOrdinal(decoded.subseconds_ptype().value()); + boolean nullable = ctx.dtype().nullable(); + + Array days = ctx.decodeChild(0, new DType.Primitive(daysPtype, nullable), ctx.rowCount()); + Array seconds = ctx.decodeChild(1, new DType.Primitive(secondsPtype, false), ctx.rowCount()); + Array subseconds = ctx.decodeChild(2, new DType.Primitive(subsecondsPtype, false), ctx.rowCount()); + + return new GenericArray(ctx.dtype(), ctx.rowCount(), new MemorySegment[0], + new Array[]{days, seconds, subseconds}); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalBytePartsEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalBytePartsEncodingDecoder.java new file mode 100644 index 00000000..a348a0a3 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalBytePartsEncodingDecoder.java @@ -0,0 +1,63 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.GenericArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DecimalBytePartsMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.decimal_byte_parts}. +public final class DecimalBytePartsEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DecimalBytePartsEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DECIMAL_BYTE_PARTS; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Decimal; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer meta = ctx.metadata(); + if (meta == null || meta.remaining() == 0) { + throw new VortexException(EncodingId.VORTEX_DECIMAL_BYTE_PARTS, "missing metadata"); + } + DecimalBytePartsMetadata decoded; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(meta.duplicate()); + decoded = DecimalBytePartsMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_DECIMAL_BYTE_PARTS, "invalid metadata: " + e.getMessage()); + } + + int lowerPartCount = decoded.lower_part_count(); + if (lowerPartCount != 0) { + throw new VortexException(EncodingId.VORTEX_DECIMAL_BYTE_PARTS, + "lower_part_count > 0 not supported, got " + lowerPartCount); + } + + PType mspPtype = PType.fromOrdinal(decoded.zeroth_child_ptype().value()); + boolean nullable = ctx.dtype().nullable(); + DType mspDtype = new DType.Primitive(mspPtype, nullable); + ArrayNode mspNode = ctx.node().children()[0]; + DecodeContext mspCtx = new DecodeContext( + mspNode, mspDtype, ctx.rowCount(), + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array mspArray = ctx.registry().decode(mspCtx); + return new GenericArray(ctx.dtype(), ctx.rowCount(), new MemorySegment[0], + new Array[]{mspArray}); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalEncodingDecoder.java new file mode 100644 index 00000000..dd7939c3 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalEncodingDecoder.java @@ -0,0 +1,67 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.GenericArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DecimalMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.decimal}. +public final class DecimalEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DecimalEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DECIMAL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Decimal; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer meta = ctx.metadata(); + if (meta == null || meta.remaining() == 0) { + throw new VortexException(EncodingId.VORTEX_DECIMAL, "missing metadata"); + } + DecimalMetadata decoded; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(meta.duplicate()); + decoded = DecimalMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_DECIMAL, "invalid metadata: " + e.getMessage()); + } + int valuesType = decoded.values_type(); + int byteWidth = decimalTypeByteWidth(valuesType); + MemorySegment buffer = ctx.buffer(0); + long expected = ctx.rowCount() * byteWidth; + if (buffer.byteSize() < expected) { + throw new VortexException(EncodingId.VORTEX_DECIMAL, + "buffer too small: expected %d bytes, got %d".formatted(expected, buffer.byteSize())); + } + return new GenericArray(ctx.dtype(), ctx.rowCount(), buffer); + } + + private static int decimalTypeByteWidth(int valuesType) { + return switch (valuesType) { + case 0 -> 1; + case 1 -> 2; + case 2 -> 4; + case 3 -> 8; + case 4 -> 16; + case 5 -> 32; + default -> throw new VortexException(EncodingId.VORTEX_DECIMAL, + "unknown DecimalType value: " + valuesType); + }; + } +} diff --git a/core/src/main/java/io/github/dfa1/vortex/encoding/DecodeContext.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecodeContext.java similarity index 87% rename from core/src/main/java/io/github/dfa1/vortex/encoding/DecodeContext.java rename to reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecodeContext.java index 78540e5b..2b8774c3 100644 --- a/core/src/main/java/io/github/dfa1/vortex/encoding/DecodeContext.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecodeContext.java @@ -1,16 +1,17 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.reader.decode; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.reader.ReadRegistry; import java.lang.foreign.MemorySegment; import java.lang.foreign.SegmentAllocator; import java.nio.ByteBuffer; -/// Decoding context passed to each [Encoding]. +/// Decoding context passed to each {@link EncodingDecoder}. /// -/// Buffers are `MemorySegment` slices materialized from the file's segment table; -/// children are decoded recursively via [#decodeChild(int)]. +///

Buffers are {@link MemorySegment} slices materialized from the file's segment table; +/// children are decoded recursively via {@link #decodeChild(int)}. /// The arena is scoped to one chunk epoch — all decode output allocated from it is /// valid until the next chunk is opened. /// @@ -18,16 +19,17 @@ /// @param dtype logical type expected for the decoded array /// @param rowCount number of logical rows to decode /// @param segmentBuffers all segment buffers for the current flat segment, indexed by segment position -/// @param registry encoding registry used for recursive child decoding +/// @param registry read registry used for recursive child decoding /// @param arena allocator for decode output; lifetime matches the current chunk epoch public record DecodeContext( ArrayNode node, DType dtype, long rowCount, MemorySegment[] segmentBuffers, - Registry registry, + ReadRegistry registry, SegmentAllocator arena ) { + /// Recursively decode child {@code i} using this context's dtype and row count. /// /// @param i zero-based child index within this node's children array @@ -75,7 +77,7 @@ public MemorySegment decodeChildSegment(int i, DType dtype, long rowCount) { return registry.decodeAsSegment(childCtx); } - /// Return the buffer at position `i` in this node's bufferIndices. + /// Returns the buffer at position {@code i} in this node's bufferIndices. /// /// @param i zero-based index into this node's {@code bufferIndices} array /// @return the {@link MemorySegment} for the referenced segment buffer @@ -85,7 +87,7 @@ public MemorySegment buffer(int i) { /// Returns the encoding-specific metadata bytes for this node, or {@code null} if absent. /// - /// @return the metadata {@link java.nio.ByteBuffer}, or {@code null} + /// @return the metadata {@link ByteBuffer}, or {@code null} public ByteBuffer metadata() { return node.metadata(); } diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DeltaEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DeltaEncodingDecoder.java new file mode 100644 index 00000000..6a6a5b50 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DeltaEncodingDecoder.java @@ -0,0 +1,198 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.DeltaMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code fastlanes.delta}. +public final class DeltaEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DeltaEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_DELTA; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return switch (p.ptype()) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> true; + default -> false; + }; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + DeltaMetadata meta; + if (rawMeta == null || !rawMeta.hasRemaining()) { + meta = new DeltaMetadata(0L, 0); + } else { + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = DeltaMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.FASTLANES_DELTA, "invalid metadata", e); + } + } + + PType ptype = ((DType.Primitive) ctx.dtype()).ptype(); + long rowCount = ctx.rowCount(); + int typeBits = typeBits(ptype); + int lanes = lanes(ptype); + long mask = typeMask(ptype); + + long deltasLen = meta.deltas_len(); + int offset = meta.offset(); + + if (deltasLen == 0L) { + MemorySegment empty = ctx.arena().allocate(0); + return switch (ptype) { + case I64, U64 -> new LongArray(ctx.dtype(), 0L, empty); + case I32, U32 -> new IntArray(ctx.dtype(), 0L, empty); + case I16, U16 -> new ShortArray(ctx.dtype(), 0L, empty); + case I8, U8 -> new ByteArray(ctx.dtype(), 0L, empty); + default -> throw new VortexException(EncodingId.FASTLANES_DELTA, "unsupported ptype: " + ptype); + }; + } + + long basesLen = (deltasLen / FL_CHUNK_SIZE) * lanes; + DType dtype = ctx.dtype(); + + long[] basesAll = readLongs(ctx.decodeChildSegment(0, dtype, basesLen), (int) basesLen, ptype); + long[] deltasAll = readLongs(ctx.decodeChildSegment(1, dtype, deltasLen), (int) deltasLen, ptype); + + int numChunks = (int) (deltasLen / FL_CHUNK_SIZE); + long[] decoded = new long[(int) deltasLen]; + long[] untransposedChunk = new long[FL_CHUNK_SIZE]; + long[] chunkBases = new long[lanes]; + long[] chunkDeltas = new long[FL_CHUNK_SIZE]; + long[] chunkUndelta = new long[FL_CHUNK_SIZE]; + + for (int chunk = 0; chunk < numChunks; chunk++) { + int basesOff = chunk * lanes; + int deltaOff = chunk * FL_CHUNK_SIZE; + + System.arraycopy(basesAll, basesOff, chunkBases, 0, lanes); + System.arraycopy(deltasAll, deltaOff, chunkDeltas, 0, FL_CHUNK_SIZE); + + undeltaChunk(chunkDeltas, chunkBases, lanes, typeBits, mask, chunkUndelta); + + for (int i = 0; i < FL_CHUNK_SIZE; i++) { + untransposedChunk[transposeIndex(i)] = chunkUndelta[i]; + } + System.arraycopy(untransposedChunk, 0, decoded, deltaOff, FL_CHUNK_SIZE); + } + + long[] result = new long[(int) rowCount]; + System.arraycopy(decoded, offset, result, 0, (int) rowCount); + + MemorySegment seg = fromLongs(result, ptype, ctx.arena()); + return switch (ptype) { + case I64, U64 -> new LongArray(ctx.dtype(), rowCount, seg); + case I32, U32 -> new IntArray(ctx.dtype(), rowCount, seg); + case I16, U16 -> new ShortArray(ctx.dtype(), rowCount, seg); + case I8, U8 -> new ByteArray(ctx.dtype(), rowCount, seg); + default -> throw new VortexException(EncodingId.FASTLANES_DELTA, "unsupported ptype: " + ptype); + }; + } + + private static void undeltaChunk(long[] deltas, long[] bases, int lanes, int typeBits, long mask, long[] out) { + for (int lane = 0; lane < lanes; lane++) { + long prev = bases[lane] & mask; + for (int row = 0; row < typeBits; row++) { + int idx = iterateIndex(row, lane); + long next = ((deltas[idx] & mask) + prev) & mask; + out[idx] = next; + prev = next; + } + } + } + + private static long[] readLongs(MemorySegment buf, int count, PType ptype) { + long[] out = new long[count]; + int elemSize = ptype.byteSize(); + long cap = SegmentBroadcast.capacity(buf, elemSize); + for (int i = 0; i < count; i++) { + long off = (i % cap) * elemSize; + out[i] = switch (ptype) { + case I8 -> buf.get(ValueLayout.JAVA_BYTE, off); + case U8 -> Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, off)); + case I16 -> buf.get(PTypeIO.LE_SHORT, off); + case U16 -> Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, off)); + case I32 -> buf.get(PTypeIO.LE_INT, off); + case U32 -> Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, off)); + case I64, U64 -> buf.get(PTypeIO.LE_LONG, off); + default -> throw new VortexException(EncodingId.FASTLANES_DELTA, "unsupported ptype: " + ptype); + }; + } + return out; + } + + private static final int FL_CHUNK_SIZE = 1024; + + private static final int[] FL_ORDER = {0, 4, 2, 6, 1, 5, 3, 7}; + + private static int transposeIndex(int idx) { + int lane = idx % 16; + int order = (idx / 16) % 8; + int row = idx / 128; + return lane * 64 + FL_ORDER[order] * 8 + row; + } + + private static int iterateIndex(int row, int lane) { + int o = row / 8; + int s = row % 8; + return FL_ORDER[o] * 16 + s * 128 + lane; + } + + private static int lanes(PType ptype) { + return FL_CHUNK_SIZE / (ptype.byteSize() * 8); + } + + private static int typeBits(PType ptype) { + return ptype.byteSize() * 8; + } + + private static long typeMask(PType ptype) { + int bits = ptype.byteSize() * 8; + return bits == 64 ? -1L : (1L << bits) - 1; + } + + private static MemorySegment fromLongs(long[] longs, PType ptype, SegmentAllocator arena) { + if (ptype == PType.I64 || ptype == PType.U64) { + MemorySegment dst = arena.allocate((long) longs.length * 8); + MemorySegment.copy(MemorySegment.ofArray(longs), ValueLayout.JAVA_LONG, 0L, dst, PTypeIO.LE_LONG, 0L, longs.length); + return dst; + } + int n = longs.length; + long elemSize = ptype.byteSize(); + MemorySegment seg = arena.allocate(n * elemSize); + for (int i = 0; i < n; i++) { + PTypeIO.set(seg, i * elemSize, ptype, longs[i]); + } + return seg; + } + +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DictEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DictEncodingDecoder.java new file mode 100644 index 00000000..d1f7b26c --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DictEncodingDecoder.java @@ -0,0 +1,402 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.DictMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.dict}. +public final class DictEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DictEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DICT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Utf8; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer meta = ctx.metadata(); + + if (ctx.dtype() instanceof DType.Utf8) { + if (ctx.node().children().length == 0) { + if (meta == null || !meta.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_DICT, "missing metadata for legacy utf8 dict"); + } + return decodeUtf8DictLegacy(ctx, meta); + } + if (meta == null || !meta.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_DICT, "missing metadata for utf8 dict"); + } + return decodeUtf8DictProto(ctx, meta.duplicate()); + } + + if (meta == null || !meta.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_DICT, "missing metadata"); + } + + if (meta.remaining() == 1) { + return decodeLegacyJava(ctx, meta.get(0)); + } + return decodeRustProto(ctx, meta.duplicate()); + } + + private static Array decodeLegacyJava(DecodeContext ctx, byte codeTypeByte) { + PType codePType = PType.fromOrdinal(Byte.toUnsignedInt(codeTypeByte)); + PType valPType = ((DType.Primitive) ctx.dtype()).ptype(); + int elemSize = valPType.byteSize(); + long rowCount = ctx.rowCount(); + + MemorySegment valuesBuf = ctx.segmentBuffers()[ctx.node().children()[0].bufferIndices()[0]]; + + DType codesDtype = new DType.Primitive(codePType, false); + MemorySegment codesBuf = ctx.decodeChildSegment(1, codesDtype, rowCount); + + MemorySegment out = ctx.arena().allocate(rowCount * (long) elemSize); + switch (codePType) { + case U8 -> expandU8(codesBuf, valuesBuf, out, rowCount, elemSize); + case U16 -> expandU16(codesBuf, valuesBuf, out, rowCount, elemSize); + case U32 -> expandU32(codesBuf, valuesBuf, out, rowCount, elemSize); + default -> { + for (long i = 0; i < rowCount; i++) { + long code = readCode(codesBuf, codePType, i); + MemorySegment.copy(valuesBuf, code * elemSize, out, i * elemSize, elemSize); + } + } + } + return typedArray(ctx.dtype(), valPType, rowCount, out.asReadOnly()); + } + + private static Array decodeRustProto(DecodeContext ctx, ByteBuffer metaBuf) { + DictMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(metaBuf); + meta = DictMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_DICT, "invalid proto metadata", e); + } + + PType codePType = PType.fromOrdinal(meta.codes_ptype().value()); + long valuesLen = meta.values_len(); + long rowCount = ctx.rowCount(); + PType valPType = ((DType.Primitive) ctx.dtype()).ptype(); + int elemSize = valPType.byteSize(); + + DType codesDtype = new DType.Primitive(codePType, false); + MemorySegment codesBuf = ctx.decodeChildSegment(0, codesDtype, rowCount); + MemorySegment valuesBuf = ctx.decodeChildSegment(1, ctx.dtype(), valuesLen); + + MemorySegment out = ctx.arena().allocate(rowCount * (long) elemSize); + switch (codePType) { + case U8 -> expandU8(codesBuf, valuesBuf, out, rowCount, elemSize); + case U16 -> expandU16(codesBuf, valuesBuf, out, rowCount, elemSize); + case U32 -> expandU32(codesBuf, valuesBuf, out, rowCount, elemSize); + default -> throw new VortexException(EncodingId.VORTEX_DICT, "unexpected code type: " + codePType); + } + return typedArray(ctx.dtype(), valPType, rowCount, out.asReadOnly()); + } + + private static Array decodeUtf8DictLegacy(DecodeContext ctx, ByteBuffer meta) { + PType codePType = PType.fromOrdinal(Byte.toUnsignedInt(meta.get(0))); + long n = ctx.rowCount(); + + MemorySegment dictBytes = ctx.buffer(0); + MemorySegment dictOffsets = ctx.buffer(1); + MemorySegment codes = ctx.buffer(2); + + return VarBinArray.ofDict(ctx.dtype(), n, + dictBytes, dictOffsets, PType.I64, + codes, codePType); + } + + private static Array decodeUtf8DictProto(DecodeContext ctx, ByteBuffer metaBuf) { + DictMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(metaBuf); + meta = DictMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_DICT, "invalid utf8 dict proto metadata", e); + } + PType codePType = PType.fromOrdinal(meta.codes_ptype().value()); + long dictSize = meta.values_len(); + long n = ctx.rowCount(); + + DType codesDtype = new DType.Primitive(codePType, false); + MemorySegment codesBuf = ctx.decodeChildSegment(0, codesDtype, n); + + Array valuesArr = ctx.decodeChild(1, ctx.dtype(), dictSize); + VarBinArray varBinValues = (VarBinArray) valuesArr; + MemorySegment dictBytes = varBinValues.bytesSegment(); + MemorySegment dictOffsets = varBinValues.offsetsSegment(); + + return VarBinArray.ofDict(ctx.dtype(), n, + dictBytes, dictOffsets, PType.I64, + codesBuf, codePType); + } + + private static long readCode(MemorySegment buf, PType codePType, long i) { + long cap = SegmentBroadcast.capacity(buf, codePType.byteSize()); + long idx = i % cap; + return switch (codePType) { + case U8 -> Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, idx)); + case U16 -> Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, idx * 2)); + case U32 -> Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, idx * 4)); + default -> throw new VortexException(EncodingId.VORTEX_DICT, "unexpected code type: " + codePType); + }; + } + + private static void expandU8(MemorySegment codes, MemorySegment values, MemorySegment out, long rowCount, int elemSize) { + long codesCap = SegmentBroadcast.capacity(codes, 1); + long valuesCap = SegmentBroadcast.capacity(values, elemSize); + boolean fast = codesCap >= rowCount && valuesCap > 1; + switch (elemSize) { + case 8 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i % codesCap)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code % valuesCap)); + } + } + } + case 4 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i % codesCap)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code % valuesCap)); + } + } + } + case 2 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i % codesCap)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code % valuesCap)); + } + } + } + case 1 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i % codesCap)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code % valuesCap)); + } + } + } + default -> { + if (fast) { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i)); + MemorySegment.copy(values, code * elemSize, out, outOff, elemSize); + } + } else { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i % codesCap)); + MemorySegment.copy(values, (code % valuesCap) * elemSize, out, outOff, elemSize); + } + } + } + } + } + + private static void expandU16(MemorySegment codes, MemorySegment values, MemorySegment out, long rowCount, int elemSize) { + long codesCap = SegmentBroadcast.capacity(codes, 2); + long valuesCap = SegmentBroadcast.capacity(values, elemSize); + boolean fast = codesCap >= rowCount && valuesCap > 1; + switch (elemSize) { + case 8 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, i * 2)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, (i % codesCap) * 2)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code % valuesCap)); + } + } + } + case 4 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, i * 2)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, (i % codesCap) * 2)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code % valuesCap)); + } + } + } + case 2 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, i * 2)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, (i % codesCap) * 2)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code % valuesCap)); + } + } + } + case 1 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, i * 2)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, (i % codesCap) * 2)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code % valuesCap)); + } + } + } + default -> { + if (fast) { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, i * 2)); + MemorySegment.copy(values, code * elemSize, out, outOff, elemSize); + } + } else { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, (i % codesCap) * 2)); + MemorySegment.copy(values, (code % valuesCap) * elemSize, out, outOff, elemSize); + } + } + } + } + } + + private static void expandU32(MemorySegment codes, MemorySegment values, MemorySegment out, long rowCount, int elemSize) { + long codesCap = SegmentBroadcast.capacity(codes, 4); + long valuesCap = SegmentBroadcast.capacity(values, elemSize); + boolean fast = codesCap >= rowCount && valuesCap > 1; + switch (elemSize) { + case 8 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, i * 4)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, (i % codesCap) * 4)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code % valuesCap)); + } + } + } + case 4 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, i * 4)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, (i % codesCap) * 4)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code % valuesCap)); + } + } + } + case 2 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, i * 4)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, (i % codesCap) * 4)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code % valuesCap)); + } + } + } + case 1 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, i * 4)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, (i % codesCap) * 4)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code % valuesCap)); + } + } + } + default -> { + if (fast) { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, i * 4)); + MemorySegment.copy(values, code * elemSize, out, outOff, elemSize); + } + } else { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, (i % codesCap) * 4)); + MemorySegment.copy(values, (code % valuesCap) * elemSize, out, outOff, elemSize); + } + } + } + } + } + + private static Array typedArray(DType dtype, PType ptype, long n, MemorySegment seg) { + return switch (ptype) { + case I64, U64 -> new LongArray(dtype, n, seg); + case I32, U32 -> new IntArray(dtype, n, seg); + case F64 -> new DoubleArray(dtype, n, seg); + case F32 -> new FloatArray(dtype, n, seg); + case I16, U16 -> new ShortArray(dtype, n, seg); + case I8, U8 -> new ByteArray(dtype, n, seg); + default -> throw new VortexException(EncodingId.VORTEX_DICT, "unsupported ptype " + ptype); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/EncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/EncodingDecoder.java new file mode 100644 index 00000000..533ac254 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/EncodingDecoder.java @@ -0,0 +1,30 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-side decoding interface. Implementations live in the {@code reader} module and +/// are discovered via {@link java.util.ServiceLoader}. +/// +///

Register via {@link io.github.dfa1.vortex.reader.ReadRegistry} — implementations +/// are discoverable via {@link java.util.ServiceLoader}. +public interface EncodingDecoder { + + /// Returns the wire identifier of this decoder. + /// + /// @return the wire identifier + EncodingId encodingId(); + + /// Returns whether this decoder handles the given dtype. + /// + /// @param dtype the dtype to test + /// @return {@code true} if this decoder can handle arrays of {@code dtype} + boolean accepts(DType dtype); + + /// Decodes an array node from the file using the provided context. + /// + /// @param ctx decoding context containing buffers, dtype, row count, and child registry + /// @return decoded array + Array decode(DecodeContext ctx); +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ExtEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ExtEncodingDecoder.java new file mode 100644 index 00000000..9bd002d4 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ExtEncodingDecoder.java @@ -0,0 +1,37 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-only decoder for {@code vortex.ext} — unwraps the storage-array child. +public final class ExtEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ExtEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_EXT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Extension; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.Extension ext)) { + throw new VortexException(EncodingId.VORTEX_EXT, "expected extension dtype, got " + ctx.dtype()); + } + long n = ctx.rowCount(); + ArrayNode childNode = ctx.node().children()[0]; + DecodeContext childCtx = new DecodeContext( + childNode, ext.storageDType(), n, + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + return ctx.registry().decode(childCtx); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FixedSizeListEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FixedSizeListEncodingDecoder.java new file mode 100644 index 00000000..046acfd7 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FixedSizeListEncodingDecoder.java @@ -0,0 +1,51 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.FixedSizeListArray; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-only decoder for {@code vortex.fixed_size_list}. +public final class FixedSizeListEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FixedSizeListEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_FIXED_SIZE_LIST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.FixedSizeList; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.FixedSizeList fsl)) { + throw new VortexException(EncodingId.VORTEX_FIXED_SIZE_LIST, + "expected DType.FixedSizeList, got " + ctx.dtype()); + } + + int nchildren = ctx.node().children().length; + if (nchildren < 1 || nchildren > 2) { + throw new VortexException(EncodingId.VORTEX_FIXED_SIZE_LIST, + "expected 1 or 2 children, got " + nchildren); + } + + long outerLen = ctx.rowCount(); + long elemLen = outerLen * fsl.fixedSize(); + DType elementType = fsl.elementType(); + + ArrayNode elemNode = ctx.node().children()[0]; + var elemCtx = new DecodeContext( + elemNode, elementType, elemLen, + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array elements = ctx.registry().decode(elemCtx); + + return new FixedSizeListArray(fsl, outerLen, elements); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FrameOfReferenceEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FrameOfReferenceEncodingDecoder.java new file mode 100644 index 00000000..6a3d9ea4 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FrameOfReferenceEncodingDecoder.java @@ -0,0 +1,130 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code fastlanes.for} (Frame of Reference). +public final class FrameOfReferenceEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FrameOfReferenceEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_FOR; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null || !rawMeta.hasRemaining()) { + throw new VortexException(EncodingId.FASTLANES_FOR, "missing metadata"); + } + ScalarValue scalar; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + scalar = ScalarValue.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.FASTLANES_FOR, "invalid metadata", e); + } + + Array encoded = ctx.decodeChild(0); + + BoolArray validity = null; + Array rawEncoded = encoded; + if (encoded instanceof MaskedArray masked) { + rawEncoded = masked.inner(); + validity = masked.validity(); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.FASTLANES_FOR, "expected primitive dtype, got " + ctx.dtype()); + } + + long ref = referenceValue(scalar); + if (ref == 0L) { + return validity != null ? new MaskedArray(rawEncoded, validity) : rawEncoded; + } + + MemorySegment src = ArraySegments.of(rawEncoded); + long n = ctx.rowCount(); + MemorySegment dst = applyReference(src, n, p.ptype(), ref, ctx.arena()); + Array result = switch (p.ptype()) { + case I64, U64 -> new LongArray(ctx.dtype(), n, dst); + case I32, U32 -> new IntArray(ctx.dtype(), n, dst); + case F64 -> new DoubleArray(ctx.dtype(), n, dst); + case I16, U16 -> new ShortArray(ctx.dtype(), n, dst); + case I8, U8 -> new ByteArray(ctx.dtype(), n, dst); + default -> throw new VortexException(EncodingId.FASTLANES_FOR, "unsupported ptype " + p.ptype()); + }; + return validity != null ? new MaskedArray(result, validity) : result; + } + + private static long referenceValue(ScalarValue scalar) { + if (scalar.int64_value() != null) { + return scalar.int64_value(); + } + if (scalar.uint64_value() != null) { + return scalar.uint64_value(); + } + return 0L; + } + + private static MemorySegment applyReference(MemorySegment src, long n, PType ptype, long ref, SegmentAllocator arena) { + int wordBytes = ptype.byteSize(); + MemorySegment dst = arena.allocate(n * wordBytes); + switch (ptype) { + case I8, U8 -> { + for (long off = 0, end = n; off < end; off++) { + byte v = src.get(ValueLayout.JAVA_BYTE, off); + dst.set(ValueLayout.JAVA_BYTE, off, (byte) (v + (byte) ref)); + } + } + case I16, U16 -> { + for (long off = 0, end = n * 2; off < end; off += 2) { + short v = src.get(PTypeIO.LE_SHORT, off); + dst.set(PTypeIO.LE_SHORT, off, (short) (v + (short) ref)); + } + } + case I32, U32 -> { + for (long off = 0, end = n * 4; off < end; off += 4) { + int v = src.get(PTypeIO.LE_INT, off); + dst.set(PTypeIO.LE_INT, off, v + (int) ref); + } + } + case I64, U64 -> { + for (long off = 0, end = n * 8; off < end; off += 8) { + long v = src.get(PTypeIO.LE_LONG, off); + dst.set(PTypeIO.LE_LONG, off, v + ref); + } + } + default -> throw new VortexException(EncodingId.FASTLANES_FOR, "unsupported ptype " + ptype); + } + return dst; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FsstEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FsstEncodingDecoder.java new file mode 100644 index 00000000..edeae832 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FsstEncodingDecoder.java @@ -0,0 +1,115 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.FSSTMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.fsst}. +public final class FsstEncodingDecoder implements EncodingDecoder { + + private static final int ESCAPE = 0xFF; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FsstEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_FSST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null) { + throw new VortexException(EncodingId.VORTEX_FSST, "missing metadata"); + } + FSSTMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = FSSTMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_FSST, "invalid metadata", e); + } + + PType uncompLenPType = PType.fromOrdinal(meta.uncompressed_lengths_ptype().value()); + PType codesOffPType = PType.fromOrdinal(meta.codes_offsets_ptype().value()); + + long n = ctx.rowCount(); + + MemorySegment symbolsBuf = ctx.buffer(0); + MemorySegment symbolLensBuf = ctx.buffer(1); + MemorySegment compressedBytes = ctx.buffer(2); + + MemorySegment uncompLensSeg = ctx.decodeChildSegment(0, new DType.Primitive(uncompLenPType, false), n); + MemorySegment codesOffsetsSeg = ctx.decodeChildSegment(1, new DType.Primitive(codesOffPType, false), n + 1); + long uncompLensCap = SegmentBroadcast.capacity(uncompLensSeg, uncompLenPType.byteSize()); + long codesOffCap = SegmentBroadcast.capacity(codesOffsetsSeg, codesOffPType.byteSize()); + + long totalUncompressed = 0L; + for (long i = 0; i < n; i++) { + totalUncompressed += readUnsigned(uncompLensSeg, i % uncompLensCap, uncompLenPType); + } + + MemorySegment outBytes = ctx.arena().allocate(totalUncompressed); + MemorySegment outOffsets = ctx.arena().allocate((n + 1) * 4L, 4); + outOffsets.setAtIndex(PTypeIO.LE_INT, 0, 0); + + long outPos = 0L; + for (long i = 0; i < n; i++) { + long cStart = readUnsigned(codesOffsetsSeg, i % codesOffCap, codesOffPType); + long cEnd = readUnsigned(codesOffsetsSeg, (i + 1) % codesOffCap, codesOffPType); + outPos = decompressString(compressedBytes, symbolsBuf, symbolLensBuf, + cStart, cEnd, outBytes, outPos); + outOffsets.setAtIndex(PTypeIO.LE_INT, i + 1, (int) outPos); + } + + return new VarBinArray(ctx.dtype(), n, outBytes.asReadOnly(), outOffsets.asReadOnly(), PType.I32); + } + + private static long decompressString( + MemorySegment compressed, MemorySegment symbols, MemorySegment symLens, + long start, long end, MemorySegment out, long outPos + ) { + for (long j = start; j < end; j++) { + int b = Byte.toUnsignedInt(compressed.get(ValueLayout.JAVA_BYTE, j)); + if (b == ESCAPE) { + out.set(ValueLayout.JAVA_BYTE, outPos++, compressed.get(ValueLayout.JAVA_BYTE, ++j)); + } else { + int symLen = Byte.toUnsignedInt(symLens.get(ValueLayout.JAVA_BYTE, b)); + long sym = symbols.getAtIndex(PTypeIO.LE_LONG, b); + for (int k = 0; k < symLen; k++) { + out.set(ValueLayout.JAVA_BYTE, outPos++, (byte) (sym >>> (k * 8))); + } + } + } + return outPos; + } + + private static long readUnsigned(MemorySegment seg, long idx, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, idx)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, idx * 2)); + case U32 -> Integer.toUnsignedLong(seg.getAtIndex(PTypeIO.LE_INT, idx)); + case I32 -> seg.getAtIndex(PTypeIO.LE_INT, idx); + case I64, U64 -> seg.getAtIndex(PTypeIO.LE_LONG, idx); + default -> throw new VortexException(EncodingId.VORTEX_FSST, "unsupported ptype " + ptype); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/KnownArrayNode.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/KnownArrayNode.java new file mode 100644 index 00000000..53e8ee7d --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/KnownArrayNode.java @@ -0,0 +1,22 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.ArrayStats; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.nio.ByteBuffer; + +/// Array node whose encoding id is well-known to this build (an {@link EncodingId} enum constant). +/// +/// @param encodingId well-known encoding id +/// @param metadata encoding-specific metadata bytes, or {@code null} +/// @param children child nodes +/// @param bufferIndices segment buffer indices +/// @param stats optional zone-map statistics +public record KnownArrayNode( + EncodingId encodingId, + ByteBuffer metadata, + ArrayNode[] children, + int[] bufferIndices, + ArrayStats stats +) implements ArrayNode { +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListEncodingDecoder.java new file mode 100644 index 00000000..5e51f82d --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListEncodingDecoder.java @@ -0,0 +1,64 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ListArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ListMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; + +/// Read-only decoder for {@code vortex.list}. +public final class ListEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ListEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_LIST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.List; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.List listDtype)) { + throw new VortexException(EncodingId.VORTEX_LIST, + "expected DType.List, got " + ctx.dtype()); + } + + int nchildren = ctx.node().children().length; + if (nchildren < 2 || nchildren > 3) { + throw new VortexException(EncodingId.VORTEX_LIST, + "expected 2 or 3 children, got " + nchildren); + } + + ListMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(ctx.metadata().duplicate()); + meta = ListMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_LIST, "invalid metadata", e); + } + + long elementsLen = meta.elements_len(); + PType offsetPtype = PType.fromOrdinal(meta.offset_ptype().value()); + long outerLen = ctx.rowCount(); + + DType elementDtype = listDtype.elementType(); + DType offsetsDtype = new DType.Primitive(offsetPtype, false); + + Array elements = ctx.decodeChild(0, elementDtype, elementsLen); + Array offsets = ctx.decodeChild(1, offsetsDtype, outerLen + 1); + + return new ListArray(listDtype, outerLen, elements, offsets); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListViewEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListViewEncodingDecoder.java new file mode 100644 index 00000000..a14cc068 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListViewEncodingDecoder.java @@ -0,0 +1,67 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ListViewArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ListViewMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; + +/// Read-only decoder for {@code vortex.listview}. +public final class ListViewEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ListViewEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_LISTVIEW; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.List; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.List listDtype)) { + throw new VortexException(EncodingId.VORTEX_LISTVIEW, + "expected DType.List, got " + ctx.dtype()); + } + + int nchildren = ctx.node().children().length; + if (nchildren < 3 || nchildren > 4) { + throw new VortexException(EncodingId.VORTEX_LISTVIEW, + "expected 3 or 4 children, got " + nchildren); + } + + ListViewMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(ctx.metadata().duplicate()); + meta = ListViewMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_LISTVIEW, "invalid metadata", e); + } + + long elementsLen = meta.elements_len(); + PType offsetPtype = PType.fromOrdinal(meta.offset_ptype().value()); + PType sizePtype = PType.fromOrdinal(meta.size_ptype().value()); + long outerLen = ctx.rowCount(); + + DType elementDtype = listDtype.elementType(); + DType offsetsDtype = new DType.Primitive(offsetPtype, false); + DType sizesDtype = new DType.Primitive(sizePtype, false); + + Array elements = ctx.decodeChild(0, elementDtype, elementsLen); + Array offsets = ctx.decodeChild(1, offsetsDtype, outerLen); + Array sizes = ctx.decodeChild(2, sizesDtype, outerLen); + + return new ListViewArray(listDtype, outerLen, elements, offsets, sizes); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/MaskedEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/MaskedEncodingDecoder.java new file mode 100644 index 00000000..3490d9e2 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/MaskedEncodingDecoder.java @@ -0,0 +1,53 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-only decoder for {@code vortex.masked} — payload child + optional validity bitmap child. +public final class MaskedEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public MaskedEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_MASKED; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public Array decode(DecodeContext ctx) { + if (ctx.node().bufferIndices().length != 0) { + throw new VortexException(EncodingId.VORTEX_MASKED, + "expected 0 buffers, got " + ctx.node().bufferIndices().length); + } + int numChildren = ctx.node().children().length; + if (numChildren < 1 || numChildren > 2) { + throw new VortexException(EncodingId.VORTEX_MASKED, + "expected 1 or 2 children, got " + numChildren); + } + + Array child = ctx.decodeChild(0, ctx.dtype().withNullable(false), ctx.rowCount()); + + BoolArray validity = null; + if (numChildren == 2) { + Array validityArray = ctx.decodeChild(1, new DType.Bool(false), ctx.rowCount()); + if (!(validityArray instanceof BoolArray ba)) { + throw new VortexException(EncodingId.VORTEX_MASKED, + "validity child decoded to unexpected type: " + validityArray.getClass().getSimpleName()); + } + validity = ba; + } + + return new MaskedArray(child, validity); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoder.java new file mode 100644 index 00000000..cc4a9ddb --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoder.java @@ -0,0 +1,29 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.NullArray; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-only decoder for {@code vortex.null} (all-null arrays). +public final class NullEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public NullEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_NULL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Null; + } + + @Override + public Array decode(DecodeContext ctx) { + return new NullArray(ctx.dtype(), ctx.rowCount()); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoder.java new file mode 100644 index 00000000..93324b1a --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoder.java @@ -0,0 +1,124 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.PatchedMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.patched}. +public final class PatchedEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PatchedEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PATCHED; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null || !rawMeta.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_PATCHED, "missing metadata"); + } + + long nPatches; + long nLanes; + long offset; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + PatchedMetadata meta = PatchedMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + nPatches = Integer.toUnsignedLong(meta.n_patches()); + nLanes = Integer.toUnsignedLong(meta.n_lanes()); + offset = Integer.toUnsignedLong(meta.offset()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_PATCHED, "invalid metadata", e); + } + + if (nLanes == 0) { + throw new VortexException(EncodingId.VORTEX_PATCHED, "n_lanes must be > 0"); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_PATCHED, + "expected primitive dtype, got " + ctx.dtype()); + } + + PType ptype = p.ptype(); + long n = ctx.rowCount(); + long nChunks = (n + offset + 1023) / 1024; + int elemBytes = ptype.byteSize(); + + MemorySegment innerSeg = ctx.decodeChildSegment(0, ctx.dtype(), n); + MemorySegment laneOffsetsSeg = ctx.decodeChildSegment(1, + new DType.Primitive(PType.U32, false), nChunks * nLanes + 1); + MemorySegment patchIndicesSeg = ctx.decodeChildSegment(2, + new DType.Primitive(PType.U16, false), nPatches); + MemorySegment patchValuesSeg = ctx.decodeChildSegment(3, ctx.dtype(), nPatches); + + MemorySegment out = ctx.arena().allocate(n * elemBytes); + SegmentBroadcast.broadcastCopy(innerSeg, out, n, elemBytes); + + if (nPatches > 0) { + applyPatches(out, n, nChunks, nLanes, offset, elemBytes, + laneOffsetsSeg, patchIndicesSeg, patchValuesSeg); + } + + return switch (ptype) { + case I8, U8 -> new ByteArray(ctx.dtype(), n, out); + case I16, U16 -> new ShortArray(ctx.dtype(), n, out); + case I32, U32 -> new IntArray(ctx.dtype(), n, out); + case I64, U64 -> new LongArray(ctx.dtype(), n, out); + case F32 -> new FloatArray(ctx.dtype(), n, out); + case F64 -> new DoubleArray(ctx.dtype(), n, out); + default -> throw new VortexException(EncodingId.VORTEX_PATCHED, + "unsupported ptype: " + ptype); + }; + } + + private static void applyPatches( + MemorySegment out, long n, long nChunks, long nLanes, long offset, int elemBytes, + MemorySegment laneOffsets, MemorySegment patchIndices, MemorySegment patchValues + ) { + long laneCap = SegmentBroadcast.capacity(laneOffsets, 4); + long idxCap = SegmentBroadcast.capacity(patchIndices, 2); + long valCap = SegmentBroadcast.capacity(patchValues, elemBytes); + for (long chunk = 0; chunk < nChunks; chunk++) { + long start = Integer.toUnsignedLong( + laneOffsets.getAtIndex(PTypeIO.LE_INT, (chunk * nLanes) % laneCap)); + long stop = Integer.toUnsignedLong( + laneOffsets.getAtIndex(PTypeIO.LE_INT, (chunk * nLanes + nLanes) % laneCap)); + + for (long i = start; i < stop; i++) { + long physicalIdx = chunk * 1024 + + Short.toUnsignedLong(patchIndices.getAtIndex(PTypeIO.LE_SHORT, i % idxCap)); + if (physicalIdx < offset || physicalIdx >= offset + n) { + continue; + } + long outputIdx = physicalIdx - offset; + MemorySegment.copy(patchValues, (i % valCap) * elemBytes, out, outputIdx * elemBytes, elemBytes); + } + } + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoder.java new file mode 100644 index 00000000..64b15e06 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoder.java @@ -0,0 +1,786 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.LeBitReader; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.PcoBin; +import io.github.dfa1.vortex.encoding.PcoTansDecoder; +import io.github.dfa1.vortex.proto.PcoChunkInfo; +import io.github.dfa1.vortex.proto.PcoMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/// Read-only decoder for {@code vortex.pco} — port of pcodec. +public final class PcoEncodingDecoder implements EncodingDecoder { + static final byte PCO_FORMAT_MAJOR = 0x04; + static final byte PCO_FORMAT_MINOR = 0x01; + static final int BITS_TO_ENCODE_OFFSET_BITS_64 = 7; + static final int BITS_TO_ENCODE_OFFSET_BITS_32 = 6; + static final int BITS_TO_ENCODE_OFFSET_BITS_16 = 5; + + private static final ValueLayout.OfLong LE_LONG = + ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PcoEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PCO; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public Array decode(DecodeContext ctx) { + PcoMetadata meta = parseMeta(ctx); + validateHeader(meta); + + DType dtype = ctx.dtype(); + if (!(dtype instanceof DType.Primitive dt)) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco decode requires Primitive dtype, got: " + dtype); + } + PType ptype = dt.ptype(); + int dtypeSize = dtypeSize(ptype); + + long n = ctx.rowCount(); + + BoolArray validity = null; + long validCount = n; + if (ctx.node().children().length > 0) { + Array validityArr = ctx.decodeChild(0, new DType.Bool(false), n); + if (!(validityArr instanceof BoolArray ba)) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco validity child must be Bool, got: " + validityArr.getClass().getSimpleName()); + } + validity = ba; + validCount = 0; + for (long i = 0; i < n; i++) { + if (validity.getBoolean(i)) { + validCount++; + } + } + } + + MemorySegment rawLatents = ctx.arena().allocate(validCount * Long.BYTES); + + int nChunks = meta.chunks().size(); + int bufIdx = 0; + long rawByteOffset = 0L; + + long[] batchLowers1 = new long[PcoTansDecoder.BATCH_N]; + int[] batchOffsetBits1 = new int[PcoTansDecoder.BATCH_N]; + long[] batchLowers2 = new long[PcoTansDecoder.BATCH_N]; + int[] batchOffsetBits2 = new int[PcoTansDecoder.BATCH_N]; + + for (int c = 0; c < nChunks; c++) { + PcoChunkInfo chunkInfo = meta.chunks().get(c); + MemorySegment chunkMetaBuf = ctx.buffer(bufIdx++); + PcoChunkMeta chunkMeta = readChunkMeta(chunkMetaBuf, dtypeSize); + + int mode = chunkMeta.mode(); + int deltaVariant = chunkMeta.deltaVariant(); + long chunkStartOffset = rawByteOffset; + + int chunkN = 0; + for (int p = 0; p < chunkInfo.pages().size(); p++) { + chunkN += chunkInfo.pages().get(p).n_values(); + } + + if (deltaVariant == 3) { + PcoTansDecoder primaryTans = PcoTansDecoder.build( + chunkMeta.ansSizeLog(), chunkMeta.bins()); + for (int p = 0; p < chunkInfo.pages().size(); p++) { + int pageN = chunkInfo.pages().get(p).n_values(); + MemorySegment pageBuf = ctx.buffer(bufIdx++); + rawByteOffset = decodeConv1Page( + primaryTans, chunkMeta.ansSizeLog(), + chunkMeta.conv1Weights().length, + chunkMeta.conv1Quantization(), chunkMeta.conv1Bias(), + chunkMeta.conv1Weights(), + dtypeSize, pageBuf, pageN, + rawLatents, rawByteOffset, + batchLowers1, batchOffsetBits1); + } + } else if (deltaVariant == 2) { + if (mode != 0) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco Lookback delta with non-Classic mode " + mode + " not yet implemented"); + } + PcoTansDecoder deltaTans = PcoTansDecoder.build( + chunkMeta.deltaAnsSizeLog(), chunkMeta.deltaBins()); + PcoTansDecoder primaryTans = PcoTansDecoder.build( + chunkMeta.ansSizeLog(), chunkMeta.bins()); + int stateN = 1 << chunkMeta.stateNLog(); + int windowN = 1 << chunkMeta.windowNLog(); + long mid = typeMid(dtypeSize); + long mask = typeMask(dtypeSize); + for (int p = 0; p < chunkInfo.pages().size(); p++) { + int pageN = chunkInfo.pages().get(p).n_values(); + MemorySegment pageBuf = ctx.buffer(bufIdx++); + rawByteOffset = decodeLookbackPage( + deltaTans, chunkMeta.deltaAnsSizeLog(), + primaryTans, chunkMeta.ansSizeLog(), + stateN, windowN, mid, mask, + dtypeSize, pageBuf, pageN, + rawLatents, rawByteOffset, ctx.arena(), + batchLowers1, batchOffsetBits1, + batchLowers2, batchOffsetBits2); + } + } else if (mode == 0 || mode == 4) { + int primaryDtypeSize = (mode == 4) ? 32 : dtypeSize; + PcoTansDecoder tans = PcoTansDecoder.build(chunkMeta.ansSizeLog(), chunkMeta.bins()); + for (int p = 0; p < chunkInfo.pages().size(); p++) { + int pageN = chunkInfo.pages().get(p).n_values(); + MemorySegment pageBuf = ctx.buffer(bufIdx++); + rawByteOffset = decodeClassicPage(tans, chunkMeta.ansSizeLog(), + chunkMeta.deltaOrder(), primaryDtypeSize, + pageBuf, pageN, rawLatents, rawByteOffset, + batchLowers1, batchOffsetBits1); + } + if (mode == 4) { + combineDict(chunkMeta.dict(), chunkN, rawLatents, chunkStartOffset); + } + } else { + long base = chunkMeta.base(); + int primaryAnsSizeLog = chunkMeta.ansSizeLog(); + int secondaryAnsSizeLog = chunkMeta.secondaryAnsSizeLog(); + PcoTansDecoder primaryTans = PcoTansDecoder.build(primaryAnsSizeLog, chunkMeta.bins()); + PcoTansDecoder secondaryTans = PcoTansDecoder.build(secondaryAnsSizeLog, chunkMeta.secondaryBins()); + int deltaOrder = chunkMeta.deltaOrder(); + int secondaryDeltaOrder = chunkMeta.secondaryUsesDelta() ? deltaOrder : 0; + + MemorySegment rawAdjs = ctx.arena().allocate((long) chunkN * Long.BYTES); + long adjByteOffset = 0L; + for (int p = 0; p < chunkInfo.pages().size(); p++) { + int pageN = chunkInfo.pages().get(p).n_values(); + MemorySegment pageBuf = ctx.buffer(bufIdx++); + decodeIntMultPage(primaryTans, primaryAnsSizeLog, deltaOrder, + secondaryTans, secondaryAnsSizeLog, secondaryDeltaOrder, + dtypeSize, pageBuf, pageN, + rawLatents, rawByteOffset, + rawAdjs, adjByteOffset, + batchLowers1, batchOffsetBits1, + batchLowers2, batchOffsetBits2); + rawByteOffset += (long) pageN * Long.BYTES; + adjByteOffset += (long) pageN * Long.BYTES; + } + + if (mode == 1) { + long mask = typeMask(dtypeSize); + for (int i = 0; i < chunkN; i++) { + long off = chunkStartOffset + (long) i * Long.BYTES; + long mult = rawLatents.get(LE_LONG, off); + long adj = rawAdjs.get(LE_LONG, (long) i * Long.BYTES); + rawLatents.set(LE_LONG, off, (mult * base + adj) & mask); + } + } else if (mode == 2) { + combineFloatMult(ptype, base, chunkN, rawLatents, chunkStartOffset, rawAdjs); + } else { + combineFloatQuant(ptype, chunkMeta.quantizeK(), chunkN, rawLatents, chunkStartOffset, rawAdjs); + } + } + } + + int elemBytes = ptype.byteSize(); + MemorySegment compactOut = ctx.arena().allocate(validCount * elemBytes); + for (long i = 0; i < validCount; i++) { + long latent = rawLatents.get(LE_LONG, i * Long.BYTES); + PTypeIO.set(compactOut, i * elemBytes, ptype, fromLatentOrdered(latent, ptype)); + } + + if (validity == null) { + return toArray(dtype, n, compactOut); + } + + MemorySegment fullOut = ctx.arena().allocate(n * elemBytes); + long srcOff = 0; + for (long i = 0; i < n; i++) { + if (validity.getBoolean(i)) { + MemorySegment.copy(compactOut, srcOff, fullOut, i * elemBytes, elemBytes); + srcOff += elemBytes; + } + } + DType nonNullDtype = new DType.Primitive(ptype, false); + return new MaskedArray(toArray(nonNullDtype, n, fullOut), validity); + } + + private static long decodeClassicPage(PcoTansDecoder tans, int ansSizeLog, int deltaOrder, + int primaryDtypeSize, MemorySegment pageBuf, int pageN, + MemorySegment rawLatents, long rawByteOffset, + long[] batchLowers, int[] batchOffsetBits) { + LeBitReader pageReader = new LeBitReader(pageBuf); + + long[] moments = new long[deltaOrder]; + for (int m = 0; m < deltaOrder; m++) { + moments[m] = pageReader.readBits(primaryDtypeSize); + } + + int[] stateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + stateIdxs[i] = (int) pageReader.readBits(ansSizeLog); + } + pageReader.alignToByte(); + + int decodedN = pageN - deltaOrder; + tans.decodePage(pageReader, stateIdxs, decodedN, rawLatents, rawByteOffset, + batchLowers, batchOffsetBits); + + if (deltaOrder > 0) { + applyConsecutiveDelta(rawLatents, rawByteOffset, pageN, moments, primaryDtypeSize); + } + + return rawByteOffset + (long) pageN * Long.BYTES; + } + + private static void decodeIntMultPage( + PcoTansDecoder primaryTans, int primaryAnsSizeLog, int deltaOrder, + PcoTansDecoder secondaryTans, int secondaryAnsSizeLog, int secondaryDeltaOrder, + int dtypeSize, MemorySegment pageBuf, int pageN, + MemorySegment rawMults, long multsOffset, + MemorySegment rawAdjs, long adjsOffset, + long[] batchLowersP, int[] batchOffsetBitsP, + long[] batchLowersS, int[] batchOffsetBitsS) { + LeBitReader pageReader = new LeBitReader(pageBuf); + + long[] primaryMoments = new long[deltaOrder]; + for (int m = 0; m < deltaOrder; m++) { + primaryMoments[m] = pageReader.readBits(dtypeSize); + } + int[] primaryStateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + primaryStateIdxs[i] = (int) pageReader.readBits(primaryAnsSizeLog); + } + + long[] secondaryMoments = new long[secondaryDeltaOrder]; + for (int m = 0; m < secondaryDeltaOrder; m++) { + secondaryMoments[m] = pageReader.readBits(dtypeSize); + } + int[] secondaryStateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + secondaryStateIdxs[i] = (int) pageReader.readBits(secondaryAnsSizeLog); + } + + pageReader.alignToByte(); + + int nRemaining = pageN; + long primaryPos = multsOffset; + long secondaryPos = adjsOffset; + + while (nRemaining > 0) { + int batchN = Math.min(nRemaining, PcoTansDecoder.BATCH_N); + int primaryPreDeltaN = Math.clamp(nRemaining - deltaOrder, 0, batchN); + int secondaryPreDeltaN = Math.clamp(nRemaining - secondaryDeltaOrder, 0, batchN); + + primaryTans.decodeBatch(pageReader, primaryStateIdxs, primaryPreDeltaN, + batchLowersP, batchOffsetBitsP, rawMults, primaryPos); + secondaryTans.decodeBatch(pageReader, secondaryStateIdxs, secondaryPreDeltaN, + batchLowersS, batchOffsetBitsS, rawAdjs, secondaryPos); + + primaryPos += (long) batchN * Long.BYTES; + secondaryPos += (long) batchN * Long.BYTES; + nRemaining -= batchN; + } + + if (deltaOrder > 0) { + applyConsecutiveDelta(rawMults, multsOffset, pageN, primaryMoments, dtypeSize); + } + if (secondaryDeltaOrder > 0) { + applyConsecutiveDelta(rawAdjs, adjsOffset, pageN, secondaryMoments, dtypeSize); + } + } + + private static long decodeLookbackPage( + PcoTansDecoder deltaTans, int deltaAnsSizeLog, + PcoTansDecoder primaryTans, int primaryAnsSizeLog, + int stateN, int windowN, long mid, long mask, + int dtypeSize, MemorySegment pageBuf, int pageN, + MemorySegment rawLatents, long latentsOffset, + SegmentAllocator arena, + long[] batchLowersD, int[] batchOffsetBitsD, + long[] batchLowersP, int[] batchOffsetBitsP) { + if (pageN < stateN) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco corrupt lookback page: stateN " + stateN + " exceeds pageN " + pageN); + } + LeBitReader pageReader = new LeBitReader(pageBuf); + + int[] deltaStateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + deltaStateIdxs[i] = (int) pageReader.readBits(deltaAnsSizeLog); + } + + long[] initialState = new long[stateN]; + for (int m = 0; m < stateN; m++) { + initialState[m] = pageReader.readBits(dtypeSize); + } + int[] primaryStateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + primaryStateIdxs[i] = (int) pageReader.readBits(primaryAnsSizeLog); + } + pageReader.alignToByte(); + + int decodeN = pageN - stateN; + if (decodeN > 1 << 23) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco corrupt lookback page: decodeN " + decodeN + " exceeds max 8388608"); + } + MemorySegment rawLookbacks = arena.allocate((long) decodeN * Long.BYTES); + MemorySegment rawResiduals = arena.allocate((long) decodeN * Long.BYTES); + + int remaining = decodeN; + long dPos = 0L; + long pPos = 0L; + while (remaining > 0) { + int batchN = Math.min(remaining, PcoTansDecoder.BATCH_N); + deltaTans.decodeBatch(pageReader, deltaStateIdxs, batchN, + batchLowersD, batchOffsetBitsD, rawLookbacks, dPos); + primaryTans.decodeBatch(pageReader, primaryStateIdxs, batchN, + batchLowersP, batchOffsetBitsP, rawResiduals, pPos); + dPos += (long) batchN * Long.BYTES; + pPos += (long) batchN * Long.BYTES; + remaining -= batchN; + } + + for (int i = 0; i < decodeN; i++) { + long off = (long) i * Long.BYTES; + rawResiduals.set(LE_LONG, off, (rawResiduals.get(LE_LONG, off) ^ mid) & mask); + } + + for (int i = 0; i < stateN; i++) { + rawLatents.set(LE_LONG, latentsOffset + (long) i * Long.BYTES, initialState[i] & mask); + } + + if (stateN > windowN) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco corrupt lookback: stateN " + stateN + " exceeds windowN " + windowN); + } + long[] window = new long[windowN + decodeN]; + for (int i = 0; i < stateN; i++) { + window[windowN - stateN + i] = initialState[i] & mask; + } + for (int i = 0; i < decodeN; i++) { + int lb = (int) rawLookbacks.get(LE_LONG, (long) i * Long.BYTES); + if (lb < 1 || lb > windowN) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco corrupt lookback index " + lb + " not in [1, " + windowN + "]"); + } + long decoded = (rawResiduals.get(LE_LONG, (long) i * Long.BYTES) + window[windowN + i - lb]) & mask; + window[windowN + i] = decoded; + rawLatents.set(LE_LONG, latentsOffset + (long) (stateN + i) * Long.BYTES, decoded); + } + + return latentsOffset + (long) pageN * Long.BYTES; + } + + private static long decodeConv1Page( + PcoTansDecoder tans, int ansSizeLog, + int order, int quantization, long bias, long[] weights, + int dtypeSize, MemorySegment pageBuf, int pageN, + MemorySegment rawLatents, long latentsOffset, + long[] batchLowers, int[] batchOffsetBits) { + LeBitReader pageReader = new LeBitReader(pageBuf); + + long[] state = new long[order]; + for (int i = 0; i < order; i++) { + state[i] = pageReader.readBits(dtypeSize); + } + int[] stateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + stateIdxs[i] = (int) pageReader.readBits(ansSizeLog); + } + pageReader.alignToByte(); + + int decodeN = pageN - order; + long mid = typeMid(dtypeSize); + long mask = typeMask(dtypeSize); + + for (int i = 0; i < order; i++) { + rawLatents.set(LE_LONG, latentsOffset + (long) i * Long.BYTES, state[i]); + } + + tans.decodePage(pageReader, stateIdxs, decodeN, rawLatents, + latentsOffset + (long) order * Long.BYTES, + batchLowers, batchOffsetBits); + + for (int i = order; i < pageN; i++) { + long off = latentsOffset + (long) i * Long.BYTES; + rawLatents.set(LE_LONG, off, (rawLatents.get(LE_LONG, off) ^ mid) & mask); + } + + for (int i = order; i < pageN; i++) { + long pred = predictConv1(rawLatents, latentsOffset, i, order, + weights, bias, quantization, mask, dtypeSize); + long off = latentsOffset + (long) i * Long.BYTES; + rawLatents.set(LE_LONG, off, (rawLatents.get(LE_LONG, off) + pred) & mask); + } + + return latentsOffset + (long) pageN * Long.BYTES; + } + + private static long predictConv1(MemorySegment seg, long baseOff, int pos, int order, + long[] weights, long bias, int quantization, long mask, int dtypeSize) { + long s = (dtypeSize == 16) ? (int) bias : bias; + for (int k = 0; k < order; k++) { + long w = (dtypeSize == 16) ? (int) weights[k] : weights[k]; + long l = seg.get(LE_LONG, baseOff + (long) (pos - order + k) * Long.BYTES); + s += w * l; + } + if (s < 0) { + s = 0; + } + return (s >> quantization) & mask; + } + + private static long fromLatentOrdered(long latent, PType ptype) { + return switch (ptype) { + case I16 -> latent ^ 0x8000L; + case I32 -> latent ^ 0x80000000L; + case I64 -> latent ^ Long.MIN_VALUE; + case F32 -> { + long l32 = latent & 0xFFFFFFFFL; + yield (l32 & 0x80000000L) != 0 ? l32 ^ 0x80000000L : l32 ^ 0xFFFFFFFFL; + } + case F64 -> (latent & Long.MIN_VALUE) != 0 ? latent ^ Long.MIN_VALUE : ~latent; + default -> latent; + }; + } + + private static void applyConsecutiveDelta(MemorySegment rawLatents, long offset, + int pageN, long[] moments, int dtypeSize) { + long mid = typeMid(dtypeSize); + long mask = typeMask(dtypeSize); + + for (int i = 0; i < pageN; i++) { + long byteOff = offset + (long) i * Long.BYTES; + rawLatents.set(LE_LONG, byteOff, (rawLatents.get(LE_LONG, byteOff) ^ mid) & mask); + } + + for (int m = moments.length - 1; m >= 0; m--) { + long moment = moments[m] & mask; + for (int i = 0; i < pageN; i++) { + long byteOff = offset + (long) i * Long.BYTES; + long tmp = rawLatents.get(LE_LONG, byteOff); + rawLatents.set(LE_LONG, byteOff, moment); + moment = (moment + tmp) & mask; + } + } + } + + private static long typeMid(int dtypeSize) { + return switch (dtypeSize) { + case 64 -> Long.MIN_VALUE; + case 32 -> 0x80000000L; + case 16 -> 0x8000L; + default -> throw new VortexException(EncodingId.VORTEX_PCO, + "pco: invalid dtypeSize " + dtypeSize); + }; + } + + private static long typeMask(int dtypeSize) { + return switch (dtypeSize) { + case 64 -> -1L; + case 32 -> 0xFFFFFFFFL; + case 16 -> 0xFFFFL; + default -> throw new VortexException(EncodingId.VORTEX_PCO, + "pco: invalid dtypeSize " + dtypeSize); + }; + } + + private static int dtypeSize(PType ptype) { + return switch (ptype) { + case I16, U16 -> 16; + case I32, U32, F32 -> 32; + case I64, U64, F64 -> 64; + default -> throw new VortexException(EncodingId.VORTEX_PCO, + "pco: unsupported ptype " + ptype); + }; + } + + private static int bitsToEncodeOffsetBits(int dtypeSize) { + return switch (dtypeSize) { + case 64 -> BITS_TO_ENCODE_OFFSET_BITS_64; + case 32 -> BITS_TO_ENCODE_OFFSET_BITS_32; + case 16 -> BITS_TO_ENCODE_OFFSET_BITS_16; + default -> throw new VortexException(EncodingId.VORTEX_PCO, + "pco: invalid dtypeSize " + dtypeSize); + }; + } + + private static Array toArray(DType dtype, long n, MemorySegment out) { + PType ptype = ((DType.Primitive) dtype).ptype(); + return switch (ptype) { + case I16, U16 -> new ShortArray(dtype, n, out); + case I32, U32 -> new IntArray(dtype, n, out); + case F32 -> new FloatArray(dtype, n, out); + case I64, U64 -> new LongArray(dtype, n, out); + case F64 -> new DoubleArray(dtype, n, out); + default -> throw new VortexException(EncodingId.VORTEX_PCO, + "pco: unsupported ptype " + ptype); + }; + } + + private static float intFloatFromLatentF32(long l) { + long mid = 0x80000000L; + boolean negative = (l < mid); + long absInt = negative ? (0x7FFFFFFFL - l) : (l ^ 0x80000000L); + long gpi = 1L << 24; + float absFloat = (absInt < gpi) ? (float) absInt + : Float.intBitsToFloat(0x4B800000 + (int) (absInt - gpi)); + return negative ? -absFloat : absFloat; + } + + private static double intFloatFromLatentF64(long l) { + boolean negative = (l >= 0); + long absInt = negative ? (Long.MAX_VALUE - l) : (l ^ Long.MIN_VALUE); + long gpi = 1L << 53; + double absFloat = (absInt < gpi) ? (double) absInt + : Double.longBitsToDouble(0x4340000000000000L + (absInt - gpi)); + return negative ? -absFloat : absFloat; + } + + private static long toLatentOrderedF32(float f) { + int bits = Float.floatToRawIntBits(f); + if ((bits & 0x80000000) != 0) { + return (~bits) & 0xFFFFFFFFL; + } else { + return (bits ^ 0x80000000) & 0xFFFFFFFFL; + } + } + + private static long toLatentOrderedF64(double d) { + long bits = Double.doubleToRawLongBits(d); + if ((bits & Long.MIN_VALUE) != 0) { + return ~bits; + } else { + return bits ^ Long.MIN_VALUE; + } + } + + private static void combineFloatMult(PType ptype, long baseLatent, int chunkN, + MemorySegment rawLatents, long multsOffset, MemorySegment rawAdjs) { + if (ptype == PType.F32) { + float baseFloat = Float.intBitsToFloat((int) fromLatentOrdered(baseLatent, PType.F32)); + for (int i = 0; i < chunkN; i++) { + long off = multsOffset + (long) i * Long.BYTES; + long mult = rawLatents.get(LE_LONG, off); + long adj = rawAdjs.get(LE_LONG, (long) i * Long.BYTES); + long unadjusted = toLatentOrderedF32(intFloatFromLatentF32(mult) * baseFloat); + rawLatents.set(LE_LONG, off, (unadjusted + adj) & 0xFFFFFFFFL); + } + } else { + double baseDouble = Double.longBitsToDouble(fromLatentOrdered(baseLatent, PType.F64)); + for (int i = 0; i < chunkN; i++) { + long off = multsOffset + (long) i * Long.BYTES; + long mult = rawLatents.get(LE_LONG, off); + long adj = rawAdjs.get(LE_LONG, (long) i * Long.BYTES); + long unadjusted = toLatentOrderedF64(intFloatFromLatentF64(mult) * baseDouble); + rawLatents.set(LE_LONG, off, unadjusted + adj); + } + } + } + + private static void combineFloatQuant(PType ptype, int k, int chunkN, + MemorySegment rawLatents, long multsOffset, MemorySegment rawAdjs) { + if (ptype == PType.F32) { + long signCutoff = 0x80000000L >>> k; + long lowestKBitsMax = (1L << k) - 1L; + for (int i = 0; i < chunkN; i++) { + long off = multsOffset + (long) i * Long.BYTES; + long quantum = rawLatents.get(LE_LONG, off); + long adj = rawAdjs.get(LE_LONG, (long) i * Long.BYTES); + long lowestKBits = (quantum >= signCutoff) ? adj : (lowestKBitsMax - adj); + rawLatents.set(LE_LONG, off, (quantum << k) + lowestKBits); + } + } else { + long signCutoff = Long.MIN_VALUE >>> k; + long lowestKBitsMax = (1L << k) - 1L; + for (int i = 0; i < chunkN; i++) { + long off = multsOffset + (long) i * Long.BYTES; + long quantum = rawLatents.get(LE_LONG, off); + long adj = rawAdjs.get(LE_LONG, (long) i * Long.BYTES); + boolean isPos = Long.compareUnsigned(quantum, signCutoff) >= 0; + long lowestKBits = isPos ? adj : (lowestKBitsMax - adj); + rawLatents.set(LE_LONG, off, (quantum << k) + lowestKBits); + } + } + } + + private static void combineDict(long[] dict, int chunkN, + MemorySegment rawLatents, long offset) { + for (int i = 0; i < chunkN; i++) { + long off = offset + (long) i * Long.BYTES; + int idx = (int) rawLatents.get(LE_LONG, off); + if (idx < 0 || idx >= dict.length) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco dict index " + idx + " out of range [0, " + dict.length + ")"); + } + rawLatents.set(LE_LONG, off, dict[idx]); + } + } + + private static PcoChunkMeta readChunkMeta(MemorySegment buf, int dtypeSize) { + LeBitReader r = new LeBitReader(buf); + + int modeNibble = (int) r.readBits(4); + long base = 0L; + int quantizeK = 0; + long[] dict = null; + if (modeNibble == 1 || modeNibble == 2) { + base = r.readBits(dtypeSize); + } else if (modeNibble == 3) { + quantizeK = (int) r.readBits(8); + } else if (modeNibble == 4) { + int nUnique = (int) r.readBits(25); + if (nUnique > 1 << 16) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco dict nUnique " + nUnique + " exceeds max 65536"); + } + r.alignToByte(); + dict = new long[nUnique]; + for (int i = 0; i < nUnique; i++) { + dict[i] = r.readBits(dtypeSize); + } + } else if (modeNibble != 0) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco mode " + modeNibble + " not yet implemented " + + "(Classic=0, IntMult=1, FloatMult=2, FloatQuant=3, Dict=4 supported)"); + } + + int deltaVariant = (int) r.readBits(4); + int deltaOrder = 0; + boolean secondaryUsesDelta = false; + int windowNLog = 0; + int stateNLog = 0; + int conv1Quantization = 0; + long conv1Bias = 0L; + long[] conv1Weights = new long[0]; + if (deltaVariant == 0) { + // NoOp + } else if (deltaVariant == 1) { + deltaOrder = (int) r.readBits(3); + secondaryUsesDelta = r.readBits(1) != 0; + } else if (deltaVariant == 2) { + windowNLog = 1 + (int) r.readBits(5); + if (windowNLog > 24) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco lookback windowNLog " + windowNLog + " exceeds max 24"); + } + stateNLog = (int) r.readBits(4); + secondaryUsesDelta = r.readBits(1) != 0; + } else if (deltaVariant == 3) { + if (dtypeSize == 64) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco Conv1 delta not supported for 64-bit dtypes (I64/U64/F64)"); + } + conv1Quantization = (int) r.readBits(5); + conv1Bias = r.readBits(64) ^ Long.MIN_VALUE; + int conv1Order = 1 + (int) r.readBits(5); + conv1Weights = new long[conv1Order]; + for (int i = 0; i < conv1Order; i++) { + conv1Weights[i] = (int) (r.readBits(32) ^ 0x80000000L); + } + } else { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco delta variant " + deltaVariant + " not yet implemented " + + "(NoOp=0, Consecutive=1, Lookback=2, Conv1=3 supported)"); + } + + int deltaAnsSizeLog = 0; + PcoBin[] deltaBins = new PcoBin[0]; + if (deltaVariant == 2) { + deltaAnsSizeLog = (int) r.readBits(4); + int nDeltaBins = (int) r.readBits(15); + deltaBins = readBins(r, nDeltaBins, deltaAnsSizeLog, 32); + } + + int primaryDtypeSize = (modeNibble == 4) ? 32 : dtypeSize; + int ansSizeLog = (int) r.readBits(4); + int nBins = (int) r.readBits(15); + PcoBin[] bins = readBins(r, nBins, ansSizeLog, primaryDtypeSize); + + int secondaryAnsSizeLog = 0; + PcoBin[] secondaryBins = new PcoBin[0]; + if (modeNibble == 1 || modeNibble == 2 || modeNibble == 3) { + secondaryAnsSizeLog = (int) r.readBits(4); + int nSecondaryBins = (int) r.readBits(15); + secondaryBins = readBins(r, nSecondaryBins, secondaryAnsSizeLog, dtypeSize); + } + r.alignToByte(); + + return new PcoChunkMeta(modeNibble, base, quantizeK, dict, + deltaVariant, deltaOrder, secondaryUsesDelta, + windowNLog, stateNLog, deltaAnsSizeLog, deltaBins, + conv1Quantization, conv1Bias, conv1Weights, + ansSizeLog, bins, secondaryAnsSizeLog, secondaryBins); + } + + private static PcoBin[] readBins(LeBitReader r, int nBins, int ansSizeLog, int dtypeSize) { + PcoBin[] bins = new PcoBin[nBins]; + int offsetBitsWidth = bitsToEncodeOffsetBits(dtypeSize); + for (int b = 0; b < nBins; b++) { + int weight = (int) r.readBits(ansSizeLog) + 1; + long lower = r.readBits(dtypeSize); + int offsetBits = (int) r.readBits(offsetBitsWidth); + bins[b] = new PcoBin(weight, lower, offsetBits); + } + return bins; + } + + private static PcoMetadata parseMeta(DecodeContext ctx) { + ByteBuffer raw = ctx.metadata(); + if (raw == null) { + throw new VortexException(EncodingId.VORTEX_PCO, "missing PcoMetadata"); + } + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(raw.duplicate()); + return PcoMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_PCO, + "invalid PcoMetadata: " + e.getMessage()); + } + } + + private static void validateHeader(PcoMetadata meta) { + byte[] header = meta.header(); + if (header.length < 2) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco header too short: " + header.length + " bytes"); + } + if (header[0] != PCO_FORMAT_MAJOR || header[1] != PCO_FORMAT_MINOR) { + throw new VortexException(EncodingId.VORTEX_PCO, + String.format("unsupported pco format version %02x.%02x (expected %02x.%02x)", + header[0] & 0xFF, header[1] & 0xFF, + PCO_FORMAT_MAJOR & 0xFF, PCO_FORMAT_MINOR & 0xFF)); + } + } + + private record PcoChunkMeta(int mode, long base, int quantizeK, long[] dict, + int deltaVariant, int deltaOrder, boolean secondaryUsesDelta, + int windowNLog, int stateNLog, int deltaAnsSizeLog, PcoBin[] deltaBins, + int conv1Quantization, long conv1Bias, long[] conv1Weights, + int ansSizeLog, PcoBin[] bins, + int secondaryAnsSizeLog, PcoBin[] secondaryBins) { + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PrimitiveEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PrimitiveEncodingDecoder.java new file mode 100644 index 00000000..743ff060 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PrimitiveEncodingDecoder.java @@ -0,0 +1,62 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.Float16Array; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.lang.foreign.MemorySegment; + +/// Read-only decoder for {@code vortex.primitive} — raw little-endian primitive arrays. +public final class PrimitiveEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PrimitiveEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PRIMITIVE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public Array decode(DecodeContext ctx) { + MemorySegment buf = ctx.buffer(0); + long n = ctx.rowCount(); + DType dt = ctx.dtype(); + PType ptype = ((DType.Primitive) dt).ptype(); + Array values = switch (ptype) { + case I64, U64 -> new LongArray(dt, n, buf); + case I32, U32 -> new IntArray(dt, n, buf); + case F64 -> new DoubleArray(dt, n, buf); + case F32 -> new FloatArray(dt, n, buf); + case I16, U16 -> new ShortArray(dt, n, buf); + case I8, U8 -> new ByteArray(dt, n, buf); + case F16 -> new Float16Array(dt, n, buf); + }; + if (ctx.node().children().length == 1) { + Array va = ctx.decodeChild(0, new DType.Bool(false), n); + if (!(va instanceof BoolArray validity)) { + throw new VortexException(EncodingId.VORTEX_PRIMITIVE, + "validity child decoded to unexpected type: " + va.getClass().getSimpleName()); + } + return new MaskedArray(values, validity); + } + return values; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RleEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RleEncodingDecoder.java new file mode 100644 index 00000000..d2012c25 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RleEncodingDecoder.java @@ -0,0 +1,228 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.Float16Array; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.RLEMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code fastlanes.rle}. +public final class RleEncodingDecoder implements EncodingDecoder { + + private static final int FL_CHUNK_SIZE = 1024; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public RleEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_RLE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + RLEMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = RLEMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.FASTLANES_RLE, "invalid metadata", e); + } + + long valuesLen = meta.values_len(); + long indicesLen = meta.indices_len(); + PType indicesPtype = PType.fromOrdinal(meta.indices_ptype().value()); + long offsetsLen = meta.values_idx_offsets_len(); + PType offsetsPtype = PType.fromOrdinal(meta.values_idx_offsets_ptype().value()); + int offset = (int) meta.offset(); + + long rowCount = ctx.rowCount(); + if (rowCount == 0 || indicesLen == 0) { + return emptyArray(ctx); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.FASTLANES_RLE, "expected Primitive dtype, got " + ctx.dtype()); + } + PType ptype = p.ptype(); + + DType valuesDtype = new DType.Primitive(ptype, false); + DType indicesDtype = new DType.Primitive(indicesPtype, false); + DType offsetsDtype = new DType.Primitive(offsetsPtype, false); + + Array indicesRaw = ctx.decodeChild(1, indicesDtype, indicesLen); + + BoolArray indicesValidity = null; + Array indicesArr = indicesRaw; + if (indicesRaw instanceof MaskedArray masked) { + indicesArr = masked.inner(); + indicesValidity = masked.validity(); + } + + long[] values = readLongs(ctx.decodeChildSegment(0, valuesDtype, valuesLen), (int) valuesLen, ptype); + int[] indices = readIndices(ArraySegments.of(indicesArr), (int) indicesLen, indicesPtype); + long[] valuesIdxOffsets = readUnsignedLongs(ctx.decodeChildSegment(2, offsetsDtype, offsetsLen), (int) offsetsLen, offsetsPtype); + + int numChunks = (int) (indicesLen / FL_CHUNK_SIZE); + int chunkEnd = (int) ((offset + rowCount + FL_CHUNK_SIZE - 1) / FL_CHUNK_SIZE); + chunkEnd = Math.min(chunkEnd, numChunks); + + long[] decoded = new long[chunkEnd * FL_CHUNK_SIZE]; + long firstOffset = valuesLen > 0 ? valuesIdxOffsets[0] : 0L; + + for (int chunkIdx = 0; chunkIdx < chunkEnd; chunkIdx++) { + long valueIdxOffset = valuesIdxOffsets[chunkIdx] - firstOffset; + long nextValueIdxOffset = (chunkIdx + 1 < numChunks) + ? (valuesIdxOffsets[chunkIdx + 1] - firstOffset) + : valuesLen; + int numChunkValues = (int) (nextValueIdxOffset - valueIdxOffset); + + int chunkBase = chunkIdx * FL_CHUNK_SIZE; + if (numChunkValues <= 1) { + long fillVal = numChunkValues == 1 ? values[(int) valueIdxOffset] : 0L; + for (int i = 0; i < FL_CHUNK_SIZE; i++) { + decoded[chunkBase + i] = fillVal; + } + } else { + for (int i = 0; i < FL_CHUNK_SIZE; i++) { + int idx = indices[chunkBase + i]; + if (idx >= numChunkValues) { + idx = numChunkValues - 1; + } + decoded[chunkBase + i] = values[(int) valueIdxOffset + idx]; + } + } + } + + MemorySegment seg = fromLongs(decoded, offset, (int) rowCount, ptype, ctx.arena()); + Array result = toArray(ctx.dtype(), rowCount, seg, ptype); + if (indicesValidity == null) { + return result; + } + int validityBytes = (int) ((rowCount + 7) / 8); + MemorySegment validityBuf = ctx.arena().allocate(validityBytes); + for (long j = 0; j < rowCount; j++) { + if (indicesValidity.getBoolean(offset + j)) { + int byteIdx = (int) (j >>> 3); + byte current = validityBuf.get(ValueLayout.JAVA_BYTE, byteIdx); + validityBuf.set(ValueLayout.JAVA_BYTE, byteIdx, (byte) (current | (1 << (j & 7)))); + } + } + BoolArray outputValidity = new BoolArray(new DType.Bool(false), rowCount, validityBuf); + return new MaskedArray(result, outputValidity); + } + + private static Array emptyArray(DecodeContext ctx) { + MemorySegment empty = ctx.arena().allocate(0); + DType dt = ctx.dtype(); + PType ptype = ((DType.Primitive) dt).ptype(); + return toArray(dt, 0L, empty, ptype); + } + + private static Array toArray(DType dtype, long n, MemorySegment seg, PType ptype) { + return switch (ptype) { + case I64, U64 -> new LongArray(dtype, n, seg); + case I32, U32 -> new IntArray(dtype, n, seg); + case I16, U16 -> new ShortArray(dtype, n, seg); + case I8, U8 -> new ByteArray(dtype, n, seg); + case F64 -> new DoubleArray(dtype, n, seg); + case F32 -> new FloatArray(dtype, n, seg); + case F16 -> new Float16Array(dtype, n, seg); + }; + } + + private static long[] readLongs(MemorySegment buf, int count, PType ptype) { + long[] out = new long[count]; + int elemSize = ptype.byteSize(); + long cap = SegmentBroadcast.capacity(buf, elemSize); + for (int i = 0; i < count; i++) { + long off = (i % cap) * elemSize; + out[i] = switch (ptype) { + case I8 -> buf.get(ValueLayout.JAVA_BYTE, off); + case U8 -> Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, off)); + case I16 -> buf.get(PTypeIO.LE_SHORT, off); + case U16, F16 -> Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, off)); + case I32 -> buf.get(PTypeIO.LE_INT, off); + case U32 -> Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, off)); + case I64, U64 -> buf.get(PTypeIO.LE_LONG, off); + case F32 -> Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, off)); + case F64 -> buf.get(PTypeIO.LE_LONG, off); + }; + } + return out; + } + + private static int[] readIndices(MemorySegment buf, int count, PType indicesPtype) { + int[] out = new int[count]; + int elemSize = indicesPtype.byteSize(); + long cap = SegmentBroadcast.capacity(buf, elemSize); + switch (indicesPtype) { + case U8 -> { + for (int i = 0; i < count; i++) { + out[i] = Byte.toUnsignedInt(buf.get(ValueLayout.JAVA_BYTE, i % cap)); + } + } + case U16 -> { + for (int i = 0; i < count; i++) { + out[i] = Short.toUnsignedInt(buf.get(PTypeIO.LE_SHORT, (i % cap) * 2)); + } + } + default -> + throw new VortexException(EncodingId.FASTLANES_RLE, "unsupported indices ptype: " + indicesPtype); + } + return out; + } + + private static long[] readUnsignedLongs(MemorySegment buf, int count, PType ptype) { + long[] out = new long[count]; + int elemSize = ptype.byteSize(); + long cap = SegmentBroadcast.capacity(buf, elemSize); + for (int i = 0; i < count; i++) { + long off = (i % cap) * elemSize; + out[i] = switch (ptype) { + case U8 -> Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, off)); + case U16 -> Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, off)); + case U32 -> Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, off)); + case U64 -> buf.get(PTypeIO.LE_LONG, off); + default -> + throw new VortexException(EncodingId.FASTLANES_RLE, "unsupported offsets ptype: " + ptype); + }; + } + return out; + } + + private static MemorySegment fromLongs(long[] decoded, int offset, int count, PType ptype, SegmentAllocator arena) { + int elemSize = ptype.byteSize(); + MemorySegment seg = arena.allocate((long) count * elemSize); + for (int i = 0; i < count; i++) { + PTypeIO.set(seg, (long) i * elemSize, ptype, decoded[offset + i]); + } + return seg; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RunEndEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RunEndEncodingDecoder.java new file mode 100644 index 00000000..ca07300d --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RunEndEncodingDecoder.java @@ -0,0 +1,271 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.RunEndMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.runend}. +public final class RunEndEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public RunEndEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_RUNEND; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null) { + throw new VortexException(EncodingId.VORTEX_RUNEND, "missing metadata"); + } + + RunEndMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = RunEndMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_RUNEND, "invalid metadata", e); + } + + PType endsPtype = PType.fromOrdinal(meta.ends_ptype().value()); + long numRuns = meta.num_runs(); + long offset = meta.offset(); + + long n = ctx.rowCount(); + DType endsDtype = new DType.Primitive(endsPtype, false); + Array endsArr = ctx.decodeChild(0, endsDtype, numRuns); + + if (ctx.dtype() instanceof DType.Utf8 || ctx.dtype() instanceof DType.Binary) { + Array valuesArr = ctx.decodeChild(1, ctx.dtype(), numRuns); + return expandStrings(endsArr, (VarBinArray) valuesArr, endsPtype, numRuns, offset, n, ctx.dtype(), ctx.arena()); + } + + if (ctx.dtype() instanceof DType.Bool) { + Array valuesArr = ctx.decodeChild(1, ctx.dtype(), numRuns); + return expandBool(endsArr, (BoolArray) valuesArr, endsPtype, numRuns, offset, n, ctx.dtype(), ctx.arena()); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_RUNEND, "expected primitive dtype, got " + ctx.dtype()); + } + PType valuePtype = p.ptype(); + + return expand(ArraySegments.of(endsArr), ctx.decodeChildSegment(1, ctx.dtype(), numRuns), + endsPtype, valuePtype, numRuns, offset, n, ctx.dtype(), ctx.arena()); + } + + private static Array expand( + MemorySegment endsSeg, MemorySegment valuesSeg, + PType endsPtype, PType valuePtype, + long numRuns, long offset, long n, + DType dtype, SegmentAllocator arena + ) { + MemorySegment out = arena.allocate(n * valuePtype.byteSize()); + switch (valuePtype) { + case I8, U8 -> expandByte(endsSeg, valuesSeg, endsPtype, numRuns, offset, n, out); + case I16, U16 -> expandShort(endsSeg, valuesSeg, endsPtype, numRuns, offset, n, out); + case I32, U32 -> expandInt(endsSeg, valuesSeg, endsPtype, numRuns, offset, n, out); + case I64, U64 -> expandLong(endsSeg, valuesSeg, endsPtype, numRuns, offset, n, out); + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "unsupported ptype " + valuePtype); + } + MemorySegment ro = out.asReadOnly(); + return switch (valuePtype) { + case I64, U64 -> new LongArray(dtype, n, ro); + case I32, U32 -> new IntArray(dtype, n, ro); + case I16, U16 -> new ShortArray(dtype, n, ro); + case I8, U8 -> new ByteArray(dtype, n, ro); + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "unsupported ptype " + valuePtype); + }; + } + + private static void expandByte(MemorySegment endsSeg, MemorySegment valuesSeg, + PType endsPtype, long numRuns, long offset, long n, MemorySegment out) { + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + long valCap = SegmentBroadcast.capacity(valuesSeg, 1); + long logicalPos = 0L, outPos = 0L; + for (long run = 0; run < numRuns && outPos < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + byte rawValue = valuesSeg.get(ValueLayout.JAVA_BYTE, run % valCap); + long writeEnd = Math.min(runEnd, offset + n); + for (long lp = Math.max(logicalPos, offset); lp < writeEnd; lp++, outPos++) { + out.set(ValueLayout.JAVA_BYTE, outPos, rawValue); + } + logicalPos = runEnd; + } + } + + private static void expandShort(MemorySegment endsSeg, MemorySegment valuesSeg, + PType endsPtype, long numRuns, long offset, long n, MemorySegment out) { + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + long valCap = SegmentBroadcast.capacity(valuesSeg, 2); + long logicalPos = 0L, outPos = 0L; + for (long run = 0; run < numRuns && outPos < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + short rawValue = valuesSeg.get(PTypeIO.LE_SHORT, (run % valCap) * 2); + long writeEnd = Math.min(runEnd, offset + n); + for (long lp = Math.max(logicalPos, offset); lp < writeEnd; lp++, outPos++) { + out.set(PTypeIO.LE_SHORT, outPos * 2, rawValue); + } + logicalPos = runEnd; + } + } + + private static void expandInt(MemorySegment endsSeg, MemorySegment valuesSeg, + PType endsPtype, long numRuns, long offset, long n, MemorySegment out) { + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + long valCap = SegmentBroadcast.capacity(valuesSeg, 4); + long logicalPos = 0L, outPos = 0L; + for (long run = 0; run < numRuns && outPos < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + int rawValue = valuesSeg.get(PTypeIO.LE_INT, (run % valCap) * 4); + long writeEnd = Math.min(runEnd, offset + n); + for (long lp = Math.max(logicalPos, offset); lp < writeEnd; lp++, outPos++) { + out.set(PTypeIO.LE_INT, outPos * 4, rawValue); + } + logicalPos = runEnd; + } + } + + private static void expandLong(MemorySegment endsSeg, MemorySegment valuesSeg, + PType endsPtype, long numRuns, long offset, long n, MemorySegment out) { + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + long valCap = SegmentBroadcast.capacity(valuesSeg, 8); + long logicalPos = 0L, outPos = 0L; + for (long run = 0; run < numRuns && outPos < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + long rawValue = valuesSeg.get(PTypeIO.LE_LONG, (run % valCap) * 8); + long writeEnd = Math.min(runEnd, offset + n); + for (long lp = Math.max(logicalPos, offset); lp < writeEnd; lp++, outPos++) { + out.set(PTypeIO.LE_LONG, outPos * 8, rawValue); + } + logicalPos = runEnd; + } + } + + private static Array expandBool( + Array endsArr, BoolArray valuesArr, + PType endsPtype, long numRuns, long offset, long n, + DType dtype, SegmentAllocator arena + ) { + MemorySegment endsSeg = ArraySegments.of(endsArr); + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + long numBytes = (n + 7) >>> 3; + MemorySegment out = arena.allocate(numBytes); + + long outIdx = 0; + long logicalPos = 0; + for (long run = 0; run < numRuns && outIdx < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + boolean val = valuesArr.getBoolean(run); + long lo = Math.max(logicalPos, offset); + long hi = Math.min(runEnd, offset + n); + for (long lp = lo; lp < hi; lp++, outIdx++) { + if (val) { + long byteIdx = outIdx >>> 3; + byte cur = out.get(ValueLayout.JAVA_BYTE, byteIdx); + out.set(ValueLayout.JAVA_BYTE, byteIdx, (byte) (cur | (1 << (outIdx & 7)))); + } + } + logicalPos = runEnd; + } + return new BoolArray(dtype, n, out.asReadOnly()); + } + + private static Array expandStrings( + Array endsArr, VarBinArray valuesArr, + PType endsPtype, long numRuns, long offset, long n, + DType dtype, SegmentAllocator arena + ) { + MemorySegment endsSeg = ArraySegments.of(endsArr); + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + MemorySegment valBytes = valuesArr.bytesSegment(); + MemorySegment valOffsets = valuesArr.offsetsSegment(); + PType valOffPtype = valuesArr.offsetsPtype(); + + long totalBytes = 0; + long logicalPos = 0; + for (long run = 0; run < numRuns; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + long lo = Math.max(logicalPos, offset); + long hi = Math.min(runEnd, offset + n); + long count = Math.max(0, hi - lo); + long strLen = readVarBinOffset(valOffsets, run + 1, valOffPtype) + - readVarBinOffset(valOffsets, run, valOffPtype); + totalBytes += count * strLen; + logicalPos = runEnd; + } + + MemorySegment outBytes = arena.allocate(totalBytes > 0 ? totalBytes : 1); + MemorySegment outOffsets = arena.allocate((n + 1) * 4L, 4); + outOffsets.setAtIndex(PTypeIO.LE_INT, 0, 0); + + long bytePos = 0; + long outIdx = 0; + logicalPos = 0; + for (long run = 0; run < numRuns && outIdx < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + long lo = Math.max(logicalPos, offset); + long hi = Math.min(runEnd, offset + n); + if (hi > lo) { + long strStart = readVarBinOffset(valOffsets, run, valOffPtype); + long strEnd = readVarBinOffset(valOffsets, run + 1, valOffPtype); + long strLen = strEnd - strStart; + for (long lp = lo; lp < hi; lp++, outIdx++) { + if (strLen > 0) { + MemorySegment.copy(valBytes, strStart, outBytes, bytePos, strLen); + bytePos += strLen; + } + outOffsets.setAtIndex(PTypeIO.LE_INT, outIdx + 1, (int) bytePos); + } + } + logicalPos = runEnd; + } + + return new VarBinArray(dtype, n, outBytes.asReadOnly(), outOffsets.asReadOnly(), PType.I32); + } + + private static long readUnsigned(MemorySegment seg, long i, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, i)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, i * 2)); + case U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, i * 4)); + case U64 -> seg.get(PTypeIO.LE_LONG, i * 8); + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "non-unsigned ends ptype " + ptype); + }; + } + + private static long readVarBinOffset(MemorySegment seg, long i, PType ptype) { + return switch (ptype) { + case I32, U32 -> Integer.toUnsignedLong(seg.getAtIndex(PTypeIO.LE_INT, i)); + case I64, U64 -> seg.getAtIndex(PTypeIO.LE_LONG, i); + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "unsupported offset ptype " + ptype); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SequenceEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SequenceEncodingDecoder.java new file mode 100644 index 00000000..d11fa421 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SequenceEncodingDecoder.java @@ -0,0 +1,140 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.Float16Array; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.SequenceMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.sequence} — {@code A[i] = base + i * multiplier}. +public final class SequenceEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public SequenceEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_SEQUENCE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer metaBuf = ctx.metadata(); + if (metaBuf == null || !metaBuf.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "missing metadata"); + } + SequenceMetadata meta; + try { + MemorySegment seg = MemorySegment.ofBuffer(metaBuf.duplicate()); + meta = SequenceMetadata.decode(seg, 0, seg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "invalid metadata", e); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "expected primitive dtype, got " + ctx.dtype()); + } + + long n = ctx.rowCount(); + PType pt = p.ptype(); + return switch (pt) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> decodeInteger(meta, pt, n, ctx.dtype(), ctx.arena()); + case F32 -> decodeF32(meta, n, ctx.dtype(), ctx.arena()); + case F64 -> decodeF64(meta, n, ctx.dtype(), ctx.arena()); + case F16 -> decodeF16(meta, n, ctx.dtype(), ctx.arena()); + }; + } + + private static Array decodeInteger( + SequenceMetadata meta, PType pt, long n, DType dtype, SegmentAllocator arena + ) { + long base = signedValue(meta.base()); + long mul = signedValue(meta.multiplier()); + int elemBytes = pt.byteSize(); + MemorySegment seg = arena.allocate(n * elemBytes); + for (long i = 0; i < n; i++) { + long v = base + i * mul; + switch (pt) { + case I8, U8 -> seg.set(ValueLayout.JAVA_BYTE, i, (byte) v); + case I16, U16 -> seg.setAtIndex(PTypeIO.LE_SHORT, i, (short) v); + case I32, U32 -> seg.setAtIndex(PTypeIO.LE_INT, i, (int) v); + case I64, U64 -> seg.setAtIndex(PTypeIO.LE_LONG, i, v); + default -> throw new IllegalStateException("unreachable"); + } + } + return switch (pt) { + case I64, U64 -> new LongArray(dtype, n, seg); + case I32, U32 -> new IntArray(dtype, n, seg); + case I16, U16 -> new ShortArray(dtype, n, seg); + case I8, U8 -> new ByteArray(dtype, n, seg); + default -> throw new VortexException(EncodingId.VORTEX_SEQUENCE, "unsupported ptype " + pt); + }; + } + + private static Array decodeF32(SequenceMetadata meta, long n, DType dtype, SegmentAllocator arena) { + float base = meta.base().f32_value(); + float mul = meta.multiplier().f32_value(); + MemorySegment seg = arena.allocate(n * 4L); + for (long i = 0; i < n; i++) { + seg.setAtIndex(PTypeIO.LE_FLOAT, i, base + i * mul); + } + return new FloatArray(dtype, n, seg); + } + + private static Array decodeF64(SequenceMetadata meta, long n, DType dtype, SegmentAllocator arena) { + double base = meta.base().f64_value(); + double mul = meta.multiplier().f64_value(); + MemorySegment seg = arena.allocate(n * 8L); + for (long i = 0; i < n; i++) { + seg.setAtIndex(PTypeIO.LE_DOUBLE, i, base + i * mul); + } + return new DoubleArray(dtype, n, seg); + } + + private static Array decodeF16(SequenceMetadata meta, long n, DType dtype, SegmentAllocator arena) { + short baseShort = (short) meta.base().f16_value().longValue(); + short mulShort = (short) meta.multiplier().f16_value().longValue(); + float base = Float.float16ToFloat(baseShort); + float mul = Float.float16ToFloat(mulShort); + MemorySegment seg = arena.allocate(n * 2L); + for (long i = 0; i < n; i++) { + seg.setAtIndex(PTypeIO.LE_SHORT, i, Float.floatToFloat16(base + i * mul)); + } + return new Float16Array(dtype, n, seg); + } + + private static long signedValue(ScalarValue sv) { + if (sv == null) { + return 0L; + } + if (sv.int64_value() != null) { + return sv.int64_value(); + } + if (sv.uint64_value() != null) { + return sv.uint64_value(); + } + return 0L; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SparseEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SparseEncodingDecoder.java new file mode 100644 index 00000000..9d4507ad --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SparseEncodingDecoder.java @@ -0,0 +1,259 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.SparseMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/// Read-only decoder for {@code vortex.sparse}. +public final class SparseEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public SparseEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_SPARSE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null || !rawMeta.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_SPARSE, "missing metadata"); + } + SparseMetadata sparseMeta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + sparseMeta = SparseMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_SPARSE, "invalid metadata", e); + } + + PatchesMetadata patches = sparseMeta.patches(); + long numPatches = patches.len(); + long offset = patches.offset(); + PType indicesPtype = PType.fromOrdinal(patches.indices_ptype().value()); + + long n = ctx.rowCount(); + + if (ctx.dtype() instanceof DType.Utf8 || ctx.dtype() instanceof DType.Binary) { + return decodeVarBin(ctx, n, numPatches, offset, indicesPtype); + } + + if (ctx.dtype() instanceof DType.Bool) { + return decodeBool(ctx, n, numPatches, offset, indicesPtype); + } + + if (!(ctx.dtype() instanceof DType.Primitive)) { + throw new VortexException(EncodingId.VORTEX_SPARSE, "expected primitive dtype, got " + ctx.dtype()); + } + PType valuePtype = ((DType.Primitive) ctx.dtype()).ptype(); + + MemorySegment fillBuf = ctx.buffer(0); + ScalarValue fillScalar; + try { + fillScalar = ScalarValue.decode(fillBuf, 0, fillBuf.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_SPARSE, "invalid fill value", e); + } + + int elemBytes = valuePtype.byteSize(); + MemorySegment out = ctx.arena().allocate(n * elemBytes); + fillSegment(out, n, valuePtype, fillScalar); + + if (numPatches > 0) { + DType indicesDtype = new DType.Primitive(indicesPtype, false); + applyPatches(out, n, valuePtype, + ctx.decodeChildSegment(0, indicesDtype, numPatches), + ctx.decodeChildSegment(1, ctx.dtype(), numPatches), + indicesPtype, numPatches, offset); + } + + return switch (valuePtype) { + case I64, U64 -> new LongArray(ctx.dtype(), n, out); + case I32, U32 -> new IntArray(ctx.dtype(), n, out); + case F64 -> new DoubleArray(ctx.dtype(), n, out); + case F32 -> new FloatArray(ctx.dtype(), n, out); + case I16, U16 -> new ShortArray(ctx.dtype(), n, out); + case I8, U8 -> new ByteArray(ctx.dtype(), n, out); + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "unsupported ptype " + valuePtype); + }; + } + + private static Array decodeBool( + DecodeContext ctx, long n, long numPatches, long offset, PType indicesPtype + ) { + long numBytes = (n + 7) >>> 3; + MemorySegment out = ctx.arena().allocate(numBytes); + if (numPatches > 0) { + DType indicesDtype = new DType.Primitive(indicesPtype, false); + MemorySegment idxSeg = ctx.decodeChildSegment(0, indicesDtype, numPatches); + BoolArray bools = (BoolArray) ctx.decodeChild(1, ctx.dtype(), numPatches); + int idxBytes = indicesPtype.byteSize(); + for (long i = 0; i < numPatches; i++) { + if (bools.getBoolean(i)) { + long pos = readUnsignedIdx(idxSeg, SegmentBroadcast.elementOffset(idxSeg, i, idxBytes), indicesPtype) - offset; + long byteIdx = pos >>> 3; + byte cur = out.get(ValueLayout.JAVA_BYTE, byteIdx); + out.set(ValueLayout.JAVA_BYTE, byteIdx, (byte) (cur | (1 << (pos & 7)))); + } + } + } + return new BoolArray(ctx.dtype(), n, out); + } + + private static Array decodeVarBin( + DecodeContext ctx, long n, long numPatches, long offset, PType indicesPtype + ) { + MemorySegment outOffsets = ctx.arena().allocate((n + 1) * 4L, 4); + if (numPatches == 0) { + MemorySegment outBytes = ctx.arena().allocate(1); + return new VarBinArray(ctx.dtype(), n, outBytes, outOffsets, PType.I32); + } + + DType indicesDtype = new DType.Primitive(indicesPtype, false); + MemorySegment idxSeg = ctx.decodeChildSegment(0, indicesDtype, numPatches); + VarBinArray varBin = (VarBinArray) ctx.decodeChild(1, ctx.dtype(), numPatches); + MemorySegment valBytes = varBin.bytesSegment(); + MemorySegment valOffsets = varBin.offsetsSegment(); + PType valOffPtype = varBin.offsetsPtype(); + + int idxBytes = indicesPtype.byteSize(); + long totalBytes = 0; + for (long i = 0; i < numPatches; i++) { + totalBytes += readVarBinOffset(valOffsets, i + 1, valOffPtype) + - readVarBinOffset(valOffsets, i, valOffPtype); + } + + MemorySegment outBytes = ctx.arena().allocate(Math.max(1, totalBytes)); + long patchCursor = 0; + long bytePos = 0; + for (long pos = 0; pos < n; pos++) { + if (patchCursor < numPatches) { + long patchPos = readUnsignedIdx(idxSeg, SegmentBroadcast.elementOffset(idxSeg, patchCursor, idxBytes), indicesPtype) - offset; + if (patchPos == pos) { + long strStart = readVarBinOffset(valOffsets, patchCursor, valOffPtype); + long strEnd = readVarBinOffset(valOffsets, patchCursor + 1, valOffPtype); + long strLen = strEnd - strStart; + if (strLen > 0) { + MemorySegment.copy(valBytes, strStart, outBytes, bytePos, strLen); + bytePos += strLen; + } + patchCursor++; + } + } + outOffsets.setAtIndex(PTypeIO.LE_INT, pos + 1, (int) bytePos); + } + + return new VarBinArray(ctx.dtype(), n, outBytes, outOffsets, PType.I32); + } + + private static long readVarBinOffset(MemorySegment seg, long i, PType ptype) { + return switch (ptype) { + case I32, U32 -> Integer.toUnsignedLong(seg.getAtIndex(PTypeIO.LE_INT, i)); + case I64, U64 -> seg.getAtIndex(PTypeIO.LE_LONG, i); + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "unsupported offset ptype " + ptype); + }; + } + + private static void fillSegment(MemorySegment out, long n, PType ptype, ScalarValue scalar) { + long fillLong = scalarToLong(scalar); + ByteBuffer bb = out.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (long i = 0; i < n; i++) { + writeElem(bb, ptype, fillLong); + } + } + + private static void applyPatches( + MemorySegment out, long n, PType valuePtype, + MemorySegment idxSeg, MemorySegment valSeg, + PType idxPtype, long numPatches, long offset + ) { + int elemBytes = valuePtype.byteSize(); + int idxBytes = idxPtype.byteSize(); + ByteBuffer outBuf = out.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (long i = 0; i < numPatches; i++) { + long idx = readUnsignedIdx(idxSeg, SegmentBroadcast.elementOffset(idxSeg, i, idxBytes), idxPtype) - offset; + if (idx < 0 || idx >= n) { + throw new VortexException(EncodingId.VORTEX_SPARSE, + "patch index " + idx + " out of range [0," + n + ")"); + } + long val = readElem(valSeg, SegmentBroadcast.elementOffset(valSeg, i, elemBytes), valuePtype); + outBuf.position((int) (idx * elemBytes)); + writeElem(outBuf, valuePtype, val); + } + } + + private static long readUnsignedIdx(MemorySegment seg, long off, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case U64 -> seg.get(PTypeIO.LE_LONG, off); + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "non-unsigned index ptype " + ptype); + }; + } + + private static long readElem(MemorySegment seg, long off, PType ptype) { + return switch (ptype) { + case I8, U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case I16, U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case I32, U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case I64, U64, F32, F64 -> seg.get(PTypeIO.LE_LONG, off); + default -> throw new UnsupportedOperationException("vortex.sparse: unsupported ptype " + ptype); + }; + } + + private static void writeElem(ByteBuffer bb, PType ptype, long bits) { + switch (ptype) { + case I8, U8 -> bb.put((byte) bits); + case I16, U16 -> bb.putShort((short) bits); + case I32, U32 -> bb.putInt((int) bits); + case I64, U64, F32, F64 -> bb.putLong(bits); + default -> throw new UnsupportedOperationException("vortex.sparse: unsupported ptype " + ptype); + } + } + + private static long scalarToLong(ScalarValue scalar) { + if (scalar.int64_value() != null) { + return scalar.int64_value(); + } + if (scalar.uint64_value() != null) { + return scalar.uint64_value(); + } + if (scalar.f32_value() != null) { + return Float.floatToRawIntBits(scalar.f32_value()); + } + if (scalar.f64_value() != null) { + return Double.doubleToRawLongBits(scalar.f64_value()); + } + return 0L; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/StructEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/StructEncodingDecoder.java new file mode 100644 index 00000000..7f630840 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/StructEncodingDecoder.java @@ -0,0 +1,105 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.StructArray; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.util.ArrayList; +import java.util.List; + +/// Read-only decoder for {@code vortex.struct}. +public final class StructEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public StructEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_STRUCT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Struct; + } + + @Override + public Array decode(DecodeContext ctx) { + int numChildren = ctx.node().children().length; + + if (ctx.dtype() instanceof DType.Struct structDtype) { + int nfields = structDtype.fieldTypes().size(); + if (numChildren != nfields && numChildren != nfields + 1) { + throw new VortexException(EncodingId.VORTEX_STRUCT, + "expected %d or %d children for struct dtype, got %d" + .formatted(nfields, nfields + 1, numChildren)); + } + boolean hasValidity = (numChildren == nfields + 1); + int fieldOffset = hasValidity ? 1 : 0; + + BoolArray structValidity = null; + if (hasValidity) { + ArrayNode validityNode = ctx.node().children()[0]; + var validityCtx = new DecodeContext(validityNode, new DType.Bool(false), + ctx.rowCount(), ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array va = ctx.registry().decode(validityCtx); + if (!(va instanceof BoolArray ba)) { + throw new VortexException(EncodingId.VORTEX_STRUCT, + "struct validity decoded to unexpected type: " + va.getClass().getSimpleName()); + } + structValidity = ba; + } + + if (nfields == 1) { + DType fieldDtype = structDtype.fieldTypes().getFirst(); + ArrayNode fieldNode = ctx.node().children()[fieldOffset]; + var fieldCtx = new DecodeContext(fieldNode, fieldDtype.withNullable(false), + ctx.rowCount(), ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array field = ctx.registry().decode(fieldCtx); + return structValidity != null ? new MaskedArray(field, structValidity) : field; + } + + List fieldArrays = new ArrayList<>(nfields); + for (int i = 0; i < nfields; i++) { + ArrayNode fieldNode = ctx.node().children()[fieldOffset + i]; + DType fieldDtype = structDtype.fieldTypes().get(i); + var fieldCtx = new DecodeContext(fieldNode, fieldDtype.withNullable(false), + ctx.rowCount(), ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array field = ctx.registry().decode(fieldCtx); + fieldArrays.add(structValidity != null ? new MaskedArray(field, structValidity) : field); + } + return new StructArray(structDtype, ctx.rowCount(), fieldArrays); + } + + if (numChildren == 1) { + ArrayNode valuesNode = ctx.node().children()[0]; + var valuesCtx = new DecodeContext( + valuesNode, ctx.dtype(), ctx.rowCount(), + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + return ctx.registry().decode(valuesCtx); + } else if (numChildren == 2) { + ArrayNode validityNode = ctx.node().children()[0]; + var validityCtx = new DecodeContext(validityNode, new DType.Bool(false), + ctx.rowCount(), ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array va = ctx.registry().decode(validityCtx); + if (!(va instanceof BoolArray validity)) { + throw new VortexException(EncodingId.VORTEX_STRUCT, + "scalar wrapper validity decoded to unexpected type: " + va.getClass().getSimpleName()); + } + ArrayNode valuesNode = ctx.node().children()[1]; + var valuesCtx = new DecodeContext( + valuesNode, ctx.dtype().withNullable(false), ctx.rowCount(), + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array values = ctx.registry().decode(valuesCtx); + return new MaskedArray(values, validity); + } else { + throw new VortexException(EncodingId.VORTEX_STRUCT, + "unexpected child count " + numChildren + " for scalar wrapper"); + } + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/UnknownArrayNode.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/UnknownArrayNode.java new file mode 100644 index 00000000..ca958b45 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/UnknownArrayNode.java @@ -0,0 +1,24 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.ArrayStats; + +import java.nio.ByteBuffer; + +/// Array node whose encoding id is not a recognised {@link io.github.dfa1.vortex.encoding.EncodingId}. +/// Produced when a file uses an encoding this build does not know about. Decoded as +/// {@link io.github.dfa1.vortex.core.array.UnknownArray} when +/// {@link ReadRegistry#isAllowUnknown()} is set; otherwise the decode call throws. +/// +/// @param rawEncodingId the raw encoding id string from the file +/// @param metadata encoding-specific metadata bytes, or {@code null} +/// @param children child nodes +/// @param bufferIndices segment buffer indices +/// @param stats optional zone-map statistics +public record UnknownArrayNode( + String rawEncodingId, + ByteBuffer metadata, + ArrayNode[] children, + int[] bufferIndices, + ArrayStats stats +) implements ArrayNode { +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinEncodingDecoder.java new file mode 100644 index 00000000..b979332e --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinEncodingDecoder.java @@ -0,0 +1,65 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.VarBinMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.varbin}. +public final class VarBinEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VarBinEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARBIN; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null) { + throw new VortexException(EncodingId.VORTEX_VARBIN, "missing metadata"); + } + VarBinMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = VarBinMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_VARBIN, "invalid metadata", e); + } + + PType offsetsPtype = PType.fromOrdinal(meta.offsets_ptype().value()); + DType offsetsDtype = new DType.Primitive(offsetsPtype, false); + long n = ctx.rowCount(); + + MemorySegment offsets = ctx.decodeChildSegment(0, offsetsDtype, n + 1); + + int offBytes = offsetsPtype.byteSize(); + long offCap = SegmentBroadcast.capacity(offsets, offBytes); + if (offCap < n + 1) { + MemorySegment materialized = ctx.arena().allocate((n + 1) * (long) offBytes, offBytes); + SegmentBroadcast.broadcastCopy(offsets, materialized, n + 1, offBytes); + offsets = materialized; + } + + MemorySegment bytes = ctx.buffer(0); + + return new VarBinArray(ctx.dtype(), n, bytes, offsets, offsetsPtype); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinViewEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinViewEncodingDecoder.java new file mode 100644 index 00000000..ea43a1d1 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinViewEncodingDecoder.java @@ -0,0 +1,81 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; + +import java.lang.foreign.MemorySegment; + +/// Read-only decoder for {@code vortex.varbinview} (Apache Arrow StringView/BinaryView). +public final class VarBinViewEncodingDecoder implements EncodingDecoder { + + private static final int MAX_INLINED_SIZE = 12; + private static final int VIEW_SIZE = 16; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VarBinViewEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARBINVIEW; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.Utf8 || ctx.dtype() instanceof DType.Binary)) { + throw new VortexException(EncodingId.VORTEX_VARBINVIEW, + "expected Utf8/Binary dtype, got " + ctx.dtype()); + } + + int numBufs = ctx.node().bufferIndices().length; + if (numBufs < 1) { + throw new VortexException(EncodingId.VORTEX_VARBINVIEW, + "expected at least 1 buffer (views), got 0"); + } + + MemorySegment viewsBuf = ctx.buffer(numBufs - 1); + MemorySegment[] dataBufs = new MemorySegment[numBufs - 1]; + for (int i = 0; i < dataBufs.length; i++) { + dataBufs[i] = ctx.buffer(i); + } + + long n = ctx.rowCount(); + + long totalBytes = 0; + for (long i = 0; i < n; i++) { + long size = Integer.toUnsignedLong(viewsBuf.get(PTypeIO.LE_INT, i * VIEW_SIZE)); + totalBytes += size; + } + + MemorySegment outBytes = ctx.arena().allocate(totalBytes > 0 ? totalBytes : 1); + MemorySegment outOffsets = ctx.arena().allocate((n + 1) * Long.BYTES, Long.BYTES); + + long bytePos = 0; + outOffsets.setAtIndex(PTypeIO.LE_LONG, 0, 0L); + for (long i = 0; i < n; i++) { + long viewOff = i * VIEW_SIZE; + long size = Integer.toUnsignedLong(viewsBuf.get(PTypeIO.LE_INT, viewOff)); + if (size <= MAX_INLINED_SIZE) { + MemorySegment.copy(viewsBuf, viewOff + 4, outBytes, bytePos, size); + } else { + int bufferIndex = viewsBuf.get(PTypeIO.LE_INT, viewOff + 8); + long srcOffset = Integer.toUnsignedLong(viewsBuf.get(PTypeIO.LE_INT, viewOff + 12)); + MemorySegment.copy(dataBufs[bufferIndex], srcOffset, outBytes, bytePos, size); + } + bytePos += size; + outOffsets.setAtIndex(PTypeIO.LE_LONG, i + 1, bytePos); + } + + return new VarBinArray(ctx.dtype(), n, outBytes.asReadOnly(), outOffsets, PType.I64); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoder.java new file mode 100644 index 00000000..62996194 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoder.java @@ -0,0 +1,127 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.VariantArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.VariantMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Read-only decoder for {@code vortex.variant}. +public final class VariantEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VariantEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARIANT; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public Array decode(DecodeContext ctx) { + DType shreddedDtype = parseShreddedDtype(ctx.metadata()); + + int numChildren = ctx.node().children().length; + if (numChildren < 1 || numChildren > 2) { + throw new VortexException(EncodingId.VORTEX_VARIANT, + "expected 1 or 2 children, got " + numChildren); + } + + Array coreStorage = ctx.decodeChild(0, ctx.dtype(), ctx.rowCount()); + + Array shredded = null; + if (shreddedDtype != null && numChildren >= 2) { + shredded = ctx.decodeChild(1, shreddedDtype, ctx.rowCount()); + } + + return new VariantArray(ctx.dtype(), ctx.rowCount(), coreStorage, shredded); + } + + private static DType parseShreddedDtype(ByteBuffer rawMeta) { + if (rawMeta == null || !rawMeta.hasRemaining()) { + return null; + } + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + VariantMetadata meta = VariantMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + if (meta.shredded_dtype() == null) { + return null; + } + return dtypeFromProto(meta.shredded_dtype()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_VARIANT, "invalid metadata", e); + } + } + + static DType dtypeFromProto(io.github.dfa1.vortex.proto.DType proto) { + if (proto.null_() != null) { + return new DType.Null(true); + } + if (proto.bool() != null) { + return new DType.Bool(proto.bool().nullable()); + } + if (proto.primitive() != null) { + return new DType.Primitive( + PType.values()[proto.primitive().type().value()], + proto.primitive().nullable()); + } + if (proto.decimal() != null) { + return new DType.Decimal( + (byte) proto.decimal().precision(), + (byte) proto.decimal().scale(), + proto.decimal().nullable()); + } + if (proto.utf8() != null) { + return new DType.Utf8(proto.utf8().nullable()); + } + if (proto.binary() != null) { + return new DType.Binary(proto.binary().nullable()); + } + if (proto.struct() != null) { + var s = proto.struct(); + var names = new ArrayList(s.names().size()); + var types = new ArrayList(s.dtypes().size()); + names.addAll(s.names()); + for (io.github.dfa1.vortex.proto.DType child : s.dtypes()) { + types.add(dtypeFromProto(child)); + } + return new DType.Struct(List.copyOf(names), List.copyOf(types), s.nullable()); + } + if (proto.list() != null) { + return new DType.List( + dtypeFromProto(proto.list().element_type()), + proto.list().nullable()); + } + if (proto.fixed_size_list() != null) { + return new DType.FixedSizeList( + dtypeFromProto(proto.fixed_size_list().element_type()), + proto.fixed_size_list().size(), + proto.fixed_size_list().nullable()); + } + if (proto.extension() != null) { + return new DType.Extension( + proto.extension().id(), + dtypeFromProto(proto.extension().storage_dtype()), + ByteBuffer.wrap(proto.extension().metadata() != null ? proto.extension().metadata() : new byte[0]).asReadOnlyBuffer(), + false); + } + if (proto.variant() != null) { + return new DType.Variant(proto.variant().nullable()); + } + throw new VortexException("unsupported proto DType"); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZigZagEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZigZagEncodingDecoder.java new file mode 100644 index 00000000..f206e617 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZigZagEncodingDecoder.java @@ -0,0 +1,95 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +/// Read-only decoder for {@code vortex.zigzag} — zigzag-decoded signed integers. +public final class ZigZagEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ZigZagEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ZIGZAG; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + PType pt = p.ptype(); + return pt == PType.I8 || pt == PType.I16 || pt == PType.I32 || pt == PType.I64; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_ZIGZAG, "expected primitive dtype, got " + ctx.dtype()); + } + PType signed = p.ptype(); + PType unsigned = toUnsigned(signed); + long n = ctx.rowCount(); + + MemorySegment src = ctx.decodeChildSegment(0, new DType.Primitive(unsigned, false), n); + int elemBytes = signed.byteSize(); + long srcCap = SegmentBroadcast.capacity(src, elemBytes); + MemorySegment dst = ctx.arena().allocate(n * elemBytes); + + return switch (signed) { + case I8 -> { + for (long i = 0; i < n; i++) { + int u = Byte.toUnsignedInt(src.get(ValueLayout.JAVA_BYTE, i % srcCap)); + dst.set(ValueLayout.JAVA_BYTE, i, (byte) ((u >>> 1) ^ -(u & 1))); + } + yield new ByteArray(ctx.dtype(), n, dst); + } + case I16 -> { + for (long i = 0; i < n; i++) { + int u = Short.toUnsignedInt(src.get(PTypeIO.LE_SHORT, (i % srcCap) * 2)); + dst.set(PTypeIO.LE_SHORT, i * 2, (short) ((u >>> 1) ^ -(u & 1))); + } + yield new ShortArray(ctx.dtype(), n, dst); + } + case I32 -> { + for (long i = 0; i < n; i++) { + int u = src.get(PTypeIO.LE_INT, (i % srcCap) * 4); + dst.set(PTypeIO.LE_INT, i * 4, (u >>> 1) ^ -(u & 1)); + } + yield new IntArray(ctx.dtype(), n, dst); + } + case I64 -> { + for (long i = 0; i < n; i++) { + long u = src.get(PTypeIO.LE_LONG, (i % srcCap) * 8); + dst.set(PTypeIO.LE_LONG, i * 8, (u >>> 1) ^ -(u & 1)); + } + yield new LongArray(ctx.dtype(), n, dst); + } + default -> throw new VortexException(EncodingId.VORTEX_ZIGZAG, "unreachable"); + }; + } + + private static PType toUnsigned(PType signed) { + return switch (signed) { + case I8 -> PType.U8; + case I16 -> PType.U16; + case I32 -> PType.U32; + case I64 -> PType.U64; + default -> throw new VortexException(EncodingId.VORTEX_ZIGZAG, "not a signed integer: " + signed); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZstdEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZstdEncodingDecoder.java new file mode 100644 index 00000000..c57d6fe7 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZstdEncodingDecoder.java @@ -0,0 +1,256 @@ +package io.github.dfa1.vortex.reader.decode; + +import com.github.luben.zstd.ZstdDecompressCtx; +import io.airlift.compress.v3.zstd.ZstdDecompressor; +import io.airlift.compress.v3.zstd.ZstdJavaDecompressor; +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.Float16Array; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ZstdMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.zstd}. +public final class ZstdEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ZstdEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ZSTD; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null) { + throw new VortexException(EncodingId.VORTEX_ZSTD, "missing metadata"); + } + ZstdMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = ZstdMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_ZSTD, "invalid metadata", e); + } + boolean hasDictionary = meta.dictionary_size() != 0; + + BoolArray validity = null; + if (ctx.node().children().length > 0) { + Array validityArray = ctx.decodeChild(0, new DType.Bool(false), ctx.rowCount()); + if (!(validityArray instanceof BoolArray ba)) { + throw new VortexException(EncodingId.VORTEX_ZSTD, + "validity child decoded to unexpected type: " + validityArray.getClass().getSimpleName()); + } + validity = ba; + } + + int frameCount = meta.frames().size(); + long totalUncompressed = 0; + for (int i = 0; i < frameCount; i++) { + totalUncompressed += meta.frames().get(i).uncompressed_size(); + } + + MemorySegment decompressed = hasDictionary + ? decompressFramesWithDict(ctx, meta, frameCount, totalUncompressed) + : decompressFrames(ctx, meta, frameCount, totalUncompressed); + + if (validity == null) { + return buildArray(ctx.dtype(), ctx.rowCount(), decompressed, ctx); + } else { + return buildNullableArray(ctx.dtype(), ctx.rowCount(), decompressed, validity, ctx); + } + } + + private static Array buildNullableArray( + DType dtype, long rowCount, MemorySegment validValues, BoolArray validity, DecodeContext ctx + ) { + Array child; + if (dtype instanceof DType.Primitive dt) { + child = buildScatteredPrimitive(dt, rowCount, validValues, validity, ctx); + } else if (dtype instanceof DType.Utf8 || dtype instanceof DType.Binary) { + child = buildScatteredVarBin(dtype, rowCount, validValues, validity, ctx); + } else { + throw new VortexException(EncodingId.VORTEX_ZSTD, "unsupported nullable dtype: " + dtype); + } + return new MaskedArray(child, validity); + } + + private static Array buildScatteredPrimitive( + DType.Primitive dt, long rowCount, MemorySegment validValues, BoolArray validity, DecodeContext ctx + ) { + int byteSize = dt.ptype().byteSize(); + MemorySegment out = ctx.arena().allocate(rowCount * byteSize); + long readPos = 0; + for (long i = 0; i < rowCount; i++) { + if (validity.getBoolean(i)) { + MemorySegment.copy(validValues, readPos, out, i * byteSize, byteSize); + readPos += byteSize; + } + } + DType.Primitive nonNull = new DType.Primitive(dt.ptype(), false); + return buildPrimitive(nonNull, rowCount, out); + } + + private static VarBinArray buildScatteredVarBin( + DType dtype, long rowCount, MemorySegment validValues, BoolArray validity, DecodeContext ctx + ) { + long totalDataBytes = 0; + long scanPos = 0; + for (long i = 0; i < rowCount; i++) { + if (validity.getBoolean(i)) { + int len = validValues.get(PTypeIO.LE_INT, scanPos); + scanPos += 4L + len; + totalDataBytes += len; + } + } + + MemorySegment values = ctx.arena().allocate(totalDataBytes > 0 ? totalDataBytes : 1); + MemorySegment offsets = ctx.arena().allocate((rowCount + 1) * 4L, 4); + offsets.setAtIndex(PTypeIO.LE_INT, 0, 0); + + long readPos = 0; + long dataPos = 0; + for (long i = 0; i < rowCount; i++) { + if (validity.getBoolean(i)) { + int len = validValues.get(PTypeIO.LE_INT, readPos); + readPos += 4; + MemorySegment.copy(validValues, readPos, values, dataPos, len); + readPos += len; + dataPos += len; + } + offsets.setAtIndex(PTypeIO.LE_INT, i + 1, (int) dataPos); + } + + return new VarBinArray(dtype.withNullable(false), rowCount, values, offsets, PType.I32); + } + + private static MemorySegment decompressFramesWithDict( + DecodeContext ctx, + ZstdMetadata meta, + int frameCount, + long totalUncompressed + ) { + MemorySegment out = ctx.arena().allocate(totalUncompressed); + byte[] dictBytes = ctx.buffer(0).toArray(ValueLayout.JAVA_BYTE); + try (ZstdDecompressCtx zctx = new ZstdDecompressCtx()) { + zctx.loadDict(dictBytes); + long outOffset = 0; + for (int i = 0; i < frameCount; i++) { + byte[] compressed = ctx.buffer(i + 1).toArray(ValueLayout.JAVA_BYTE); + int uncompSize = (int) meta.frames().get(i).uncompressed_size(); + byte[] temp = new byte[uncompSize]; + int written = zctx.decompressByteArray(temp, 0, uncompSize, compressed, 0, compressed.length); + if (written != uncompSize) { + throw new VortexException(EncodingId.VORTEX_ZSTD, + "frame " + i + ": expected " + uncompSize + " bytes, got " + written); + } + MemorySegment.copy(MemorySegment.ofArray(temp), 0, out, outOffset, uncompSize); + outOffset += uncompSize; + } + } catch (VortexException e) { + throw e; + } catch (Exception e) { + throw new VortexException(EncodingId.VORTEX_ZSTD, "dict decompression failed", e); + } + return out; + } + + private static MemorySegment decompressFrames( + DecodeContext ctx, + ZstdMetadata meta, + int frameCount, + long totalUncompressed + ) { + MemorySegment out = ctx.arena().allocate(totalUncompressed); + ZstdDecompressor decompressor = new ZstdJavaDecompressor(); + long outOffset = 0; + for (int i = 0; i < frameCount; i++) { + MemorySegment frameSeg = ctx.buffer(i); + byte[] compressed = frameSeg.toArray(ValueLayout.JAVA_BYTE); + int uncompSize = (int) meta.frames().get(i).uncompressed_size(); + byte[] temp = new byte[uncompSize]; + int written = decompressor.decompress(compressed, 0, compressed.length, temp, 0, uncompSize); + if (written != uncompSize) { + throw new VortexException(EncodingId.VORTEX_ZSTD, + "frame " + i + ": expected " + uncompSize + " bytes, got " + written); + } + MemorySegment.copy(MemorySegment.ofArray(temp), 0, out, outOffset, uncompSize); + outOffset += uncompSize; + } + return out; + } + + private static Array buildArray(DType dtype, long n, MemorySegment decompressed, DecodeContext ctx) { + if (dtype instanceof DType.Primitive dt) { + return buildPrimitive(dt, n, decompressed); + } + if (dtype instanceof DType.Utf8 || dtype instanceof DType.Binary) { + return buildVarBin(dtype, n, decompressed, ctx); + } + throw new VortexException(EncodingId.VORTEX_ZSTD, "unsupported dtype: " + dtype); + } + + private static Array buildPrimitive(DType.Primitive dt, long n, MemorySegment decompressed) { + PType ptype = dt.ptype(); + return switch (ptype) { + case I64, U64 -> new LongArray(dt, n, decompressed); + case I32, U32 -> new IntArray(dt, n, decompressed); + case F64 -> new DoubleArray(dt, n, decompressed); + case F32 -> new FloatArray(dt, n, decompressed); + case I16, U16 -> new ShortArray(dt, n, decompressed); + case I8, U8 -> new ByteArray(dt, n, decompressed); + case F16 -> new Float16Array(dt, n, decompressed); + }; + } + + private static VarBinArray buildVarBin(DType dtype, long n, MemorySegment decompressed, DecodeContext ctx) { + long totalDataBytes = 0; + long pos = 0; + for (long i = 0; i < n; i++) { + int len = decompressed.get(PTypeIO.LE_INT, pos); + pos += 4 + len; + totalDataBytes += len; + } + + MemorySegment values = ctx.arena().allocate(totalDataBytes); + MemorySegment offsets = ctx.arena().allocate((n + 1) * 4L, 4); + offsets.setAtIndex(PTypeIO.LE_INT, 0, 0); + + pos = 0; + long dataPos = 0; + for (long i = 0; i < n; i++) { + int len = decompressed.get(PTypeIO.LE_INT, pos); + pos += 4; + MemorySegment.copy(decompressed, pos, values, dataPos, len); + pos += len; + dataPos += len; + offsets.setAtIndex(PTypeIO.LE_INT, i + 1, (int) dataPos); + } + + return new VarBinArray(dtype, n, values, offsets, PType.I32); + } +} diff --git a/reader/src/main/resources/META-INF/services/io.github.dfa1.vortex.reader.decode.EncodingDecoder b/reader/src/main/resources/META-INF/services/io.github.dfa1.vortex.reader.decode.EncodingDecoder new file mode 100644 index 00000000..781a8a16 --- /dev/null +++ b/reader/src/main/resources/META-INF/services/io.github.dfa1.vortex.reader.decode.EncodingDecoder @@ -0,0 +1,33 @@ +io.github.dfa1.vortex.reader.decode.AlpEncodingDecoder +io.github.dfa1.vortex.reader.decode.AlpRdEncodingDecoder +io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder +io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder +io.github.dfa1.vortex.reader.decode.ByteBoolEncodingDecoder +io.github.dfa1.vortex.reader.decode.ChunkedEncodingDecoder +io.github.dfa1.vortex.reader.decode.ConstantEncodingDecoder +io.github.dfa1.vortex.reader.decode.DateTimePartsEncodingDecoder +io.github.dfa1.vortex.reader.decode.DecimalBytePartsEncodingDecoder +io.github.dfa1.vortex.reader.decode.DecimalEncodingDecoder +io.github.dfa1.vortex.reader.decode.DeltaEncodingDecoder +io.github.dfa1.vortex.reader.decode.DictEncodingDecoder +io.github.dfa1.vortex.reader.decode.ExtEncodingDecoder +io.github.dfa1.vortex.reader.decode.FixedSizeListEncodingDecoder +io.github.dfa1.vortex.reader.decode.FrameOfReferenceEncodingDecoder +io.github.dfa1.vortex.reader.decode.FsstEncodingDecoder +io.github.dfa1.vortex.reader.decode.ListEncodingDecoder +io.github.dfa1.vortex.reader.decode.ListViewEncodingDecoder +io.github.dfa1.vortex.reader.decode.MaskedEncodingDecoder +io.github.dfa1.vortex.reader.decode.NullEncodingDecoder +io.github.dfa1.vortex.reader.decode.PatchedEncodingDecoder +io.github.dfa1.vortex.reader.decode.PcoEncodingDecoder +io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder +io.github.dfa1.vortex.reader.decode.RleEncodingDecoder +io.github.dfa1.vortex.reader.decode.RunEndEncodingDecoder +io.github.dfa1.vortex.reader.decode.SequenceEncodingDecoder +io.github.dfa1.vortex.reader.decode.SparseEncodingDecoder +io.github.dfa1.vortex.reader.decode.StructEncodingDecoder +io.github.dfa1.vortex.reader.decode.VarBinEncodingDecoder +io.github.dfa1.vortex.reader.decode.VariantEncodingDecoder +io.github.dfa1.vortex.reader.decode.VarBinViewEncodingDecoder +io.github.dfa1.vortex.reader.decode.ZigZagEncodingDecoder +io.github.dfa1.vortex.reader.decode.ZstdEncodingDecoder diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/LayoutDepthBombSecurityTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/LayoutDepthBombSecurityTest.java index c1256797..6e6f0738 100644 --- a/reader/src/test/java/io/github/dfa1/vortex/reader/LayoutDepthBombSecurityTest.java +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/LayoutDepthBombSecurityTest.java @@ -3,7 +3,7 @@ import com.google.flatbuffers.FlatBufferBuilder; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.VortexFormat; -import io.github.dfa1.vortex.encoding.Registry; + import io.github.dfa1.vortex.fbs.ArraySpec; import io.github.dfa1.vortex.fbs.Footer; import io.github.dfa1.vortex.fbs.Layout; @@ -38,7 +38,7 @@ */ class LayoutDepthBombSecurityTest { - private static final Registry REGISTRY = Registry.empty(); + private static final ReadRegistry REGISTRY = ReadRegistry.empty(); @Test void deeplyNestedLayout_throwsVortexException(@TempDir Path tmp) throws Exception { diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/MalformedFooterSecurityTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/MalformedFooterSecurityTest.java index 0f91f8a2..cf9789f3 100644 --- a/reader/src/test/java/io/github/dfa1/vortex/reader/MalformedFooterSecurityTest.java +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/MalformedFooterSecurityTest.java @@ -3,7 +3,7 @@ import com.google.flatbuffers.FlatBufferBuilder; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.VortexFormat; -import io.github.dfa1.vortex.encoding.Registry; + import io.github.dfa1.vortex.fbs.ArraySpec; import io.github.dfa1.vortex.fbs.Footer; import io.github.dfa1.vortex.fbs.Layout; @@ -41,7 +41,7 @@ */ class MalformedFooterSecurityTest { - private static final Registry REGISTRY = Registry.empty(); + private static final ReadRegistry REGISTRY = ReadRegistry.empty(); static Stream outOfBoundsSpecs() { // SegmentSpec FlatBuffer schema: offset is u64, length is u32. Negative or diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/MalformedTrailerSecurityTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/MalformedTrailerSecurityTest.java index e1b5d3f2..21f17339 100644 --- a/reader/src/test/java/io/github/dfa1/vortex/reader/MalformedTrailerSecurityTest.java +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/MalformedTrailerSecurityTest.java @@ -2,7 +2,7 @@ import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.VortexFormat; -import io.github.dfa1.vortex.encoding.Registry; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -31,7 +31,7 @@ */ class MalformedTrailerSecurityTest { - private static final Registry REGISTRY = Registry.empty(); + private static final ReadRegistry REGISTRY = ReadRegistry.empty(); @Test void fileSmallerThanTrailer_throwsVortexException(@TempDir Path tmp) throws Exception { diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/ReadRegistryTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/ReadRegistryTest.java new file mode 100644 index 00000000..4b92f0e0 --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/ReadRegistryTest.java @@ -0,0 +1,120 @@ +package io.github.dfa1.vortex.reader; + +import io.github.dfa1.vortex.core.ArrayStats; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.UnknownArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.reader.decode.UnknownArrayNode; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class ReadRegistryTest { + + @Test + void decodeUnknownEncodingThrowsByDefault() { + // Given + ReadRegistry sut = ReadRegistry.empty(); + ArrayNode node = new UnknownArrayNode("some.unknown", + ByteBuffer.allocate(0), new ArrayNode[0], new int[0], ArrayStats.empty()); + DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0L, + new MemorySegment[0], sut, Arena.ofAuto()); + + // When / Then + assertThatThrownBy(() -> sut.decode(ctx)) + .isInstanceOf(VortexException.class) + .hasMessageContaining("some.unknown"); + } + + @Test + void decodeKnownEncodingWithoutDecoderThrowsByDefault() { + // Given — EncodingId is known but no decoder registered for it + ReadRegistry sut = ReadRegistry.empty(); + ArrayNode node = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, + ByteBuffer.allocate(0), new ArrayNode[0], new int[0], ArrayStats.empty()); + DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0L, + new MemorySegment[0], sut, Arena.ofAuto()); + + // When / Then + assertThatThrownBy(() -> sut.decode(ctx)) + .isInstanceOf(VortexException.class) + .hasMessageContaining("vortex.primitive"); + } + + @Test + void decodeKnownEncodingWithoutDecoderReturnsUnknownArrayWhenAllowed() { + // Given + ReadRegistry sut = ReadRegistry.builder().allowUnknown().build(); + ArrayNode node = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, + ByteBuffer.allocate(0), new ArrayNode[0], new int[0], ArrayStats.empty()); + DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0L, + new MemorySegment[0], sut, Arena.ofAuto()); + + // When + Array result = sut.decode(ctx); + + // Then + assertThat(result).isInstanceOf(UnknownArray.class); + assertThat(((UnknownArray) result).encodingId()).isEqualTo("vortex.primitive"); + } + + @Test + void decodeUnknownEncodingReturnsUnknownArrayWhenAllowed() { + // Given + ReadRegistry sut = ReadRegistry.builder().allowUnknown().build(); + ByteBuffer metadata = ByteBuffer.wrap(new byte[]{1, 2, 3}); + MemorySegment buf = Arena.ofAuto().allocate(4); + buf.set(java.lang.foreign.ValueLayout.JAVA_INT, 0, 42); + ArrayNode node = new UnknownArrayNode("some.unknown", + metadata, new ArrayNode[0], new int[]{0}, ArrayStats.empty()); + DecodeContext ctx = new DecodeContext(node, DTypes.I32, 5L, + new MemorySegment[]{buf}, sut, Arena.ofAuto()); + + // When + Array result = sut.decode(ctx); + + // Then + assertThat(result).isInstanceOf(UnknownArray.class); + UnknownArray unknown = (UnknownArray) result; + assertThat(unknown.encodingId()).isEqualTo("some.unknown"); + assertThat(unknown.dtype()).isEqualTo(DTypes.I32); + assertThat(unknown.length()).isEqualTo(5L); + assertThat(unknown.metadata()).isEqualTo(metadata); + assertThat(unknown.buffers()).hasSize(1); + assertThat(unknown.buffers()[0].get(java.lang.foreign.ValueLayout.JAVA_INT, 0)).isEqualTo(42); + assertThat(unknown.children()).isEmpty(); + } + + @Test + void decodeUnknownEncodingWrapsChildrenAsUnknown() { + // Given + ReadRegistry sut = ReadRegistry.builder().allowUnknown().build(); + // Child uses a known id; allow-unknown still wraps it unknown because + // its parent is unknown — mirrors Rust decode_foreign in vortex-array/src/serde.rs:380. + ArrayNode child = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, + ByteBuffer.allocate(0), new ArrayNode[0], new int[0], ArrayStats.empty()); + ArrayNode parent = new UnknownArrayNode("some.unknown", + ByteBuffer.allocate(0), new ArrayNode[]{child}, new int[0], ArrayStats.empty()); + DecodeContext ctx = new DecodeContext(parent, DTypes.I32, 0L, + new MemorySegment[0], sut, Arena.ofAuto()); + + // When + Array result = sut.decode(ctx); + + // Then + UnknownArray unknown = (UnknownArray) result; + assertThat(unknown.children()).hasSize(1); + assertThat(unknown.children()[0]).isInstanceOf(UnknownArray.class); + assertThat(((UnknownArray) unknown.children()[0]).encodingId()).isEqualTo("vortex.primitive"); + assertThat(sut.isAllowUnknown()).isTrue(); + } +} diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/VortexReaderTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/VortexReaderTest.java index 7110096a..780d598f 100644 --- a/reader/src/test/java/io/github/dfa1/vortex/reader/VortexReaderTest.java +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/VortexReaderTest.java @@ -6,12 +6,10 @@ import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.EmptyArray; import io.github.dfa1.vortex.core.array.UnknownArray; -import io.github.dfa1.vortex.encoding.DecodeContext; -import io.github.dfa1.vortex.encoding.EncodeContext; -import io.github.dfa1.vortex.encoding.EncodeResult; -import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.reader.decode.EncodingDecoder; import io.github.dfa1.vortex.encoding.EncodingId; -import io.github.dfa1.vortex.encoding.Registry; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.ParameterizedTest; @@ -30,10 +28,10 @@ class VortexReaderTest { // --- trailer / magic validation --- - private static Registry buildUniversalStubRegistry() { - var b = Registry.builder(); + private static ReadRegistry buildUniversalStubRegistry() { + var b = ReadRegistry.builder(); for (EncodingId encodingId : EncodingId.values()) { - b.register(new Encoding() { + b.register(new EncodingDecoder() { @Override public EncodingId encodingId() { return encodingId; @@ -44,11 +42,6 @@ public boolean accepts(DType dtype) { return false; } - @Override - public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { - throw new UnsupportedOperationException(); - } - @Override public Array decode(DecodeContext ctx) { return EmptyArray.of(ctx.dtype()); @@ -173,13 +166,13 @@ void scan_withNoDecoders_reachesDecodeStep(String name) throws URISyntaxExceptio Path path = fixtureFile(name); // When / Then — layout traversal succeeds; decode fails only on missing decoder - try (var sut = VortexReader.open(path, Registry.empty()); + try (var sut = VortexReader.open(path, ReadRegistry.empty()); var iter = sut.scan(ScanOptions.all())) { // Decode now happens in next(), not hasNext() — hasNext() is side-effect-free. assertThat(iter.hasNext()).isTrue(); assertThatThrownBy(iter::next) .isInstanceOf(VortexException.class) - .hasMessageContaining("no encoding registered"); + .hasMessageContaining("no decoder registered"); } } @@ -194,7 +187,7 @@ void scan_withNoDecoders_reachesDecodeStep(String name) throws URISyntaxExceptio void scan_withNoDecoders_allowUnknown_returnsUnknownArray(String name) throws URISyntaxException, IOException { // Given — empty registry + allowUnknown: every leaf decodes to a passthrough UnknownArray Path path = fixtureFile(name); - var registry = Registry.builder().allowUnknown().build(); + var registry = ReadRegistry.builder().allowUnknown().build(); // When try (var sut = VortexReader.open(path, registry); diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/ZipBombSecurityTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/ZipBombSecurityTest.java index 1825bd21..3b24d214 100644 --- a/reader/src/test/java/io/github/dfa1/vortex/reader/ZipBombSecurityTest.java +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/ZipBombSecurityTest.java @@ -5,9 +5,9 @@ import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.LongArray; -import io.github.dfa1.vortex.encoding.ConstantEncoding; -import io.github.dfa1.vortex.encoding.Registry; -import io.github.dfa1.vortex.encoding.PrimitiveEncoding; + +import io.github.dfa1.vortex.reader.decode.ConstantEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import io.github.dfa1.vortex.fbs.ArrayNode; import io.github.dfa1.vortex.fbs.ArraySpec; import io.github.dfa1.vortex.fbs.Footer; @@ -56,7 +56,7 @@ void attack1_constantEncoding_inflatedFlatRowCount(@TempDir Path tmp) throws Exc // Given — 10M rows: 80 MB if fix reverted (clean AssertionError, not JVM crash) long claimedRows = 10_000_000L; Path bomb = buildConstantBomb(tmp, claimedRows); - var registry = Registry.builder().register(new ConstantEncoding()).build(); + var registry = ReadRegistry.builder().register(new ConstantEncodingDecoder()).build(); // When try (var reader = VortexReader.open(bomb, registry); @@ -84,7 +84,7 @@ void attack2_dictLayout_inflatedCodesRowCount(@TempDir Path tmp) throws Exceptio // Given — 100 rows: no OOM risk even if fix reverted; loop hits OOB on index 1 // → IndexOutOfBoundsException (not VortexException) → clean assertion failure Path bomb = buildDictBomb(tmp, 100L); - var registry = Registry.builder().register(new PrimitiveEncoding()).build(); + var registry = ReadRegistry.builder().register(new PrimitiveEncodingDecoder()).build(); // When / Then — VortexException before any O(n) allocation. Decode now runs // in next() (hasNext() is side-effect free), so the validation throws there. diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoderTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoderTest.java new file mode 100644 index 00000000..cabcb796 --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoderTest.java @@ -0,0 +1,56 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.reader.ReadRegistry; + +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.EncodingId; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +class ByteBoolEncodingDecoderTest { + + static Stream cases() { + return Stream.of( + Arguments.of("all false", new byte[]{0, 0, 0}, new boolean[]{false, false, false}), + Arguments.of("all true", new byte[]{1, 42, (byte) 0xFF}, new boolean[]{true, true, true}), + Arguments.of("mixed", new byte[]{0, 1, 0, 1}, new boolean[]{false, true, false, true}), + Arguments.of("empty", new byte[]{}, new boolean[]{}) + ); + } + + private static DecodeContext buildCtx(byte[] byteValues) { + MemorySegment buf = MemorySegment.ofArray(byteValues); + ArrayNode node = ArrayNode.of(EncodingId.VORTEX_BYTEBOOL, null, new ArrayNode[0], new int[]{0}, null); + ReadRegistry registry = TestRegistry.ofDecoders(new ByteBoolEncodingDecoder()); + return new DecodeContext(node, DTypes.BOOL, byteValues.length, new MemorySegment[]{buf}, registry, + Arena.ofAuto()); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("cases") + void decode_byteBool_packsToBitArray(String name, byte[] input, boolean[] expected) { + // Given + DecodeContext ctx = buildCtx(input); + var sut = new ByteBoolEncodingDecoder(); + + // When + var result = sut.decode(ctx); + + // Then + assertThat(result).isInstanceOf(BoolArray.class); + assertThat(result.length()).isEqualTo(expected.length); + BoolArray boolArr = (BoolArray) result; + for (int i = 0; i < expected.length; i++) { + assertThat(boolArr.getBoolean(i)).as("index %d", i).isEqualTo(expected[i]); + } + } +} diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/DecodeTestHelper.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/DecodeTestHelper.java new file mode 100644 index 00000000..d67fac41 --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/DecodeTestHelper.java @@ -0,0 +1,43 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.reader.ReadRegistry; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.List; + +/// Utilities for wrapping encode output into a {@link DecodeContext} for round-trip tests. +/// +/// Public so writer/ test trees can reuse via the reader test-jar. +public final class DecodeTestHelper { + + private DecodeTestHelper() { + } + + /// Wraps a writer's {@link EncodeResult} into a {@link DecodeContext} for round-trip assertions. + /// + /// @param result writer output + /// @param rowCount logical row count + /// @param dtype decoded dtype + /// @param registry registry used for nested decode dispatch + /// @return decode context ready for {@link EncodingDecoder#decode} + public static DecodeContext toDecodeContext( + EncodeResult result, long rowCount, DType dtype, ReadRegistry registry + ) { + List buffers = result.buffers(); + MemorySegment[] segments = buffers.toArray(new MemorySegment[0]); + ArrayNode root = toArrayNode(result.rootNode()); + return new DecodeContext(root, dtype, rowCount, segments, registry, Arena.ofAuto()); + } + + private static ArrayNode toArrayNode(EncodeNode enc) { + ArrayNode[] children = new ArrayNode[enc.children().length]; + for (int i = 0; i < children.length; i++) { + children[i] = toArrayNode(enc.children()[i]); + } + return ArrayNode.of(enc.encodingId(), enc.metadata(), children, enc.bufferIndices(), null); + } +} diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoderTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoderTest.java new file mode 100644 index 00000000..1ca3fd72 --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoderTest.java @@ -0,0 +1,34 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.reader.ReadRegistry; + +import io.github.dfa1.vortex.core.array.NullArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.EncodingId; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +import static org.assertj.core.api.Assertions.assertThat; + +class NullEncodingDecoderTest { + + @Test + void decode_nullArray_returnsNullArrayWithCorrectLength() { + // Given + long rowCount = 42L; + ArrayNode node = ArrayNode.of(EncodingId.VORTEX_NULL, null, new ArrayNode[0], new int[0], null); + DecodeContext ctx = new DecodeContext(node, DTypes.NULL, rowCount, new MemorySegment[0], + ReadRegistry.empty(), Arena.ofAuto()); + var sut = new NullEncodingDecoder(); + + // When + var result = sut.decode(ctx); + + // Then + assertThat(result).isInstanceOf(NullArray.class); + assertThat(result.length()).isEqualTo(rowCount); + assertThat(result.dtype()).isEqualTo(DTypes.NULL); + } +} diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoderTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoderTest.java new file mode 100644 index 00000000..58f1aef2 --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoderTest.java @@ -0,0 +1,165 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.reader.ReadRegistry; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.PatchedMetadata; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class PatchedEncodingDecoderTest { + + private static final PatchedEncodingDecoder SUT = new PatchedEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(SUT, new PrimitiveEncodingDecoder()); + + private static ByteBuffer patchedMeta(int nPatches, int nLanes, int offset) { + return ByteBuffer.wrap(new PatchedMetadata(nPatches, nLanes, offset).encode()); + } + + private static MemorySegment i32Segment(int... values) { + MemorySegment seg = MemorySegment.ofArray(new byte[values.length * 4]); + ByteBuffer bb = seg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (int v : values) { + bb.putInt(v); + } + return seg; + } + + private static MemorySegment u32Segment(int... values) { + return i32Segment(values); + } + + private static MemorySegment u16Segment(short... values) { + MemorySegment seg = MemorySegment.ofArray(new byte[values.length * 2]); + ByteBuffer bb = seg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (short v : values) { + bb.putShort(v); + } + return seg; + } + + private static MemorySegment i64Segment(long... values) { + MemorySegment seg = MemorySegment.ofArray(new byte[values.length * 8]); + ByteBuffer bb = seg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (long v : values) { + bb.putLong(v); + } + return seg; + } + + private static Array decode(int n, int[] innerI32, int[] laneOffsets, short[] patchIndices, int[] patchValues) { + return decode(new DType.Primitive(PType.I32, false), n, + i32Segment(innerI32), u32Segment(laneOffsets), + u16Segment(patchIndices), i32Segment(patchValues), + laneOffsets.length - 1); + } + + private static Array decode(DType dtype, int n, + MemorySegment inner, MemorySegment laneOffsets, + MemorySegment patchIndices, MemorySegment patchValues, + int nLanes) { + int nPatches = (int) (patchIndices.byteSize() / 2); + ByteBuffer meta = patchedMeta(nPatches, nLanes, 0); + + MemorySegment[] segments = {inner, laneOffsets, patchIndices, patchValues}; + + ArrayNode innerNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, null); + ArrayNode laneNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{1}, null); + ArrayNode idxNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{2}, null); + ArrayNode valNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{3}, null); + + ArrayNode patchedNode = ArrayNode.of(EncodingId.VORTEX_PATCHED, meta, + new ArrayNode[]{innerNode, laneNode, idxNode, valNode}, new int[]{}, null); + + DecodeContext ctx = new DecodeContext(patchedNode, dtype, n, segments, REGISTRY, Arena.ofAuto()); + return SUT.decode(ctx); + } + + @Test + void decode_noPatches_returnsInnerUnchanged() { + int n = 4; + int[] inner = {10, 20, 30, 40}; + Array sut = decode(n, inner, new int[]{0, 0}, new short[]{}, new int[]{}); + + assertThat(sut).isInstanceOf(IntArray.class); + MemorySegment seg = ArraySegments.of(sut); + for (int i = 0; i < n; i++) { + assertThat(seg.getAtIndex(PTypeIO.LE_INT, i)).as("index %d", i).isEqualTo(inner[i]); + } + } + + @Test + void decode_singlePatch_overwrites() { + Array sut = decode(4, new int[]{10, 20, 30, 40}, new int[]{0, 1}, new short[]{2}, new int[]{99}); + + MemorySegment seg = ArraySegments.of(sut); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 0)).isEqualTo(10); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 1)).isEqualTo(20); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 2)).isEqualTo(99); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 3)).isEqualTo(40); + } + + @Test + void decode_multiplePatches_allApplied() { + Array sut = decode(4, new int[]{0, 0, 0, 0}, new int[]{0, 2}, new short[]{0, 3}, new int[]{1, 7}); + + MemorySegment seg = ArraySegments.of(sut); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 0)).isEqualTo(1); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 1)).isEqualTo(0); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 2)).isEqualTo(0); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 3)).isEqualTo(7); + } + + @ParameterizedTest + @ValueSource(ints = {1, 2, 1020, 1023, 1024, 1025, 2048}) + void decode_variousLengths_noPatches(int n) { + int[] inner = new int[n]; + Array sut = decode(n, inner, new int[]{0, 0}, new short[]{}, new int[]{}); + + MemorySegment seg = ArraySegments.of(sut); + for (int i = 0; i < n; i++) { + assertThat(seg.getAtIndex(PTypeIO.LE_INT, i)).as("index %d", i).isZero(); + } + } + + @Test + void decode_i64_singlePatch() { + DType dtype = new DType.Primitive(PType.I64, false); + Array sut = decode(dtype, 3, i64Segment(100L, 200L, 300L), u32Segment(0, 1), + u16Segment((short) 1), i64Segment(999L), 1); + + assertThat(sut).isInstanceOf(LongArray.class); + MemorySegment seg = ArraySegments.of(sut); + assertThat(seg.getAtIndex(PTypeIO.LE_LONG, 0)).isEqualTo(100L); + assertThat(seg.getAtIndex(PTypeIO.LE_LONG, 1)).isEqualTo(999L); + assertThat(seg.getAtIndex(PTypeIO.LE_LONG, 2)).isEqualTo(300L); + } + + @Test + void decode_missingMetadata_throws() { + ArrayNode innerNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, null); + ArrayNode patchedNode = ArrayNode.of(EncodingId.VORTEX_PATCHED, null, + new ArrayNode[]{innerNode, innerNode, innerNode, innerNode}, new int[]{}, null); + MemorySegment seg = i32Segment(1, 2, 3); + DecodeContext ctx = new DecodeContext(patchedNode, new DType.Primitive(PType.I32, false), 3, + new MemorySegment[]{seg}, ReadRegistry.empty(), Arena.ofAuto()); + + assertThatThrownBy(() -> SUT.decode(ctx)).hasMessageContaining("missing metadata"); + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/PcoEncodingTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoderTest.java similarity index 50% rename from core/src/test/java/io/github/dfa1/vortex/encoding/PcoEncodingTest.java rename to reader/src/test/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoderTest.java index dc462dcc..685a6337 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/PcoEncodingTest.java +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoderTest.java @@ -1,13 +1,17 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.reader.ReadRegistry; -import io.github.dfa1.vortex.proto.PcoChunkInfo; -import io.github.dfa1.vortex.proto.PcoMetadata; -import io.github.dfa1.vortex.proto.PcoPageInfo; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.array.LongArray; import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.encoding.EncodingId; + +import io.github.dfa1.vortex.proto.PcoChunkInfo; +import io.github.dfa1.vortex.proto.PcoMetadata; +import io.github.dfa1.vortex.proto.PcoPageInfo; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -25,21 +29,21 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class PcoEncodingTest { +class PcoEncodingDecoderTest { + + private static final PcoEncodingDecoder SUT = new PcoEncodingDecoder(); private static ByteBuffer validMetaBuffer() { - PcoMetadata meta = new PcoMetadata(new byte[]{PcoEncoding.PCO_FORMAT_MAJOR, PcoEncoding.PCO_FORMAT_MINOR}, java.util.List.of()); + PcoMetadata meta = new PcoMetadata(new byte[]{PcoEncodingDecoder.PCO_FORMAT_MAJOR, PcoEncodingDecoder.PCO_FORMAT_MINOR}, java.util.List.of()); return ByteBuffer.wrap(meta.encode()); } private static DecodeContext ctxWith(ByteBuffer meta, DType dtype, long rowCount, MemorySegment[] buffers) { ArrayNode node = ArrayNode.of(EncodingId.VORTEX_PCO, meta, new ArrayNode[0], bufferIndices(buffers.length), null); - return new DecodeContext(node, dtype, rowCount, buffers, Registry.empty(), Arena.ofAuto()); + return new DecodeContext(node, dtype, rowCount, buffers, ReadRegistry.empty(), Arena.ofAuto()); } - /// Build a nullable DecodeContext: validity buffer at index 0, pco buffers at indices 1..N. - /// Validity is a bit-packed Bool array (LSB-first, 1=valid). private static DecodeContext ctxWithValidity(ByteBuffer meta, DType dtype, long rowCount, MemorySegment validityBuf, MemorySegment[] pcoBuffers) { MemorySegment[] allBuffers = new MemorySegment[1 + pcoBuffers.length]; @@ -56,7 +60,7 @@ private static DecodeContext ctxWithValidity(ByteBuffer meta, DType dtype, long ArrayNode pcoNode = ArrayNode.of(EncodingId.VORTEX_PCO, meta, new ArrayNode[]{validityNode}, pcoBufferIndices, null); - Registry registry = TestRegistry.of(new BoolEncoding()); + ReadRegistry registry = TestRegistry.ofDecoders(new BoolEncodingDecoder()); return new DecodeContext(pcoNode, dtype, rowCount, allBuffers, registry, Arena.ofAuto()); } @@ -76,30 +80,17 @@ private static MemorySegment segmentOf(byte... bytes) { return seg; } - /// Build a PcoMetadata proto with one chunk containing one page of {@code nValues} values. private static ByteBuffer metaWithOneChunk(int nValues) { PcoMetadata meta = new PcoMetadata( - new byte[]{PcoEncoding.PCO_FORMAT_MAJOR, PcoEncoding.PCO_FORMAT_MINOR}, + new byte[]{PcoEncodingDecoder.PCO_FORMAT_MAJOR, PcoEncodingDecoder.PCO_FORMAT_MINOR}, java.util.List.of(new PcoChunkInfo(java.util.List.of(new PcoPageInfo(nValues))))); return ByteBuffer.wrap(meta.encode()); } - /// Chunk-meta bytes for Classic mode, Consecutive delta at {@code order}, ansSizeLog=0, nBins=0. - /// - /// Bit layout (LSB-first per byte): - /// byte0: mode_nibble=0, delta_nibble=1 - /// byte1: order (3b), secondary_uses_delta=0 (1b), ansSizeLog=0 (4b) - /// byte2–3: nBins=0 (15b), align padding private static MemorySegment chunkMetaConsecutive(int order) { - return segmentOf( - (byte) 0x10, // mode=0, delta_variant=1 - (byte) order, // order[2:0], secondary=0, ansSizeLog=0 (order ≤ 7) - (byte) 0x00, // nBins bits16-23 = 0 - (byte) 0x00 // nBins bits24-30 = 0, padding - ); + return segmentOf((byte) 0x10, (byte) order, (byte) 0x00, (byte) 0x00); } - /// Page bytes: {@code order} LE-U64 moments, then 4 zero ANS-state slots (0 bits each). private static MemorySegment pageWithMoments(long... moments) { byte[] buf = new byte[moments.length * Long.BYTES]; java.nio.ByteBuffer bb = java.nio.ByteBuffer.wrap(buf).order(java.nio.ByteOrder.LITTLE_ENDIAN); @@ -109,43 +100,32 @@ private static MemorySegment pageWithMoments(long... moments) { return segmentOf(buf); } - /// Build chunk meta for Classic mode + Conv1 delta by packing bits LSB-first. - /// - /// Layout: mode(4b)=0, delta(4b)=3, quantization(5b), bias_latent(64b), - /// order-1(5b), weights[order*32b], ansSizeLog(4b)=0, nBins(15b)=0, align. - /// bias_latent = bias ^ Long.MIN_VALUE; each weight_latent = weight ^ 0x80000000L. private static MemorySegment chunkMetaConv1(int quantization, long biasLatent, int order, long[] weightLatents) { java.util.BitSet bits = new java.util.BitSet(); int pos = 0; - // mode nibble = 0 (Classic) pos += 4; - // delta nibble = 3 (Conv1): bits = 0b0011 bits.set(pos); bits.set(pos + 1); pos += 4; - // quantization (5 bits) for (int i = 0; i < 5; i++) { if (((quantization >> i) & 1) != 0) { bits.set(pos); } pos++; } - // bias latent (64 bits, LSB first) for (int i = 0; i < 64; i++) { if (((biasLatent >> i) & 1L) != 0L) { bits.set(pos); } pos++; } - // order-1 (5 bits) for (int i = 0; i < 5; i++) { if ((((order - 1) >> i) & 1) != 0) { bits.set(pos); } pos++; } - // weight latents (order × 32 bits) for (long wl : weightLatents) { for (int i = 0; i < 32; i++) { if (((wl >> i) & 1L) != 0L) { @@ -154,9 +134,7 @@ private static MemorySegment chunkMetaConv1(int quantization, long biasLatent, pos++; } } - // ansSizeLog (4 bits) = 0 pos += 4; - // nBins (15 bits) = 0 pos += 15; int byteLen = (pos + 7) / 8; byte[] buf = new byte[byteLen]; @@ -168,20 +146,10 @@ private static MemorySegment chunkMetaConv1(int quantization, long biasLatent, return segmentOf(buf); } - /// Chunk-meta bytes for Classic mode + Lookback delta with windowNLog=1 (windowN=2), stateNLog=0 (stateN=1), - /// deltaAnsSizeLog=0, primaryAnsSizeLog=0, no bins. - /// - /// Bit layout: - /// byte0: mode=0[3:0], delta=2[7:4] → 0x20 - /// bytes 1-6: windowNLog-1(5b)=0, stateNLog(4b)=0, secondary(1b)=0, - /// deltaAnsSizeLog(4b)=0, nDeltaBins(15b)=0, - /// primaryAnsSizeLog(4b)=0, nBins(15b)=0, align → all 0x00 private static MemorySegment chunkMetaLookback() { return segmentOf((byte) 0x20, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00); } - /// Page bytes for Lookback with stateN=1, U64, deltaAnsSizeLog=0, primaryAnsSizeLog=0. - /// Format: 8 bytes (one 64-bit initial state). No ANS state bits (sizeLog=0). No decoded bits. private static MemorySegment lookbackPage(long initialState) { byte[] buf = new byte[Long.BYTES]; java.nio.ByteBuffer.wrap(buf).order(java.nio.ByteOrder.LITTLE_ENDIAN).putLong(initialState); @@ -189,28 +157,10 @@ private static MemorySegment lookbackPage(long initialState) { } @Nested - class EncodingIdTest { - + class EncodingIdNested { @Test void encodingId_isVortexPco() { - // Given / When / Then - assertThat(new PcoEncoding().encodingId()).isEqualTo(EncodingId.VORTEX_PCO); - } - } - - @Nested - class Encode { - - @Test - void encode_throwsVortexException() { - // Given - var sut = new PcoEncoding(); - DType dtype = new DType.Primitive(PType.I64, false); - - // When / Then - assertThatThrownBy(() -> sut.encode(dtype, new long[]{1L, 2L, 3L}, EncodeTestHelper.testCtx())) - .isInstanceOf(VortexException.class) - .hasMessageContaining("not implemented"); + assertThat(SUT.encodingId()).isEqualTo(EncodingId.VORTEX_PCO); } } @@ -219,51 +169,35 @@ class Decode { @Test void decode_nullMetadata_throwsMissingMeta() { - // Given - var sut = new PcoEncoding(); DecodeContext ctx = ctxWith(null, new DType.Primitive(PType.I64, false), 0, new MemorySegment[0]); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("missing PcoMetadata"); } @Test void decode_invalidHeaderVersion_throwsUnsupported() { - // Given - var sut = new PcoEncoding(); PcoMetadata meta = new PcoMetadata(new byte[]{0x03, 0x00}, java.util.List.of()); DecodeContext ctx = ctxWith(ByteBuffer.wrap(meta.encode()), new DType.Primitive(PType.I64, false), 0, new MemorySegment[0]); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("unsupported pco format version 03.00"); } @Test void decode_nonPrimitiveDtype_throws() { - // Given - var sut = new PcoEncoding(); DecodeContext ctx = ctxWith(validMetaBuffer(), new DType.Utf8(false), 0, new MemorySegment[0]); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("Primitive dtype"); } @Test void decode_unsupportedPtype_throws() { - // Given — F16 not supported by pco - var sut = new PcoEncoding(); DecodeContext ctx = ctxWith(validMetaBuffer(), new DType.Primitive(PType.F16, false), 0, new MemorySegment[0]); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("unsupported ptype"); } @@ -271,53 +205,25 @@ void decode_unsupportedPtype_throws() { @ParameterizedTest @EnumSource(value = PType.class, names = {"I16", "U16", "I32", "U32", "F32", "I64", "U64", "F64"}) void decode_zeroChunks_returnsEmptyArray(PType ptype) { - // Given — valid metadata with 0 chunks, 0 rows, any supported ptype - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith(validMetaBuffer(), new DType.Primitive(ptype, false), 0, - new MemorySegment[0]); - - // When - var result = sut.decode(ctx); - - // Then + DecodeContext ctx = ctxWith(validMetaBuffer(), new DType.Primitive(ptype, false), 0, new MemorySegment[0]); + var result = SUT.decode(ctx); assertThat(result.length()).isZero(); } @Test void decode_consecutiveDelta_order1_singleValue_decodes() { - // Given — U64 sequence [42] encoded with Classic mode, Consecutive delta order=1. - // With pageN=1 and order=1, decodedN=0: the single output value is the moment itself. - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{chunkMetaConsecutive(1), pageWithMoments(42L)}); - - // When - var result = sut.decode(ctx); - - // Then + var result = SUT.decode(ctx); assertThat(result.length()).isEqualTo(1); assertThat(((LongArray) result).getLong(0)).isEqualTo(42L); } @Test void decode_consecutiveDelta_order2_twoValues_decodes() { - // Given — U64 sequence [10, 17] encoded with Consecutive delta order=2. - // With pageN=2, order=2: decodedN=0; moments=[m0=10, m1=delta1=7]. - // Expected reconstruction: [m0, m0+m1] = [10, 17]. - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(2), - new DType.Primitive(PType.U64, false), - 2, + DecodeContext ctx = ctxWith(metaWithOneChunk(2), new DType.Primitive(PType.U64, false), 2, new MemorySegment[]{chunkMetaConsecutive(2), pageWithMoments(10L, 7L)}); - - // When - var result = sut.decode(ctx); - - // Then + var result = SUT.decode(ctx); assertThat(result.length()).isEqualTo(2); assertThat(((LongArray) result).getLong(0)).isEqualTo(10L); assertThat(((LongArray) result).getLong(1)).isEqualTo(17L); @@ -325,22 +231,12 @@ void decode_consecutiveDelta_order2_twoValues_decodes() { @Test void decode_multiPage_singleChunk_decodes() { - // Given — 1 chunk, 2 pages each containing 1 value (Consecutive order=1). - // buffers: [chunkMeta, page0, page1]; page0 moment=10→value 10, page1 moment=20→value 20. - var sut = new PcoEncoding(); PcoMetadata meta = new PcoMetadata( - new byte[]{PcoEncoding.PCO_FORMAT_MAJOR, PcoEncoding.PCO_FORMAT_MINOR}, + new byte[]{PcoEncodingDecoder.PCO_FORMAT_MAJOR, PcoEncodingDecoder.PCO_FORMAT_MINOR}, java.util.List.of(new PcoChunkInfo(java.util.List.of(new PcoPageInfo(1), new PcoPageInfo(1))))); - DecodeContext ctx = ctxWith( - ByteBuffer.wrap(meta.encode()), - new DType.Primitive(PType.U64, false), - 2, + DecodeContext ctx = ctxWith(ByteBuffer.wrap(meta.encode()), new DType.Primitive(PType.U64, false), 2, new MemorySegment[]{chunkMetaConsecutive(1), pageWithMoments(10L), pageWithMoments(20L)}); - - // When - var result = sut.decode(ctx); - - // Then + var result = SUT.decode(ctx); assertThat(result.length()).isEqualTo(2); assertThat(((LongArray) result).getLong(0)).isEqualTo(10L); assertThat(((LongArray) result).getLong(1)).isEqualTo(20L); @@ -348,26 +244,15 @@ void decode_multiPage_singleChunk_decodes() { @Test void decode_multiChunk_decodes() { - // Given — 2 chunks each with 1 page containing 1 value (Consecutive order=1). - // buffers: [chunkMeta0, page0, chunkMeta1, page1]; values=[100, 200]. - var sut = new PcoEncoding(); PcoMetadata meta = new PcoMetadata( - new byte[]{PcoEncoding.PCO_FORMAT_MAJOR, PcoEncoding.PCO_FORMAT_MINOR}, + new byte[]{PcoEncodingDecoder.PCO_FORMAT_MAJOR, PcoEncodingDecoder.PCO_FORMAT_MINOR}, java.util.List.of( new PcoChunkInfo(java.util.List.of(new PcoPageInfo(1))), new PcoChunkInfo(java.util.List.of(new PcoPageInfo(1))))); - DecodeContext ctx = ctxWith( - ByteBuffer.wrap(meta.encode()), - new DType.Primitive(PType.U64, false), - 2, - new MemorySegment[]{ - chunkMetaConsecutive(1), pageWithMoments(100L), + DecodeContext ctx = ctxWith(ByteBuffer.wrap(meta.encode()), new DType.Primitive(PType.U64, false), 2, + new MemorySegment[]{chunkMetaConsecutive(1), pageWithMoments(100L), chunkMetaConsecutive(1), pageWithMoments(200L)}); - - // When - var result = sut.decode(ctx); - - // Then + var result = SUT.decode(ctx); assertThat(result.length()).isEqualTo(2); assertThat(((LongArray) result).getLong(0)).isEqualTo(100L); assertThat(((LongArray) result).getLong(1)).isEqualTo(200L); @@ -379,27 +264,16 @@ class DecodeNullable { @Test void decode_nullable_someNulls_scattersCorrectly() { - // Given — U64 sequence: 3 total rows, validity=[true,false,true], valid values=[100,200]. - // Validity bits LSB-first: bit0=1, bit1=0, bit2=1 → byte 0x05. - // Pco encodes only valid values: 1 chunk, 2 pages of nValues=1 (Consecutive order=1). - var sut = new PcoEncoding(); PcoMetadata meta = new PcoMetadata( - new byte[]{PcoEncoding.PCO_FORMAT_MAJOR, PcoEncoding.PCO_FORMAT_MINOR}, + new byte[]{PcoEncodingDecoder.PCO_FORMAT_MAJOR, PcoEncodingDecoder.PCO_FORMAT_MINOR}, java.util.List.of(new PcoChunkInfo(java.util.List.of(new PcoPageInfo(1), new PcoPageInfo(1))))); - MemorySegment validityBuf = segmentOf((byte) 0x05); // bits: 1,0,1 + MemorySegment validityBuf = segmentOf((byte) 0x05); DecodeContext ctx = ctxWithValidity( - ByteBuffer.wrap(meta.encode()), - new DType.Primitive(PType.U64, true), - 3, - validityBuf, + ByteBuffer.wrap(meta.encode()), new DType.Primitive(PType.U64, true), 3, validityBuf, new MemorySegment[]{chunkMetaConsecutive(1), pageWithMoments(100L), pageWithMoments(200L)}); + var result = SUT.decode(ctx); - // When - var result = sut.decode(ctx); - - // Then — MaskedArray with 3 slots; positions 0 and 2 valid, position 1 null assertThat(result).isInstanceOf(MaskedArray.class); - assertThat(result.length()).isEqualTo(3); MaskedArray masked = (MaskedArray) result; assertThat(masked.isValid(0)).isTrue(); assertThat(masked.isValid(1)).isFalse(); @@ -410,23 +284,12 @@ void decode_nullable_someNulls_scattersCorrectly() { @Test void decode_nullable_allNull_returnsAllZeroed() { - // Given — 2 total rows, validity=[false,false], validCount=0. Pco has 0 chunks. - // Validity bits LSB-first: 0x00. - var sut = new PcoEncoding(); MemorySegment validityBuf = segmentOf((byte) 0x00); - DecodeContext ctx = ctxWithValidity( - validMetaBuffer(), - new DType.Primitive(PType.U64, true), - 2, - validityBuf, - new MemorySegment[0]); - - // When - var result = sut.decode(ctx); + DecodeContext ctx = ctxWithValidity(validMetaBuffer(), new DType.Primitive(PType.U64, true), 2, + validityBuf, new MemorySegment[0]); + var result = SUT.decode(ctx); - // Then — MaskedArray, length 2, both null, values zeroed assertThat(result).isInstanceOf(MaskedArray.class); - assertThat(result.length()).isEqualTo(2); MaskedArray masked = (MaskedArray) result; assertThat(masked.isValid(0)).isFalse(); assertThat(masked.isValid(1)).isFalse(); @@ -436,23 +299,12 @@ void decode_nullable_allNull_returnsAllZeroed() { @Test void decode_nullable_allValid_returnsMaskedWithAllValues() { - // Given — 2 total rows, validity=[true,true], valid values=[10,20]. - // Validity bits: 0x03. - var sut = new PcoEncoding(); - MemorySegment validityBuf = segmentOf((byte) 0x03); // bits: 1,1 - DecodeContext ctx = ctxWithValidity( - metaWithOneChunk(2), - new DType.Primitive(PType.U64, true), - 2, - validityBuf, - new MemorySegment[]{chunkMetaConsecutive(2), pageWithMoments(10L, 10L)}); + MemorySegment validityBuf = segmentOf((byte) 0x03); + DecodeContext ctx = ctxWithValidity(metaWithOneChunk(2), new DType.Primitive(PType.U64, true), 2, + validityBuf, new MemorySegment[]{chunkMetaConsecutive(2), pageWithMoments(10L, 10L)}); + var result = SUT.decode(ctx); - // When - var result = sut.decode(ctx); - - // Then — MaskedArray, all valid, values [10, 20] assertThat(result).isInstanceOf(MaskedArray.class); - assertThat(result.length()).isEqualTo(2); MaskedArray masked = (MaskedArray) result; assertThat(masked.isValid(0)).isTrue(); assertThat(masked.isValid(1)).isTrue(); @@ -466,27 +318,14 @@ class DecodeConv1 { @Test void decode_conv1_order1_zeroPrediction_statePassedThrough() { - // Given — I32, pageN=2, order=1, bias=0, weight=0 → prediction always 0. - // State raw = value ^ 0x80000000: for value=5, state_raw=0x80000005. - // Residual from degenerate tANS=0 → decoded = 0 ^ mid_i32(0x80000000) = 0x80000000 - // → fromLatentOrdered(0x80000000, I32) = 0x80000000 ^ 0x80000000 = 0. - // Expected output: [5, 0]. - var sut = new PcoEncoding(); - long biasLatent = Long.MIN_VALUE; // encodes bias=0: raw ^ MIN_VALUE = 0 - long weightLatent = 0x80000000L; // encodes weight=0: (int)(raw ^ 0x80000000) = 0 + long biasLatent = Long.MIN_VALUE; + long weightLatent = 0x80000000L; MemorySegment chunkMeta = chunkMetaConv1(0, biasLatent, 1, new long[]{weightLatent}); - // Page: 1 × 32-bit state = 0x80000005 (encodes value 5), no residual bits. MemorySegment page = segmentOf((byte) 0x05, (byte) 0x00, (byte) 0x00, (byte) 0x80); - DecodeContext ctx = ctxWith( - metaWithOneChunk(2), - new DType.Primitive(PType.I32, false), - 2, + DecodeContext ctx = ctxWith(metaWithOneChunk(2), new DType.Primitive(PType.I32, false), 2, new MemorySegment[]{chunkMeta, page}); + var result = SUT.decode(ctx); - // When - var result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(2); assertThat(((io.github.dfa1.vortex.core.array.IntArray) result).getInt(0)).isEqualTo(5); assertThat(((io.github.dfa1.vortex.core.array.IntArray) result).getInt(1)).isZero(); @@ -498,59 +337,29 @@ class DecodeLookback { @Test void decode_lookback_corruptIndexZero_throwsVortexException() { - // Given — Classic+Lookback, windowN=2, stateN=1, degenerate ANS (0 bins). - // Degenerate tANS always outputs lower=0; lb=0 is out of [1, windowN=2]. - // pageN=2: stateN=1 initial value + 1 decoded value with lb=0. - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(2), - new DType.Primitive(PType.U64, false), - 2, + DecodeContext ctx = ctxWith(metaWithOneChunk(2), new DType.Primitive(PType.U64, false), 2, new MemorySegment[]{chunkMetaLookback(), lookbackPage(0L)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("corrupt lookback index 0"); } @Test void decode_lookback_stateNExceedsPageN_throwsVortexException() { - // Given — stateN=2 (stateNLog=1) but pageN=1 → decodeN = 1-2 = -1 → corrupt. - // chunkMeta bit layout (all LE, LSB-first): - // byte0: mode=0[3:0], delta=2[7:4] → 0x20 - // byte1: windowNLog-1(5b)=0, stateNLog[0](1b)=1 → 0x20 - // byte2: stateNLog[1..3](3b)=0, secondary(1b)=0, deltaAnsSizeLog(4b)=0 → 0x00 - // bytes3-6: nDeltaBins(15b)=0, primaryAnsSizeLog(4b)=0, nBins(15b)=0 → 0x00 - var sut = new PcoEncoding(); MemorySegment chunkMeta = segmentOf( (byte) 0x20, (byte) 0x20, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{chunkMeta, segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("stateN"); } @Test void decode_lookback_singleInitialValue_returnsIt() { - // Given — pageN=1, stateN=1, decodeN=0: only the initial state value; no decoded values. - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{chunkMetaLookback(), lookbackPage(42L)}); - - // When - var result = sut.decode(ctx); - - // Then + var result = SUT.decode(ctx); assertThat(result.length()).isEqualTo(1); assertThat(((LongArray) result).getLong(0)).isEqualTo(42L); } @@ -558,21 +367,12 @@ void decode_lookback_singleInitialValue_returnsIt() { @Nested class DecodeLookbackDecodeN { - @Test void lookback_decodeNExceedsMax_throwsVortexException() { - // Given — stateN=1, pageN=(1<<23)+2 → decodeN=(1<<23)+1 > cap 1<<23. - // Check fires before arena.allocate; page only needs 8 bytes (1×64-bit initial state). - var sut = new PcoEncoding(); int pageN = (1 << 23) + 2; - DecodeContext ctx = ctxWith( - metaWithOneChunk(pageN), - new DType.Primitive(PType.U64, false), - pageN, + DecodeContext ctx = ctxWith(metaWithOneChunk(pageN), new DType.Primitive(PType.U64, false), pageN, new MemorySegment[]{chunkMetaLookback(), segmentOf(new byte[8])}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("decodeN"); } @@ -580,28 +380,14 @@ void lookback_decodeNExceedsMax_throwsVortexException() { @Nested class DecodeLookbackStateNWindow { - @Test void lookback_stateNExceedsWindowN_throwsVortexException() { - // Given — windowNLog=1 (windowN=2), stateNLog=2 (stateN=4) → stateN > windowN. - // pageN=4 (≥ stateN=4 passes the pageN sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("stateN"); } @@ -609,25 +395,13 @@ void lookback_stateNExceedsWindowN_throwsVortexException() { @Nested class DecodeLookbackWindowNLog { - @Test void lookback_windowNLogExceedsMax_throwsVortexException() { - // Given — mode=0 (Classic), delta=2 (Lookback), windowNLog=25 > max 24. - // Bit layout (LSB-first after byte0): - // byte0: mode=0[3:0], delta=2[7:4] → 0x20 - // byte1: windowNLog-1(5b)=24=0b11000 → 0x18; stateNLog bits start at bit5 - // bytes2-6: all 0x00 - var sut = new PcoEncoding(); MemorySegment chunkMeta = segmentOf( (byte) 0x20, (byte) 0x18, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{chunkMeta, segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("windowNLog"); } @@ -635,30 +409,17 @@ void lookback_windowNLogExceedsMax_throwsVortexException() { @Nested class DecodeDict { - @Test void dict_nUniqueExceedsMax_throwsVortexException() { - // Given — mode=4 (Dict), nUnique=65537 > max 65536. - // Bit layout (LSB-first): mode[3:0]=4, nUnique[28:4]=65537 - // combined = 4 | (65537 << 4) = 0x100014 - // bytes: 0x14, 0x00, 0x10, 0x00, 0x00 - var sut = new PcoEncoding(); - MemorySegment chunkMeta = segmentOf( - (byte) 0x14, (byte) 0x00, (byte) 0x10, (byte) 0x00, (byte) 0x00); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + MemorySegment chunkMeta = segmentOf((byte) 0x14, (byte) 0x00, (byte) 0x10, (byte) 0x00, (byte) 0x00); + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{chunkMeta, segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("nUnique"); } } - /// Adversarial coverage: malformed inputs must throw VortexException — never AIOOBE, NPE, or OOM. @Nested class Adversarial { @@ -680,115 +441,64 @@ static Stream pageBytesProvider() { }).limit(50); } - /// Random chunk-meta bytes — any exception must be a VortexException, not a JVM crash exception. @ParameterizedTest @MethodSource("chunkMetaBytesProvider") void randomChunkMetaBytes_neverThrowsJvmException(byte[] chunkMetaBytes) { - // Given — valid pco header + 1 chunk with 1 page of 1 value; garbage chunk-meta bytes. - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{segmentOf(chunkMetaBytes), segmentOf((byte) 0x00)}); - - // When / Then — either succeeds or throws VortexException; never AIOOBE/NPE/OOM try { - sut.decode(ctx); + SUT.decode(ctx); } catch (VortexException ignored) { - // expected — malformed input } } - /// Random page bytes after a valid Classic-mode chunk meta — must not crash the JVM. @ParameterizedTest @MethodSource("pageBytesProvider") void randomPageBytes_classicMode_neverThrowsJvmException(byte[] pageBytes) { - // Given — Classic mode, delta=NoOp, ansSizeLog=0, nBins=0 chunk meta. - var sut = new PcoEncoding(); - // byte0: mode=0 (bits3:0), deltaVariant=0 (bits7:4) → 0x00 - // byte1: ansSizeLog=0 (bits3:0), nBins low bits = 0 - // bytes 2-3: nBins high bits = 0 - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{segmentOf((byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00), segmentOf(pageBytes)}); - - // When / Then try { - sut.decode(ctx); + SUT.decode(ctx); } catch (VortexException ignored) { - // expected — malformed page data } } - /// Invalid mode nibbles (5–15) must produce a VortexException naming the mode number. @ParameterizedTest @ValueSource(ints = {5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}) void invalidModeNibble_throwsVortexException(int modeNibble) { - // Given — chunk meta with unsupported mode nibble in bits[3:0]. - var sut = new PcoEncoding(); - // bits[3:0] = modeNibble, delta nibble doesn't matter (won't be reached) byte modeByte = (byte) (modeNibble & 0x0F); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, - new MemorySegment[]{ - segmentOf(modeByte, (byte) 0x00, (byte) 0x00, (byte) 0x00), + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, + new MemorySegment[]{segmentOf(modeByte, (byte) 0x00, (byte) 0x00, (byte) 0x00), segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("pco mode " + modeNibble); } - /// Invalid delta variants (4–15) must produce a VortexException naming the variant number. @ParameterizedTest @ValueSource(ints = {4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}) void invalidDeltaVariant_throwsVortexException(int deltaVariant) { - // Given — Classic mode (nibble=0) + invalid delta nibble in bits[7:4]. - var sut = new PcoEncoding(); - // byte0: bits[3:0]=mode=0, bits[7:4]=deltaVariant byte modeDeltaByte = (byte) ((deltaVariant & 0x0F) << 4); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, - new MemorySegment[]{ - segmentOf(modeDeltaByte, (byte) 0x00, (byte) 0x00, (byte) 0x00), + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, + new MemorySegment[]{segmentOf(modeDeltaByte, (byte) 0x00, (byte) 0x00, (byte) 0x00), segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("delta variant " + deltaVariant); } - /// Conv1 delta with 64-bit dtype must throw VortexException; pcodec only supports 16/32-bit Conv1. @ParameterizedTest @EnumSource(value = PType.class, names = {"I64", "U64", "F64"}) void conv1Delta_with64BitDtype_throwsVortexException(PType ptype) { - // Given — Conv1 delta variant (nibble=3 in bits[7:4]), Classic mode (nibble=0 in bits[3:0]). - // byte0: bits[3:0]=0 (Classic), bits[7:4]=3 (Conv1) → 0x30 - // Remaining bytes: conv1 bit fields (don't matter — error fires before parsing them). - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(ptype, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(ptype, false), 1, new MemorySegment[]{ segmentOf((byte) 0x30, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00), segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("Conv1"); } diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/TestDecodeContexts.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/TestDecodeContexts.java new file mode 100644 index 00000000..eb656ccd --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/TestDecodeContexts.java @@ -0,0 +1,79 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.reader.ReadRegistry; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +/// Fluent builder for {@link DecodeContext} used in encoding tests. +/// +/// Defaults: rowCount=0, segments=[], registry=empty, arena=Arena.global(). +/// +/// Public so writer/ test trees can reuse via the reader test-jar. +public final class TestDecodeContexts { + + private final ArrayNode node; + private final DType dtype; + private long rowCount = 0; + private MemorySegment[] segments = new MemorySegment[0]; + private ReadRegistry registry = ReadRegistry.empty(); + private Arena arena = Arena.global(); + + private TestDecodeContexts(ArrayNode node, DType dtype) { + this.node = node; + this.dtype = dtype; + } + + /// Creates a new builder for the given node and dtype. + /// + /// @param node the array node + /// @param dtype the logical type + /// @return a new builder + public static TestDecodeContexts of(ArrayNode node, DType dtype) { + return new TestDecodeContexts(node, dtype); + } + + /// Sets the row count. + /// + /// @param n row count + /// @return this builder + public TestDecodeContexts rowCount(long n) { + this.rowCount = n; + return this; + } + + /// Sets the segment buffers. + /// + /// @param segs segment buffers + /// @return this builder + public TestDecodeContexts segments(MemorySegment... segs) { + this.segments = segs; + return this; + } + + /// Sets the registry. + /// + /// @param reg registry + /// @return this builder + public TestDecodeContexts registry(ReadRegistry reg) { + this.registry = reg; + return this; + } + + /// Sets the arena. + /// + /// @param a arena + /// @return this builder + public TestDecodeContexts arena(Arena a) { + this.arena = a; + return this; + } + + /// Builds the {@link DecodeContext}. + /// + /// @return a new decode context + public DecodeContext build() { + return new DecodeContext(node, dtype, rowCount, segments, registry, arena); + } +} diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/TestRegistry.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/TestRegistry.java new file mode 100644 index 00000000..23561c2a --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/TestRegistry.java @@ -0,0 +1,37 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.reader.ReadRegistry; + +/// Static factories for {@link ReadRegistry} instances used in decode tests. +/// +/// Public so writer/ test trees can reuse via the reader test-jar. +public final class TestRegistry { + + private TestRegistry() { + } + + /// Builds a {@link ReadRegistry} containing only the supplied decoders. + /// + /// @param decoders decoders to register + /// @return registry instance + public static ReadRegistry ofDecoders(EncodingDecoder... decoders) { + var b = ReadRegistry.builder(); + for (EncodingDecoder d : decoders) { + b.register(d); + } + return b.build(); + } + + /// Builds a {@link ReadRegistry} containing the supplied decoder plus a + /// {@link PrimitiveEncodingDecoder} fallback (so child decodes of primitive segments work). + /// + /// @param sut decoder under test + /// @return registry instance + public static ReadRegistry withPrimitive(EncodingDecoder sut) { + var b = ReadRegistry.builder().register(sut); + if (!(sut instanceof PrimitiveEncodingDecoder)) { + b.register(new PrimitiveEncodingDecoder()); + } + return b.build(); + } +} diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoderTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoderTest.java new file mode 100644 index 00000000..e3cc8053 --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoderTest.java @@ -0,0 +1,142 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.reader.ReadRegistry; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.NullArray; +import io.github.dfa1.vortex.core.array.VariantArray; +import io.github.dfa1.vortex.encoding.EncodingId; + +import io.github.dfa1.vortex.proto.Primitive; +import io.github.dfa1.vortex.proto.VariantMetadata; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class VariantEncodingDecoderTest { + + private static final DType VARIANT_DTYPE = new DType.Variant(false); + private static final int N = 3; + + private static final VariantEncodingDecoder SUT = new VariantEncodingDecoder(); + + private static ByteBuffer variantMetaWithShredded(io.github.dfa1.vortex.proto.DType shredded) { + return ByteBuffer.wrap(new VariantMetadata(shredded).encode()); + } + + private static ArrayNode primitiveChildNode(int segIdx) { + return ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{segIdx}, null); + } + + private static ArrayNode nullChildNode() { + return ArrayNode.of(EncodingId.VORTEX_NULL, null, new ArrayNode[0], new int[]{}, null); + } + + private static MemorySegment i32Segment(int... values) { + MemorySegment seg = MemorySegment.ofArray(new byte[values.length * 4]); + ByteBuffer bb = seg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (int v : values) { + bb.putInt(v); + } + return seg; + } + + @org.junit.jupiter.api.Test + void decode_withoutShredded_returnsCoreStorageOnly() { + ArrayNode coreNode = nullChildNode(); + ArrayNode variantNode = ArrayNode.of(EncodingId.VORTEX_VARIANT, null, + new ArrayNode[]{coreNode}, new int[]{}, null); + + ReadRegistry registry = TestRegistry.ofDecoders(SUT, new NullEncodingDecoder()); + DecodeContext ctx = new DecodeContext(variantNode, VARIANT_DTYPE, N, + new MemorySegment[0], registry, Arena.ofAuto()); + + Array result = SUT.decode(ctx); + + assertThat(result).isInstanceOf(VariantArray.class); + VariantArray va = (VariantArray) result; + assertThat(va.dtype()).isEqualTo(VARIANT_DTYPE); + assertThat(va.length()).isEqualTo(N); + assertThat(va.coreStorage()).isInstanceOf(NullArray.class); + assertThat(va.shredded()).isNull(); + } + + @Test + void decode_withShredded_decodesSecondChild() { + io.github.dfa1.vortex.proto.DType shreddedProto = io.github.dfa1.vortex.proto.DType.ofPrimitive( + new Primitive(io.github.dfa1.vortex.proto.PType.I32, false)); + ByteBuffer meta = variantMetaWithShredded(shreddedProto); + + ArrayNode coreNode = nullChildNode(); + ArrayNode shreddedNode = primitiveChildNode(0); + ArrayNode variantNode = ArrayNode.of(EncodingId.VORTEX_VARIANT, meta, + new ArrayNode[]{coreNode, shreddedNode}, new int[]{}, null); + + MemorySegment[] segments = {i32Segment(1, 2, 3)}; + ReadRegistry registry = TestRegistry.ofDecoders(SUT, new NullEncodingDecoder(), new PrimitiveEncodingDecoder()); + DecodeContext ctx = new DecodeContext(variantNode, VARIANT_DTYPE, N, + segments, registry, Arena.ofAuto()); + + Array result = SUT.decode(ctx); + + assertThat(result).isInstanceOf(VariantArray.class); + VariantArray va = (VariantArray) result; + assertThat(va.shredded()).isNotNull(); + assertThat(va.shredded().dtype()).isEqualTo(new DType.Primitive(PType.I32, false)); + assertThat(va.shredded().length()).isEqualTo(N); + } + + @Test + void decode_emptyMetadata_noShredded() { + ArrayNode coreNode = nullChildNode(); + ArrayNode variantNode = ArrayNode.of(EncodingId.VORTEX_VARIANT, ByteBuffer.allocate(0), + new ArrayNode[]{coreNode}, new int[]{}, null); + + ReadRegistry registry = TestRegistry.ofDecoders(SUT, new NullEncodingDecoder()); + DecodeContext ctx = new DecodeContext(variantNode, VARIANT_DTYPE, N, + new MemorySegment[0], registry, Arena.ofAuto()); + + Array result = SUT.decode(ctx); + + VariantArray va = (VariantArray) result; + assertThat(va.shredded()).isNull(); + } + + @Test + void decode_nullableDtype_preservedOnResult() { + DType nullableVariant = new DType.Variant(true); + ArrayNode coreNode = nullChildNode(); + ArrayNode variantNode = ArrayNode.of(EncodingId.VORTEX_VARIANT, null, + new ArrayNode[]{coreNode}, new int[]{}, null); + + ReadRegistry registry = TestRegistry.ofDecoders(SUT, new NullEncodingDecoder()); + DecodeContext ctx = new DecodeContext(variantNode, nullableVariant, N, + new MemorySegment[0], registry, Arena.ofAuto()); + + VariantArray va = (VariantArray) SUT.decode(ctx); + + assertThat(va.dtype()).isEqualTo(nullableVariant); + assertThat(va.dtype().nullable()).isTrue(); + } + + @Test + void decode_wrongChildCount_throws() { + ArrayNode variantNode = ArrayNode.of(EncodingId.VORTEX_VARIANT, null, + new ArrayNode[0], new int[]{}, null); + + ReadRegistry registry = TestRegistry.ofDecoders(SUT); + DecodeContext ctx = new DecodeContext(variantNode, VARIANT_DTYPE, N, + new MemorySegment[0], registry, Arena.ofAuto()); + + assertThatThrownBy(() -> SUT.decode(ctx)) + .hasMessageContaining("expected 1 or 2 children"); + } +} diff --git a/writer/pom.xml b/writer/pom.xml index a96b858f..009f0ae2 100644 --- a/writer/pom.xml +++ b/writer/pom.xml @@ -25,9 +25,22 @@ flatbuffers-java + + + io.github.dfa1.vortex + vortex-core + test-jar + test + + + io.github.dfa1.vortex + vortex-reader + test + io.github.dfa1.vortex vortex-reader + test-jar test diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java b/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java index 65f638f6..d820f9d8 100644 --- a/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java @@ -4,30 +4,32 @@ import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.VortexFormat; -import io.github.dfa1.vortex.encoding.AlpEncoding; -import io.github.dfa1.vortex.encoding.BitpackedEncoding; -import io.github.dfa1.vortex.encoding.BoolEncoding; -import io.github.dfa1.vortex.encoding.CascadingCompressor; -import io.github.dfa1.vortex.encoding.ConstantEncoding; import io.github.dfa1.vortex.encoding.DateTimePartsData; -import io.github.dfa1.vortex.encoding.DateTimePartsEncoding; -import io.github.dfa1.vortex.encoding.DictEncoding; import io.github.dfa1.vortex.encoding.EncodeContext; import io.github.dfa1.vortex.encoding.EncodeNode; import io.github.dfa1.vortex.encoding.EncodeResult; -import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingEncoder; import io.github.dfa1.vortex.encoding.EncodingId; -import io.github.dfa1.vortex.encoding.ExtEncoding; -import io.github.dfa1.vortex.encoding.FixedSizeListEncoding; -import io.github.dfa1.vortex.encoding.Registry; -import io.github.dfa1.vortex.encoding.FrameOfReferenceEncoding; import io.github.dfa1.vortex.encoding.ListData; import io.github.dfa1.vortex.encoding.ListViewData; -import io.github.dfa1.vortex.encoding.PrimitiveEncoding; -import io.github.dfa1.vortex.encoding.RleEncoding; -import io.github.dfa1.vortex.encoding.RunEndEncoding; -import io.github.dfa1.vortex.encoding.VarBinEncoding; -import io.github.dfa1.vortex.encoding.ZstdEncoding; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.writer.encode.AlpEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.BitpackedEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.BoolEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.CascadingCompressor; +import io.github.dfa1.vortex.writer.encode.ConstantEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.DateTimePartsEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.DictEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.ExtEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.FixedSizeListEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.FrameOfReferenceEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.FsstEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.MaskedEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.PrimitiveEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.RleEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.RunEndEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.VarBinEncodingEncoder; +import io.github.dfa1.vortex.writer.encode.ZstdEncodingEncoder; import io.github.dfa1.vortex.fbs.ArraySpec; import io.github.dfa1.vortex.fbs.Extension; import io.github.dfa1.vortex.fbs.Footer; @@ -77,10 +79,10 @@ public final class VortexWriter implements Closeable { // Kept low: global dict hurts high-cardinality F64 columns (ALP codes beat U16 dict codes). private static final int GLOBAL_DICT_MAX_CARDINALITY = 2_048; - private static final List DEFAULT_CODECS = List.of( - new AlpEncoding(), new PrimitiveEncoding(), new BoolEncoding(), new DictEncoding(), - new VarBinEncoding(), new ExtEncoding(), - new io.github.dfa1.vortex.encoding.FixedSizeListEncoding()); + private static final List DEFAULT_CODECS = List.of( + new AlpEncodingEncoder(), new PrimitiveEncodingEncoder(), new BoolEncodingEncoder(), + new DictEncodingEncoder(), new VarBinEncodingEncoder(), new ExtEncodingEncoder(), + new FixedSizeListEncodingEncoder()); // Base cascade codec list — no Zstd. Zstd is appended (before PrimitiveEncoding) when // WriteOptions.enableZstd() is true. See WriteOptions.withZstd(boolean) for the tradeoff. @@ -88,10 +90,11 @@ public final class VortexWriter implements Closeable { private final WritableByteChannel channel; private final DType.Struct schema; private final WriteOptions options; - private final List encodings; - private final Registry defaultRegistry; - private final List cascadeCodecs; - private final Registry cascadeRegistry; + private final List encodings; + private final Map defaultEncoders; + private final Registry extensionRegistry; + private final List cascadeCodecs; + private final Map cascadeEncoders; private final List segs = new ArrayList<>(); private final Map> colChunks = new LinkedHashMap<>(); private final Map encodingIdx = new LinkedHashMap<>(); @@ -105,30 +108,34 @@ public final class VortexWriter implements Closeable { private boolean firstChunkSeen = false; private VortexWriter( - WritableByteChannel channel, DType.Struct schema, WriteOptions options, List encodings + WritableByteChannel channel, DType.Struct schema, WriteOptions options, List encodings ) { this.channel = channel; this.schema = schema; this.options = options; this.encodings = encodings; - this.defaultRegistry = buildWriterRegistry(encodings); + this.defaultEncoders = buildEncoderMap(encodings); + this.extensionRegistry = buildExtensionRegistry(); this.cascadeCodecs = buildCascadeCodecs(options); - this.cascadeRegistry = buildWriterRegistry(this.cascadeCodecs); + this.cascadeEncoders = buildEncoderMap(this.cascadeCodecs); for (String name : schema.fieldNames()) { colChunks.put(name, new ArrayList<>()); } } - /// Builds the writer's registry: the explicit encoding list plus every - /// {@link io.github.dfa1.vortex.extension.Extension} discovered via ServiceLoader, - /// so {@code writeChunk} can auto-route {@code Collection} inputs through - /// the matching extension impl — including third-party extensions outside - /// {@link io.github.dfa1.vortex.extension.Extension#findKnown}. - private static Registry buildWriterRegistry(List encodings) { - Registry.Builder b = Registry.builder(); - for (Encoding e : encodings) { - b.register(e); + /// Builds the encoder map from the given encoder list. + private static Map buildEncoderMap(List encoders) { + Map map = new java.util.HashMap<>(); + for (EncodingEncoder e : encoders) { + map.put(e.encodingId(), e); } + return Map.copyOf(map); + } + + /// Builds a registry containing only service-loaded extensions, so {@code writeChunk} + /// can auto-route {@code Collection} inputs through the matching extension impl. + private static Registry buildExtensionRegistry() { + Registry.Builder b = Registry.builder(); for (io.github.dfa1.vortex.extension.Extension ext : java.util.ServiceLoader.load(io.github.dfa1.vortex.extension.Extension.class)) { b.register(ext); @@ -136,8 +143,8 @@ private static Registry buildWriterRegistry(List encodings) { return b.build(); } - private static List buildCascadeCodecs(WriteOptions options) { - List codecs = new ArrayList<>(); + private static List buildCascadeCodecs(WriteOptions options) { + List codecs = new ArrayList<>(); // Extension-dtype dispatch order matters: findPrimitiveEncoding picks the first // accepting codec. DateTimePartsEncoding goes first because it consumes // pre-decomposed DateTimePartsData (Parquet importer path); when the data is @@ -146,39 +153,52 @@ private static List buildCascadeCodecs(WriteOptions options) { // ExtEncoding which cascades the storage child through FoR/Bitpacked/RLE/ALP. // FixedSizeListEncoding handles UUID-style fixed-size byte storage downstream // of ExtEncoding. - codecs.add(new DateTimePartsEncoding()); - codecs.add(new ExtEncoding()); - codecs.add(new FixedSizeListEncoding()); - codecs.add(new ConstantEncoding()); - codecs.add(new AlpEncoding()); - codecs.add(new FrameOfReferenceEncoding()); - codecs.add(new RunEndEncoding()); - codecs.add(new RleEncoding()); - codecs.add(new DictEncoding()); - codecs.add(new BitpackedEncoding()); - // FsstEncoding sits between Dict and VarBin. Today's non-primitive dispatch + codecs.add(new DateTimePartsEncodingEncoder()); + codecs.add(new ExtEncodingEncoder()); + codecs.add(new FixedSizeListEncodingEncoder()); + codecs.add(new ConstantEncodingEncoder()); + codecs.add(new AlpEncodingEncoder()); + codecs.add(new FrameOfReferenceEncodingEncoder()); + codecs.add(new RunEndEncodingEncoder()); + codecs.add(new RleEncodingEncoder()); + codecs.add(new DictEncodingEncoder()); + codecs.add(new BitpackedEncodingEncoder()); + // FsstEncodingEncoder sits between Dict and VarBin. Today's non-primitive dispatch // (CascadingCompressor.findPrimitiveEncoding) is first-match, so Dict still // wins for Utf8; FSST only fires when Dict is excluded (cascade nested re-runs // via spliceResult's notApplicable retry). Listing it here matches Rust which // uses FSST for high-cardinality short strings (e.g. taxi store_and_fwd_flag). - codecs.add(new io.github.dfa1.vortex.encoding.FsstEncoding()); - codecs.add(new VarBinEncoding()); + codecs.add(new FsstEncodingEncoder()); + codecs.add(new VarBinEncodingEncoder()); if (options.enableZstd()) { - codecs.add(new ZstdEncoding()); + codecs.add(new ZstdEncodingEncoder()); } - codecs.add(new PrimitiveEncoding()); - codecs.add(new BoolEncoding()); + codecs.add(new PrimitiveEncodingEncoder()); + codecs.add(new BoolEncodingEncoder()); return List.copyOf(codecs); } + /// Creates a {@link VortexWriter} using the default encoder set. + /// + /// @param channel the channel to write to + /// @param schema the struct schema for the file + /// @param options write options + /// @return a new writer public static VortexWriter create( WritableByteChannel channel, DType.Struct schema, WriteOptions options ) { return new VortexWriter(channel, schema, options, DEFAULT_CODECS); } + /// Creates a {@link VortexWriter} with a custom encoder list. + /// + /// @param channel the channel to write to + /// @param schema the struct schema for the file + /// @param options write options + /// @param encodings custom encoder list + /// @return a new writer public static VortexWriter create( - WritableByteChannel channel, DType.Struct schema, WriteOptions options, List encodings + WritableByteChannel channel, DType.Struct schema, WriteOptions options, List encodings ) { // Custom encoding list: disable global dict — using DEFAULT_CODECS for values/codes behind the scenes // would violate the user's expectation that only their encoding list is used. @@ -317,7 +337,7 @@ public void writeChunk(Map columns) throws IOException { if (colDtype instanceof DType.Extension extDtype && data instanceof java.util.Collection coll) { io.github.dfa1.vortex.extension.Extension impl = io.github.dfa1.vortex.extension.ExtensionId.parse(extDtype.extensionId()) - .map(defaultRegistry::lookup) + .map(extensionRegistry::lookup) .orElse(null); if (impl != null) { data = impl.encodeAll(extDtype, coll); @@ -386,31 +406,31 @@ private int writeSegment(DType dtype, Object data) throws IOException { /// Writes a segment, optionally forcing a specific {@code encodingOverride} instead of /// the configured cascade. Used by the global Utf8 dictionary path where the values - /// segment must be flat varbin — the cascade would otherwise re-pick {@link - /// io.github.dfa1.vortex.encoding.DictEncoding} and wrap the dictionary in another - /// dict (which the reader cannot unwrap). - private int writeSegment(DType dtype, Object data, Encoding encodingOverride) throws IOException { - // Non-extension nullable columns (Primitive, Utf8) wrap with MaskedEncoding here. - // Extension columns route through ExtEncoding.encode which itself delegates to - // MaskedEncoding when its storage data is NullableData — handled inside ExtEncoding. + /// segment must be flat varbin — the cascade would otherwise re-pick + /// {@link DictEncodingEncoder} and wrap the dictionary in another dict (which the reader + /// cannot unwrap). + private int writeSegment(DType dtype, Object data, EncodingEncoder encodingOverride) throws IOException { + // Non-extension nullable columns (Primitive, Utf8) wrap with MaskedEncodingEncoder here. + // Extension columns route through ExtEncodingEncoder.encode which itself delegates to + // MaskedEncodingEncoder when its storage data is NullableData — handled inside ExtEncoding. if (encodingOverride == null && data instanceof io.github.dfa1.vortex.core.array.NullableData && !(dtype instanceof DType.Extension)) { - encodingOverride = new io.github.dfa1.vortex.encoding.MaskedEncoding(); + encodingOverride = new MaskedEncodingEncoder(); } try (Arena arena = Arena.ofConfined()) { EncodeResult result; if (encodingOverride != null) { - EncodeContext encodeCtx = EncodeContext.of(arena, this.defaultRegistry); + EncodeContext encodeCtx = EncodeContext.of(arena, defaultEncoders); result = encodingOverride.encode(dtype, data, encodeCtx); } else if (options.allowedCascading() > 0) { - EncodeContext encodeCtx = EncodeContext.ofDepth(options.allowedCascading(), arena, cascadeRegistry); + EncodeContext encodeCtx = EncodeContext.ofDepth(options.allowedCascading(), arena, cascadeEncoders); CascadingCompressor compressor = new CascadingCompressor(cascadeCodecs); result = compressor.encode(dtype, data, encodeCtx); } else { - Encoding encoding = findEncoding(dtype); - EncodeContext encodeCtx = EncodeContext.of(arena, this.defaultRegistry); - result = encoding.encode(dtype, data, encodeCtx); + EncodingEncoder encoder = findEncoder(dtype); + EncodeContext encodeCtx = EncodeContext.of(arena, defaultEncoders); + result = encoder.encode(dtype, data, encodeCtx); } // Register all encoding IDs found in the node tree registerEncodingIds(result.rootNode()); @@ -449,13 +469,13 @@ private void registerEncodingIds(EncodeNode node) { } } - private Encoding findEncoding(DType dtype) { - for (Encoding c : encodings) { + private EncodingEncoder findEncoder(DType dtype) { + for (EncodingEncoder c : encodings) { if (c.accepts(dtype)) { return c; } } - throw new UnsupportedOperationException("no encoding for dtype: " + dtype); + throw new UnsupportedOperationException("no encoder for dtype: " + dtype); } private void write(MemorySegment seg) throws IOException { @@ -749,7 +769,7 @@ private void writeGlobalDictUtf8Column(String colName, DType.Utf8 dtype, List encodeF64((double[]) data, ctx); + case F32 -> encodeF32((float[]) data, ctx); + default -> throw new UnsupportedOperationException("ALP encode not supported for " + ptype); + }; + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + if (ptype == PType.F64) { + return encodeCascadeF64((double[]) data, ctx); + } + return CascadeStep.terminal(encode(dtype, data, ctx)); + } + + private static int[] findExponentsF64(double[] values) { + int sampleLen = Math.min(SAMPLE_SIZE, values.length); + int bestExpE = 0, bestExpF = 0, bestExceptions = sampleLen + 1; + + outer: + for (int expE = 0; expE <= MAX_EXPONENT_F64; expE++) { + for (int expF = 0; expF <= MAX_EXPONENT_F64; expF++) { + double ef = F10_F64[expE]; + double iff = IF10_F64[expF]; + double df = F10_F64[expF]; + double de = IF10_F64[expE]; + int exceptions = 0; + for (int i = 0; i < sampleLen; i++) { + double enc = values[i] * ef * iff; + if (!Double.isFinite(enc) || (double) Math.round(enc) * df * de != values[i]) { + exceptions++; + } + } + if (exceptions < bestExceptions) { + bestExceptions = exceptions; + bestExpE = expE; + bestExpF = expF; + if (bestExceptions == 0) { + break outer; + } + } + } + } + return new int[]{bestExpE, bestExpF}; + } + + private static AlpF64Data computeF64(double[] values) { + int n = values.length; + int[] exps = findExponentsF64(values); + int expE = exps[0], expF = exps[1]; + double ef = F10_F64[expE]; + double iff = IF10_F64[expF]; + double df = F10_F64[expF]; + double de = IF10_F64[expE]; + + long[] encodedArr = new long[n]; + var patchIndices = new ArrayList(); + var patchValues = new ArrayList(); + + double min = Double.MAX_VALUE, max = -Double.MAX_VALUE; + for (int i = 0; i < n; i++) { + double v = values[i]; + double enc = v * ef * iff; + long encoded; + if (Double.isFinite(enc) && (double) (encoded = Math.round(enc)) * df * de == v) { + encodedArr[i] = encoded; + } else { + encodedArr[i] = 0L; + patchIndices.add(i); + patchValues.add(v); + } + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + + byte[] statsMin = n > 0 ? scalarF64(min) : null; + byte[] statsMax = n > 0 ? scalarF64(max) : null; + return new AlpF64Data(expE, expF, encodedArr, patchIndices, patchValues, statsMin, statsMax); + } + + private static EncodeResult encodeF64(double[] values, EncodeContext ctx) { + AlpF64Data d = computeF64(values); + int n = values.length; + + MemorySegment encodedBuf = ctx.arena().allocate((long) n * 8, 8); + for (int i = 0; i < n; i++) { + encodedBuf.setAtIndex(PTypeIO.LE_LONG, i, d.encodedArr()[i]); + } + + EncodeNode encodedNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + + if (d.patchIndices().isEmpty()) { + byte[] metaBytes = new ALPMetadata(d.expE(), d.expF(), null).encode(); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), new EncodeNode[]{encodedNode}, new int[0]); + return new EncodeResult(root, List.of(encodedBuf), d.statsMin(), d.statsMax()); + } + + int numPatches = d.patchIndices().size(); + MemorySegment idxBuf = ctx.arena().allocate((long) numPatches * 4, 4); + MemorySegment valBuf = ctx.arena().allocate((long) numPatches * 8, 8); + for (int i = 0; i < numPatches; i++) { + idxBuf.setAtIndex(PTypeIO.LE_INT, i, d.patchIndices().get(i)); + valBuf.setAtIndex(PTypeIO.LE_DOUBLE, i, d.patchValues().get(i)); + } + + PatchesMetadata patches = buildPatchesMeta(numPatches); + byte[] metaBytes = new ALPMetadata(d.expE(), d.expF(), patches).encode(); + + EncodeNode idxNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode valNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{encodedNode, idxNode, valNode}, + new int[0]); + return new EncodeResult(root, List.of(encodedBuf, idxBuf, valBuf), d.statsMin(), d.statsMax()); + } + + private static CascadeStep encodeCascadeF64(double[] values, EncodeContext ctx) { + AlpF64Data d = computeF64(values); + if (d.patchIndices().isEmpty()) { + byte[] metaBytes = new ALPMetadata(d.expE(), d.expF(), null).encode(); + EncodeNode partialRoot = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), new EncodeNode[1], new int[0]); + ChildSlot slot = new ChildSlot(I64_DTYPE, d.encodedArr(), 0); + return new CascadeStep(partialRoot, List.of(), List.of(slot), d.statsMin(), d.statsMax(), true); + } + + int numPatches = d.patchIndices().size(); + MemorySegment idxBuf = ctx.arena().allocate((long) numPatches * 4, 4); + MemorySegment valBuf = ctx.arena().allocate((long) numPatches * 8, 8); + for (int i = 0; i < numPatches; i++) { + idxBuf.setAtIndex(PTypeIO.LE_INT, i, d.patchIndices().get(i)); + valBuf.setAtIndex(PTypeIO.LE_DOUBLE, i, d.patchValues().get(i)); + } + + PatchesMetadata patches = buildPatchesMeta(numPatches); + byte[] metaBytes = new ALPMetadata(d.expE(), d.expF(), patches).encode(); + + EncodeNode idxNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode valNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode partialRoot = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), new EncodeNode[]{null, idxNode, valNode}, new int[0]); + ChildSlot slot = new ChildSlot(I64_DTYPE, d.encodedArr(), 0); + return new CascadeStep(partialRoot, List.of(idxBuf, valBuf), List.of(slot), d.statsMin(), d.statsMax(), true); + } + + private static int[] findExponentsF32(float[] values) { + int sampleLen = Math.min(SAMPLE_SIZE, values.length); + int bestExpE = 0, bestExpF = 0, bestExceptions = sampleLen + 1; + + outer: + for (int expE = 0; expE <= MAX_EXPONENT_F32; expE++) { + for (int expF = 0; expF <= MAX_EXPONENT_F32; expF++) { + float ef = F10_F32[expE]; + float iff = IF10_F32[expF]; + float df = F10_F32[expF]; + float de = IF10_F32[expE]; + int exceptions = 0; + for (int i = 0; i < sampleLen; i++) { + float enc = values[i] * ef * iff; + if (!Float.isFinite(enc) || (float) Math.round(enc) * df * de != values[i]) { + exceptions++; + } + } + if (exceptions < bestExceptions) { + bestExceptions = exceptions; + bestExpE = expE; + bestExpF = expF; + if (bestExceptions == 0) { + break outer; + } + } + } + } + return new int[]{bestExpE, bestExpF}; + } + + private static EncodeResult encodeF32(float[] values, EncodeContext ctx) { + int n = values.length; + int[] exps = findExponentsF32(values); + int expE = exps[0], expF = exps[1]; + float ef = F10_F32[expE]; + float iff = IF10_F32[expF]; + float df = F10_F32[expF]; + float de = IF10_F32[expE]; + + int[] encodedArr = new int[n]; + var patchIndices = new ArrayList(); + var patchValues = new ArrayList(); + + float min = Float.MAX_VALUE, max = -Float.MAX_VALUE; + for (int i = 0; i < n; i++) { + float v = values[i]; + float enc = v * ef * iff; + int encoded; + if (Float.isFinite(enc) && (float) (encoded = Math.round(enc)) * df * de == v) { + encodedArr[i] = encoded; + } else { + encodedArr[i] = 0; + patchIndices.add(i); + patchValues.add(v); + } + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + + byte[] statsMin = n > 0 ? scalarF32(min) : null; + byte[] statsMax = n > 0 ? scalarF32(max) : null; + + MemorySegment encodedBuf = ctx.arena().allocate((long) n * 4, 4); + for (int i = 0; i < n; i++) { + encodedBuf.setAtIndex(PTypeIO.LE_INT, i, encodedArr[i]); + } + + EncodeNode encodedNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + + if (patchIndices.isEmpty()) { + byte[] metaBytes = new ALPMetadata(expE, expF, null).encode(); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), new EncodeNode[]{encodedNode}, new int[0]); + return new EncodeResult(root, List.of(encodedBuf), statsMin, statsMax); + } + + int numPatches = patchIndices.size(); + MemorySegment idxBuf = ctx.arena().allocate((long) numPatches * 4, 4); + MemorySegment valBuf = ctx.arena().allocate((long) numPatches * 4, 4); + for (int i = 0; i < numPatches; i++) { + idxBuf.setAtIndex(PTypeIO.LE_INT, i, patchIndices.get(i)); + valBuf.setAtIndex(PTypeIO.LE_FLOAT, i, patchValues.get(i)); + } + + PatchesMetadata patches = new PatchesMetadata( + numPatches, + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(PType.U32.ordinal()), + null, null, null); + byte[] metaBytes = new ALPMetadata(expE, expF, patches).encode(); + + EncodeNode idxNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode valNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{encodedNode, idxNode, valNode}, + new int[0]); + return new EncodeResult(root, List.of(encodedBuf, idxBuf, valBuf), statsMin, statsMax); + } + + private static PatchesMetadata buildPatchesMeta(int numPatches) { + return new PatchesMetadata( + numPatches, + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(PType.U32.ordinal()), + null, null, null); + } + + private static byte[] scalarF64(double v) { + return ScalarValue.ofF64Value(v).encode(); + } + + private static byte[] scalarF32(float v) { + return ScalarValue.ofF32Value(v).encode(); + } + + private record AlpF64Data(int expE, int expF, long[] encodedArr, + List patchIndices, List patchValues, + byte[] statsMin, byte[] statsMax) { + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoder.java new file mode 100644 index 00000000..01829d24 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoder.java @@ -0,0 +1,328 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ALPRDMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/// Write-only encoder for {@code vortex.alprd}. +public final class AlpRdEncodingEncoder implements EncodingEncoder { + private static final DType U16_DTYPE = new DType.Primitive(PType.U16, false); + private static final DType U32_DTYPE = new DType.Primitive(PType.U32, false); + private static final DType U64_DTYPE = new DType.Primitive(PType.U64, false); + + private static final int SAMPLE_SIZE = 512; + private static final int MAX_CUT = 16; + private static final int MAX_DICT_SIZE = 8; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public AlpRdEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ALPRD; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return p.ptype() == PType.F32 || p.ptype() == PType.F64; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + return switch (ptype) { + case F64 -> encodeF64((double[]) data, ctx); + case F32 -> encodeF32((float[]) data, ctx); + default -> throw new UnsupportedOperationException("ALP-RD encode not supported for " + ptype); + }; + } + + private static EncodeResult encodeF64(double[] values, EncodeContext ctx) { + int n = values.length; + if (n == 0) { + return emptyResult(U64_DTYPE, ctx); + } + + int sampleLen = Math.min(SAMPLE_SIZE, n); + Dictionary64 best = findBestDictionaryF64(values, sampleLen); + + Map lookup = buildLookup(best.dict); + long rightMask = -1L >>> (64 - best.rightBitWidth); + + short[] leftCodes = new short[n]; + long[] rightParts = new long[n]; + List excPos = new ArrayList<>(); + List excVals = new ArrayList<>(); + + for (int i = 0; i < n; i++) { + long bits = Double.doubleToRawLongBits(values[i]); + short leftU16 = (short) (bits >>> best.rightBitWidth); + rightParts[i] = bits & rightMask; + Short code = lookup.get(leftU16); + if (code != null) { + leftCodes[i] = code; + } else { + leftCodes[i] = 0; + excPos.add((long) i); + excVals.add(leftU16); + } + } + + return buildEncodeResult( + best.dict, best.rightBitWidth, leftCodes, rightParts, + U64_DTYPE, excPos, excVals, ctx); + } + + private static Dictionary64 findBestDictionaryF64(double[] values, int sampleLen) { + double bestEstSize = Double.MAX_VALUE; + int bestRightBw = 48; + short[] bestDict = new short[]{0}; + + for (int p = 1; p <= MAX_CUT; p++) { + int rightBw = 64 - p; + Map counts = new HashMap<>(); + for (int i = 0; i < sampleLen; i++) { + long bits = Double.doubleToRawLongBits(values[i]); + short leftU16 = (short) (bits >>> rightBw); + counts.merge(leftU16, 1, Integer::sum); + } + short[] dict = topKByCount(counts); + int excCount = countExceptionsF64(values, sampleLen, dict, rightBw); + int maxCode = dict.length - 1; + int leftBw = maxCode == 0 ? 1 : (Integer.SIZE - Integer.numberOfLeadingZeros(maxCode)); + double estSize = rightBw + leftBw + (double) (excCount * 32) / sampleLen; + if (estSize < bestEstSize) { + bestEstSize = estSize; + bestRightBw = rightBw; + bestDict = dict; + } + } + return new Dictionary64(bestDict, bestRightBw); + } + + private static int countExceptionsF64(double[] values, int sampleLen, short[] dict, int rightBw) { + Map dictSet = new HashMap<>(); + for (short d : dict) { + dictSet.put(d, Boolean.TRUE); + } + int count = 0; + for (int i = 0; i < sampleLen; i++) { + long bits = Double.doubleToRawLongBits(values[i]); + short leftU16 = (short) (bits >>> rightBw); + if (!dictSet.containsKey(leftU16)) { + count++; + } + } + return count; + } + + private static EncodeResult encodeF32(float[] values, EncodeContext ctx) { + int n = values.length; + if (n == 0) { + return emptyResult(U32_DTYPE, ctx); + } + + int sampleLen = Math.min(SAMPLE_SIZE, n); + Dictionary32 best = findBestDictionaryF32(values, sampleLen); + + Map lookup = buildLookup(best.dict); + int rightMask = -1 >>> (32 - best.rightBitWidth); + + short[] leftCodes = new short[n]; + int[] rightParts = new int[n]; + List excPos = new ArrayList<>(); + List excVals = new ArrayList<>(); + + for (int i = 0; i < n; i++) { + int bits = Float.floatToRawIntBits(values[i]); + short leftU16 = (short) (bits >>> best.rightBitWidth); + rightParts[i] = bits & rightMask; + Short code = lookup.get(leftU16); + if (code != null) { + leftCodes[i] = code; + } else { + leftCodes[i] = 0; + excPos.add((long) i); + excVals.add(leftU16); + } + } + + return buildEncodeResult( + best.dict, best.rightBitWidth, leftCodes, rightParts, + U32_DTYPE, excPos, excVals, ctx); + } + + private static Dictionary32 findBestDictionaryF32(float[] values, int sampleLen) { + double bestEstSize = Double.MAX_VALUE; + int bestRightBw = 16; + short[] bestDict = new short[]{0}; + + for (int p = 1; p <= MAX_CUT; p++) { + int rightBw = 32 - p; + Map counts = new HashMap<>(); + for (int i = 0; i < sampleLen; i++) { + int bits = Float.floatToRawIntBits(values[i]); + short leftU16 = (short) (bits >>> rightBw); + counts.merge(leftU16, 1, Integer::sum); + } + short[] dict = topKByCount(counts); + int excCount = countExceptionsF32(values, sampleLen, dict, rightBw); + int maxCode = dict.length - 1; + int leftBw = maxCode == 0 ? 1 : (Integer.SIZE - Integer.numberOfLeadingZeros(maxCode)); + double estSize = rightBw + leftBw + (double) (excCount * 32) / sampleLen; + if (estSize < bestEstSize) { + bestEstSize = estSize; + bestRightBw = rightBw; + bestDict = dict; + } + } + return new Dictionary32(bestDict, bestRightBw); + } + + private static int countExceptionsF32(float[] values, int sampleLen, short[] dict, int rightBw) { + Map dictSet = new HashMap<>(); + for (short d : dict) { + dictSet.put(d, Boolean.TRUE); + } + int count = 0; + for (int i = 0; i < sampleLen; i++) { + int bits = Float.floatToRawIntBits(values[i]); + short leftU16 = (short) (bits >>> rightBw); + if (!dictSet.containsKey(leftU16)) { + count++; + } + } + return count; + } + + private static short[] topKByCount(Map counts) { + List> sorted = new ArrayList<>(counts.entrySet()); + sorted.sort((a, b) -> b.getValue() - a.getValue()); + int dictSize = Math.min(sorted.size(), MAX_DICT_SIZE); + short[] dict = new short[dictSize]; + for (int i = 0; i < dictSize; i++) { + dict[i] = sorted.get(i).getKey(); + } + return dict; + } + + private static Map buildLookup(short[] dict) { + Map lookup = new HashMap<>(); + for (short i = 0; i < dict.length; i++) { + lookup.put(dict[i], i); + } + return lookup; + } + + private static EncodeResult buildEncodeResult( + short[] dict, int rightBitWidth, + short[] leftCodes, Object rightPartsData, DType rightDtype, + List excPos, List excVals, EncodeContext ctx) { + + EncodingEncoder bp = ctx.lookupEncoder(EncodingId.FASTLANES_BITPACKED); + EncodeResult leftResult = bp.encode(U16_DTYPE, leftCodes, ctx); + EncodeResult rightResult = bp.encode(rightDtype, rightPartsData, ctx); + + List allBuffers = new ArrayList<>(leftResult.buffers()); + int leftBufCount = allBuffers.size(); + allBuffers.addAll(rightResult.buffers()); + + EncodeNode leftNode = EncodeNode.remapBufferIndices(leftResult.rootNode(), 0); + EncodeNode rightNode = EncodeNode.remapBufferIndices(rightResult.rootNode(), leftBufCount); + + List dictList = new ArrayList<>(dict.length); + for (short d : dict) { + dictList.add(d & 0xFFFF); + } + + EncodeNode[] children; + PatchesMetadata patchesMeta = null; + if (excPos.isEmpty()) { + children = new EncodeNode[]{leftNode, rightNode}; + } else { + long[] excPosArr = excPos.stream().mapToLong(Long::longValue).toArray(); + short[] excValsArr = new short[excVals.size()]; + for (int i = 0; i < excVals.size(); i++) { + excValsArr[i] = excVals.get(i); + } + + EncodeResult idxResult = bp.encode(U64_DTYPE, excPosArr, ctx); + EncodeResult valResult = bp.encode(U16_DTYPE, excValsArr, ctx); + + int idxOffset = allBuffers.size(); + allBuffers.addAll(idxResult.buffers()); + int idxBufCount = idxResult.buffers().size(); + allBuffers.addAll(valResult.buffers()); + + EncodeNode idxNode = EncodeNode.remapBufferIndices(idxResult.rootNode(), idxOffset); + EncodeNode valNode = EncodeNode.remapBufferIndices(valResult.rootNode(), idxOffset + idxBufCount); + + patchesMeta = new PatchesMetadata( + (long) excPos.size(), + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(PType.U64.ordinal()), + null, null, null); + children = new EncodeNode[]{leftNode, rightNode, idxNode, valNode}; + } + + byte[] metaBytes = new ALPRDMetadata( + rightBitWidth, + dict.length, + dictList, + io.github.dfa1.vortex.proto.PType.fromValue(PType.U16.ordinal()), + patchesMeta + ).encode(); + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_ALPRD, ByteBuffer.wrap(metaBytes), children, new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static EncodeResult emptyResult(DType rightDtype, EncodeContext ctx) { + EncodingEncoder bp = ctx.lookupEncoder(EncodingId.FASTLANES_BITPACKED); + EncodeResult leftResult = bp.encode(U16_DTYPE, new short[0], ctx); + EncodeResult rightResult = bp.encode(rightDtype, + rightDtype.equals(U32_DTYPE) ? new int[0] : new long[0], ctx); + + List allBuffers = new ArrayList<>(leftResult.buffers()); + int leftBufCount = allBuffers.size(); + allBuffers.addAll(rightResult.buffers()); + + EncodeNode leftNode = EncodeNode.remapBufferIndices(leftResult.rootNode(), 0); + EncodeNode rightNode = EncodeNode.remapBufferIndices(rightResult.rootNode(), leftBufCount); + + byte[] metaBytes = new ALPRDMetadata( + 48, + 0, + List.of(), + io.github.dfa1.vortex.proto.PType.fromValue(PType.U16.ordinal()), + null).encode(); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_ALPRD, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{leftNode, rightNode}, new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private record Dictionary64(short[] dict, int rightBitWidth) { + } + + private record Dictionary32(short[] dict, int rightBitWidth) { + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoder.java new file mode 100644 index 00000000..e65ca31c --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoder.java @@ -0,0 +1,227 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.BitPackedMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code fastlanes.bitpacked}. +public final class BitpackedEncodingEncoder implements EncodingEncoder { + private static final int[] FL_ORDER = {0, 4, 2, 6, 1, 5, 3, 7}; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public BitpackedEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_BITPACKED; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return switch (p.ptype()) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> true; + default -> false; + }; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + long[] longs = toLongs(data, ptype); + int n = longs.length; + int typeBits = ptype.byteSize() * 8; + long typeMask = typeMask(typeBits); + boolean unsign = isUnsigned(ptype); + + long signedMin = 0L; + long signedMax = 0L; + long maxUnsigned = 0L; + int bitWidth = 0; + + if (n > 0) { + signedMin = longs[0]; + signedMax = longs[0]; + for (long v : longs) { + if (unsign ? Long.compareUnsigned(v, signedMin) < 0 : v < signedMin) { + signedMin = v; + } + if (unsign ? Long.compareUnsigned(v, signedMax) > 0 : v > signedMax) { + signedMax = v; + } + long uv = v & typeMask; + if (Long.compareUnsigned(uv, maxUnsigned) > 0) { + maxUnsigned = uv; + } + } + bitWidth = maxUnsigned == 0L ? 0 : (Long.SIZE - Long.numberOfLeadingZeros(maxUnsigned)); + } + + MemorySegment packed = packFastLanes(longs, n, bitWidth, typeBits, ctx.arena()); + + byte[] metaBytes = new BitPackedMetadata(bitWidth, 0, null).encode(); + + byte[] statsMin = n > 0 ? statsBytes(ptype, signedMin) : null; + byte[] statsMax = n > 0 ? statsBytes(ptype, signedMax) : null; + + EncodeNode root = new EncodeNode(EncodingId.FASTLANES_BITPACKED, ByteBuffer.wrap(metaBytes), + new EncodeNode[0], new int[]{0}); + return new EncodeResult(root, List.of(packed), statsMin, statsMax); + } + + private static MemorySegment packFastLanes(long[] values, int n, int bitWidth, int typeBits, Arena arena) { + if (bitWidth == 0 || n == 0) { + return MemorySegment.ofArray(new byte[0]); + } + int lanes = 1024 / typeBits; + int wordBytes = typeBits / 8; + int blockCount = (n + 1023) / 1024; + long typeMask = typeMask(typeBits); + MemorySegment seg = arena.allocate((long) blockCount * 128 * bitWidth); + + for (int block = 0; block < blockCount; block++) { + int blockByteOff = block * 128 * bitWidth; + int blockStart = block * 1024; + + for (int row = 0; row < typeBits; row++) { + int currWord = (row * bitWidth) / typeBits; + int nextWord = ((row + 1) * bitWidth) / typeBits; + int shift = (row * bitWidth) % typeBits; + int remainingBits = (nextWord > currWord) ? ((row + 1) * bitWidth) % typeBits : 0; + int currentBits = bitWidth - remainingBits; + + for (int lane = 0; lane < lanes; lane++) { + int o = row / 8; + int s = row % 8; + int logicalIdx = blockStart + FL_ORDER[o] * 16 + s * 128 + lane; + long value = (logicalIdx < n) ? (values[logicalIdx] & typeMask) : 0L; + + int wordOff = blockByteOff + (lanes * currWord + lane) * wordBytes; + long existing = readWordFromSeg(seg, wordOff, typeBits); + existing |= (value << shift) & typeMask; + writeWordToSeg(seg, wordOff, existing, typeBits); + + if (remainingBits > 0) { + int hiWordOff = blockByteOff + (lanes * nextWord + lane) * wordBytes; + long existingHi = readWordFromSeg(seg, hiWordOff, typeBits); + existingHi |= (value >>> currentBits) & typeMask; + writeWordToSeg(seg, hiWordOff, existingHi, typeBits); + } + } + } + } + return seg; + } + + private static long[] toLongs(Object data, PType ptype) { + return switch (ptype) { + case I8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Byte.toUnsignedLong(arr[i]); + } + yield r; + } + case I16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Short.toUnsignedLong(arr[i]); + } + yield r; + } + case I32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Integer.toUnsignedLong(arr[i]); + } + yield r; + } + case I64, U64 -> (long[]) data; + default -> throw new VortexException(EncodingId.FASTLANES_BITPACKED, "unsupported ptype: " + ptype); + }; + } + + private static long typeMask(int typeBits) { + return typeBits == 64 ? -1L : (1L << typeBits) - 1L; + } + + private static boolean isUnsigned(PType ptype) { + return switch (ptype) { + case U8, U16, U32, U64 -> true; + default -> false; + }; + } + + private static byte[] statsBytes(PType ptype, long value) { + if (isUnsigned(ptype)) { + return ScalarValue.ofUint64Value(value).encode(); + } + return ScalarValue.ofInt64Value(value).encode(); + } + + private static long readWordFromSeg(MemorySegment seg, int off, int typeBits) { + return switch (typeBits) { + case 8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case 16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case 32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case 64 -> seg.get(PTypeIO.LE_LONG, off); + default -> + throw new VortexException(EncodingId.FASTLANES_BITPACKED, "unsupported typeBits: " + typeBits); + }; + } + + private static void writeWordToSeg(MemorySegment seg, int off, long value, int typeBits) { + switch (typeBits) { + case 8 -> seg.set(ValueLayout.JAVA_BYTE, off, (byte) value); + case 16 -> seg.set(PTypeIO.LE_SHORT, off, (short) value); + case 32 -> seg.set(PTypeIO.LE_INT, off, (int) value); + case 64 -> seg.set(PTypeIO.LE_LONG, off, value); + default -> + throw new VortexException(EncodingId.FASTLANES_BITPACKED, "unsupported typeBits: " + typeBits); + } + } +} diff --git a/core/src/main/java/io/github/dfa1/vortex/encoding/BoolEncoding.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoder.java similarity index 69% rename from core/src/main/java/io/github/dfa1/vortex/encoding/BoolEncoding.java rename to writer/src/main/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoder.java index 19292a07..cfd0501d 100644 --- a/core/src/main/java/io/github/dfa1/vortex/encoding/BoolEncoding.java +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoder.java @@ -1,19 +1,26 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.DType; -import io.github.dfa1.vortex.core.array.Array; -import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; import io.github.dfa1.vortex.proto.ScalarValue; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; -/// Encoding for `vortex.bool` — bit-packed boolean arrays (LSB first). -public final class BoolEncoding implements Encoding { +/// Write-only encoder for {@code vortex.bool} (bit-packed boolean arrays, LSB first). +/// +///

ADR 0001 Phase 3: first encoding lifted into a standalone {@link EncodingEncoder} +/// implementation in the {@code writer} module. The corresponding read-side decode path +/// lives on {@link io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder} in +/// {@code reader}. +public final class BoolEncodingEncoder implements EncodingEncoder { - /// Creates a new {@code BoolEncoding} instance. - public BoolEncoding() { + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public BoolEncodingEncoder() { } private static MemorySegment encodeBool(boolean[] data, Arena arena) { @@ -65,9 +72,4 @@ public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { : null; return EncodeResult.simple(encodingId(), encodeBool(bools, ctx.arena()), statsMin, statsMax); } - - @Override - public Array decode(DecodeContext ctx) { - return new BoolArray(ctx.dtype(), ctx.rowCount(), ctx.buffer(0)); - } } diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoder.java new file mode 100644 index 00000000..f496dddc --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoder.java @@ -0,0 +1,38 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +/// Write-only encoder for {@code vortex.bytebool} — one byte per boolean element. +public final class ByteBoolEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ByteBoolEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_BYTEBOOL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Bool; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + boolean[] bools = (boolean[]) data; + MemorySegment seg = ctx.arena().allocate(bools.length); + for (int i = 0; i < bools.length; i++) { + seg.set(ValueLayout.JAVA_BYTE, i, bools[i] ? (byte) 1 : (byte) 0); + } + return EncodeResult.simple(encodingId(), seg); + } +} diff --git a/core/src/main/java/io/github/dfa1/vortex/encoding/CascadingCompressor.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/CascadingCompressor.java similarity index 83% rename from core/src/main/java/io/github/dfa1/vortex/encoding/CascadingCompressor.java rename to writer/src/main/java/io/github/dfa1/vortex/writer/encode/CascadingCompressor.java index 0f7cead2..37752202 100644 --- a/core/src/main/java/io/github/dfa1/vortex/encoding/CascadingCompressor.java +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/CascadingCompressor.java @@ -1,27 +1,36 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.StructData; import java.lang.foreign.MemorySegment; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Set; /// Cascading compressor: evaluates multiple encodings on a sample and picks the one /// producing the smallest output. With {@code allowedCascading > 0}, also recurses /// into open child slots (e.g. ALP → Bitpacked for F64 columns). /// -///

Encodings that override {@link Encoding#encodeCascade} expose intermediate +///

Encodings that override {@link EncodingEncoder#encodeCascade} expose intermediate /// representations as open children; encodings that use the default are terminal. /// At depth 0 only terminal encodings are considered. public final class CascadingCompressor { - private final List encodings; + private final List encodings; - /// Constructs a {@code CascadingCompressor} with the given candidate encodings. + /// Constructs a {@code CascadingCompressor} with the given candidate encoders. /// - /// @param encodings candidate encodings evaluated during compression - public CascadingCompressor(List encodings) { + /// @param encodings candidate encoders evaluated during compression + public CascadingCompressor(List encodings) { this.encodings = List.copyOf(encodings); } @@ -54,8 +63,6 @@ private static Object sliceSample(Object data, int n) { }; } - // ── Size measurement (on sample) ────────────────────────────────────────── - private static long primitiveBytes(DType dtype, int n) { if (!(dtype instanceof DType.Primitive p)) { return (long) n * 8; @@ -66,19 +73,17 @@ private static long primitiveBytes(DType dtype, int n) { /// Entry point: encode {@code data} using the best cascading strategy. /// ///

Cascade parameters (depth, sampling, exclusions) are taken from {@code ctx}. - /// Use {@link EncodeContext#ofDepth(int, java.lang.foreign.Arena, Registry)} + /// Use {@link EncodeContext#ofDepth(int, java.lang.foreign.Arena, java.util.Map)} /// to build a context with cascade depth set. /// /// @param dtype the logical type of the data to encode /// @param data input data in the format expected by the candidate encodings - /// @param ctx encoding context supplying the arena, registry, and cascade parameters + /// @param ctx encoding context supplying the arena, encoder map, and cascade parameters /// @return the {@link EncodeResult} produced by the winning encoding public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { return encodeWithCtx(dtype, data, ctx); } - // ── Full-data run + buffer splicing ─────────────────────────────────────── - private EncodeResult encodeWithCtx(DType dtype, Object data, EncodeContext ctx) { if (dtype instanceof DType.Struct structDtype) { return encodeStruct(structDtype, (StructData) data, ctx); @@ -100,9 +105,9 @@ private EncodeResult encodeWithCtx(DType dtype, Object data, EncodeContext ctx) Object sample = (sampleSize < n) ? sliceSample(data, sampleSize) : data; long bestSampleSize = primitiveBytes(dtype, sampleSize); - Encoding winner = null; + EncodingEncoder winner = null; - for (Encoding enc : encodings) { + for (EncodingEncoder enc : encodings) { if (!enc.accepts(dtype) || ctx.excluded().contains(enc.encodingId())) { continue; } @@ -129,9 +134,7 @@ private EncodeResult encodeWithCtx(DType dtype, Object data, EncodeContext ctx) return spliceResult(winner, dtype, data, ctx); } - // ── Struct encoding ─────────────────────────────────────────────────────── - - private long measureStep(Encoding enc, CascadeStep step, EncodeContext ctx) { + private long measureStep(EncodingEncoder enc, CascadeStep step, EncodeContext ctx) { long total = step.ownedBytes(); for (ChildSlot slot : step.openChildren()) { EncodeContext childCtx = ctx.withDecrementedDepth().withExcluded(enc.encodingId()); @@ -140,12 +143,10 @@ private long measureStep(Encoding enc, CascadeStep step, EncodeContext ctx) { return total; } - // ── Helpers ─────────────────────────────────────────────────────────────── - private long measureBestChild(DType dtype, Object data, EncodeContext ctx) { int n = dataLength(data); long best = primitiveBytes(dtype, n); - for (Encoding enc : encodings) { + for (EncodingEncoder enc : encodings) { if (!enc.accepts(dtype) || ctx.excluded().contains(enc.encodingId())) { continue; } @@ -161,7 +162,7 @@ private long measureBestChild(DType dtype, Object data, EncodeContext ctx) { return best; } - private EncodeResult spliceResult(Encoding winner, DType dtype, Object data, EncodeContext ctx) { + private EncodeResult spliceResult(EncodingEncoder winner, DType dtype, Object data, EncodeContext ctx) { CascadeStep step = winner.encodeCascade(dtype, data, ctx); if (!step.applicable()) { @@ -209,8 +210,8 @@ private EncodeResult encodeStruct(DType.Struct dtype, StructData data, EncodeCon return new EncodeResult(root, List.copyOf(allBuffers), null, null); } - private Encoding findPrimitiveEncoding(DType dtype, java.util.Set excluded) { - for (Encoding enc : encodings) { + private EncodingEncoder findPrimitiveEncoding(DType dtype, Set excluded) { + for (EncodingEncoder enc : encodings) { if (excluded.contains(enc.encodingId())) { continue; } @@ -220,7 +221,7 @@ private Encoding findPrimitiveEncoding(DType dtype, java.util.Set ex } // Fall through to any accepting encoding (still honouring exclusions so that // spliceResult's notApplicable retry rotates to the next candidate). - for (Encoding enc : encodings) { + for (EncodingEncoder enc : encodings) { if (excluded.contains(enc.encodingId())) { continue; } @@ -228,6 +229,6 @@ private Encoding findPrimitiveEncoding(DType dtype, java.util.Set ex return enc; } } - throw new UnsupportedOperationException("no encoding for dtype: " + dtype); + throw new UnsupportedOperationException("no encoder for dtype: " + dtype); } } diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoder.java new file mode 100644 index 00000000..1b915c2a --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoder.java @@ -0,0 +1,93 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; + + +import io.github.dfa1.vortex.encoding.ChunkedData; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; + +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + + + + + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.chunked}. +public final class ChunkedEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ChunkedEncodingEncoder() { + } + + private static final List FALLBACK = List.of( + new PrimitiveEncodingEncoder(), new VarBinEncodingEncoder(), new BoolEncodingEncoder(), + new NullEncodingEncoder(), new ByteBoolEncodingEncoder(), new StructEncodingEncoder()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_CHUNKED; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Struct; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + ChunkedData cd = (ChunkedData) data; + List chunks = cd.chunks(); + long[] chunkLengths = cd.chunkLengths(); + int nchunks = chunks.size(); + if (nchunks == 0) { + throw new VortexException(EncodingId.VORTEX_CHUNKED, "at least one chunk required"); + } + + long[] offsets = new long[nchunks + 1]; + offsets[0] = 0; + for (int i = 0; i < nchunks; i++) { + offsets[i + 1] = offsets[i] + chunkLengths[i]; + } + + DType u64 = new DType.Primitive(PType.U64, false); + EncodeResult offsetsResult = ctx.lookupEncoder(EncodingId.VORTEX_PRIMITIVE).encode(u64, offsets, ctx); + + List allBuffers = new ArrayList<>(offsetsResult.buffers()); + EncodeNode[] children = new EncodeNode[nchunks + 1]; + children[0] = offsetsResult.rootNode(); + + EncodingEncoder inner = findEncoding(dtype); + for (int i = 0; i < nchunks; i++) { + EncodeResult chunkResult = inner.encode(dtype, chunks.get(i), ctx); + int bufOffset = allBuffers.size(); + children[i + 1] = EncodeNode.remapBufferIndices(chunkResult.rootNode(), bufOffset); + allBuffers.addAll(chunkResult.buffers()); + } + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_CHUNKED, + ByteBuffer.wrap(new byte[0]), + children, + new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static EncodingEncoder findEncoding(DType dtype) { + for (EncodingEncoder enc : FALLBACK) { + if (enc.accepts(dtype)) { + return enc; + } + } + throw new UnsupportedOperationException("no fallback encoding for dtype: " + dtype); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoder.java new file mode 100644 index 00000000..95864d27 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoder.java @@ -0,0 +1,103 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.MemorySegment; + +/// Write-only encoder for {@code vortex.constant}. +public final class ConstantEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ConstantEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_CONSTANT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_CONSTANT, "encode only supports Primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + if (!isConstant(data, ptype)) { + throw new VortexException(EncodingId.VORTEX_CONSTANT, "not a constant array"); + } + long firstRaw = readFirstRaw(data, ptype); + ScalarValue scalar = buildScalar(ptype, firstRaw); + return EncodeResult.simple(EncodingId.VORTEX_CONSTANT, MemorySegment.ofArray(scalar.encode())); + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext encodeCtx) { + if (!isConstant(data, ((DType.Primitive) dtype).ptype())) { + return CascadeStep.notApplicable(); + } + return CascadeStep.terminal(encode(dtype, data, encodeCtx)); + } + + private static long readFirstRaw(Object data, PType ptype) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data).length > 0 ? ((byte[]) data)[0] : 0L; + case I16, U16 -> ((short[]) data).length > 0 ? ((short[]) data)[0] : 0L; + case I32, U32 -> ((int[]) data).length > 0 ? ((int[]) data)[0] : 0L; + case I64, U64 -> ((long[]) data).length > 0 ? ((long[]) data)[0] : 0L; + case F32 -> ((float[]) data).length > 0 ? Float.floatToRawIntBits(((float[]) data)[0]) : 0L; + case F64 -> ((double[]) data).length > 0 ? Double.doubleToRawLongBits(((double[]) data)[0]) : 0L; + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype: " + ptype); + }; + } + + private static boolean isConstant(Object data, PType ptype) { + long firstRaw = readFirstRaw(data, ptype); + int len = switch (ptype) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case I64, U64 -> ((long[]) data).length; + case F32 -> ((float[]) data).length; + case F64 -> ((double[]) data).length; + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype: " + ptype); + }; + for (int i = 1; i < len; i++) { + long raw = switch (ptype) { + case I8, U8 -> ((byte[]) data)[i]; + case I16, U16 -> ((short[]) data)[i]; + case I32, U32 -> ((int[]) data)[i]; + case I64, U64 -> ((long[]) data)[i]; + case F32 -> Float.floatToRawIntBits(((float[]) data)[i]); + case F64 -> Double.doubleToRawLongBits(((double[]) data)[i]); + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype: " + ptype); + }; + if (raw != firstRaw) { + return false; + } + } + return true; + } + + private static ScalarValue buildScalar(PType ptype, long rawBits) { + return switch (ptype) { + case U8, U16, U32, U64 -> ScalarValue.ofUint64Value(rawBits); + case I8, I16, I32, I64 -> ScalarValue.ofInt64Value(rawBits); + case F32 -> ScalarValue.ofF32Value(Float.intBitsToFloat((int) rawBits)); + case F64 -> ScalarValue.ofF64Value(Double.longBitsToDouble(rawBits)); + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype: " + ptype); + }; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoder.java new file mode 100644 index 00000000..b686d7a4 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoder.java @@ -0,0 +1,157 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.DateTimePartsData; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.TimeUnit; +import io.github.dfa1.vortex.proto.DateTimePartsMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.datetimeparts}. +public final class DateTimePartsEncodingEncoder implements EncodingEncoder { + + private static final long SECONDS_PER_DAY = 86_400L; + private static final DType I64 = new DType.Primitive(PType.I64, false); + private static final DType I64_NULLABLE = new DType.Primitive(PType.I64, true); + private static final io.github.dfa1.vortex.proto.PType I64_PROTO = + io.github.dfa1.vortex.proto.PType.fromValue(PType.I64.ordinal()); + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DateTimePartsEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DATETIMEPARTS; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Extension; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.Extension ext = (DType.Extension) dtype; + DateTimePartsData d = (DateTimePartsData) data; + + ByteBuffer extMeta = ext.metadata(); + if (extMeta == null || extMeta.remaining() < 3) { + throw new VortexException(EncodingId.VORTEX_DATETIMEPARTS, + "extension metadata missing or too short"); + } + byte[] extBytes = new byte[extMeta.remaining()]; + extMeta.duplicate().get(extBytes); + TimeUnit unit = TimeUnit.fromTag(extBytes[0]); + + long divisor = unit.divisor(); + long ticksPerDay = SECONDS_PER_DAY * divisor; + int n = d.timestamps().length; + + long[] days = new long[n]; + long[] seconds = new long[n]; + long[] subseconds = new long[n]; + + for (int i = 0; i < n; i++) { + long ts = d.timestamps()[i]; + long dval = ts / ticksPerDay; + long rem = ts % ticksPerDay; + if (rem < 0) { + rem += ticksPerDay; + dval--; + } + days[i] = dval; + seconds[i] = rem / divisor; + subseconds[i] = rem % divisor; + } + + DType daysDtype = d.nullable() ? I64_NULLABLE : I64; + + EncodingEncoder primEnc = ctx.lookupEncoder(EncodingId.VORTEX_PRIMITIVE); + EncodeResult daysResult = primEnc.encode(daysDtype, days, ctx); + EncodeResult secondsResult = primEnc.encode(I64, seconds, ctx); + EncodeResult subsecondsResult = primEnc.encode(I64, subseconds, ctx); + + List allBuffers = new ArrayList<>(); + allBuffers.addAll(daysResult.buffers()); + allBuffers.addAll(secondsResult.buffers()); + allBuffers.addAll(subsecondsResult.buffers()); + + int off1 = daysResult.buffers().size(); + int off2 = off1 + secondsResult.buffers().size(); + + EncodeNode daysNode = EncodeNode.remapBufferIndices(daysResult.rootNode(), 0); + EncodeNode secondsNode = EncodeNode.remapBufferIndices(secondsResult.rootNode(), off1); + EncodeNode subsecondsNode = EncodeNode.remapBufferIndices(subsecondsResult.rootNode(), off2); + + byte[] metaBytes = new DateTimePartsMetadata(I64_PROTO, I64_PROTO, I64_PROTO).encode(); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_DATETIMEPARTS, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{daysNode, secondsNode, subsecondsNode}, + new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext encodeCtx) { + if (!(data instanceof DateTimePartsData d)) { + return CascadeStep.notApplicable(); + } + DType.Extension ext = (DType.Extension) dtype; + ByteBuffer extMeta = ext.metadata(); + byte[] extBytes = new byte[extMeta.remaining()]; + extMeta.duplicate().get(extBytes); + TimeUnit unit = TimeUnit.fromTag(extBytes[0]); + + long divisor = unit.divisor(); + long ticksPerDay = SECONDS_PER_DAY * divisor; + int n = d.timestamps().length; + + long[] days = new long[n]; + long[] seconds = new long[n]; + long[] subseconds = new long[n]; + + for (int i = 0; i < n; i++) { + long ts = d.timestamps()[i]; + long dval = ts / ticksPerDay; + long rem = ts % ticksPerDay; + if (rem < 0) { + rem += ticksPerDay; + dval--; + } + days[i] = dval; + seconds[i] = rem / divisor; + subseconds[i] = rem % divisor; + } + + byte[] metaBytes = new DateTimePartsMetadata(I64_PROTO, I64_PROTO, I64_PROTO).encode(); + + EncodeNode partialRoot = new EncodeNode( + EncodingId.VORTEX_DATETIMEPARTS, + ByteBuffer.wrap(metaBytes), + new EncodeNode[3], + new int[0]); + + DType daysDtype = d.nullable() ? I64_NULLABLE : I64; + List children = List.of( + new ChildSlot(daysDtype, days, 0), + new ChildSlot(I64, seconds, 1), + new ChildSlot(I64, subseconds, 2)); + + return new CascadeStep(partialRoot, List.of(), children, null, null, true); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoder.java new file mode 100644 index 00000000..4f3f9804 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoder.java @@ -0,0 +1,49 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DecimalBytePartsMetadata; + +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code vortex.decimal_byte_parts}. +public final class DecimalBytePartsEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DecimalBytePartsEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DECIMAL_BYTE_PARTS; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Decimal; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.Decimal d = (DType.Decimal) dtype; + long[] longs = (long[]) data; + DType mspDtype = new DType.Primitive(PType.I64, d.nullable()); + EncodeResult mspResult = ctx.lookupEncoder(EncodingId.VORTEX_PRIMITIVE).encode(mspDtype, longs, ctx); + + DecimalBytePartsMetadata proto = new DecimalBytePartsMetadata( + io.github.dfa1.vortex.proto.PType.fromValue(PType.I64.ordinal()), + 0); + ByteBuffer metaBuf = ByteBuffer.wrap(proto.encode()); + + EncodeNode mspNode = EncodeNode.remapBufferIndices(mspResult.rootNode(), 0); + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_DECIMAL_BYTE_PARTS, metaBuf, new EncodeNode[]{mspNode}, new int[]{}); + return new EncodeResult(root, List.copyOf(mspResult.buffers()), null, null); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoder.java new file mode 100644 index 00000000..2e34e3f9 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoder.java @@ -0,0 +1,79 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DecimalMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code vortex.decimal}. +public final class DecimalEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DecimalEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DECIMAL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Decimal; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.Decimal d = (DType.Decimal) dtype; + MemorySegment seg = (MemorySegment) data; + int valuesType = valuesType(d.precision()); + int bw = byteWidth(valuesType); + if (seg.byteSize() % bw != 0) { + throw new VortexException(EncodingId.VORTEX_DECIMAL, + "buffer size %d not multiple of byteWidth %d".formatted(seg.byteSize(), bw)); + } + ByteBuffer metaBuf = ByteBuffer.wrap(new DecimalMetadata(valuesType).encode()); + EncodeNode node = new EncodeNode(EncodingId.VORTEX_DECIMAL, metaBuf, new EncodeNode[0], new int[]{0}); + return new EncodeResult(node, List.of(seg), null, null); + } + + private static int valuesType(byte precision) { + if (precision <= 2) { + return 0; + } + if (precision <= 4) { + return 1; + } + if (precision <= 9) { + return 2; + } + if (precision <= 18) { + return 3; + } + if (precision <= 38) { + return 4; + } + return 5; + } + + private static int byteWidth(int valuesType) { + return switch (valuesType) { + case 0 -> 1; + case 1 -> 2; + case 2 -> 4; + case 3 -> 8; + case 4 -> 16; + case 5 -> 32; + default -> throw new VortexException(EncodingId.VORTEX_DECIMAL, + "unknown valuesType: " + valuesType); + }; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoder.java new file mode 100644 index 00000000..d0089056 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoder.java @@ -0,0 +1,240 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.DeltaMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code fastlanes.delta}. +public final class DeltaEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DeltaEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_DELTA; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return switch (p.ptype()) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> true; + default -> false; + }; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + long[] longs = toLongs(data, ptype); + int n = longs.length; + int typeBits = typeBits(ptype); + int lanes = lanes(ptype); + long mask = typeMask(ptype); + boolean unsign = isUnsigned(ptype); + + long minVal = 0L, maxVal = 0L; + if (n > 0) { + minVal = longs[0]; + maxVal = longs[0]; + for (int i = 1; i < n; i++) { + long v = longs[i]; + if (unsign ? Long.compareUnsigned(v, minVal) < 0 : v < minVal) { + minVal = v; + } + if (unsign ? Long.compareUnsigned(v, maxVal) > 0 : v > maxVal) { + maxVal = v; + } + } + } + + int numChunks = n == 0 ? 0 : (n + FL_CHUNK_SIZE - 1) / FL_CHUNK_SIZE; + long paddedLen = (long) numChunks * FL_CHUNK_SIZE; + int basesLen = numChunks * lanes; + + long[] basesAll = new long[basesLen]; + long[] deltasAll = new long[(int) paddedLen]; + long[] chunkBuf = new long[FL_CHUNK_SIZE]; + long[] transposed = new long[FL_CHUNK_SIZE]; + long[] chunkBases = new long[lanes]; + long[] chunkDelta = new long[FL_CHUNK_SIZE]; + + for (int chunk = 0; chunk < numChunks; chunk++) { + int start = chunk * FL_CHUNK_SIZE; + int end = Math.min(start + FL_CHUNK_SIZE, n); + for (int i = start; i < end; i++) { + chunkBuf[i - start] = longs[i] & mask; + } + for (int i = end - start; i < FL_CHUNK_SIZE; i++) { + chunkBuf[i] = 0L; + } + for (int i = 0; i < FL_CHUNK_SIZE; i++) { + transposed[i] = chunkBuf[transposeIndex(i)]; + } + int basesOff = chunk * lanes; + System.arraycopy(transposed, 0, basesAll, basesOff, lanes); + System.arraycopy(basesAll, basesOff, chunkBases, 0, lanes); + deltaChunk(transposed, chunkBases, lanes, typeBits, mask, chunkDelta); + System.arraycopy(chunkDelta, 0, deltasAll, chunk * FL_CHUNK_SIZE, FL_CHUNK_SIZE); + } + + MemorySegment basesSeg = fromLongs(basesAll, ptype, ctx.arena()); + MemorySegment deltasSeg = fromLongs(deltasAll, ptype, ctx.arena()); + + byte[] metaBytes = new DeltaMetadata(paddedLen, 0).encode(); + + byte[] statsMin = n > 0 ? statsBytes(ptype, minVal) : null; + byte[] statsMax = n > 0 ? statsBytes(ptype, maxVal) : null; + + EncodeNode basesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode deltasNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode root = new EncodeNode(EncodingId.FASTLANES_DELTA, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{basesNode, deltasNode}, new int[0]); + return new EncodeResult(root, List.of(basesSeg, deltasSeg), statsMin, statsMax); + } + + private static void deltaChunk(long[] transposed, long[] bases, int lanes, int typeBits, long mask, long[] out) { + for (int lane = 0; lane < lanes; lane++) { + long prev = bases[lane] & mask; + for (int row = 0; row < typeBits; row++) { + int idx = iterateIndex(row, lane); + long next = transposed[idx] & mask; + out[idx] = (next - prev) & mask; + prev = next; + } + } + } + + private static long[] toLongs(Object data, PType ptype) { + return switch (ptype) { + case I8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Byte.toUnsignedLong(arr[i]); + } + yield r; + } + case I16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Short.toUnsignedLong(arr[i]); + } + yield r; + } + case I32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Integer.toUnsignedLong(arr[i]); + } + yield r; + } + case I64, U64 -> (long[]) data; + default -> throw new VortexException(EncodingId.FASTLANES_DELTA, "unsupported ptype: " + ptype); + }; + } + + private static boolean isUnsigned(PType ptype) { + return switch (ptype) { + case U8, U16, U32, U64 -> true; + default -> false; + }; + } + + private static byte[] statsBytes(PType ptype, long value) { + if (isUnsigned(ptype)) { + return ScalarValue.ofUint64Value(value).encode(); + } + return ScalarValue.ofInt64Value(value).encode(); + } + + private static final int FL_CHUNK_SIZE = 1024; + + private static final int[] FL_ORDER = {0, 4, 2, 6, 1, 5, 3, 7}; + + private static int transposeIndex(int idx) { + int lane = idx % 16; + int order = (idx / 16) % 8; + int row = idx / 128; + return lane * 64 + FL_ORDER[order] * 8 + row; + } + + private static int iterateIndex(int row, int lane) { + int o = row / 8; + int s = row % 8; + return FL_ORDER[o] * 16 + s * 128 + lane; + } + + private static int lanes(PType ptype) { + return FL_CHUNK_SIZE / (ptype.byteSize() * 8); + } + + private static int typeBits(PType ptype) { + return ptype.byteSize() * 8; + } + + private static long typeMask(PType ptype) { + int bits = ptype.byteSize() * 8; + return bits == 64 ? -1L : (1L << bits) - 1; + } + + private static MemorySegment fromLongs(long[] longs, PType ptype, SegmentAllocator arena) { + if (ptype == PType.I64 || ptype == PType.U64) { + MemorySegment dst = arena.allocate((long) longs.length * 8); + MemorySegment.copy(MemorySegment.ofArray(longs), ValueLayout.JAVA_LONG, 0L, dst, PTypeIO.LE_LONG, 0L, longs.length); + return dst; + } + int n = longs.length; + long elemSize = ptype.byteSize(); + MemorySegment seg = arena.allocate(n * elemSize); + for (int i = 0; i < n; i++) { + PTypeIO.set(seg, i * elemSize, ptype, longs[i]); + } + return seg; + } + +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoder.java new file mode 100644 index 00000000..5bd4dc9c --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoder.java @@ -0,0 +1,317 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.DictMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.VarBinMetadata; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashMap; +import java.util.List; + +/// Write-only encoder for {@code vortex.dict}. +public final class DictEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DictEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DICT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Utf8; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (dtype instanceof DType.Utf8) { + return encodeUtf8((String[]) data, ctx); + } + DictData d = buildDictData(dtype, data, ctx); + PType codePType = d.codePType(); + int codeBytes = codePType.byteSize(); + + MemorySegment codesBuf = ctx.arena().allocate((long) d.len() * codeBytes); + for (int i = 0; i < d.len(); i++) { + writeCodeToSeg(codesBuf, codePType, i, readCodeFromArr(d.codesArr(), codePType, i)); + } + + ByteBuffer meta = ByteBuffer.allocate(1).put(0, (byte) codePType.ordinal()); + EncodeNode valuesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode codesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode rootNode = new EncodeNode( + EncodingId.VORTEX_DICT, meta, + new EncodeNode[]{valuesNode, codesNode}, + new int[0]); + + return new EncodeResult(rootNode, List.of(d.valuesBuf(), codesBuf), null, null); + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext ctx) { + if (dtype instanceof DType.Utf8) { + return CascadeStep.terminal(encodeUtf8((String[]) data, ctx)); + } + DictData d = buildDictData(dtype, data, ctx); + PType codePType = d.codePType(); + + ByteBuffer meta = ByteBuffer.allocate(1).put(0, (byte) codePType.ordinal()); + EncodeNode valuesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode partialRoot = new EncodeNode( + EncodingId.VORTEX_DICT, meta, + new EncodeNode[]{valuesNode, null}, + new int[0]); + + DType codesDtype = new DType.Primitive(codePType, false); + ChildSlot slot = new ChildSlot(codesDtype, d.codesArr(), 1); + return new CascadeStep(partialRoot, List.of(d.valuesBuf()), List.of(slot), null, null, true); + } + + private static EncodeResult encodeUtf8(String[] strings, EncodeContext ctx) { + int n = strings.length; + + var valueMap = new LinkedHashMap(); + for (String s : strings) { + valueMap.computeIfAbsent(s, _ -> valueMap.size()); + } + + int dictSize = valueMap.size(); + PType codePType = codePType(dictSize); + int codeBytes = codePType.byteSize(); + + byte[][] dictByteArrays = new byte[dictSize][]; + int j = 0; + long totalDictBytes = 0; + for (String s : valueMap.keySet()) { + dictByteArrays[j] = s.getBytes(StandardCharsets.UTF_8); + totalDictBytes += dictByteArrays[j].length; + j++; + } + + Arena arena = ctx.arena(); + MemorySegment dictBytesBuf = arena.allocate(totalDictBytes > 0 ? totalDictBytes : 1); + MemorySegment dictOffsetsBuf = arena.allocate((long) (dictSize + 1) * Long.BYTES, Long.BYTES); + + long pos = 0; + dictOffsetsBuf.setAtIndex(PTypeIO.LE_LONG, 0, 0L); + for (int i = 0; i < dictSize; i++) { + MemorySegment.copy(MemorySegment.ofArray(dictByteArrays[i]), 0, dictBytesBuf, pos, dictByteArrays[i].length); + pos += dictByteArrays[i].length; + dictOffsetsBuf.setAtIndex(PTypeIO.LE_LONG, i + 1, pos); + } + + MemorySegment codesBuf = arena.allocate((long) n * codeBytes); + for (int i = 0; i < n; i++) { + writeCodeToSeg(codesBuf, codePType, i, valueMap.get(strings[i])); + } + + byte[] metaBytes = new DictMetadata( + dictSize, + io.github.dfa1.vortex.proto.PType.fromValue(codePType.ordinal()), + null, + null + ).encode(); + + byte[] varBinMetaBytes = new VarBinMetadata( + io.github.dfa1.vortex.proto.PType.fromValue(PType.I64.ordinal()) + ).encode(); + + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode valuesNode = new EncodeNode(EncodingId.VORTEX_VARBIN, + ByteBuffer.wrap(varBinMetaBytes), + new EncodeNode[]{offsetsNode}, + new int[]{0}); + EncodeNode codesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_DICT, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{codesNode, valuesNode}, + new int[0]); + + String minStr = valueMap.keySet().stream().min(String::compareTo).orElse(null); + String maxStr = valueMap.keySet().stream().max(String::compareTo).orElse(null); + byte[] statsMin = minStr != null ? ScalarValue.ofStringValue(minStr).encode() : null; + byte[] statsMax = maxStr != null ? ScalarValue.ofStringValue(maxStr).encode() : null; + return new EncodeResult(root, List.of(dictBytesBuf, dictOffsetsBuf, codesBuf), statsMin, statsMax); + } + + private static DictData buildDictData(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + var valueMap = new LinkedHashMap(); + int len = arrayLength(data, ptype); + for (int i = 0; i < len; i++) { + Object v = readElement(data, ptype, i); + valueMap.computeIfAbsent(v, _ -> valueMap.size()); + } + + int dictSize = valueMap.size(); + PType codePType = codePType(dictSize); + int codeBytes = codePType.byteSize(); + + Object uniqueArray = buildUniqueArray(ptype, valueMap.keySet(), dictSize); + MemorySegment valuesBuf = PTypeIO.copyArray(ptype, uniqueArray, dictSize); + + MemorySegment codesBuf = ctx.arena().allocate((long) len * codeBytes); + for (int i = 0; i < len; i++) { + Object v = readElement(data, ptype, i); + int code = valueMap.get(v); + writeCodeToSeg(codesBuf, codePType, i, code); + } + + Object codesArr = switch (codePType) { + case U8 -> { + byte[] a = new byte[len]; + for (int i = 0; i < len; i++) { + a[i] = codesBuf.get(ValueLayout.JAVA_BYTE, i); + } + yield a; + } + case U16 -> { + short[] a = new short[len]; + for (int i = 0; i < len; i++) { + a[i] = codesBuf.get(PTypeIO.LE_SHORT, (long) i * 2); + } + yield a; + } + default -> { + int[] a = new int[len]; + for (int i = 0; i < len; i++) { + a[i] = codesBuf.get(PTypeIO.LE_INT, (long) i * 4); + } + yield a; + } + }; + return new DictData(valuesBuf, codesArr, codePType, len); + } + + private static PType codePType(int dictSize) { + if (dictSize <= 256) { + return PType.U8; + } + if (dictSize <= 65536) { + return PType.U16; + } + return PType.U32; + } + + private static int arrayLength(Object data, PType ptype) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case I64, U64 -> ((long[]) data).length; + case F32 -> ((float[]) data).length; + case F64 -> ((double[]) data).length; + case F16 -> ((short[]) data).length; + }; + } + + private static Object readElement(Object data, PType ptype, int i) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data)[i]; + case I16, U16, F16 -> ((short[]) data)[i]; + case I32, U32 -> ((int[]) data)[i]; + case I64, U64 -> ((long[]) data)[i]; + case F32 -> ((float[]) data)[i]; + case F64 -> ((double[]) data)[i]; + }; + } + + private static Object buildUniqueArray(PType ptype, Iterable uniques, int dictSize) { + return switch (ptype) { + case I8, U8 -> { + byte[] a = new byte[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Byte) v; + } + yield a; + } + case I16, U16 -> { + short[] a = new short[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Short) v; + } + yield a; + } + case I32, U32 -> { + int[] a = new int[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Integer) v; + } + yield a; + } + case I64, U64 -> { + long[] a = new long[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Long) v; + } + yield a; + } + case F32 -> { + float[] a = new float[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Float) v; + } + yield a; + } + case F64 -> { + double[] a = new double[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Double) v; + } + yield a; + } + case F16 -> { + short[] a = new short[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Short) v; + } + yield a; + } + }; + } + + private static void writeCodeToSeg(MemorySegment seg, PType codePType, int idx, int code) { + switch (codePType) { + case U8 -> seg.set(ValueLayout.JAVA_BYTE, idx, (byte) code); + case U16 -> seg.set(PTypeIO.LE_SHORT, (long) idx * 2, (short) code); + case U32 -> seg.set(PTypeIO.LE_INT, (long) idx * 4, code); + default -> throw new VortexException(EncodingId.VORTEX_DICT, "unexpected code type: " + codePType); + } + } + + private static int readCodeFromArr(Object arr, PType codePType, int i) { + return switch (codePType) { + case U8 -> Byte.toUnsignedInt(((byte[]) arr)[i]); + case U16 -> Short.toUnsignedInt(((short[]) arr)[i]); + default -> ((int[]) arr)[i]; + }; + } + + private record DictData(MemorySegment valuesBuf, Object codesArr, PType codePType, int len) { + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoder.java new file mode 100644 index 00000000..e5e81329 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoder.java @@ -0,0 +1,75 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.NullableData; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.util.List; + +/// Write-only encoder for {@code vortex.ext} — wraps a storage-array encode in an ext node. +public final class ExtEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ExtEncodingEncoder() { + } + + private static final List STORAGE_FALLBACK = List.of( + new PrimitiveEncodingEncoder(), + new FixedSizeListEncodingEncoder()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_EXT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Extension; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Extension ext)) { + throw new VortexException(EncodingId.VORTEX_EXT, "expected extension dtype, got " + dtype); + } + DType storage = ext.storageDType(); + EncodeResult childResult; + if (data instanceof NullableData) { + childResult = new MaskedEncodingEncoder().encode(storage, data, ctx); + } else { + EncodingEncoder storageEncoder = null; + for (EncodingEncoder enc : STORAGE_FALLBACK) { + if (enc.accepts(storage)) { + storageEncoder = enc; + break; + } + } + if (storageEncoder == null) { + throw new VortexException(EncodingId.VORTEX_EXT, "no storage encoder for " + storage); + } + childResult = storageEncoder.encode(storage, data, ctx); + } + EncodeNode root = new EncodeNode(EncodingId.VORTEX_EXT, null, new EncodeNode[]{childResult.rootNode()}, new int[0]); + return new EncodeResult(root, childResult.buffers(), childResult.statsMin(), childResult.statsMax()); + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Extension ext)) { + throw new VortexException(EncodingId.VORTEX_EXT, "expected extension dtype, got " + dtype); + } + if (data instanceof NullableData) { + return CascadeStep.terminal(encode(dtype, data, ctx)); + } + EncodeNode partialRoot = new EncodeNode(EncodingId.VORTEX_EXT, null, new EncodeNode[1], new int[0]); + ChildSlot slot = new ChildSlot(ext.storageDType(), data, 0); + return new CascadeStep(partialRoot, List.of(), List.of(slot), null, null, true); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoder.java new file mode 100644 index 00000000..3987469f --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoder.java @@ -0,0 +1,71 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; + + +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; + +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.FixedSizeListData; + + + + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.fixed_size_list}. +public final class FixedSizeListEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FixedSizeListEncodingEncoder() { + } + + private static final List FALLBACK = List.of( + new PrimitiveEncodingEncoder(), new VarBinEncodingEncoder(), new BoolEncodingEncoder(), + new NullEncodingEncoder(), new ByteBoolEncodingEncoder()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_FIXED_SIZE_LIST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.FixedSizeList; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.FixedSizeList fsl = (DType.FixedSizeList) dtype; + FixedSizeListData fsd = (FixedSizeListData) data; + DType elementType = fsl.elementType(); + EncodingEncoder inner = findEncoding(elementType); + + EncodeResult elemResult = inner.encode(elementType, fsd.elements(), ctx); + + List allBuffers = new ArrayList<>(elemResult.buffers()); + EncodeNode elemNode = EncodeNode.remapBufferIndices(elemResult.rootNode(), 0); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_FIXED_SIZE_LIST, + ByteBuffer.wrap(new byte[0]), + new EncodeNode[]{elemNode}, + new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static EncodingEncoder findEncoding(DType dtype) { + for (EncodingEncoder enc : FALLBACK) { + if (enc.accepts(dtype)) { + return enc; + } + } + throw new UnsupportedOperationException("no fallback encoding for dtype: " + dtype); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoder.java new file mode 100644 index 00000000..47db9bde --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoder.java @@ -0,0 +1,167 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code fastlanes.for} (Frame of Reference). +public final class FrameOfReferenceEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FrameOfReferenceEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_FOR; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.FASTLANES_FOR, "expected primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + long[] longs = toLongs(data, ptype); + int n = longs.length; + + long ref = computeRef(longs, n); + MemorySegment residuals = toResidualBuffer(longs, ref, ptype, ctx); + ByteBuffer meta = buildForMeta(ref, ptype); + + EncodeNode child = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode root = new EncodeNode(EncodingId.FASTLANES_FOR, meta, new EncodeNode[]{child}, new int[0]); + return new EncodeResult(root, List.of(residuals), null, null); + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext encodeCtx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.FASTLANES_FOR, "expected primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + long[] longs = toLongs(data, ptype); + int n = longs.length; + + long ref = computeRef(longs, n); + ByteBuffer meta = buildForMeta(ref, ptype); + + EncodeNode partialRoot = new EncodeNode(EncodingId.FASTLANES_FOR, meta, new EncodeNode[1], new int[0]); + ChildSlot slot = new ChildSlot(dtype, residualsAsNativeArray(longs, ref, ptype), 0); + return new CascadeStep(partialRoot, List.of(), List.of(slot), null, null, true); + } + + private static long computeRef(long[] longs, int n) { + long ref = n > 0 ? longs[0] : 0L; + for (long v : longs) { + if (v < ref) { + ref = v; + } + } + return ref; + } + + private static ByteBuffer buildForMeta(long ref, PType ptype) { + boolean unsigned = switch (ptype) { + case U8, U16, U32, U64 -> true; + default -> false; + }; + ScalarValue scalar = unsigned ? ScalarValue.ofUint64Value(ref) : ScalarValue.ofInt64Value(ref); + return ByteBuffer.wrap(scalar.encode()); + } + + private static Object residualsAsNativeArray(long[] longs, long ref, PType ptype) { + int n = longs.length; + return switch (ptype) { + case I8, U8 -> { + byte[] r = new byte[n]; + for (int i = 0; i < n; i++) { + r[i] = (byte) (longs[i] - ref); + } + yield r; + } + case I16, U16 -> { + short[] r = new short[n]; + for (int i = 0; i < n; i++) { + r[i] = (short) (longs[i] - ref); + } + yield r; + } + case I32, U32 -> { + int[] r = new int[n]; + for (int i = 0; i < n; i++) { + r[i] = (int) (longs[i] - ref); + } + yield r; + } + case I64, U64 -> { + long[] r = new long[n]; + for (int i = 0; i < n; i++) { + r[i] = longs[i] - ref; + } + yield r; + } + default -> throw new VortexException(EncodingId.FASTLANES_FOR, "unsupported ptype: " + ptype); + }; + } + + private static long[] toLongs(Object data, PType ptype) { + return switch (ptype) { + case I8, U8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = ptype == PType.U8 ? Byte.toUnsignedLong(arr[i]) : arr[i]; + } + yield r; + } + case I16, U16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = ptype == PType.U16 ? Short.toUnsignedLong(arr[i]) : arr[i]; + } + yield r; + } + case I32, U32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = ptype == PType.U32 ? Integer.toUnsignedLong(arr[i]) : arr[i]; + } + yield r; + } + case I64, U64 -> (long[]) data; + default -> throw new VortexException(EncodingId.FASTLANES_FOR, "unsupported ptype: " + ptype); + }; + } + + private static MemorySegment toResidualBuffer(long[] longs, long ref, PType ptype, EncodeContext ctx) { + int n = longs.length; + int elemBytes = ptype.byteSize(); + MemorySegment seg = ctx.arena().allocate((long) n * elemBytes, elemBytes); + for (int i = 0; i < n; i++) { + long r = longs[i] - ref; + PTypeIO.set(seg, (long) i * elemBytes, ptype, r); + } + return seg; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoder.java new file mode 100644 index 00000000..5841237a --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoder.java @@ -0,0 +1,158 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.FSSTMetadata; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; + +/// Write-only encoder for {@code vortex.fsst}. +public final class FsstEncodingEncoder implements EncodingEncoder { + + private static final int MAX_SYMBOLS = 255; + private static final int BIGRAM_COUNT = 65536; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FsstEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_FSST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + String[] strings = (String[]) data; + int n = strings.length; + + byte[][] byteArrays = new byte[n][]; + for (int i = 0; i < n; i++) { + byteArrays[i] = strings[i].getBytes(StandardCharsets.UTF_8); + } + + int[] freq = new int[BIGRAM_COUNT]; + for (byte[] b : byteArrays) { + for (int i = 0; i + 1 < b.length; i++) { + freq[(Byte.toUnsignedInt(b[i]) << 8) | Byte.toUnsignedInt(b[i + 1])]++; + } + } + long[] ranked = new long[BIGRAM_COUNT]; + for (int i = 0; i < BIGRAM_COUNT; i++) { + ranked[i] = ((long) freq[i] << 16) | i; + } + Arrays.sort(ranked); + + int numSymbols = 0; + int[] codeForBigram = new int[BIGRAM_COUNT]; + Arrays.fill(codeForBigram, -1); + long[] symbolValues = new long[MAX_SYMBOLS]; + for (int rank = BIGRAM_COUNT - 1; rank >= 0 && numSymbols < MAX_SYMBOLS; rank--) { + int bg = (int) (ranked[rank] & 0xFFFF); + if (freq[bg] == 0) { + break; + } + codeForBigram[bg] = numSymbols; + int hi = bg >>> 8; + int lo = bg & 0xFF; + symbolValues[numSymbols] = hi | ((long) lo << 8); + numSymbols++; + } + + byte[][] compressed = new byte[n][]; + for (int i = 0; i < n; i++) { + compressed[i] = compressString(byteArrays[i], codeForBigram); + } + + Arena arena = ctx.arena(); + + MemorySegment symBuf = arena.allocate(Math.max(numSymbols * 8L, 1), 8); + for (int i = 0; i < numSymbols; i++) { + symBuf.setAtIndex(PTypeIO.LE_LONG, i, symbolValues[i]); + } + + MemorySegment symLenBuf = arena.allocate(Math.max(numSymbols, 1)); + for (int i = 0; i < numSymbols; i++) { + symLenBuf.set(ValueLayout.JAVA_BYTE, i, (byte) 2); + } + + int totalCompressed = 0; + for (byte[] c : compressed) { + totalCompressed += c.length; + } + MemorySegment compBuf = arena.allocate(Math.max(totalCompressed, 1)); + long pos = 0; + for (byte[] c : compressed) { + MemorySegment.copy(MemorySegment.ofArray(c), 0, compBuf, pos, c.length); + pos += c.length; + } + + MemorySegment uncompLenBuf = arena.allocate(Math.max(n * 4L, 1), 4); + for (int i = 0; i < n; i++) { + uncompLenBuf.setAtIndex(PTypeIO.LE_INT, i, byteArrays[i].length); + } + + MemorySegment codesOffBuf = arena.allocate((long) (n + 1) * 4, 4); + long off = 0; + codesOffBuf.setAtIndex(PTypeIO.LE_INT, 0, 0); + for (int i = 0; i < n; i++) { + off += compressed[i].length; + codesOffBuf.setAtIndex(PTypeIO.LE_INT, i + 1, (int) off); + } + + byte[] metaBytes = new FSSTMetadata( + io.github.dfa1.vortex.proto.PType.fromValue(PType.I32.ordinal()), + io.github.dfa1.vortex.proto.PType.fromValue(PType.I32.ordinal()) + ).encode(); + + EncodeNode uncompLensNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 3); + EncodeNode codesOffNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 4); + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_FSST, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{uncompLensNode, codesOffNode}, + new int[]{0, 1, 2}); + + return new EncodeResult(root, + List.of(symBuf, symLenBuf, compBuf, uncompLenBuf, codesOffBuf), + null, null); + } + + private static byte[] compressString(byte[] input, int[] codeForBigram) { + byte[] out = new byte[input.length * 2]; + int outLen = 0; + int i = 0; + while (i < input.length) { + if (i + 1 < input.length) { + int bg = (Byte.toUnsignedInt(input[i]) << 8) | Byte.toUnsignedInt(input[i + 1]); + int code = codeForBigram[bg]; + if (code >= 0) { + out[outLen++] = (byte) code; + i += 2; + continue; + } + } + out[outLen++] = (byte) 0xFF; + out[outLen++] = input[i]; + i++; + } + return Arrays.copyOf(out, outLen); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoder.java new file mode 100644 index 00000000..40af0fd2 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoder.java @@ -0,0 +1,88 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; + + +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; + +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.ListData; + +import io.github.dfa1.vortex.encoding.PTypeIO; + + +import io.github.dfa1.vortex.proto.ListMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.list}. +public final class ListEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ListEncodingEncoder() { + } + + private static final List FALLBACK = List.of( + new PrimitiveEncodingEncoder(), new VarBinEncodingEncoder(), new BoolEncodingEncoder(), + new NullEncodingEncoder(), new ByteBoolEncodingEncoder()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_LIST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.List; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.List listDtype = (DType.List) dtype; + ListData ld = (ListData) data; + DType elementType = listDtype.elementType(); + EncodingEncoder elemEncoding = findEncoding(elementType); + EncodeResult elemResult = elemEncoding.encode(elementType, ld.elements(), ctx); + + List allBuffers = new ArrayList<>(elemResult.buffers()); + int elemBufCount = allBuffers.size(); + EncodeNode elemNode = EncodeNode.remapBufferIndices(elemResult.rootNode(), 0); + + long nOffsets = ld.outerLen() + 1; + MemorySegment offsetsBuf = ctx.arena().allocate(nOffsets * Long.BYTES, Long.BYTES); + for (int i = 0; i < nOffsets; i++) { + offsetsBuf.setAtIndex(PTypeIO.LE_LONG, i, ld.offsets()[i]); + } + allBuffers.add(offsetsBuf); + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, elemBufCount); + + long elementsLen = ld.offsets()[(int) ld.outerLen()]; + byte[] metaBytes = new ListMetadata( + elementsLen, + io.github.dfa1.vortex.proto.PType.fromValue(PType.I64.ordinal()) + ).encode(); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_LIST, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{elemNode, offsetsNode}, + new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static EncodingEncoder findEncoding(DType dtype) { + for (EncodingEncoder enc : FALLBACK) { + if (enc.accepts(dtype)) { + return enc; + } + } + throw new UnsupportedOperationException("no fallback encoding for dtype: " + dtype); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoder.java new file mode 100644 index 00000000..9f5c152a --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoder.java @@ -0,0 +1,97 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; + + +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; + +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.ListViewData; + +import io.github.dfa1.vortex.encoding.PTypeIO; + + +import io.github.dfa1.vortex.proto.ListViewMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.listview}. +public final class ListViewEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ListViewEncodingEncoder() { + } + + private static final List FALLBACK = List.of( + new PrimitiveEncodingEncoder(), new VarBinEncodingEncoder(), new BoolEncodingEncoder(), + new NullEncodingEncoder(), new ByteBoolEncodingEncoder()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_LISTVIEW; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.List; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.List listDtype = (DType.List) dtype; + ListViewData lvd = (ListViewData) data; + DType elementType = listDtype.elementType(); + EncodingEncoder elemEncoding = findEncoding(elementType); + EncodeResult elemResult = elemEncoding.encode(elementType, lvd.elements(), ctx); + + List allBuffers = new ArrayList<>(elemResult.buffers()); + int elemBufCount = allBuffers.size(); + EncodeNode elemNode = EncodeNode.remapBufferIndices(elemResult.rootNode(), 0); + + long n = lvd.outerLen(); + + MemorySegment offsetsBuf = ctx.arena().allocate(n * Integer.BYTES, Integer.BYTES); + for (int i = 0; i < n; i++) { + offsetsBuf.setAtIndex(PTypeIO.LE_INT, i, lvd.offsets()[i]); + } + allBuffers.add(offsetsBuf); + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, elemBufCount); + + MemorySegment sizesBuf = ctx.arena().allocate(n * Integer.BYTES, Integer.BYTES); + for (int i = 0; i < n; i++) { + sizesBuf.setAtIndex(PTypeIO.LE_INT, i, lvd.sizes()[i]); + } + allBuffers.add(sizesBuf); + EncodeNode sizesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, elemBufCount + 1); + + long elementsLen = java.lang.reflect.Array.getLength(lvd.elements()); + byte[] metaBytes = new ListViewMetadata( + elementsLen, + io.github.dfa1.vortex.proto.PType.fromValue(PType.I32.ordinal()), + io.github.dfa1.vortex.proto.PType.fromValue(PType.I32.ordinal()) + ).encode(); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_LISTVIEW, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{elemNode, offsetsNode, sizesNode}, + new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static EncodingEncoder findEncoding(DType dtype) { + for (EncodingEncoder enc : FALLBACK) { + if (enc.accepts(dtype)) { + return enc; + } + } + throw new UnsupportedOperationException("no fallback encoding for dtype: " + dtype); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoder.java new file mode 100644 index 00000000..93653168 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoder.java @@ -0,0 +1,79 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.NullableData; + +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; + +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + + + + +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.masked}. Wraps the payload encode in a values + validity +/// pair driven by a {@link NullableData} carrier. +public final class MaskedEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public MaskedEncodingEncoder() { + } + + private static final List INNER_FALLBACK = List.of( + new PrimitiveEncodingEncoder(), + new VarBinEncodingEncoder(), + new FixedSizeListEncodingEncoder()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_MASKED; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(data instanceof NullableData nd)) { + throw new VortexException(EncodingId.VORTEX_MASKED, + "expected NullableData, got " + (data == null ? "null" : data.getClass().getName())); + } + DType nonNullable = dtype.withNullable(false); + EncodingEncoder inner = pickInner(nonNullable); + EncodeResult valuesResult = inner.encode(nonNullable, nd.values(), ctx); + EncodeResult validityResult = new BoolEncodingEncoder().encode(new DType.Bool(false), nd.validity(), ctx); + + int valuesBufCount = valuesResult.buffers().size(); + EncodeNode validityNode = EncodeNode.remapBufferIndices(validityResult.rootNode(), valuesBufCount); + + List buffers = new ArrayList<>(valuesBufCount + validityResult.buffers().size()); + buffers.addAll(valuesResult.buffers()); + buffers.addAll(validityResult.buffers()); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_MASKED, + null, + new EncodeNode[]{valuesResult.rootNode(), validityNode}, + new int[0]); + return new EncodeResult(root, buffers, valuesResult.statsMin(), valuesResult.statsMax()); + } + + private static EncodingEncoder pickInner(DType nonNullable) { + for (EncodingEncoder e : INNER_FALLBACK) { + if (e.accepts(nonNullable)) { + return e; + } + } + throw new VortexException(EncodingId.VORTEX_MASKED, + "no inner encoding for " + nonNullable); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoder.java new file mode 100644 index 00000000..d8f427f2 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoder.java @@ -0,0 +1,34 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.util.List; + +/// Write-only encoder for {@code vortex.null} (all-null arrays). +public final class NullEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public NullEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_NULL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Null; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + EncodeNode root = new EncodeNode(EncodingId.VORTEX_NULL, null, new EncodeNode[0], new int[0]); + return new EncodeResult(root, List.of(), null, null); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoder.java new file mode 100644 index 00000000..0e968582 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoder.java @@ -0,0 +1,31 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Write-only encoder for {@code vortex.patched} — currently throws (not implemented). +public final class PatchedEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PatchedEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PATCHED; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + throw new VortexException(EncodingId.VORTEX_PATCHED, "encode not yet implemented"); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoder.java new file mode 100644 index 00000000..78b37af9 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoder.java @@ -0,0 +1,32 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Write-side stub for {@code vortex.pco} — encode is not yet implemented. +public final class PcoEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PcoEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PCO; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + throw new VortexException(EncodingId.VORTEX_PCO, + "encode not implemented — pco encode port pending"); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoder.java new file mode 100644 index 00000000..21cb1f5b --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoder.java @@ -0,0 +1,292 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +/// Write-only encoder for {@code vortex.primitive} — raw little-endian primitive arrays. +public final class PrimitiveEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PrimitiveEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PRIMITIVE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + MemorySegment seg = encodePrimitive(ptype, data, ctx.arena()); + byte[] min = null; + byte[] max = null; + byte[][] stats = computeStats(ptype, data); + if (stats != null) { + min = stats[0]; + max = stats[1]; + } + return EncodeResult.simple(EncodingId.VORTEX_PRIMITIVE, seg, min, max); + } + + private static MemorySegment encodePrimitive(PType ptype, Object data, Arena arena) { + return switch (ptype) { + case I8, U8 -> MemorySegment.ofArray((byte[]) data); + case I16, U16, F16 -> { + short[] arr = (short[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 2, 2); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_SHORT, i, arr[i]); + } + yield seg; + } + case I32, U32 -> { + int[] arr = (int[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 4, 4); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_INT, i, arr[i]); + } + yield seg; + } + case I64, U64 -> { + long[] arr = (long[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 8, 8); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_LONG, i, arr[i]); + } + yield seg; + } + case F32 -> { + float[] arr = (float[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 4, 4); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_FLOAT, i, arr[i]); + } + yield seg; + } + case F64 -> { + double[] arr = (double[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 8, 8); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_DOUBLE, i, arr[i]); + } + yield seg; + } + }; + } + + private static byte[][] computeStats(PType ptype, Object data) { + return switch (ptype) { + case I8 -> { + byte[] arr = (byte[]) data; + if (arr.length == 0) { + yield null; + } + long min = arr[0], max = arr[0]; + for (byte v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarI64(min), scalarI64(max)}; + } + case I16 -> { + short[] arr = (short[]) data; + if (arr.length == 0) { + yield null; + } + long min = arr[0], max = arr[0]; + for (short v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarI64(min), scalarI64(max)}; + } + case I32 -> { + int[] arr = (int[]) data; + if (arr.length == 0) { + yield null; + } + long min = arr[0], max = arr[0]; + for (int v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarI64(min), scalarI64(max)}; + } + case I64 -> { + long[] arr = (long[]) data; + if (arr.length == 0) { + yield null; + } + long min = arr[0], max = arr[0]; + for (long v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarI64(min), scalarI64(max)}; + } + case U8 -> { + byte[] arr = (byte[]) data; + if (arr.length == 0) { + yield null; + } + long min = Byte.toUnsignedInt(arr[0]), max = Byte.toUnsignedInt(arr[0]); + for (byte v : arr) { + long uv = Byte.toUnsignedInt(v); + if (uv < min) { + min = uv; + } + if (uv > max) { + max = uv; + } + } + yield new byte[][]{scalarU64(min), scalarU64(max)}; + } + case U16 -> { + short[] arr = (short[]) data; + if (arr.length == 0) { + yield null; + } + long min = Short.toUnsignedInt(arr[0]), max = Short.toUnsignedInt(arr[0]); + for (short v : arr) { + long uv = Short.toUnsignedInt(v); + if (uv < min) { + min = uv; + } + if (uv > max) { + max = uv; + } + } + yield new byte[][]{scalarU64(min), scalarU64(max)}; + } + case U32 -> { + int[] arr = (int[]) data; + if (arr.length == 0) { + yield null; + } + long min = Integer.toUnsignedLong(arr[0]), max = Integer.toUnsignedLong(arr[0]); + for (int v : arr) { + long uv = Integer.toUnsignedLong(v); + if (uv < min) { + min = uv; + } + if (uv > max) { + max = uv; + } + } + yield new byte[][]{scalarU64(min), scalarU64(max)}; + } + case U64 -> { + long[] arr = (long[]) data; + if (arr.length == 0) { + yield null; + } + long min = arr[0], max = arr[0]; + for (long v : arr) { + if (Long.compareUnsigned(v, min) < 0) { + min = v; + } + if (Long.compareUnsigned(v, max) > 0) { + max = v; + } + } + yield new byte[][]{scalarU64(min), scalarU64(max)}; + } + case F32 -> { + float[] arr = (float[]) data; + if (arr.length == 0) { + yield null; + } + float min = arr[0], max = arr[0]; + for (float v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarF32(min), scalarF32(max)}; + } + case F64 -> { + double[] arr = (double[]) data; + if (arr.length == 0) { + yield null; + } + double min = arr[0], max = arr[0]; + for (double v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarF64(min), scalarF64(max)}; + } + case F16 -> { + short[] arr = (short[]) data; + if (arr.length == 0) { + yield null; + } + float min = Float.float16ToFloat(arr[0]), max = Float.float16ToFloat(arr[0]); + for (short v : arr) { + float fv = Float.float16ToFloat(v); + if (fv < min) { + min = fv; + } + if (fv > max) { + max = fv; + } + } + yield new byte[][]{scalarF32(min), scalarF32(max)}; + } + }; + } + + private static byte[] scalarI64(long v) { + return ScalarValue.ofInt64Value(v).encode(); + } + + private static byte[] scalarU64(long v) { + return ScalarValue.ofUint64Value(v).encode(); + } + + private static byte[] scalarF32(float v) { + return ScalarValue.ofF32Value(v).encode(); + } + + private static byte[] scalarF64(double v) { + return ScalarValue.ofF64Value(v).encode(); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoder.java new file mode 100644 index 00000000..e3a0e0c9 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoder.java @@ -0,0 +1,256 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.RLEMetadata; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code fastlanes.rle}. +public final class RleEncodingEncoder implements EncodingEncoder { + + private static final int FL_CHUNK_SIZE = 1024; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public RleEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_RLE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.FASTLANES_RLE, "encode only supports Primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + long[] longs = toLongs(data, ptype); + int n = longs.length; + + if (n == 0) { + return encodeEmpty(ctx); + } + + int numChunks = (n + FL_CHUNK_SIZE - 1) / FL_CHUNK_SIZE; + int paddedLen = numChunks * FL_CHUNK_SIZE; + + long[] globalValues = new long[paddedLen]; + short[] globalIndices = new short[paddedLen]; + long[] valuesIdxOffsets = new long[numChunks]; + + long[] chunkInput = new long[FL_CHUNK_SIZE]; + long[] chunkValues = new long[FL_CHUNK_SIZE]; + short[] chunkIndices = new short[FL_CHUNK_SIZE]; + + int globalValuesCount = 0; + + for (int chunk = 0; chunk < numChunks; chunk++) { + int chunkStart = chunk * FL_CHUNK_SIZE; + int chunkEnd = Math.min(chunkStart + FL_CHUNK_SIZE, n); + int chunkLen = chunkEnd - chunkStart; + + System.arraycopy(longs, chunkStart, chunkInput, 0, chunkLen); + long lastVal = longs[chunkEnd - 1]; + for (int i = chunkLen; i < FL_CHUNK_SIZE; i++) { + chunkInput[i] = lastVal; + } + + int numChunkValues = rleEncode(chunkInput, chunkValues, chunkIndices); + + valuesIdxOffsets[chunk] = globalValuesCount; + System.arraycopy(chunkValues, 0, globalValues, globalValuesCount, numChunkValues); + globalValuesCount += numChunkValues; + + System.arraycopy(chunkIndices, 0, globalIndices, chunkStart, FL_CHUNK_SIZE); + } + + MemorySegment valuesSeg = fromLongs(globalValues, globalValuesCount, ptype, ctx.arena()); + MemorySegment indicesSeg = toIndicesSeg(globalIndices, paddedLen, ctx.arena()); + MemorySegment offsetsSeg = fromLongsU64(valuesIdxOffsets, numChunks, ctx.arena()); + + PType indicesPtype = PType.U16; + PType offsetsPtype = PType.U64; + + byte[] metaBytes = new RLEMetadata( + globalValuesCount, + paddedLen, + io.github.dfa1.vortex.proto.PType.fromValue(indicesPtype.ordinal()), + numChunks, + io.github.dfa1.vortex.proto.PType.fromValue(offsetsPtype.ordinal()), + 0L + ).encode(); + + EncodeNode valuesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode indicesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode( + EncodingId.FASTLANES_RLE, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{valuesNode, indicesNode, offsetsNode}, + new int[0]); + return new EncodeResult(root, List.of(valuesSeg, indicesSeg, offsetsSeg), null, null); + } + + private static int rleEncode(long[] input, long[] chunkValues, short[] chunkIndices) { + short posVal = 0; + int valIdx = 1; + long prev = input[0]; + chunkValues[0] = prev; + chunkIndices[0] = 0; + + for (int i = 1; i < FL_CHUNK_SIZE; i++) { + long cur = input[i]; + if (cur != prev) { + chunkValues[valIdx] = cur; + valIdx++; + posVal++; + prev = cur; + } + chunkIndices[i] = posVal; + } + return valIdx; + } + + private static EncodeResult encodeEmpty(EncodeContext ctx) { + MemorySegment empty = ctx.arena().allocate(0); + PType indicesPtype = PType.U16; + PType offsetsPtype = PType.U64; + byte[] metaBytes = new RLEMetadata( + 0L, + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(indicesPtype.ordinal()), + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(offsetsPtype.ordinal()), + 0L + ).encode(); + EncodeNode valuesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode indicesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode( + EncodingId.FASTLANES_RLE, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{valuesNode, indicesNode, offsetsNode}, + new int[0]); + return new EncodeResult(root, List.of(empty, empty, empty), null, null); + } + + private static long[] toLongs(Object data, PType ptype) { + return switch (ptype) { + case I8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Byte.toUnsignedLong(arr[i]); + } + yield r; + } + case I16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Short.toUnsignedLong(arr[i]); + } + yield r; + } + case I32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Integer.toUnsignedLong(arr[i]); + } + yield r; + } + case I64, U64 -> (long[]) data; + case F32 -> { + float[] arr = (float[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Float.floatToRawIntBits(arr[i]); + } + yield r; + } + case F64 -> { + double[] arr = (double[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Double.doubleToRawLongBits(arr[i]); + } + yield r; + } + case F16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Short.toUnsignedLong(arr[i]); + } + yield r; + } + }; + } + + private static MemorySegment fromLongs(long[] values, int count, PType ptype, SegmentAllocator arena) { + int elemSize = ptype.byteSize(); + MemorySegment seg = arena.allocate((long) count * elemSize); + for (int i = 0; i < count; i++) { + PTypeIO.set(seg, (long) i * elemSize, ptype, values[i]); + } + return seg; + } + + private static MemorySegment fromLongsU64(long[] values, int count, SegmentAllocator arena) { + MemorySegment seg = arena.allocate((long) count * 8); + for (int i = 0; i < count; i++) { + seg.setAtIndex(PTypeIO.LE_LONG, i, values[i]); + } + return seg; + } + + private static MemorySegment toIndicesSeg(short[] indices, int count, SegmentAllocator arena) { + MemorySegment seg = arena.allocate((long) count * 2); + for (int i = 0; i < count; i++) { + seg.setAtIndex(PTypeIO.LE_SHORT, i, indices[i]); + } + return seg; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoder.java new file mode 100644 index 00000000..b8829b55 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoder.java @@ -0,0 +1,108 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.RunEndMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.runend}. +public final class RunEndEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public RunEndEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_RUNEND; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_RUNEND, "encode only supports Primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + int n = arrayLength(data, ptype); + + List ends = new ArrayList<>(); + List values = new ArrayList<>(); + if (n > 0) { + long runVal = readLong(data, ptype, 0); + for (int i = 1; i < n; i++) { + long cur = readLong(data, ptype, i); + if (cur != runVal) { + ends.add(i); + values.add(runVal); + runVal = cur; + } + } + ends.add(n); + values.add(runVal); + } + + int numRuns = ends.size(); + + MemorySegment endsBuf = ctx.arena().allocate((long) numRuns * 4, 4); + for (int i = 0; i < numRuns; i++) { + endsBuf.setAtIndex(PTypeIO.LE_INT, i, ends.get(i)); + } + + int elemBytes = ptype.byteSize(); + MemorySegment valuesBuf = ctx.arena().allocate((long) numRuns * elemBytes, elemBytes); + for (int i = 0; i < numRuns; i++) { + PTypeIO.set(valuesBuf, (long) i * elemBytes, ptype, values.get(i)); + } + + byte[] metaBytes = new RunEndMetadata( + io.github.dfa1.vortex.proto.PType.fromValue(PType.U32.ordinal()), + numRuns, + 0L + ).encode(); + + EncodeNode endsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode valuesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_RUNEND, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{endsNode, valuesNode}, new int[0]); + return new EncodeResult(root, List.of(endsBuf, valuesBuf), null, null); + } + + private static int arrayLength(Object data, PType ptype) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case I64, U64 -> ((long[]) data).length; + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "unsupported ptype: " + ptype); + }; + } + + private static long readLong(Object data, PType ptype, int i) { + return switch (ptype) { + case I8 -> ((byte[]) data)[i]; + case U8 -> Byte.toUnsignedLong(((byte[]) data)[i]); + case I16 -> ((short[]) data)[i]; + case U16 -> Short.toUnsignedLong(((short[]) data)[i]); + case I32 -> ((int[]) data)[i]; + case U32 -> Integer.toUnsignedLong(((int[]) data)[i]); + case I64, U64 -> ((long[]) data)[i]; + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "unsupported ptype: " + ptype); + }; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoder.java new file mode 100644 index 00000000..00e19ea3 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoder.java @@ -0,0 +1,138 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.SequenceMetadata; + +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code vortex.sequence} — arithmetic sequences as (base, multiplier). +public final class SequenceEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public SequenceEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_SEQUENCE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "encode only supports Primitive dtype, got " + dtype); + } + PType pt = p.ptype(); + return switch (pt) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> encodeInteger(pt, data); + case F32 -> encodeF32((float[]) data); + case F64 -> encodeF64((double[]) data); + case F16 -> encodeF16((short[]) data); + }; + } + + private static EncodeResult encodeInteger(PType pt, Object data) { + int n = intArrayLength(pt, data); + long base = 0; + long multiplier = 0; + if (n > 0) { + base = readLong(pt, data, 0); + multiplier = n > 1 ? readLong(pt, data, 1) - base : 0; + for (int i = 2; i < n; i++) { + long expected = base + (long) i * multiplier; + if (readLong(pt, data, i) != expected) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "not an arithmetic sequence at index " + i); + } + } + } + ScalarValue baseScalar = buildIntScalar(pt, base); + ScalarValue mulScalar = buildIntScalar(pt, multiplier); + return buildResult(baseScalar, mulScalar); + } + + private static EncodeResult encodeF32(float[] data) { + float base = data.length > 0 ? data[0] : 0f; + float mul = data.length > 1 ? data[1] - base : 0f; + for (int i = 2; i < data.length; i++) { + if (data[i] != base + (float) i * mul) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "not an arithmetic sequence at index " + i); + } + } + return buildResult(ScalarValue.ofF32Value(base), ScalarValue.ofF32Value(mul)); + } + + private static EncodeResult encodeF64(double[] data) { + double base = data.length > 0 ? data[0] : 0.0; + double mul = data.length > 1 ? data[1] - base : 0.0; + for (int i = 2; i < data.length; i++) { + if (data[i] != base + (double) i * mul) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "not an arithmetic sequence at index " + i); + } + } + return buildResult(ScalarValue.ofF64Value(base), ScalarValue.ofF64Value(mul)); + } + + private static EncodeResult encodeF16(short[] data) { + short baseShort = data.length > 0 ? data[0] : 0; + float baseF = Float.float16ToFloat(baseShort); + float mulF = data.length > 1 ? Float.float16ToFloat(data[1]) - baseF : 0f; + short mulShort = Float.floatToFloat16(mulF); + for (int i = 2; i < data.length; i++) { + short expected = Float.floatToFloat16(baseF + (float) i * mulF); + if (data[i] != expected) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "not an arithmetic sequence at index " + i); + } + } + return buildResult( + ScalarValue.ofF16Value(Short.toUnsignedLong(baseShort)), + ScalarValue.ofF16Value(Short.toUnsignedLong(mulShort))); + } + + private static EncodeResult buildResult(ScalarValue base, ScalarValue mul) { + SequenceMetadata meta = new SequenceMetadata(base, mul); + ByteBuffer metaBuf = ByteBuffer.wrap(meta.encode()); + EncodeNode node = new EncodeNode(EncodingId.VORTEX_SEQUENCE, metaBuf, new EncodeNode[0], new int[]{}); + return new EncodeResult(node, List.of(), null, null); + } + + private static ScalarValue buildIntScalar(PType pt, long value) { + return switch (pt) { + case U8, U16, U32, U64 -> ScalarValue.ofUint64Value(value); + default -> ScalarValue.ofInt64Value(value); + }; + } + + private static int intArrayLength(PType pt, Object data) { + return switch (pt) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case I64, U64 -> ((long[]) data).length; + default -> throw new VortexException(EncodingId.VORTEX_SEQUENCE, "unsupported ptype: " + pt); + }; + } + + private static long readLong(PType pt, Object data, int i) { + return switch (pt) { + case I8, U8 -> ((byte[]) data)[i]; + case I16, U16 -> ((short[]) data)[i]; + case I32, U32 -> ((int[]) data)[i]; + case I64, U64 -> ((long[]) data)[i]; + default -> throw new VortexException(EncodingId.VORTEX_SEQUENCE, "unsupported ptype: " + pt); + }; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoder.java new file mode 100644 index 00000000..2c5e142e --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoder.java @@ -0,0 +1,149 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.SparseMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.sparse}. +public final class SparseEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public SparseEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_SPARSE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_SPARSE, + "encode only supports Primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + int n = arrayLength(data, ptype); + + List patchIdx = new ArrayList<>(); + List patchBits = new ArrayList<>(); + for (int i = 0; i < n; i++) { + long bits = readBits(data, ptype, i); + if (bits != 0L) { + patchIdx.add(i); + patchBits.add(bits); + } + } + + int numPatches = patchIdx.size(); + PType idxPtype = chooseIdxPtype(n); + + ScalarValue fillScalar = zeroScalar(ptype); + byte[] fillBytes = fillScalar.encode(); + MemorySegment fillBuf = ctx.arena().allocate(fillBytes.length); + MemorySegment.copy(MemorySegment.ofArray(fillBytes), 0, fillBuf, 0, fillBytes.length); + + MemorySegment idxBuf = buildIdxBuf(patchIdx, idxPtype, numPatches, ctx); + MemorySegment valBuf = buildValBuf(patchBits, ptype, numPatches, ctx); + + PatchesMetadata patchesMeta = new PatchesMetadata( + numPatches, + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(idxPtype.ordinal()), + null, + null, + null + ); + byte[] metaBytes = new SparseMetadata(patchesMeta).encode(); + + EncodeNode idxNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode valNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_SPARSE, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{idxNode, valNode}, new int[]{0}); + return new EncodeResult(root, List.of(fillBuf, idxBuf, valBuf), null, null); + } + + private static int arrayLength(Object data, PType ptype) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case I64, U64 -> ((long[]) data).length; + case F32 -> ((float[]) data).length; + case F64 -> ((double[]) data).length; + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "unsupported ptype: " + ptype); + }; + } + + private static long readBits(Object data, PType ptype, int i) { + return switch (ptype) { + case I8 -> ((byte[]) data)[i]; + case U8 -> Byte.toUnsignedLong(((byte[]) data)[i]); + case I16 -> ((short[]) data)[i]; + case U16 -> Short.toUnsignedLong(((short[]) data)[i]); + case I32 -> ((int[]) data)[i]; + case U32 -> Integer.toUnsignedLong(((int[]) data)[i]); + case I64, U64 -> ((long[]) data)[i]; + case F32 -> Float.floatToRawIntBits(((float[]) data)[i]); + case F64 -> Double.doubleToRawLongBits(((double[]) data)[i]); + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "unsupported ptype: " + ptype); + }; + } + + private static PType chooseIdxPtype(int n) { + if (n <= 0xFF) { + return PType.U8; + } else if (n <= 0xFFFF) { + return PType.U16; + } else { + return PType.U32; + } + } + + private static ScalarValue zeroScalar(PType ptype) { + return switch (ptype) { + case I8, I16, I32, I64 -> ScalarValue.ofInt64Value(0L); + case U8, U16, U32, U64 -> ScalarValue.ofUint64Value(0L); + case F32 -> ScalarValue.ofF32Value(0.0f); + case F64 -> ScalarValue.ofF64Value(0.0); + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "unsupported ptype: " + ptype); + }; + } + + private static MemorySegment buildIdxBuf(List patchIdx, PType idxPtype, int numPatches, EncodeContext ctx) { + int elemBytes = idxPtype.byteSize(); + MemorySegment seg = ctx.arena().allocate(Math.max(1L, (long) numPatches * elemBytes), elemBytes); + for (int i = 0; i < numPatches; i++) { + PTypeIO.set(seg, (long) i * elemBytes, idxPtype, patchIdx.get(i)); + } + return seg; + } + + private static MemorySegment buildValBuf(List patchBits, PType ptype, int numPatches, EncodeContext ctx) { + int elemBytes = ptype.byteSize(); + MemorySegment seg = ctx.arena().allocate(Math.max(1L, (long) numPatches * elemBytes), elemBytes); + for (int i = 0; i < numPatches; i++) { + PTypeIO.set(seg, (long) i * elemBytes, ptype, patchBits.get(i)); + } + return seg; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java new file mode 100644 index 00000000..5b1661f5 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java @@ -0,0 +1,74 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; + + +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; + +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + + +import io.github.dfa1.vortex.encoding.StructData; + + +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.struct}. +public final class StructEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public StructEncodingEncoder() { + } + + private static final List FALLBACK = List.of( + new PrimitiveEncodingEncoder(), new VarBinEncodingEncoder(), new BoolEncodingEncoder(), + new NullEncodingEncoder(), new ByteBoolEncodingEncoder()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_STRUCT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Struct; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.Struct sdtype = (DType.Struct) dtype; + StructData sd = (StructData) data; + List fields = sd.fieldArrays(); + List fieldTypes = sdtype.fieldTypes(); + if (fields.size() != fieldTypes.size()) { + throw new VortexException(EncodingId.VORTEX_STRUCT, + "fieldArrays length %d != fieldTypes length %d".formatted(fields.size(), fieldTypes.size())); + } + List allBuffers = new ArrayList<>(); + EncodeNode[] children = new EncodeNode[fields.size()]; + for (int i = 0; i < fields.size(); i++) { + DType fieldDtype = fieldTypes.get(i); + EncodeResult fieldResult = findEncoding(fieldDtype).encode(fieldDtype, fields.get(i), ctx); + int bufOffset = allBuffers.size(); + children[i] = EncodeNode.remapBufferIndices(fieldResult.rootNode(), bufOffset); + allBuffers.addAll(fieldResult.buffers()); + } + EncodeNode root = new EncodeNode(EncodingId.VORTEX_STRUCT, null, children, new int[0]); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static EncodingEncoder findEncoding(DType dtype) { + for (EncodingEncoder enc : FALLBACK) { + if (enc.accepts(dtype)) { + return enc; + } + } + throw new UnsupportedOperationException("no fallback encoding for dtype: " + dtype); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoder.java new file mode 100644 index 00000000..b1731029 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoder.java @@ -0,0 +1,84 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.VarBinMetadata; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; + +/// Write-only encoder for {@code vortex.varbin}. +public final class VarBinEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VarBinEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARBIN; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + String[] strings = (String[]) data; + int n = strings.length; + + byte[][] byteArrays = new byte[n][]; + int totalBytes = 0; + for (int i = 0; i < n; i++) { + byteArrays[i] = strings[i].getBytes(StandardCharsets.UTF_8); + totalBytes += byteArrays[i].length; + } + + Arena arena = ctx.arena(); + MemorySegment bytesBuf = arena.allocate(totalBytes > 0 ? totalBytes : 1); + MemorySegment offsetsBuf = arena.allocate((long) (n + 1) * Long.BYTES, Long.BYTES); + + long pos = 0; + offsetsBuf.setAtIndex(PTypeIO.LE_LONG, 0, 0L); + for (int i = 0; i < n; i++) { + MemorySegment.copy(MemorySegment.ofArray(byteArrays[i]), 0, bytesBuf, pos, byteArrays[i].length); + pos += byteArrays[i].length; + offsetsBuf.setAtIndex(PTypeIO.LE_LONG, i + 1, pos); + } + + byte[] metaBytes = new VarBinMetadata(io.github.dfa1.vortex.proto.PType.fromValue(PType.I64.ordinal())).encode(); + + String minStr = null; + String maxStr = null; + for (String s : strings) { + if (s == null) { + continue; + } + if (minStr == null || s.compareTo(minStr) < 0) { + minStr = s; + } + if (maxStr == null || s.compareTo(maxStr) > 0) { + maxStr = s; + } + } + byte[] statsMin = minStr != null ? ScalarValue.ofStringValue(minStr).encode() : null; + byte[] statsMax = maxStr != null ? ScalarValue.ofStringValue(maxStr).encode() : null; + + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_VARBIN, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{offsetsNode}, new int[]{0}); + return new EncodeResult(root, List.of(bytesBuf, offsetsBuf), statsMin, statsMax); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoder.java new file mode 100644 index 00000000..99971563 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoder.java @@ -0,0 +1,84 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.charset.StandardCharsets; +import java.util.List; + +/// Write-only encoder for {@code vortex.varbinview}. +public final class VarBinViewEncodingEncoder implements EncodingEncoder { + + private static final int MAX_INLINED_SIZE = 12; + private static final int VIEW_SIZE = 16; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VarBinViewEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARBINVIEW; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + String[] strings = (String[]) data; + int n = strings.length; + + byte[][] bytes = new byte[n][]; + int totalDataBytes = 0; + for (int i = 0; i < n; i++) { + bytes[i] = strings[i].getBytes(StandardCharsets.UTF_8); + if (bytes[i].length > MAX_INLINED_SIZE) { + totalDataBytes += bytes[i].length; + } + } + + Arena arena = ctx.arena(); + boolean hasDataBuf = totalDataBytes > 0; + MemorySegment dataBuf = arena.allocate(hasDataBuf ? totalDataBytes : 1); + MemorySegment viewsBuf = arena.allocate(n > 0 ? (long) n * VIEW_SIZE : 1); + + int dataOffset = 0; + for (int i = 0; i < n; i++) { + byte[] b = bytes[i]; + long viewOff = (long) i * VIEW_SIZE; + viewsBuf.set(PTypeIO.LE_INT, viewOff, b.length); + if (b.length <= MAX_INLINED_SIZE) { + MemorySegment.copy(MemorySegment.ofArray(b), 0, viewsBuf, viewOff + 4, b.length); + } else { + MemorySegment.copy(MemorySegment.ofArray(b), 0, viewsBuf, viewOff + 4, 4); + viewsBuf.set(PTypeIO.LE_INT, viewOff + 8, 0); + viewsBuf.set(PTypeIO.LE_INT, viewOff + 12, dataOffset); + MemorySegment.copy(MemorySegment.ofArray(b), 0, dataBuf, dataOffset, b.length); + dataOffset += b.length; + } + } + + int[] bufIndices; + List buffers; + if (hasDataBuf) { + bufIndices = new int[]{0, 1}; + buffers = List.of(dataBuf, viewsBuf); + } else { + bufIndices = new int[]{0}; + buffers = List.of(viewsBuf); + } + + EncodeNode root = new EncodeNode(EncodingId.VORTEX_VARBINVIEW, null, new EncodeNode[0], bufIndices); + return new EncodeResult(root, buffers, null, null); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VariantEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VariantEncodingEncoder.java new file mode 100644 index 00000000..368d6501 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VariantEncodingEncoder.java @@ -0,0 +1,31 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Write-only encoder for {@code vortex.variant} — currently throws (not implemented). +public final class VariantEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VariantEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARIANT; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + throw new VortexException(EncodingId.VORTEX_VARIANT, "encode not yet implemented"); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoder.java new file mode 100644 index 00000000..a1d69f7b --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoder.java @@ -0,0 +1,84 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.util.List; + +/// Write-only encoder for {@code vortex.zigzag} — signed integers as zigzag-encoded unsigned values. +public final class ZigZagEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ZigZagEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ZIGZAG; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + PType pt = p.ptype(); + return pt == PType.I8 || pt == PType.I16 || pt == PType.I32 || pt == PType.I64; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType signed = ((DType.Primitive) dtype).ptype(); + MemorySegment seg = switch (signed) { + case I8 -> { + byte[] arr = (byte[]) data; + MemorySegment s = ctx.arena().allocate(arr.length); + for (int i = 0; i < arr.length; i++) { + byte v = arr[i]; + s.set(ValueLayout.JAVA_BYTE, i, (byte) ((v << 1) ^ (v >> 7))); + } + yield s; + } + case I16 -> { + short[] arr = (short[]) data; + MemorySegment s = ctx.arena().allocate((long) arr.length * 2, 2); + for (int i = 0; i < arr.length; i++) { + short v = arr[i]; + s.setAtIndex(PTypeIO.LE_SHORT, i, (short) ((v << 1) ^ (v >> 15))); + } + yield s; + } + case I32 -> { + int[] arr = (int[]) data; + MemorySegment s = ctx.arena().allocate((long) arr.length * 4, 4); + for (int i = 0; i < arr.length; i++) { + int v = arr[i]; + s.setAtIndex(PTypeIO.LE_INT, i, (v << 1) ^ (v >> 31)); + } + yield s; + } + case I64 -> { + long[] arr = (long[]) data; + MemorySegment s = ctx.arena().allocate((long) arr.length * 8, 8); + for (int i = 0; i < arr.length; i++) { + long v = arr[i]; + s.setAtIndex(PTypeIO.LE_LONG, i, (v << 1) ^ (v >> 63)); + } + yield s; + } + default -> throw new VortexException(EncodingId.VORTEX_ZIGZAG, "unsupported ptype: " + signed); + }; + EncodeNode child = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZIGZAG, null, new EncodeNode[]{child}, new int[0]); + return new EncodeResult(root, List.of(seg), null, null); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java new file mode 100644 index 00000000..7b3b75a1 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java @@ -0,0 +1,159 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.airlift.compress.v3.zstd.ZstdCompressor; +import io.airlift.compress.v3.zstd.ZstdJavaCompressor; +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ZstdFrameMetadata; +import io.github.dfa1.vortex.proto.ZstdMetadata; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; + +/// Write-only encoder for {@code vortex.zstd}. +public final class ZstdEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ZstdEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ZSTD; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (dtype instanceof DType.Primitive dt) { + return encodePrimitive(dt, data); + } + if (dtype instanceof DType.Utf8 || dtype instanceof DType.Binary) { + return encodeVarBin((String[]) data); + } + throw new VortexException(EncodingId.VORTEX_ZSTD, "unsupported dtype: " + dtype); + } + + private static EncodeResult encodePrimitive(DType.Primitive dt, Object data) { + MemorySegment raw = primitiveToLeBytes(dt.ptype(), data, Arena.ofAuto()); + long n = primitiveLength(dt.ptype(), data); + byte[] rawBytes = raw.toArray(ValueLayout.JAVA_BYTE); + return buildResult(rawBytes, n); + } + + private static EncodeResult encodeVarBin(String[] strings) { + byte[] raw = buildLengthPrefixed(strings); + return buildResult(raw, strings.length); + } + + private static EncodeResult buildResult(byte[] raw, long n) { + byte[] compressed = compress(raw); + byte[] meta = new ZstdMetadata( + 0, + java.util.List.of(new ZstdFrameMetadata(raw.length, n)) + ).encode(); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, ByteBuffer.wrap(meta), + new EncodeNode[0], new int[]{0}); + return new EncodeResult(root, List.of(MemorySegment.ofArray(compressed)), null, null); + } + + private static byte[] compress(byte[] input) { + ZstdCompressor compressor = new ZstdJavaCompressor(); + byte[] out = new byte[compressor.maxCompressedLength(input.length)]; + int len = compressor.compress(input, 0, input.length, out, 0, out.length); + return Arrays.copyOf(out, len); + } + + private static MemorySegment primitiveToLeBytes(PType ptype, Object data, Arena arena) { + return switch (ptype) { + case I8, U8 -> MemorySegment.ofArray((byte[]) data); + case I16, U16, F16 -> { + short[] arr = (short[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 2, 2); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_SHORT, i, arr[i]); + } + yield seg; + } + case I32, U32 -> { + int[] arr = (int[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 4, 4); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_INT, i, arr[i]); + } + yield seg; + } + case I64, U64 -> { + long[] arr = (long[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 8, 8); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_LONG, i, arr[i]); + } + yield seg; + } + case F32 -> { + float[] arr = (float[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 4, 4); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_FLOAT, i, arr[i]); + } + yield seg; + } + case F64 -> { + double[] arr = (double[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 8, 8); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_DOUBLE, i, arr[i]); + } + yield seg; + } + }; + } + + private static long primitiveLength(PType ptype, Object data) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16, F16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case F32 -> ((float[]) data).length; + case I64, U64 -> ((long[]) data).length; + case F64 -> ((double[]) data).length; + }; + } + + private static byte[] buildLengthPrefixed(String[] strings) { + int total = 0; + byte[][] encoded = new byte[strings.length][]; + for (int i = 0; i < strings.length; i++) { + encoded[i] = strings[i].getBytes(StandardCharsets.UTF_8); + total += 4 + encoded[i].length; + } + try (Arena scratch = Arena.ofConfined()) { + MemorySegment seg = scratch.allocate(total > 0 ? total : 1); + long pos = 0; + for (byte[] bytes : encoded) { + seg.set(PTypeIO.LE_INT, pos, bytes.length); + pos += 4; + MemorySegment.copy(MemorySegment.ofArray(bytes), 0, seg, pos, bytes.length); + pos += bytes.length; + } + return seg.asSlice(0, total).toArray(ValueLayout.JAVA_BYTE); + } + } +} diff --git a/writer/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingEncoder b/writer/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingEncoder new file mode 100644 index 00000000..19a1dab3 --- /dev/null +++ b/writer/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingEncoder @@ -0,0 +1,33 @@ +io.github.dfa1.vortex.writer.encode.AlpEncodingEncoder +io.github.dfa1.vortex.writer.encode.AlpRdEncodingEncoder +io.github.dfa1.vortex.writer.encode.BitpackedEncodingEncoder +io.github.dfa1.vortex.writer.encode.BoolEncodingEncoder +io.github.dfa1.vortex.writer.encode.ByteBoolEncodingEncoder +io.github.dfa1.vortex.writer.encode.ChunkedEncodingEncoder +io.github.dfa1.vortex.writer.encode.ConstantEncodingEncoder +io.github.dfa1.vortex.writer.encode.DateTimePartsEncodingEncoder +io.github.dfa1.vortex.writer.encode.DecimalBytePartsEncodingEncoder +io.github.dfa1.vortex.writer.encode.DecimalEncodingEncoder +io.github.dfa1.vortex.writer.encode.DeltaEncodingEncoder +io.github.dfa1.vortex.writer.encode.DictEncodingEncoder +io.github.dfa1.vortex.writer.encode.ExtEncodingEncoder +io.github.dfa1.vortex.writer.encode.FixedSizeListEncodingEncoder +io.github.dfa1.vortex.writer.encode.FrameOfReferenceEncodingEncoder +io.github.dfa1.vortex.writer.encode.FsstEncodingEncoder +io.github.dfa1.vortex.writer.encode.ListEncodingEncoder +io.github.dfa1.vortex.writer.encode.ListViewEncodingEncoder +io.github.dfa1.vortex.writer.encode.MaskedEncodingEncoder +io.github.dfa1.vortex.writer.encode.NullEncodingEncoder +io.github.dfa1.vortex.writer.encode.PatchedEncodingEncoder +io.github.dfa1.vortex.writer.encode.PcoEncodingEncoder +io.github.dfa1.vortex.writer.encode.PrimitiveEncodingEncoder +io.github.dfa1.vortex.writer.encode.RleEncodingEncoder +io.github.dfa1.vortex.writer.encode.RunEndEncodingEncoder +io.github.dfa1.vortex.writer.encode.SequenceEncodingEncoder +io.github.dfa1.vortex.writer.encode.SparseEncodingEncoder +io.github.dfa1.vortex.writer.encode.StructEncodingEncoder +io.github.dfa1.vortex.writer.encode.VarBinEncodingEncoder +io.github.dfa1.vortex.writer.encode.VariantEncodingEncoder +io.github.dfa1.vortex.writer.encode.VarBinViewEncodingEncoder +io.github.dfa1.vortex.writer.encode.ZigZagEncodingEncoder +io.github.dfa1.vortex.writer.encode.ZstdEncodingEncoder diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/BitpackedEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/BitpackedEncodingTest.java index 2c7b873c..cd1c0c9e 100644 --- a/writer/src/test/java/io/github/dfa1/vortex/writer/BitpackedEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/BitpackedEncodingTest.java @@ -2,8 +2,9 @@ import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; -import io.github.dfa1.vortex.encoding.BitpackedEncoding; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.writer.encode.BitpackedEncodingEncoder; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; import io.github.dfa1.vortex.reader.VortexReader; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -27,8 +28,8 @@ class BitpackedEncodingTest { List.of(new DType.Primitive(PType.I32, false)), false); - private static Registry bitpackedRegistry() { - return Registry.builder().register(new BitpackedEncoding()).build(); + private static ReadRegistry bitpackedRegistry() { + return ReadRegistry.builder().register(new BitpackedEncodingDecoder()).build(); } @Test @@ -39,7 +40,7 @@ void roundTrip_positiveIntegers(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I32_SCHEMA, WriteOptions.defaults(), - List.of(new BitpackedEncoding()))) { + List.of(new BitpackedEncodingEncoder()))) { // When sut.writeChunk(Map.of("value", data)); } @@ -58,7 +59,7 @@ void roundTrip_allSameValue(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I32_SCHEMA, WriteOptions.defaults(), - List.of(new BitpackedEncoding()))) { + List.of(new BitpackedEncodingEncoder()))) { // When sut.writeChunk(Map.of("value", data)); } @@ -77,7 +78,7 @@ void roundTrip_negativeIntegers(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I32_SCHEMA, WriteOptions.defaults(), - List.of(new BitpackedEncoding()))) { + List.of(new BitpackedEncodingEncoder()))) { // When sut.writeChunk(Map.of("value", data)); } @@ -99,7 +100,7 @@ void roundTrip_multipleChunks(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I32_SCHEMA, WriteOptions.defaults(), - List.of(new BitpackedEncoding()))) { + List.of(new BitpackedEncodingEncoder()))) { // When sut.writeChunk(Map.of("value", chunk1)); sut.writeChunk(Map.of("value", chunk2)); @@ -123,7 +124,7 @@ void roundTrip_bitWidths(int maxVal, @TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I32_SCHEMA, WriteOptions.defaults(), - List.of(new BitpackedEncoding()))) { + List.of(new BitpackedEncodingEncoder()))) { // When sut.writeChunk(Map.of("value", data)); } diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/DeltaEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/DeltaEncodingTest.java index a8bf37e9..5af74491 100644 --- a/writer/src/test/java/io/github/dfa1/vortex/writer/DeltaEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/DeltaEncodingTest.java @@ -2,8 +2,8 @@ import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; -import io.github.dfa1.vortex.encoding.DeltaEncoding; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.writer.encode.DeltaEncodingEncoder; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.VortexReader; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -27,10 +27,10 @@ class DeltaEncodingTest { List.of(new DType.Primitive(PType.I64, false)), false); - private static Registry deltaRegistry() { - return Registry.builder() - .register(new DeltaEncoding()) - .register(new io.github.dfa1.vortex.encoding.PrimitiveEncoding()) + private static ReadRegistry deltaRegistry() { + return ReadRegistry.builder() + .register(new io.github.dfa1.vortex.reader.decode.DeltaEncodingDecoder()) + .register(new io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder()) .build(); } @@ -42,7 +42,7 @@ void roundTrip_monotonicIncreasing(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I64_SCHEMA, WriteOptions.defaults(), - List.of(new DeltaEncoding()))) { + List.of(new DeltaEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", data)); } @@ -61,7 +61,7 @@ void roundTrip_monotonicDecreasing(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I64_SCHEMA, WriteOptions.defaults(), - List.of(new DeltaEncoding()))) { + List.of(new DeltaEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", data)); } @@ -80,7 +80,7 @@ void roundTrip_allSameValue(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I64_SCHEMA, WriteOptions.defaults(), - List.of(new DeltaEncoding()))) { + List.of(new DeltaEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", data)); } @@ -99,7 +99,7 @@ void roundTrip_singleElement(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I64_SCHEMA, WriteOptions.defaults(), - List.of(new DeltaEncoding()))) { + List.of(new DeltaEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", data)); } @@ -118,7 +118,7 @@ void roundTrip_mixedDeltas(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I64_SCHEMA, WriteOptions.defaults(), - List.of(new DeltaEncoding()))) { + List.of(new DeltaEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", data)); } @@ -140,7 +140,7 @@ void roundTrip_multipleChunks(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I64_SCHEMA, WriteOptions.defaults(), - List.of(new DeltaEncoding()))) { + List.of(new DeltaEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", chunk1)); sut.writeChunk(Map.of("ts", chunk2)); @@ -165,7 +165,7 @@ void roundTrip_sequentialTimestamps(int n, @TempDir Path tmp) throws IOException try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, I64_SCHEMA, WriteOptions.defaults(), - List.of(new DeltaEncoding()))) { + List.of(new DeltaEncodingEncoder()))) { // When sut.writeChunk(Map.of("ts", data)); } diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/DictEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/DictEncodingTest.java index cdee97a4..c48a44c6 100644 --- a/writer/src/test/java/io/github/dfa1/vortex/writer/DictEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/DictEncodingTest.java @@ -4,9 +4,8 @@ import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; -import io.github.dfa1.vortex.encoding.DictEncoding; -import io.github.dfa1.vortex.encoding.Registry; -import io.github.dfa1.vortex.encoding.PrimitiveEncoding; +import io.github.dfa1.vortex.writer.encode.DictEncodingEncoder; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.VortexReader; import io.github.dfa1.vortex.reader.Chunk; import io.github.dfa1.vortex.reader.ScanOptions; @@ -33,10 +32,10 @@ class DictEncodingTest { List.of(new DType.Primitive(PType.I32, false)), false); - private static Registry dictRegistry() { - return Registry.builder() - .register(new DictEncoding()) - .register(new PrimitiveEncoding()) + private static ReadRegistry dictRegistry() { + return ReadRegistry.builder() + .register(new io.github.dfa1.vortex.reader.decode.DictEncodingDecoder()) + .register(new io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder()) .build(); } @@ -48,7 +47,7 @@ void roundTrip_lowCardinality_valuesExpandCorrectly(@TempDir Path tmp) throws IO try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, SCHEMA, WriteOptions.defaults(), - List.of(new DictEncoding()))) { + List.of(new DictEncodingEncoder()))) { // When sut.writeChunk(Map.of("category", data)); } @@ -69,7 +68,7 @@ void roundTrip_singleUniqueValue_u8Codes(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, SCHEMA, WriteOptions.defaults(), - List.of(new DictEncoding()))) { + List.of(new DictEncodingEncoder()))) { // When sut.writeChunk(Map.of("category", data)); } @@ -89,7 +88,7 @@ void roundTrip_multipleChunks(@TempDir Path tmp) throws IOException { try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); var sut = VortexWriter.create(ch, SCHEMA, WriteOptions.defaults(), - List.of(new DictEncoding()))) { + List.of(new DictEncodingEncoder()))) { // When sut.writeChunk(Map.of("category", chunk1)); sut.writeChunk(Map.of("category", chunk2)); diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/GlobalDictUtf8Test.java b/writer/src/test/java/io/github/dfa1/vortex/writer/GlobalDictUtf8Test.java index d1caaa5f..955cb732 100644 --- a/writer/src/test/java/io/github/dfa1/vortex/writer/GlobalDictUtf8Test.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/GlobalDictUtf8Test.java @@ -1,7 +1,7 @@ package io.github.dfa1.vortex.writer; import io.github.dfa1.vortex.core.DType; -import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.VortexReader; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -53,7 +53,7 @@ void lowCardinality_utf8_acrossChunks_usesGlobalDict(@TempDir Path tmp) throws I assertThat(size).as("global dict for 5 chunks of 1000 strings").isLessThan(8_000L); // And values round-trip exactly across all chunks. - try (var vf = VortexReader.open(file, Registry.loadAll())) { + try (var vf = VortexReader.open(file, ReadRegistry.loadAll())) { List got = readAllStrings(vf, "status"); assertThat(got).hasSize(rowsPerChunk * chunkCount); // Spot-check: first row of chunk c starts with dict[c % 3], etc. @@ -82,7 +82,7 @@ void highCardinality_utf8_fallsBackToPerChunk(@TempDir Path tmp) throws IOExcept } // Then — file is readable, all rows round-trip (correctness, not size). - try (var vf = VortexReader.open(file, Registry.loadAll())) { + try (var vf = VortexReader.open(file, ReadRegistry.loadAll())) { List got = readAllStrings(vf, "status"); assertThat(got).hasSize(rows); for (int i = 0; i < rows; i++) { @@ -108,7 +108,7 @@ void utf8_globalDict_disabled_byOptions(@TempDir Path tmp) throws IOException { } // Then - try (var vf = VortexReader.open(file, Registry.loadAll())) { + try (var vf = VortexReader.open(file, ReadRegistry.loadAll())) { List got = readAllStrings(vf, "status"); assertThat(got).containsExactly(data); } diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/VortexWriterTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/VortexWriterTest.java index 312d621e..abfa305c 100644 --- a/writer/src/test/java/io/github/dfa1/vortex/writer/VortexWriterTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/VortexWriterTest.java @@ -5,9 +5,7 @@ import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.LongArray; -import io.github.dfa1.vortex.encoding.AlpEncoding; -import io.github.dfa1.vortex.encoding.Registry; -import io.github.dfa1.vortex.encoding.PrimitiveEncoding; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.VortexReader; import io.github.dfa1.vortex.reader.Chunk; import io.github.dfa1.vortex.reader.ScanOptions; @@ -50,10 +48,10 @@ private static List snapshotAll(VortexReader vf, ScanOptions opts return snapshots; } - private static Registry primitiveRegistry() { - return Registry.builder() - .register(new AlpEncoding()) - .register(new PrimitiveEncoding()) + private static ReadRegistry primitiveRegistry() { + return ReadRegistry.builder() + .register(new io.github.dfa1.vortex.reader.decode.AlpEncodingDecoder()) + .register(new io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder()) .build(); } @@ -81,8 +79,8 @@ void writeChunk_autoroutesExtensionCollectionViaSpecExtension(@TempDir Path tmp) } // Then — read back through DateExtension.decodeAll and assert end-to-end equality. - // Registry.loadAll() picks up PrimitiveEncoding (storage) plus DateExtension. - try (var vf = VortexReader.open(file, Registry.loadAll()); + // ReadRegistry.loadAll() picks up PrimitiveEncoding (storage) plus DateExtension. + try (var vf = VortexReader.open(file, ReadRegistry.loadAll()); var iter = vf.scan(ScanOptions.all())) { assertThat(iter.hasNext()).isTrue(); try (Chunk chunk = iter.next()) { @@ -113,7 +111,7 @@ void writeChunk_roundTripsTimeExtension(@TempDir Path tmp) throws IOException { } // Then - try (var vf = VortexReader.open(file, Registry.loadAll()); + try (var vf = VortexReader.open(file, ReadRegistry.loadAll()); var iter = vf.scan(ScanOptions.all())) { assertThat(iter.hasNext()).isTrue(); try (Chunk chunk = iter.next()) { @@ -142,7 +140,7 @@ void writeChunk_roundTripsTimestampExtension(@TempDir Path tmp) throws IOExcepti } // Then - try (var vf = VortexReader.open(file, Registry.loadAll()); + try (var vf = VortexReader.open(file, ReadRegistry.loadAll()); var iter = vf.scan(ScanOptions.all())) { assertThat(iter.hasNext()).isTrue(); try (Chunk chunk = iter.next()) { @@ -166,7 +164,7 @@ void chunkAs_mismatchedDomainType_throws(@TempDir Path tmp) throws IOException { List.of(java.time.LocalDate.of(2026, 6, 10)))); } - try (var vf = VortexReader.open(file, Registry.loadAll()); + try (var vf = VortexReader.open(file, ReadRegistry.loadAll()); var iter = vf.scan(ScanOptions.all())) { try (Chunk chunk = iter.next()) { // When / Then — the accessor must fail-fast, not return a wrongly-cast list @@ -195,7 +193,7 @@ void writeChunk_roundTripsUuidExtension(@TempDir Path tmp) throws IOException { } // Then - try (var vf = VortexReader.open(file, Registry.loadAll()); + try (var vf = VortexReader.open(file, ReadRegistry.loadAll()); var iter = vf.scan(ScanOptions.all())) { assertThat(iter.hasNext()).isTrue(); try (Chunk chunk = iter.next()) { @@ -233,7 +231,7 @@ void writeChunk_cascadeCompressesTimestampExtensionStorage(@TempDir Path tmp) th .isLessThan(flatSize); // And — cascaded file still round-trips back to the same Instants - try (var vf = VortexReader.open(cascadedFile, Registry.loadAll()); + try (var vf = VortexReader.open(cascadedFile, ReadRegistry.loadAll()); var iter = vf.scan(ScanOptions.all())) { try (Chunk chunk = iter.next()) { assertThat(chunk.as("events", java.time.Instant.class)) diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/ZoneMapPruningTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/ZoneMapPruningTest.java index d12f025e..2131458a 100644 --- a/writer/src/test/java/io/github/dfa1/vortex/writer/ZoneMapPruningTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/ZoneMapPruningTest.java @@ -2,8 +2,8 @@ import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; -import io.github.dfa1.vortex.encoding.Registry; -import io.github.dfa1.vortex.encoding.PrimitiveEncoding; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import io.github.dfa1.vortex.reader.VortexReader; import io.github.dfa1.vortex.reader.RowFilter; import io.github.dfa1.vortex.reader.ScanOptions; @@ -60,8 +60,8 @@ private static long[] range(long from, long to) { return arr; } - private static Registry primitiveRegistry() { - return Registry.builder().register(new PrimitiveEncoding()).build(); + private static ReadRegistry primitiveRegistry() { + return ReadRegistry.builder().register(new PrimitiveEncodingDecoder()).build(); } @Test diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/AlpEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoderTest.java similarity index 50% rename from core/src/test/java/io/github/dfa1/vortex/encoding/AlpEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoderTest.java index 5878e433..b11e7840 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/AlpEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoderTest.java @@ -1,10 +1,22 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; import io.github.dfa1.vortex.proto.ALPMetadata; import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.reader.decode.AlpEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -17,24 +29,22 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.within; -class AlpEncodingTest { +class AlpEncodingEncoderTest { + private static final AlpEncodingEncoder ENCODER = new AlpEncodingEncoder(); + private static final AlpEncodingDecoder DECODER = new AlpEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Decode { private static DecodeContext buildAlpCtxF64( - int expE, int expF, - long[] encodedVals, - long[] patchIndices, - double[] patchValues + int expE, int expF, long[] encodedVals, + long[] patchIndices, double[] patchValues ) { PatchesMetadata pm = patchIndices != null - ? new PatchesMetadata( - (long) patchIndices.length, - 0L, - io.github.dfa1.vortex.proto.PType.U32, - null, null, null) + ? new PatchesMetadata((long) patchIndices.length, 0L, + io.github.dfa1.vortex.proto.PType.U32, null, null, null) : null; byte[] metaBytes = new ALPMetadata(expE, expF, pm).encode(); @@ -56,24 +66,16 @@ private static DecodeContext buildAlpCtxF64( for (long v : patchIndices) { ib.putInt((int) v); } - byte[] valBuf = new byte[patchValues.length * 8]; ByteBuffer vb = ByteBuffer.wrap(valBuf).order(ByteOrder.LITTLE_ENDIAN); for (double v : patchValues) { vb.putDouble(v); } - - ArrayNode idxNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, - new ArrayNode[0], new int[]{1}, ArrayStats.empty()); - ArrayNode valNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, - new ArrayNode[0], new int[]{2}, ArrayStats.empty()); - + ArrayNode idxNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{1}, ArrayStats.empty()); + ArrayNode valNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{2}, ArrayStats.empty()); children = new ArrayNode[]{encNode, idxNode, valNode}; segments = new MemorySegment[]{ - MemorySegment.ofArray(encBuf), - MemorySegment.ofArray(idxBuf), - MemorySegment.ofArray(valBuf) - }; + MemorySegment.ofArray(encBuf), MemorySegment.ofArray(idxBuf), MemorySegment.ofArray(valBuf)}; } else { children = new ArrayNode[]{encNode}; segments = new MemorySegment[]{MemorySegment.ofArray(encBuf)}; @@ -82,91 +84,59 @@ private static DecodeContext buildAlpCtxF64( ArrayNode alpNode = ArrayNode.of(EncodingId.VORTEX_ALP, ByteBuffer.wrap(metaBytes), children, new int[0], ArrayStats.empty()); - Registry registry = TestRegistry.of(new AlpEncoding(), new PrimitiveEncoding()); - - return new DecodeContext(alpNode, DTypes.F64, encodedVals.length, segments, registry, java.lang.foreign.Arena.global()); + return new DecodeContext(alpNode, DTypes.F64, encodedVals.length, segments, REGISTRY, java.lang.foreign.Arena.global()); } - private static DecodeContext buildAlpCtxF32( - int expE, int expF, - int[] encodedVals - ) { + private static DecodeContext buildAlpCtxF32(int expE, int expF, int[] encodedVals) { byte[] metaBytes = new ALPMetadata(expE, expF, null).encode(); - byte[] encBuf = new byte[encodedVals.length * 4]; ByteBuffer bb = ByteBuffer.wrap(encBuf).order(ByteOrder.LITTLE_ENDIAN); for (int v : encodedVals) { bb.putInt(v); } - - ArrayNode encNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, - new ArrayNode[0], new int[]{0}, ArrayStats.empty()); - - ArrayNode alpNode = ArrayNode.of(EncodingId.VORTEX_ALP, - ByteBuffer.wrap(metaBytes), new ArrayNode[]{encNode}, new int[0], ArrayStats.empty()); - + ArrayNode encNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, ArrayStats.empty()); + ArrayNode alpNode = ArrayNode.of(EncodingId.VORTEX_ALP, ByteBuffer.wrap(metaBytes), + new ArrayNode[]{encNode}, new int[0], ArrayStats.empty()); MemorySegment[] segments = {MemorySegment.ofArray(encBuf)}; - - Registry registry = TestRegistry.of(new AlpEncoding(), new PrimitiveEncoding()); - - return new DecodeContext(alpNode, DTypes.F32, encodedVals.length, segments, registry, java.lang.foreign.Arena.global()); + return new DecodeContext(alpNode, DTypes.F32, encodedVals.length, segments, REGISTRY, java.lang.foreign.Arena.global()); } @Test void decode_f64_noPatches() { - // Given — encode 1.23 with exp_e=2, exp_f=1: encoded = round(1.23 * 100 / 10) = 12 - // decode: 12 * 10^1 * 10^-2 = 12 * 10 * 0.01 = 1.2 - // Use exp_e=0, exp_f=2: encoded = round(1.23 * 1 / 100) ... let's use known values - // encode: value * F10[e] * IF10[f] then round; decode: encoded * F10[f] * IF10[e] - // Use e=2, f=0: encoded = round(1.23 * 100 * 1.0) = 123; decode = 123 * 1.0 * 0.01 = 1.23 int expE = 2, expF = 0; long[] encoded = {123L, 456L, 789L}; double[] expected = {1.23, 4.56, 7.89}; DecodeContext ctx = buildAlpCtxF64(expE, expF, encoded, null, null); - AlpEncoding sut = new AlpEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(encoded.length); - var layout = PTypeIO.LE_DOUBLE; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, (long) i * 8)) .as("index %d", i).isCloseTo(expected[i], within(1e-9)); } } @Test void decode_f64_withPatches() { - // Given — 5 values, encoded = [100, 0, 200, 0, 300] with e=2,f=0 → [1.0, 0.0, 2.0, 0.0, 3.0] - // patches at [1, 3] with real values [Double.NaN, Double.POSITIVE_INFINITY] int expE = 2, expF = 0; long[] encoded = {100L, 0L, 200L, 0L, 300L}; long[] patchIndices = {1L, 3L}; double[] patchValues = {Double.NaN, Double.POSITIVE_INFINITY}; DecodeContext ctx = buildAlpCtxF64(expE, expF, encoded, patchIndices, patchValues); - AlpEncoding sut = new AlpEncoding(); - - // When - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_DOUBLE; - assertThat(ArraySegments.of(result).get(layout, 0L)).isCloseTo(1.0, within(1e-9)); - assertThat(ArraySegments.of(result).get(layout, 8L)).isNaN(); - assertThat(ArraySegments.of(result).get(layout, 16L)).isCloseTo(2.0, within(1e-9)); - assertThat(ArraySegments.of(result).get(layout, 24L)).isInfinite(); - assertThat(ArraySegments.of(result).get(layout, 32L)).isCloseTo(3.0, within(1e-9)); + Array result = DECODER.decode(ctx); + + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 0L)).isCloseTo(1.0, within(1e-9)); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 8L)).isNaN(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 16L)).isCloseTo(2.0, within(1e-9)); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 24L)).isInfinite(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 32L)).isCloseTo(3.0, within(1e-9)); } @ParameterizedTest @CsvSource({"0,0", "1,0", "2,1", "3,2", "4,3"}) void decode_f64_exponentCombinations(int expE, int expF) { - // Given — encode 42 with given exponents, then verify round-trip - // encode: encoded = round(42.0 * F10[e] * IF10[f]) double value = 42.0; double[] f10 = {1e0, 1e1, 1e2, 1e3, 1e4}; double[] if10 = {1e-0, 1e-1, 1e-2, 1e-3, 1e-4}; @@ -174,34 +144,23 @@ void decode_f64_exponentCombinations(int expE, int expF) { long[] encoded = {encVal}; DecodeContext ctx = buildAlpCtxF64(expE, expF, encoded, null, null); - AlpEncoding sut = new AlpEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then - var layout = PTypeIO.LE_DOUBLE; - double decoded = ArraySegments.of(result).get(layout, 0L); + double decoded = ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 0L); assertThat(decoded).isCloseTo(value, within(1e-6)); } @Test void decode_f32_noPatches() { - // Given — e=1, f=0: decode = encoded * 1.0 * 0.1 int expE = 1, expF = 0; int[] encoded = {10, 25, 100}; float[] expected = {1.0f, 2.5f, 10.0f}; DecodeContext ctx = buildAlpCtxF32(expE, expF, encoded); - AlpEncoding sut = new AlpEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_FLOAT; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 4)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, (long) i * 4)) .as("index %d", i).isCloseTo(expected[i], within(1e-6f)); } } @@ -212,80 +171,55 @@ class Encode { @Test void encode_f32_roundTrip_noPatches() { - // Given float[] values = {1.0f, 2.5f, 3.75f, 10.0f, 0.1f}; - AlpEncoding sut = new AlpEncoding(); - - Registry registry = TestRegistry.withPrimitive(sut); - // When - EncodeResult encoded = sut.encode(DTypes.F32, values, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F32, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.F32, values, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F32, REGISTRY); + Array result = DECODER.decode(ctx); - // Then - var layout = PTypeIO.LE_FLOAT; for (int i = 0; i < values.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 4)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, (long) i * 4)) .as("index %d", i).isCloseTo(values[i], within(1e-6f)); } } @Test void encode_f32_roundTrip_withPatches() { - // Given — Float.NaN and Float.POSITIVE_INFINITY can't be ALP-encoded; must become patches float[] values = {1.0f, Float.NaN, 2.5f, Float.POSITIVE_INFINITY, 3.0f}; - AlpEncoding sut = new AlpEncoding(); - - Registry registry = TestRegistry.withPrimitive(sut); - - // When - EncodeResult encoded = sut.encode(DTypes.F32, values, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F32, registry); - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_FLOAT; - assertThat(ArraySegments.of(result).get(layout, 0L)).isCloseTo(1.0f, within(1e-6f)); - assertThat(ArraySegments.of(result).get(layout, 4L)).isNaN(); - assertThat(ArraySegments.of(result).get(layout, 8L)).isCloseTo(2.5f, within(1e-6f)); - assertThat(ArraySegments.of(result).get(layout, 12L)).isInfinite(); - assertThat(ArraySegments.of(result).get(layout, 16L)).isCloseTo(3.0f, within(1e-6f)); + + EncodeResult encoded = ENCODER.encode(DTypes.F32, values, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F32, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, 0L)).isCloseTo(1.0f, within(1e-6f)); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, 4L)).isNaN(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, 8L)).isCloseTo(2.5f, within(1e-6f)); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, 12L)).isInfinite(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, 16L)).isCloseTo(3.0f, within(1e-6f)); } @Test void encode_f64_roundTrip_noPatches() { - // Given double[] values = {1.23, 4.56, 7.89, 0.001, 100.0}; - AlpEncoding sut = new AlpEncoding(); - - Registry registry = TestRegistry.withPrimitive(sut); - // When - EncodeResult encoded = sut.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F64, REGISTRY); + Array result = DECODER.decode(ctx); - // Then - var layout = PTypeIO.LE_DOUBLE; for (int i = 0; i < values.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, (long) i * 8)) .as("index %d", i).isCloseTo(values[i], within(1e-9)); } } @Test void encode_f64_metadata_expE_isNonZero() throws Exception { - // Given — 2-decimal values force ALP to pick exp_e=2 (×100); if tag drifts, exp_e reads as 0 double[] values = {1.23, 4.56, 7.89}; - AlpEncoding sut = new AlpEncoding(); - // When - EncodeResult result = sut.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); MemorySegment metaSeg = MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); ALPMetadata meta = ALPMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.exp_e()).isGreaterThan(0); } } diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoderTest.java new file mode 100644 index 00000000..1b263227 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoderTest.java @@ -0,0 +1,58 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.ALPRDMetadata; +import io.github.dfa1.vortex.reader.decode.AlpRdEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.within; + +class AlpRdEncodingEncoderTest { + + @Test + void encode_f64_roundTrip() { + // Given + double[] values = {0.1, 0.2, 0.3, 0.4, 0.5}; + var encoder = new AlpRdEncodingEncoder(); + var decoder = new AlpRdEncodingDecoder(); + ReadRegistry registry = TestRegistry.ofDecoders(decoder, new BitpackedEncodingDecoder(), new PrimitiveEncodingDecoder()); + + // When + EncodeResult encoded = encoder.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F64, registry); + var result = decoder.decode(ctx); + + // Then + for (int i = 0; i < values.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, (long) i * 8)) + .as("index %d", i).isCloseTo(values[i], within(1e-9)); + } + } + + @Test + void encode_f64_metadata_rightBitWidth_isNonZero() throws Exception { + // Given — ALPRD splits F64 into left+right parts; right_bit_width>0 means real encoding happened + // if tag drifts, right_bit_width reads as 0 (proto3 default) and right parts are all zero + double[] values = {0.1, 0.2, 0.3, 0.4, 0.5}; + var sut = new AlpRdEncodingEncoder(); + + // When + EncodeResult result = sut.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + ALPRDMetadata meta = ALPRDMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + + // Then + assertThat(meta.right_bit_width()).isGreaterThan(0); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedConstantPatchesBroadcastTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedConstantPatchesBroadcastTest.java new file mode 100644 index 00000000..c9bd5075 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedConstantPatchesBroadcastTest.java @@ -0,0 +1,83 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.proto.BitPackedMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static org.assertj.core.api.Assertions.assertThat; + +/// Regression for the IOOB crash in `BitpackedEncoding.applyPatches` (and sibling +/// `SparseEncoding`, `AlpEncoding`, `PatchedEncoding`, etc.) when a patches child is +/// encoded with [io.github.dfa1.vortex.encoding.ConstantEncoding]. +class BitpackedConstantPatchesBroadcastTest { + + @Test + void bitpackedDecode_withConstantPatchesValues_broadcastsValueAcrossPatches() { + long n = 10; + long numPatches = 3; + long constantPatchValue = 42L; + + byte[] packed = new byte[128]; + + ScalarValue idxScalar = ScalarValue.ofUint64Value(2L); + byte[] idxScalarBytes = idxScalar.encode(); + + ScalarValue valScalar = ScalarValue.ofInt64Value(constantPatchValue); + byte[] valScalarBytes = valScalar.encode(); + + PatchesMetadata patches = new PatchesMetadata(numPatches, 0, + io.github.dfa1.vortex.proto.PType.U32, null, null, null); + BitPackedMetadata meta = new BitPackedMetadata(1, 0, patches); + ByteBuffer metaBuf = ByteBuffer.wrap(meta.encode()).order(ByteOrder.LITTLE_ENDIAN); + + try (Arena arena = Arena.ofConfined()) { + MemorySegment packedSeg = arena.allocate(packed.length, 8); + MemorySegment.copy(MemorySegment.ofArray(packed), 0, packedSeg, 0, packed.length); + MemorySegment idxBufSeg = arena.allocate(idxScalarBytes.length, 1); + MemorySegment.copy(MemorySegment.ofArray(idxScalarBytes), 0, idxBufSeg, 0, idxScalarBytes.length); + MemorySegment valBufSeg = arena.allocate(valScalarBytes.length, 1); + MemorySegment.copy(MemorySegment.ofArray(valScalarBytes), 0, valBufSeg, 0, valScalarBytes.length); + + ArrayNode idxChild = ArrayNode.of(EncodingId.VORTEX_CONSTANT, null, + new ArrayNode[0], new int[]{1}, null); + ArrayNode valChild = ArrayNode.of(EncodingId.VORTEX_CONSTANT, null, + new ArrayNode[0], new int[]{2}, null); + ArrayNode root = ArrayNode.of(EncodingId.FASTLANES_BITPACKED, metaBuf, + new ArrayNode[]{idxChild, valChild}, new int[]{0}, null); + + DType dtype = new DType.Primitive(PType.I64, false); + ReadRegistry registry = ReadRegistry.loadAll(); + DecodeContext ctx = new DecodeContext(root, dtype, n, + new MemorySegment[]{packedSeg, idxBufSeg, valBufSeg}, + registry, Arena.ofAuto()); + + Array result = new BitpackedEncodingDecoder().decode(ctx); + + assertThat(result.length()).isEqualTo(n); + MemorySegment data = ArraySegments.of(result); + assertThat(data.getAtIndex(PTypeIO.LE_LONG, 2)).isEqualTo(constantPatchValue); + for (long i = 0; i < n; i++) { + if (i == 2) { + continue; + } + assertThat(data.getAtIndex(PTypeIO.LE_LONG, i)).as("non-patched index %d", i).isZero(); + } + } + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoderTest.java new file mode 100644 index 00000000..ea9a5567 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoderTest.java @@ -0,0 +1,88 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.BitPackedMetadata; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +/// Property: encode then decode is lossless for unsigned integer types. +class BitpackedEncodingEncoderTest { + + private static final BitpackedEncodingEncoder ENCODER = new BitpackedEncodingEncoder(); + private static final BitpackedEncodingDecoder DECODER = new BitpackedEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER); + + static Stream u32Arrays() { + return Stream.of( + Arguments.of("empty", new int[]{}), + Arguments.of("single", new int[]{0}), + Arguments.of("all-zeros", new int[]{0, 0, 0, 0, 0}), + Arguments.of("small-values", new int[]{1, 2, 3, 4, 5, 6, 7}), + Arguments.of("mixed", new int[]{0, 7, 63, 255, 1023, 65535}), + Arguments.of("max-unsigned", new int[]{-1, -1, -1}) // 0xFFFFFFFF + ); + } + + static Stream u64Arrays() { + return Stream.of( + Arguments.of("empty", new long[]{}), + Arguments.of("single", new long[]{0L}), + Arguments.of("small-values", new long[]{1L, 2L, 3L, 4L, 5L}), + Arguments.of("large-values", new long[]{0L, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL}) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("u32Arrays") + void encodeDecode_u32_isLossless(String name, int[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.U32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.U32, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest(name = "{0}") + @MethodSource("u64Arrays") + void encodeDecode_u64_isLossless(String name, long[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.U64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.U64, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + } + } + + @Test + void encode_i32_metadata_bitWidth_isNonZero() throws Exception { + // Given — max value 5 needs 3 bits; if tag drifts, bit_width reads as 0 (proto3 default) + int[] data = {1, 2, 3, 4, 5}; + + EncodeResult result = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + BitPackedMetadata meta = BitPackedMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + + assertThat(meta.bit_width()).isGreaterThan(0); + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/BitpackedEncodingPatchesTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingPatchesTest.java similarity index 68% rename from core/src/test/java/io/github/dfa1/vortex/encoding/BitpackedEncodingPatchesTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingPatchesTest.java index 0ab812af..7fe87cde 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/BitpackedEncodingPatchesTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingPatchesTest.java @@ -1,10 +1,21 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.BitPackedMetadata; -import io.github.dfa1.vortex.proto.PatchesMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.BitPackedMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -17,21 +28,21 @@ class BitpackedEncodingPatchesTest { + private static final BitpackedEncodingEncoder ENCODER = new BitpackedEncodingEncoder(); + private static final BitpackedEncodingDecoder DECODER = new BitpackedEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); + @Nested class Decode { @Test void decode_appliesPatches_overridingBitPackedValues() { - // Given — bit-pack [10, 20, 30, 40, 50] via the production encoder (bitWidth = 6), - // then attach synthetic patches metadata that rewrites indices [1, 3] with [777, 999]. int[] base = {10, 20, 30, 40, 50}; - BitpackedEncoding sut = new BitpackedEncoding(); - EncodeResult packed = sut.encode(DTypes.I32, base, EncodeTestHelper.testCtx()); + EncodeResult packed = ENCODER.encode(DTypes.I32, base, EncodeTestHelper.testCtx()); MemorySegment packedSeg = packed.buffers().getFirst(); byte[] packedBytes = packedSeg.toArray(java.lang.foreign.ValueLayout.JAVA_BYTE); - // Build new BitPackedMetadata that re-uses the packed bytes but advertises patches. PatchesMetadata patches = new PatchesMetadata(2L, 0L, io.github.dfa1.vortex.proto.PType.U32, null, null, null); byte[] metaBytes = new BitPackedMetadata(6, 0, patches).encode(); @@ -48,8 +59,7 @@ void decode_appliesPatches_overridingBitPackedValues() { ArrayNode bpNode = ArrayNode.of(EncodingId.FASTLANES_BITPACKED, ByteBuffer.wrap(metaBytes), new ArrayNode[]{idxNode, valNode}, - new int[]{0}, - ArrayStats.empty()); + new int[]{0}, ArrayStats.empty()); MemorySegment[] segments = { MemorySegment.ofArray(packedBytes), @@ -57,15 +67,11 @@ void decode_appliesPatches_overridingBitPackedValues() { MemorySegment.ofArray(valBuf) }; - Registry registry = TestRegistry.of(new BitpackedEncoding(), new PrimitiveEncoding()); - DecodeContext ctx = new DecodeContext( - bpNode, DTypes.I32, base.length, segments, registry, Arena.global()); + bpNode, DTypes.I32, base.length, segments, REGISTRY, Arena.global()); - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(base.length); assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, 0L)).isEqualTo(10); assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, 4L)).isEqualTo(777); diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoderTest.java new file mode 100644 index 00000000..d3be412f --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoderTest.java @@ -0,0 +1,73 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.foreign.ValueLayout; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +/// Property: encode then decode is lossless for boolean arrays of all lengths (including non-multiples of 8). +class BoolEncodingEncoderTest { + + static Stream boolArrays() { + return Stream.of( + new boolean[]{}, + new boolean[]{false}, + new boolean[]{true}, + new boolean[]{false, true, false, true, false, true, false, true}, + new boolean[]{true, true, true, false, false, false, true, false, true}, + new boolean[]{false, false, false, false, false, false, false}, + new boolean[]{true, true, true, true, true, true, true, true, true} + ); + } + + @ParameterizedTest + @MethodSource("boolArrays") + void encodeDecode_isLossless(boolean[] data) { + // Given + var encoder = new BoolEncodingEncoder(); + var decoder = new BoolEncodingDecoder(); + ReadRegistry registry = TestRegistry.ofDecoders(decoder); + + // When + EncodeResult encoded = encoder.encode(DTypes.BOOL, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.BOOL, registry); + Array result = decoder.decode(ctx); + + // Then + assertThat(result).isInstanceOf(BoolArray.class); + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + byte byteVal = ArraySegments.of(result).get(ValueLayout.JAVA_BYTE, i / 8); + boolean decoded = ((byteVal >>> (i % 8)) & 1) == 1; + assertThat(decoded).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest + @MethodSource("boolArrays") + void encodedSize_isPackedBits(boolean[] data) { + // Given + var sut = new BoolEncodingEncoder(); + + // When + EncodeResult encoded = sut.encode(DTypes.BOOL, data, EncodeTestHelper.testCtx()); + + // Then — bit-packed: ceiling(n/8) bytes, always ≤ n bytes raw + long totalBytes = encoded.buffers().stream().mapToLong(java.lang.foreign.MemorySegment::byteSize).sum(); + assertThat(totalBytes).isEqualTo((long) (data.length + 7) / 8); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoderTest.java new file mode 100644 index 00000000..23391ae7 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoderTest.java @@ -0,0 +1,53 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.reader.decode.ByteBoolEncodingDecoder; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +class ByteBoolEncodingEncoderTest { + + static Stream boolArrays() { + return Stream.of( + new boolean[]{}, + new boolean[]{false}, + new boolean[]{true}, + new boolean[]{true, false, true, false, true}, + new boolean[]{false, false, false, false}, + new boolean[]{true, true, true, true, true, true, true, true, true} + ); + } + + @ParameterizedTest + @MethodSource("boolArrays") + void encodeDecode_isLossless(boolean[] data) { + // Given + var encoder = new ByteBoolEncodingEncoder(); + var decoder = new ByteBoolEncodingDecoder(); + ReadRegistry registry = TestRegistry.ofDecoders(decoder); + + // When + EncodeResult encoded = encoder.encode(DTypes.BOOL, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.BOOL, registry); + Array result = decoder.decode(ctx); + + // Then + assertThat(result.length()).isEqualTo(data.length); + BoolArray boolArr = (BoolArray) result; + for (int i = 0; i < data.length; i++) { + assertThat(boolArr.getBoolean(i)).as("index %d", i).isEqualTo(data[i]); + } + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/CascadingCompressorTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/CascadingCompressorTest.java similarity index 71% rename from core/src/test/java/io/github/dfa1/vortex/encoding/CascadingCompressorTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/CascadingCompressorTest.java index 5f09856c..55283640 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/CascadingCompressorTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/CascadingCompressorTest.java @@ -1,25 +1,44 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.array.DoubleArray; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.ReadRegistry; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; class CascadingCompressorTest { - private static final List ALL_CODECS = List.of( - new AlpEncoding(), new FrameOfReferenceEncoding(), new DictEncoding(), - new BitpackedEncoding(), new PrimitiveEncoding()); + private static final List ALL_CODECS = List.of( + new AlpEncodingEncoder(), new FrameOfReferenceEncodingEncoder(), new DictEncodingEncoder(), + new BitpackedEncodingEncoder(), new PrimitiveEncodingEncoder()); + + private static Map toMap(List codecs) { + Map map = new HashMap<>(); + for (EncodingEncoder enc : codecs) { + map.put(enc.encodingId(), enc); + } + return Map.copyOf(map); + } private static EncodeContext ctx(int depth) { - return EncodeContext.ofDepth(depth, Arena.ofAuto(), Registry.of(ALL_CODECS)); + return EncodeContext.ofDepth(depth, Arena.ofAuto(), toMap(ALL_CODECS)); } @Nested @@ -37,7 +56,7 @@ void depth0_picksTerminalOnly_forF64() { // When EncodeResult result = sut.encode(DTypes.F64, values, ctx(0)); - // Then - result should be a valid non-null encode result + // Then assertThat(result).isNotNull(); assertThat(result.rootNode()).isNotNull(); assertThat(result.buffers()).isNotEmpty(); @@ -45,7 +64,7 @@ void depth0_picksTerminalOnly_forF64() { @Test void depth1_alpPlusBitpacked_producesSmallResult_forF64() { - // Given: OHLC-style prices — ALP-encodable, small range → bitpackable residuals + // Given double[] values = new double[4096]; for (int i = 0; i < values.length; i++) { values[i] = 100.0 + (i % 50) * 0.01; @@ -55,28 +74,27 @@ void depth1_alpPlusBitpacked_producesSmallResult_forF64() { // When EncodeResult result = sut.encode(DTypes.F64, values, ctx(1)); - // Then - cascaded result should be smaller than raw primitive (4096 * 8 = 32768 bytes) + // Then long totalBytes = result.buffers().stream().mapToLong(MemorySegment::byteSize).sum(); assertThat(totalBytes).isLessThan(4096L * 8); } @Test void excludedEncodings_areSkipped() { - // Given: exclude AlpEncoding — should not be selected for DTypes.F64 + // Given double[] values = new double[512]; for (int i = 0; i < values.length; i++) { values[i] = i * 1.5; } - // Depth=1 but ALP excluded via context EncodeContext encodeCtx = new EncodeContext( - Arena.ofAuto(), Registry.of(ALL_CODECS), + Arena.ofAuto(), toMap(ALL_CODECS), 1, Set.of(EncodingId.VORTEX_ALP), 42L, 64, 0.1); CascadingCompressor sut = new CascadingCompressor(ALL_CODECS); // When EncodeResult result = sut.encode(DTypes.F64, values, encodeCtx); - // Then - ALP node should not appear in the tree + // Then assertThat(containsEncoding(result.rootNode(), EncodingId.VORTEX_ALP)).isFalse(); } @@ -106,9 +124,9 @@ void alpBitpacked_f64() { CascadingCompressor sut = new CascadingCompressor(ALL_CODECS); // When - Registry registry = Registry.of(ALL_CODECS); - EncodeResult result = sut.encode(DTypes.F64, values, EncodeContext.ofDepth(1, Arena.ofAuto(), registry)); - DecodeContext decodeCtx = EncodeTestHelper.toDecodeContext(result, values.length, DTypes.F64, registry); + ReadRegistry registry = ReadRegistry.loadAll(); + EncodeResult result = sut.encode(DTypes.F64, values, EncodeContext.ofDepth(1, Arena.ofAuto(), toMap(ALL_CODECS))); + DecodeContext decodeCtx = DecodeTestHelper.toDecodeContext(result, values.length, DTypes.F64, registry); DoubleArray decoded = (DoubleArray) registry.decode(decodeCtx); // Then @@ -119,7 +137,7 @@ void alpBitpacked_f64() { @Test void forBitpacked_i64() { - // Given: integers in narrow range → FoR reduces to small residuals → bitpackable + // Given long[] values = new long[1024]; for (int i = 0; i < values.length; i++) { values[i] = 1_000_000L + (i % 200); @@ -127,9 +145,9 @@ void forBitpacked_i64() { CascadingCompressor sut = new CascadingCompressor(ALL_CODECS); // When - Registry registry = Registry.of(ALL_CODECS); - EncodeResult result = sut.encode(DTypes.I64, values, EncodeContext.ofDepth(1, Arena.ofAuto(), registry)); - DecodeContext decodeCtx = EncodeTestHelper.toDecodeContext(result, values.length, DTypes.I64, registry); + ReadRegistry registry = ReadRegistry.loadAll(); + EncodeResult result = sut.encode(DTypes.I64, values, EncodeContext.ofDepth(1, Arena.ofAuto(), toMap(ALL_CODECS))); + DecodeContext decodeCtx = DecodeTestHelper.toDecodeContext(result, values.length, DTypes.I64, registry); LongArray decoded = (LongArray) registry.decode(decodeCtx); // Then @@ -140,7 +158,7 @@ void forBitpacked_i64() { @Test void dictBitpacked_i32() { - // Given: low-cardinality int column + // Given int[] values = new int[2048]; for (int i = 0; i < values.length; i++) { values[i] = i % 10; @@ -148,9 +166,9 @@ void dictBitpacked_i32() { CascadingCompressor sut = new CascadingCompressor(ALL_CODECS); // When - Registry registry = Registry.of(ALL_CODECS); - EncodeResult result = sut.encode(DTypes.I32, values, EncodeContext.ofDepth(1, Arena.ofAuto(), registry)); - DecodeContext decodeCtx = EncodeTestHelper.toDecodeContext(result, values.length, DTypes.I32, registry); + ReadRegistry registry = ReadRegistry.loadAll(); + EncodeResult result = sut.encode(DTypes.I32, values, EncodeContext.ofDepth(1, Arena.ofAuto(), toMap(ALL_CODECS))); + DecodeContext decodeCtx = DecodeTestHelper.toDecodeContext(result, values.length, DTypes.I32, registry); io.github.dfa1.vortex.core.array.IntArray decoded = (io.github.dfa1.vortex.core.array.IntArray) registry.decode(decodeCtx); diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ChunkedEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoderTest.java similarity index 63% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ChunkedEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoderTest.java index 59030e0f..76773a2e 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ChunkedEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoderTest.java @@ -1,10 +1,22 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.ChunkedData; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.reader.decode.ChunkedEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -17,11 +29,16 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class ChunkedEncodingTest { +class ChunkedEncodingEncoderTest { private static final ValueLayout.OfLong LE_LONG = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + private static final ChunkedEncodingEncoder ENCODER = new ChunkedEncodingEncoder(); + private static final ChunkedEncodingDecoder DECODER = new ChunkedEncodingDecoder(); + private static final PrimitiveEncodingEncoder PRIM_ENCODER = new PrimitiveEncodingEncoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); + private static ArrayNode toArrayNode(EncodeNode enc) { ArrayNode[] children = new ArrayNode[enc.children().length]; for (int i = 0; i < children.length; i++) { @@ -39,23 +56,15 @@ class Encode { @Test void roundTrip_twoChunks_i64_preservesValues() { - // Given long[] chunk0 = {10L, 20L, 30L}; long[] chunk1 = {40L, 50L}; DType i64 = new DType.Primitive(PType.I64, false); - var sut = new ChunkedEncoding(); - Registry registry = Registry.builder() - .register(sut) - .register(new PrimitiveEncoding()) - .build(); ChunkedData data = new ChunkedData(List.of(chunk0, chunk1), new long[]{3, 2}); - // When - EncodeResult encoded = sut.encode(i64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, 5L, i64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(i64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, 5L, i64, REGISTRY); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(5); assertThat(ArraySegments.of(result).get(LE_LONG, 0L)).isEqualTo(10L); assertThat(ArraySegments.of(result).get(LE_LONG, 8L)).isEqualTo(20L); @@ -66,28 +75,19 @@ void roundTrip_twoChunks_i64_preservesValues() { @Test void encodeNode_hasNoDirectBuffers_offsetsAsFirstChild() { - // Given long[] chunk0 = {1L, 2L}; DType i64 = new DType.Primitive(PType.I64, false); - var sut = new ChunkedEncoding(); ChunkedData data = new ChunkedData(List.of(chunk0), new long[]{2}); - // When - EncodeResult result = sut.encode(i64, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(i64, data, EncodeTestHelper.testCtx()); - // Then assertThat(result.rootNode().bufferIndices()).isEmpty(); - assertThat(result.rootNode().children()).hasSize(2); // offsets + 1 chunk - assertThat(result.buffers()).hasSize(2); // offsets buf + chunk buf + assertThat(result.rootNode().children()).hasSize(2); + assertThat(result.buffers()).hasSize(2); } @Test void mismatchedLengths_throws() { - // Given - var sut = new ChunkedEncoding(); - DType i64 = new DType.Primitive(PType.I64, false); - - // When / Then assertThatThrownBy(() -> new ChunkedData(List.of(new long[]{1L}), new long[]{1, 2})) .isInstanceOf(IllegalArgumentException.class); } @@ -98,37 +98,21 @@ class Decode { @Test void roundTrip_twoChunks_concatenatesValues() { - // Given long[] chunk0 = {10L, 20L, 30L}; long[] chunk1 = {40L, 50L}; DType i64 = new DType.Primitive(PType.I64, false); DType u64 = new DType.Primitive(PType.U64, false); - var sut = new ChunkedEncoding(); - Registry registry = Registry.builder() - .register(sut) - .register(new PrimitiveEncoding()) - .build(); - - // Build chunk_offsets segment: [0, 3, 5] as U64 LE - EncodeResult offsetsResult = new PrimitiveEncoding().encode(u64, new long[]{0L, 3L, 5L}, EncodeTestHelper.testCtx()); - // Build chunk0 segment - EncodeResult chunk0Result = new PrimitiveEncoding().encode(i64, chunk0, EncodeTestHelper.testCtx()); - // Build chunk1 segment - EncodeResult chunk1Result = new PrimitiveEncoding().encode(i64, chunk1, EncodeTestHelper.testCtx()); + EncodeResult offsetsResult = PRIM_ENCODER.encode(u64, new long[]{0L, 3L, 5L}, EncodeTestHelper.testCtx()); + EncodeResult chunk0Result = PRIM_ENCODER.encode(i64, chunk0, EncodeTestHelper.testCtx()); + EncodeResult chunk1Result = PRIM_ENCODER.encode(i64, chunk1, EncodeTestHelper.testCtx()); - // Collect all buffers: [offsets_buf, chunk0_buf, chunk1_buf] MemorySegment[] allBufs = { offsetsResult.buffers().getFirst(), chunk0Result.buffers().getFirst(), chunk1Result.buffers().getFirst() }; - // Build ArrayNode tree: - // root: ChunkedEncoding, children=[offsetsNode, chunk0Node, chunk1Node], buffers=[] - // offsetsNode: PrimitiveEncoding, bufferIndices=[0] - // chunk0Node: PrimitiveEncoding, bufferIndices=[1] - // chunk1Node: PrimitiveEncoding, bufferIndices=[2] ArrayNode offsetsNode = toArrayNode(offsetsResult.rootNode()); ArrayNode chunk0Node = toArrayNode(remapped(chunk0Result.rootNode(), 1)); ArrayNode chunk1Node = toArrayNode(remapped(chunk1Result.rootNode(), 2)); @@ -137,12 +121,9 @@ void roundTrip_twoChunks_concatenatesValues() { new ArrayNode[]{offsetsNode, chunk0Node, chunk1Node}, new int[]{}, null); - DecodeContext ctx = new DecodeContext(root, i64, 5L, allBufs, registry, Arena.ofAuto()); + DecodeContext ctx = new DecodeContext(root, i64, 5L, allBufs, REGISTRY, Arena.ofAuto()); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(5); assertThat(ArraySegments.of(result).get(LE_LONG, 0L)).isEqualTo(10L); assertThat(ArraySegments.of(result).get(LE_LONG, 8L)).isEqualTo(20L); @@ -153,18 +134,12 @@ void roundTrip_twoChunks_concatenatesValues() { @Test void singleChunk_returnsSameValues() { - // Given long[] data = {1L, 2L, 3L}; DType i64 = new DType.Primitive(PType.I64, false); DType u64 = new DType.Primitive(PType.U64, false); - Registry registry = Registry.builder() - .register(new ChunkedEncoding()) - .register(new PrimitiveEncoding()) - .build(); - - EncodeResult offsetsResult = new PrimitiveEncoding().encode(u64, new long[]{0L, 3L}, EncodeTestHelper.testCtx()); - EncodeResult chunkResult = new PrimitiveEncoding().encode(i64, data, EncodeTestHelper.testCtx()); + EncodeResult offsetsResult = PRIM_ENCODER.encode(u64, new long[]{0L, 3L}, EncodeTestHelper.testCtx()); + EncodeResult chunkResult = PRIM_ENCODER.encode(i64, data, EncodeTestHelper.testCtx()); MemorySegment[] allBufs = { offsetsResult.buffers().getFirst(), @@ -176,12 +151,9 @@ void singleChunk_returnsSameValues() { new ArrayNode[]{toArrayNode(offsetsResult.rootNode()), toArrayNode(remapped(chunkResult.rootNode(), 1))}, new int[]{}, null); - DecodeContext ctx = new DecodeContext(root, i64, 3L, allBufs, registry, Arena.ofAuto()); - - // When - Array result = new ChunkedEncoding().decode(ctx); + DecodeContext ctx = new DecodeContext(root, i64, 3L, allBufs, REGISTRY, Arena.ofAuto()); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(3); for (int i = 0; i < 3; i++) { assertThat(ArraySegments.of(result).get(LE_LONG, (long) i * 8)).isEqualTo(data[i]); @@ -190,16 +162,11 @@ void singleChunk_returnsSameValues() { @Test void noChildren_throws() { - // Given DType i64 = new DType.Primitive(PType.I64, false); - Registry registry = Registry.builder() - .register(new ChunkedEncoding()) - .build(); ArrayNode root = ArrayNode.of(EncodingId.VORTEX_CHUNKED, null, new ArrayNode[]{}, new int[]{}, null); - DecodeContext ctx = new DecodeContext(root, i64, 0L, new MemorySegment[]{}, registry, Arena.ofAuto()); + DecodeContext ctx = new DecodeContext(root, i64, 0L, new MemorySegment[]{}, REGISTRY, Arena.ofAuto()); - // When / Then - assertThatThrownBy(() -> new ChunkedEncoding().decode(ctx)) + assertThatThrownBy(() -> DECODER.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("at least one child"); } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ConstantEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoderTest.java similarity index 52% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ConstantEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoderTest.java index c0342f93..11fb304b 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ConstantEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoderTest.java @@ -1,4 +1,4 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; @@ -8,6 +8,15 @@ import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.LongArray; import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.reader.decode.ConstantEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -21,7 +30,11 @@ /// Property: encode then decode is lossless for constant (all-equal) arrays. /// Property: decode allocates O(1) memory regardless of rowCount. -class ConstantEncodingTest { +class ConstantEncodingEncoderTest { + + private static final ConstantEncodingEncoder ENCODER = new ConstantEncodingEncoder(); + private static final ConstantEncodingDecoder DECODER = new ConstantEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER); @Nested class Decode { @@ -31,18 +44,14 @@ void decode_largeRowCount_bufferStaysConstantSize() { // Given — 10M rows would allocate 80 MB if the decoder materializes every element; // the correct impl stores exactly one element regardless of logical length. long rowCount = 10_000_000L; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); // When - EncodeResult encoded = sut.encode(DTypes.I64, new long[]{42L}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I64, new long[]{42L}, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); // Then assertThat(result.length()).isEqualTo(rowCount); - // Constant encoding must not materialize the full array: the backing buffer must - // hold exactly one element. Before fix: buffer is rowCount * 8 bytes. assertThat(ArraySegments.of(result).byteSize()) .as("constant encoding must not allocate O(rowCount) memory") .isEqualTo(Long.BYTES); @@ -75,65 +84,43 @@ static Stream i64ConstantArrays() { @ParameterizedTest @MethodSource("i32ConstantArrays") void encodeDecode_i32_isLossless(int[] data) { - // Given - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - var le = PTypeIO.LE_INT; - - // When - EncodeResult encoded = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, REGISTRY); + Array result = DECODER.decode(ctx); - // Then — buffer holds one element; logical length is n assertThat(result.length()).isEqualTo(data.length); assertThat(ArraySegments.of(result).byteSize()).isEqualTo(Integer.BYTES); - assertThat(ArraySegments.of(result).get(le, 0L)).isEqualTo(data[0]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, 0L)).isEqualTo(data[0]); } @ParameterizedTest @MethodSource("i64ConstantArrays") void encodeDecode_i64_isLossless(long[] data) { - // Given - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - var le = PTypeIO.LE_LONG; - - // When - EncodeResult encoded = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); - // Then — buffer holds one element; logical length is n assertThat(result.length()).isEqualTo(data.length); assertThat(ArraySegments.of(result).byteSize()).isEqualTo(Long.BYTES); - assertThat(ArraySegments.of(result).get(le, 0L)).isEqualTo(data[0]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, 0L)).isEqualTo(data[0]); } } /// ConstantEncoding stores 1 element in the buffer but reports length=N. /// Primitive Array accessors must broadcast that single element across every - /// logical index, not OOB. Regression-guard for commit ed658b7 (added the - /// broadcast semantic) and 051a794 (fast-path branch-split — preserve broadcast - /// only on the slow path where `elementCount != length`). + /// logical index, not OOB. @Nested class Broadcast { @ParameterizedTest @ValueSource(longs = {1, 2, 10, 1_000, 131_072, 1_000_000L}) void i64_getLong_returnsConstantAtEveryIndex(long rowCount) { - // Given long constant = 0xDEADBEEFCAFEBABEL; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.I64, new long[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, registry); + EncodeResult encoded = ENCODER.encode(DTypes.I64, new long[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, REGISTRY); - // When - LongArray result = (LongArray) sut.decode(ctx); + LongArray result = (LongArray) DECODER.decode(ctx); - // Then — getLong at first, last, and arbitrary midpoints all return the constant. - // Catches: missing modulo (OOB or wrong value) and accidental skip of broadcast branch. assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getLong(0)).isEqualTo(constant); assertThat(result.getLong(rowCount - 1)).isEqualTo(constant); @@ -144,39 +131,26 @@ void i64_getLong_returnsConstantAtEveryIndex(long rowCount) { @Test void i64_fold_broadcastsAcrossAllRows() { - // Given — fold is the hot path for column aggregates. Must use the broadcast - // branch when elementCount != length, otherwise fold returns wrong sum. long rowCount = 1_000_000L; long constant = 7L; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.I64, new long[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, registry); + EncodeResult encoded = ENCODER.encode(DTypes.I64, new long[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, REGISTRY); - // When - LongArray result = (LongArray) sut.decode(ctx); + LongArray result = (LongArray) DECODER.decode(ctx); long sum = result.fold(0L, Long::sum); - // Then — every row contributes the constant; total is rowCount * constant. - // Pre-fix (no modulo) the fold would OOB on row 1; post-bug-fix without branch-split - // the result is correct but ~25% slower. assertThat(sum).isEqualTo(rowCount * constant); } @Test void i32_getInt_broadcastsAcrossEveryIndex() { - // Given long rowCount = 10_000L; int constant = -123_456; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.I32, new int[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I32, registry); + EncodeResult encoded = ENCODER.encode(DTypes.I32, new int[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I32, REGISTRY); - // When - IntArray result = (IntArray) sut.decode(ctx); + IntArray result = (IntArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getInt(0)).isEqualTo(constant); assertThat(result.getInt(rowCount - 1)).isEqualTo(constant); @@ -185,40 +159,29 @@ void i32_getInt_broadcastsAcrossEveryIndex() { @Test void f64_getDouble_broadcastsAcrossEveryIndex() { - // Given long rowCount = 10_000L; double constant = 3.141592653589793; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.F64, new double[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.F64, registry); + EncodeResult encoded = ENCODER.encode(DTypes.F64, new double[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.F64, REGISTRY); - // When - DoubleArray result = (DoubleArray) sut.decode(ctx); + DoubleArray result = (DoubleArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getDouble(0)).isEqualTo(constant); assertThat(result.getDouble(rowCount - 1)).isEqualTo(constant); - // Iterative double sum drifts (~1e-10 per 10K rows) — use tolerance, not strict equality. assertThat(result.fold(0.0, Double::sum)) .isCloseTo(rowCount * constant, org.assertj.core.data.Offset.offset(1e-6)); } @Test void f32_getFloat_broadcastsAcrossEveryIndex() { - // Given long rowCount = 10_000L; float constant = 2.71828f; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.F32, new float[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.F32, registry); + EncodeResult encoded = ENCODER.encode(DTypes.F32, new float[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.F32, REGISTRY); - // When - FloatArray result = (FloatArray) sut.decode(ctx); + FloatArray result = (FloatArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getFloat(0)).isEqualTo(constant); assertThat(result.getFloat(rowCount - 1)).isEqualTo(constant); @@ -226,18 +189,13 @@ void f32_getFloat_broadcastsAcrossEveryIndex() { @Test void i16_getShort_broadcastsAcrossEveryIndex() { - // Given long rowCount = 10_000L; short constant = (short) -12345; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.I16, new short[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I16, registry); + EncodeResult encoded = ENCODER.encode(DTypes.I16, new short[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I16, REGISTRY); - // When - ShortArray result = (ShortArray) sut.decode(ctx); + ShortArray result = (ShortArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getShort(0)).isEqualTo(constant); assertThat(result.getShort(rowCount - 1)).isEqualTo(constant); @@ -245,18 +203,13 @@ void i16_getShort_broadcastsAcrossEveryIndex() { @Test void i8_getByte_broadcastsAcrossEveryIndex() { - // Given long rowCount = 10_000L; byte constant = (byte) -42; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.I8, new byte[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I8, registry); + EncodeResult encoded = ENCODER.encode(DTypes.I8, new byte[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I8, REGISTRY); - // When - ByteArray result = (ByteArray) sut.decode(ctx); + ByteArray result = (ByteArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getByte(0)).isEqualTo(constant); assertThat(result.getByte(rowCount - 1)).isEqualTo(constant); diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/DateTimePartsEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoderTest.java similarity index 66% rename from core/src/test/java/io/github/dfa1/vortex/encoding/DateTimePartsEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoderTest.java index 68e8851b..f941d4b5 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/DateTimePartsEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoderTest.java @@ -1,11 +1,24 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.DateTimePartsMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.array.GenericArray; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DateTimePartsData; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.encoding.TimeUnit; +import io.github.dfa1.vortex.proto.DateTimePartsMetadata; +import io.github.dfa1.vortex.reader.decode.DateTimePartsEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -19,16 +32,19 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class DateTimePartsEncodingTest { +class DateTimePartsEncodingEncoderTest { + + private static final DateTimePartsEncodingEncoder ENCODER = new DateTimePartsEncodingEncoder(); + private static final DateTimePartsEncodingDecoder DECODER = new DateTimePartsEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static final DType EXT_TIMESTAMP_MS = timestampDType(TimeUnit.Milliseconds); private static final DType EXT_TIMESTAMP_NS = timestampDType(TimeUnit.Nanoseconds); private static DType timestampDType(TimeUnit unit) { - // Rust hand-rolled: byte[0]=unit tag, bytes[1-2]=tz_len u16 LE (0 = no tz) ByteBuffer meta = ByteBuffer.allocate(3).order(ByteOrder.LITTLE_ENDIAN); meta.put((byte) unit.ordinal()); - meta.putShort((short) 0); // no timezone + meta.putShort((short) 0); meta.flip(); return new DType.Extension("vortex.timestamp", new DType.Primitive(PType.I64, false), meta, false); @@ -42,45 +58,26 @@ private static ArrayNode toArrayNode(EncodeNode node) { return ArrayNode.of(node.encodingId(), node.metadata(), children, node.bufferIndices(), ArrayStats.empty()); } - private static Registry registry() { - return Registry.builder() - .register(new DateTimePartsEncoding()) - .register(new PrimitiveEncoding()) - .build(); - } - @Nested class Encode { @Test void accepts_extensionDtype_true() { - // Given - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - - // When / Then - assertThat(sut.accepts(EXT_TIMESTAMP_MS)).isTrue(); + assertThat(ENCODER.accepts(EXT_TIMESTAMP_MS)).isTrue(); } @Test void accepts_primitiveDtype_false() { - // Given - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I64)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I64)).isFalse(); } @Test void encode_producesThreeChildren_noBuffersAtRoot() { - // Given long[] timestamps = {0L, 86_400_000L}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); - // Then assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_DATETIMEPARTS); assertThat(result.rootNode().bufferIndices()).isEmpty(); assertThat(result.rootNode().children()).hasSize(3); @@ -88,14 +85,11 @@ void encode_producesThreeChildren_noBuffersAtRoot() { @Test void encode_missingMetadata_throws() { - // Given DType noMeta = new DType.Extension("vortex.timestamp", new DType.Primitive(PType.I64, false), null, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); DateTimePartsData data = new DateTimePartsData(new long[]{0L}, false); - // When / Then - assertThatThrownBy(() -> sut.encode(noMeta, data, EncodeTestHelper.testCtx())) + assertThatThrownBy(() -> ENCODER.encode(noMeta, data, EncodeTestHelper.testCtx())) .hasMessageContaining("extension metadata missing"); } } @@ -105,22 +99,17 @@ class Decode { @Test void roundTrip_milliseconds_preservesDaysSecondsSubseconds() { - // Given - // 1970-01-02 01:02:03.456 UTC in millis long msPerDay = 86_400_000L; long ts = msPerDay + (3723L * 1000L) + 456L; long[] timestamps = {ts}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 1, bufs, registry(), Arena.global()); - GenericArray decoded = (GenericArray) sut.decode(ctx); + toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 1, bufs, REGISTRY, Arena.global()); + GenericArray decoded = (GenericArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(1); LongArray days = (LongArray) decoded.child(0); LongArray seconds = (LongArray) decoded.child(1); @@ -132,22 +121,17 @@ void roundTrip_milliseconds_preservesDaysSecondsSubseconds() { @Test void roundTrip_nanoseconds_preservesSubsecondPrecision() { - // Given - // 1970-01-02 01:02:03.456789123 UTC in nanos long nsPerDay = 86_400_000_000_000L; long ts = nsPerDay + (3723L * 1_000_000_000L) + 456_789_123L; long[] timestamps = {ts}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_NS, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_NS, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), EXT_TIMESTAMP_NS, 1, bufs, registry(), Arena.global()); - GenericArray decoded = (GenericArray) sut.decode(ctx); + toArrayNode(result.rootNode()), EXT_TIMESTAMP_NS, 1, bufs, REGISTRY, Arena.global()); + GenericArray decoded = (GenericArray) DECODER.decode(ctx); - // Then LongArray days = (LongArray) decoded.child(0); LongArray seconds = (LongArray) decoded.child(1); LongArray subseconds = (LongArray) decoded.child(2); @@ -158,19 +142,15 @@ void roundTrip_nanoseconds_preservesSubsecondPrecision() { @Test void roundTrip_epoch_allZero() { - // Given long[] timestamps = {0L}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 1, bufs, registry(), Arena.global()); - GenericArray decoded = (GenericArray) sut.decode(ctx); + toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 1, bufs, REGISTRY, Arena.global()); + GenericArray decoded = (GenericArray) DECODER.decode(ctx); - // Then LongArray days = (LongArray) decoded.child(0); LongArray seconds = (LongArray) decoded.child(1); LongArray subseconds = (LongArray) decoded.child(2); @@ -181,20 +161,16 @@ void roundTrip_epoch_allZero() { @Test void roundTrip_multipleTimestamps_allRowsPreserved() { - // Given long msPerDay = 86_400_000L; long[] timestamps = {0L, msPerDay, msPerDay + 1000L, msPerDay + 1001L}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 4, bufs, registry(), Arena.global()); - GenericArray decoded = (GenericArray) sut.decode(ctx); + toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 4, bufs, REGISTRY, Arena.global()); + GenericArray decoded = (GenericArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(4); LongArray days = (LongArray) decoded.child(0); assertThat(days.getLong(0)).isEqualTo(0L); @@ -209,20 +185,16 @@ void roundTrip_multipleTimestamps_allRowsPreserved() { @ParameterizedTest @EnumSource(value = TimeUnit.class, names = {"Nanoseconds", "Microseconds", "Milliseconds", "Seconds"}) void roundTrip_allUnits_epochIsZero(TimeUnit unit) { - // Given DType dtype = timestampDType(unit); long[] timestamps = {0L}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), dtype, 1, bufs, registry(), Arena.global()); - GenericArray decoded = (GenericArray) sut.decode(ctx); + toArrayNode(result.rootNode()), dtype, 1, bufs, REGISTRY, Arena.global()); + GenericArray decoded = (GenericArray) DECODER.decode(ctx); - // Then LongArray days = (LongArray) decoded.child(0); LongArray seconds = (LongArray) decoded.child(1); LongArray subseconds = (LongArray) decoded.child(2); @@ -233,21 +205,16 @@ void roundTrip_allUnits_epochIsZero(TimeUnit unit) { @Test void encode_metadata_ptypes_areI64() throws Exception { - // Given — DateTimeParts always encodes days/seconds/subseconds as I64 (ordinal=7) - // if any tag drifts, the corresponding ptype reads as 0 (U8) which is proto3 default long[] timestamps = {0L, 86_400_000L}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); - DateTimePartsMetadata meta = - DateTimePartsMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + DateTimePartsMetadata meta = DateTimePartsMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then - assertThat(meta.days_ptype().value()).isEqualTo(7); // I64 - assertThat(meta.seconds_ptype().value()).isEqualTo(7); // I64 - assertThat(meta.subseconds_ptype().value()).isEqualTo(7); // I64 + assertThat(meta.days_ptype().value()).isEqualTo(7); + assertThat(meta.seconds_ptype().value()).isEqualTo(7); + assertThat(meta.subseconds_ptype().value()).isEqualTo(7); } } } diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoderTest.java new file mode 100644 index 00000000..597d3d8e --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoderTest.java @@ -0,0 +1,80 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.GenericArray; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.DecimalBytePartsMetadata; +import io.github.dfa1.vortex.reader.decode.DecimalBytePartsEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class DecimalBytePartsEncodingEncoderTest { + + @Test + void roundTrip_longArray_preservesMspValues() { + // Given + long[] values = {1L, -2L, 3L}; + DType dtype = new DType.Decimal((byte) 18, (byte) 0, false); + var encoder = new DecimalBytePartsEncodingEncoder(); + var decoder = new DecimalBytePartsEncodingDecoder(); + ReadRegistry registry = TestRegistry.ofDecoders(decoder, new PrimitiveEncodingDecoder()); + + // When + EncodeResult encoded = encoder.encode(dtype, values, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, values.length, dtype, registry); + GenericArray result = (GenericArray) decoder.decode(ctx); + + // Then + assertThat(result.length()).isEqualTo(values.length); + Array msp = result.child(0); + assertThat(msp.length()).isEqualTo(values.length); + for (int i = 0; i < values.length; i++) { + assertThat(ArraySegments.of(msp).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(values[i]); + } + } + + @Test + void encodeNode_hasNoBuffers_andOneMspChild() { + // Given + long[] values = {10L, 20L}; + DType dtype = new DType.Decimal((byte) 18, (byte) 0, false); + var sut = new DecimalBytePartsEncodingEncoder(); + + // When + EncodeResult result = sut.encode(dtype, values, EncodeTestHelper.testCtx()); + + // Then + assertThat(result.rootNode().bufferIndices()).isEmpty(); + assertThat(result.rootNode().children()).hasSize(1); + assertThat(result.buffers()).hasSize(1); + } + + @Test + void metadata_zerothChildPtype_isI64_lowerPartCountIsZero() throws Exception { + // Given + long[] values = {42L}; + DType dtype = new DType.Decimal((byte) 18, (byte) 0, false); + var sut = new DecimalBytePartsEncodingEncoder(); + + // When + EncodeResult result = sut.encode(dtype, values, EncodeTestHelper.testCtx()); + + // Then + byte[] metaBytes = new byte[result.rootNode().metadata().remaining()]; + result.rootNode().metadata().duplicate().get(metaBytes); + DecimalBytePartsMetadata meta = + DecimalBytePartsMetadata.decode(java.lang.foreign.MemorySegment.ofArray(metaBytes), 0, metaBytes.length); + assertThat(meta.zeroth_child_ptype().value()).isEqualTo(7); // I64 ordinal + assertThat(meta.lower_part_count()).isEqualTo(0); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoderTest.java new file mode 100644 index 00000000..b2763184 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoderTest.java @@ -0,0 +1,99 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.encoding.TestSegments; +import io.github.dfa1.vortex.proto.DecimalMetadata; +import io.github.dfa1.vortex.reader.decode.DecimalEncodingDecoder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class DecimalEncodingEncoderTest { + + private static final DecimalEncodingEncoder ENCODER = new DecimalEncodingEncoder(); + private static final DecimalEncodingDecoder DECODER = new DecimalEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER); + + @Test + void roundTrip_i64Precision_preservesBuffer() { + long[] values = {100L, -200L, 300L}; + MemorySegment input = TestSegments.leLongs(values); + DType dtype = new DType.Decimal((byte) 18, (byte) 2, false); + + EncodeResult encoded = ENCODER.encode(dtype, input, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, values.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(values.length); + for (int i = 0; i < values.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(values[i]); + } + } + + @Test + void accepts_decimalDtype_true_primitiveReturnsFalse() { + assertThat(ENCODER.accepts(new DType.Decimal((byte) 18, (byte) 2, false))).isTrue(); + assertThat(ENCODER.accepts(new DType.Primitive(PType.I64, false))).isFalse(); + } + + @ParameterizedTest(name = "precision={0} → valuesType={1}") + @CsvSource({ + "1, 0", + "2, 0", + "3, 1", + "4, 1", + "5, 2", + "9, 2", + "10, 3", + "18, 3", + "19, 4", + "38, 4", + "39, 5", + }) + void valuesType_matchesPrecisionBoundaries(int precision, int expectedValuesType) throws Exception { + int byteWidth = switch (expectedValuesType) { + case 0 -> 1; + case 1 -> 2; + case 2 -> 4; + case 3 -> 8; + case 4 -> 16; + default -> 32; + }; + MemorySegment input = Arena.ofAuto().allocate(byteWidth); + DType dtype = new DType.Decimal((byte) precision, (byte) 0, false); + + EncodeResult encoded = ENCODER.encode(dtype, input, EncodeTestHelper.testCtx()); + + byte[] metaBytes = new byte[encoded.rootNode().metadata().remaining()]; + encoded.rootNode().metadata().duplicate().get(metaBytes); + DecimalMetadata meta = DecimalMetadata.decode(java.lang.foreign.MemorySegment.ofArray(metaBytes), 0, metaBytes.length); + assertThat(meta.values_type()).isEqualTo(expectedValuesType); + } + + @Test + void invalidBufferSize_throws() { + MemorySegment input = Arena.ofAuto().allocate(7); + DType dtype = new DType.Decimal((byte) 18, (byte) 0, false); + + assertThatThrownBy(() -> ENCODER.encode(dtype, input, EncodeTestHelper.testCtx())) + .isInstanceOf(VortexException.class) + .hasMessageContaining("not multiple of byteWidth"); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoderTest.java new file mode 100644 index 00000000..e7403982 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoderTest.java @@ -0,0 +1,110 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.DeltaMetadata; +import io.github.dfa1.vortex.reader.decode.DeltaEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.foreign.MemorySegment; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +class DeltaEncodingEncoderTest { + + private static final DeltaEncodingEncoder ENCODER = new DeltaEncodingEncoder(); + private static final DeltaEncodingDecoder DECODER = new DeltaEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); + + static Stream i64Arrays() { + return Stream.of( + new long[]{0}, + new long[]{Long.MIN_VALUE}, + new long[]{0, 1, 2, 3, 4, 5, 6, 7}, + new long[]{100, 200, 300, 400, 500}, + new long[]{-100, -50, 0, 50, 100}, + new long[]{1000, 999, 998, 997, 996} + ); + } + + static Stream i32Arrays() { + return Stream.of( + new int[]{0}, + new int[]{Integer.MIN_VALUE}, + new int[]{0, 1, 2, 3, 4, 5, 6, 7}, + new int[]{10, 20, 30, 40, 50}, + new int[]{-5, -4, -3, -2, -1, 0} + ); + } + + static Stream monotoneI64Arrays() { + return Stream.of( + Arguments.of("ascending-1", new long[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + Arguments.of("ascending-100", new long[]{0, 100, 200, 300, 400, 500, 600, 700, 800, 900}), + Arguments.of("descending", new long[]{1000, 999, 998, 997, 996, 995, 994, 993, 992, 991}) + ); + } + + @ParameterizedTest + @MethodSource("i64Arrays") + void encodeDecode_i64_isLossless(long[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest + @MethodSource("i32Arrays") + void encodeDecode_i32_isLossless(int[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest(name = "{0}") + @MethodSource("monotoneI64Arrays") + void encodeDecode_monotoneI64_isLossless(String name, long[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + } + } + + @Test + void encode_i64_metadata_deltasLen_isNonZero() throws Exception { + long[] data = {10L, 20L, 30L, 40L, 50L}; + + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + MemorySegment metaSeg = MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + DeltaMetadata meta = DeltaMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + + assertThat(meta.deltas_len()).isGreaterThan(0); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoderTest.java new file mode 100644 index 00000000..2d1f9bd4 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoderTest.java @@ -0,0 +1,162 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.DictMetadata; +import io.github.dfa1.vortex.reader.decode.DictEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.VarBinEncodingDecoder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.foreign.MemorySegment; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +class DictEncodingEncoderTest { + + private static final DictEncodingEncoder ENCODER = new DictEncodingEncoder(); + private static final DictEncodingDecoder DECODER = new DictEncodingDecoder(); + private static final ReadRegistry PRIM_REG = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); + private static final ReadRegistry UTF8_REG = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new VarBinEncodingDecoder()); + + static Stream i32Arrays() { + return Stream.of( + new int[]{0}, + new int[]{1, 2, 3}, + new int[]{0, 1, 2, 0, 1, 2, 0, 1, 2}, + new int[]{42, 42, 42, 42, 42}, + new int[]{Integer.MIN_VALUE, Integer.MAX_VALUE, 0, Integer.MIN_VALUE, Integer.MAX_VALUE} + ); + } + + static Stream i64Arrays() { + return Stream.of( + new long[]{0L}, + new long[]{Long.MIN_VALUE, Long.MAX_VALUE, 0L, Long.MIN_VALUE}, + new long[]{1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L} + ); + } + + static Stream repetitiveI32Arrays() { + return Stream.of( + Arguments.of("binary-100", repeat(new int[]{0, 1}, 50)), + Arguments.of("single-value-50", repeat(new int[]{42}, 50)), + Arguments.of("three-values-60", repeat(new int[]{10, 20, 30}, 20)) + ); + } + + static Stream utf8Arrays() { + return Stream.of( + Arguments.of("single", new String[]{"hello"}), + Arguments.of("all-unique", new String[]{"a", "b", "c"}), + Arguments.of("repeated", new String[]{"AAPL", "GOOG", "AAPL", "MSFT", "GOOG", "AAPL"}), + Arguments.of("unicode", new String[]{"café", "naïve", "café", "résumé", "naïve"}), + Arguments.of("empty-string", new String[]{"", "x", "", "y", ""}) + ); + } + + private static int[] repeat(int[] pattern, int times) { + int[] result = new int[pattern.length * times]; + for (int i = 0; i < times; i++) { + System.arraycopy(pattern, 0, result, i * pattern.length, pattern.length); + } + return result; + } + + @SuppressWarnings("SameParameterValue") + private static String[] repeat(String[] pattern, int times) { + String[] result = new String[pattern.length * times]; + for (int i = 0; i < times; i++) { + System.arraycopy(pattern, 0, result, i * pattern.length, pattern.length); + } + return result; + } + + @ParameterizedTest + @MethodSource("i32Arrays") + void encodeDecode_i32_isLossless(int[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, PRIM_REG); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest + @MethodSource("i64Arrays") + void encodeDecode_i64_isLossless(long[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, PRIM_REG); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest(name = "{0}") + @MethodSource("repetitiveI32Arrays") + void encodedSize_lowCardinality_compressesWellVsRaw(String name, int[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + + long encodedBytes = encoded.buffers().stream().mapToLong(MemorySegment::byteSize).sum(); + long rawBytes = (long) data.length * 4; + assertThat(encodedBytes).isLessThan(rawBytes); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("utf8Arrays") + void encodeDecode_utf8_isLossless(String name, String[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.UTF8, UTF8_REG); + Array result = DECODER.decode(ctx); + + assertThat(result).isInstanceOf(VarBinArray.class); + VarBinArray arr = (VarBinArray) result; + assertThat(arr.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + String actual = arr.getString(i); + assertThat(actual).as("index %d", i).isEqualTo(data[i]); + } + } + + @Test + void encodedSize_lowCardinalityUtf8_compressesWellVsRaw() { + String[] symbols = {"AAPL", "GOOG", "MSFT"}; + String[] data = repeat(symbols, 1000); + + EncodeResult encoded = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + + long encodedBytes = encoded.buffers().stream().mapToLong(MemorySegment::byteSize).sum(); + long rawBytes = 3000L * 5; + assertThat(encodedBytes).isLessThan(rawBytes); + } + + @Test + void encode_utf8_metadata_valuesLen_matchesUniqueCount() throws Exception { + String[] data = {"apple", "banana", "apple", "banana", "apple"}; + + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + DictMetadata meta = DictMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + + assertThat(meta.values_len()).isEqualTo(2); + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ExtEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoderTest.java similarity index 71% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ExtEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoderTest.java index 5db5bed7..6f74ef9b 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ExtEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoderTest.java @@ -1,8 +1,21 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.encoding.TestSegments; +import io.github.dfa1.vortex.reader.decode.ExtEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -11,59 +24,49 @@ import static org.assertj.core.api.Assertions.assertThat; -class ExtEncodingTest { +class ExtEncodingEncoderTest { + + private static final ExtEncodingEncoder ENCODER = new ExtEncodingEncoder(); + private static final ExtEncodingDecoder DECODER = new ExtEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(new PrimitiveEncodingDecoder(), DECODER); @Nested class Encode { @Test void accepts_extensionDtype_returnsTrue() { - // Given DType extDType = new DType.Extension("vortex.timestamp", new DType.Primitive(PType.I64, false), null, false); - var sut = new ExtEncoding(); - - // When / Then - assertThat(sut.accepts(extDType)).isTrue(); + assertThat(ENCODER.accepts(extDType)).isTrue(); } @Test void accepts_primitiveDtype_returnsFalse() { - // Given - var sut = new ExtEncoding(); - - // When / Then - assertThat(sut.accepts(new DType.Primitive(PType.I64, false))).isFalse(); + assertThat(ENCODER.accepts(new DType.Primitive(PType.I64, false))).isFalse(); } @Test void encode_extensionWrappingI64_roundTrips() { - // Given long[] data = {100L, 200L, 300L, 400L}; DType storageDType = new DType.Primitive(PType.I64, false); DType extDType = new DType.Extension("vortex.timestamp", storageDType, null, false); - var sut = new ExtEncoding(); - // When - EncodeResult result = sut.encode(extDType, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(extDType, data, EncodeTestHelper.testCtx()); - // Then — root is ext, child is primitive assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_EXT); assertThat(result.rootNode().children()).hasSize(1); assertThat(result.rootNode().children()[0].encodingId()).isEqualTo(EncodingId.VORTEX_PRIMITIVE); - // Decode back - Registry registry = TestRegistry.of(new PrimitiveEncoding(), new ExtEncoding()); ArrayNode rootNode = encodeNodeToArrayNode(result.rootNode()); DecodeContext ctx = new DecodeContext( rootNode, extDType, data.length, result.buffers().toArray(MemorySegment[]::new), - registry, Arena.ofAuto()); - var decoded = sut.decode(ctx); + REGISTRY, Arena.ofAuto()); + var decoded = DECODER.decode(ctx); assertThat(decoded).isInstanceOf(LongArray.class); + LongArray longArray = (LongArray) decoded; for (int i = 0; i < data.length; i++) { - LongArray longArray = (LongArray) decoded; assertThat(longArray.getLong(i)).isEqualTo(data[i]); } } @@ -82,16 +85,12 @@ class Cascade { @Test void encodeCascade_exposesStorageAsOpenChild() { - // Given — extension wraps an I64 storage column long[] data = {100L, 200L, 300L, 400L}; DType storageDType = new DType.Primitive(PType.I64, false); DType extDType = new DType.Extension("vortex.timestamp", storageDType, null, false); - var sut = new ExtEncoding(); - // When - CascadeStep step = sut.encodeCascade(extDType, data, EncodeTestHelper.testCtx()); + CascadeStep step = ENCODER.encodeCascade(extDType, data, EncodeTestHelper.testCtx()); - // Then — non-terminal step with one open slot for the storage child assertThat(step.applicable()).isTrue(); assertThat(step.isTerminal()).isFalse(); assertThat(step.openChildren()).hasSize(1); @@ -105,12 +104,8 @@ void encodeCascade_exposesStorageAsOpenChild() { @Test void encodeCascade_rejectsNonExtensionDtype() { - // Given - var sut = new ExtEncoding(); - - // When / Then org.assertj.core.api.Assertions.assertThatThrownBy(() -> - sut.encodeCascade(new DType.Primitive(PType.I64, false), new long[]{1L}, + ENCODER.encodeCascade(new DType.Primitive(PType.I64, false), new long[]{1L}, EncodeTestHelper.testCtx())) .isInstanceOf(io.github.dfa1.vortex.core.VortexException.class) .hasMessageContaining("expected extension dtype"); @@ -122,33 +117,24 @@ class Decode { @Test void decode_extensionWrappingI64_returnsStorageArray() { - // Given long[] values = {10L, 20L, 30L, 40L}; MemorySegment buf = TestSegments.leLongs(values); DType storageDType = new DType.Primitive(PType.I64, false); DType extDType = new DType.Extension("vortex.timestamp", storageDType, null, false); - // child node: vortex.primitive with buffer index 0 ArrayNode primitiveNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, null); - // parent node: vortex.ext, no buffers, one child ArrayNode extNode = ArrayNode.of(EncodingId.VORTEX_EXT, null, new ArrayNode[]{primitiveNode}, new int[0], null); - Registry registry = TestRegistry.of(new PrimitiveEncoding(), new ExtEncoding()); - DecodeContext ctx = new DecodeContext( - extNode, extDType, values.length, new MemorySegment[]{buf}, registry, Arena.ofAuto()); - - var sut = new ExtEncoding(); + extNode, extDType, values.length, new MemorySegment[]{buf}, REGISTRY, Arena.ofAuto()); - // When - var result = sut.decode(ctx); + var result = DECODER.decode(ctx); - // Then assertThat(result).isInstanceOf(LongArray.class); assertThat(result.length()).isEqualTo(values.length); + LongArray longArray = (LongArray) result; for (int i = 0; i < values.length; i++) { - LongArray longArray = (LongArray) result; assertThat(longArray.getLong(i)).isEqualTo(values[i]); } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/FixedSizeListEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoderTest.java similarity index 67% rename from core/src/test/java/io/github/dfa1/vortex/encoding/FixedSizeListEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoderTest.java index 4745c9a2..e69a1072 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/FixedSizeListEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoderTest.java @@ -1,9 +1,21 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.array.FixedSizeListArray; import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.FixedSizeListData; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.reader.decode.FixedSizeListEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -13,8 +25,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class FixedSizeListEncodingTest { +class FixedSizeListEncodingEncoderTest { + private static final FixedSizeListEncodingEncoder ENCODER = new FixedSizeListEncodingEncoder(); + private static final FixedSizeListEncodingDecoder DECODER = new FixedSizeListEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static ArrayNode toArrayNode(EncodeNode node) { ArrayNode[] children = new ArrayNode[node.children().length]; @@ -24,47 +39,28 @@ private static ArrayNode toArrayNode(EncodeNode node) { return ArrayNode.of(node.encodingId(), node.metadata(), children, node.bufferIndices(), ArrayStats.empty()); } - private static Registry registry() { - return Registry.builder() - .register(new FixedSizeListEncoding()) - .register(new PrimitiveEncoding()) - .build(); - } - @Nested class Encode { @Test void accepts_fixedSizeListDtype_true() { - // Given - FixedSizeListEncoding sut = new FixedSizeListEncoding(); DType.FixedSizeList dtype = new DType.FixedSizeList(DTypes.I32, 3, false); - - // When / Then - assertThat(sut.accepts(dtype)).isTrue(); + assertThat(ENCODER.accepts(dtype)).isTrue(); } @Test void accepts_primitiveDtype_false() { - // Given - FixedSizeListEncoding sut = new FixedSizeListEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I32)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I32)).isFalse(); } @Test void encode_producesOneChild_noBuffers() { - // Given DType.FixedSizeList dtype = new DType.FixedSizeList(DTypes.I32, 2, false); int[] elements = {1, 2, 3, 4}; FixedSizeListData data = new FixedSizeListData(elements, 2); - FixedSizeListEncoding sut = new FixedSizeListEncoding(); - // When - EncodeResult result = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); - // Then assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_FIXED_SIZE_LIST); assertThat(result.rootNode().bufferIndices()).isEmpty(); assertThat(result.rootNode().children()).hasSize(1); @@ -76,20 +72,16 @@ class Decode { @Test void roundTrip_i32Elements_preservesValues() { - // Given DType.FixedSizeList dtype = new DType.FixedSizeList(DTypes.I32, 3, false); int[] elements = {10, 20, 30, 40, 50, 60}; FixedSizeListData data = new FixedSizeListData(elements, 2); - FixedSizeListEncoding sut = new FixedSizeListEncoding(); - // When - EncodeResult result = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), dtype, 2, bufs, registry(), Arena.global()); - FixedSizeListArray decoded = (FixedSizeListArray) sut.decode(ctx); + toArrayNode(result.rootNode()), dtype, 2, bufs, REGISTRY, Arena.global()); + FixedSizeListArray decoded = (FixedSizeListArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(2); assertThat(decoded.fixedSize()).isEqualTo(3); IntArray elems = (IntArray) decoded.elements(); @@ -101,20 +93,16 @@ void roundTrip_i32Elements_preservesValues() { @Test void roundTrip_fixedSizeOne_preservesValues() { - // Given DType.FixedSizeList dtype = new DType.FixedSizeList(DTypes.I32, 1, false); int[] elements = {7, 8, 9}; FixedSizeListData data = new FixedSizeListData(elements, 3); - FixedSizeListEncoding sut = new FixedSizeListEncoding(); - // When - EncodeResult result = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), dtype, 3, bufs, registry(), Arena.global()); - FixedSizeListArray decoded = (FixedSizeListArray) sut.decode(ctx); + toArrayNode(result.rootNode()), dtype, 3, bufs, REGISTRY, Arena.global()); + FixedSizeListArray decoded = (FixedSizeListArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(3); assertThat(decoded.fixedSize()).isEqualTo(1); IntArray elems = (IntArray) decoded.elements(); @@ -125,14 +113,11 @@ void roundTrip_fixedSizeOne_preservesValues() { @Test void decode_wrongDtype_throws() { - // Given - FixedSizeListEncoding sut = new FixedSizeListEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_FIXED_SIZE_LIST, null, new ArrayNode[0], new int[0], ArrayStats.empty()); - DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0, new MemorySegment[0], registry(), Arena.global()); + DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0, new MemorySegment[0], REGISTRY, Arena.global()); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> DECODER.decode(ctx)) .hasMessageContaining("DType.FixedSizeList"); } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/FrameOfReferenceEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoderTest.java similarity index 60% rename from core/src/test/java/io/github/dfa1/vortex/encoding/FrameOfReferenceEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoderTest.java index 6ea12e5e..0ebba22d 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/FrameOfReferenceEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoderTest.java @@ -1,4 +1,4 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; @@ -6,7 +6,20 @@ import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.FrameOfReferenceEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -20,8 +33,11 @@ import static org.assertj.core.api.Assertions.assertThat; -class FrameOfReferenceEncodingTest { +class FrameOfReferenceEncodingEncoderTest { + private static final FrameOfReferenceEncodingEncoder ENCODER = new FrameOfReferenceEncodingEncoder(); + private static final FrameOfReferenceEncodingDecoder DECODER = new FrameOfReferenceEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Decode { @@ -43,115 +59,74 @@ private static DecodeContext buildForContext( } ArrayNode childNode = ArrayNode.of( - EncodingId.VORTEX_PRIMITIVE, - null, - new ArrayNode[0], - new int[]{0}, - ArrayStats.empty() - ); - + EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, ArrayStats.empty()); ArrayNode forNode = ArrayNode.of( - EncodingId.FASTLANES_FOR, - ByteBuffer.wrap(metaBytes), - new ArrayNode[]{childNode}, - new int[0], - ArrayStats.empty() - ); + EncodingId.FASTLANES_FOR, ByteBuffer.wrap(metaBytes), + new ArrayNode[]{childNode}, new int[0], ArrayStats.empty()); MemorySegment[] segments = {MemorySegment.ofArray(childBytes)}; - - Registry registry = TestRegistry.of(new FrameOfReferenceEncoding(), new PrimitiveEncoding()); - - return new DecodeContext(forNode, dtype, residuals.length, segments, registry, java.lang.foreign.Arena.global()); + return new DecodeContext(forNode, dtype, residuals.length, segments, REGISTRY, java.lang.foreign.Arena.global()); } @Test void decode_i64_addsReferenceToResiduals() { - // Given long reference = 1000L; long[] residuals = {0, 1, 2, 3, 4}; long[] expected = {1000, 1001, 1002, 1003, 1004}; DecodeContext ctx = buildForContext(DTypes.I64, reference, residuals, PType.I64); - FrameOfReferenceEncoding sut = new FrameOfReferenceEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(residuals.length); - var layout = PTypeIO.LE_LONG; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) - .as("index %d", i) - .isEqualTo(expected[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)) + .as("index %d", i).isEqualTo(expected[i]); } } @Test void decode_i32_addsReferenceToResiduals() { - // Given long reference = -100L; long[] residuals = {0, 5, 10, 15}; int[] expected = {-100, -95, -90, -85}; DecodeContext ctx = buildForContext(DTypes.I32, reference, residuals, PType.I32); - FrameOfReferenceEncoding sut = new FrameOfReferenceEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(residuals.length); - var layout = PTypeIO.LE_INT; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 4)) - .as("index %d", i) - .isEqualTo(expected[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)) + .as("index %d", i).isEqualTo(expected[i]); } } @Test void decode_zeroReference_returnsChildUnchanged() { - // Given — reference == 0, should skip the add entirely long[] residuals = {7, 8, 9}; DecodeContext ctx = buildForContext(DTypes.I64, 0L, residuals, PType.I64); - FrameOfReferenceEncoding sut = new FrameOfReferenceEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then — values unchanged - var layout = PTypeIO.LE_LONG; for (int i = 0; i < residuals.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)).isEqualTo(residuals[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(residuals[i]); } } @ParameterizedTest @ValueSource(longs = {Long.MIN_VALUE, Long.MAX_VALUE, -1L, 1L}) void decode_wrappingAdd_i64(long reference) { - // Given — wrapping arithmetic: MAX + 1 wraps to MIN long[] residuals = {1L}; DecodeContext ctx = buildForContext(DTypes.I64, reference, residuals, PType.I64); - FrameOfReferenceEncoding sut = new FrameOfReferenceEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_LONG; - long got = ArraySegments.of(result).get(layout, 0L); + long got = ArraySegments.of(result).get(PTypeIO.LE_LONG, 0L); assertThat(got).isEqualTo(residuals[0] + reference); } @Test void decode_nullableResiduals_returnsMaskedArrayWithCorrectValues() { - // Given — 4 I32 residuals; positions 1 and 3 are null (validity: 0b00000101 = 0x05) - // Residuals: [0, 0, 5, 0], reference: 100 → valid outputs: [100, ?, 105, ?] long reference = 100L; long[] residuals = {0, 0, 5, 0}; - MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{0x05}); // bits 0,2 + MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{0x05}); byte[] residualBytes = new byte[residuals.length * 4]; ByteBuffer bb = ByteBuffer.wrap(residualBytes).order(ByteOrder.LITTLE_ENDIAN); @@ -167,26 +142,23 @@ void decode_nullableResiduals_returnsMaskedArrayWithCorrectValues() { ArrayNode forNode = ArrayNode.of( EncodingId.FASTLANES_FOR, ByteBuffer.wrap(metaBytes), new ArrayNode[]{primNode}, new int[0], ArrayStats.empty()); - Registry registry = TestRegistry.of(new FrameOfReferenceEncoding(), new PrimitiveEncoding(), new BoolEncoding()); + ReadRegistry registry = TestRegistry.ofDecoders( + new FrameOfReferenceEncodingDecoder(), new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); MemorySegment[] segments = {MemorySegment.ofArray(residualBytes), validitySeg}; DecodeContext ctx = new DecodeContext( forNode, DTypes.I32, residuals.length, segments, registry, java.lang.foreign.Arena.global()); - FrameOfReferenceEncoding sut = new FrameOfReferenceEncoding(); - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then — MaskedArray; reference added to valid positions only assertThat(result).isInstanceOf(MaskedArray.class); MaskedArray masked = (MaskedArray) result; assertThat(masked.isValid(0)).isTrue(); assertThat(masked.isValid(1)).isFalse(); assertThat(masked.isValid(2)).isTrue(); assertThat(masked.isValid(3)).isFalse(); - var layout = PTypeIO.LE_INT; - assertThat(ArraySegments.of(masked.inner()).get(layout, 0L)).isEqualTo(100); - assertThat(ArraySegments.of(masked.inner()).get(layout, 8L)).isEqualTo(105); + assertThat(ArraySegments.of(masked.inner()).get(PTypeIO.LE_INT, 0L)).isEqualTo(100); + assertThat(ArraySegments.of(masked.inner()).get(PTypeIO.LE_INT, 8L)).isEqualTo(105); } } @@ -215,40 +187,26 @@ static Stream i32Arrays() { @ParameterizedTest @MethodSource("i64Arrays") void encodeDecode_i64_isLossless(long[] data) { - // Given - var sut = new FrameOfReferenceEncoding(); - Registry registry = TestRegistry.withPrimitive(sut); - var le = PTypeIO.LE_LONG; - - // When - EncodeResult encoded = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); } } @ParameterizedTest @MethodSource("i32Arrays") void encodeDecode_i32_isLossless(int[] data) { - // Given - var sut = new FrameOfReferenceEncoding(); - Registry registry = TestRegistry.withPrimitive(sut); - var le = PTypeIO.LE_INT; - - // When - EncodeResult encoded = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, REGISTRY); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/FsstEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoderTest.java similarity index 60% rename from core/src/test/java/io/github/dfa1/vortex/encoding/FsstEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoderTest.java index cdd8eb20..7b464193 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/FsstEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoderTest.java @@ -1,9 +1,21 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.FSSTMetadata; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.FSSTMetadata; +import io.github.dfa1.vortex.reader.decode.FsstEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -20,7 +32,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class FsstEncodingTest { +class FsstEncodingEncoderTest { + + private static final FsstEncodingEncoder ENCODER = new FsstEncodingEncoder(); + private static final FsstEncodingDecoder DECODER = new FsstEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Encode { @@ -56,50 +72,30 @@ private static String[] repeat(String s, int n) { @Test void accepts_utf8_true() { - // Given - var sut = new FsstEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.UTF8)).isTrue(); + assertThat(ENCODER.accepts(DTypes.UTF8)).isTrue(); } @Test void accepts_binary_true() { - // Given - var sut = new FsstEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.BINARY)).isTrue(); + assertThat(ENCODER.accepts(DTypes.BINARY)).isTrue(); } @Test void accepts_primitive_false() { - // Given - var sut = new FsstEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I32)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I32)).isFalse(); } @ParameterizedTest(name = "{0}") - @MethodSource("io.github.dfa1.vortex.encoding.FsstEncodingTest$Encode#stringArrays") + @MethodSource("stringArrays") void encode_thenDecode_roundtripsAllStrings(String name, String[] values) { - // Given - var sut = new FsstEncoding(); Arena arena = Arena.ofAuto(); - // When - EncodeResult result = sut.encode(DTypes.UTF8, values, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.UTF8, values, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); ArrayNode node = toArrayNode(result.rootNode()); - Registry registry = Registry.builder() - .register(new PrimitiveEncoding()) - .register(sut) - .build(); - DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, values.length, bufs, registry, arena); - var decoded = (VarBinArray) sut.decode(ctx); - - // Then + DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, values.length, bufs, REGISTRY, arena); + var decoded = (VarBinArray) DECODER.decode(ctx); + assertThat(decoded.length()).isEqualTo(values.length); for (int i = 0; i < values.length; i++) { assertThat(decoded.getString(i)).as("index %d", i).isEqualTo(values[i]); @@ -116,31 +112,26 @@ private static DecodeContext buildCtx( ) { Arena arena = Arena.ofAuto(); - // Buffer 0: symbol table (8 bytes per symbol, LE u64) MemorySegment symBuf = arena.allocate(Math.max(symbols.length * 8L, 1), 8); for (int i = 0; i < symbols.length; i++) { symBuf.setAtIndex(PTypeIO.LE_LONG, i, symbols[i]); } - // Buffer 1: symbol lengths (1 byte per symbol) MemorySegment symLenBuf = arena.allocate(Math.max(symLens.length, 1)); for (int i = 0; i < symLens.length; i++) { symLenBuf.set(ValueLayout.JAVA_BYTE, i, symLens[i]); } - // Buffer 2: compressed bytes MemorySegment compBuf = arena.allocate(Math.max(compressed.length, 1)); for (int i = 0; i < compressed.length; i++) { compBuf.set(ValueLayout.JAVA_BYTE, i, compressed[i]); } - // Buffer 3: uncompressed_lengths (I32) MemorySegment uncompLenBuf = arena.allocate((long) uncompLens.length * Integer.BYTES, Integer.BYTES); for (int i = 0; i < uncompLens.length; i++) { uncompLenBuf.setAtIndex(PTypeIO.LE_INT, i, uncompLens[i]); } - // Buffer 4: codes_offsets (I32, n+1 elements) MemorySegment codesOffBuf = arena.allocate((long) codesOffsets.length * Integer.BYTES, Integer.BYTES); for (int i = 0; i < codesOffsets.length; i++) { codesOffBuf.setAtIndex(PTypeIO.LE_INT, i, codesOffsets[i]); @@ -158,103 +149,47 @@ private static DecodeContext buildCtx( EncodingId.VORTEX_FSST, ByteBuffer.wrap(metaBytes), new ArrayNode[]{uncompLensNode, codesOffNode}, new int[]{0, 1, 2}, null); - return new DecodeContext(root, DTypes.UTF8, n, segs, buildRegistry(), arena); - } - - private static Registry buildRegistry() { - return Registry.builder().register(new PrimitiveEncoding()).build(); + return new DecodeContext(root, DTypes.UTF8, n, segs, REGISTRY, arena); } @Test void decode_singleByteSymbol_decompressesCorrectly() { - // Given: symbol 0 = 'A' (LE u64 = 0x41, length 1); string "AA" → codes [0, 0] - var sut = new FsstEncoding(); DecodeContext ctx = buildCtx( - new long[]{0x41L}, - new byte[]{1}, - new byte[]{0x00, 0x00}, - new int[]{2}, - new int[]{0, 2}, - 1 - ); - - // When - var result = sut.decode(ctx); - - // Then - assertThat(result).isInstanceOf(VarBinArray.class); + new long[]{0x41L}, new byte[]{1}, new byte[]{0x00, 0x00}, + new int[]{2}, new int[]{0, 2}, 1); + var result = DECODER.decode(ctx); VarBinArray vba = (VarBinArray) result; - assertThat(vba.length()).isEqualTo(1); assertThat(vba.getBytes(0)).isEqualTo("AA".getBytes(StandardCharsets.UTF_8)); } @Test void decode_escapeByte_decompressesCorrectly() { - // Given: no symbols; string "XY" → ESCAPE 'X' ESCAPE 'Y' - var sut = new FsstEncoding(); DecodeContext ctx = buildCtx( - new long[0], - new byte[0], - new byte[]{(byte) 0xFF, 0x58, (byte) 0xFF, 0x59}, - new int[]{2}, - new int[]{0, 4}, - 1 - ); - - // When - var result = sut.decode(ctx); - - // Then - assertThat(result).isInstanceOf(VarBinArray.class); + new long[0], new byte[0], new byte[]{(byte) 0xFF, 0x58, (byte) 0xFF, 0x59}, + new int[]{2}, new int[]{0, 4}, 1); + var result = DECODER.decode(ctx); VarBinArray vba = (VarBinArray) result; - assertThat(vba.length()).isEqualTo(1); assertThat(vba.getBytes(0)).isEqualTo("XY".getBytes(StandardCharsets.UTF_8)); } @Test void decode_multiByteSymbol_decompressesCorrectly() { - // Given: symbol 0 = "ab" (LE u64 = 0x6261, length 2); string "ab" → code [0] - var sut = new FsstEncoding(); DecodeContext ctx = buildCtx( - new long[]{0x6261L}, - new byte[]{2}, - new byte[]{0x00}, - new int[]{2}, - new int[]{0, 1}, - 1 - ); - - // When - var result = sut.decode(ctx); - - // Then - assertThat(result).isInstanceOf(VarBinArray.class); + new long[]{0x6261L}, new byte[]{2}, new byte[]{0x00}, + new int[]{2}, new int[]{0, 1}, 1); + var result = DECODER.decode(ctx); VarBinArray vba = (VarBinArray) result; - assertThat(vba.length()).isEqualTo(1); assertThat(vba.getBytes(0)).isEqualTo("ab".getBytes(StandardCharsets.UTF_8)); } @Test void decode_multipleStrings_decompressesAll() { - // Given: symbol 0 = 'H'; strings ["H", "HH", "!"] where "!" uses ESCAPE - // compressed: [0x00] | [0x00, 0x00] | [0xFF, 0x21] - var sut = new FsstEncoding(); DecodeContext ctx = buildCtx( - new long[]{0x48L}, - new byte[]{1}, + new long[]{0x48L}, new byte[]{1}, new byte[]{0x00, 0x00, 0x00, (byte) 0xFF, 0x21}, - new int[]{1, 2, 1}, - new int[]{0, 1, 3, 5}, - 3 - ); - - // When - var result = sut.decode(ctx); - - // Then - assertThat(result).isInstanceOf(VarBinArray.class); + new int[]{1, 2, 1}, new int[]{0, 1, 3, 5}, 3); + var result = DECODER.decode(ctx); VarBinArray vba = (VarBinArray) result; - assertThat(vba.length()).isEqualTo(3); assertThat(vba.getBytes(0)).isEqualTo("H".getBytes(StandardCharsets.UTF_8)); assertThat(vba.getBytes(1)).isEqualTo("HH".getBytes(StandardCharsets.UTF_8)); assertThat(vba.getBytes(2)).isEqualTo("!".getBytes(StandardCharsets.UTF_8)); @@ -262,15 +197,9 @@ void decode_multipleStrings_decompressesAll() { @Test void decode_missingMetadata_throwsVortexException() { - // Given - var sut = new FsstEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_FSST, null, new ArrayNode[0], new int[0], null); - DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, 0, new MemorySegment[0], - buildRegistry(), Arena.ofAuto()); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .isInstanceOf(VortexException.class); + DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, 0, new MemorySegment[0], REGISTRY, Arena.ofAuto()); + assertThatThrownBy(() -> DECODER.decode(ctx)).isInstanceOf(VortexException.class); } } @@ -279,19 +208,12 @@ class Metadata { @Test void encode_metadata_ptypes_areI32() throws Exception { - // Given — FsstEncoding always uses I32 (ordinal=6) for both ptype fields - // if either tag drifts, the ptype reads as 0 (U8) which is proto3 default String[] data = {"hello", "world", "hello", "fsst"}; - var sut = new FsstEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); - FSSTMetadata meta = - FSSTMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); - - // Then - assertThat(meta.uncompressed_lengths_ptype().value()).isEqualTo(6); // I32 - assertThat(meta.codes_offsets_ptype().value()).isEqualTo(6); // I32 + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + FSSTMetadata meta = FSSTMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + assertThat(meta.uncompressed_lengths_ptype().value()).isEqualTo(6); + assertThat(meta.codes_offsets_ptype().value()).isEqualTo(6); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ListEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoderTest.java similarity index 63% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ListEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoderTest.java index 975b556b..2a09fefc 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ListEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoderTest.java @@ -1,10 +1,22 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.ListMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.ListArray; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.ListData; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.ListMetadata; +import io.github.dfa1.vortex.reader.decode.ListEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -14,8 +26,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class ListEncodingTest { +class ListEncodingEncoderTest { + private static final ListEncodingEncoder ENCODER = new ListEncodingEncoder(); + private static final ListEncodingDecoder DECODER = new ListEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static ArrayNode toArrayNode(EncodeNode node) { ArrayNode[] children = new ArrayNode[node.children().length]; @@ -25,46 +40,27 @@ private static ArrayNode toArrayNode(EncodeNode node) { return ArrayNode.of(node.encodingId(), node.metadata(), children, node.bufferIndices(), ArrayStats.empty()); } - private static Registry registry() { - return Registry.builder() - .register(new ListEncoding()) - .register(new PrimitiveEncoding()) - .build(); - } - @Nested class Encode { @Test void accepts_listDtype_true() { - // Given - ListEncoding sut = new ListEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.LIST_I32)).isTrue(); + assertThat(ENCODER.accepts(DTypes.LIST_I32)).isTrue(); } @Test void accepts_primitiveDtype_false() { - // Given - ListEncoding sut = new ListEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I32)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I32)).isFalse(); } @Test void encode_producesTwoChildren_noBuffers() { - // Given int[] elements = {1, 2, 3, 4, 5}; long[] offsets = {0, 2, 5}; ListData data = new ListData(elements, offsets, 2); - ListEncoding sut = new ListEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); - // Then assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_LIST); assertThat(result.rootNode().bufferIndices()).isEmpty(); assertThat(result.rootNode().children()).hasSize(2); @@ -76,20 +72,16 @@ class Decode { @Test void roundTrip_i32Elements_preservesValues() { - // Given int[] elements = {10, 20, 30, 40, 50}; long[] offsets = {0, 2, 3, 5}; ListData data = new ListData(elements, offsets, 3); - ListEncoding sut = new ListEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 3, bufs, registry(), Arena.global()); - ListArray decoded = (ListArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 3, bufs, REGISTRY, Arena.global()); + ListArray decoded = (ListArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(3); IntArray decodedElems = (IntArray) decoded.elements(); assertThat(decodedElems.length()).isEqualTo(5); @@ -105,20 +97,16 @@ void roundTrip_i32Elements_preservesValues() { @Test void roundTrip_emptyLists_preservesOffsets() { - // Given int[] elements = {}; long[] offsets = {0, 0, 0}; ListData data = new ListData(elements, offsets, 2); - ListEncoding sut = new ListEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 2, bufs, registry(), Arena.global()); - ListArray decoded = (ListArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 2, bufs, REGISTRY, Arena.global()); + ListArray decoded = (ListArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(2); assertThat(decoded.elements().length()).isEqualTo(0); assertThat(decoded.offsets().length()).isEqualTo(3); @@ -126,20 +114,16 @@ void roundTrip_emptyLists_preservesOffsets() { @Test void roundTrip_singleList_preservesValues() { - // Given int[] elements = {7, 8, 9}; long[] offsets = {0, 3}; ListData data = new ListData(elements, offsets, 1); - ListEncoding sut = new ListEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 1, bufs, registry(), Arena.global()); - ListArray decoded = (ListArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 1, bufs, REGISTRY, Arena.global()); + ListArray decoded = (ListArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(1); IntArray decodedElems = (IntArray) decoded.elements(); assertThat(decodedElems.length()).isEqualTo(3); @@ -150,31 +134,23 @@ void roundTrip_singleList_preservesValues() { @Test void decode_wrongDtype_throws() { - // Given - ListEncoding sut = new ListEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_LIST, null, new ArrayNode[0], new int[0], ArrayStats.empty()); - DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0, new MemorySegment[0], registry(), Arena.global()); + DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0, new MemorySegment[0], REGISTRY, Arena.global()); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .hasMessageContaining("DType.List"); + assertThatThrownBy(() -> DECODER.decode(ctx)).hasMessageContaining("DType.List"); } @Test void decode_wrongChildCount_throws() { - // Given - ListEncoding sut = new ListEncoding(); ArrayNode child = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[0], ArrayStats.empty()); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_LIST, java.nio.ByteBuffer.wrap(new byte[0]), new ArrayNode[]{child}, new int[0], ArrayStats.empty()); - DecodeContext ctx = new DecodeContext(node, DTypes.LIST_I32, 0, new MemorySegment[0], registry(), Arena.global()); + DecodeContext ctx = new DecodeContext(node, DTypes.LIST_I32, 0, new MemorySegment[0], REGISTRY, Arena.global()); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .hasMessageContaining("expected 2 or 3 children"); + assertThatThrownBy(() -> DECODER.decode(ctx)).hasMessageContaining("expected 2 or 3 children"); } } @@ -183,19 +159,14 @@ class Metadata { @Test void encode_metadata_elementsLen_matchesElementCount() throws Exception { - // Given — 5 elements total across 2 outer lists - // if tag drifts, elements_len reads as 0 and decode allocates wrong-sized arrays int[] elements = {1, 2, 3, 4, 5}; long[] offsets = {0L, 2L, 5L}; ListData data = new ListData(elements, offsets, 2); - ListEncoding sut = new ListEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); - ListMetadata meta = - ListMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + ListMetadata meta = ListMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.elements_len()).isEqualTo(5); } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ListViewEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoderTest.java similarity index 68% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ListViewEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoderTest.java index 539ccfa5..8f4f8af4 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ListViewEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoderTest.java @@ -1,9 +1,22 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.ListViewMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.ListViewArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.ListViewData; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestDecodeContexts; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.ListViewMetadata; +import io.github.dfa1.vortex.reader.decode.ListViewEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -13,8 +26,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class ListViewEncodingTest { +class ListViewEncodingEncoderTest { + private static final ListViewEncodingEncoder ENCODER = new ListViewEncodingEncoder(); + private static final ListViewEncodingDecoder DECODER = new ListViewEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static ArrayNode toArrayNode(EncodeNode node) { ArrayNode[] children = new ArrayNode[node.children().length]; @@ -24,44 +40,28 @@ private static ArrayNode toArrayNode(EncodeNode node) { return ArrayNode.of(node.encodingId(), node.metadata(), children, node.bufferIndices(), ArrayStats.empty()); } - private static Registry registry() { - return TestRegistry.of(new ListViewEncoding(), new PrimitiveEncoding()); - } - @Nested class Encode { @Test void accepts_listDtype_true() { - // Given - ListViewEncoding sut = new ListViewEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.LIST_I32)).isTrue(); + assertThat(ENCODER.accepts(DTypes.LIST_I32)).isTrue(); } @Test void accepts_primitiveDtype_false() { - // Given - ListViewEncoding sut = new ListViewEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I32)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I32)).isFalse(); } @Test void encode_producesThreeChildren_noBuffers() { - // Given int[] elements = {1, 2, 3, 4, 5}; int[] offsets = {0, 2}; int[] sizes = {2, 3}; ListViewData data = new ListViewData(elements, offsets, sizes, 2); - ListViewEncoding sut = new ListViewEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); - // Then assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_LISTVIEW); assertThat(result.rootNode().bufferIndices()).isEmpty(); assertThat(result.rootNode().children()).hasSize(3); @@ -73,21 +73,17 @@ class Decode { @Test void roundTrip_i32Elements_preservesValues() { - // Given int[] elements = {10, 20, 30, 40, 50}; int[] offsets = {0, 2, 3}; int[] sizes = {2, 1, 2}; ListViewData data = new ListViewData(elements, offsets, sizes, 3); - ListViewEncoding sut = new ListViewEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 3, bufs, registry(), Arena.global()); - ListViewArray decoded = (ListViewArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 3, bufs, REGISTRY, Arena.global()); + ListViewArray decoded = (ListViewArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(3); IntArray decodedElems = (IntArray) decoded.elements(); assertThat(decodedElems.length()).isEqualTo(5); @@ -102,21 +98,17 @@ void roundTrip_i32Elements_preservesValues() { @Test void roundTrip_emptyLists_preservesZeroSizes() { - // Given int[] elements = {}; int[] offsets = {0, 0}; int[] sizes = {0, 0}; ListViewData data = new ListViewData(elements, offsets, sizes, 2); - ListViewEncoding sut = new ListViewEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 2, bufs, registry(), Arena.global()); - ListViewArray decoded = (ListViewArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 2, bufs, REGISTRY, Arena.global()); + ListViewArray decoded = (ListViewArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(2); assertThat(decoded.elements().length()).isEqualTo(0); assertThat(decoded.offsets().length()).isEqualTo(2); @@ -125,21 +117,17 @@ void roundTrip_emptyLists_preservesZeroSizes() { @Test void roundTrip_singleList_preservesValues() { - // Given int[] elements = {7, 8, 9}; int[] offsets = {0}; int[] sizes = {3}; ListViewData data = new ListViewData(elements, offsets, sizes, 1); - ListViewEncoding sut = new ListViewEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 1, bufs, registry(), Arena.global()); - ListViewArray decoded = (ListViewArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 1, bufs, REGISTRY, Arena.global()); + ListViewArray decoded = (ListViewArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(1); IntArray decodedElems = (IntArray) decoded.elements(); assertThat(decodedElems.length()).isEqualTo(3); @@ -150,31 +138,23 @@ void roundTrip_singleList_preservesValues() { @Test void decode_wrongDtype_throws() { - // Given - ListViewEncoding sut = new ListViewEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_LISTVIEW, null, new ArrayNode[0], new int[0], ArrayStats.empty()); - DecodeContext ctx = TestDecodeContexts.of(node, DTypes.I32).registry(registry()).build(); + DecodeContext ctx = TestDecodeContexts.of(node, DTypes.I32).registry(REGISTRY).build(); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .hasMessageContaining("DType.List"); + assertThatThrownBy(() -> DECODER.decode(ctx)).hasMessageContaining("DType.List"); } @Test void decode_wrongChildCount_throws() { - // Given - ListViewEncoding sut = new ListViewEncoding(); ArrayNode child = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[0], ArrayStats.empty()); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_LISTVIEW, java.nio.ByteBuffer.wrap(new byte[0]), new ArrayNode[]{child}, new int[0], ArrayStats.empty()); - DecodeContext ctx = TestDecodeContexts.of(node, DTypes.LIST_I32).registry(registry()).build(); + DecodeContext ctx = TestDecodeContexts.of(node, DTypes.LIST_I32).registry(REGISTRY).build(); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .hasMessageContaining("expected 3 or 4 children"); + assertThatThrownBy(() -> DECODER.decode(ctx)).hasMessageContaining("expected 3 or 4 children"); } } @@ -183,21 +163,16 @@ class Metadata { @Test void encode_metadata_elementsLen_matchesElementCount() throws Exception { - // Given — 5 elements across 2 outer lists - // if tag drifts, elements_len reads as 0 and decode allocates wrong-sized arrays int[] elements = {1, 2, 3, 4, 5}; int[] offsets = {0, 2}; int[] sizes = {2, 3}; ListViewData data = new ListViewData(elements, offsets, sizes, 2); - ListViewEncoding sut = new ListViewEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); java.nio.ByteBuffer metaBuf = result.rootNode().metadata().duplicate(); java.lang.foreign.MemorySegment metaSeg = java.lang.foreign.MemorySegment.ofBuffer(metaBuf); ListViewMetadata meta = ListViewMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.elements_len()).isEqualTo(5); } } diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoderTest.java new file mode 100644 index 00000000..f7939b52 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoderTest.java @@ -0,0 +1,162 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.MaskedEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class MaskedEncodingEncoderTest { + + private static final MaskedEncodingDecoder DECODER = new MaskedEncodingDecoder(); + private static final PrimitiveEncodingEncoder PRIM_ENCODER = new PrimitiveEncodingEncoder(); + private static final BoolEncodingEncoder BOOL_ENCODER = new BoolEncodingEncoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); + + private static EncodeResult maskedResult(int[] values, boolean[] validity) { + DType i32 = new DType.Primitive(PType.I32, false); + EncodeResult childResult = PRIM_ENCODER.encode(i32, values, EncodeTestHelper.testCtx()); + + List allBuffers = new ArrayList<>(childResult.buffers()); + EncodeNode[] children; + + if (validity == null) { + children = new EncodeNode[]{childResult.rootNode()}; + } else { + DType boolDtype = new DType.Bool(false); + EncodeResult validityResult = BOOL_ENCODER.encode(boolDtype, validity, EncodeTestHelper.testCtx()); + EncodeNode remapped = EncodeNode.remapBufferIndices( + validityResult.rootNode(), childResult.buffers().size()); + allBuffers.addAll(validityResult.buffers()); + children = new EncodeNode[]{childResult.rootNode(), remapped}; + } + + EncodeNode maskedNode = new EncodeNode( + EncodingId.VORTEX_MASKED, null, children, new int[]{}); + return new EncodeResult(maskedNode, allBuffers, null, null); + } + + @Test + void oneChild_noValidity_allPositionsValid() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + EncodeResult ctx = maskedResult(new int[]{10, 20, 30}, null); + + Array result = DECODER.decode(DecodeTestHelper.toDecodeContext(ctx, 3L, i32Nullable, REGISTRY)); + + assertThat(result).isInstanceOf(MaskedArray.class); + MaskedArray masked = (MaskedArray) result; + assertThat(masked.length()).isEqualTo(3); + assertThat(masked.isValid(0)).isTrue(); + assertThat(masked.isValid(1)).isTrue(); + assertThat(masked.isValid(2)).isTrue(); + } + + @Test + void twoChildren_withValidity_masksNulls() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + EncodeResult ctx = maskedResult(new int[]{1, 2, 3, 4, 5}, + new boolean[]{true, false, true, false, true}); + + Array result = DECODER.decode(DecodeTestHelper.toDecodeContext(ctx, 5L, i32Nullable, REGISTRY)); + + MaskedArray masked = (MaskedArray) result; + assertThat(masked.length()).isEqualTo(5); + assertThat(masked.isValid(0)).isTrue(); + assertThat(masked.isValid(1)).isFalse(); + assertThat(masked.isValid(2)).isTrue(); + assertThat(masked.isValid(3)).isFalse(); + assertThat(masked.isValid(4)).isTrue(); + } + + @Test + void dtype_isNullable() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + EncodeResult ctx = maskedResult(new int[]{1, 2, 3}, null); + + Array result = DECODER.decode(DecodeTestHelper.toDecodeContext(ctx, 3L, i32Nullable, REGISTRY)); + + assertThat(result.dtype().nullable()).isTrue(); + } + + @Test + void inner_containsChildValues() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + EncodeResult ctx = maskedResult(new int[]{7, 8, 9}, null); + + MaskedArray result = (MaskedArray) DECODER.decode(DecodeTestHelper.toDecodeContext(ctx, 3L, i32Nullable, REGISTRY)); + IntArray inner = (IntArray) result.inner(); + + assertThat(ArraySegments.of(inner).get(PTypeIO.LE_INT, 0L)).isEqualTo(7); + assertThat(ArraySegments.of(inner).get(PTypeIO.LE_INT, 4L)).isEqualTo(8); + assertThat(ArraySegments.of(inner).get(PTypeIO.LE_INT, 8L)).isEqualTo(9); + } + + @Test + void buffersPresentThrows() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + + EncodeNode childNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode maskedNode = new EncodeNode( + EncodingId.VORTEX_MASKED, null, + new EncodeNode[]{childNode}, new int[]{1}); + MemorySegment dummyBuf = Arena.ofAuto().allocate(4); + EncodeResult result = new EncodeResult(maskedNode, List.of(dummyBuf, dummyBuf), null, null); + + assertThatThrownBy(() -> DECODER.decode(DecodeTestHelper.toDecodeContext(result, 1L, i32Nullable, REGISTRY))) + .isInstanceOf(VortexException.class) + .hasMessageContaining("expected 0 buffers"); + } + + @Test + void zeroChildrenThrows() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + + EncodeNode maskedNode = new EncodeNode( + EncodingId.VORTEX_MASKED, null, new EncodeNode[]{}, new int[]{}); + EncodeResult result = new EncodeResult(maskedNode, List.of(), null, null); + + assertThatThrownBy(() -> DECODER.decode(DecodeTestHelper.toDecodeContext(result, 0L, i32Nullable, REGISTRY))) + .isInstanceOf(VortexException.class) + .hasMessageContaining("expected 1 or 2 children"); + } + + @Test + void threeChildrenThrows() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + + DType i32 = new DType.Primitive(PType.I32, false); + EncodeResult childResult = PRIM_ENCODER.encode(i32, new int[]{1}, EncodeTestHelper.testCtx()); + EncodeNode childNode = childResult.rootNode(); + EncodeNode maskedNode = new EncodeNode( + EncodingId.VORTEX_MASKED, null, + new EncodeNode[]{childNode, childNode, childNode}, new int[]{}); + List bufs = new ArrayList<>(childResult.buffers()); + EncodeResult result = new EncodeResult(maskedNode, bufs, null, null); + + assertThatThrownBy(() -> DECODER.decode(DecodeTestHelper.toDecodeContext(result, 1L, i32Nullable, REGISTRY))) + .isInstanceOf(VortexException.class) + .hasMessageContaining("expected 1 or 2 children"); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoderTest.java new file mode 100644 index 00000000..b080585d --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoderTest.java @@ -0,0 +1,58 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.NullArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.NullEncodingDecoder; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +import static org.assertj.core.api.Assertions.assertThat; + +class NullEncodingEncoderTest { + + @Nested + class Encode { + + @Test + void encode_producesEmptyNode() { + // Given + var sut = new NullEncodingEncoder(); + + // When + EncodeResult result = sut.encode(DTypes.NULL, null, EncodeTestHelper.testCtx()); + + // Then + assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_NULL); + assertThat(result.rootNode().children()).isEmpty(); + assertThat(result.buffers()).isEmpty(); + } + + @Test + void encode_thenDecode_roundTrips() { + // Given + long rowCount = 10L; + var encoder = new NullEncodingEncoder(); + var decoder = new NullEncodingDecoder(); + + // When + EncodeResult encoded = encoder.encode(DTypes.NULL, null, EncodeTestHelper.testCtx()); + ArrayNode node = ArrayNode.of(encoded.rootNode().encodingId(), null, new ArrayNode[0], new int[0], null); + DecodeContext ctx = new DecodeContext(node, DTypes.NULL, rowCount, new MemorySegment[0], + ReadRegistry.empty(), Arena.ofAuto()); + + // Then + var decoded = decoder.decode(ctx); + assertThat(decoded).isInstanceOf(NullArray.class); + assertThat(decoded.length()).isEqualTo(rowCount); + } + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoderTest.java new file mode 100644 index 00000000..3779e503 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoderTest.java @@ -0,0 +1,22 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class PcoEncodingEncoderTest { + + @Test + void encode_throwsVortexException() { + var sut = new PcoEncodingEncoder(); + DType dtype = new DType.Primitive(PType.I64, false); + + assertThatThrownBy(() -> sut.encode(dtype, new long[]{1L, 2L, 3L}, EncodeTestHelper.testCtx())) + .isInstanceOf(VortexException.class) + .hasMessageContaining("not implemented"); + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/PrimitiveEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoderTest.java similarity index 63% rename from core/src/test/java/io/github/dfa1/vortex/encoding/PrimitiveEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoderTest.java index 8cb6ea37..4eb5f611 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/PrimitiveEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoderTest.java @@ -1,4 +1,4 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; @@ -7,6 +7,18 @@ import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.encoding.TestSegments; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -18,8 +30,11 @@ import static org.assertj.core.api.Assertions.assertThat; -/// Property: encode then decode is lossless for all primitive types and array sizes. -class PrimitiveEncodingTest { +class PrimitiveEncodingEncoderTest { + + private static final PrimitiveEncodingEncoder ENCODER = new PrimitiveEncodingEncoder(); + private static final PrimitiveEncodingDecoder DECODER = new PrimitiveEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER); @Nested class Encode { @@ -57,78 +72,52 @@ static Stream doubleArrays() { @ParameterizedTest @MethodSource("longArrays") void encodeDecode_i64_isLossless(long[] data) { - // Given DType dtype = new DType.Primitive(PType.I64, false); - var sut = new PrimitiveEncoding(); - Registry registry = TestRegistry.of(sut); - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); - // Then — roundtrip lossless assertThat(result.length()).isEqualTo(data.length); - var le = PTypeIO.LE_LONG; for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 8)).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(data[i]); } } @ParameterizedTest @MethodSource("intArrays") void encodeDecode_i32_isLossless(int[] data) { - // Given DType dtype = new DType.Primitive(PType.I32, false); - var sut = new PrimitiveEncoding(); - Registry registry = TestRegistry.of(sut); - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); - // Then — roundtrip lossless assertThat(result.length()).isEqualTo(data.length); - var le = PTypeIO.LE_INT; for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 4)).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).isEqualTo(data[i]); } } @ParameterizedTest @MethodSource("doubleArrays") void encodeDecode_f64_isLossless(double[] data) { - // Given DType dtype = new DType.Primitive(PType.F64, false); - var sut = new PrimitiveEncoding(); - Registry registry = TestRegistry.of(sut); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry); - Array result = sut.decode(ctx); - - // Then — roundtrip lossless assertThat(result.length()).isEqualTo(data.length); - var le = PTypeIO.LE_DOUBLE; for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 8)).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, (long) i * 8)).isEqualTo(data[i]); } } @ParameterizedTest @MethodSource("longArrays") void encodedSize_equalsBytesInBuffer(long[] data) { - // Given DType dtype = new DType.Primitive(PType.I64, false); - var sut = new PrimitiveEncoding(); - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); - // Then — no compression: wire size = n * elemBytes - long totalBytes = encoded.buffers().stream().mapToLong(java.lang.foreign.MemorySegment::byteSize).sum(); + long totalBytes = encoded.buffers().stream().mapToLong(MemorySegment::byteSize).sum(); assertThat(totalBytes).isEqualTo((long) data.length * 8); } } @@ -138,17 +127,16 @@ class Decode { @Test void decode_withValidityChild_returnsMaskedArray() { - // Given — 4 I32 values; positions 1 and 3 are null (validity bitmap: 0b00000101 = 0x05) - int[] raw = {10, 0, 20, 0}; // garbage at null positions + int[] raw = {10, 0, 20, 0}; MemorySegment valuesSeg = TestSegments.leInts(raw); - MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{0x05}); // bits 0,2 set + MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{0x05}); ArrayNode validityNode = ArrayNode.of( EncodingId.VORTEX_BOOL, null, new ArrayNode[0], new int[]{1}, ArrayStats.empty()); ArrayNode primNode = ArrayNode.of( EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[]{validityNode}, new int[]{0}, ArrayStats.empty()); - Registry registry = TestRegistry.of(new PrimitiveEncoding(), new BoolEncoding()); + ReadRegistry registry = TestRegistry.ofDecoders(new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); DType dtype = new DType.Primitive(PType.I32, false); DecodeContext ctx = new DecodeContext( @@ -156,12 +144,8 @@ void decode_withValidityChild_returnsMaskedArray() { new MemorySegment[]{valuesSeg, validitySeg}, registry, Arena.global()); - PrimitiveEncoding sut = new PrimitiveEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then — returns MaskedArray; only valid positions are usable assertThat(result).isInstanceOf(MaskedArray.class); MaskedArray masked = (MaskedArray) result; assertThat(masked.inner()).isInstanceOf(IntArray.class); @@ -176,27 +160,19 @@ void decode_withValidityChild_returnsMaskedArray() { @Test void decode_noValidityChild_returnsPlainArray() { - // Given — 3 I32 values; no validity child int[] raw = {1, 2, 3}; MemorySegment valuesSeg = TestSegments.leInts(raw); ArrayNode primNode = ArrayNode.of( EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, ArrayStats.empty()); - Registry registry = TestRegistry.of(new PrimitiveEncoding()); - DType dtype = new DType.Primitive(PType.I32, false); DecodeContext ctx = new DecodeContext( primNode, dtype, raw.length, new MemorySegment[]{valuesSeg}, - registry, Arena.global()); - - PrimitiveEncoding sut = new PrimitiveEncoding(); - - // When - Array result = sut.decode(ctx); + REGISTRY, Arena.global()); - // Then — plain array, not MaskedArray + Array result = DECODER.decode(ctx); assertThat(result).isInstanceOf(IntArray.class); } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/RandomAccessTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RandomAccessTest.java similarity index 51% rename from core/src/test/java/io/github/dfa1/vortex/encoding/RandomAccessTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/RandomAccessTest.java index a2a6266c..7430db51 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/RandomAccessTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RandomAccessTest.java @@ -1,8 +1,19 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.FrameOfReferenceEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -13,17 +24,6 @@ import static org.assertj.core.api.Assertions.assertThat; /// Verifies that decoded arrays support correct random-order element access. -/// -/// The O(1) random-access claim in docs/explanation.md holds only if every -/// position decodes independently. This test catches encoder bugs where -/// value[N] accidentally depends on value[N-1] — e.g. an unapplied delta, -/// an off-by-one in residual accumulation, or a stateful read cursor that -/// advances forward and gives wrong values when accessed out of order. -/// -/// Data: {@code value[i] = i * 1_000_003 + 7} — every position unique, -/// prime multiplier prevents accidental aliasing. -/// Access orders: reverse (N-1…0) and seeded random — the combination -/// catches symmetric bugs that reverse alone might miss. class RandomAccessTest { private static final int N = 1024; @@ -31,37 +31,36 @@ class RandomAccessTest { static Stream encodings() { return Stream.of( - Arguments.of("Primitive", new PrimitiveEncoding(), DTypes.I64), - Arguments.of("BitpackedU64", new BitpackedEncoding(), DTypes.U64), - Arguments.of("FrameOfReference", new FrameOfReferenceEncoding(), DTypes.I64) + Arguments.of("Primitive", new PrimitiveEncodingEncoder(), new PrimitiveEncodingDecoder(), DTypes.I64), + Arguments.of("BitpackedU64", new BitpackedEncodingEncoder(), new BitpackedEncodingDecoder(), DTypes.U64), + Arguments.of("FrameOfReference", new FrameOfReferenceEncodingEncoder(), new FrameOfReferenceEncodingDecoder(), DTypes.I64) ); } @ParameterizedTest(name = "{0}") @MethodSource("encodings") - void randomOrderAccess_matchesForwardOrder(String name, Encoding sut, DType dtype) { - // Given — unique value per position, no two rows the same + void randomOrderAccess_matchesForwardOrder(String name, + EncodingEncoder encoder, io.github.dfa1.vortex.reader.decode.EncodingDecoder decoder, DType dtype) { long[] original = new long[N]; for (int i = 0; i < N; i++) { original[i] = (long) i * 1_000_003L + 7L; } - Registry registry = TestRegistry.withPrimitive(sut); - EncodeResult encoded = sut.encode(dtype, original, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, N, dtype, registry); + ReadRegistry registry = decoder instanceof PrimitiveEncodingDecoder + ? TestRegistry.ofDecoders(decoder) + : TestRegistry.ofDecoders(decoder, new PrimitiveEncodingDecoder()); + EncodeResult encoded = encoder.encode(dtype, original, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, N, dtype, registry); - // When - Array array = sut.decode(ctx); + Array array = decoder.decode(ctx); LongArray result = (LongArray) array; - // Then — reverse order for (int i = N - 1; i >= 0; i--) { assertThat(result.getLong(i)) .as("reverse access at index %d", i) .isEqualTo(original[i]); } - // Then — random order (seeded for reproducibility) Random rng = new Random(SEED); for (int check = 0; check < N; check++) { int i = rng.nextInt(N); diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/RleEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoderTest.java similarity index 55% rename from core/src/test/java/io/github/dfa1/vortex/encoding/RleEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoderTest.java index f7adef03..09afba71 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/RleEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoderTest.java @@ -1,6 +1,5 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.RLEMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; @@ -8,6 +7,22 @@ import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.decode.KnownArrayNode; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.RLEMetadata; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.RleEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -20,14 +35,11 @@ import static org.assertj.core.api.Assertions.assertThat; -class RleEncodingTest { +class RleEncodingEncoderTest { - private static Registry registry() { - return Registry.builder() - .register(new RleEncoding()) - .register(new PrimitiveEncoding()) - .build(); - } + private static final RleEncodingEncoder ENCODER = new RleEncodingEncoder(); + private static final RleEncodingDecoder DECODER = new RleEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static KnownArrayNode toArrayNode(EncodeNode enc) { ArrayNode[] children = new ArrayNode[enc.children().length]; @@ -42,53 +54,35 @@ class Encode { @Test void roundTrip_empty_i32() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); - - // When - EncodeResult encoded = sut.encode(dtype, new int[0], EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, 0, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, new int[0], EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, 0, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isZero(); } @Test void roundTrip_singleElement_i32() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = {42}; - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(1); assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, 0)).isEqualTo(42); } @Test void roundTrip_constantArray_i32() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int n = 2048; int[] data = new int[n]; for (int i = 0; i < n; i++) { data[i] = 99; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, n, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(n); for (int i = 0; i < n; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(99); @@ -97,41 +91,27 @@ void roundTrip_constantArray_i32() { @Test void roundTrip_classicRunLengthData_i32() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = {1, 1, 1, 2, 2, 3}; - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry()); - Array result = sut.decode(ctx); - - // Then - assertThat(result.length()).isEqualTo(data.length); - int[] expected = {1, 1, 1, 2, 2, 3}; - for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(expected[i]); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); } } @Test void roundTrip_multipleChunks_i32() { - // Given: spans 3 chunks - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int n = 3000; int[] data = new int[n]; for (int i = 0; i < n; i++) { data[i] = i / 100; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, n, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(n); for (int i = 0; i < n; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(i / 100); @@ -140,18 +120,11 @@ void roundTrip_multipleChunks_i32() { @Test void roundTrip_i64() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I64, false); long[] data = {100L, 100L, 200L, 300L, 300L, 300L}; - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry()); - Array result = sut.decode(ctx); - - // Then - assertThat(result.length()).isEqualTo(data.length); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); } @@ -160,20 +133,14 @@ void roundTrip_i64() { @ParameterizedTest @ValueSource(ints = {1, 512, 1023, 1024, 1025, 2048, 2049}) void roundTrip_variousLengths_i32(int n) { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = new int[n]; for (int i = 0; i < n; i++) { data[i] = i / 50; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, n, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(n); for (int i = 0; i < n; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(i / 50); @@ -182,42 +149,27 @@ void roundTrip_variousLengths_i32(int n) { @Test void roundTrip_allDifferent_u16() { - // Given: worst case — every consecutive value is unique (no compression) - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.U16, false); short[] data = new short[256]; for (int i = 0; i < 256; i++) { data[i] = (short) i; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry()); - Array result = sut.decode(ctx); - - // Then - assertThat(result.length()).isEqualTo(data.length); - var le = PTypeIO.LE_SHORT; + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { - assertThat(Short.toUnsignedInt(ArraySegments.of(result).get(le, (long) i * 2))) + assertThat(Short.toUnsignedInt(ArraySegments.of(result).get(PTypeIO.LE_SHORT, (long) i * 2))) .as("index %d", i).isEqualTo(i); } } @Test void roundTrip_negativeValues_i32() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = {-3, -3, -1, -1, 0, 0, 5}; - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry()); - Array result = sut.decode(ctx); - - // Then - assertThat(result.length()).isEqualTo(data.length); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); } @@ -229,67 +181,46 @@ class Decode { @Test void decode_exactlyOneChunk_correctLength() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = new int[1024]; for (int i = 0; i < 1024; i++) { data[i] = i / 10; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, 1024, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, 1024, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(1024); } @Test void decode_crossesChunkBoundary_correctValues() { - // Given: values span the chunk boundary at element 1024 - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int n = 2048; int[] data = new int[n]; for (int i = 0; i < n; i++) { data[i] = i / 100; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, registry()); - Array result = sut.decode(ctx); - - // Then — verify values near the chunk boundary + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, n, dtype, REGISTRY); + Array result = DECODER.decode(ctx); for (int i = 1000; i < 1048; i++) { - assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)) - .as("index %d", i).isEqualTo(i / 100); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(i / 100); } } @Test void decode_nullableIndices_returnsMaskedArrayWithCorrectValidity() { - // Given — encode [10, 10, 20, 20]; then inject a validity bitmap into the indices node - // so positions 1 and 3 are null. Valid outputs: [10, null, 20, null]. - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = {10, 10, 20, 20}; - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); - // Original buffers: [values_buf, indices_buf, offsets_buf] List originalBufs = new ArrayList<>(encoded.buffers()); - // Validity bitmap: bits 0 and 2 set → 0b00000101 = 0x05 → positions 0,2 valid MemorySegment validityBuf = MemorySegment.ofArray(new byte[]{0x05}); - originalBufs.add(validityBuf); // index 3 + originalBufs.add(validityBuf); MemorySegment[] segments = originalBufs.toArray(new MemorySegment[0]); - // Rebuild the ArrayNode tree from the encode result KnownArrayNode origRoot = toArrayNode(encoded.rootNode()); - // RLE root children: [values(0), indices(1), offsets(2)] KnownArrayNode origIndices = (KnownArrayNode) origRoot.children()[1]; - // Wrap indices with a validity child pointing to buffer 3 ArrayNode validityNode = ArrayNode.of( EncodingId.VORTEX_BOOL, null, new ArrayNode[0], new int[]{3}, ArrayStats.empty()); ArrayNode nullableIndices = ArrayNode.of( @@ -300,17 +231,11 @@ void decode_nullableIndices_returnsMaskedArrayWithCorrectValidity() { new ArrayNode[]{origRoot.children()[0], nullableIndices, origRoot.children()[2]}, origRoot.bufferIndices(), ArrayStats.empty()); - Registry reg = Registry.builder() - .register(sut) - .register(new PrimitiveEncoding()) - .register(new BoolEncoding()) - .build(); + ReadRegistry reg = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); DecodeContext ctx = new DecodeContext(root, dtype, data.length, segments, reg, Arena.ofAuto()); - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then — MaskedArray; valid at positions 0 and 2 assertThat(result).isInstanceOf(MaskedArray.class); MaskedArray masked = (MaskedArray) result; assertThat(masked.isValid(0)).isTrue(); @@ -324,21 +249,15 @@ void decode_nullableIndices_returnsMaskedArrayWithCorrectValidity() { @Test void decode_partialLastChunk_correctLength() { - // Given: 1500 elements — two chunks (1024 full + 476 partial) - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int n = 1500; int[] data = new int[n]; for (int i = 0; i < n; i++) { data[i] = i / 100; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, n, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(n); for (int i = 0; i < n; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(i / 100); @@ -347,19 +266,14 @@ void decode_partialLastChunk_correctLength() { @Test void encode_i32_metadata_valuesLen_matchesRunCount() throws Exception { - // Given — 2 distinct runs; if tag drifts, values_len reads as 0 (proto3 default) int[] data = {1, 1, 1, 2, 2, 2}; - RleEncoding sut = new RleEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - RLEMetadata meta = - RLEMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + var metaSeg = MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + RLEMetadata meta = RLEMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.values_len()).isEqualTo(2); assertThat(meta.indices_len()).isGreaterThan(0); - assertThat(meta.indices_ptype().value()).isEqualTo(1); // U16 + assertThat(meta.indices_ptype().value()).isEqualTo(1); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/RunEndEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoderTest.java similarity index 62% rename from core/src/test/java/io/github/dfa1/vortex/encoding/RunEndEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoderTest.java index 1fb239ee..a14bb58f 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/RunEndEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoderTest.java @@ -1,11 +1,23 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.RunEndMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.RunEndMetadata; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.RunEndEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -18,7 +30,11 @@ import static org.assertj.core.api.Assertions.assertThat; -class RunEndEncodingTest { +class RunEndEncodingEncoderTest { + + private static final RunEndEncodingEncoder ENCODER = new RunEndEncodingEncoder(); + private static final RunEndEncodingDecoder DECODER = new RunEndEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Decode { @@ -27,8 +43,8 @@ private static DecodeContext buildCtx( DType dtype, long rowCount, long[] ends, long[] values, PType endsPtype, long offset ) { - byte[] metaBytes = new RunEndMetadata(io.github.dfa1.vortex.proto.PType.fromValue(endsPtype.ordinal()), ends.length, offset) - .encode(); + byte[] metaBytes = new RunEndMetadata( + io.github.dfa1.vortex.proto.PType.fromValue(endsPtype.ordinal()), ends.length, offset).encode(); byte[] endsBuf = toLEBytes(ends, endsPtype); byte[] valBuf = toLEBytes(values, PType.I64); @@ -47,9 +63,7 @@ private static DecodeContext buildCtx( MemorySegment.ofArray(valBuf) }; - Registry registry = TestRegistry.of(new RunEndEncoding(), new PrimitiveEncoding()); - - return new DecodeContext(reNode, dtype, rowCount, segments, registry, java.lang.foreign.Arena.global()); + return new DecodeContext(reNode, dtype, rowCount, segments, REGISTRY, java.lang.foreign.Arena.global()); } private static byte[] toLEBytes(long[] values, PType ptype) { @@ -70,61 +84,41 @@ private static byte[] toLEBytes(long[] values, PType ptype) { @Test void decode_singleRun_fillsAllElements() { - // Given — 1 run: ends=[5], values=[42]; output = [42, 42, 42, 42, 42] long[] ends = {5L}; long[] values = {42L}; DecodeContext ctx = buildCtx(DTypes.I64, 5, ends, values, PType.U32, 0L); - RunEndEncoding sut = new RunEndEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(5L); - var layout = PTypeIO.LE_LONG; for (int i = 0; i < 5; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)).isEqualTo(42L); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(42L); } } @Test void decode_multipleRuns_expandsCorrectly() { - // Given — runs: [0,2)=10, [2,5)=20, [5,7)=30 - // ends=[2,5,7], values=[10,20,30] long[] ends = {2L, 5L, 7L}; long[] values = {10L, 20L, 30L}; DecodeContext ctx = buildCtx(DTypes.I64, 7, ends, values, PType.U32, 0L); - RunEndEncoding sut = new RunEndEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then long[] expected = {10, 10, 20, 20, 20, 30, 30}; - var layout = PTypeIO.LE_LONG; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)) .as("index %d", i).isEqualTo(expected[i]); } } @Test void decode_withOffset_skipsLogicalElements() { - // Given — logical array: [0,3)=10, [3,6)=20; offset=2, rowCount=3 - // output elements [2..5): [10, 20, 20] long[] ends = {3L, 6L}; long[] values = {10L, 20L}; DecodeContext ctx = buildCtx(DTypes.I64, 3, ends, values, PType.U32, 2L); - RunEndEncoding sut = new RunEndEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then long[] expected = {10L, 20L, 20L}; - var layout = PTypeIO.LE_LONG; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)) .as("index %d", i).isEqualTo(expected[i]); } } @@ -146,37 +140,25 @@ static Stream i64Arrays() { @ParameterizedTest @MethodSource("i64Arrays") void encodeDecode_i64_isLossless(long[] data) { - // Given - var sut = new RunEndEncoding(); - Registry registry = TestRegistry.withPrimitive(sut); - var le = PTypeIO.LE_LONG; + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); - // When - EncodeResult encoded = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, registry); - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); } } @Test void encode_i64_metadata_numRuns_andEndsPtype() throws Exception { - // Given — 3 runs; if tag drifts, num_runs reads as 0 or ends_ptype reads as 0 (U8) long[] data = {1L, 1L, 2L, 2L, 3L}; - RunEndEncoding sut = new RunEndEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - RunEndMetadata meta = - RunEndMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + RunEndMetadata meta = RunEndMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.num_runs()).isEqualTo(3); - assertThat(meta.ends_ptype().value()).isEqualTo(2); // U32 + assertThat(meta.ends_ptype().value()).isEqualTo(2); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/SequenceEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoderTest.java similarity index 69% rename from core/src/test/java/io/github/dfa1/vortex/encoding/SequenceEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoderTest.java index 7e5981b4..9487a7eb 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/SequenceEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoderTest.java @@ -1,4 +1,4 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.VortexException; @@ -8,8 +8,16 @@ import io.github.dfa1.vortex.core.array.FloatArray; import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.proto.ScalarValue; import io.github.dfa1.vortex.proto.SequenceMetadata; +import io.github.dfa1.vortex.reader.decode.SequenceEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -24,7 +32,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class SequenceEncodingTest { +class SequenceEncodingEncoderTest { + + private static final SequenceEncodingEncoder ENCODER = new SequenceEncodingEncoder(); + private static final SequenceEncodingDecoder DECODER = new SequenceEncodingDecoder(); @Nested class Encode { @@ -32,27 +43,21 @@ class Encode { private static DecodeContext encodeResultToCtx(EncodeResult result, DType dtype, long n) { ByteBuffer meta = result.rootNode().metadata(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_SEQUENCE, meta, new ArrayNode[0], new int[0], null); - return new DecodeContext(node, dtype, n, new MemorySegment[0], Registry.empty(), Arena.ofAuto()); + return new DecodeContext(node, dtype, n, new MemorySegment[0], ReadRegistry.empty(), Arena.ofAuto()); } @Test void encodingId_isVortexSequence() { - // Given / When / Then - assertThat(new SequenceEncoding().encodingId()).isEqualTo(EncodingId.VORTEX_SEQUENCE); + assertThat(ENCODER.encodingId()).isEqualTo(EncodingId.VORTEX_SEQUENCE); } @Test void encode_i64_roundTrips() { - // Given - var sut = new SequenceEncoding(); long[] data = {10L, 12L, 14L, 16L}; - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); DecodeContext ctx = encodeResultToCtx(result, DTypes.I64, data.length); - LongArray decoded = (LongArray) sut.decode(ctx); + LongArray decoded = (LongArray) DECODER.decode(ctx); - // Then for (int i = 0; i < data.length; i++) { assertThat(decoded.getLong(i)).as("index %d", i).isEqualTo(data[i]); } @@ -60,16 +65,11 @@ void encode_i64_roundTrips() { @Test void encode_f64_roundTrips() { - // Given - var sut = new SequenceEncoding(); double[] data = {1.0, 1.5, 2.0, 2.5}; - - // When - EncodeResult result = sut.encode(DTypes.F64, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.F64, data, EncodeTestHelper.testCtx()); DecodeContext ctx = encodeResultToCtx(result, DTypes.F64, data.length); - DoubleArray decoded = (DoubleArray) sut.decode(ctx); + DoubleArray decoded = (DoubleArray) DECODER.decode(ctx); - // Then for (int i = 0; i < data.length; i++) { assertThat(decoded.getDouble(i)).as("index %d", i).isEqualTo(data[i]); } @@ -77,16 +77,11 @@ void encode_f64_roundTrips() { @Test void encode_f16_roundTrips() { - // Given — 0.0, 1.0, 2.0 as F16 bit patterns - var sut = new SequenceEncoding(); short[] data = {Float.floatToFloat16(0.0f), Float.floatToFloat16(1.0f), Float.floatToFloat16(2.0f)}; - - // When - EncodeResult result = sut.encode(DTypes.F16, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.F16, data, EncodeTestHelper.testCtx()); DecodeContext ctx = encodeResultToCtx(result, DTypes.F16, data.length); - Float16Array decoded = (Float16Array) sut.decode(ctx); + Float16Array decoded = (Float16Array) DECODER.decode(ctx); - // Then for (int i = 0; i < data.length; i++) { assertThat(decoded.getFloat(i)).as("index %d", i).isEqualTo(Float.float16ToFloat(data[i])); } @@ -94,22 +89,14 @@ void encode_f16_roundTrips() { @Test void encode_nonArithmeticSequence_throwsVortexException() { - // Given - var sut = new SequenceEncoding(); long[] data = {1L, 2L, 4L}; - - // When / Then - assertThatThrownBy(() -> sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx())) + assertThatThrownBy(() -> ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx())) .isInstanceOf(VortexException.class); } @Test void encode_nonPrimitiveDtype_throwsVortexException() { - // Given - var sut = new SequenceEncoding(); - - // When / Then - assertThatThrownBy(() -> sut.encode(new DType.Utf8(false), new long[]{1L}, EncodeTestHelper.testCtx())) + assertThatThrownBy(() -> ENCODER.encode(new DType.Utf8(false), new long[]{1L}, EncodeTestHelper.testCtx())) .isInstanceOf(VortexException.class); } } @@ -136,11 +123,9 @@ static Stream i32Sequences() { } private static DecodeContext makeCtx(byte[] meta, DType dtype, long n) { - ArrayNode node = ArrayNode.of( - EncodingId.VORTEX_SEQUENCE, - ByteBuffer.wrap(meta), - new ArrayNode[0], new int[0], null); - return new DecodeContext(node, dtype, n, new MemorySegment[0], Registry.empty(), Arena.ofAuto()); + ArrayNode node = ArrayNode.of(EncodingId.VORTEX_SEQUENCE, + ByteBuffer.wrap(meta), new ArrayNode[0], new int[0], null); + return new DecodeContext(node, dtype, n, new MemorySegment[0], ReadRegistry.empty(), Arena.ofAuto()); } private static byte[] intMeta(long base, long mul) { @@ -164,17 +149,12 @@ private static byte[] f16Meta(short baseShort, short mulShort) { @ParameterizedTest @MethodSource("i64Sequences") void decode_i64_generatesCorrectSequence(long base, long mul, long[] expected) { - // Given - var sut = new SequenceEncoding(); DecodeContext ctx = makeCtx(intMeta(base, mul), DTypes.I64, expected.length); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(expected.length); + LongArray longArray = (LongArray) result; for (int i = 0; i < expected.length; i++) { - LongArray longArray = (LongArray) result; assertThat(longArray.getLong(i)).as("index %d", i).isEqualTo(expected[i]); } } @@ -182,31 +162,21 @@ void decode_i64_generatesCorrectSequence(long base, long mul, long[] expected) { @ParameterizedTest @MethodSource("i32Sequences") void decode_i32_generatesCorrectSequence(long base, long mul, int[] expected) { - // Given - var sut = new SequenceEncoding(); DecodeContext ctx = makeCtx(intMeta(base, mul), DTypes.I32, expected.length); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(expected.length); + IntArray intArray = (IntArray) result; for (int i = 0; i < expected.length; i++) { - IntArray longArray = (IntArray) result; - assertThat(longArray.getInt(i)).as("index %d", i).isEqualTo(expected[i]); + assertThat(intArray.getInt(i)).as("index %d", i).isEqualTo(expected[i]); } } @Test void decode_f64_generatesCorrectSequence() { - // Given - var sut = new SequenceEncoding(); DecodeContext ctx = makeCtx(f64Meta(1.0, 0.5), DTypes.F64, 4); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(4); DoubleArray doubleArray = (DoubleArray) result; assertThat(doubleArray.getDouble(0)).isEqualTo(1.0); @@ -217,14 +187,9 @@ void decode_f64_generatesCorrectSequence() { @Test void decode_f32_generatesCorrectSequence() { - // Given - var sut = new SequenceEncoding(); DecodeContext ctx = makeCtx(f32Meta(0.0f, 1.0f), DTypes.F32, 3); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(3); FloatArray floatArray = (FloatArray) result; assertThat(floatArray.getFloat(0)).isEqualTo(0.0f); @@ -234,42 +199,27 @@ void decode_f32_generatesCorrectSequence() { @Test void decode_emptySequence_returnsZeroLengthArray() { - // Given - var sut = new SequenceEncoding(); DecodeContext ctx = makeCtx(intMeta(0, 1), DTypes.I64, 0); - - // When - Array result = sut.decode(ctx); - - // Then + Array result = DECODER.decode(ctx); assertThat(result.length()).isZero(); } @Test void decode_missingMetadata_throwsVortexException() { - // Given - var sut = new SequenceEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_SEQUENCE, null, new ArrayNode[0], new int[0], null); - DecodeContext ctx = new DecodeContext(node, DTypes.I64, 3, new MemorySegment[0], Registry.empty(), Arena.ofAuto()); + DecodeContext ctx = new DecodeContext(node, DTypes.I64, 3, new MemorySegment[0], ReadRegistry.empty(), Arena.ofAuto()); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .isInstanceOf(VortexException.class); + assertThatThrownBy(() -> DECODER.decode(ctx)).isInstanceOf(VortexException.class); } @Test void decode_f16_generatesCorrectSequence() { - // Given — 0.0, 1.0, 2.0 as F16 bit patterns via the direct-metadata path - var sut = new SequenceEncoding(); short baseShort = Float.floatToFloat16(0.0f); short mulShort = Float.floatToFloat16(1.0f); byte[] meta = f16Meta(baseShort, mulShort); DecodeContext ctx = makeCtx(meta, DTypes.F16, 3); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(3); Float16Array f16Array = (Float16Array) result; assertThat(f16Array.getFloat(0)).isEqualTo(0.0f); @@ -279,14 +229,9 @@ void decode_f16_generatesCorrectSequence() { @Test void decode_nonPrimitiveDtype_throwsVortexException() { - // Given - var sut = new SequenceEncoding(); DType utf8 = new DType.Utf8(false); DecodeContext ctx = makeCtx(intMeta(0, 1), utf8, 3); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .isInstanceOf(VortexException.class); + assertThatThrownBy(() -> DECODER.decode(ctx)).isInstanceOf(VortexException.class); } } @@ -295,17 +240,11 @@ class Metadata { @Test void encode_i64_metadata_base_andMultiplier_areSet() throws Exception { - // Given — arithmetic sequence {10, 12, 14, 16} → base=10, multiplier=2 - // if tag drifts, base/multiplier messages are missing (hasBase() == false) long[] data = {10L, 12L, 14L, 16L}; - SequenceEncoding sut = new SequenceEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); MemorySegment metaSeg = MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); SequenceMetadata meta = SequenceMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.base()).isNotNull(); assertThat(meta.multiplier()).isNotNull(); assertThat(meta.base().int64_value()).isEqualTo(10L); diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/SparseEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoderTest.java similarity index 67% rename from core/src/test/java/io/github/dfa1/vortex/encoding/SparseEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoderTest.java index dee032b9..65e2e501 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/SparseEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoderTest.java @@ -1,8 +1,5 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.PatchesMetadata; -import io.github.dfa1.vortex.proto.SparseMetadata; -import io.github.dfa1.vortex.proto.VarBinMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; @@ -10,12 +7,29 @@ import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.BoolArray; import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.NullValue; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.SparseMetadata; +import io.github.dfa1.vortex.proto.VarBinMetadata; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.SparseEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.VarBinEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import io.github.dfa1.vortex.proto.NullValue; -import io.github.dfa1.vortex.proto.ScalarValue; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; @@ -26,8 +40,11 @@ import static org.assertj.core.api.Assertions.assertThat; -class SparseEncodingTest { +class SparseEncodingEncoderTest { + private static final SparseEncodingEncoder ENCODER = new SparseEncodingEncoder(); + private static final SparseEncodingDecoder DECODER = new SparseEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Encode { @@ -45,76 +62,48 @@ private static Array decodeResult(EncodeResult encoded, DType dtype, int n) { ArrayNode sparseNode = ArrayNode.of(root.encodingId(), root.metadata(), new ArrayNode[]{idxNode, valNode}, root.bufferIndices(), ArrayStats.empty()); - Registry registry = TestRegistry.of(new SparseEncoding(), new PrimitiveEncoding()); - - DecodeContext ctx = new DecodeContext(sparseNode, dtype, n, segments, registry, Arena.global()); - return new SparseEncoding().decode(ctx); + DecodeContext ctx = new DecodeContext(sparseNode, dtype, n, segments, REGISTRY, Arena.global()); + return DECODER.decode(ctx); } @Test void encode_allZeros_noPatches() throws java.io.IOException { - // Given long[] data = {0L, 0L, 0L, 0L}; - SparseEncoding sut = new SparseEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - - // Then - java.nio.ByteBuffer metaBuf = result.rootNode().metadata().duplicate(); - java.lang.foreign.MemorySegment metaSeg = java.lang.foreign.MemorySegment.ofBuffer(metaBuf); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); SparseMetadata meta = SparseMetadata.decode(metaSeg, 0, metaSeg.byteSize()); assertThat(meta.patches().len()).isZero(); } @Test void encode_withNonZero_createsPatches() throws java.io.IOException { - // Given — [0, 10, 0, 50, 0] long[] data = {0L, 10L, 0L, 50L, 0L}; - SparseEncoding sut = new SparseEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - - // Then - java.nio.ByteBuffer metaBuf = result.rootNode().metadata().duplicate(); - java.lang.foreign.MemorySegment metaSeg = java.lang.foreign.MemorySegment.ofBuffer(metaBuf); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); SparseMetadata meta = SparseMetadata.decode(metaSeg, 0, metaSeg.byteSize()); assertThat(meta.patches().len()).isEqualTo(2); } @Test void encode_roundTrip_i64() { - // Given — sparse long array long[] data = {0L, 0L, 42L, 0L, 99L, 0L, 0L, 7L}; - SparseEncoding sut = new SparseEncoding(); - - // When - EncodeResult encoded = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); Array decoded = decodeResult(encoded, DTypes.I64, data.length); - // Then - var layout = PTypeIO.LE_LONG; for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(decoded).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(decoded).get(PTypeIO.LE_LONG, (long) i * 8)) .as("index %d", i).isEqualTo(data[i]); } } @Test void encode_roundTrip_f64() { - // Given double[] data = {0.0, 3.14, 0.0, 0.0, 2.72}; - SparseEncoding sut = new SparseEncoding(); - - // When - EncodeResult encoded = sut.encode(DTypes.F64, data, EncodeTestHelper.testCtx()); + EncodeResult encoded = ENCODER.encode(DTypes.F64, data, EncodeTestHelper.testCtx()); Array decoded = decodeResult(encoded, DTypes.F64, data.length); - // Then - var layout = PTypeIO.LE_DOUBLE; for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(decoded).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(decoded).get(PTypeIO.LE_DOUBLE, (long) i * 8)) .as("index %d", i).isEqualTo(data[i]); } } @@ -122,16 +111,9 @@ void encode_roundTrip_f64() { @ParameterizedTest @ValueSource(ints = {0, 1, 100}) void encode_empty_or_allZero_noPatches(int size) throws java.io.IOException { - // Given long[] data = new long[size]; - SparseEncoding sut = new SparseEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - - // Then - java.nio.ByteBuffer metaBuf = result.rootNode().metadata().duplicate(); - java.lang.foreign.MemorySegment metaSeg = java.lang.foreign.MemorySegment.ofBuffer(metaBuf); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); SparseMetadata meta = SparseMetadata.decode(metaSeg, 0, metaSeg.byteSize()); assertThat(meta.patches().len()).isZero(); } @@ -152,14 +134,11 @@ private static DecodeContext buildSparseCtxWithOffset( long[] patchIndices, long[] patchValues, long offset ) { byte[] fillBytes = ScalarValue.ofInt64Value(fillLong).encode(); - byte[] metaBytes = buildSparseMetaBytes(patchIndices.length, offset, idxPtype); - byte[] idxBuf = toLEBytes(patchIndices, idxPtype); byte[] valBuf = toLEBytes(patchValues, PType.I64); - return buildCtx(dtype, rowCount, fillBytes, metaBytes, idxBuf, valBuf, - new DType.Primitive(idxPtype, false)); + return buildCtx(dtype, rowCount, fillBytes, metaBytes, idxBuf, valBuf); } private static DecodeContext buildSparseCtxF64( @@ -168,20 +147,14 @@ private static DecodeContext buildSparseCtxF64( ) { byte[] fillBytes = ScalarValue.ofF64Value(fillDouble).encode(); byte[] metaBytes = buildSparseMetaBytes(patchIndices.length, 0L, PType.U32); - byte[] idxBuf = toLEBytes(patchIndices, PType.U32); byte[] valBuf = f64LEBytes(patchValues); - return buildCtx(dtype, rowCount, fillBytes, metaBytes, idxBuf, valBuf, - new DType.Primitive(PType.U32, false)); + return buildCtx(dtype, rowCount, fillBytes, metaBytes, idxBuf, valBuf); } - private static DecodeContext buildCtx( - DType dtype, long rowCount, - byte[] fillBytes, byte[] metaBytes, - byte[] idxBuf, byte[] valBuf, - DType idxDtype - ) { + private static DecodeContext buildCtx(DType dtype, long rowCount, + byte[] fillBytes, byte[] metaBytes, byte[] idxBuf, byte[] valBuf) { ArrayNode idxNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{1}, ArrayStats.empty()); ArrayNode valNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, @@ -189,24 +162,20 @@ private static DecodeContext buildCtx( ArrayNode sparseNode = ArrayNode.of(EncodingId.VORTEX_SPARSE, ByteBuffer.wrap(metaBytes), new ArrayNode[]{idxNode, valNode}, - new int[]{0}, - ArrayStats.empty()); + new int[]{0}, ArrayStats.empty()); MemorySegment[] segments = { MemorySegment.ofArray(fillBytes), MemorySegment.ofArray(idxBuf), MemorySegment.ofArray(valBuf) }; - - Registry registry = TestRegistry.of(new SparseEncoding(), new PrimitiveEncoding()); - - return new DecodeContext(sparseNode, dtype, rowCount, segments, registry, java.lang.foreign.Arena.global()); + return new DecodeContext(sparseNode, dtype, rowCount, segments, REGISTRY, java.lang.foreign.Arena.global()); } private static byte[] buildSparseMetaBytes(long numPatches, long offset, PType idxPtype) { - PatchesMetadata patchesMeta = new PatchesMetadata(numPatches, offset, io.github.dfa1.vortex.proto.PType.fromValue(idxPtype.ordinal()), null, null, null); - return new SparseMetadata(patchesMeta) - .encode(); + PatchesMetadata patchesMeta = new PatchesMetadata(numPatches, offset, + io.github.dfa1.vortex.proto.PType.fromValue(idxPtype.ordinal()), null, null, null); + return new SparseMetadata(patchesMeta).encode(); } private static byte[] toLEBytes(long[] values, PType ptype) { @@ -245,114 +214,75 @@ private static byte[] intLEBytes(int[] values) { @Test void decode_noPatches_returnsFillValue() { - // Given — 5 elements, fill=99, no patches long fill = 99L; DecodeContext ctx = buildSparseCtx(DTypes.I64, 5, fill, PType.U32, new long[0], new long[0]); - SparseEncoding sut = new SparseEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(5L); - var layout = PTypeIO.LE_LONG; for (int i = 0; i < 5; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)) .as("index %d", i).isEqualTo(fill); } } @Test void decode_withPatches_overwritesAtIndices() { - // Given — 8 elements, fill=0, patches at indices [1, 5] with values [10, 50] long fill = 0L; long[] patchIndices = {1L, 5L}; long[] patchValues = {10L, 50L}; DecodeContext ctx = buildSparseCtx(DTypes.I64, 8, fill, PType.U32, patchIndices, patchValues); - SparseEncoding sut = new SparseEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_LONG; long[] expected = {0, 10, 0, 0, 0, 50, 0, 0}; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)) .as("index %d", i).isEqualTo(expected[i]); } } @Test void decode_f64_fillAndPatches() { - // Given — 4 F64 elements, fill=NaN bits, patch at index 2 with value 3.14 double fillVal = Double.NaN; double patchVal = 3.14; DecodeContext ctx = buildSparseCtxF64(DTypes.F64, 4, fillVal, new long[]{2L}, new double[]{patchVal}); - SparseEncoding sut = new SparseEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then - var layout = PTypeIO.LE_DOUBLE; - assertThat(ArraySegments.of(result).get(layout, 0L)).isNaN(); - assertThat(ArraySegments.of(result).get(layout, 8L)).isNaN(); - assertThat(ArraySegments.of(result).get(layout, 16L)).isEqualTo(3.14); - assertThat(ArraySegments.of(result).get(layout, 24L)).isNaN(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 0L)).isNaN(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 8L)).isNaN(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 16L)).isEqualTo(3.14); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 24L)).isNaN(); } @Test void decode_offsetSubtracted() { - // Given — offset=10, patch index=12 → absolute position = 12 - 10 = 2 long[] patchIndices = {12L}; long[] patchValues = {777L}; DecodeContext ctx = buildSparseCtxWithOffset(DTypes.I64, 5, 0L, PType.U32, patchIndices, patchValues, 10L); - SparseEncoding sut = new SparseEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_LONG; - assertThat(ArraySegments.of(result).get(layout, 16L)).isEqualTo(777L); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, 16L)).isEqualTo(777L); } - // regression: NULL_VALUE fill caused "unexpected scalar kind NULL_VALUE" on nullable cols @Test void decode_nullValueFill_treatedAsZero() { - // Given — fill encoded as ScalarValue.NULL_VALUE (as Rust writes for nullable cols) byte[] nullFill = ScalarValue.ofNullValue(NullValue.NULL_VALUE).encode(); byte[] meta = buildSparseMetaBytes(0, 0L, PType.U32); - DecodeContext ctx = buildCtx(DTypes.I64, 4, nullFill, meta, new byte[0], new byte[0], - new DType.Primitive(PType.U32, false)); - SparseEncoding sut = new SparseEncoding(); - - // When - Array result = sut.decode(ctx); + DecodeContext ctx = buildCtx(DTypes.I64, 4, nullFill, meta, new byte[0], new byte[0]); + Array result = DECODER.decode(ctx); - // Then - var layout = PTypeIO.LE_LONG; for (int i = 0; i < 4; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)).as("index %d", i).isZero(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isZero(); } } - // regression: Utf8 dtype caused "expected primitive dtype, got Utf8[nullable=true]" @Test void decode_utf8_noPatches_allEmpty() { - // Given — Utf8 sparse, no patches → all positions empty (null fill) DType utf8 = new DType.Utf8(true); byte[] nullFill = ScalarValue.ofNullValue(NullValue.NULL_VALUE).encode(); byte[] meta = buildSparseMetaBytes(0, 0L, PType.U32); - DecodeContext ctx = buildCtx(utf8, 3, nullFill, meta, new byte[0], new byte[0], - new DType.Primitive(PType.U32, false)); - SparseEncoding sut = new SparseEncoding(); - - // When - Array result = sut.decode(ctx); + DecodeContext ctx = buildCtx(utf8, 3, nullFill, meta, new byte[0], new byte[0]); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(3L); VarBinArray varBin = (VarBinArray) result; for (int i = 0; i < 3; i++) { @@ -360,10 +290,8 @@ void decode_utf8_noPatches_allEmpty() { } } - // regression: Utf8 dtype caused "expected primitive dtype, got Utf8[nullable=true]" @Test void decode_utf8_withPatches_writesStringsAtIndices() { - // Given — 5 Utf8 elements, patches at [1]="hi" and [3]="bye" DType utf8 = new DType.Utf8(true); byte[] nullFill = ScalarValue.ofNullValue(NullValue.NULL_VALUE).encode(); byte[] meta = buildSparseMetaBytes(2, 0L, PType.U32); @@ -384,7 +312,7 @@ void decode_utf8_withPatches_writesStringsAtIndices() { ByteBuffer.wrap(meta), new ArrayNode[]{idxNode, valNode}, new int[]{0}, ArrayStats.empty()); - Registry registry = TestRegistry.of(new SparseEncoding(), new PrimitiveEncoding(), new VarBinEncoding()); + ReadRegistry registry = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new VarBinEncodingDecoder()); MemorySegment[] segments = { MemorySegment.ofArray(nullFill), @@ -393,12 +321,9 @@ void decode_utf8_withPatches_writesStringsAtIndices() { MemorySegment.ofArray(offsets), }; DecodeContext ctx = new DecodeContext(sparseNode, utf8, 5, segments, registry, Arena.global()); - SparseEncoding sut = new SparseEncoding(); - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then VarBinArray varBin = (VarBinArray) result; assertThat(varBin.length()).isEqualTo(5L); assertThat(varBin.getByteLength(0)).isZero(); @@ -408,10 +333,8 @@ void decode_utf8_withPatches_writesStringsAtIndices() { assertThat(varBin.getByteLength(4)).isZero(); } - // regression: Bool dtype caused "expected primitive dtype, got Bool[nullable=true]" @Test void decode_bool_withPatches_setsBitsAtIndices() { - // Given — 6 Bool elements, patches at [2]=true and [5]=true DType bool = new DType.Bool(true); byte[] nullFill = ScalarValue.ofNullValue(NullValue.NULL_VALUE).encode(); byte[] meta = buildSparseMetaBytes(2, 0L, PType.U32); @@ -426,7 +349,7 @@ void decode_bool_withPatches_setsBitsAtIndices() { ByteBuffer.wrap(meta), new ArrayNode[]{idxNode, valNode}, new int[]{0}, ArrayStats.empty()); - Registry registry = TestRegistry.of(new SparseEncoding(), new PrimitiveEncoding(), new BoolEncoding()); + ReadRegistry registry = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); MemorySegment[] segments = { MemorySegment.ofArray(nullFill), @@ -434,12 +357,9 @@ void decode_bool_withPatches_setsBitsAtIndices() { MemorySegment.ofArray(boolBits), }; DecodeContext ctx = new DecodeContext(sparseNode, bool, 6, segments, registry, Arena.global()); - SparseEncoding sut = new SparseEncoding(); - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then BoolArray boolArr = (BoolArray) result; assertThat(boolArr.length()).isEqualTo(6L); assertThat(boolArr.getBoolean(0)).isFalse(); diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/StructEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoderTest.java similarity index 65% rename from core/src/test/java/io/github/dfa1/vortex/encoding/StructEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoderTest.java index 777445f5..2c93f0ef 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/StructEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoderTest.java @@ -1,4 +1,4 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; @@ -7,6 +7,21 @@ import io.github.dfa1.vortex.core.array.LongArray; import io.github.dfa1.vortex.core.array.MaskedArray; import io.github.dfa1.vortex.core.array.StructArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.encoding.StructData; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.encoding.TestSegments; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.StructEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -16,7 +31,11 @@ import static org.assertj.core.api.Assertions.assertThat; -class StructEncodingTest { +class StructEncodingEncoderTest { + + private static final StructEncodingEncoder ENCODER = new StructEncodingEncoder(); + private static final StructEncodingDecoder DECODER = new StructEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static ArrayNode toArrayNode(EncodeNode node) { ArrayNode[] children = new ArrayNode[node.children().length]; @@ -31,35 +50,25 @@ class Encode { @Test void accepts_structDtype_trueForStruct_falseForPrimitive() { - // Given - StructEncoding sut = new StructEncoding(); - DType.Struct structDtype = new DType.Struct( - List.of("x"), List.of(DTypes.I64), false); - - // When / Then - assertThat(sut.accepts(structDtype)).isTrue(); - assertThat(sut.accepts(DTypes.I64)).isFalse(); + DType.Struct structDtype = new DType.Struct(List.of("x"), List.of(DTypes.I64), false); + assertThat(ENCODER.accepts(structDtype)).isTrue(); + assertThat(ENCODER.accepts(DTypes.I64)).isFalse(); } @Test void roundTrip_twoI64Fields_preservesValues() { - // Given long[] ids = {1L, 2L, 3L}; long[] values = {10L, 20L, 30L}; DType.Struct dtype = new DType.Struct( List.of("id", "value"), List.of(DTypes.I64, DTypes.I64), false); StructData data = new StructData(List.of(ids, values)); - StructEncoding sut = new StructEncoding(); - // When - EncodeResult result = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); - // Then — decode round-trip MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); - Registry registry = TestRegistry.of(new StructEncoding(), new PrimitiveEncoding()); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), dtype, ids.length, bufs, registry, Arena.global()); - StructArray decoded = (StructArray) sut.decode(ctx); + toArrayNode(result.rootNode()), dtype, ids.length, bufs, REGISTRY, Arena.global()); + StructArray decoded = (StructArray) DECODER.decode(ctx); assertThat(decoded.length()).isEqualTo(ids.length); assertThat(decoded.fieldCount()).isEqualTo(2); @@ -73,32 +82,25 @@ void roundTrip_twoI64Fields_preservesValues() { @Test void singleField_encodeResult_hasOneChildAndNoBuffers() { - // Given long[] data = {7L, 14L, 21L}; DType.Struct dtype = new DType.Struct(List.of("v"), List.of(DTypes.I64), false); - StructEncoding sut = new StructEncoding(); - // When - EncodeResult result = sut.encode(dtype, new StructData(List.of(data)), EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, new StructData(List.of(data)), EncodeTestHelper.testCtx()); - // Then — struct node wraps one field child with remapped buffers assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_STRUCT); assertThat(result.rootNode().children()).hasSize(1); assertThat(result.rootNode().bufferIndices()).isEmpty(); - assertThat(result.buffers()).hasSize(1); // one buffer for the DTypes.I64 field + assertThat(result.buffers()).hasSize(1); } @Test void fieldCountMismatch_throwsVortexException() { - // Given DType.Struct dtype = new DType.Struct(List.of("a", "b"), List.of(DTypes.I64, DTypes.I64), false); - StructData data = new StructData(List.of(new long[]{1L})); // only 1 field, dtype has 2 - StructEncoding sut = new StructEncoding(); + StructData data = new StructData(List.of(new long[]{1L})); - // When / Then org.junit.jupiter.api.Assertions.assertThrows( io.github.dfa1.vortex.core.VortexException.class, - () -> sut.encode(dtype, data, EncodeTestHelper.testCtx())); + () -> ENCODER.encode(dtype, data, EncodeTestHelper.testCtx())); } } @@ -115,27 +117,18 @@ private static ArrayNode boolNode(int bufferIdx) { new int[]{bufferIdx}, ArrayStats.empty()); } - private static DecodeContext buildStructCtx(ArrayNode structNode, MemorySegment[] segs, long rowCount) { - Registry registry = TestRegistry.of(new StructEncoding(), new PrimitiveEncoding()); - return new DecodeContext(structNode, DTypes.I64, rowCount, segs, registry, Arena.global()); - } - @Test void decode_nonNullableWrapper_oneChild_returnsValues() { - // Given — struct{values: DTypes.I64} (non-nullable, 1 child) long[] data = {10L, 20L, 30L}; MemorySegment seg = TestSegments.leLongs(data); ArrayNode valuesNode = primitiveNode(0); ArrayNode structNode = ArrayNode.of(EncodingId.VORTEX_STRUCT, null, new ArrayNode[]{valuesNode}, new int[0], ArrayStats.empty()); - DecodeContext ctx = buildStructCtx(structNode, new MemorySegment[]{seg}, data.length); - StructEncoding sut = new StructEncoding(); + DecodeContext ctx = new DecodeContext(structNode, DTypes.I64, data.length, + new MemorySegment[]{seg}, REGISTRY, Arena.global()); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(data[i]); @@ -144,28 +137,23 @@ void decode_nonNullableWrapper_oneChild_returnsValues() { @Test void decode_nullableWrapper_twoChildren_returnsMaskedArray() { - // Given — struct{validity: Bool, values: DTypes.I64} (nullable, 2 children) long[] data = {7L, 14L, 21L}; - MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{(byte) 0xFF}); // all valid + MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{(byte) 0xFF}); MemorySegment valuesSeg = TestSegments.leLongs(data); - ArrayNode validityNode = boolNode(0); // slot 0 = validity bitmap - ArrayNode valuesNode = primitiveNode(1); // slot 1 = actual values + ArrayNode validityNode = boolNode(0); + ArrayNode valuesNode = primitiveNode(1); ArrayNode structNode = ArrayNode.of(EncodingId.VORTEX_STRUCT, null, new ArrayNode[]{validityNode, valuesNode}, new int[0], ArrayStats.empty()); - Registry registry = TestRegistry.of(new StructEncoding(), new PrimitiveEncoding(), new BoolEncoding()); + ReadRegistry registry = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); DecodeContext ctx = new DecodeContext( structNode, DTypes.I64, data.length, new MemorySegment[]{validitySeg, valuesSeg}, registry, Arena.global()); - StructEncoding sut = new StructEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then — validity preserved; values accessible via inner array assertThat(result).isInstanceOf(MaskedArray.class); MaskedArray masked = (MaskedArray) result; LongArray values = (LongArray) masked.inner(); diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoderTest.java new file mode 100644 index 00000000..c5ef1934 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoderTest.java @@ -0,0 +1,146 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.VarBinMetadata; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.VarBinEncodingDecoder; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.charset.StandardCharsets; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class VarBinEncodingEncoderTest { + + private static final VarBinEncodingEncoder ENCODER = new VarBinEncodingEncoder(); + private static final VarBinEncodingDecoder DECODER = new VarBinEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); + + @Nested + class Encode { + + @Test + void encodingId_isVortexVarbin() { + assertThat(ENCODER.encodingId()).isEqualTo(EncodingId.VORTEX_VARBIN); + } + + @Test + void accepts_utf8Dtype_returnsTrue() { + assertThat(ENCODER.accepts(DTypes.UTF8)).isTrue(); + } + + @Test + void accepts_binaryDtype_returnsTrue() { + assertThat(ENCODER.accepts(DTypes.BINARY)).isTrue(); + } + + @Test + void accepts_primitiveDtype_returnsFalse() { + assertThat(ENCODER.accepts(new DType.Primitive(PType.I64, false))).isFalse(); + } + + @Test + void encode_singleString_roundTrips() { + String[] data = {"hello"}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, REGISTRY); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); + + assertThat(decoded.length()).isEqualTo(1); + assertThat(decoded.getBytes(0)).isEqualTo("hello".getBytes(StandardCharsets.UTF_8)); + } + + @Test + void encode_multipleStrings_roundTrips() { + String[] data = {"foo", "bar", "baz"}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, REGISTRY); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); + + assertThat(decoded.length()).isEqualTo(3); + for (int i = 0; i < data.length; i++) { + assertThat(decoded.getBytes(i)).isEqualTo(data[i].getBytes(StandardCharsets.UTF_8)); + } + } + + @Test + void encode_unicodeString_roundTrips() { + String[] data = {"héllo", "wörld", "日本語"}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, REGISTRY); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); + + assertThat(decoded.length()).isEqualTo(3); + for (int i = 0; i < data.length; i++) { + assertThat(decoded.getBytes(i)).isEqualTo(data[i].getBytes(StandardCharsets.UTF_8)); + } + } + + @Test + void encode_emptyStringInArray_roundTrips() { + String[] data = {"a", "", "b"}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, REGISTRY); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); + + assertThat(decoded.length()).isEqualTo(3); + assertThat(decoded.getBytes(0)).isEqualTo(new byte[]{'a'}); + assertThat(decoded.getBytes(1)).isEmpty(); + assertThat(decoded.getBytes(2)).isEqualTo(new byte[]{'b'}); + } + + @Test + void encode_emptyArray_producesZeroLengthResult() { + String[] data = {}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, REGISTRY); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); + + assertThat(decoded.length()).isZero(); + } + } + + @Nested + class Decode { + + @Test + void decode_missingMetadata_throwsVortexException() { + ArrayNode node = ArrayNode.of(EncodingId.VORTEX_VARBIN, null, new ArrayNode[0], new int[0], null); + DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, 3, new MemorySegment[0], + ReadRegistry.empty(), Arena.ofAuto()); + assertThatThrownBy(() -> DECODER.decode(ctx)) + .isInstanceOf(VortexException.class) + .hasMessageContaining("missing metadata"); + } + } + + @Nested + class Metadata { + + @Test + void encode_utf8_metadata_offsetsPtype_isI64() throws Exception { + String[] data = {"hello", "world"}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + VarBinMetadata meta = VarBinMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + + assertThat(meta.offsets_ptype().value()).isEqualTo(7); + } + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/VarBinViewEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoderTest.java similarity index 71% rename from core/src/test/java/io/github/dfa1/vortex/encoding/VarBinViewEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoderTest.java index 457070a0..09a9a9b6 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/VarBinViewEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoderTest.java @@ -1,6 +1,16 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.reader.decode.VarBinViewEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -14,58 +24,43 @@ import static org.assertj.core.api.Assertions.assertThat; -/// Property: decode reconstructs every string value exactly, -/// regardless of inlined vs reference layout. -class VarBinViewEncodingTest { +class VarBinViewEncodingEncoderTest { + + private static final VarBinViewEncodingEncoder ENCODER = new VarBinViewEncodingEncoder(); + private static final VarBinViewEncodingDecoder DECODER = new VarBinViewEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER); @Nested class Encode { @Test void accepts_utf8_true() { - // Given - var sut = new VarBinViewEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.UTF8)).isTrue(); + assertThat(ENCODER.accepts(DTypes.UTF8)).isTrue(); } @Test void accepts_binary_true() { - // Given - var sut = new VarBinViewEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.BINARY)).isTrue(); + assertThat(ENCODER.accepts(DTypes.BINARY)).isTrue(); } @Test void accepts_primitive_false() { - // Given - var sut = new VarBinViewEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I32)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I32)).isFalse(); } @ParameterizedTest(name = "{0}") - @MethodSource("io.github.dfa1.vortex.encoding.VarBinViewEncodingTest$Decode#stringArrays") + @MethodSource("io.github.dfa1.vortex.writer.encode.VarBinViewEncodingEncoderTest$Decode#stringArrays") void encode_thenDecode_roundtripsAllStrings(String name, String[] values) { - // Given - var sut = new VarBinViewEncoding(); Arena arena = Arena.ofAuto(); - // When - EncodeResult result = sut.encode(DTypes.UTF8, values, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.UTF8, values, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); ArrayNode node = ArrayNode.of( EncodingId.VORTEX_VARBINVIEW, null, new ArrayNode[0], result.rootNode().bufferIndices(), null); - Registry registry = TestRegistry.of(sut); - DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, values.length, bufs, registry, arena); - var decoded = (VarBinArray) sut.decode(ctx); + DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, values.length, bufs, REGISTRY, arena); + var decoded = (VarBinArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(values.length); for (int i = 0; i < values.length; i++) { assertThat(decoded.getString(i)).as("index %d", i).isEqualTo(values[i]); @@ -81,8 +76,8 @@ static Stream stringArrays() { Arguments.of("empty-array", new String[0]), Arguments.of("single-empty-string", new String[]{""}), Arguments.of("short-strings", new String[]{"hi", "ok", "no"}), - Arguments.of("exactly-12-bytes", new String[]{"123456789012"}), // max inlined - Arguments.of("just-over-12-bytes", new String[]{"1234567890123"}), // min reference + Arguments.of("exactly-12-bytes", new String[]{"123456789012"}), + Arguments.of("just-over-12-bytes", new String[]{"1234567890123"}), Arguments.of("long-strings", new String[]{"the quick brown fox jumps over the lazy dog"}), Arguments.of("mixed-lengths", new String[]{"a", "hello", "this is a longer string than twelve"}), Arguments.of("repeated-short", repeat("ab", 50)), @@ -99,13 +94,11 @@ private static String[] repeat(String s, int n) { } @ParameterizedTest(name = "{0}") - @MethodSource("io.github.dfa1.vortex.encoding.VarBinViewEncodingTest$Decode#stringArrays") + @MethodSource("stringArrays") void decode_roundtrip_returnsAllStrings(String name, String[] values) { - // Given Arena arena = Arena.ofAuto(); long n = values.length; - // Encode all long strings into one data buffer byte[][] bytesArr = new byte[values.length][]; int dataBufLen = 0; for (int i = 0; i < values.length; i++) { @@ -124,18 +117,15 @@ void decode_roundtrip_returnsAllStrings(String name, String[] values) { long viewOff = (long) i * 16; views.set(PTypeIO.LE_INT, viewOff, b.length); if (b.length <= 12) { - // inlined: data at viewOff+4 MemorySegment.copy(MemorySegment.ofArray(b), 0, views, viewOff + 4, b.length); } else { - // reference: buffer_index=0, offset=dataOffset - views.set(PTypeIO.LE_INT, viewOff + 8, 0); // buffer_index - views.set(PTypeIO.LE_INT, viewOff + 12, dataOffset); // offset + views.set(PTypeIO.LE_INT, viewOff + 8, 0); + views.set(PTypeIO.LE_INT, viewOff + 12, dataOffset); MemorySegment.copy(MemorySegment.ofArray(b), 0, dataBuf, dataOffset, b.length); dataOffset += b.length; } } - // bufferIndices: [0=dataBuf, 1=views] when data buffer needed, else just [0=views] int[] bufIndices; MemorySegment[] segBufs; if (dataBufLen > 0) { @@ -149,23 +139,17 @@ void decode_roundtrip_returnsAllStrings(String name, String[] values) { ArrayNode node = ArrayNode.of(EncodingId.VORTEX_VARBINVIEW, null, new ArrayNode[0], bufIndices, null); - Registry registry = TestRegistry.of(new VarBinViewEncoding()); - - DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, n, segBufs, registry, arena); - var sut = new VarBinViewEncoding(); + DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, n, segBufs, REGISTRY, arena); - // When - var result = sut.decode(ctx); + var result = DECODER.decode(ctx); - // Then assertThat(result).isInstanceOf(VarBinArray.class); assertThat(result.length()).isEqualTo(n); + VarBinArray varBinArray = (VarBinArray) result; for (int i = 0; i < values.length; i++) { - VarBinArray varBinArray = (VarBinArray) result; String decoded = varBinArray.getString(i); assertThat(decoded).as("index %d", i).isEqualTo(values[i]); } } } - } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ZigZagEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoderTest.java similarity index 66% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ZigZagEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoderTest.java index eec143e7..1be86e81 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ZigZagEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoderTest.java @@ -1,8 +1,20 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.ZigZagEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -17,15 +29,17 @@ import static org.assertj.core.api.Assertions.assertThat; -class ZigZagEncodingTest { +class ZigZagEncodingEncoderTest { + private static final ZigZagEncodingEncoder ENCODER = new ZigZagEncodingEncoder(); + private static final ZigZagEncodingDecoder DECODER = new ZigZagEncodingDecoder(); + private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Decode { static Stream i32Cases() { return Stream.of( - // zigzag: 0→0, 1→-1, 2→1, 3→-2, 4→2 Arguments.of("zeros", new int[]{0, 0, 0}, new int[]{0, 0, 0}), Arguments.of("mixed", new int[]{0, 1, 2, 3, 4}, new int[]{0, -1, 1, -2, 2}), Arguments.of("large", new int[]{Integer.MAX_VALUE & ~1, (Integer.MAX_VALUE & ~1) | 1}, @@ -44,22 +58,16 @@ private static DecodeContext buildI32Ctx(int[] encodedUnsigned) { ArrayNode primitiveNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, null); ArrayNode zigzagNode = ArrayNode.of(EncodingId.VORTEX_ZIGZAG, null, new ArrayNode[]{primitiveNode}, new int[0], null); - Registry registry = TestRegistry.of(new ZigZagEncoding(), new PrimitiveEncoding()); return new DecodeContext(zigzagNode, DTypes.I32, encodedUnsigned.length, - new MemorySegment[]{seg}, registry, Arena.ofAuto()); + new MemorySegment[]{seg}, REGISTRY, Arena.ofAuto()); } @ParameterizedTest(name = "{0}") @MethodSource("i32Cases") void decode_i32_zigzagDecodesCorrectly(String name, int[] encoded, int[] expected) { - // Given DecodeContext ctx = buildI32Ctx(encoded); - var sut = new ZigZagEncoding(); + var result = DECODER.decode(ctx); - // When - var result = sut.decode(ctx); - - // Then assertThat(result).isInstanceOf(IntArray.class); assertThat(result.length()).isEqualTo(expected.length); MemorySegment seg = ArraySegments.of(result); @@ -71,14 +79,8 @@ void decode_i32_zigzagDecodesCorrectly(String name, int[] encoded, int[] expecte @Test void decode_empty_returnsEmptyArray() { - // Given DecodeContext ctx = buildI32Ctx(new int[]{}); - var sut = new ZigZagEncoding(); - - // When - var result = sut.decode(ctx); - - // Then + var result = DECODER.decode(ctx); assertThat(result.length()).isZero(); } } @@ -108,40 +110,26 @@ static Stream i64RoundtripArrays() { @ParameterizedTest @MethodSource("i32RoundtripArrays") void encodeDecode_i32_isLossless(int[] data) { - // Given - var sut = new ZigZagEncoding(); - Registry registry = TestRegistry.withPrimitive(sut); - var le = PTypeIO.LE_INT; + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, REGISTRY); + Array result = DECODER.decode(ctx); - // When - EncodeResult encoded = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, registry); - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); } } @ParameterizedTest @MethodSource("i64RoundtripArrays") void encodeDecode_i64_isLossless(long[] data) { - // Given - var sut = new ZigZagEncoding(); - Registry registry = TestRegistry.withPrimitive(sut); - var le = PTypeIO.LE_LONG; - - // When - EncodeResult encoded = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ZstdEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java similarity index 69% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ZstdEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java index b1b0eaff..0a8c0cda 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ZstdEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java @@ -1,7 +1,5 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.ZstdFrameMetadata; -import io.github.dfa1.vortex.proto.ZstdMetadata; import com.github.luben.zstd.ZstdCompressCtx; import io.airlift.compress.v3.zstd.ZstdCompressor; import io.airlift.compress.v3.zstd.ZstdJavaCompressor; @@ -12,6 +10,20 @@ import io.github.dfa1.vortex.core.array.LongArray; import io.github.dfa1.vortex.core.array.MaskedArray; import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.reader.decode.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.reader.decode.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.reader.decode.DecodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.reader.ReadRegistry; +import io.github.dfa1.vortex.reader.decode.TestRegistry; +import io.github.dfa1.vortex.proto.ZstdFrameMetadata; +import io.github.dfa1.vortex.proto.ZstdMetadata; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.ZstdEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -27,24 +39,21 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class ZstdEncodingTest { +class ZstdEncodingEncoderTest { + + private static final ZstdEncodingEncoder ENCODER = new ZstdEncodingEncoder(); + private static final ZstdEncodingDecoder DECODER = new ZstdEncodingDecoder(); + private static final BoolEncodingEncoder BOOL_ENCODER = new BoolEncodingEncoder(); @Nested class Encode { @Test void encode_i32_roundTrips() { - // Given - var sut = new ZstdEncoding(); int[] data = {10, 20, 30, 40}; - - // When - EncodeResult result = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.I32, Registry.empty()); - IntArray decoded = (IntArray) sut.decode(ctx); - - // Then - assertThat(decoded.length()).isEqualTo(data.length); + EncodeResult result = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.I32, ReadRegistry.empty()); + IntArray decoded = (IntArray) DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { assertThat(decoded.getInt(i)).as("index %d", i).isEqualTo(data[i]); } @@ -52,17 +61,10 @@ void encode_i32_roundTrips() { @Test void encode_i64_roundTrips() { - // Given - var sut = new ZstdEncoding(); long[] data = {100L, 200L, 300L}; - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.I64, Registry.empty()); - LongArray decoded = (LongArray) sut.decode(ctx); - - // Then - assertThat(decoded.length()).isEqualTo(data.length); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.I64, ReadRegistry.empty()); + LongArray decoded = (LongArray) DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { assertThat(decoded.getLong(i)).as("index %d", i).isEqualTo(data[i]); } @@ -70,17 +72,10 @@ void encode_i64_roundTrips() { @Test void encode_utf8_roundTrips() { - // Given - var sut = new ZstdEncoding(); String[] data = {"hello", "world", "zstd"}; - - // When - EncodeResult result = sut.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, Registry.empty()); - VarBinArray decoded = (VarBinArray) sut.decode(ctx); - - // Then - assertThat(decoded.length()).isEqualTo(data.length); + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, ReadRegistry.empty()); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { assertThat(decoded.getString(i)).as("index %d", i).isEqualTo(data[i]); } @@ -88,26 +83,16 @@ void encode_utf8_roundTrips() { @Test void encode_emptyArray_roundTrips() { - // Given - var sut = new ZstdEncoding(); int[] data = {}; - - // When - EncodeResult result = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.I32, Registry.empty()); - IntArray decoded = (IntArray) sut.decode(ctx); - - // Then + EncodeResult result = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.I32, ReadRegistry.empty()); + IntArray decoded = (IntArray) DECODER.decode(ctx); assertThat(decoded.length()).isZero(); } @Test void encode_unsupportedDtype_throwsVortexException() { - // Given - var sut = new ZstdEncoding(); - - // When / Then - assertThatThrownBy(() -> sut.encode(new DType.Null(false), null, EncodeTestHelper.testCtx())) + assertThatThrownBy(() -> ENCODER.encode(new DType.Null(false), null, EncodeTestHelper.testCtx())) .isInstanceOf(VortexException.class); } } @@ -118,7 +103,6 @@ class Decode { private static DecodeContext makeDictCtx( byte[] meta, DType dtype, long n, byte[] dictBytes, byte[]... compressedFrames ) { - // buffer[0] = dict, buffer[1..] = frames MemorySegment[] segments = new MemorySegment[1 + compressedFrames.length]; segments[0] = MemorySegment.ofArray(dictBytes); int[] bufIndices = new int[1 + compressedFrames.length]; @@ -129,11 +113,10 @@ private static DecodeContext makeDictCtx( } ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ZSTD, ByteBuffer.wrap(meta), new ArrayNode[0], bufIndices, null); - return new DecodeContext(node, dtype, n, segments, Registry.empty(), Arena.ofAuto()); + return new DecodeContext(node, dtype, n, segments, ReadRegistry.empty(), Arena.ofAuto()); } private static byte[] makeDictFor(byte[]... samples) { - // Repeat samples to meet zstd's minimum training data requirement (~1 KB) int total = 0; for (byte[] s : samples) { total += s.length; @@ -160,8 +143,7 @@ private static byte[] compressWithDict(byte[] data, byte[] dictBytes) { private static DecodeContext makeNullableCtx( byte[] meta, DType dtype, long n, boolean[] validityBits, byte[]... compressedFrames ) { - BoolEncoding boolEncoding = new BoolEncoding(); - EncodeResult validityResult = boolEncoding.encode(new DType.Bool(false), validityBits, EncodeTestHelper.testCtx()); + EncodeResult validityResult = BOOL_ENCODER.encode(new DType.Bool(false), validityBits, EncodeTestHelper.testCtx()); EncodeNode remappedValidity = EncodeNode.remapBufferIndices( validityResult.rootNode(), compressedFrames.length); @@ -177,7 +159,7 @@ private static DecodeContext makeNullableCtx( ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ZSTD, ByteBuffer.wrap(meta), new ArrayNode[]{validityNode}, bufIndices, null); - Registry registry = Registry.builder().register(new BoolEncoding()).build(); + ReadRegistry registry = TestRegistry.ofDecoders(new BoolEncodingDecoder()); return new DecodeContext(node, dtype, n, allSegments.toArray(new MemorySegment[0]), registry, Arena.ofAuto()); @@ -191,18 +173,6 @@ private static ArrayNode toArrayNode(EncodeNode enc) { return ArrayNode.of(enc.encodingId(), enc.metadata(), children, enc.bufferIndices(), null); } - private static DecodeContext makeCtx(byte[] meta, DType dtype, long n, byte[]... compressedFrames) { - MemorySegment[] segments = new MemorySegment[compressedFrames.length]; - int[] bufIndices = new int[compressedFrames.length]; - for (int i = 0; i < compressedFrames.length; i++) { - segments[i] = MemorySegment.ofArray(compressedFrames[i]); - bufIndices[i] = i; - } - ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ZSTD, ByteBuffer.wrap(meta), - new ArrayNode[0], bufIndices, null); - return new DecodeContext(node, dtype, n, segments, Registry.empty(), Arena.ofAuto()); - } - private static byte[] compress(byte[] input) { ZstdCompressor compressor = new ZstdJavaCompressor(); byte[] out = new byte[compressor.maxCompressedLength(input.length)]; @@ -226,14 +196,6 @@ private static byte[] toLeBytes(int[] values) { return buf.array(); } - private static byte[] toLeBytes(long[] values) { - ByteBuffer buf = ByteBuffer.allocate(values.length * 8).order(ByteOrder.LITTLE_ENDIAN); - for (long v : values) { - buf.putLong(v); - } - return buf.array(); - } - private static byte[] toLengthPrefixed(String[] strings) { int total = 0; for (String s : strings) { @@ -250,8 +212,6 @@ private static byte[] toLengthPrefixed(String[] strings) { @Test void decode_withDictionary_utf8_roundTrips() { - // Given - var sut = new ZstdEncoding(); String[] strings = {"hello", "world", "zstd"}; byte[] raw = toLengthPrefixed(strings); byte[] dictBytes = makeDictFor(raw); @@ -260,10 +220,8 @@ void decode_withDictionary_utf8_roundTrips() { java.util.List.of(new ZstdFrameMetadata(raw.length, strings.length))).encode(); DecodeContext ctx = makeDictCtx(meta, DTypes.UTF8, strings.length, dictBytes, compressed); - // When - VarBinArray result = (VarBinArray) sut.decode(ctx); + VarBinArray result = (VarBinArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(strings.length); for (int i = 0; i < strings.length; i++) { assertThat(result.getString(i)).as("index %d", i).isEqualTo(strings[i]); @@ -272,8 +230,6 @@ void decode_withDictionary_utf8_roundTrips() { @Test void decode_withDictionary_multipleFrames_roundTrips() { - // Given - var sut = new ZstdEncoding(); int[] frame0 = {1, 2, 3}; int[] frame1 = {4, 5}; byte[] raw0 = toLeBytes(frame0); @@ -286,10 +242,8 @@ void decode_withDictionary_multipleFrames_roundTrips() { new ZstdFrameMetadata(raw1.length, frame1.length))).encode(); DecodeContext ctx = makeDictCtx(meta, DTypes.I32, 5, dictBytes, comp0, comp1); - // When - IntArray result = (IntArray) sut.decode(ctx); + IntArray result = (IntArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(5); for (int i = 0; i < 3; i++) { assertThat(result.getInt(i)).isEqualTo(frame0[i]); @@ -301,11 +255,7 @@ void decode_withDictionary_multipleFrames_roundTrips() { @Test void decode_nullable_primitive_scattersValuesCorrectly() { - // Given - var sut = new ZstdEncoding(); - // validity: [true, false, true, false] — positions 0,2 are valid boolean[] validityBits = {true, false, true, false}; - // only valid values compressed: 10, 30 byte[] raw = toLeBytes(new int[]{10, 30}); byte[] compressed = compress(raw); DType i32Nullable = new DType.Primitive(PType.I32, true); @@ -313,10 +263,8 @@ void decode_nullable_primitive_scattersValuesCorrectly() { metaNoDict(new long[]{raw.length}, new long[]{2}), i32Nullable, 4, validityBits, compressed); - // When - MaskedArray result = (MaskedArray) sut.decode(ctx); + MaskedArray result = (MaskedArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(4); assertThat(result.isValid(0)).isTrue(); assertThat(result.isValid(1)).isFalse(); @@ -329,11 +277,7 @@ void decode_nullable_primitive_scattersValuesCorrectly() { @Test void decode_nullable_utf8_scattersValuesCorrectly() { - // Given - var sut = new ZstdEncoding(); - // validity: [true, false, true] — positions 0,2 are valid boolean[] validityBits = {true, false, true}; - // only valid strings compressed byte[] raw = toLengthPrefixed(new String[]{"hello", "world"}); byte[] compressed = compress(raw); DType utf8Nullable = new DType.Utf8(true); @@ -341,10 +285,8 @@ void decode_nullable_utf8_scattersValuesCorrectly() { metaNoDict(new long[]{raw.length}, new long[]{2}), utf8Nullable, 3, validityBits, compressed); - // When - MaskedArray result = (MaskedArray) sut.decode(ctx); + MaskedArray result = (MaskedArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(3); assertThat(result.isValid(0)).isTrue(); assertThat(result.isValid(1)).isFalse(); @@ -356,10 +298,7 @@ void decode_nullable_utf8_scattersValuesCorrectly() { @Test void decode_allNull_returnsEmptyMaskedArray() { - // Given - var sut = new ZstdEncoding(); boolean[] validityBits = {false, false, false}; - // no valid values — zero-length compressed buffer byte[] raw = new byte[0]; byte[] compressed = compress(raw); DType i32Nullable = new DType.Primitive(PType.I32, true); @@ -367,10 +306,8 @@ void decode_allNull_returnsEmptyMaskedArray() { metaNoDict(new long[]{raw.length}, new long[]{0}), i32Nullable, 3, validityBits, compressed); - // When - MaskedArray result = (MaskedArray) sut.decode(ctx); + MaskedArray result = (MaskedArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(3); assertThat(result.isValid(0)).isFalse(); assertThat(result.isValid(1)).isFalse(); @@ -379,14 +316,11 @@ void decode_allNull_returnsEmptyMaskedArray() { @Test void decode_missingMetadata_throwsVortexException() { - // Given - var sut = new ZstdEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ZSTD, null, new ArrayNode[0], new int[0], null); DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0, new MemorySegment[0], - Registry.empty(), Arena.ofAuto()); + ReadRegistry.empty(), Arena.ofAuto()); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> DECODER.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("missing metadata"); } @@ -397,20 +331,14 @@ class Metadata { @Test void encode_i32_metadata_framesCount_isNonZero() throws Exception { - // Given — any non-empty encode produces at least one zstd frame - // if tag drifts, frames list is empty and decode silently produces no data int[] data = new int[100]; for (int i = 0; i < data.length; i++) { data[i] = i; } - ZstdEncoding sut = new ZstdEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - ZstdMetadata meta = - ZstdMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + ZstdMetadata meta = ZstdMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.frames().size()).isGreaterThan(0); } }