From 66c17ef509f8f33372ec0b232e1c7f8edb53cce0 Mon Sep 17 00:00:00 2001 From: guangtao Date: Mon, 30 Mar 2026 01:19:46 -0700 Subject: [PATCH 01/16] Add experimental Arrow Flight support --- .github/workflows/ci.yml | 44 + .github/workflows/ci_nightly.yml | 45 + Project.toml | 17 +- README.md | 28 +- dev/release/rat_exclude_files.txt | 1 + docs/src/manual.md | 10 + ext/ArrowgRPCServerExt.jl | 29 + ext/arrowgrpcserverext/constants.jl | 20 + ext/arrowgrpcserverext/context.jl | 46 + ext/arrowgrpcserverext/descriptor.jl | 39 + ext/arrowgrpcserverext/handlers.jl | 111 ++ ext/arrowgrpcserverext/streams.jl | 39 + src/Arrow.jl | 23 +- src/ArrowTypes/Project.toml | 2 +- src/ArrowTypes/src/ArrowTypes.jl | 93 +- src/ArrowTypes/test/tests.jl | 78 + src/arraytypes/arraytypes.jl | 84 +- src/arraytypes/bool.jl | 20 +- src/arraytypes/fixedsizelist.jl | 60 +- src/arraytypes/list.jl | 184 ++- src/arraytypes/map.jl | 115 +- src/arraytypes/primitive.jl | 16 + src/arraytypes/struct.jl | 65 +- src/flight/Flight.jl | 33 + src/flight/client.jl | 25 + src/flight/client/auth.jl | 91 ++ src/flight/client/constants.jl | 22 + src/flight/client/headers.jl | 90 ++ src/flight/client/locations.jl | 33 + src/flight/client/methods/actions.jl | 68 + src/flight/client/methods/data.jl | 96 ++ src/flight/client/methods/discovery.jl | 84 + src/flight/client/protocol_clients.jl | 78 + src/flight/client/rpc_methods.jl | 20 + src/flight/client/transport.jl | 132 ++ src/flight/client/types.jl | 93 ++ src/flight/convert.jl | 22 + src/flight/convert/constants.jl | 21 + src/flight/convert/flightdata.jl | 108 ++ src/flight/convert/framing.jl | 57 + src/flight/convert/schema.jl | 65 + src/flight/convert/streaming.jl | 80 + src/flight/exports.jl | 46 + src/flight/generated/arrow/arrow.jl | 24 + src/flight/generated/arrow/flight/flight.jl | 22 + .../arrow/flight/protocol/Flight_pb.jl | 1359 +++++++++++++++++ .../arrow/flight/protocol/protocol.jl | 24 + src/flight/generated/google/google.jl | 22 + .../generated/google/protobuf/protobuf.jl | 22 + .../generated/google/protobuf/timestamp_pb.jl | 61 + src/flight/proto/Flight.proto | 678 ++++++++ src/flight/protocol.jl | 22 + src/flight/server.jl | 21 + src/flight/server/descriptors.jl | 136 ++ src/flight/server/dispatch.jl | 59 + src/flight/server/handlers.jl | 96 ++ src/flight/server/types.jl | 45 + src/utils.jl | 287 +++- src/write.jl | 11 + test/Project.toml | 24 +- test/flight.jl | 31 + test/flight/client_surface.jl | 28 + .../client_surface/constructor_tests.jl | 52 + .../flight/client_surface/header_tls_tests.jl | 61 + .../client_surface/protocol_client_tests.jl | 52 + test/flight/client_surface/support.jl | 29 + test/flight/grpcserver_extension.jl | 37 + .../bidi_streaming_tests.jl | 107 ++ .../grpcserver_extension/descriptor_tests.jl | 34 + .../server_streaming_tests.jl | 118 ++ .../grpcserver_extension/streaming_tests.jl | 24 + test/flight/grpcserver_extension/support.jl | 21 + .../grpcserver_extension/support/context.jl | 27 + .../grpcserver_extension/support/fixture.jl | 54 + .../grpcserver_extension/support/service.jl | 73 + .../grpcserver_extension/support/streams.jl | 49 + .../grpcserver_extension/unary_tests.jl | 40 + test/flight/handshake_interop.jl | 71 + test/flight/header_interop.jl | 67 + test/flight/ipc_conversion.jl | 78 + test/flight/ipc_schema_separation.jl | 43 + test/flight/poll_interop.jl | 59 + test/flight/pyarrow_interop.jl | 44 + .../flight/pyarrow_interop/discovery_tests.jl | 37 + test/flight/pyarrow_interop/download_tests.jl | 32 + test/flight/pyarrow_interop/exchange_tests.jl | 39 + test/flight/pyarrow_interop/support.jl | 53 + test/flight/pyarrow_interop/upload_tests.jl | 44 + test/flight/server_core.jl | 30 + test/flight/server_core/descriptor_tests.jl | 37 + .../server_core/direct_handler_tests.jl | 50 + test/flight/server_core/dispatch_tests.jl | 56 + test/flight/server_core/metadata_tests.jl | 25 + test/flight/server_core/support.jl | 59 + test/flight/support.jl | 43 + test/flight/support/grpc.jl | 33 + test/flight/support/paths.jl | 55 + test/flight/support/python_servers.jl | 78 + test/flight/support/streams.jl | 26 + test/flight/support/tls.jl | 65 + test/flight/support/types.jl | 23 + test/flight/tls_interop.jl | 75 + test/flight_grpcserver.jl | 75 + test/flight_handshake_server.py | 68 + test/flight_headers_server.py | 81 + test/flight_poll_server.py | 125 ++ test/flight_pyarrow_server.py | 146 ++ test/flight_tls_server.py | 111 ++ test/runtests.jl | 323 ++++ 109 files changed, 8260 insertions(+), 74 deletions(-) create mode 100644 ext/ArrowgRPCServerExt.jl create mode 100644 ext/arrowgrpcserverext/constants.jl create mode 100644 ext/arrowgrpcserverext/context.jl create mode 100644 ext/arrowgrpcserverext/descriptor.jl create mode 100644 ext/arrowgrpcserverext/handlers.jl create mode 100644 ext/arrowgrpcserverext/streams.jl create mode 100644 src/flight/Flight.jl create mode 100644 src/flight/client.jl create mode 100644 src/flight/client/auth.jl create mode 100644 src/flight/client/constants.jl create mode 100644 src/flight/client/headers.jl create mode 100644 src/flight/client/locations.jl create mode 100644 src/flight/client/methods/actions.jl create mode 100644 src/flight/client/methods/data.jl create mode 100644 src/flight/client/methods/discovery.jl create mode 100644 src/flight/client/protocol_clients.jl create mode 100644 src/flight/client/rpc_methods.jl create mode 100644 src/flight/client/transport.jl create mode 100644 src/flight/client/types.jl create mode 100644 src/flight/convert.jl create mode 100644 src/flight/convert/constants.jl create mode 100644 src/flight/convert/flightdata.jl create mode 100644 src/flight/convert/framing.jl create mode 100644 src/flight/convert/schema.jl create mode 100644 src/flight/convert/streaming.jl create mode 100644 src/flight/exports.jl create mode 100644 src/flight/generated/arrow/arrow.jl create mode 100644 src/flight/generated/arrow/flight/flight.jl create mode 100644 src/flight/generated/arrow/flight/protocol/Flight_pb.jl create mode 100644 src/flight/generated/arrow/flight/protocol/protocol.jl create mode 100644 src/flight/generated/google/google.jl create mode 100644 src/flight/generated/google/protobuf/protobuf.jl create mode 100644 src/flight/generated/google/protobuf/timestamp_pb.jl create mode 100644 src/flight/proto/Flight.proto create mode 100644 src/flight/protocol.jl create mode 100644 src/flight/server.jl create mode 100644 src/flight/server/descriptors.jl create mode 100644 src/flight/server/dispatch.jl create mode 100644 src/flight/server/handlers.jl create mode 100644 src/flight/server/types.jl create mode 100644 test/flight.jl create mode 100644 test/flight/client_surface.jl create mode 100644 test/flight/client_surface/constructor_tests.jl create mode 100644 test/flight/client_surface/header_tls_tests.jl create mode 100644 test/flight/client_surface/protocol_client_tests.jl create mode 100644 test/flight/client_surface/support.jl create mode 100644 test/flight/grpcserver_extension.jl create mode 100644 test/flight/grpcserver_extension/bidi_streaming_tests.jl create mode 100644 test/flight/grpcserver_extension/descriptor_tests.jl create mode 100644 test/flight/grpcserver_extension/server_streaming_tests.jl create mode 100644 test/flight/grpcserver_extension/streaming_tests.jl create mode 100644 test/flight/grpcserver_extension/support.jl create mode 100644 test/flight/grpcserver_extension/support/context.jl create mode 100644 test/flight/grpcserver_extension/support/fixture.jl create mode 100644 test/flight/grpcserver_extension/support/service.jl create mode 100644 test/flight/grpcserver_extension/support/streams.jl create mode 100644 test/flight/grpcserver_extension/unary_tests.jl create mode 100644 test/flight/handshake_interop.jl create mode 100644 test/flight/header_interop.jl create mode 100644 test/flight/ipc_conversion.jl create mode 100644 test/flight/ipc_schema_separation.jl create mode 100644 test/flight/poll_interop.jl create mode 100644 test/flight/pyarrow_interop.jl create mode 100644 test/flight/pyarrow_interop/discovery_tests.jl create mode 100644 test/flight/pyarrow_interop/download_tests.jl create mode 100644 test/flight/pyarrow_interop/exchange_tests.jl create mode 100644 test/flight/pyarrow_interop/support.jl create mode 100644 test/flight/pyarrow_interop/upload_tests.jl create mode 100644 test/flight/server_core.jl create mode 100644 test/flight/server_core/descriptor_tests.jl create mode 100644 test/flight/server_core/direct_handler_tests.jl create mode 100644 test/flight/server_core/dispatch_tests.jl create mode 100644 test/flight/server_core/metadata_tests.jl create mode 100644 test/flight/server_core/support.jl create mode 100644 test/flight/support.jl create mode 100644 test/flight/support/grpc.jl create mode 100644 test/flight/support/paths.jl create mode 100644 test/flight/support/python_servers.jl create mode 100644 test/flight/support/streams.jl create mode 100644 test/flight/support/tls.jl create mode 100644 test/flight/support/types.jl create mode 100644 test/flight/tls_interop.jl create mode 100644 test/flight_grpcserver.jl create mode 100644 test/flight_handshake_server.py create mode 100644 test/flight_headers_server.py create mode 100644 test/flight_poll_server.py create mode 100644 test/flight_pyarrow_server.py create mode 100644 test/flight_tls_server.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aebdfc8a..31f5b03a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -168,6 +168,50 @@ jobs: continue-on-error: false run: > julia --color=yes --project=monorepo -e 'using Pkg; Pkg.test("Arrow")' + flight_interop: + name: Arrow Flight interop - Julia 1 - ubuntu-latest + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: '3.11' + - name: Install Flight Python dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pyarrow grpcio grpcio-tools + - uses: julia-actions/setup-julia@v2 + with: + version: '1' + - uses: actions/cache@v5 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-flight-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-flight-${{ env.cache-name }}- + ${{ runner.os }}-flight- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1.6 + with: + project: . + - name: Dev local ArrowTypes for Arrow.jl tests + shell: julia --project=. {0} + run: | + using Pkg + Pkg.develop(PackageSpec(path="src/ArrowTypes")) + - name: Run Arrow Flight interop tests + env: + ARROW_FLIGHT_PYTHON: ${{ env.pythonLocation }}/bin/python + run: > + julia --color=yes --project=test -e 'using Pkg; + Pkg.develop(PackageSpec(path=".")); + Pkg.develop(PackageSpec(path="src/ArrowTypes")); + Pkg.instantiate(); + using Test, Arrow; + include("test/flight.jl")' docs: name: Documentation runs-on: ubuntu-latest diff --git a/.github/workflows/ci_nightly.yml b/.github/workflows/ci_nightly.yml index fb71886f..9d7d6ce6 100644 --- a/.github/workflows/ci_nightly.yml +++ b/.github/workflows/ci_nightly.yml @@ -106,3 +106,48 @@ jobs: continue-on-error: false run: > julia --color=yes --project=monorepo -e 'using Pkg; Pkg.test("Arrow")' + flight_interop: + name: Arrow Flight interop - Julia nightly - ubuntu-latest + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: '3.11' + - name: Install Flight Python dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pyarrow grpcio grpcio-tools + - uses: julia-actions/setup-julia@v2 + with: + version: 'nightly' + arch: x64 + - uses: actions/cache@v5 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-flight-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-flight-${{ env.cache-name }}- + ${{ runner.os }}-flight- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1.6 + with: + project: . + - name: Dev local ArrowTypes for Arrow.jl tests + shell: julia --project=. {0} + run: | + using Pkg + Pkg.develop(PackageSpec(path="src/ArrowTypes")) + - name: Run Arrow Flight interop tests + env: + ARROW_FLIGHT_PYTHON: ${{ env.pythonLocation }}/bin/python + run: > + julia --color=yes --project=test -e 'using Pkg; + Pkg.develop(PackageSpec(path=".")); + Pkg.develop(PackageSpec(path="src/ArrowTypes")); + Pkg.instantiate(); + using Test, Arrow; + include("test/flight.jl")' diff --git a/Project.toml b/Project.toml index b87ff796..5b4a3bcc 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ version = "2.8.1" [deps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" +Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1" CodecLz4 = "5ba52731-8f18-5e0d-9241-30f10d1ec561" CodecZstd = "6b39b394-51ab-5f42-8807-6242bab2b4c2" @@ -28,6 +29,8 @@ ConcurrentUtilities = "f0e56b4a-5159-44fe-b623-3e5288b988bb" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429" +gRPCClient = "aaca4a50-36af-4a1d-b878-4c443f2061ad" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" @@ -37,6 +40,15 @@ TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53" TranscodingStreams = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +[weakdeps] +gRPCServer = "608c6337-0d7d-447f-bb69-0f5674ee3959" + +[extensions] +ArrowgRPCServerExt = "gRPCServer" + +[sources] +ArrowTypes = { path = "src/ArrowTypes" } + [compat] ArrowTypes = "1.1,2" BitIntegers = "0.2, 0.3" @@ -45,10 +57,13 @@ CodecZstd = "0.7, 0.8" ConcurrentUtilities = "2" DataAPI = "1" EnumX = "1" +ProtoBuf = "~1.2.1" +gRPCClient = "1" +gRPCServer = "0.1" PooledArrays = "0.5, 1.0" SentinelArrays = "1" StringViews = "1" Tables = "1.1" TimeZones = "1" TranscodingStreams = "0.9.12, 0.10, 0.11" -julia = "1.9" +julia = "1.12" diff --git a/README.md b/README.md index 98bc9fd9..daada49a 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,8 @@ The package can be installed by typing in the following in a Julia REPL: julia> using Pkg; Pkg.add("Arrow") ``` +Arrow.jl currently requires Julia `1.12+`. + ## Local Development When developing on Arrow.jl it is recommended that you run the following to ensure that any @@ -49,6 +51,12 @@ changes to ArrowTypes.jl are immediately available to Arrow.jl without requiring julia --project -e 'using Pkg; Pkg.develop(path="src/ArrowTypes")' ``` +Current write-path notes: + * `Arrow.tobuffer` includes a direct single-partition fast path for eligible inputs + * `Arrow.tobuffer(Tables.partitioner(...))` also includes a targeted direct multi-record-batch path for single-column top-level strings and single-column non-missing binary/code-units columns + * `Arrow.write(io, Tables.partitioner(...))` now reuses that same targeted direct multi-record-batch path instead of always going through the legacy `Writer` orchestration + * multi-column partitions, dictionary-encoded top-level columns, map-heavy inputs, and missing-binary partitions retain the existing writer path + ## Format Support This implementation supports the 1.0 version of the specification, including support for: @@ -60,9 +68,27 @@ This implementation supports the 1.0 version of the specification, including sup It currently doesn't include support for: * Tensors or sparse tensors - * Flight RPC * C data interface +Flight RPC status: + * Experimental `Arrow.Flight` support is available in-tree + * Requires Julia `1.12+` + * Includes generated protocol bindings and complete client constructors for the `FlightService` RPC surface + * Keeps the top-level Flight module shell thin, with exports and generated-protocol setup split out of `src/flight/Flight.jl` + * Includes high-level `FlightData <-> Arrow IPC` helpers for `Arrow.Table`, `Arrow.Stream`, and DoPut payload generation + * Keeps the Flight IPC conversion layer modular under `src/flight/convert/`, with `src/flight/convert.jl` retained as a thin entrypoint + * Includes client helpers for request headers, binary metadata, handshake token reuse, and TLS configuration via `withheaders`, `withtoken`, and `authenticate` + * Keeps the Flight client implementation modular under `src/flight/client/`, with thin entrypoints at `src/flight/client.jl` and `src/flight/client/rpc_methods.jl` + * Includes a transport-agnostic server core (`Service`, `ServerCallContext`, `ServiceDescriptor`, `MethodDescriptor`) for local Flight method dispatch, path lookup, and handler testing + * Keeps the transport-agnostic server core modular under `src/flight/server/`, with `src/flight/server.jl` retained as a thin entrypoint + * Includes an optional `gRPCServer.jl` package extension that maps `Arrow.Flight.Service` into `gRPCServer.ServiceDescriptor` and registers Flight proto types with the external server package when it is present + * Keeps the optional `gRPCServer.jl` bridge modular under `ext/arrowgrpcserverext/`, with `ext/ArrowgRPCServerExt.jl` retained as a thin entrypoint + * Includes optional live interoperability coverage for `Handshake`, authenticated token propagation, `PollFlightInfo`, and TLS via dedicated Python reference servers + * Includes optional live `pyarrow.flight` interoperability coverage for `ListFlights`, `GetFlightInfo`, `GetSchema`, `DoGet`, `DoPut`, `DoExchange`, `ListActions`, and `DoAction` + * Keeps targeted Flight verification modular under `test/flight/`, with `test/flight.jl` retained as a thin entrypoint for local and CI invocation stability, the client-constructor/protocol-wrapper checks decomposed under `test/flight/client_surface/`, the optional `gRPCServer` extension scenarios decomposed under `test/flight/grpcserver_extension/`, the `pyarrow.flight` interop scenarios decomposed under `test/flight/pyarrow_interop/`, and the transport-agnostic server-core checks decomposed under `test/flight/server_core/` + * Includes `test/flight_grpcserver.jl` as a temporary-environment runner for optional native `gRPCServer` coverage without mutating `test/Project.toml` + * Dedicated CI jobs now exercise the Flight interop suite on stable and nightly Linux; native Julia server transport remains optional/experimental and is not part of the default Flight suite + Third-party data formats: * CSV, parquet and avro support via the existing [CSV.jl](https://github.com/JuliaData/CSV.jl), [Parquet.jl](https://github.com/JuliaIO/Parquet.jl) and [Avro.jl](https://github.com/JuliaData/Avro.jl) packages * Other Tables.jl-compatible packages automatically supported ([DataFrames.jl](https://github.com/JuliaData/DataFrames.jl), [JSONTables.jl](https://github.com/JuliaData/JSONTables.jl), [JuliaDB.jl](https://github.com/JuliaData/JuliaDB.jl), [SQLite.jl](https://github.com/JuliaDatabases/SQLite.jl), [MySQL.jl](https://github.com/JuliaDatabases/MySQL.jl), [JDBC.jl](https://github.com/JuliaDatabases/JDBC.jl), [ODBC.jl](https://github.com/JuliaDatabases/ODBC.jl), [XLSX.jl](https://github.com/felipenoris/XLSX.jl), etc.) diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 6e32d072..728dee92 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -16,6 +16,7 @@ # under the License. Manifest.toml +*/Manifest.toml dev/release/apache-rat-*.jar dev/release/filtered_rat.txt dev/release/rat.xml diff --git a/docs/src/manual.md b/docs/src/manual.md index 5a3330fd..264b8170 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -97,10 +97,20 @@ One note on performance: when writing `TimeZones.ZonedDateTime` columns to the a as the column has `ZonedDateTime` elements that all share a common timezone. This ensures the writing process can know "upfront" which timezone will be encoded and is thus much more efficient and performant. +Similarly, `ArrowTypes.ToArrow` avoids repeated type-promotion work for +homogeneous custom columns even when `ArrowTypes.ArrowType(T)` is abstract, so +write-time conversion does not pay unnecessary overhead once the serialized +element type is stable. + #### Custom types To support writing your custom Julia struct, Arrow.jl utilizes the format's mechanism for "extension types" by allowing the storing of Julia type name and metadata in the field metadata. To "hook in" to this machinery, custom types can utilize the interface methods defined in the `Arrow.ArrowTypes` submodule. For example: +Arrow.jl already uses this mechanism for several Base logical types, including +`nothing`, `Tuple`, `VersionNumber`, and `Complex`, so those values roundtrip as +their original Julia types instead of falling back to plain struct-shaped +`NamedTuple`s. + ```julia using Arrow diff --git a/ext/ArrowgRPCServerExt.jl b/ext/ArrowgRPCServerExt.jl new file mode 100644 index 00000000..af16d70a --- /dev/null +++ b/ext/ArrowgRPCServerExt.jl @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module ArrowgRPCServerExt + +using Arrow +using gRPCServer + +include("arrowgrpcserverext/constants.jl") +include("arrowgrpcserverext/context.jl") +include("arrowgrpcserverext/streams.jl") +include("arrowgrpcserverext/handlers.jl") +include("arrowgrpcserverext/descriptor.jl") + +end # module ArrowgRPCServerExt diff --git a/ext/arrowgrpcserverext/constants.jl b/ext/arrowgrpcserverext/constants.jl new file mode 100644 index 00000000..3fd7ab4a --- /dev/null +++ b/ext/arrowgrpcserverext/constants.jl @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const Flight = Arrow.Flight +const STREAM_BUFFER_SIZE = 16 +const GENERATED_TYPE_PREFIX = "Arrow.Flight.Generated." diff --git a/ext/arrowgrpcserverext/context.jl b/ext/arrowgrpcserverext/context.jl new file mode 100644 index 00000000..88598e58 --- /dev/null +++ b/ext/arrowgrpcserverext/context.jl @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _method_type(method::Flight.MethodDescriptor) + if method.request_streaming + return method.response_streaming ? gRPCServer.MethodType.BIDI_STREAMING : + gRPCServer.MethodType.CLIENT_STREAMING + end + return method.response_streaming ? gRPCServer.MethodType.SERVER_STREAMING : + gRPCServer.MethodType.UNARY +end + +function _call_context(context::gRPCServer.ServerContext) + headers = Flight.HeaderPair[ + String(name) => (value isa String ? value : Vector{UInt8}(value)) for + (name, value) in pairs(context.metadata) + ] + peer = string(context.peer.address, ":", context.peer.port) + return Flight.ServerCallContext( + headers=headers, + peer=peer, + secure=(context.peer.certificate !== nothing), + ) +end + +function _proto_type_name(T::Type) + type_name = string(T) + if startswith(type_name, GENERATED_TYPE_PREFIX) + return type_name[(ncodeunits(GENERATED_TYPE_PREFIX) + 1):end] + end + return type_name +end diff --git a/ext/arrowgrpcserverext/descriptor.jl b/ext/arrowgrpcserverext/descriptor.jl new file mode 100644 index 00000000..d3b6459c --- /dev/null +++ b/ext/arrowgrpcserverext/descriptor.jl @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _register_proto_types!(method::Flight.MethodDescriptor) + registry = gRPCServer.get_type_registry() + registry[_proto_type_name(method.request_type)] = method.request_type + registry[_proto_type_name(method.response_type)] = method.response_type + return nothing +end + +function gRPCServer.service_descriptor(service::Flight.Service) + descriptor = Flight.servicedescriptor(service) + methods = Dict{String,gRPCServer.MethodDescriptor}() + for method in descriptor.methods + _register_proto_types!(method) + methods[method.name] = gRPCServer.MethodDescriptor( + method.name, + _method_type(method), + _proto_type_name(method.request_type), + _proto_type_name(method.response_type), + _handler(service, method), + ) + end + return gRPCServer.ServiceDescriptor(descriptor.name, methods, nothing) +end diff --git a/ext/arrowgrpcserverext/handlers.jl b/ext/arrowgrpcserverext/handlers.jl new file mode 100644 index 00000000..028107e2 --- /dev/null +++ b/ext/arrowgrpcserverext/handlers.jl @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _unary_handler(service::Flight.Service, method::Flight.MethodDescriptor) + return (context, request) -> + Flight.dispatch(service, _call_context(context), method, request) +end + +function _server_streaming_handler(service::Flight.Service, method::Flight.MethodDescriptor) + return (context, request, stream) -> begin + response = Channel{method.response_type}(STREAM_BUFFER_SIZE) + task = @async begin + try + if method.handler_field === :listactions + Flight.listactions(service, _call_context(context), response) + else + Flight.dispatch(service, _call_context(context), method, request, response) + end + finally + close(response) + end + end + try + _drain_response!(stream, response) + _streaming_handler_result(task) + gRPCServer.close!(stream) + finally + istaskdone(task) || wait(task) + end + end +end + +function _client_streaming_handler(service::Flight.Service, method::Flight.MethodDescriptor) + return (context, stream) -> begin + request = Channel{method.request_type}(STREAM_BUFFER_SIZE) + producer = @async begin + try + for message in stream + put!(request, message) + end + finally + close(request) + end + end + task = @async Flight.dispatch(service, _call_context(context), method, request) + try + return fetch(task) + finally + _streaming_handler_result(task, producer) + end + end +end + +function _bidi_streaming_handler(service::Flight.Service, method::Flight.MethodDescriptor) + return (context, stream) -> begin + request = Channel{method.request_type}(STREAM_BUFFER_SIZE) + response = Channel{method.response_type}(STREAM_BUFFER_SIZE) + producer = @async begin + try + for message in stream + put!(request, message) + end + finally + close(request) + end + end + task = @async begin + try + Flight.dispatch(service, _call_context(context), method, request, response) + finally + close(response) + end + end + try + for message in response + gRPCServer.send!(stream, message) + end + _streaming_handler_result(task, producer) + gRPCServer.close!(stream) + finally + istaskdone(task) || wait(task) + isnothing(producer) || (istaskdone(producer) || wait(producer)) + end + return nothing + end +end + +function _handler(service::Flight.Service, method::Flight.MethodDescriptor) + if !method.request_streaming && !method.response_streaming + return _unary_handler(service, method) + elseif !method.request_streaming && method.response_streaming + return _server_streaming_handler(service, method) + elseif method.request_streaming && !method.response_streaming + return _client_streaming_handler(service, method) + end + return _bidi_streaming_handler(service, method) +end diff --git a/ext/arrowgrpcserverext/streams.jl b/ext/arrowgrpcserverext/streams.jl new file mode 100644 index 00000000..22f79251 --- /dev/null +++ b/ext/arrowgrpcserverext/streams.jl @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _drain_response!(stream::gRPCServer.ServerStream, response::Channel) + # gRPCServer falls back to `ServerStream{Any}` when a descriptor only carries + # protobuf type names. Drain generically and let `send!` enforce compatibility. + for message in response + gRPCServer.send!(stream, message) + end + return nothing +end + +function _streaming_handler_result(task::Task, producer::Union{Nothing,Task}=nothing) + if !isnothing(producer) + if istaskfailed(producer) + throw(producer.exception) + end + wait(producer) + end + if istaskfailed(task) + throw(task.exception) + end + wait(task) + return nothing +end diff --git a/src/Arrow.jl b/src/Arrow.jl index 6f3ccdf8..eb9c2188 100644 --- a/src/Arrow.jl +++ b/src/Arrow.jl @@ -29,9 +29,27 @@ This implementation supports the 1.0 version of the specification, including sup It currently doesn't include support for: * Tensors or sparse tensors - * Flight RPC * C data interface +Flight RPC status: + * Experimental `Arrow.Flight` support is available in-tree + * Requires Julia `1.12+` + * Includes generated protocol bindings and client constructors for the `FlightService` RPC surface + * Keeps the top-level Flight module shell thin, with exports and generated-protocol setup split out of `src/flight/Flight.jl` + * Includes high-level `FlightData <-> Arrow IPC` helpers for `Arrow.Table`, `Arrow.Stream`, and DoPut payload generation + * Keeps the Flight IPC conversion layer modular under `src/flight/convert/`, with `src/flight/convert.jl` retained as a thin entrypoint + * Includes client helpers for request headers, binary metadata, handshake token reuse, and TLS configuration via `withheaders`, `withtoken`, and `authenticate` + * Keeps the Flight client implementation modular under `src/flight/client/`, with thin entrypoints at `src/flight/client.jl` and `src/flight/client/rpc_methods.jl` + * Includes a transport-agnostic server core (`Service`, `ServerCallContext`, `ServiceDescriptor`, `MethodDescriptor`) for local Flight method dispatch, path lookup, and handler testing + * Keeps the transport-agnostic server core modular under `src/flight/server/`, with `src/flight/server.jl` retained as a thin entrypoint + * Includes an optional `gRPCServer.jl` package extension that maps `Arrow.Flight.Service` into `gRPCServer.ServiceDescriptor` and registers Flight proto types with the external server package when it is present + * Keeps the optional `gRPCServer.jl` bridge modular under `ext/arrowgrpcserverext/`, with `ext/ArrowgRPCServerExt.jl` retained as a thin entrypoint + * Includes optional live interoperability coverage for `Handshake`, authenticated token propagation, `PollFlightInfo`, and TLS via dedicated Python reference servers + * Includes optional live `pyarrow.flight` interoperability coverage for `ListFlights`, `GetFlightInfo`, `GetSchema`, `DoGet`, `DoPut`, `DoExchange`, `ListActions`, and `DoAction` + * Keeps targeted Flight verification modular under `test/flight/`, with `test/flight.jl` retained as a thin entrypoint for local and CI invocation stability, the client-constructor/protocol-wrapper checks decomposed under `test/flight/client_surface/`, the optional `gRPCServer` extension scenarios decomposed under `test/flight/grpcserver_extension/`, the `pyarrow.flight` interop scenarios decomposed under `test/flight/pyarrow_interop/`, and the transport-agnostic server-core checks decomposed under `test/flight/server_core/` + * Includes `test/flight_grpcserver.jl` as a temporary-environment runner for optional native `gRPCServer` coverage without mutating `test/Project.toml` + * Dedicated CI jobs now exercise the Flight interop suite on stable and nightly Linux; native Julia server transport remains optional/experimental and is not part of the default Flight suite + Third-party data formats: * csv and parquet support via the existing [CSV.jl](https://github.com/JuliaData/CSV.jl) and [Parquet.jl](https://github.com/JuliaIO/Parquet.jl) packages * Other [Tables.jl](https://github.com/JuliaData/Tables.jl)-compatible packages automatically supported ([DataFrames.jl](https://github.com/JuliaData/DataFrames.jl), [JSONTables.jl](https://github.com/JuliaData/JSONTables.jl), [JuliaDB.jl](https://github.com/JuliaData/JuliaDB.jl), [SQLite.jl](https://github.com/JuliaDatabases/SQLite.jl), [MySQL.jl](https://github.com/JuliaDatabases/MySQL.jl), [JDBC.jl](https://github.com/JuliaDatabases/JDBC.jl), [ODBC.jl](https://github.com/JuliaDatabases/ODBC.jl), [XLSX.jl](https://github.com/felipenoris/XLSX.jl), etc.) @@ -55,7 +73,7 @@ using DataAPI, ConcurrentUtilities, StringViews -export ArrowTypes +export ArrowTypes, Flight using Base: @propagate_inbounds import Base: == @@ -79,6 +97,7 @@ include("table.jl") include("write.jl") include("append.jl") include("show.jl") +include("flight/Flight.jl") const ZSTD_COMPRESSOR = Lockable{ZstdCompressor}[] const ZSTD_DECOMPRESSOR = Lockable{ZstdDecompressor}[] diff --git a/src/ArrowTypes/Project.toml b/src/ArrowTypes/Project.toml index 0166f602..50fd3796 100644 --- a/src/ArrowTypes/Project.toml +++ b/src/ArrowTypes/Project.toml @@ -25,4 +25,4 @@ Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] -julia = "1.0" +julia = "1.12" diff --git a/src/ArrowTypes/src/ArrowTypes.jl b/src/ArrowTypes/src/ArrowTypes.jl index fe2223f4..46c1c21c 100644 --- a/src/ArrowTypes/src/ArrowTypes.jl +++ b/src/ArrowTypes/src/ArrowTypes.jl @@ -324,6 +324,14 @@ arrowname(::Type{Tuple{}}) = TUPLE JuliaType(::Val{TUPLE}, ::Type{NamedTuple{names,types}}) where {names,types<:Tuple} = types fromarrow(::Type{T}, x::NamedTuple) where {T<:Tuple} = Tuple(x) +# Complex +const COMPLEX = Symbol("JuliaLang.Complex") +arrowname(::Type{<:Complex}) = COMPLEX +JuliaType(::Val{COMPLEX}, ::Type{NamedTuple{names,Tuple{T,T}}}) where {names,T<:Real} = + Complex{T} +fromarrowstruct(::Type{T}, ::Val{(:re, :im)}, re, im) where {T<:Complex} = T(re, im) +fromarrowstruct(::Type{T}, ::Val{(:im, :re)}, im, re) where {T<:Complex} = T(re, im) + # VersionNumber const VERSION_NUMBER = Symbol("JuliaLang.VersionNumber") ArrowKind(::Type{VersionNumber}) = StructKind() @@ -388,13 +396,67 @@ default(::Type{NamedTuple{names,types}}) where {names,types} = NamedTuple{names}(Tuple(default(fieldtype(types, i)) for i = 1:length(names))) function promoteunion(T, S) + T === S && return T new = promote_type(T, S) return isabstracttype(new) ? Union{T,S} : new end +function _toarroweltype(x) + state = iterate(x) + state === nothing && return Missing + y, st = state + srcT = Union{} + stable = false + T = Missing + if y !== missing + srcT = typeof(y) + mapped = ArrowType(srcT) + stable = isconcretetype(mapped) + T = stable ? mapped : typeof(toarrow(y)) + end + while true + state = iterate(x, st) + state === nothing && return T + y, st = state + if y === missing + S = Missing + elseif srcT === Union{} + srcT = typeof(y) + mapped = ArrowType(srcT) + stable = isconcretetype(mapped) + S = stable ? mapped : typeof(toarrow(y)) + elseif stable && typeof(y) === srcT + continue + else + S = typeof(toarrow(y)) + if stable && typeof(y) !== srcT + stable = false + end + end + S === T && continue + T = promoteunion(T, S) + end +end + +@inline _hasoffsetaxes(data) = Base.has_offset_axes(data) +@inline _offsetshift(data) = _hasoffsetaxes(data) ? firstindex(data) - 1 : 0 +@inline _hasonebasedaxes(data) = !_hasoffsetaxes(data) + # lazily call toarrow(x) on getindex for each x in data struct ToArrow{T,A} <: AbstractVector{T} data::A + offset::Int + needsconvert::Bool +end +@inline _sourcedata(x::ToArrow) = getfield(x, :data) +@inline _sourceoffset(x::ToArrow) = getfield(x, :offset) +@inline _needsconvert(x::ToArrow) = getfield(x, :needsconvert) +@inline _sourcevalue(x::ToArrow, i::Integer) = + @inbounds getindex(_sourcedata(x), i + _sourceoffset(x)) + +function ToArrow{T,A}(data::A) where {T,A} + needsconvert = !(eltype(A) === T && concrete_or_concreteunion(T)) + return ToArrow{T,A}(data, _offsetshift(data), needsconvert) end concrete_or_concreteunion(T) = @@ -404,15 +466,14 @@ concrete_or_concreteunion(T) = function ToArrow(x::A) where {A} S = eltype(A) T = ArrowType(S) - fi = firstindex(x) - if S === T && concrete_or_concreteunion(S) && fi == 1 + if S === T && concrete_or_concreteunion(S) && _hasonebasedaxes(x) return x elseif !concrete_or_concreteunion(T) # arrow needs concrete types, so try to find a concrete common type, preferring unions if isempty(x) return Missing[] end - T = mapreduce(typeof ∘ toarrow, promoteunion, x) + T = _toarroweltype(x) if T === Missing && concrete_or_concreteunion(S) T = promoteunion(T, typeof(toarrow(default(S)))) end @@ -440,7 +501,29 @@ function _convert(::Type{T}, x) where {T} return convert(T, x) end end -Base.getindex(x::ToArrow{T}, i::Int) where {T} = - _convert(T, toarrow(getindex(x.data, i + firstindex(x.data) - 1))) + +@inline function _toarrowvalue(x::ToArrow{T}, value) where {T} + _needsconvert(x) || return value + return _convert(T, toarrow(value)) +end + +Base.@propagate_inbounds function Base.getindex(x::ToArrow{T}, i::Int) where {T} + value = _sourcevalue(x, i) + return _toarrowvalue(x, value) +end + +function Base.iterate(x::ToArrow) + state = iterate(x.data) + state === nothing && return nothing + value, st = state + return _toarrowvalue(x, value), st +end + +function Base.iterate(x::ToArrow, st) + state = iterate(x.data, st) + state === nothing && return nothing + value, st = state + return _toarrowvalue(x, value), st +end end # module ArrowTypes diff --git a/src/ArrowTypes/test/tests.jl b/src/ArrowTypes/test/tests.jl index 22d8dd0e..e5822f43 100644 --- a/src/ArrowTypes/test/tests.jl +++ b/src/ArrowTypes/test/tests.jl @@ -144,6 +144,17 @@ end @test ArrowTypes.default(Tuple{Vararg{Int}}) == () @test ArrowTypes.default(Tuple{String,Vararg{Int}}) == ("",) + z = 1.0 + 2.0im + @test ArrowTypes.ArrowKind(typeof(z)) == ArrowTypes.StructKind() + @test ArrowTypes.arrowname(typeof(z)) == ArrowTypes.COMPLEX + @test ArrowTypes.arrowname(Union{Missing,typeof(z)}) == ArrowTypes.COMPLEX + @test ArrowTypes.JuliaType( + Val(ArrowTypes.COMPLEX), + NamedTuple{(:re, :im),Tuple{Float64,Float64}}, + ) == ComplexF64 + @test ArrowTypes.fromarrowstruct(ComplexF64, Val((:re, :im)), 1.0, 2.0) == z + @test ArrowTypes.fromarrowstruct(ComplexF64, Val((:im, :re)), 2.0, 1.0) == z + v = v"1" v_nt = (major=1, minor=0, patch=0, prerelease=(), build=()) @test ArrowTypes.ArrowKind(VersionNumber) == ArrowTypes.StructKind() @@ -167,39 +178,97 @@ end @test ArrowTypes.promoteunion(Int, Float64) == Float64 @test ArrowTypes.promoteunion(Int, String) == Union{Int,String} + @test ArrowTypes.promoteunion(Int, Int) == Int @test ArrowTypes.concrete_or_concreteunion(Int) @test !ArrowTypes.concrete_or_concreteunion(Union{Real,String}) @test !ArrowTypes.concrete_or_concreteunion(Any) @testset "ToArrow" begin + @test !ArrowTypes._hasoffsetaxes([1, 2, 3]) + @test ArrowTypes._offsetshift([1, 2, 3]) == 0 + x = ArrowTypes.ToArrow([1, 2, 3]) @test x isa Vector{Int} @test x == [1, 2, 3] + baseview = @view [1, 2, 3][1:3] + x = ArrowTypes.ToArrow(baseview) + @test x === baseview + x = ArrowTypes.ToArrow([:hey, :ho]) @test x isa ArrowTypes.ToArrow{String,Vector{Symbol}} @test eltype(x) == String + @test ArrowTypes._needsconvert(x) + @test x[1] == "hey" + @test collect(x) == ["hey", "ho"] @test x == ["hey", "ho"] x = ArrowTypes.ToArrow(Any[1, 3.14]) @test x isa ArrowTypes.ToArrow{Float64,Vector{Any}} @test eltype(x) == Float64 + @test collect(x) == [1.0, 3.14] @test x == [1.0, 3.14] + x = ArrowTypes.ToArrow(Any[UUID(UInt128(1)), UUID(UInt128(2))]) + @test x isa ArrowTypes.ToArrow{NTuple{16,UInt8},Vector{Any}} + @test eltype(x) == NTuple{16,UInt8} + @test collect(x) == + [ArrowTypes.toarrow(UUID(UInt128(1))), ArrowTypes.toarrow(UUID(UInt128(2)))] + + x = ArrowTypes.ToArrow(Any[missing, UUID(UInt128(1))]) + @test x isa ArrowTypes.ToArrow{Union{Missing,NTuple{16,UInt8}},Vector{Any}} + @test eltype(x) == Union{Missing,NTuple{16,UInt8}} + @test isequal( + collect(x), + Union{Missing,NTuple{16,UInt8}}[missing, ArrowTypes.toarrow(UUID(UInt128(1)))], + ) + x = ArrowTypes.ToArrow(Any[1, 3.14, "hey"]) @test x isa ArrowTypes.ToArrow{Union{Float64,String},Vector{Any}} @test eltype(x) == Union{Float64,String} + @test collect(x) == Union{Float64,String}[1.0, 3.14, "hey"] @test x == [1.0, 3.14, "hey"] + x = ArrowTypes.ToArrow(Any[UUID(UInt128(1)), "tail"]) + @test x isa ArrowTypes.ToArrow{Union{NTuple{16,UInt8},String},Vector{Any}} + @test eltype(x) == Union{NTuple{16,UInt8},String} + @test collect(x) == + Union{NTuple{16,UInt8},String}[ArrowTypes.toarrow(UUID(UInt128(1))), "tail"] + x = ArrowTypes.ToArrow(OffsetArray([1, 2, 3], -3:-1)) @test x isa ArrowTypes.ToArrow{Int,OffsetVector{Int,Vector{Int}}} + @test ArrowTypes._hasoffsetaxes(getfield(x, :data)) + @test getfield(x, :offset) == ArrowTypes._offsetshift(getfield(x, :data)) + @test ArrowTypes._sourcedata(x) === getfield(x, :data) + @test ArrowTypes._sourceoffset(x) == getfield(x, :offset) + @test !ArrowTypes._needsconvert(x) + @test ArrowTypes._sourcevalue(x, 1) == 1 @test eltype(x) == Int + @test x[1] == 1 + @test x[3] == 3 + @test collect(x) == [1, 2, 3] @test x == [1, 2, 3] + x = ArrowTypes.ToArrow(OffsetArray(Union{Missing,Int}[1, missing], -3:-2)) + @test x isa ArrowTypes.ToArrow{ + Union{Missing,Int}, + OffsetVector{Union{Missing,Int},Vector{Union{Missing,Int}}}, + } + @test !ArrowTypes._needsconvert(x) + @test x[1] == 1 + @test x[2] === missing + @test isequal(collect(x), Union{Missing,Int}[1, missing]) + x = ArrowTypes.ToArrow(OffsetArray(Any[1, 3.14], -3:-2)) @test x isa ArrowTypes.ToArrow{Float64,OffsetVector{Any,Vector{Any}}} + @test getfield(x, :offset) == ArrowTypes._offsetshift(getfield(x, :data)) + @test ArrowTypes._sourcevalue(x, 2) == 3.14 @test eltype(x) == Float64 + @test ArrowTypes._needsconvert(x) + @test x[1] == 1 + @test x[2] == 3.14 + @test collect(x) == [1.0, 3.14] @test x == [1, 3.14] @testset "respect non-missing concrete type" begin @@ -219,6 +288,15 @@ end T = Union{DateTimeTZ,Missing} @test !ArrowTypes.concrete_or_concreteunion(ArrowTypes.ArrowType(T)) @test eltype(ArrowTypes.ToArrow(T[missing])) == Union{Timestamp{:UTC},Missing} + @test eltype( + ArrowTypes.ToArrow(DateTimeTZ[DateTimeTZ(1, "UTC"), DateTimeTZ(2, "UTC")]), + ) == Timestamp{:UTC} + @test eltype( + ArrowTypes.ToArrow(DateTimeTZ[DateTimeTZ(1, "UTC"), DateTimeTZ(2, "PST")]), + ) == Timestamp + @test eltype( + ArrowTypes.ToArrow(Any[DateTimeTZ(1, "UTC"), DateTimeTZ(2, "UTC")]), + ) == Timestamp{:UTC} # Works since `ArrowTypes.default(Any) === nothing` and # `ArrowTypes.toarrow(nothing) === missing`. Defining `toarrow(::Nothing) = nothing` diff --git a/src/arraytypes/arraytypes.jl b/src/arraytypes/arraytypes.jl index 58bab082..20bbce2d 100644 --- a/src/arraytypes/arraytypes.jl +++ b/src/arraytypes/arraytypes.jl @@ -99,7 +99,7 @@ function arrowvector( dictencode=dictencode, kw..., ) - elseif !(x isa DictEncode) + elseif !(x isa DictEncode) && !_keeprawmapvector(T, x) x = ToArrow(x) end S = maybemissing(eltype(x)) @@ -144,6 +144,64 @@ function _arrowtypemeta(meta, n, m) return toidict(dict) end +@inline function _materializeconverted(x::ArrowTypes.ToArrow) + data = ArrowTypes._sourcedata(x) + if ArrowTypes._needsconvert(x) && !ArrowTypes.concrete_or_concreteunion(eltype(data)) + return _materializeconverted(eltype(x), x) + end + return x +end + +function _materializeconverted(::Type{T}, x::ArrowTypes.ToArrow{T,A}) where {T,A} + len = length(x) + data = Vector{T}(undef, len) + source = ArrowTypes._sourcedata(x) + i = 1 + for value in source + @inbounds data[i] = + value isa T ? value : ArrowTypes._convert(T, ArrowTypes.toarrow(value)) + i += 1 + end + return data +end + +@inline function _materializefixedbytes16(value) + if value isa ArrowTypes.UUID + return ArrowTypes._cast(NTuple{16,UInt8}, value.value) + elseif value isa NTuple{16,UInt8} + return value + else + return ArrowTypes._convert(NTuple{16,UInt8}, ArrowTypes.toarrow(value)) + end +end + +function _materializeconverted( + ::Type{NTuple{16,UInt8}}, + x::ArrowTypes.ToArrow{NTuple{16,UInt8},A}, +) where {A} + len = length(x) + data = Vector{NTuple{16,UInt8}}(undef, len) + source = ArrowTypes._sourcedata(x) + i = 1 + for value in source + @inbounds data[i] = _materializefixedbytes16(value) + i += 1 + end + return data +end + +@inline _toarrowvaliditysource(x::ArrowTypes.ToArrow) = + ArrowTypes._needsconvert(x) ? x : ArrowTypes._sourcedata(x) + +@inline _toarrowvalidity(x::ArrowTypes.ToArrow, data) = + data === x ? ValidityBitmap(x) : ValidityBitmap(data) + +@inline function _keeprawmapvector(::Type{T}, x) where {T} + return Base.has_offset_axes(x) && + ArrowTypes.concrete_or_concreteunion(T) && + ArrowKind(T) isa ArrowTypes.MapKind +end + # now we check for ArrowType converions and dispatch on ArrowKind function arrowvector(::Type{S}, x, i, nl, fi, de, ded, meta; kw...) where {S} meta = _normalizemeta(meta) @@ -201,15 +259,10 @@ end Base.size(p::ValidityBitmap) = (p.ℓ,) nullcount(x::ValidityBitmap) = x.nc -function ValidityBitmap(x) - T = eltype(x) - if !(T >: Missing) - return ValidityBitmap(UInt8[], 1, length(x), 0) - end +function _validitybitmap(x, len) len = length(x) blen = cld(len, 8) bytes = Vector{UInt8}(undef, blen) - st = iterate(x) nc = 0 b = 0xff j = k = 1 @@ -232,6 +285,23 @@ function ValidityBitmap(x) return ValidityBitmap(nc == 0 ? UInt8[] : bytes, 1, nc == 0 ? 0 : len, nc) end +function ValidityBitmap(x) + T = eltype(x) + if !(T >: Missing) + return ValidityBitmap(UInt8[], 1, length(x), 0) + end + return _validitybitmap(x, length(x)) +end + +function ValidityBitmap(x::ArrowTypes.ToArrow) + T = eltype(x) + if !(T >: Missing) + return ValidityBitmap(UInt8[], 1, length(x), 0) + end + source = _toarrowvaliditysource(x) + return _validitybitmap(source, length(x)) +end + @propagate_inbounds function Base.getindex(p::ValidityBitmap, i::Integer) # no boundscheck because parent array should do it # if a validity bitmap is empty, it either means: diff --git a/src/arraytypes/bool.jl b/src/arraytypes/bool.jl index 29c1505a..8a33668a 100644 --- a/src/arraytypes/bool.jl +++ b/src/arraytypes/bool.jl @@ -52,9 +52,7 @@ end arrowvector(::BoolKind, x::BoolVector, i, nl, fi, de, ded, meta; kw...) = x -function arrowvector(::BoolKind, x, i, nl, fi, de, ded, meta; kw...) - validity = ValidityBitmap(x) - len = length(x) +function _packboolbytes(x, len) blen = cld(len, 8) bytes = Vector{UInt8}(undef, blen) b = 0xff @@ -74,9 +72,25 @@ function arrowvector(::BoolKind, x, i, nl, fi, de, ded, meta; kw...) if j > 1 bytes[k] = b end + return bytes +end + +function arrowvector(::BoolKind, x, i, nl, fi, de, ded, meta; kw...) + validity = ValidityBitmap(x) + len = length(x) + bytes = _packboolbytes(x, len) return BoolVector{eltype(x)}(bytes, 1, validity, len, meta) end +function arrowvector(::BoolKind, x::ArrowTypes.ToArrow, i, nl, fi, de, ded, meta; kw...) + data = _materializeconverted(x) + validity = _toarrowvalidity(x, data) + len = length(data) + source = data === x ? _toarrowvaliditysource(x) : data + bytes = _packboolbytes(source, len) + return BoolVector{eltype(data)}(bytes, 1, validity, len, meta) +end + function compress(Z::Meta.CompressionType.T, comp, p::P) where {P<:BoolVector} len = length(p) nc = nullcount(p) diff --git a/src/arraytypes/fixedsizelist.jl b/src/arraytypes/fixedsizelist.jl index 2558dd54..4e8f74cf 100644 --- a/src/arraytypes/fixedsizelist.jl +++ b/src/arraytypes/fixedsizelist.jl @@ -81,6 +81,8 @@ struct ToFixedSizeList{T,N,A} <: AbstractVector{T} end origtype(::ToFixedSizeList{T,N,A}) where {T,N,A} = eltype(A) +@inline _fixedsizedata(A::ToFixedSizeList) = getfield(A, :data) +@inline _fixedsizevalue(A::ToFixedSizeList, i::Integer) = @inbounds _fixedsizedata(A)[i] function ToFixedSizeList(input) NT = ArrowTypes.ArrowKind(Base.nonmissingtype(eltype(input))) # typically NTuple{N, T} @@ -90,7 +92,7 @@ function ToFixedSizeList(input) end Base.IndexStyle(::Type{<:ToFixedSizeList}) = Base.IndexLinear() -Base.size(x::ToFixedSizeList{T,N}) where {T,N} = (N * length(x.data),) +Base.size(x::ToFixedSizeList{T,N}) where {T,N} = (N * length(_fixedsizedata(x)),) Base.@propagate_inbounds function Base.getindex( A::ToFixedSizeList{T,N}, @@ -98,7 +100,7 @@ Base.@propagate_inbounds function Base.getindex( ) where {T,N} @boundscheck checkbounds(A, i) a, b = fldmod1(i, N) - @inbounds x = A.data[a] + x = _fixedsizevalue(A, a) return @inbounds x === missing ? ArrowTypes.default(T) : x[b] end @@ -108,7 +110,7 @@ end (i, chunk, chunk_i, len)=(1, 1, 1, length(A)), ) where {T,N} i > len && return nothing - @inbounds y = A.data[chunk] + y = _fixedsizevalue(A, chunk) @inbounds x = y === missing ? ArrowTypes.default(T) : y[chunk_i] if chunk_i == N chunk += 1 @@ -119,8 +121,60 @@ end return x, (i + 1, chunk, chunk_i, len) end +@inline function _writefixedsizechunk(io::IO, chunk::NTuple{N,UInt8}) where {N} + ref = Ref(chunk) + GC.@preserve ref begin + return Base.unsafe_write(io, Base.unsafe_convert(Ptr{UInt8}, ref), N) + end +end + +@inline function _writefixedsizecontiguous(io::IO, data::Vector{NTuple{N,UInt8}}) where {N} + GC.@preserve data begin + return Base.unsafe_write(io, Ptr{UInt8}(pointer(data)), N * length(data)) + end +end + +function writearray(io::IO, ::Type{UInt8}, col::ToFixedSizeList{UInt8,N}) where {N} + n = 0 + defaultchunk = ntuple(_ -> ArrowTypes.default(UInt8), Val(N)) + data = _fixedsizedata(col) + data isa Vector{NTuple{N,UInt8}} && return _writefixedsizecontiguous(io, data) + for chunk in data + n += _writefixedsizechunk(io, chunk === missing ? defaultchunk : chunk) + end + return n +end + arrowvector(::FixedSizeListKind, x::FixedSizeList, i, nl, fi, de, ded, meta; kw...) = x +function arrowvector( + kind::FixedSizeListKind{N,T}, + x::ArrowTypes.ToArrow, + i, + nl, + fi, + de, + ded, + meta; + kw..., +) where {N,T} + data = _materializeconverted(x) + if data !== x + return arrowvector(kind, data, i, nl, fi, de, ded, meta; kw...) + end + len = length(x) + validity = ValidityBitmap(x) + flat = ToFixedSizeList(x) + if eltype(flat) == UInt8 + child = flat + S = origtype(flat) + else + child = arrowvector(flat, i, nl + 1, fi, de, ded, nothing; kw...) + S = withmissing(eltype(x), NTuple{N,eltype(child)}) + end + return FixedSizeList{S,typeof(child)}(UInt8[], validity, child, len, meta) +end + function arrowvector( ::FixedSizeListKind{N,T}, x, diff --git a/src/arraytypes/list.jl b/src/arraytypes/list.jl index 41ac66f9..5d0cf6d0 100644 --- a/src/arraytypes/list.jl +++ b/src/arraytypes/list.jl @@ -86,13 +86,17 @@ _codeunits(x::Base.CodeUnits) = x # an AbstractVector version of Iterators.flatten # code based on SentinelArrays.ChainedVector -struct ToList{T,stringtype,A,I} <: AbstractVector{T} - data::Vector{A} # A is AbstractVector or AbstractString +struct ToList{T,stringtype,A<:AbstractVector,I} <: AbstractVector{T} + data::A # A is the outer AbstractVector of AbstractVector or AbstractString inds::Vector{I} + offset::Int end -origtype(::ToList{T,S,A,I}) where {T,S,A,I} = A +origtype(::ToList{T,S,A,I}) where {T,S,A<:AbstractVector,I} = eltype(A) liststringtype(::Type{ToList{T,S,A,I}}) where {T,S,A,I} = S +materializeouter(::Type) = false +materializeouter(input) = materializeouter(typeof(input)) +materializeouterdata(input) = materializeouter(input) ? collect(input) : input function liststringtype(::List{T,O,A}) where {T,O,A} ST = Base.nonmissingtype(T) K = ArrowTypes.ArrowKind(ST) @@ -100,42 +104,80 @@ function liststringtype(::List{T,O,A}) where {T,O,A} end liststringtype(T) = false -function ToList(input; largelists::Bool=false) +@inline function _tolisttraits(input) AT = eltype(input) ST = Base.nonmissingtype(AT) K = ArrowTypes.ArrowKind(ST) stringtype = ArrowTypes.isstringtype(K) || ST <: Base.CodeUnits # add the CodeUnits check for ArrowTypes compat for now T = stringtype ? UInt8 : eltype(ST) - len = stringtype ? _ncodeunits : length - data = AT[] + lenf = stringtype ? _ncodeunits : length + return T, stringtype, lenf +end + +@inline function _promotetolistinds(inds::Vector{Int32}, len::Int, filled::Int) + promoted = Vector{Int64}(undef, len + 1) + copyto!(promoted, 1, inds, 1, filled) + return promoted +end + +function _buildtolist(input, data, dataoffset::Int, len::Int; largelists::Bool=false) + T, stringtype, lenf = _tolisttraits(input) I = largelists ? Int64 : Int32 - inds = I[0] - sizehint!(data, length(input)) - sizehint!(inds, length(input)) + inds = Vector{I}(undef, len + 1) + inds[1] = zero(I) totalsize = I(0) - for x in input - if x === missing - push!(data, missing) - else - push!(data, x) - totalsize += len(x) - if I === Int32 && totalsize > 2147483647 + @inbounds for i = 1:len + x = data[i + dataoffset] + if x !== missing + totalsize += lenf(x) + if I === Int32 && totalsize > typemax(Int32) I = Int64 - inds = convert(Vector{Int64}, inds) + inds = _promotetolistinds(inds, len, i) end end - push!(inds, totalsize) + inds[i + 1] = totalsize end - return ToList{T,stringtype,AT,I}(data, inds) + return ToList{T,stringtype,typeof(data),I}(data, inds, dataoffset) +end + +function _tolistgeneric(input; largelists::Bool=false) + data = materializeouterdata(input) + return _buildtolist( + input, + data, + ArrowTypes._offsetshift(data), + length(data); + largelists=largelists, + ) +end + +function ToList(input; largelists::Bool=false) + return _tolistgeneric(input; largelists=largelists) +end + +function ToList(input::ArrowTypes.ToArrow; largelists::Bool=false) + ArrowTypes._needsconvert(input) && return _tolistgeneric(input; largelists=largelists) + data = ArrowTypes._sourcedata(input) + return _buildtolist( + input, + data, + ArrowTypes._sourceoffset(input), + length(input); + largelists=largelists, + ) end Base.IndexStyle(::Type{<:ToList}) = Base.IndexLinear() Base.size(x::ToList{T,S,A,I}) where {T,S,A,I} = (isempty(x.inds) ? zero(I) : x.inds[end],) +@inline _tolistdata(A::ToList) = getfield(A, :data) +@inline _tolistoffset(A::ToList) = getfield(A, :offset) +@inline _tolistchunk(A::ToList, i::Integer) = @inbounds _tolistdata(A)[i + _tolistoffset(A)] + function Base.pointer(A::ToList{UInt8}, i::Integer) chunk = searchsortedfirst(A.inds, i) chunk = chunk > length(A.inds) ? 1 : (chunk - 1) - return pointer(A.data[chunk]) + return pointer(_tolistchunk(A, chunk)) end @inline function index(A::ToList, i::Integer) @@ -149,7 +191,7 @@ Base.@propagate_inbounds function Base.getindex( ) where {T,stringtype} @boundscheck checkbounds(A, i) chunk, ix = index(A, i) - @inbounds x = A.data[chunk] + x = _tolistchunk(A, chunk) return @inbounds stringtype ? _codeunits(x)[ix] : x[ix] end @@ -160,7 +202,7 @@ Base.@propagate_inbounds function Base.setindex!( ) where {T,stringtype} @boundscheck checkbounds(A, i) chunk, ix = index(A, i) - @inbounds x = A.data[chunk] + x = _tolistchunk(A, chunk) if stringtype _codeunits(x)[ix] = v else @@ -180,7 +222,7 @@ end chunk += 1 chunk_len = A.inds[chunk] end - val = A.data[chunk - 1] + val = _tolistchunk(A, chunk - 1) x = stringtype ? _codeunits(val)[1] : val[1] # find next valid index i += 1 @@ -202,7 +244,7 @@ end (i, chunk, chunk_i, chunk_len, len), ) where {T,stringtype} i > len && return nothing - @inbounds val = A.data[chunk - 1] + val = _tolistchunk(A, chunk - 1) @inbounds x = stringtype ? _codeunits(val)[chunk_i] : val[chunk_i] i += 1 if i > chunk_len @@ -219,6 +261,100 @@ end return x, (i, chunk, chunk_i, chunk_len, len) end +@inline function _writeuint8chunk(io::IO, bytes) + GC.@preserve bytes begin + return Base.unsafe_write(io, pointer(bytes), length(bytes)) + end +end + +@inline function _writeutf8chunk(io::IO, chunk::AbstractString) + GC.@preserve chunk begin + return Base.unsafe_write(io, pointer(chunk), ncodeunits(chunk)) + end +end + +@inline function _sizehint_iobuffer!(io::IO, n::Integer) + io isa IOBuffer || return nothing + data = getfield(io, :data) + data isa Vector{UInt8} || return nothing + sizehint!(data, max(length(data), position(io) + n)) + return nothing +end + +function _writearray_tolist_bitstype(io::IO, ::Type{T}, col::ToList{T,false}) where {T} + n = 0 + off = _tolistoffset(col) + data = _tolistdata(col) + if off == 0 + for chunk in data + chunk === missing && continue + n += writearray(io, T, chunk) + end + else + len = length(data) + @inbounds for i = 1:len + chunk = data[i + off] + chunk === missing && continue + n += writearray(io, T, chunk) + end + end + return n +end + +function _writearray_tolist_uint8(io::IO, col::ToList{UInt8,stringtype}) where {stringtype} + n = 0 + _sizehint_iobuffer!(io, length(col)) + off = _tolistoffset(col) + data = _tolistdata(col) + if off == 0 + for chunk in data + chunk === missing && continue + bytes = stringtype ? _codeunits(chunk) : chunk + n += _writeuint8chunk(io, bytes) + end + else + len = length(data) + @inbounds for i = 1:len + chunk = data[i + off] + chunk === missing && continue + bytes = stringtype ? _codeunits(chunk) : chunk + n += _writeuint8chunk(io, bytes) + end + end + return n +end + +function _writearray_tolist_uint8( + io::IO, + col::ToList{UInt8,true,A}, +) where {A<:AbstractVector{<:AbstractString}} + n = 0 + _sizehint_iobuffer!(io, length(col)) + off = _tolistoffset(col) + data = _tolistdata(col) + if off == 0 + for chunk in data + chunk === missing && continue + n += _writeutf8chunk(io, chunk) + end + else + len = length(data) + @inbounds for i = 1:len + chunk = data[i + off] + chunk === missing && continue + n += _writeutf8chunk(io, chunk) + end + end + return n +end + +function writearray(io::IO, ::Type{T}, col::ToList{T,stringtype}) where {T,stringtype} + T === UInt8 && return _writearray_tolist_uint8(io, col) + isbitstype(T) || return _writearrayfallback(io, T, col) + stringtype && return _writearrayfallback(io, T, col) + return _writearray_tolist_bitstype(io, T, col) +end + arrowvector(::ListKind, x::List, i, nl, fi, de, ded, meta; kw...) = x function arrowvector(::ListKind, x, i, nl, fi, de, ded, meta; largelists::Bool=false, kw...) diff --git a/src/arraytypes/map.jl b/src/arraytypes/map.jl index 42160732..c7b82756 100644 --- a/src/arraytypes/map.jl +++ b/src/arraytypes/map.jl @@ -43,8 +43,108 @@ Base.size(l::Map) = (l.ℓ,) end end -keyvalues(KT, ::Missing) = missing -keyvalues(KT, x::AbstractDict) = [KT(k, v) for (k, v) in pairs(x)] +@inline function _promotemapoffsets(offsets::Vector{Int32}, len::Int, filled::Int) + promoted = Vector{Int64}(undef, len + 1) + copyto!(promoted, 1, offsets, 1, filled) + return promoted +end + +function _mapoffsetsandvaluesindexed(::Type{KT}, x; largelists::Bool=false) where {KT} + len = length(x) + O = largelists ? Int64 : Int32 + offsets = Vector{O}(undef, len + 1) + offsets[1] = zero(O) + total = 0 + off = firstindex(x) - 1 + @inbounds for i = 1:len + y = x[i + off] + if y !== missing + total += length(y) + if O === Int32 && total > typemax(Int32) + O = Int64 + offsets = _promotemapoffsets(offsets, len, i) + end + end + offsets[i + 1] = total + end + values = Vector{KT}(undef, total) + pos = 1 + @inbounds for i = 1:len + y = x[i + off] + y === missing && continue + for (k, v) in pairs(y) + values[pos] = KT(k, v) + pos += 1 + end + end + return offsets, values +end + +function mapoffsetsandvalues(::Type{KT}, x; largelists::Bool=false) where {KT} + Base.has_offset_axes(x) && + return _mapoffsetsandvaluesindexed(KT, x; largelists=largelists) + len = length(x) + O = largelists ? Int64 : Int32 + offsets = Vector{O}(undef, len + 1) + offsets[1] = zero(O) + total = 0 + i = 1 + for y in x + if y !== missing + total += length(y) + if O === Int32 && total > typemax(Int32) + O = Int64 + offsets = _promotemapoffsets(offsets, len, i) + end + end + @inbounds offsets[i + 1] = total + i += 1 + end + values = Vector{KT}(undef, total) + pos = 1 + for y in x + y === missing && continue + for (k, v) in pairs(y) + @inbounds values[pos] = KT(k, v) + pos += 1 + end + end + return offsets, values +end + +function mapoffsetsandvalues( + ::Type{KT}, + x::ArrowTypes.ToArrow; + largelists::Bool=false, +) where {KT} + len = length(x) + O = largelists ? Int64 : Int32 + offsets = Vector{O}(undef, len + 1) + offsets[1] = zero(O) + total = 0 + @inbounds for i = 1:len + y = x[i] + if y !== missing + total += length(y) + if O === Int32 && total > typemax(Int32) + O = Int64 + offsets = _promotemapoffsets(offsets, len, i) + end + end + offsets[i + 1] = total + end + values = Vector{KT}(undef, total) + pos = 1 + @inbounds for i = 1:len + y = x[i] + y === missing && continue + for (k, v) in pairs(y) + values[pos] = KT(k, v) + pos += 1 + end + end + return offsets, values +end keyvaluetypes(::Type{NamedTuple{(:key, :value),Tuple{K,V}}}) where {K,V} = (K, V) @@ -67,13 +167,12 @@ function arrowvector(::MapKind, x, i, nl, fi, de, ded, meta; largelists::Bool=fa ), ) KT = KeyValue{KDT,VDT} - VT = Vector{KT} - T = DT !== ET ? Union{Missing,VT} : VT - flat = ToList(T[keyvalues(KT, y) for y in x]; largelists=largelists) - offsets = Offsets(UInt8[], flat.inds) - data = arrowvector(flat, i, nl + 1, fi, de, ded, nothing; largelists=largelists, kw...) + offsetsdata, values = mapoffsetsandvalues(KT, x; largelists=largelists) + offsets = Offsets(UInt8[], offsetsdata) + data = + arrowvector(values, i, nl + 1, fi, de, ded, nothing; largelists=largelists, kw...) K, V = keyvaluetypes(eltype(data)) - return Map{withmissing(ET, Dict{K,V}),eltype(flat.inds),typeof(data)}( + return Map{withmissing(ET, Dict{K,V}),eltype(offsetsdata),typeof(data)}( validity, offsets, data, diff --git a/src/arraytypes/primitive.jl b/src/arraytypes/primitive.jl index 7d86bfe0..fbd6483c 100644 --- a/src/arraytypes/primitive.jl +++ b/src/arraytypes/primitive.jl @@ -70,6 +70,22 @@ function arrowvector(::PrimitiveKind, x, i, nl, fi, de, ded, meta; kw...) return Primitive(eltype(x), UInt8[], validity, x, length(x), meta) end +function arrowvector( + ::PrimitiveKind, + x::ArrowTypes.ToArrow, + i, + nl, + fi, + de, + ded, + meta; + kw..., +) + data = _materializeconverted(x) + validity = _toarrowvalidity(x, data) + return Primitive(eltype(data), UInt8[], validity, data, length(data), meta) +end + function compress(Z::Meta.CompressionType.T, comp, p::P) where {P<:Primitive} len = length(p) nc = nullcount(p) diff --git a/src/arraytypes/struct.jl b/src/arraytypes/struct.jl index 23a8b641..b66633ea 100644 --- a/src/arraytypes/struct.jl +++ b/src/arraytypes/struct.jl @@ -80,13 +80,72 @@ end ToStruct(x::A, j::Integer) where {A} = ToStruct{fieldtype(Base.nonmissingtype(eltype(A)), j),j,A}(x) +@inline _structsource(A::ToStruct) = getfield(A, :data) +@inline _structsourcevalue(A::ToStruct, i::Integer) = @inbounds _structsource(A)[i] + Base.IndexStyle(::Type{<:ToStruct}) = Base.IndexLinear() -Base.size(x::ToStruct) = (length(x.data),) +Base.size(x::ToStruct) = (length(_structsource(x)),) + +@inline _structfield(::Type{T}, x, j) where {T} = + x === missing ? ArrowTypes.default(T) : getfield(x, j) Base.@propagate_inbounds function Base.getindex(A::ToStruct{T,j}, i::Integer) where {T,j} @boundscheck checkbounds(A, i) - @inbounds x = A.data[i] - return x === missing ? ArrowTypes.default(T) : getfield(x, j) + x = _structsourcevalue(A, i) + return _structfield(T, x, j) +end + +function Base.iterate(A::ToStruct{T,j}) where {T,j} + state = iterate(_structsource(A)) + state === nothing && return nothing + x, st = state + return _structfield(T, x, j), st +end + +function Base.iterate(A::ToStruct{T,j}, st) where {T,j} + state = iterate(_structsource(A), st) + state === nothing && return nothing + x, st = state + return _structfield(T, x, j), st +end + +function writearray(io::IO, ::Type{T}, col::ToStruct{T,j}) where {T,j} + isbitstype(T) || return _writearrayfallback(io, T, col) + data = Vector{T}(undef, length(col)) + i = 1 + for x in col + @inbounds data[i] = x + i += 1 + end + return _writearraycontiguous(io, T, data) +end + +function writearray( + io::IO, + ::Type{UInt8}, + col::ToList{UInt8,stringtype,A}, +) where {stringtype,T,j,A<:ToStruct{T,j}} + off = _tolistoffset(col) + off == 0 || return _writearray_tolist_uint8(io, col) + len = length(col) + len <= 1_048_576 || return _writearray_tolist_uint8(io, col) + outer = _tolistdata(col) + data = _structsource(outer) + buf = Vector{UInt8}(undef, len) + pos = 1 + @inbounds for idx in eachindex(data) + chunk = _structfield(T, data[idx], j) + chunk === missing && continue + bytes = stringtype ? _codeunits(chunk) : chunk + for b in bytes + buf[pos] = b + pos += 1 + end + end + written = pos - 1 + GC.@preserve buf begin + return Base.unsafe_write(io, pointer(buf), written) + end end arrowvector(::StructKind, x::Struct, i, nl, fi, de, ded, meta; kw...) = x diff --git a/src/flight/Flight.jl b/src/flight/Flight.jl new file mode 100644 index 00000000..f6ae04a4 --- /dev/null +++ b/src/flight/Flight.jl @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module Flight + +using Base64 +using ProtoBuf +using gRPCClient +using Tables + +const ArrowParent = parentmodule(@__MODULE__) + +include("exports.jl") +include("protocol.jl") +include("client.jl") +include("server.jl") +include("convert.jl") + +end # module Flight diff --git a/src/flight/client.jl b/src/flight/client.jl new file mode 100644 index 00000000..dc196bd0 --- /dev/null +++ b/src/flight/client.jl @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("client/constants.jl") +include("client/locations.jl") +include("client/types.jl") +include("client/headers.jl") +include("client/transport.jl") +include("client/protocol_clients.jl") +include("client/auth.jl") +include("client/rpc_methods.jl") diff --git a/src/flight/client/auth.jl b/src/flight/client/auth.jl new file mode 100644 index 00000000..9750cbb9 --- /dev/null +++ b/src/flight/client/auth.jl @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +handshake( + client::Client, + request::Channel{Protocol.HandshakeRequest}, + response::Channel{Protocol.HandshakeResponse}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _handshake_client(client; kwargs...), + request, + response, + headers=_merge_headers(client, headers), +) + +function handshake( + client::Client; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + request = Channel{Protocol.HandshakeRequest}(request_capacity) + response = Channel{Protocol.HandshakeResponse}(response_capacity) + req = handshake(client, request, response; headers=headers, kwargs...) + return req, request, response +end + +function authenticate( + client::Client, + requests::AbstractVector{<:Protocol.HandshakeRequest}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + req, request_channel, response_channel = handshake(client; headers=headers, kwargs...) + for request in requests + put!(request_channel, request) + end + close(request_channel) + + responses = collect(response_channel) + gRPCClient.grpc_async_await(req) + + isempty(responses) && + throw(ArgumentError("Arrow Flight handshake returned no response messages")) + + return withtoken(client, responses[end].payload), responses +end + +function authenticate( + client::Client, + payloads::AbstractVector{<:AbstractVector{UInt8}}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + requests = [ + Protocol.HandshakeRequest(UInt64(0), Vector{UInt8}(payload)) for payload in payloads + ] + return authenticate(client, requests; headers=headers, kwargs...) +end + +function authenticate( + client::Client, + username::AbstractString, + password::AbstractString; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + return authenticate( + client, + [Vector{UInt8}(codeunits(username)), Vector{UInt8}(codeunits(password))]; + headers=headers, + kwargs..., + ) +end diff --git a/src/flight/client/constants.jl b/src/flight/client/constants.jl new file mode 100644 index 00000000..f5276985 --- /dev/null +++ b/src/flight/client/constants.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const DEFAULT_MAX_MESSAGE_LENGTH = 4 * 1024 * 1024 +const DEFAULT_STREAM_BUFFER = 16 +const HeaderValue = Union{String,Vector{UInt8}} +const HeaderPair = Pair{String,HeaderValue} +const AUTH_TOKEN_HEADER = "auth-token-bin" diff --git a/src/flight/client/headers.jl b/src/flight/client/headers.jl new file mode 100644 index 00000000..6f1ca7ff --- /dev/null +++ b/src/flight/client/headers.jl @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +_normalize_header_value(value::AbstractString) = String(value) +_normalize_header_value(value::AbstractVector{UInt8}) = Vector{UInt8}(value) +function _normalize_header_value(value) + throw( + ArgumentError( + "Arrow Flight header values must be strings or byte vectors, got $(typeof(value))", + ), + ) +end + +function _normalize_headers(headers::AbstractVector{<:Pair}) + normalized = HeaderPair[] + for header in headers + push!(normalized, String(first(header)) => _normalize_header_value(last(header))) + end + return normalized +end + +withheaders(client::Client, headers::Pair...) = withheaders(client, collect(headers)) + +function withheaders(client::Client, headers::AbstractVector{<:Pair}) + merged_headers = copy(client.headers) + append!(merged_headers, _normalize_headers(headers)) + return _rebuild_client(client; headers=merged_headers) +end + +withtoken(client::Client, token::AbstractString) = + withtoken(client, Vector{UInt8}(codeunits(token))) +withtoken(client::Client, token::AbstractVector{UInt8}) = + _withreplacedheader(client, AUTH_TOKEN_HEADER => Vector{UInt8}(token)) + +function _withreplacedheader(client::Client, header::Pair) + normalized_header = String(first(header)) => _normalize_header_value(last(header)) + name = lowercase(first(normalized_header)) + filtered_headers = HeaderPair[ + existing for existing in client.headers if lowercase(first(existing)) != name + ] + push!(filtered_headers, normalized_header) + return _rebuild_client(client; headers=filtered_headers) +end + +function _header_lines(headers::AbstractVector{HeaderPair}) + lines = String[] + for (name, value) in headers + isempty(name) && throw(ArgumentError("Arrow Flight header names must not be empty")) + any(ch -> ch == '\r' || ch == '\n', name) && + throw(ArgumentError("Arrow Flight header names must not contain newlines")) + rendered_value = _render_header_value(name, value) + any(ch -> ch == '\r' || ch == '\n', rendered_value) && + throw(ArgumentError("Arrow Flight header values must not contain newlines")) + push!(lines, string(name, ": ", rendered_value)) + end + return lines +end + +function _render_header_value(name::String, value::String) + if endswith(lowercase(name), "-bin") + return Base64.base64encode(codeunits(value)) + end + return value +end + +function _render_header_value(name::String, value::Vector{UInt8}) + endswith(lowercase(name), "-bin") || + throw(ArgumentError("Arrow Flight binary header values require a '-bin' suffix")) + return Base64.base64encode(value) +end + +function _merge_headers(client::Client, headers::AbstractVector{<:Pair}=HeaderPair[]) + merged_headers = copy(client.headers) + append!(merged_headers, _normalize_headers(headers)) + return merged_headers +end diff --git a/src/flight/client/locations.jl b/src/flight/client/locations.jl new file mode 100644 index 00000000..846d72c5 --- /dev/null +++ b/src/flight/client/locations.jl @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _parse_location(uri::String) + match_result = match( + r"^(grpc\+tls|grpc\+tcp|grpc|https|http)://(\[[^\]]+\]|[^:/?#]+):([0-9]+)(/.*)?$", + uri, + ) + isnothing(match_result) && + throw(ArgumentError("unsupported Arrow Flight location URI: $uri")) + scheme = match_result.captures[1] + host = match_result.captures[2] + port = parse(Int64, match_result.captures[3]) + secure = scheme == "grpc+tls" || scheme == "https" + if startswith(host, "[") && endswith(host, "]") + host = host[2:(end - 1)] + end + return secure, host, port +end diff --git a/src/flight/client/methods/actions.jl b/src/flight/client/methods/actions.jl new file mode 100644 index 00000000..7de2201c --- /dev/null +++ b/src/flight/client/methods/actions.jl @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +doaction( + client::Client, + action::Protocol.Action, + response::Channel{Protocol.Result}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _doaction_client(client; kwargs...), + action, + response; + headers=_merge_headers(client, headers), +) + +function doaction( + client::Client, + action::Protocol.Action; + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + response = Channel{Protocol.Result}(response_capacity) + req = doaction(client, action, response; headers=headers, kwargs...) + return req, response +end + +function listactions( + client::Client, + response::Channel{Protocol.ActionType}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + return _grpc_async_request( + client, + _listactions_client(client; kwargs...), + Protocol.Empty(), + response, + headers=_merge_headers(client, headers), + ) +end + +function listactions( + client::Client; + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + response = Channel{Protocol.ActionType}(response_capacity) + req = listactions(client, response; headers=headers, kwargs...) + return req, response +end diff --git a/src/flight/client/methods/data.jl b/src/flight/client/methods/data.jl new file mode 100644 index 00000000..10a2f54a --- /dev/null +++ b/src/flight/client/methods/data.jl @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +doget( + client::Client, + ticket::Protocol.Ticket, + response::Channel{Protocol.FlightData}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _doget_client(client; kwargs...), + ticket, + response; + headers=_merge_headers(client, headers), +) + +function doget( + client::Client, + ticket::Protocol.Ticket; + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + response = Channel{Protocol.FlightData}(response_capacity) + req = doget(client, ticket, response; headers=headers, kwargs...) + return req, response +end + +doput( + client::Client, + request::Channel{Protocol.FlightData}, + response::Channel{Protocol.PutResult}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _doput_client(client; kwargs...), + request, + response; + headers=_merge_headers(client, headers), +) + +function doput( + client::Client; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + request = Channel{Protocol.FlightData}(request_capacity) + response = Channel{Protocol.PutResult}(response_capacity) + req = doput(client, request, response; headers=headers, kwargs...) + return req, request, response +end + +doexchange( + client::Client, + request::Channel{Protocol.FlightData}, + response::Channel{Protocol.FlightData}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _doexchange_client(client; kwargs...), + request, + response, + headers=_merge_headers(client, headers), +) + +function doexchange( + client::Client; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + request = Channel{Protocol.FlightData}(request_capacity) + response = Channel{Protocol.FlightData}(response_capacity) + req = doexchange(client, request, response; headers=headers, kwargs...) + return req, request, response +end diff --git a/src/flight/client/methods/discovery.jl b/src/flight/client/methods/discovery.jl new file mode 100644 index 00000000..56351545 --- /dev/null +++ b/src/flight/client/methods/discovery.jl @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +listflights( + client::Client, + criteria::Protocol.Criteria, + response::Channel{Protocol.FlightInfo}; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) = _grpc_async_request( + client, + _listflights_client(client; kwargs...), + criteria, + response, + headers=_merge_headers(client, headers), +) + +function listflights( + client::Client, + criteria::Protocol.Criteria=Protocol.Criteria(UInt8[]); + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + response = Channel{Protocol.FlightInfo}(response_capacity) + req = listflights(client, criteria, response; headers=headers, kwargs...) + return req, response +end + +function getflightinfo( + client::Client, + descriptor::Protocol.FlightDescriptor; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + return _grpc_sync_request( + client, + _getflightinfo_client(client; kwargs...), + descriptor; + headers=_merge_headers(client, headers), + ) +end + +function pollflightinfo( + client::Client, + descriptor::Protocol.FlightDescriptor; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + return _grpc_sync_request( + client, + _pollflightinfo_client(client; kwargs...), + descriptor; + headers=_merge_headers(client, headers), + ) +end + +function getschema( + client::Client, + descriptor::Protocol.FlightDescriptor; + headers::AbstractVector{<:Pair}=HeaderPair[], + kwargs..., +) + return _grpc_sync_request( + client, + _getschema_client(client; kwargs...), + descriptor; + headers=_merge_headers(client, headers), + ) +end diff --git a/src/flight/client/protocol_clients.jl b/src/flight/client/protocol_clients.jl new file mode 100644 index 00000000..5524cbb4 --- /dev/null +++ b/src/flight/client/protocol_clients.jl @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +_handshake_client(client::Client; kwargs...) = Protocol.FlightService_Handshake_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_listflights_client(client::Client; kwargs...) = Protocol.FlightService_ListFlights_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_getflightinfo_client(client::Client; kwargs...) = + Protocol.FlightService_GetFlightInfo_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., + ) + +_pollflightinfo_client(client::Client; kwargs...) = + Protocol.FlightService_PollFlightInfo_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., + ) + +_getschema_client(client::Client; kwargs...) = Protocol.FlightService_GetSchema_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_doget_client(client::Client; kwargs...) = Protocol.FlightService_DoGet_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_doput_client(client::Client; kwargs...) = Protocol.FlightService_DoPut_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_doexchange_client(client::Client; kwargs...) = Protocol.FlightService_DoExchange_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_doaction_client(client::Client; kwargs...) = Protocol.FlightService_DoAction_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) + +_listactions_client(client::Client; kwargs...) = Protocol.FlightService_ListActions_Client( + client.host, + client.port; + _rpc_options(client; kwargs...)..., +) diff --git a/src/flight/client/rpc_methods.jl b/src/flight/client/rpc_methods.jl new file mode 100644 index 00000000..7e8e2cdb --- /dev/null +++ b/src/flight/client/rpc_methods.jl @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("methods/discovery.jl") +include("methods/data.jl") +include("methods/actions.jl") diff --git a/src/flight/client/transport.jl b/src/flight/client/transport.jl new file mode 100644 index 00000000..4d6fef80 --- /dev/null +++ b/src/flight/client/transport.jl @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _append_headers_unlocked!( + req::gRPCClient.gRPCRequest, + headers::AbstractVector{HeaderPair}, +) + isempty(headers) && return req + for header_line in _header_lines(headers) + req.headers = gRPCClient.curl_slist_append(req.headers, header_line) + end + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_HTTPHEADER, req.headers) + return req +end + +function _apply_tls_options_unlocked!(client::Client, req::gRPCClient.gRPCRequest) + if !client.secure + return req + end + + if client.disable_server_verification + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSL_VERIFYPEER, Clong(0)) + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSL_VERIFYHOST, Clong(0)) + else + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSL_VERIFYPEER, Clong(1)) + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSL_VERIFYHOST, Clong(2)) + end + + !isnothing(client.tls_root_certs) && gRPCClient.curl_easy_setopt( + req.easy, + gRPCClient.CURLOPT_CAINFO, + client.tls_root_certs, + ) + !isnothing(client.cert_chain) && + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSLCERT, client.cert_chain) + !isnothing(client.private_key) && + gRPCClient.curl_easy_setopt(req.easy, gRPCClient.CURLOPT_SSLKEY, client.private_key) + !isnothing(client.key_password) && gRPCClient.curl_easy_setopt( + req.easy, + gRPCClient.CURLOPT_KEYPASSWD, + client.key_password, + ) + + return req +end + +function _apply_client_options_unlocked!( + client::Client, + req::gRPCClient.gRPCRequest, + headers::AbstractVector{HeaderPair}, +) + _append_headers_unlocked!(req, headers) + return _apply_tls_options_unlocked!(client, req) +end + +function _grpc_sync_request( + client::Client, + rpc_client::gRPCClient.gRPCServiceClient{TRequest,false,TResponse,false}, + request::TRequest; + headers::AbstractVector{HeaderPair}=HeaderPair[], +) where {TRequest<:Any,TResponse<:Any} + req = lock(rpc_client.grpc.lock) do + req = gRPCClient.grpc_async_request(rpc_client, request) + _apply_client_options_unlocked!(client, req, headers) + end + return gRPCClient.grpc_async_await(rpc_client, req) +end + +function _grpc_async_request( + client::Client, + rpc_client::gRPCClient.gRPCServiceClient{TRequest,false,TResponse,true}, + request::TRequest, + response::Channel{TResponse}; + headers::AbstractVector{HeaderPair}=HeaderPair[], +) where {TRequest<:Any,TResponse<:Any} + return lock(rpc_client.grpc.lock) do + req = gRPCClient.grpc_async_request(rpc_client, request, response) + _apply_client_options_unlocked!(client, req, headers) + end +end + +function _grpc_async_request( + client::Client, + rpc_client::gRPCClient.gRPCServiceClient{TRequest,true,TResponse,false}, + request::Channel{TRequest}, + response::Channel{TResponse}; + headers::AbstractVector{HeaderPair}=HeaderPair[], +) where {TRequest<:Any,TResponse<:Any} + return lock(rpc_client.grpc.lock) do + req = gRPCClient.grpc_async_request(rpc_client, request, response) + _apply_client_options_unlocked!(client, req, headers) + end +end + +function _grpc_async_request( + client::Client, + rpc_client::gRPCClient.gRPCServiceClient{TRequest,true,TResponse,true}, + request::Channel{TRequest}, + response::Channel{TResponse}; + headers::AbstractVector{HeaderPair}=HeaderPair[], +) where {TRequest<:Any,TResponse<:Any} + return lock(rpc_client.grpc.lock) do + req = gRPCClient.grpc_async_request(rpc_client, request, response) + _apply_client_options_unlocked!(client, req, headers) + end +end + +_default_rpc_options(client::Client) = ( + secure=client.secure, + grpc=client.grpc, + deadline=client.deadline, + keepalive=client.keepalive, + max_send_message_length=client.max_send_message_length, + max_recieve_message_length=client.max_recieve_message_length, +) + +_rpc_options(client::Client; kwargs...) = + merge(_default_rpc_options(client), NamedTuple(kwargs)) diff --git a/src/flight/client/types.jl b/src/flight/client/types.jl new file mode 100644 index 00000000..77d0bdd5 --- /dev/null +++ b/src/flight/client/types.jl @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +struct Client + host::String + port::Int64 + secure::Bool + grpc::gRPCClient.gRPCCURL + deadline::Float64 + keepalive::Float64 + max_send_message_length::Int64 + max_recieve_message_length::Int64 + headers::Vector{HeaderPair} + tls_root_certs::Union{Nothing,String} + cert_chain::Union{Nothing,String} + private_key::Union{Nothing,String} + key_password::Union{Nothing,String} + disable_server_verification::Bool +end + +function Client( + host, + port; + secure::Bool=false, + grpc::gRPCClient.gRPCCURL=gRPCClient.grpc_global_handle(), + deadline::Real=10, + keepalive::Real=60, + max_send_message_length::Integer=DEFAULT_MAX_MESSAGE_LENGTH, + max_recieve_message_length::Integer=DEFAULT_MAX_MESSAGE_LENGTH, + headers::AbstractVector{<:Pair}=HeaderPair[], + tls_root_certs::Union{Nothing,AbstractString}=nothing, + cert_chain::Union{Nothing,AbstractString}=nothing, + private_key::Union{Nothing,AbstractString}=nothing, + key_password::Union{Nothing,AbstractString}=nothing, + disable_server_verification::Bool=false, +) + Client( + String(host), + Int64(port), + secure, + grpc, + Float64(deadline), + Float64(keepalive), + Int64(max_send_message_length), + Int64(max_recieve_message_length), + _normalize_headers(headers), + isnothing(tls_root_certs) ? nothing : String(tls_root_certs), + isnothing(cert_chain) ? nothing : String(cert_chain), + isnothing(private_key) ? nothing : String(private_key), + isnothing(key_password) ? nothing : String(key_password), + disable_server_verification, + ) +end + +Client(location::Protocol.Location; kwargs...) = Client(location.uri; kwargs...) + +function Client(uri::AbstractString; kwargs...) + secure, host, port = _parse_location(String(uri)) + Client(host, port; secure=secure, kwargs...) +end + +function _rebuild_client(client::Client; headers::AbstractVector{<:Pair}=client.headers) + return Client( + client.host, + client.port; + secure=client.secure, + grpc=client.grpc, + deadline=client.deadline, + keepalive=client.keepalive, + max_send_message_length=client.max_send_message_length, + max_recieve_message_length=client.max_recieve_message_length, + headers=headers, + tls_root_certs=client.tls_root_certs, + cert_chain=client.cert_chain, + private_key=client.private_key, + key_password=client.key_password, + disable_server_verification=client.disable_server_verification, + ) +end diff --git a/src/flight/convert.jl b/src/flight/convert.jl new file mode 100644 index 00000000..b6d82896 --- /dev/null +++ b/src/flight/convert.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("convert/constants.jl") +include("convert/framing.jl") +include("convert/schema.jl") +include("convert/streaming.jl") +include("convert/flightdata.jl") diff --git a/src/flight/convert/constants.jl b/src/flight/convert/constants.jl new file mode 100644 index 00000000..80db62e8 --- /dev/null +++ b/src/flight/convert/constants.jl @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const DEFAULT_IPC_ALIGNMENT = 8 + +_collect_messages(messages::AbstractVector{<:Protocol.FlightData}) = messages +_collect_messages(messages) = collect(messages) diff --git a/src/flight/convert/flightdata.jl b/src/flight/convert/flightdata.jl new file mode 100644 index 00000000..d78d83a9 --- /dev/null +++ b/src/flight/convert/flightdata.jl @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flightdata( + source; + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, +) + dictencodings = Dict{Int64,Any}() + messages = Protocol.FlightData[] + schema = Ref{Tables.Schema}() + normalized_colmetadata = ArrowParent._normalizecolmeta(colmetadata) + meta = isnothing(metadata) ? ArrowParent.getmetadata(source) : metadata + + for tbl in Tables.partitions(source) + tblcols = Tables.columns(tbl) + cols = ArrowParent.toarrowtable( + tblcols, + dictencodings, + largelists, + compress, + denseunions, + dictencode, + dictencodenested, + maxdepth, + meta, + normalized_colmetadata, + ) + if !isassigned(schema) + schema[] = Tables.schema(cols) + push!( + messages, + _flightdata_message( + ArrowParent.makeschemamsg(schema[], cols); + descriptor=descriptor, + alignment=alignment, + ), + ) + if !isempty(dictencodings) + for (id, delock) in sort!(collect(dictencodings); by=x -> x.first, rev=true) + de = delock.value + dictsch = Tables.Schema((:col,), (eltype(de.data),)) + push!( + messages, + _flightdata_message( + ArrowParent.makedictionarybatchmsg( + dictsch, + (col=de.data,), + id, + false, + alignment, + ); + alignment=alignment, + ), + ) + end + end + elseif !isempty(cols.dictencodingdeltas) + for de in cols.dictencodingdeltas + dictsch = Tables.Schema((:col,), (eltype(de.data),)) + push!( + messages, + _flightdata_message( + ArrowParent.makedictionarybatchmsg( + dictsch, + (col=de.data,), + de.id, + true, + alignment, + ); + alignment=alignment, + ), + ) + end + end + push!( + messages, + _flightdata_message( + ArrowParent.makerecordbatchmsg(schema[], cols, alignment); + alignment=alignment, + ), + ) + descriptor = nothing + end + return messages +end diff --git a/src/flight/convert/framing.jl b/src/flight/convert/framing.jl new file mode 100644 index 00000000..8b205d15 --- /dev/null +++ b/src/flight/convert/framing.jl @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _message_body(msg::ArrowParent.Message, alignment::Integer) + msg.columns === nothing && return UInt8[] + io = IOBuffer() + for col in Tables.Columns(msg.columns) + ArrowParent.writebuffer(io, col, alignment) + end + return take!(io) +end + +function _flightdata_message( + msg::ArrowParent.Message; + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, +) + body = _message_body(msg, alignment) + length(body) == msg.bodylen || + throw(ArgumentError("FlightData body length mismatch while encoding Arrow IPC")) + return Protocol.FlightData(descriptor, Vector{UInt8}(msg.msgflatbuf), UInt8[], body) +end + +function _write_framed_message( + io::IO, + data_header::AbstractVector{UInt8}, + data_body::AbstractVector{UInt8}, + alignment::Integer, +) + metalen = ArrowParent.padding(length(data_header), alignment) + Base.write(io, ArrowParent.CONTINUATION_INDICATOR_BYTES) + Base.write(io, Int32(metalen)) + Base.write(io, data_header) + ArrowParent.writezeros(io, ArrowParent.paddinglength(length(data_header), alignment)) + Base.write(io, data_body) + return +end + +function _write_end_marker(io::IO) + Base.write(io, ArrowParent.CONTINUATION_INDICATOR_BYTES) + Base.write(io, Int32(0)) + return +end diff --git a/src/flight/convert/schema.jl b/src/flight/convert/schema.jl new file mode 100644 index 00000000..55cf33d4 --- /dev/null +++ b/src/flight/convert/schema.jl @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _normalize_schemaipc( + schema::AbstractVector{UInt8}; + alignment::Integer=DEFAULT_IPC_ALIGNMENT, +) + bytes = Vector{UInt8}(schema) + isempty(bytes) && throw(ArgumentError("schema bytes cannot be empty")) + if length(bytes) >= 8 && + ArrowParent.readbuffer(bytes, 1, UInt32) == ArrowParent.CONTINUATION_INDICATOR_BYTES + return bytes + end + if length(bytes) >= 4 + metalen = ArrowParent.readbuffer(bytes, 1, Int32) + if metalen >= 0 && metalen == length(bytes) - 4 + io = IOBuffer() + Base.write(io, ArrowParent.CONTINUATION_INDICATOR_BYTES) + Base.write(io, bytes) + return take!(io) + end + end + io = IOBuffer() + _write_framed_message(io, bytes, UInt8[], alignment) + return take!(io) +end + +schemaipc(result::Protocol.SchemaResult; alignment::Integer=DEFAULT_IPC_ALIGNMENT) = + _normalize_schemaipc(result.schema; alignment=alignment) + +schemaipc(info::Protocol.FlightInfo; alignment::Integer=DEFAULT_IPC_ALIGNMENT) = + _normalize_schemaipc(info.schema; alignment=alignment) + +schemaipc(schema::AbstractVector{UInt8}; alignment::Integer=DEFAULT_IPC_ALIGNMENT) = + _normalize_schemaipc(schema; alignment=alignment) + +function schemaipc(message::Protocol.FlightData; alignment::Integer=DEFAULT_IPC_ALIGNMENT) + isempty(message.data_header) && + throw(ArgumentError("FlightData message is missing the Arrow IPC header")) + io = IOBuffer() + _write_framed_message(io, message.data_header, message.data_body, alignment) + return take!(io) +end + +function schemaipc(source; kwargs...) + alignment = get(kwargs, :alignment, DEFAULT_IPC_ALIGNMENT) + messages = flightdata(source; kwargs...) + isempty(messages) && + throw(ArgumentError("cannot derive schema bytes from an empty Flight source")) + return schemaipc(first(messages); alignment=alignment) +end diff --git a/src/flight/convert/streaming.jl b/src/flight/convert/streaming.jl new file mode 100644 index 00000000..b18396b6 --- /dev/null +++ b/src/flight/convert/streaming.jl @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +streambytes(message::Protocol.FlightData; kwargs...) = + streambytes(Protocol.FlightData[message]; kwargs...) + +function _missing_schema_message() + return join( + [ + "cannot derive Arrow Flight schema from a response stream without a schema message", + "the server may have terminated the stream before emitting the first schema-bearing FlightData message", + "or the underlying transport did not surface the corresponding gRPC status", + ], + "; ", + ) +end + +function _require_schema_messages(messages::AbstractVector{<:Protocol.FlightData}, schema) + schema === nothing || return messages + any(message -> !isempty(message.data_header), messages) && return messages + throw(ArgumentError(_missing_schema_message())) +end + +function streambytes( + messages; + schema=nothing, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + end_marker::Bool=true, +) + collected = _require_schema_messages(_collect_messages(messages), schema) + io = IOBuffer() + schema === nothing || Base.write(io, schemaipc(schema; alignment=alignment)) + for message in collected + if isempty(message.data_header) + isempty(message.data_body) || throw( + ArgumentError("FlightData message has a body but no Arrow IPC header"), + ) + continue + end + _write_framed_message(io, message.data_header, message.data_body, alignment) + end + end_marker && _write_end_marker(io) + return take!(io) +end + +function stream( + messages; + schema=nothing, + convert::Bool=true, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + end_marker::Bool=true, +) + bytes = streambytes(messages; schema=schema, alignment=alignment, end_marker=end_marker) + return ArrowParent.Stream(bytes; convert=convert) +end + +function table( + messages; + schema=nothing, + convert::Bool=true, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + end_marker::Bool=true, +) + bytes = streambytes(messages; schema=schema, alignment=alignment, end_marker=end_marker) + return ArrowParent.Table(bytes; convert=convert) +end diff --git a/src/flight/exports.jl b/src/flight/exports.jl new file mode 100644 index 00000000..2809ce22 --- /dev/null +++ b/src/flight/exports.jl @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +export Client, + Service, + ServerCallContext, + MethodDescriptor, + ServiceDescriptor, + withheaders, + withtoken, + Protocol, + Generated, + authenticate, + callheader, + servicedescriptor, + lookupmethod, + dispatch, + handshake, + listflights, + getflightinfo, + pollflightinfo, + getschema, + doget, + doput, + doexchange, + doaction, + listactions, + schemaipc, + streambytes, + stream, + table, + flightdata diff --git a/src/flight/generated/arrow/arrow.jl b/src/flight/generated/arrow/arrow.jl new file mode 100644 index 00000000..1b821af7 --- /dev/null +++ b/src/flight/generated/arrow/arrow.jl @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module arrow + +include("../google/google.jl") + +include("flight/flight.jl") + +end # module arrow diff --git a/src/flight/generated/arrow/flight/flight.jl b/src/flight/generated/arrow/flight/flight.jl new file mode 100644 index 00000000..9a6b34b5 --- /dev/null +++ b/src/flight/generated/arrow/flight/flight.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module flight + +include("protocol/protocol.jl") + +end # module flight diff --git a/src/flight/generated/arrow/flight/protocol/Flight_pb.jl b/src/flight/generated/arrow/flight/protocol/Flight_pb.jl new file mode 100644 index 00000000..a7c73d35 --- /dev/null +++ b/src/flight/generated/arrow/flight/protocol/Flight_pb.jl @@ -0,0 +1,1359 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Autogenerated using ProtoBuf.jl v1.2.3 +# original file: Flight.proto (proto3 syntax) + +import ProtoBuf as PB +import gRPCClient +using ProtoBuf: OneOf +using ProtoBuf.EnumX: @enumx + +export HandshakeRequest, Ticket, HandshakeResponse, Action +export var"FlightDescriptor.DescriptorType", Criteria, CloseSessionRequest, Result +export ActionType, PutResult, Empty, var"SessionOptionValue.StringListValue", SchemaResult +export CancelStatus, GetSessionOptionsRequest, var"SetSessionOptionsResult.ErrorValue" +export Location, var"CloseSessionResult.Status", BasicAuth, FlightDescriptor +export SessionOptionValue, CancelFlightInfoResult, var"SetSessionOptionsResult.Error" +export FlightEndpoint, CloseSessionResult, FlightData, SetSessionOptionsRequest +export GetSessionOptionsResult, SetSessionOptionsResult, RenewFlightEndpointRequest +export FlightInfo, CancelFlightInfoRequest, PollInfo + +struct HandshakeRequest + protocol_version::UInt64 + payload::Vector{UInt8} +end +PB.default_values(::Type{HandshakeRequest}) = + (; protocol_version=zero(UInt64), payload=UInt8[]) +PB.field_numbers(::Type{HandshakeRequest}) = (; protocol_version=1, payload=2) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:HandshakeRequest}) + protocol_version = zero(UInt64) + payload = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + protocol_version = PB.decode(d, UInt64) + elseif field_number == 2 + payload = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return HandshakeRequest(protocol_version, payload) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::HandshakeRequest) + initpos = position(e.io) + x.protocol_version != zero(UInt64) && PB.encode(e, 1, x.protocol_version) + !isempty(x.payload) && PB.encode(e, 2, x.payload) + return position(e.io) - initpos +end +function PB._encoded_size(x::HandshakeRequest) + encoded_size = 0 + x.protocol_version != zero(UInt64) && + (encoded_size += PB._encoded_size(x.protocol_version, 1)) + !isempty(x.payload) && (encoded_size += PB._encoded_size(x.payload, 2)) + return encoded_size +end + +struct Ticket + ticket::Vector{UInt8} +end +PB.default_values(::Type{Ticket}) = (; ticket=UInt8[]) +PB.field_numbers(::Type{Ticket}) = (; ticket=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Ticket}) + ticket = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + ticket = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return Ticket(ticket) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Ticket) + initpos = position(e.io) + !isempty(x.ticket) && PB.encode(e, 1, x.ticket) + return position(e.io) - initpos +end +function PB._encoded_size(x::Ticket) + encoded_size = 0 + !isempty(x.ticket) && (encoded_size += PB._encoded_size(x.ticket, 1)) + return encoded_size +end + +struct HandshakeResponse + protocol_version::UInt64 + payload::Vector{UInt8} +end +PB.default_values(::Type{HandshakeResponse}) = + (; protocol_version=zero(UInt64), payload=UInt8[]) +PB.field_numbers(::Type{HandshakeResponse}) = (; protocol_version=1, payload=2) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:HandshakeResponse}) + protocol_version = zero(UInt64) + payload = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + protocol_version = PB.decode(d, UInt64) + elseif field_number == 2 + payload = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return HandshakeResponse(protocol_version, payload) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::HandshakeResponse) + initpos = position(e.io) + x.protocol_version != zero(UInt64) && PB.encode(e, 1, x.protocol_version) + !isempty(x.payload) && PB.encode(e, 2, x.payload) + return position(e.io) - initpos +end +function PB._encoded_size(x::HandshakeResponse) + encoded_size = 0 + x.protocol_version != zero(UInt64) && + (encoded_size += PB._encoded_size(x.protocol_version, 1)) + !isempty(x.payload) && (encoded_size += PB._encoded_size(x.payload, 2)) + return encoded_size +end + +struct Action + var"#type"::String + body::Vector{UInt8} +end +PB.default_values(::Type{Action}) = (; var"#type"="", body=UInt8[]) +PB.field_numbers(::Type{Action}) = (; var"#type"=1, body=2) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Action}) + var"#type" = "" + body = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + var"#type" = PB.decode(d, String) + elseif field_number == 2 + body = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return Action(var"#type", body) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Action) + initpos = position(e.io) + !isempty(x.var"#type") && PB.encode(e, 1, x.var"#type") + !isempty(x.body) && PB.encode(e, 2, x.body) + return position(e.io) - initpos +end +function PB._encoded_size(x::Action) + encoded_size = 0 + !isempty(x.var"#type") && (encoded_size += PB._encoded_size(x.var"#type", 1)) + !isempty(x.body) && (encoded_size += PB._encoded_size(x.body, 2)) + return encoded_size +end + +@enumx var"FlightDescriptor.DescriptorType" UNKNOWN=0 PATH=1 CMD=2 + +struct Criteria + expression::Vector{UInt8} +end +PB.default_values(::Type{Criteria}) = (; expression=UInt8[]) +PB.field_numbers(::Type{Criteria}) = (; expression=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Criteria}) + expression = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + expression = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return Criteria(expression) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Criteria) + initpos = position(e.io) + !isempty(x.expression) && PB.encode(e, 1, x.expression) + return position(e.io) - initpos +end +function PB._encoded_size(x::Criteria) + encoded_size = 0 + !isempty(x.expression) && (encoded_size += PB._encoded_size(x.expression, 1)) + return encoded_size +end + +struct CloseSessionRequest end + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:CloseSessionRequest}) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + Base.skip(d, wire_type) + end + return CloseSessionRequest() +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::CloseSessionRequest) + initpos = position(e.io) + return position(e.io) - initpos +end +function PB._encoded_size(x::CloseSessionRequest) + encoded_size = 0 + return encoded_size +end + +struct Result + body::Vector{UInt8} +end +PB.default_values(::Type{Result}) = (; body=UInt8[]) +PB.field_numbers(::Type{Result}) = (; body=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Result}) + body = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + body = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return Result(body) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Result) + initpos = position(e.io) + !isempty(x.body) && PB.encode(e, 1, x.body) + return position(e.io) - initpos +end +function PB._encoded_size(x::Result) + encoded_size = 0 + !isempty(x.body) && (encoded_size += PB._encoded_size(x.body, 1)) + return encoded_size +end + +struct ActionType + var"#type"::String + description::String +end +PB.default_values(::Type{ActionType}) = (; var"#type"="", description="") +PB.field_numbers(::Type{ActionType}) = (; var"#type"=1, description=2) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:ActionType}) + var"#type" = "" + description = "" + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + var"#type" = PB.decode(d, String) + elseif field_number == 2 + description = PB.decode(d, String) + else + Base.skip(d, wire_type) + end + end + return ActionType(var"#type", description) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::ActionType) + initpos = position(e.io) + !isempty(x.var"#type") && PB.encode(e, 1, x.var"#type") + !isempty(x.description) && PB.encode(e, 2, x.description) + return position(e.io) - initpos +end +function PB._encoded_size(x::ActionType) + encoded_size = 0 + !isempty(x.var"#type") && (encoded_size += PB._encoded_size(x.var"#type", 1)) + !isempty(x.description) && (encoded_size += PB._encoded_size(x.description, 2)) + return encoded_size +end + +struct PutResult + app_metadata::Vector{UInt8} +end +PB.default_values(::Type{PutResult}) = (; app_metadata=UInt8[]) +PB.field_numbers(::Type{PutResult}) = (; app_metadata=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:PutResult}) + app_metadata = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + app_metadata = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return PutResult(app_metadata) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::PutResult) + initpos = position(e.io) + !isempty(x.app_metadata) && PB.encode(e, 1, x.app_metadata) + return position(e.io) - initpos +end +function PB._encoded_size(x::PutResult) + encoded_size = 0 + !isempty(x.app_metadata) && (encoded_size += PB._encoded_size(x.app_metadata, 1)) + return encoded_size +end + +struct Empty end + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Empty}) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + Base.skip(d, wire_type) + end + return Empty() +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Empty) + initpos = position(e.io) + return position(e.io) - initpos +end +function PB._encoded_size(x::Empty) + encoded_size = 0 + return encoded_size +end + +struct var"SessionOptionValue.StringListValue" + values::Vector{String} +end +PB.default_values(::Type{var"SessionOptionValue.StringListValue"}) = + (; values=Vector{String}()) +PB.field_numbers(::Type{var"SessionOptionValue.StringListValue"}) = (; values=1) + +function PB.decode( + d::PB.AbstractProtoDecoder, + ::Type{<:var"SessionOptionValue.StringListValue"}, +) + values = PB.BufferedVector{String}() + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, values) + else + Base.skip(d, wire_type) + end + end + return var"SessionOptionValue.StringListValue"(values[]) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::var"SessionOptionValue.StringListValue") + initpos = position(e.io) + !isempty(x.values) && PB.encode(e, 1, x.values) + return position(e.io) - initpos +end +function PB._encoded_size(x::var"SessionOptionValue.StringListValue") + encoded_size = 0 + !isempty(x.values) && (encoded_size += PB._encoded_size(x.values, 1)) + return encoded_size +end + +struct SchemaResult + schema::Vector{UInt8} +end +PB.default_values(::Type{SchemaResult}) = (; schema=UInt8[]) +PB.field_numbers(::Type{SchemaResult}) = (; schema=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:SchemaResult}) + schema = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + schema = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return SchemaResult(schema) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::SchemaResult) + initpos = position(e.io) + !isempty(x.schema) && PB.encode(e, 1, x.schema) + return position(e.io) - initpos +end +function PB._encoded_size(x::SchemaResult) + encoded_size = 0 + !isempty(x.schema) && (encoded_size += PB._encoded_size(x.schema, 1)) + return encoded_size +end + +@enumx CancelStatus CANCEL_STATUS_UNSPECIFIED=0 CANCEL_STATUS_CANCELLED=1 CANCEL_STATUS_CANCELLING=2 CANCEL_STATUS_NOT_CANCELLABLE=3 + +struct GetSessionOptionsRequest end + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:GetSessionOptionsRequest}) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + Base.skip(d, wire_type) + end + return GetSessionOptionsRequest() +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::GetSessionOptionsRequest) + initpos = position(e.io) + return position(e.io) - initpos +end +function PB._encoded_size(x::GetSessionOptionsRequest) + encoded_size = 0 + return encoded_size +end + +@enumx var"SetSessionOptionsResult.ErrorValue" UNSPECIFIED=0 INVALID_NAME=1 INVALID_VALUE=2 ERROR=3 + +struct Location + uri::String +end +PB.default_values(::Type{Location}) = (; uri="") +PB.field_numbers(::Type{Location}) = (; uri=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Location}) + uri = "" + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + uri = PB.decode(d, String) + else + Base.skip(d, wire_type) + end + end + return Location(uri) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Location) + initpos = position(e.io) + !isempty(x.uri) && PB.encode(e, 1, x.uri) + return position(e.io) - initpos +end +function PB._encoded_size(x::Location) + encoded_size = 0 + !isempty(x.uri) && (encoded_size += PB._encoded_size(x.uri, 1)) + return encoded_size +end + +@enumx var"CloseSessionResult.Status" UNSPECIFIED=0 CLOSED=1 CLOSING=2 NOT_CLOSEABLE=3 + +struct BasicAuth + username::String + password::String +end +PB.default_values(::Type{BasicAuth}) = (; username="", password="") +PB.field_numbers(::Type{BasicAuth}) = (; username=2, password=3) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:BasicAuth}) + username = "" + password = "" + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 2 + username = PB.decode(d, String) + elseif field_number == 3 + password = PB.decode(d, String) + else + Base.skip(d, wire_type) + end + end + return BasicAuth(username, password) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::BasicAuth) + initpos = position(e.io) + !isempty(x.username) && PB.encode(e, 2, x.username) + !isempty(x.password) && PB.encode(e, 3, x.password) + return position(e.io) - initpos +end +function PB._encoded_size(x::BasicAuth) + encoded_size = 0 + !isempty(x.username) && (encoded_size += PB._encoded_size(x.username, 2)) + !isempty(x.password) && (encoded_size += PB._encoded_size(x.password, 3)) + return encoded_size +end + +struct FlightDescriptor + var"#type"::var"FlightDescriptor.DescriptorType".T + cmd::Vector{UInt8} + path::Vector{String} +end +PB.default_values(::Type{FlightDescriptor}) = (; + var"#type"=var"FlightDescriptor.DescriptorType".UNKNOWN, + cmd=UInt8[], + path=Vector{String}(), +) +PB.field_numbers(::Type{FlightDescriptor}) = (; var"#type"=1, cmd=2, path=3) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:FlightDescriptor}) + var"#type" = var"FlightDescriptor.DescriptorType".UNKNOWN + cmd = UInt8[] + path = PB.BufferedVector{String}() + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + var"#type" = PB.decode(d, var"FlightDescriptor.DescriptorType".T) + elseif field_number == 2 + cmd = PB.decode(d, Vector{UInt8}) + elseif field_number == 3 + PB.decode!(d, path) + else + Base.skip(d, wire_type) + end + end + return FlightDescriptor(var"#type", cmd, path[]) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::FlightDescriptor) + initpos = position(e.io) + x.var"#type" != var"FlightDescriptor.DescriptorType".UNKNOWN && + PB.encode(e, 1, x.var"#type") + !isempty(x.cmd) && PB.encode(e, 2, x.cmd) + !isempty(x.path) && PB.encode(e, 3, x.path) + return position(e.io) - initpos +end +function PB._encoded_size(x::FlightDescriptor) + encoded_size = 0 + x.var"#type" != var"FlightDescriptor.DescriptorType".UNKNOWN && + (encoded_size += PB._encoded_size(x.var"#type", 1)) + !isempty(x.cmd) && (encoded_size += PB._encoded_size(x.cmd, 2)) + !isempty(x.path) && (encoded_size += PB._encoded_size(x.path, 3)) + return encoded_size +end + +struct SessionOptionValue + option_value::Union{ + Nothing, + OneOf{<:Union{String,Bool,Int64,Float64,var"SessionOptionValue.StringListValue"}}, + } +end +PB.oneof_field_types(::Type{SessionOptionValue}) = (; + option_value=(; + string_value=String, + bool_value=Bool, + int64_value=Int64, + double_value=Float64, + string_list_value=var"SessionOptionValue.StringListValue", + ), +) +PB.default_values(::Type{SessionOptionValue}) = (; + string_value="", + bool_value=false, + int64_value=zero(Int64), + double_value=zero(Float64), + string_list_value=nothing, +) +PB.field_numbers(::Type{SessionOptionValue}) = + (; string_value=1, bool_value=2, int64_value=3, double_value=4, string_list_value=5) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:SessionOptionValue}) + option_value = nothing + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + option_value = OneOf(:string_value, PB.decode(d, String)) + elseif field_number == 2 + option_value = OneOf(:bool_value, PB.decode(d, Bool)) + elseif field_number == 3 + option_value = OneOf(:int64_value, PB.decode(d, Int64, Val{:fixed})) + elseif field_number == 4 + option_value = OneOf(:double_value, PB.decode(d, Float64)) + elseif field_number == 5 + option_value = OneOf( + :string_list_value, + PB.decode(d, Ref{var"SessionOptionValue.StringListValue"}), + ) + else + Base.skip(d, wire_type) + end + end + return SessionOptionValue(option_value) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::SessionOptionValue) + initpos = position(e.io) + if isnothing(x.option_value) + ; + elseif x.option_value.name === :string_value + PB.encode(e, 1, x.option_value[]::String) + elseif x.option_value.name === :bool_value + PB.encode(e, 2, x.option_value[]::Bool) + elseif x.option_value.name === :int64_value + PB.encode(e, 3, x.option_value[]::Int64, Val{:fixed}) + elseif x.option_value.name === :double_value + PB.encode(e, 4, x.option_value[]::Float64) + elseif x.option_value.name === :string_list_value + PB.encode(e, 5, x.option_value[]::var"SessionOptionValue.StringListValue") + end + return position(e.io) - initpos +end +function PB._encoded_size(x::SessionOptionValue) + encoded_size = 0 + if isnothing(x.option_value) + ; + elseif x.option_value.name === :string_value + encoded_size += PB._encoded_size(x.option_value[]::String, 1) + elseif x.option_value.name === :bool_value + encoded_size += PB._encoded_size(x.option_value[]::Bool, 2) + elseif x.option_value.name === :int64_value + encoded_size += PB._encoded_size(x.option_value[]::Int64, 3, Val{:fixed}) + elseif x.option_value.name === :double_value + encoded_size += PB._encoded_size(x.option_value[]::Float64, 4) + elseif x.option_value.name === :string_list_value + encoded_size += + PB._encoded_size(x.option_value[]::var"SessionOptionValue.StringListValue", 5) + end + return encoded_size +end + +struct CancelFlightInfoResult + status::CancelStatus.T +end +PB.default_values(::Type{CancelFlightInfoResult}) = + (; status=CancelStatus.CANCEL_STATUS_UNSPECIFIED) +PB.field_numbers(::Type{CancelFlightInfoResult}) = (; status=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:CancelFlightInfoResult}) + status = CancelStatus.CANCEL_STATUS_UNSPECIFIED + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + status = PB.decode(d, CancelStatus.T) + else + Base.skip(d, wire_type) + end + end + return CancelFlightInfoResult(status) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::CancelFlightInfoResult) + initpos = position(e.io) + x.status != CancelStatus.CANCEL_STATUS_UNSPECIFIED && PB.encode(e, 1, x.status) + return position(e.io) - initpos +end +function PB._encoded_size(x::CancelFlightInfoResult) + encoded_size = 0 + x.status != CancelStatus.CANCEL_STATUS_UNSPECIFIED && + (encoded_size += PB._encoded_size(x.status, 1)) + return encoded_size +end + +struct var"SetSessionOptionsResult.Error" + value::var"SetSessionOptionsResult.ErrorValue".T +end +PB.default_values(::Type{var"SetSessionOptionsResult.Error"}) = + (; value=var"SetSessionOptionsResult.ErrorValue".UNSPECIFIED) +PB.field_numbers(::Type{var"SetSessionOptionsResult.Error"}) = (; value=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:var"SetSessionOptionsResult.Error"}) + value = var"SetSessionOptionsResult.ErrorValue".UNSPECIFIED + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + value = PB.decode(d, var"SetSessionOptionsResult.ErrorValue".T) + else + Base.skip(d, wire_type) + end + end + return var"SetSessionOptionsResult.Error"(value) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::var"SetSessionOptionsResult.Error") + initpos = position(e.io) + x.value != var"SetSessionOptionsResult.ErrorValue".UNSPECIFIED && + PB.encode(e, 1, x.value) + return position(e.io) - initpos +end +function PB._encoded_size(x::var"SetSessionOptionsResult.Error") + encoded_size = 0 + x.value != var"SetSessionOptionsResult.ErrorValue".UNSPECIFIED && + (encoded_size += PB._encoded_size(x.value, 1)) + return encoded_size +end + +struct FlightEndpoint + ticket::Union{Nothing,Ticket} + location::Vector{Location} + expiration_time::Union{Nothing,google.protobuf.Timestamp} + app_metadata::Vector{UInt8} +end +PB.default_values(::Type{FlightEndpoint}) = (; + ticket=nothing, + location=Vector{Location}(), + expiration_time=nothing, + app_metadata=UInt8[], +) +PB.field_numbers(::Type{FlightEndpoint}) = + (; ticket=1, location=2, expiration_time=3, app_metadata=4) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:FlightEndpoint}) + ticket = Ref{Union{Nothing,Ticket}}(nothing) + location = PB.BufferedVector{Location}() + expiration_time = Ref{Union{Nothing,google.protobuf.Timestamp}}(nothing) + app_metadata = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, ticket) + elseif field_number == 2 + PB.decode!(d, location) + elseif field_number == 3 + PB.decode!(d, expiration_time) + elseif field_number == 4 + app_metadata = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return FlightEndpoint(ticket[], location[], expiration_time[], app_metadata) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::FlightEndpoint) + initpos = position(e.io) + !isnothing(x.ticket) && PB.encode(e, 1, x.ticket) + !isempty(x.location) && PB.encode(e, 2, x.location) + !isnothing(x.expiration_time) && PB.encode(e, 3, x.expiration_time) + !isempty(x.app_metadata) && PB.encode(e, 4, x.app_metadata) + return position(e.io) - initpos +end +function PB._encoded_size(x::FlightEndpoint) + encoded_size = 0 + !isnothing(x.ticket) && (encoded_size += PB._encoded_size(x.ticket, 1)) + !isempty(x.location) && (encoded_size += PB._encoded_size(x.location, 2)) + !isnothing(x.expiration_time) && + (encoded_size += PB._encoded_size(x.expiration_time, 3)) + !isempty(x.app_metadata) && (encoded_size += PB._encoded_size(x.app_metadata, 4)) + return encoded_size +end + +struct CloseSessionResult + status::var"CloseSessionResult.Status".T +end +PB.default_values(::Type{CloseSessionResult}) = + (; status=var"CloseSessionResult.Status".UNSPECIFIED) +PB.field_numbers(::Type{CloseSessionResult}) = (; status=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:CloseSessionResult}) + status = var"CloseSessionResult.Status".UNSPECIFIED + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + status = PB.decode(d, var"CloseSessionResult.Status".T) + else + Base.skip(d, wire_type) + end + end + return CloseSessionResult(status) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::CloseSessionResult) + initpos = position(e.io) + x.status != var"CloseSessionResult.Status".UNSPECIFIED && PB.encode(e, 1, x.status) + return position(e.io) - initpos +end +function PB._encoded_size(x::CloseSessionResult) + encoded_size = 0 + x.status != var"CloseSessionResult.Status".UNSPECIFIED && + (encoded_size += PB._encoded_size(x.status, 1)) + return encoded_size +end + +struct FlightData + flight_descriptor::Union{Nothing,FlightDescriptor} + data_header::Vector{UInt8} + app_metadata::Vector{UInt8} + data_body::Vector{UInt8} +end +PB.default_values(::Type{FlightData}) = (; + flight_descriptor=nothing, + data_header=UInt8[], + app_metadata=UInt8[], + data_body=UInt8[], +) +PB.field_numbers(::Type{FlightData}) = + (; flight_descriptor=1, data_header=2, app_metadata=3, data_body=1000) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:FlightData}) + flight_descriptor = Ref{Union{Nothing,FlightDescriptor}}(nothing) + data_header = UInt8[] + app_metadata = UInt8[] + data_body = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, flight_descriptor) + elseif field_number == 2 + data_header = PB.decode(d, Vector{UInt8}) + elseif field_number == 3 + app_metadata = PB.decode(d, Vector{UInt8}) + elseif field_number == 1000 + data_body = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return FlightData(flight_descriptor[], data_header, app_metadata, data_body) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::FlightData) + initpos = position(e.io) + !isnothing(x.flight_descriptor) && PB.encode(e, 1, x.flight_descriptor) + !isempty(x.data_header) && PB.encode(e, 2, x.data_header) + !isempty(x.app_metadata) && PB.encode(e, 3, x.app_metadata) + !isempty(x.data_body) && PB.encode(e, 1000, x.data_body) + return position(e.io) - initpos +end +function PB._encoded_size(x::FlightData) + encoded_size = 0 + !isnothing(x.flight_descriptor) && + (encoded_size += PB._encoded_size(x.flight_descriptor, 1)) + !isempty(x.data_header) && (encoded_size += PB._encoded_size(x.data_header, 2)) + !isempty(x.app_metadata) && (encoded_size += PB._encoded_size(x.app_metadata, 3)) + !isempty(x.data_body) && (encoded_size += PB._encoded_size(x.data_body, 1000)) + return encoded_size +end + +struct SetSessionOptionsRequest + session_options::Dict{String,SessionOptionValue} +end +PB.default_values(::Type{SetSessionOptionsRequest}) = + (; session_options=Dict{String,SessionOptionValue}()) +PB.field_numbers(::Type{SetSessionOptionsRequest}) = (; session_options=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:SetSessionOptionsRequest}) + session_options = Dict{String,SessionOptionValue}() + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, session_options) + else + Base.skip(d, wire_type) + end + end + return SetSessionOptionsRequest(session_options) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::SetSessionOptionsRequest) + initpos = position(e.io) + !isempty(x.session_options) && PB.encode(e, 1, x.session_options) + return position(e.io) - initpos +end +function PB._encoded_size(x::SetSessionOptionsRequest) + encoded_size = 0 + !isempty(x.session_options) && (encoded_size += PB._encoded_size(x.session_options, 1)) + return encoded_size +end + +struct GetSessionOptionsResult + session_options::Dict{String,SessionOptionValue} +end +PB.default_values(::Type{GetSessionOptionsResult}) = + (; session_options=Dict{String,SessionOptionValue}()) +PB.field_numbers(::Type{GetSessionOptionsResult}) = (; session_options=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:GetSessionOptionsResult}) + session_options = Dict{String,SessionOptionValue}() + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, session_options) + else + Base.skip(d, wire_type) + end + end + return GetSessionOptionsResult(session_options) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::GetSessionOptionsResult) + initpos = position(e.io) + !isempty(x.session_options) && PB.encode(e, 1, x.session_options) + return position(e.io) - initpos +end +function PB._encoded_size(x::GetSessionOptionsResult) + encoded_size = 0 + !isempty(x.session_options) && (encoded_size += PB._encoded_size(x.session_options, 1)) + return encoded_size +end + +struct SetSessionOptionsResult + errors::Dict{String,var"SetSessionOptionsResult.Error"} +end +PB.default_values(::Type{SetSessionOptionsResult}) = + (; errors=Dict{String,var"SetSessionOptionsResult.Error"}()) +PB.field_numbers(::Type{SetSessionOptionsResult}) = (; errors=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:SetSessionOptionsResult}) + errors = Dict{String,var"SetSessionOptionsResult.Error"}() + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, errors) + else + Base.skip(d, wire_type) + end + end + return SetSessionOptionsResult(errors) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::SetSessionOptionsResult) + initpos = position(e.io) + !isempty(x.errors) && PB.encode(e, 1, x.errors) + return position(e.io) - initpos +end +function PB._encoded_size(x::SetSessionOptionsResult) + encoded_size = 0 + !isempty(x.errors) && (encoded_size += PB._encoded_size(x.errors, 1)) + return encoded_size +end + +struct RenewFlightEndpointRequest + endpoint::Union{Nothing,FlightEndpoint} +end +PB.default_values(::Type{RenewFlightEndpointRequest}) = (; endpoint=nothing) +PB.field_numbers(::Type{RenewFlightEndpointRequest}) = (; endpoint=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:RenewFlightEndpointRequest}) + endpoint = Ref{Union{Nothing,FlightEndpoint}}(nothing) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, endpoint) + else + Base.skip(d, wire_type) + end + end + return RenewFlightEndpointRequest(endpoint[]) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::RenewFlightEndpointRequest) + initpos = position(e.io) + !isnothing(x.endpoint) && PB.encode(e, 1, x.endpoint) + return position(e.io) - initpos +end +function PB._encoded_size(x::RenewFlightEndpointRequest) + encoded_size = 0 + !isnothing(x.endpoint) && (encoded_size += PB._encoded_size(x.endpoint, 1)) + return encoded_size +end + +struct FlightInfo + schema::Vector{UInt8} + flight_descriptor::Union{Nothing,FlightDescriptor} + endpoint::Vector{FlightEndpoint} + total_records::Int64 + total_bytes::Int64 + ordered::Bool + app_metadata::Vector{UInt8} +end +PB.default_values(::Type{FlightInfo}) = (; + schema=UInt8[], + flight_descriptor=nothing, + endpoint=Vector{FlightEndpoint}(), + total_records=zero(Int64), + total_bytes=zero(Int64), + ordered=false, + app_metadata=UInt8[], +) +PB.field_numbers(::Type{FlightInfo}) = (; + schema=1, + flight_descriptor=2, + endpoint=3, + total_records=4, + total_bytes=5, + ordered=6, + app_metadata=7, +) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:FlightInfo}) + schema = UInt8[] + flight_descriptor = Ref{Union{Nothing,FlightDescriptor}}(nothing) + endpoint = PB.BufferedVector{FlightEndpoint}() + total_records = zero(Int64) + total_bytes = zero(Int64) + ordered = false + app_metadata = UInt8[] + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + schema = PB.decode(d, Vector{UInt8}) + elseif field_number == 2 + PB.decode!(d, flight_descriptor) + elseif field_number == 3 + PB.decode!(d, endpoint) + elseif field_number == 4 + total_records = PB.decode(d, Int64) + elseif field_number == 5 + total_bytes = PB.decode(d, Int64) + elseif field_number == 6 + ordered = PB.decode(d, Bool) + elseif field_number == 7 + app_metadata = PB.decode(d, Vector{UInt8}) + else + Base.skip(d, wire_type) + end + end + return FlightInfo( + schema, + flight_descriptor[], + endpoint[], + total_records, + total_bytes, + ordered, + app_metadata, + ) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::FlightInfo) + initpos = position(e.io) + !isempty(x.schema) && PB.encode(e, 1, x.schema) + !isnothing(x.flight_descriptor) && PB.encode(e, 2, x.flight_descriptor) + !isempty(x.endpoint) && PB.encode(e, 3, x.endpoint) + x.total_records != zero(Int64) && PB.encode(e, 4, x.total_records) + x.total_bytes != zero(Int64) && PB.encode(e, 5, x.total_bytes) + x.ordered != false && PB.encode(e, 6, x.ordered) + !isempty(x.app_metadata) && PB.encode(e, 7, x.app_metadata) + return position(e.io) - initpos +end +function PB._encoded_size(x::FlightInfo) + encoded_size = 0 + !isempty(x.schema) && (encoded_size += PB._encoded_size(x.schema, 1)) + !isnothing(x.flight_descriptor) && + (encoded_size += PB._encoded_size(x.flight_descriptor, 2)) + !isempty(x.endpoint) && (encoded_size += PB._encoded_size(x.endpoint, 3)) + x.total_records != zero(Int64) && (encoded_size += PB._encoded_size(x.total_records, 4)) + x.total_bytes != zero(Int64) && (encoded_size += PB._encoded_size(x.total_bytes, 5)) + x.ordered != false && (encoded_size += PB._encoded_size(x.ordered, 6)) + !isempty(x.app_metadata) && (encoded_size += PB._encoded_size(x.app_metadata, 7)) + return encoded_size +end + +struct CancelFlightInfoRequest + info::Union{Nothing,FlightInfo} +end +PB.default_values(::Type{CancelFlightInfoRequest}) = (; info=nothing) +PB.field_numbers(::Type{CancelFlightInfoRequest}) = (; info=1) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:CancelFlightInfoRequest}) + info = Ref{Union{Nothing,FlightInfo}}(nothing) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, info) + else + Base.skip(d, wire_type) + end + end + return CancelFlightInfoRequest(info[]) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::CancelFlightInfoRequest) + initpos = position(e.io) + !isnothing(x.info) && PB.encode(e, 1, x.info) + return position(e.io) - initpos +end +function PB._encoded_size(x::CancelFlightInfoRequest) + encoded_size = 0 + !isnothing(x.info) && (encoded_size += PB._encoded_size(x.info, 1)) + return encoded_size +end + +struct PollInfo + info::Union{Nothing,FlightInfo} + flight_descriptor::Union{Nothing,FlightDescriptor} + progress::Float64 + expiration_time::Union{Nothing,google.protobuf.Timestamp} +end +PB.default_values(::Type{PollInfo}) = (; + info=nothing, + flight_descriptor=nothing, + progress=zero(Float64), + expiration_time=nothing, +) +PB.field_numbers(::Type{PollInfo}) = + (; info=1, flight_descriptor=2, progress=3, expiration_time=4) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:PollInfo}) + info = Ref{Union{Nothing,FlightInfo}}(nothing) + flight_descriptor = Ref{Union{Nothing,FlightDescriptor}}(nothing) + progress = zero(Float64) + expiration_time = Ref{Union{Nothing,google.protobuf.Timestamp}}(nothing) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + PB.decode!(d, info) + elseif field_number == 2 + PB.decode!(d, flight_descriptor) + elseif field_number == 3 + progress = PB.decode(d, Float64) + elseif field_number == 4 + PB.decode!(d, expiration_time) + else + Base.skip(d, wire_type) + end + end + return PollInfo(info[], flight_descriptor[], progress, expiration_time[]) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::PollInfo) + initpos = position(e.io) + !isnothing(x.info) && PB.encode(e, 1, x.info) + !isnothing(x.flight_descriptor) && PB.encode(e, 2, x.flight_descriptor) + x.progress !== zero(Float64) && PB.encode(e, 3, x.progress) + !isnothing(x.expiration_time) && PB.encode(e, 4, x.expiration_time) + return position(e.io) - initpos +end +function PB._encoded_size(x::PollInfo) + encoded_size = 0 + !isnothing(x.info) && (encoded_size += PB._encoded_size(x.info, 1)) + !isnothing(x.flight_descriptor) && + (encoded_size += PB._encoded_size(x.flight_descriptor, 2)) + x.progress !== zero(Float64) && (encoded_size += PB._encoded_size(x.progress, 3)) + !isnothing(x.expiration_time) && + (encoded_size += PB._encoded_size(x.expiration_time, 4)) + return encoded_size +end + +# gRPCClient.jl BEGIN +FlightService_Handshake_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{HandshakeRequest,true,HandshakeResponse,true}( + host, + port, + "/arrow.flight.protocol.FlightService/Handshake"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_Handshake_Client + +FlightService_ListFlights_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{Criteria,false,FlightInfo,true}( + host, + port, + "/arrow.flight.protocol.FlightService/ListFlights"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_ListFlights_Client + +FlightService_GetFlightInfo_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{FlightDescriptor,false,FlightInfo,false}( + host, + port, + "/arrow.flight.protocol.FlightService/GetFlightInfo"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_GetFlightInfo_Client + +FlightService_PollFlightInfo_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{FlightDescriptor,false,PollInfo,false}( + host, + port, + "/arrow.flight.protocol.FlightService/PollFlightInfo"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_PollFlightInfo_Client + +FlightService_GetSchema_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{FlightDescriptor,false,SchemaResult,false}( + host, + port, + "/arrow.flight.protocol.FlightService/GetSchema"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_GetSchema_Client + +FlightService_DoGet_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{Ticket,false,FlightData,true}( + host, + port, + "/arrow.flight.protocol.FlightService/DoGet"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_DoGet_Client + +FlightService_DoPut_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{FlightData,true,PutResult,true}( + host, + port, + "/arrow.flight.protocol.FlightService/DoPut"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_DoPut_Client + +FlightService_DoExchange_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{FlightData,true,FlightData,true}( + host, + port, + "/arrow.flight.protocol.FlightService/DoExchange"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_DoExchange_Client + +FlightService_DoAction_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{Action,false,Result,true}( + host, + port, + "/arrow.flight.protocol.FlightService/DoAction"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_DoAction_Client + +FlightService_ListActions_Client( + host, + port; + secure=false, + grpc=gRPCClient.grpc_global_handle(), + deadline=10, + keepalive=60, + max_send_message_length=4*1024*1024, + max_recieve_message_length=4*1024*1024, +) = gRPCClient.gRPCServiceClient{Empty,false,ActionType,true}( + host, + port, + "/arrow.flight.protocol.FlightService/ListActions"; + secure=secure, + grpc=grpc, + deadline=deadline, + keepalive=keepalive, + max_send_message_length=max_send_message_length, + max_recieve_message_length=max_recieve_message_length, +) +export FlightService_ListActions_Client +# gRPCClient.jl END diff --git a/src/flight/generated/arrow/flight/protocol/protocol.jl b/src/flight/generated/arrow/flight/protocol/protocol.jl new file mode 100644 index 00000000..0e8132f4 --- /dev/null +++ b/src/flight/generated/arrow/flight/protocol/protocol.jl @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module protocol + +import ...google + +include("Flight_pb.jl") + +end # module protocol diff --git a/src/flight/generated/google/google.jl b/src/flight/generated/google/google.jl new file mode 100644 index 00000000..eaea4251 --- /dev/null +++ b/src/flight/generated/google/google.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module google + +include("protobuf/protobuf.jl") + +end # module google diff --git a/src/flight/generated/google/protobuf/protobuf.jl b/src/flight/generated/google/protobuf/protobuf.jl new file mode 100644 index 00000000..f066b99e --- /dev/null +++ b/src/flight/generated/google/protobuf/protobuf.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module protobuf + +include("timestamp_pb.jl") + +end # module protobuf diff --git a/src/flight/generated/google/protobuf/timestamp_pb.jl b/src/flight/generated/google/protobuf/timestamp_pb.jl new file mode 100644 index 00000000..2831ff78 --- /dev/null +++ b/src/flight/generated/google/protobuf/timestamp_pb.jl @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Autogenerated using ProtoBuf.jl v1.2.3 +# original file: google/protobuf/timestamp.proto (proto3 syntax) + +import ProtoBuf as PB +using ProtoBuf: OneOf +using ProtoBuf.EnumX: @enumx + +export Timestamp + +struct Timestamp + seconds::Int64 + nanos::Int32 +end +PB.default_values(::Type{Timestamp}) = (; seconds=zero(Int64), nanos=zero(Int32)) +PB.field_numbers(::Type{Timestamp}) = (; seconds=1, nanos=2) + +function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Timestamp}) + seconds = zero(Int64) + nanos = zero(Int32) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + seconds = PB.decode(d, Int64) + elseif field_number == 2 + nanos = PB.decode(d, Int32) + else + Base.skip(d, wire_type) + end + end + return Timestamp(seconds, nanos) +end + +function PB.encode(e::PB.AbstractProtoEncoder, x::Timestamp) + initpos = position(e.io) + x.seconds != zero(Int64) && PB.encode(e, 1, x.seconds) + x.nanos != zero(Int32) && PB.encode(e, 2, x.nanos) + return position(e.io) - initpos +end +function PB._encoded_size(x::Timestamp) + encoded_size = 0 + x.seconds != zero(Int64) && (encoded_size += PB._encoded_size(x.seconds, 1)) + x.nanos != zero(Int32) && (encoded_size += PB._encoded_size(x.nanos, 2)) + return encoded_size +end diff --git a/src/flight/proto/Flight.proto b/src/flight/proto/Flight.proto new file mode 100644 index 00000000..69e74c5d --- /dev/null +++ b/src/flight/proto/Flight.proto @@ -0,0 +1,678 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +import "google/protobuf/timestamp.proto"; + +option java_package = "org.apache.arrow.flight.impl"; +option go_package = "github.com/apache/arrow-go/arrow/flight/gen/flight"; +option csharp_namespace = "Apache.Arrow.Flight.Protocol"; + +package arrow.flight.protocol; + +/* + * A flight service is an endpoint for retrieving or storing Arrow data. A + * flight service can expose one or more predefined endpoints that can be + * accessed using the Arrow Flight Protocol. Additionally, a flight service + * can expose a set of actions that are available. + */ +service FlightService { + + /* + * Handshake between client and server. Depending on the server, the + * handshake may be required to determine the token that should be used for + * future operations. Both request and response are streams to allow multiple + * round-trips depending on auth mechanism. + */ + rpc Handshake(stream HandshakeRequest) returns (stream HandshakeResponse) {} + + /* + * Get a list of available streams given a particular criteria. Most flight + * services will expose one or more streams that are readily available for + * retrieval. This api allows listing the streams available for + * consumption. A user can also provide a criteria. The criteria can limit + * the subset of streams that can be listed via this interface. Each flight + * service allows its own definition of how to consume criteria. + */ + rpc ListFlights(Criteria) returns (stream FlightInfo) {} + + /* + * For a given FlightDescriptor, get information about how the flight can be + * consumed. This is a useful interface if the consumer of the interface + * already can identify the specific flight to consume. This interface can + * also allow a consumer to generate a flight stream through a specified + * descriptor. For example, a flight descriptor might be something that + * includes a SQL statement or a Pickled Python operation that will be + * executed. In those cases, the descriptor will not be previously available + * within the list of available streams provided by ListFlights but will be + * available for consumption for the duration defined by the specific flight + * service. + */ + rpc GetFlightInfo(FlightDescriptor) returns (FlightInfo) {} + + /* + * For a given FlightDescriptor, start a query and get information + * to poll its execution status. This is a useful interface if the + * query may be a long-running query. The first PollFlightInfo call + * should return as quickly as possible. (GetFlightInfo doesn't + * return until the query is complete.) + * + * A client can consume any available results before + * the query is completed. See PollInfo.info for details. + * + * A client can poll the updated query status by calling + * PollFlightInfo() with PollInfo.flight_descriptor. A server + * should not respond until the result would be different from last + * time. That way, the client can "long poll" for updates + * without constantly making requests. Clients can set a short timeout + * to avoid blocking calls if desired. + * + * A client can't use PollInfo.flight_descriptor after + * PollInfo.expiration_time passes. A server might not accept the + * retry descriptor anymore and the query may be cancelled. + * + * A client may use the CancelFlightInfo action with + * PollInfo.info to cancel the running query. + */ + rpc PollFlightInfo(FlightDescriptor) returns (PollInfo) {} + + /* + * For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema + * This is used when a consumer needs the Schema of flight stream. Similar to + * GetFlightInfo this interface may generate a new flight that was not previously + * available in ListFlights. + */ + rpc GetSchema(FlightDescriptor) returns (SchemaResult) {} + + /* + * Retrieve a single stream associated with a particular descriptor + * associated with the referenced ticket. A Flight can be composed of one or + * more streams where each stream can be retrieved using a separate opaque + * ticket that the flight service uses for managing a collection of streams. + */ + rpc DoGet(Ticket) returns (stream FlightData) {} + + /* + * Push a stream to the flight service associated with a particular + * flight stream. This allows a client of a flight service to upload a stream + * of data. Depending on the particular flight service, a client consumer + * could be allowed to upload a single stream per descriptor or an unlimited + * number. In the latter, the service might implement a 'seal' action that + * can be applied to a descriptor once all streams are uploaded. + */ + rpc DoPut(stream FlightData) returns (stream PutResult) {} + + /* + * Open a bidirectional data channel for a given descriptor. This + * allows clients to send and receive arbitrary Arrow data and + * application-specific metadata in a single logical stream. In + * contrast to DoGet/DoPut, this is more suited for clients + * offloading computation (rather than storage) to a Flight service. + */ + rpc DoExchange(stream FlightData) returns (stream FlightData) {} + + /* + * Flight services can support an arbitrary number of simple actions in + * addition to the possible ListFlights, GetFlightInfo, DoGet, DoPut + * operations that are potentially available. DoAction allows a flight client + * to do a specific action against a flight service. An action includes + * opaque request and response objects that are specific to the type action + * being undertaken. + */ + rpc DoAction(Action) returns (stream Result) {} + + /* + * A flight service exposes all of the available action types that it has + * along with descriptions. This allows different flight consumers to + * understand the capabilities of the flight service. + */ + rpc ListActions(Empty) returns (stream ActionType) {} +} + +/* + * The request that a client provides to a server on handshake. + */ +message HandshakeRequest { + + /* + * A defined protocol version + */ + uint64 protocol_version = 1; + + /* + * Arbitrary auth/handshake info. + */ + bytes payload = 2; +} + +message HandshakeResponse { + + /* + * A defined protocol version + */ + uint64 protocol_version = 1; + + /* + * Arbitrary auth/handshake info. + */ + bytes payload = 2; +} + +/* + * A message for doing simple auth. + */ +message BasicAuth { + string username = 2; + string password = 3; +} + +message Empty {} + +/* + * Describes an available action, including both the name used for execution + * along with a short description of the purpose of the action. + */ +message ActionType { + string type = 1; + string description = 2; +} + +/* + * A service specific expression that can be used to return a limited set + * of available Arrow Flight streams. + */ +message Criteria { + bytes expression = 1; +} + +/* + * An opaque action specific for the service. + */ +message Action { + string type = 1; + bytes body = 2; +} + +/* + * An opaque result returned after executing an action. + */ +message Result { + bytes body = 1; +} + +/* + * Wrap the result of a getSchema call + */ +message SchemaResult { + // The schema of the dataset in its IPC form: + // 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + bytes schema = 1; +} + +/* + * The name or tag for a Flight. May be used as a way to retrieve or generate + * a flight or be used to expose a set of previously defined flights. + */ +message FlightDescriptor { + + /* + * Describes what type of descriptor is defined. + */ + enum DescriptorType { + + // Protobuf pattern, not used. + UNKNOWN = 0; + + /* + * A named path that identifies a dataset. A path is composed of a string + * or list of strings describing a particular dataset. This is conceptually + * similar to a path inside a filesystem. + */ + PATH = 1; + + /* + * An opaque command to generate a dataset. + */ + CMD = 2; + } + + DescriptorType type = 1; + + /* + * Opaque value used to express a command. Should only be defined when + * type = CMD. + */ + bytes cmd = 2; + + /* + * List of strings identifying a particular dataset. Should only be defined + * when type = PATH. + */ + repeated string path = 3; +} + +/* + * The access coordinates for retrieval of a dataset. With a FlightInfo, a + * consumer is able to determine how to retrieve a dataset. + */ +message FlightInfo { + // The schema of the dataset in its IPC form: + // 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + bytes schema = 1; + + /* + * The descriptor associated with this info. + */ + FlightDescriptor flight_descriptor = 2; + + /* + * A list of endpoints associated with the flight. To consume the + * whole flight, all endpoints (and hence all Tickets) must be + * consumed. Endpoints can be consumed in any order. + * + * In other words, an application can use multiple endpoints to + * represent partitioned data. + * + * If the returned data has an ordering, an application can use + * "FlightInfo.ordered = true" or should return all data in a + * single endpoint. Otherwise, there is no ordering defined on + * endpoints or the data within. + * + * A client can read ordered data by reading data from returned + * endpoints, in order, from front to back. + * + * Note that a client may ignore "FlightInfo.ordered = true". If an + * ordering is important for an application, an application must + * choose one of them: + * + * * An application requires that all clients must read data in + * returned endpoints order. + * * An application must return all data in a single endpoint. + */ + repeated FlightEndpoint endpoint = 3; + + // Set these to -1 if unknown. + int64 total_records = 4; + int64 total_bytes = 5; + + /* + * FlightEndpoints are in the same order as the data. + */ + bool ordered = 6; + + /* + * Application-defined metadata. + * + * There is no inherent or required relationship between this + * and the app_metadata fields in the FlightEndpoints or resulting + * FlightData messages. Since this metadata is application-defined, + * a given application could define there to be a relationship, + * but there is none required by the spec. + */ + bytes app_metadata = 7; +} + +/* + * The information to process a long-running query. + */ +message PollInfo { + /* + * The currently available results. + * + * If "flight_descriptor" is not specified, the query is complete + * and "info" specifies all results. Otherwise, "info" contains + * partial query results. + * + * Note that each PollInfo response contains a complete + * FlightInfo (not just the delta between the previous and current + * FlightInfo). + * + * Subsequent PollInfo responses may only append new endpoints to + * info. + * + * Clients can begin fetching results via DoGet(Ticket) with the + * ticket in the info before the query is + * completed. FlightInfo.ordered is also valid. + */ + FlightInfo info = 1; + + /* + * The descriptor the client should use on the next try. + * If unset, the query is complete. + */ + FlightDescriptor flight_descriptor = 2; + + /* + * Query progress. If known, must be in [0.0, 1.0] but need not be + * monotonic or nondecreasing. If unknown, do not set. + */ + optional double progress = 3; + + /* + * Expiration time for this request. After this passes, the server + * might not accept the retry descriptor anymore (and the query may + * be cancelled). This may be updated on a call to PollFlightInfo. + */ + google.protobuf.Timestamp expiration_time = 4; +} + +/* + * The request of the CancelFlightInfo action. + * + * The request should be stored in Action.body. + */ +message CancelFlightInfoRequest { + FlightInfo info = 1; +} + +/* + * The result of a cancel operation. + * + * This is used by CancelFlightInfoResult.status. + */ +enum CancelStatus { + // The cancellation status is unknown. Servers should avoid using + // this value (send a NOT_FOUND error if the requested query is + // not known). Clients can retry the request. + CANCEL_STATUS_UNSPECIFIED = 0; + // The cancellation request is complete. Subsequent requests with + // the same payload may return CANCELLED or a NOT_FOUND error. + CANCEL_STATUS_CANCELLED = 1; + // The cancellation request is in progress. The client may retry + // the cancellation request. + CANCEL_STATUS_CANCELLING = 2; + // The query is not cancellable. The client should not retry the + // cancellation request. + CANCEL_STATUS_NOT_CANCELLABLE = 3; +} + +/* + * The result of the CancelFlightInfo action. + * + * The result should be stored in Result.body. + */ +message CancelFlightInfoResult { + CancelStatus status = 1; +} + +/* + * An opaque identifier that the service can use to retrieve a particular + * portion of a stream. + * + * Tickets are meant to be single use. It is an error/application-defined + * behavior to reuse a ticket. + */ +message Ticket { + bytes ticket = 1; +} + +/* + * A location to retrieve a particular stream from. This URI should be one of + * the following: + * - An empty string or the string 'arrow-flight-reuse-connection://?': + * indicating that the ticket can be redeemed on the service where the + * ticket was generated via a DoGet request. + * - A valid grpc URI (grpc://, grpc+tls://, grpc+unix://, etc.): + * indicating that the ticket can be redeemed on the service at the given + * URI via a DoGet request. + * - A valid HTTP URI (http://, https://, etc.): + * indicating that the client should perform a GET request against the + * given URI to retrieve the stream. The ticket should be empty + * in this case and should be ignored by the client. Cloud object storage + * can be utilized by presigned URLs or mediating the auth separately and + * returning the full URL (e.g. https://amzn-s3-demo-bucket.s3.us-west-2.amazonaws.com/...). + * + * We allow non-Flight URIs for the purpose of allowing Flight services to indicate that + * results can be downloaded in formats other than Arrow (such as Parquet) or to allow + * direct fetching of results from a URI to reduce excess copying and data movement. + * In these cases, the following conventions should be followed by servers and clients: + * + * - Unless otherwise specified by the 'Content-Type' header of the response, + * a client should assume the response is using the Arrow IPC Streaming format. + * Usage of an IANA media type like 'application/octet-stream' should be assumed to + * be using the Arrow IPC Streaming format. + * - The server may allow the client to choose a specific response format by + * specifying an 'Accept' header in the request, such as 'application/vnd.apache.parquet' + * or 'application/vnd.apache.arrow.stream'. If multiple types are requested and + * supported by the server, the choice of which to use is server-specific. If + * none of the requested content-types are supported, the server may respond with + * either 406 (Not Acceptable) or 415 (Unsupported Media Type), or successfully + * respond with a different format that it does support along with the correct + * 'Content-Type' header. + * + * Note: new schemes may be proposed in the future to allow for more flexibility based + * on community requests. + */ +message Location { + string uri = 1; +} + +/* + * A particular stream or split associated with a flight. + */ +message FlightEndpoint { + + /* + * Token used to retrieve this stream. + */ + Ticket ticket = 1; + + /* + * A list of URIs where this ticket can be redeemed via DoGet(). + * + * If the list is empty, the expectation is that the ticket can only + * be redeemed on the current service where the ticket was + * generated. + * + * If the list is not empty, the expectation is that the ticket can be + * redeemed at any of the locations, and that the data returned will be + * equivalent. In this case, the ticket may only be redeemed at one of the + * given locations, and not (necessarily) on the current service. If one + * of the given locations is "arrow-flight-reuse-connection://?", the + * client may redeem the ticket on the service where the ticket was + * generated (i.e., the same as above), in addition to the other + * locations. (This URI was chosen to maximize compatibility, as 'scheme:' + * or 'scheme://' are not accepted by Java's java.net.URI.) + * + * In other words, an application can use multiple locations to + * represent redundant and/or load balanced services. + */ + repeated Location location = 2; + + /* + * Expiration time of this stream. If present, clients may assume + * they can retry DoGet requests. Otherwise, it is + * application-defined whether DoGet requests may be retried. + */ + google.protobuf.Timestamp expiration_time = 3; + + /* + * Application-defined metadata. + * + * There is no inherent or required relationship between this + * and the app_metadata fields in the FlightInfo or resulting + * FlightData messages. Since this metadata is application-defined, + * a given application could define there to be a relationship, + * but there is none required by the spec. + */ + bytes app_metadata = 4; +} + +/* + * The request of the RenewFlightEndpoint action. + * + * The request should be stored in Action.body. + */ +message RenewFlightEndpointRequest { + FlightEndpoint endpoint = 1; +} + +/* + * A batch of Arrow data as part of a stream of batches. + */ +message FlightData { + + /* + * The descriptor of the data. This is only relevant when a client is + * starting a new DoPut stream. + */ + FlightDescriptor flight_descriptor = 1; + + /* + * Header for message data as described in Message.fbs::Message. + */ + bytes data_header = 2; + + /* + * Application-defined metadata. + */ + bytes app_metadata = 3; + + /* + * The actual batch of Arrow data. Preferably handled with minimal-copies + * coming last in the definition to help with sidecar patterns (it is + * expected that some implementations will fetch this field off the wire + * with specialized code to avoid extra memory copies). + */ + bytes data_body = 1000; +} + +/** + * The response message associated with the submission of a DoPut. + */ +message PutResult { + bytes app_metadata = 1; +} + +/* + * EXPERIMENTAL: Union of possible value types for a Session Option to be set to. + * + * By convention, an attempt to set a valueless SessionOptionValue should + * attempt to unset or clear the named option value on the server. + */ +message SessionOptionValue { + message StringListValue { + repeated string values = 1; + } + + oneof option_value { + string string_value = 1; + bool bool_value = 2; + sfixed64 int64_value = 3; + double double_value = 4; + StringListValue string_list_value = 5; + } +} + +/* + * EXPERIMENTAL: A request to set session options for an existing or new (implicit) + * server session. + * + * Sessions are persisted and referenced via a transport-level state management, typically + * RFC 6265 HTTP cookies when using an HTTP transport. The suggested cookie name or state + * context key is 'arrow_flight_session_id', although implementations may freely choose their + * own name. + * + * Session creation (if one does not already exist) is implied by this RPC request, however + * server implementations may choose to initiate a session that also contains client-provided + * session options at any other time, e.g. on authentication, or when any other call is made + * and the server wishes to use a session to persist any state (or lack thereof). + */ +message SetSessionOptionsRequest { + map session_options = 1; +} + +/* + * EXPERIMENTAL: The results (individually) of setting a set of session options. + * + * Option names should only be present in the response if they were not successfully + * set on the server; that is, a response without an Error for a name provided in the + * SetSessionOptionsRequest implies that the named option value was set successfully. + */ +message SetSessionOptionsResult { + enum ErrorValue { + // Protobuf deserialization fallback value: The status is unknown or unrecognized. + // Servers should avoid using this value. The request may be retried by the client. + UNSPECIFIED = 0; + // The given session option name is invalid. + INVALID_NAME = 1; + // The session option value or type is invalid. + INVALID_VALUE = 2; + // The session option cannot be set. + ERROR = 3; + } + + message Error { + ErrorValue value = 1; + } + + map errors = 1; +} + +/* + * EXPERIMENTAL: A request to access the session options for the current server session. + * + * The existing session is referenced via a cookie header or similar (see + * SetSessionOptionsRequest above); it is an error to make this request with a missing, + * invalid, or expired session cookie header or other implementation-defined session + * reference token. + */ +message GetSessionOptionsRequest { +} + +/* + * EXPERIMENTAL: The result containing the current server session options. + */ +message GetSessionOptionsResult { + map session_options = 1; +} + +/* + * Request message for the "Close Session" action. + * + * The existing session is referenced via a cookie header. + */ +message CloseSessionRequest { +} + +/* + * The result of closing a session. + */ +message CloseSessionResult { + enum Status { + // Protobuf deserialization fallback value: The session close status is unknown or + // not recognized. Servers should avoid using this value (send a NOT_FOUND error if + // the requested session is not known or expired). Clients can retry the request. + UNSPECIFIED = 0; + // The session close request is complete. Subsequent requests with + // the same session produce a NOT_FOUND error. + CLOSED = 1; + // The session close request is in progress. The client may retry + // the close request. + CLOSING = 2; + // The session is not closeable. The client should not retry the + // close request. + NOT_CLOSEABLE = 3; + } + + Status status = 1; +} diff --git a/src/flight/protocol.jl b/src/flight/protocol.jl new file mode 100644 index 00000000..a1f74c0f --- /dev/null +++ b/src/flight/protocol.jl @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module Generated +include("generated/arrow/arrow.jl") +end + +const Protocol = Generated.arrow.flight.protocol diff --git a/src/flight/server.jl b/src/flight/server.jl new file mode 100644 index 00000000..e7f12128 --- /dev/null +++ b/src/flight/server.jl @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("server/types.jl") +include("server/descriptors.jl") +include("server/handlers.jl") +include("server/dispatch.jl") diff --git a/src/flight/server/descriptors.jl b/src/flight/server/descriptors.jl new file mode 100644 index 00000000..913b6094 --- /dev/null +++ b/src/flight/server/descriptors.jl @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const FLIGHT_SERVICE_NAME = "arrow.flight.protocol.FlightService" + +struct MethodDescriptor + name::String + path::String + handler_field::Symbol + request_streaming::Bool + response_streaming::Bool + request_type::Type + response_type::Type +end + +struct ServiceDescriptor + name::String + methods::Vector{MethodDescriptor} + method_lookup::Dict{String,MethodDescriptor} +end + +function MethodDescriptor( + name::AbstractString, + handler_field::Symbol, + request_streaming::Bool, + response_streaming::Bool, + request_type::Type, + response_type::Type, +) + normalized_name = String(name) + MethodDescriptor( + normalized_name, + "/$(FLIGHT_SERVICE_NAME)/$(normalized_name)", + handler_field, + request_streaming, + response_streaming, + request_type, + response_type, + ) +end + +function ServiceDescriptor(name::AbstractString, methods::Vector{MethodDescriptor}) + lookup = Dict{String,MethodDescriptor}() + for method in methods + lookup[method.name] = method + lookup[method.path] = method + end + return ServiceDescriptor(String(name), methods, lookup) +end + +const FLIGHT_METHODS = [ + MethodDescriptor( + "Handshake", + :handshake, + true, + true, + Protocol.HandshakeRequest, + Protocol.HandshakeResponse, + ), + MethodDescriptor( + "ListFlights", + :listflights, + false, + true, + Protocol.Criteria, + Protocol.FlightInfo, + ), + MethodDescriptor( + "GetFlightInfo", + :getflightinfo, + false, + false, + Protocol.FlightDescriptor, + Protocol.FlightInfo, + ), + MethodDescriptor( + "PollFlightInfo", + :pollflightinfo, + false, + false, + Protocol.FlightDescriptor, + Protocol.PollInfo, + ), + MethodDescriptor( + "GetSchema", + :getschema, + false, + false, + Protocol.FlightDescriptor, + Protocol.SchemaResult, + ), + MethodDescriptor("DoGet", :doget, false, true, Protocol.Ticket, Protocol.FlightData), + MethodDescriptor("DoPut", :doput, true, true, Protocol.FlightData, Protocol.PutResult), + MethodDescriptor( + "DoExchange", + :doexchange, + true, + true, + Protocol.FlightData, + Protocol.FlightData, + ), + MethodDescriptor("DoAction", :doaction, false, true, Protocol.Action, Protocol.Result), + MethodDescriptor( + "ListActions", + :listactions, + false, + true, + Protocol.Empty, + Protocol.ActionType, + ), +] + +const FLIGHT_SERVICE_DESCRIPTOR = ServiceDescriptor(FLIGHT_SERVICE_NAME, FLIGHT_METHODS) + +servicedescriptor(::Service) = FLIGHT_SERVICE_DESCRIPTOR + +function lookupmethod(descriptor::ServiceDescriptor, key::AbstractString) + return get(descriptor.method_lookup, String(key), nothing) +end + +lookupmethod(service::Service, key::AbstractString) = + lookupmethod(servicedescriptor(service), key) diff --git a/src/flight/server/dispatch.jl b/src/flight/server/dispatch.jl new file mode 100644 index 00000000..2f3241b7 --- /dev/null +++ b/src/flight/server/dispatch.jl @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function dispatch( + service::Service, + context::ServerCallContext, + method::MethodDescriptor, + args..., +) + if method.handler_field === :handshake + return handshake(service, context, args...) + elseif method.handler_field === :listflights + return listflights(service, context, args...) + elseif method.handler_field === :getflightinfo + return getflightinfo(service, context, args...) + elseif method.handler_field === :pollflightinfo + return pollflightinfo(service, context, args...) + elseif method.handler_field === :getschema + return getschema(service, context, args...) + elseif method.handler_field === :doget + return doget(service, context, args...) + elseif method.handler_field === :doput + return doput(service, context, args...) + elseif method.handler_field === :doexchange + return doexchange(service, context, args...) + elseif method.handler_field === :doaction + return doaction(service, context, args...) + elseif method.handler_field === :listactions + return listactions(service, context, args...) + end + + throw(ArgumentError("unsupported Arrow Flight handler field $(method.handler_field)")) +end + +function dispatch( + service::Service, + context::ServerCallContext, + key::AbstractString, + args..., +) + method = lookupmethod(service, key) + isnothing(method) && + throw(ArgumentError("unknown Arrow Flight method path or name: $(String(key))")) + return dispatch(service, context, method, args...) +end diff --git a/src/flight/server/handlers.jl b/src/flight/server/handlers.jl new file mode 100644 index 00000000..c0be8a78 --- /dev/null +++ b/src/flight/server/handlers.jl @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function _unimplemented(service_method::String) + throw( + gRPCClient.gRPCServiceCallException( + gRPCClient.GRPC_UNIMPLEMENTED, + "Arrow Flight server method $(service_method) is not implemented", + ), + ) +end + +function _invoke_handler(handler::Union{Nothing,Function}, service_method::String, args...) + isnothing(handler) && _unimplemented(service_method) + return handler(args...) +end + +handshake( + service::Service, + context::ServerCallContext, + request::Channel{Protocol.HandshakeRequest}, + response::Channel{Protocol.HandshakeResponse}, +) = _invoke_handler(service.handshake, "Handshake", context, request, response) + +listflights( + service::Service, + context::ServerCallContext, + criteria::Protocol.Criteria, + response::Channel{Protocol.FlightInfo}, +) = _invoke_handler(service.listflights, "ListFlights", context, criteria, response) + +getflightinfo( + service::Service, + context::ServerCallContext, + descriptor::Protocol.FlightDescriptor, +) = _invoke_handler(service.getflightinfo, "GetFlightInfo", context, descriptor) + +pollflightinfo( + service::Service, + context::ServerCallContext, + descriptor::Protocol.FlightDescriptor, +) = _invoke_handler(service.pollflightinfo, "PollFlightInfo", context, descriptor) + +getschema( + service::Service, + context::ServerCallContext, + descriptor::Protocol.FlightDescriptor, +) = _invoke_handler(service.getschema, "GetSchema", context, descriptor) + +doget( + service::Service, + context::ServerCallContext, + ticket::Protocol.Ticket, + response::Channel{Protocol.FlightData}, +) = _invoke_handler(service.doget, "DoGet", context, ticket, response) + +doput( + service::Service, + context::ServerCallContext, + request::Channel{Protocol.FlightData}, + response::Channel{Protocol.PutResult}, +) = _invoke_handler(service.doput, "DoPut", context, request, response) + +doexchange( + service::Service, + context::ServerCallContext, + request::Channel{Protocol.FlightData}, + response::Channel{Protocol.FlightData}, +) = _invoke_handler(service.doexchange, "DoExchange", context, request, response) + +doaction( + service::Service, + context::ServerCallContext, + action::Protocol.Action, + response::Channel{Protocol.Result}, +) = _invoke_handler(service.doaction, "DoAction", context, action, response) + +listactions( + service::Service, + context::ServerCallContext, + response::Channel{Protocol.ActionType}, +) = _invoke_handler(service.listactions, "ListActions", context, response) diff --git a/src/flight/server/types.jl b/src/flight/server/types.jl new file mode 100644 index 00000000..71b98724 --- /dev/null +++ b/src/flight/server/types.jl @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const ServerHeaderPair = HeaderPair + +Base.@kwdef struct ServerCallContext + headers::Vector{ServerHeaderPair} = ServerHeaderPair[] + peer::Union{Nothing,String} = nothing + secure::Bool = false +end + +Base.@kwdef struct Service + handshake::Union{Nothing,Function} = nothing + listflights::Union{Nothing,Function} = nothing + getflightinfo::Union{Nothing,Function} = nothing + pollflightinfo::Union{Nothing,Function} = nothing + getschema::Union{Nothing,Function} = nothing + doget::Union{Nothing,Function} = nothing + doput::Union{Nothing,Function} = nothing + doexchange::Union{Nothing,Function} = nothing + doaction::Union{Nothing,Function} = nothing + listactions::Union{Nothing,Function} = nothing +end + +function callheader(context::ServerCallContext, name::AbstractString) + needle = lowercase(String(name)) + for (header_name, header_value) in context.headers + lowercase(header_name) == needle && return header_value + end + return nothing +end diff --git a/src/utils.jl b/src/utils.jl index 8e2dfeed..7c042b71 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -37,6 +37,33 @@ end # efficient writing of arrays writearray(io, col) = writearray(io, maybemissing(eltype(col)), col) +function _writearrayfallback(io::IO, ::Type{T}, col) where {T} + n = 0 + data = Vector{UInt8}(undef, sizeof(col)) + buf = IOBuffer(data; write=true) + for x in col + n += Base.write(buf, coalesce(x, ArrowTypes.default(T))) + end + n = Base.write(io, take!(buf)) + return n +end + +@inline function _writearraycontiguous(io::IO, ::Type{T}, data) where {T} + return Base.unsafe_write(io, pointer(data), sizeof(T) * length(data)) +end + +@inline function _contiguoustoarrowdata(::Type{T}, col::ArrowTypes.ToArrow) where {T} + ArrowTypes._needsconvert(col) && return nothing + data = ArrowTypes._sourcedata(col) + strides(data) == (1,) || return nothing + if data isa AbstractVector{T} + return isbitstype(T) ? data : nothing + elseif isbitstype(T) && data isa AbstractVector{Union{T,Missing}} + return data + end + return nothing +end + function writearray(io::IO, ::Type{T}, col) where {T} if col isa Vector{T} n = Base.write(io, col) @@ -51,17 +78,17 @@ function writearray(io::IO, ::Type{T}, col) where {T} n += writearray(io, T, A) end else - n = 0 - data = Vector{UInt8}(undef, sizeof(col)) - buf = IOBuffer(data; write=true) - for x in col - n += Base.write(buf, coalesce(x, ArrowTypes.default(T))) - end - n = Base.write(io, take!(buf)) + n = _writearrayfallback(io, T, col) end return n end +function writearray(io::IO, ::Type{T}, col::ArrowTypes.ToArrow) where {T} + data = _contiguoustoarrowdata(T, col) + isnothing(data) || return _writearraycontiguous(io, T, data) + return _writearrayfallback(io, T, col) +end + getbit(v::UInt8, n::Integer) = (v & (1 << (n - 1))) > 0x00 function setbit(v::UInt8, b::Bool, n::Integer) @@ -120,6 +147,11 @@ function getrb(filebytes) # FlatBuffers.getrootas(Meta.Message, filebytes, rb.offset) end +@inline function messagebytes(msg, alignment) + metalen = padding(length(msg.msgflatbuf), alignment) + return 8 + metalen + msg.bodylen +end + function readmessage(filebytes, off=9) @assert readbuffer(filebytes, off, UInt32) === 0xFFFFFFFF len = readbuffer(filebytes, off + 4, Int32) @@ -127,7 +159,248 @@ function readmessage(filebytes, off=9) FlatBuffers.getrootas(Meta.Message, filebytes, off + 8) end +@inline _issinglepartition(parts) = parts isa Tuple && length(parts) == 1 + +@inline function _directtobuffercoleligible(col) + T = Base.nonmissingtype(eltype(col)) + return !(T <: AbstractString || T <: Base.CodeUnits) +end + +@inline function _directtobufferstringonly(col) + T = Base.nonmissingtype(eltype(col)) + return T <: AbstractString +end + +@inline function _directtobufferbinaryonly(col) + return eltype(col) <: Base.CodeUnits +end + +@inline function _directstreamcoleligible(col) + return !(col isa DictEncode) && + DataAPI.refarray(col) === col && + (_directtobufferstringonly(col) || _directtobufferbinaryonly(col)) +end + +function _directtobuffereligible(part) + tblcols = Tables.columns(part) + sch = Tables.schema(tblcols) + ncols = 0 + singlecolspecial = false + allnonstrings = true + Tables.eachcolumn(sch, tblcols) do col, _, _ + ncols += 1 + eligible = _directtobuffercoleligible(col) + allnonstrings &= eligible + singlecolspecial = + ncols == 1 && (_directtobufferstringonly(col) || _directtobufferbinaryonly(col)) + end + return allnonstrings || (ncols == 1 && singlecolspecial) +end + +@inline function _directstreameligible(part) + tblcols = Tables.columns(part) + sch = Tables.schema(tblcols) + ncols = 0 + singlecolspecial = false + Tables.eachcolumn(sch, tblcols) do col, _, _ + ncols += 1 + singlecolspecial = ncols == 1 && _directstreamcoleligible(col) + end + return ncols == 1 && singlecolspecial +end + +@inline _partitionsinspectable(parts) = + parts isa Tuple || parts isa AbstractVector || parts isa Tables.Partitioner + +@inline function _directtobuffersizehint( + cols, + dictmsgs, + schmsg, + recbatchmsg, + endmsg, + alignment, +) + for col in Tables.Columns(cols) + if col isa Map + return + messagebytes(schmsg, alignment) + + sum(msg -> messagebytes(msg, alignment), dictmsgs; init=0) + + messagebytes(recbatchmsg, alignment) + + messagebytes(endmsg, alignment) + end + end + return nothing +end + +function _writedictionarymessages!(io, blocks, schref, alignment, dictencodings) + isempty(dictencodings) && return + des = sort!(collect(dictencodings); by=x -> x.first, rev=true) + for (id, delock) in des + de = delock.value + dictsch = Tables.Schema((:col,), (eltype(de.data),)) + msg = makedictionarybatchmsg(dictsch, (col=de.data,), id, false, alignment) + Base.write(io, msg, blocks, schref, alignment) + end + return +end + +function _writedictionarydeltas!(io, blocks, schref, alignment, deltas) + isempty(deltas) && return + for de in deltas + dictsch = Tables.Schema((:col,), (eltype(de.data),)) + msg = makedictionarybatchmsg(dictsch, (col=de.data,), de.id, true, alignment) + Base.write(io, msg, blocks, schref, alignment) + end + return +end + +@inline function _directstreamstate(parts) + _partitionsinspectable(parts) || return nothing + firststate = iterate(parts) + isnothing(firststate) && return nothing + firstpart, state = firststate + isnothing(iterate(parts, state)) && return nothing + return firstpart, state +end + +function _directtobuffer(part, source, kwargs) + largelists = get(kwargs, :largelists, false) + compress = get(kwargs, :compress, nothing) + denseunions = get(kwargs, :denseunions, true) + dictencode = get(kwargs, :dictencode, false) + dictencodenested = get(kwargs, :dictencodenested, false) + alignment = Int32(get(kwargs, :alignment, 8)) + maxdepth = get(kwargs, :maxdepth, DEFAULT_MAX_DEPTH) + metadata = get(kwargs, :metadata, getmetadata(source)) + colmetadata = get(kwargs, :colmetadata, nothing) + + tblcols = Tables.columns(part) + dictencodings = Dict{Int64,Any}() + cols = toarrowtable( + tblcols, + dictencodings, + largelists, + compress, + denseunions, + dictencode, + dictencodenested, + maxdepth, + metadata, + colmetadata, + ) + sch = Tables.schema(cols) + schmsg = makeschemamsg(sch, cols) + dictmsgs = if isempty(dictencodings) + Message[] + else + des = sort!(collect(dictencodings); by=x -> x.first, rev=true) + [ + begin + de = delock.value + dictsch = Tables.Schema((:col,), (eltype(de.data),)) + makedictionarybatchmsg(dictsch, (col=de.data,), id, false, alignment) + end for (id, delock) in des + ] + end + recbatchmsg = makerecordbatchmsg(sch, cols, alignment) + endmsg = Message(UInt8[], nothing, 0, true, false, Meta.Schema) + sizehint = + _directtobuffersizehint(cols, dictmsgs, schmsg, recbatchmsg, endmsg, alignment) + io = isnothing(sizehint) ? IOBuffer() : IOBuffer(; sizehint=sizehint) + blocks = (Block[], Block[]) + schref = Ref(sch) + Base.write(io, schmsg, blocks, schref, alignment) + foreach(msg -> Base.write(io, msg, blocks, schref, alignment), dictmsgs) + Base.write(io, recbatchmsg, blocks, schref, alignment) + Base.write(io, endmsg, blocks, schref, alignment) + seekstart(io) + return io +end + +function _directstreamwrite!(io::IO, firstpart, state, parts, source, kwargs) + largelists = get(kwargs, :largelists, false) + compress = get(kwargs, :compress, nothing) + denseunions = get(kwargs, :denseunions, true) + dictencode = get(kwargs, :dictencode, false) + dictencodenested = get(kwargs, :dictencodenested, false) + alignment = Int32(get(kwargs, :alignment, 8)) + maxdepth = get(kwargs, :maxdepth, DEFAULT_MAX_DEPTH) + metadata = get(kwargs, :metadata, getmetadata(source)) + colmetadata = get(kwargs, :colmetadata, nothing) + + dictencodings = Dict{Int64,Any}() + firstcols = toarrowtable( + Tables.columns(firstpart), + dictencodings, + largelists, + compress, + denseunions, + dictencode, + dictencodenested, + maxdepth, + metadata, + colmetadata, + ) + sch = Tables.schema(firstcols) + schmsg = makeschemamsg(sch, firstcols) + blocks = (Block[], Block[]) + schref = Ref(sch) + Base.write(io, schmsg, blocks, schref, alignment) + _writedictionarymessages!(io, blocks, schref, alignment, dictencodings) + Base.write(io, makerecordbatchmsg(sch, firstcols, alignment), blocks, schref, alignment) + + next = iterate(parts, state) + while !isnothing(next) + part, state = next + cols = toarrowtable( + Tables.columns(part), + dictencodings, + largelists, + compress, + denseunions, + dictencode, + dictencodenested, + maxdepth, + metadata, + colmetadata, + ) + Tables.schema(cols) == sch || + throw(ArgumentError("all partitions must have the exact same Tables.Schema")) + _writedictionarydeltas!(io, blocks, schref, alignment, cols.dictencodingdeltas) + Base.write(io, makerecordbatchmsg(sch, cols, alignment), blocks, schref, alignment) + next = iterate(parts, state) + end + Base.write( + io, + Message(UInt8[], nothing, 0, true, false, Meta.Schema), + blocks, + schref, + alignment, + ) + return io +end + +function _directstreamtobuffer(firstpart, state, parts, source, kwargs) + io = IOBuffer() + _directstreamwrite!(io, firstpart, state, parts, source, kwargs) + seekstart(io) + return io +end + function tobuffer(data; kwargs...) + parts = Tables.partitions(data) + if !get(kwargs, :file, false) + if _issinglepartition(parts) && _directtobuffereligible(parts[1]) + return _directtobuffer(parts[1], data, kwargs) + else + streamstate = _directstreamstate(parts) + if !isnothing(streamstate) + firstpart, state = streamstate + _directstreameligible(firstpart) && + return _directstreamtobuffer(firstpart, state, parts, data, kwargs) + end + end + end io = IOBuffer() write(io, data; kwargs...) seekstart(io) diff --git a/src/write.jl b/src/write.jl index 4c3800f2..a25d4afe 100644 --- a/src/write.jl +++ b/src/write.jl @@ -403,6 +403,17 @@ function Base.close(writer::Writer) end function write(io::IO, tbl; kwargs...) + if !get(kwargs, :file, false) + parts = Tables.partitions(tbl) + streamstate = _directstreamstate(parts) + if !isnothing(streamstate) + firstpart, state = streamstate + if _directstreameligible(firstpart) + _directstreamwrite!(io, firstpart, state, parts, tbl, kwargs) + return io + end + end + end open(Writer, io; file=false, kwargs...) do writer write(writer, tbl) end diff --git a/test/Project.toml b/test/Project.toml index c2e02aa8..f5e62b1e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,15 +6,17 @@ # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. [deps] +Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" @@ -25,15 +27,17 @@ JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" +TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +gRPCClient = "aaca4a50-36af-4a1d-b878-4c443f2061ad" [compat] ArrowTypes = "2.3" @@ -44,8 +48,10 @@ FilePathsBase = "0.9" JSON3 = "1" OffsetArrays = "1" PooledArrays = "1" -StructTypes = "1" +ProtoBuf = "~1.2.1" SentinelArrays = "1" +StructTypes = "1" Tables = "1" TestSetExtensions = "3" TimeZones = "1" +gRPCClient = "1" diff --git a/test/flight.jl b/test/flight.jl new file mode 100644 index 00000000..fdd8f223 --- /dev/null +++ b/test/flight.jl @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +using gRPCClient +using Tables + +include("flight/support.jl") +include("flight/header_interop.jl") +include("flight/handshake_interop.jl") +include("flight/tls_interop.jl") +include("flight/poll_interop.jl") +include("flight/client_surface.jl") +include("flight/server_core.jl") +include("flight/grpcserver_extension.jl") +include("flight/ipc_conversion.jl") +include("flight/ipc_schema_separation.jl") +include("flight/pyarrow_interop.jl") diff --git a/test/flight/client_surface.jl b/test/flight/client_surface.jl new file mode 100644 index 00000000..e50bdd9f --- /dev/null +++ b/test/flight/client_surface.jl @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("client_surface/support.jl") +include("client_surface/constructor_tests.jl") +include("client_surface/header_tls_tests.jl") +include("client_surface/protocol_client_tests.jl") + +@testset "Flight RPC client surface" begin + fixture = flight_client_surface_fixture() + flight_client_surface_test_constructors(fixture) + flight_client_surface_test_header_tls_helpers(fixture) + flight_client_surface_test_protocol_clients(fixture) +end diff --git a/test/flight/client_surface/constructor_tests.jl b/test/flight/client_surface/constructor_tests.jl new file mode 100644 index 00000000..ca47f02c --- /dev/null +++ b/test/flight/client_surface/constructor_tests.jl @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_client_surface_test_constructors(fixture) + client = fixture.client + + @test client.host == "localhost" + @test client.port == 8815 + @test client.secure + @test client.deadline == 30.0 + @test client.keepalive == 15.0 + @test client.max_send_message_length == 1024 + @test client.max_recieve_message_length == 2048 + @test isempty(client.headers) + @test isnothing(client.tls_root_certs) + @test isnothing(client.cert_chain) + @test isnothing(client.private_key) + @test isnothing(client.key_password) + @test !client.disable_server_verification + + uri_client = Arrow.Flight.Client("grpc://127.0.0.1:31337") + @test uri_client.host == "127.0.0.1" + @test uri_client.port == 31337 + @test !uri_client.secure + + tls_client = Arrow.Flight.Client("grpc+tls://example.com:9443") + @test tls_client.host == "example.com" + @test tls_client.port == 9443 + @test tls_client.secure + + location_client = + Arrow.Flight.Client(fixture.protocol.Location("https://demo.example:8443")) + @test location_client.host == "demo.example" + @test location_client.port == 8443 + @test location_client.secure + + @test_throws ArgumentError Arrow.Flight.Client("grpc://missing-port") +end diff --git a/test/flight/client_surface/header_tls_tests.jl b/test/flight/client_surface/header_tls_tests.jl new file mode 100644 index 00000000..326684d4 --- /dev/null +++ b/test/flight/client_surface/header_tls_tests.jl @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_client_surface_test_header_tls_helpers(fixture) + client = fixture.client + + tls_client = Arrow.Flight.Client( + "grpc+tls://secure.example:9443"; + tls_root_certs="/tmp/root.pem", + cert_chain="/tmp/client.pem", + private_key="/tmp/client.key", + key_password="secret", + disable_server_verification=true, + ) + @test tls_client.tls_root_certs == "/tmp/root.pem" + @test tls_client.cert_chain == "/tmp/client.pem" + @test tls_client.private_key == "/tmp/client.key" + @test tls_client.key_password == "secret" + @test tls_client.disable_server_verification + + header_client = Arrow.Flight.withheaders( + client, + "authorization" => "Bearer token1234", + "x-trace-id" => "trace-1", + ) + @test header_client.headers == + ["authorization" => "Bearer token1234", "x-trace-id" => "trace-1"] + @test header_client.host == client.host + @test header_client.grpc === client.grpc + @test header_client.disable_server_verification == client.disable_server_verification + + binary_header_client = + Arrow.Flight.withheaders(client, "auth-token-bin" => UInt8[0x00, 0xff, 0x41]) + @test binary_header_client.headers == ["auth-token-bin" => UInt8[0x00, 0xff, 0x41]] + @test Arrow.Flight._header_lines(binary_header_client.headers) == + ["auth-token-bin: AP9B"] + + token_client = Arrow.Flight.withtoken(client, UInt8[0x01, 0x02]) + @test token_client.headers == ["auth-token-bin" => UInt8[0x01, 0x02]] + @test Arrow.Flight._header_lines(token_client.headers) == ["auth-token-bin: AQI="] + + invalid_binary_header_client = + Arrow.Flight.withheaders(client, "x-binary" => UInt8[0x00]) + @test_throws ArgumentError Arrow.Flight._header_lines( + invalid_binary_header_client.headers, + ) +end diff --git a/test/flight/client_surface/protocol_client_tests.jl b/test/flight/client_surface/protocol_client_tests.jl new file mode 100644 index 00000000..eae3b8f1 --- /dev/null +++ b/test/flight/client_surface/protocol_client_tests.jl @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_client_surface_test_protocol_clients(fixture) + client = fixture.client + + @test isdefined(fixture.protocol, :FlightService_Handshake_Client) + @test isdefined(fixture.protocol, :FlightService_ListFlights_Client) + @test isdefined(fixture.protocol, :FlightService_GetFlightInfo_Client) + @test isdefined(fixture.protocol, :FlightService_PollFlightInfo_Client) + @test isdefined(fixture.protocol, :FlightService_GetSchema_Client) + @test isdefined(fixture.protocol, :FlightService_DoGet_Client) + @test isdefined(fixture.protocol, :FlightService_DoPut_Client) + @test isdefined(fixture.protocol, :FlightService_DoExchange_Client) + @test isdefined(fixture.protocol, :FlightService_DoAction_Client) + @test isdefined(fixture.protocol, :FlightService_ListActions_Client) + + @test Arrow.Flight._handshake_client(client).path == + "/arrow.flight.protocol.FlightService/Handshake" + @test Arrow.Flight._listflights_client(client).path == + "/arrow.flight.protocol.FlightService/ListFlights" + @test Arrow.Flight._getflightinfo_client(client).path == + "/arrow.flight.protocol.FlightService/GetFlightInfo" + @test Arrow.Flight._pollflightinfo_client(client).path == + "/arrow.flight.protocol.FlightService/PollFlightInfo" + @test Arrow.Flight._getschema_client(client).path == + "/arrow.flight.protocol.FlightService/GetSchema" + @test Arrow.Flight._doget_client(client).path == + "/arrow.flight.protocol.FlightService/DoGet" + @test Arrow.Flight._doput_client(client).path == + "/arrow.flight.protocol.FlightService/DoPut" + @test Arrow.Flight._doexchange_client(client).path == + "/arrow.flight.protocol.FlightService/DoExchange" + @test Arrow.Flight._doaction_client(client).path == + "/arrow.flight.protocol.FlightService/DoAction" + @test Arrow.Flight._listactions_client(client).path == + "/arrow.flight.protocol.FlightService/ListActions" +end diff --git a/test/flight/client_surface/support.jl b/test/flight/client_surface/support.jl new file mode 100644 index 00000000..eaa065d0 --- /dev/null +++ b/test/flight/client_surface/support.jl @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_client_surface_fixture() + client = Arrow.Flight.Client( + "localhost", + 8815; + secure=true, + deadline=30, + keepalive=15, + max_send_message_length=1024, + max_recieve_message_length=2048, + ) + return (; client, protocol=Arrow.Flight.Protocol) +end diff --git a/test/flight/grpcserver_extension.jl b/test/flight/grpcserver_extension.jl new file mode 100644 index 00000000..55c7a648 --- /dev/null +++ b/test/flight/grpcserver_extension.jl @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("grpcserver_extension/support.jl") +include("grpcserver_extension/descriptor_tests.jl") +include("grpcserver_extension/unary_tests.jl") +include("grpcserver_extension/streaming_tests.jl") + +@testset "Flight gRPCServer extension" begin + grpcserver = FlightTestSupport.load_grpcserver() + if isnothing(grpcserver) + @test true + else + protocol = Arrow.Flight.Protocol + fixture = grpcserver_extension_fixture(protocol) + service = grpcserver_extension_service(protocol, fixture) + metadata = grpcserver_extension_metadata() + + grpcserver_extension_test_descriptor(grpcserver, service) + grpcserver_extension_test_unary(grpcserver, service, fixture, metadata) + grpcserver_extension_test_streaming(grpcserver, service, fixture, metadata) + end +end diff --git a/test/flight/grpcserver_extension/bidi_streaming_tests.jl b/test/flight/grpcserver_extension/bidi_streaming_tests.jl new file mode 100644 index 00000000..bd5bc4ca --- /dev/null +++ b/test/flight/grpcserver_extension/bidi_streaming_tests.jl @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_test_bidi_streaming(grpcserver, service, fixture, metadata) + grpc_descriptor = grpcserver.service_descriptor(service) + protocol = Arrow.Flight.Protocol + + handshake_messages, handshake_closed, handshake_stream = grpcserver_capture_bidi_stream( + grpcserver, + protocol.HandshakeRequest, + protocol.HandshakeResponse, + fixture.handshake_requests, + ) + grpc_descriptor.methods["Handshake"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/Handshake"; + metadata=metadata, + ), + handshake_stream, + ) + @test handshake_closed[] + @test length(handshake_messages) == 1 + @test handshake_messages[1].payload == b"native-token" + + doput_messages, doput_closed, doput_stream = grpcserver_capture_bidi_stream( + grpcserver, + protocol.FlightData, + protocol.PutResult, + fixture.messages, + ) + grpc_descriptor.methods["DoPut"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoPut"; + metadata=metadata, + ), + doput_stream, + ) + @test doput_closed[] + @test length(doput_messages) == 1 + @test String(doput_messages[1].app_metadata) == "stored" + + doexchange_messages, doexchange_closed, doexchange_stream = + grpcserver_capture_bidi_stream( + grpcserver, + protocol.FlightData, + protocol.FlightData, + fixture.exchange_messages, + ) + grpc_descriptor.methods["DoExchange"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoExchange"; + metadata=metadata, + ), + doexchange_stream, + ) + @test doexchange_closed[] + @test length(doexchange_messages) == length(fixture.exchange_messages) + + failing_service = Arrow.Flight.Service( + doexchange=(ctx, request, response) -> + throw(ArgumentError("bidi streaming failed before first response")), + ) + failing_descriptor = grpcserver.service_descriptor(failing_service) + failing_messages, failing_closed, failing_stream = grpcserver_capture_bidi_stream( + grpcserver, + protocol.FlightData, + protocol.FlightData, + fixture.exchange_messages, + ) + failure = try + failing_descriptor.methods["DoExchange"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoExchange"; + metadata=metadata, + ), + failing_stream, + ) + nothing + catch err + err + end + @test failure isa ArgumentError + @test occursin( + "bidi streaming failed before first response", + sprint(showerror, failure), + ) + @test !failing_closed[] + @test isempty(failing_messages) +end diff --git a/test/flight/grpcserver_extension/descriptor_tests.jl b/test/flight/grpcserver_extension/descriptor_tests.jl new file mode 100644 index 00000000..631891be --- /dev/null +++ b/test/flight/grpcserver_extension/descriptor_tests.jl @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_test_descriptor(grpcserver, service) + grpc_descriptor = grpcserver.service_descriptor(service) + @test Base.get_extension(Arrow, :ArrowgRPCServerExt) !== nothing + @test grpc_descriptor.name == "arrow.flight.protocol.FlightService" + @test haskey(grpc_descriptor.methods, "GetFlightInfo") + @test haskey(grpc_descriptor.methods, "DoGet") + @test haskey(grpc_descriptor.methods, "DoExchange") + @test grpc_descriptor.methods["GetFlightInfo"].method_type == + grpcserver.MethodType.UNARY + @test grpc_descriptor.methods["DoGet"].method_type == + grpcserver.MethodType.SERVER_STREAMING + @test grpc_descriptor.methods["DoExchange"].method_type == + grpcserver.MethodType.BIDI_STREAMING + @test grpc_descriptor.methods["DoGet"].input_type == "arrow.flight.protocol.Ticket" + @test grpc_descriptor.methods["DoGet"].output_type == "arrow.flight.protocol.FlightData" + return grpc_descriptor +end diff --git a/test/flight/grpcserver_extension/server_streaming_tests.jl b/test/flight/grpcserver_extension/server_streaming_tests.jl new file mode 100644 index 00000000..80bfcf24 --- /dev/null +++ b/test/flight/grpcserver_extension/server_streaming_tests.jl @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_test_server_streaming(grpcserver, service, fixture, metadata) + grpc_descriptor = grpcserver.service_descriptor(service) + protocol = Arrow.Flight.Protocol + + doget_messages, doget_closed, doget_stream = + grpcserver_capture_server_stream(grpcserver, protocol.FlightData) + grpc_descriptor.methods["DoGet"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoGet"; + metadata=metadata, + ), + fixture.ticket, + doget_stream, + ) + @test doget_closed[] + @test length(doget_messages) == length(fixture.messages) + @test Arrow.Flight.table(doget_messages; schema=fixture.info).name == + ["one", "two", "three"] + + doget_any_messages = Any[] + doget_any_closed = Ref(false) + doget_any_stream = grpcserver.ServerStream{Any}( + (message, compress) -> begin + @test compress + push!(doget_any_messages, message) + end, + () -> (doget_any_closed[] = true), + ) + grpc_descriptor.methods["DoGet"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoGet"; + metadata=metadata, + ), + fixture.ticket, + doget_any_stream, + ) + @test doget_any_closed[] + @test length(doget_any_messages) == length(fixture.messages) + @test all(message -> message isa protocol.FlightData, doget_any_messages) + + actions_messages, actions_closed, actions_stream = + grpcserver_capture_server_stream(grpcserver, protocol.ActionType) + grpc_descriptor.methods["ListActions"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/ListActions"; + metadata=metadata, + ), + protocol.Empty(), + actions_stream, + ) + @test actions_closed[] + @test length(actions_messages) == 1 + @test actions_messages[1].var"#type" == "ping" + + action_messages, action_closed, action_stream = + grpcserver_capture_server_stream(grpcserver, protocol.Result) + grpc_descriptor.methods["DoAction"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoAction"; + metadata=metadata, + ), + protocol.Action("ping", UInt8[]), + action_stream, + ) + @test action_closed[] + @test length(action_messages) == 1 + @test String(action_messages[1].body) == "pong" + + failing_service = Arrow.Flight.Service( + doget=(ctx, req, response) -> + throw(ArgumentError("server streaming failed before first response")), + ) + failing_descriptor = grpcserver.service_descriptor(failing_service) + failing_messages, failing_closed, failing_stream = + grpcserver_capture_server_stream(grpcserver, protocol.FlightData) + failure = try + failing_descriptor.methods["DoGet"].handler( + grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/DoGet"; + metadata=metadata, + ), + fixture.ticket, + failing_stream, + ) + nothing + catch err + err + end + @test failure isa ArgumentError + @test occursin( + "server streaming failed before first response", + sprint(showerror, failure), + ) + @test !failing_closed[] + @test isempty(failing_messages) +end diff --git a/test/flight/grpcserver_extension/streaming_tests.jl b/test/flight/grpcserver_extension/streaming_tests.jl new file mode 100644 index 00000000..d82ba909 --- /dev/null +++ b/test/flight/grpcserver_extension/streaming_tests.jl @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("server_streaming_tests.jl") +include("bidi_streaming_tests.jl") + +function grpcserver_extension_test_streaming(grpcserver, service, fixture, metadata) + grpcserver_extension_test_server_streaming(grpcserver, service, fixture, metadata) + grpcserver_extension_test_bidi_streaming(grpcserver, service, fixture, metadata) +end diff --git a/test/flight/grpcserver_extension/support.jl b/test/flight/grpcserver_extension/support.jl new file mode 100644 index 00000000..7daf8a41 --- /dev/null +++ b/test/flight/grpcserver_extension/support.jl @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("support/fixture.jl") +include("support/service.jl") +include("support/context.jl") +include("support/streams.jl") diff --git a/test/flight/grpcserver_extension/support/context.jl b/test/flight/grpcserver_extension/support/context.jl new file mode 100644 index 00000000..2d465989 --- /dev/null +++ b/test/flight/grpcserver_extension/support/context.jl @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +grpcserver_extension_metadata() = + Dict{String,Union{String,Vector{UInt8}}}("authorization" => "Bearer native") + +function grpcserver_extension_context( + grpcserver, + method::AbstractString; + metadata=grpcserver_extension_metadata(), +) + return grpcserver.ServerContext(method=String(method), metadata=metadata) +end diff --git a/test/flight/grpcserver_extension/support/fixture.jl b/test/flight/grpcserver_extension/support/fixture.jl new file mode 100644 index 00000000..23ff5a94 --- /dev/null +++ b/test/flight/grpcserver_extension/support/fixture.jl @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_fixture(protocol) + descriptor_type = protocol.var"FlightDescriptor.DescriptorType" + descriptor = + protocol.FlightDescriptor(descriptor_type.PATH, UInt8[], ["native", "dataset"]) + ticket = protocol.Ticket(b"native-ticket") + messages = Arrow.Flight.flightdata( + Tables.partitioner(( + (id=Int64[1, 2], name=["one", "two"]), + (id=Int64[3], name=["three"]), + )); + descriptor=descriptor, + ) + schema_bytes = Arrow.Flight.schemaipc(first(messages)) + info = protocol.FlightInfo( + schema_bytes[5:end], + descriptor, + [protocol.FlightEndpoint(ticket, protocol.Location[], nothing, UInt8[])], + Int64(3), + Int64(-1), + false, + UInt8[], + ) + handshake_requests = [protocol.HandshakeRequest(UInt64(0), b"native-token")] + exchange_messages = Arrow.Flight.flightdata( + Tables.partitioner(((id=Int64[10], name=["ten"]),)); + descriptor=descriptor, + ) + return ( + descriptor=descriptor, + ticket=ticket, + messages=messages, + schema_bytes=schema_bytes, + info=info, + handshake_requests=handshake_requests, + exchange_messages=exchange_messages, + ) +end diff --git a/test/flight/grpcserver_extension/support/service.jl b/test/flight/grpcserver_extension/support/service.jl new file mode 100644 index 00000000..8b0ef260 --- /dev/null +++ b/test/flight/grpcserver_extension/support/service.jl @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_service(protocol, fixture) + return Arrow.Flight.Service( + handshake=(ctx, request, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + incoming = collect(request) + @test length(incoming) == 1 + put!(response, protocol.HandshakeResponse(UInt64(0), incoming[1].payload)) + close(response) + return :handshake_ok + end, + getflightinfo=(ctx, req) -> begin + @test req.path == fixture.descriptor.path + return fixture.info + end, + getschema=(ctx, req) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + @test req.path == fixture.descriptor.path + return protocol.SchemaResult(fixture.schema_bytes[5:end]) + end, + doget=(ctx, req, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + @test req.ticket == fixture.ticket.ticket + foreach(message -> put!(response, message), fixture.messages) + close(response) + return :doget_ok + end, + listactions=(ctx, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + put!(response, protocol.ActionType("ping", "Ping action")) + close(response) + return :listactions_ok + end, + doaction=(ctx, action, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + @test action.var"#type" == "ping" + put!(response, protocol.Result(b"pong")) + close(response) + return :doaction_ok + end, + doput=(ctx, request, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + incoming = collect(request) + @test length(incoming) == length(fixture.messages) + put!(response, protocol.PutResult(b"stored")) + close(response) + return :doput_ok + end, + doexchange=(ctx, request, response) -> begin + @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" + incoming = collect(request) + foreach(message -> put!(response, message), incoming) + close(response) + return :doexchange_ok + end, + ) +end diff --git a/test/flight/grpcserver_extension/support/streams.jl b/test/flight/grpcserver_extension/support/streams.jl new file mode 100644 index 00000000..7d9da0a5 --- /dev/null +++ b/test/flight/grpcserver_extension/support/streams.jl @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_capture_server_stream(grpcserver, ::Type{T}) where {T} + messages = T[] + closed = Ref(false) + stream = grpcserver.ServerStream{T}( + (message, compress) -> begin + @test compress + push!(messages, message) + end, + () -> (closed[] = true), + ) + return messages, closed, stream +end + +function grpcserver_capture_bidi_stream( + grpcserver, + ::Type{Request}, + ::Type{Response}, + requests, +) where {Request,Response} + messages = Response[] + closed = Ref(false) + stream = grpcserver.BidiStream{Request,Response}( + FlightTestSupport.next_message_factory(requests), + (message, compress) -> begin + @test compress + push!(messages, message) + end, + () -> (closed[] = true), + () -> false, + ) + return messages, closed, stream +end diff --git a/test/flight/grpcserver_extension/unary_tests.jl b/test/flight/grpcserver_extension/unary_tests.jl new file mode 100644 index 00000000..2f005a19 --- /dev/null +++ b/test/flight/grpcserver_extension/unary_tests.jl @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function grpcserver_extension_test_unary(grpcserver, service, fixture, metadata) + grpc_descriptor = grpcserver.service_descriptor(service) + + unary_context = grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/GetFlightInfo"; + metadata=metadata, + ) + schema_context = grpcserver_extension_context( + grpcserver, + "/arrow.flight.protocol.FlightService/GetSchema"; + metadata=metadata, + ) + + direct_info = + grpc_descriptor.methods["GetFlightInfo"].handler(unary_context, fixture.descriptor) + @test direct_info.total_records == 3 + @test direct_info.endpoint[1].ticket.ticket == fixture.ticket.ticket + + direct_schema = + grpc_descriptor.methods["GetSchema"].handler(schema_context, fixture.descriptor) + @test Arrow.Flight.schemaipc(direct_schema) == fixture.schema_bytes +end diff --git a/test/flight/handshake_interop.jl b/test/flight/handshake_interop.jl new file mode 100644 index 00000000..0ef27117 --- /dev/null +++ b/test/flight/handshake_interop.jl @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight handshake interop" begin + server = FlightTestSupport.start_handshake_flight_server() + if isnothing(server) + @test true + else + protocol = Arrow.Flight.Protocol + + try + FlightTestSupport.with_test_grpc_handle() do grpc + client = Arrow.Flight.Client("grpc://127.0.0.1:$(server.port)"; grpc=grpc) + + handshake_req, handshake_request, handshake_response = + Arrow.Flight.handshake(client) + put!(handshake_request, protocol.HandshakeRequest(UInt64(0), b"test")) + put!(handshake_request, protocol.HandshakeRequest(UInt64(0), b"p4ssw0rd")) + close(handshake_request) + + handshake_messages = collect(handshake_response) + gRPCClient.grpc_async_await(handshake_req) + + @test length(handshake_messages) == 1 + @test handshake_messages[1].protocol_version == 0 + @test handshake_messages[1].payload == b"secret:test" + + token_client = Arrow.Flight.withtoken(client, handshake_messages[1].payload) + actions_req, actions_channel = Arrow.Flight.listactions(token_client) + actions = collect(actions_channel) + gRPCClient.grpc_async_await(actions_req) + @test actions == + [protocol.ActionType("authenticated", "Requires a valid auth token")] + + auth_client, auth_messages = + Arrow.Flight.authenticate(client, "test", "p4ssw0rd") + @test length(auth_messages) == 1 + @test auth_messages[1].protocol_version == + handshake_messages[1].protocol_version + @test auth_messages[1].payload == handshake_messages[1].payload + @test auth_client.headers == ["auth-token-bin" => b"secret:test"] + + bad_req, bad_request, bad_response = Arrow.Flight.handshake(client) + put!(bad_request, protocol.HandshakeRequest(UInt64(0), b"test")) + put!(bad_request, protocol.HandshakeRequest(UInt64(0), b"wrong")) + close(bad_request) + + @test isempty(collect(bad_response)) + @test_throws gRPCClient.gRPCServiceCallException gRPCClient.grpc_async_await( + bad_req, + ) + end + finally + FlightTestSupport.stop_pyarrow_flight_server(server) + end + end +end diff --git a/test/flight/header_interop.jl b/test/flight/header_interop.jl new file mode 100644 index 00000000..a2f35a6f --- /dev/null +++ b/test/flight/header_interop.jl @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight header interop" begin + server = FlightTestSupport.start_headers_flight_server() + if isnothing(server) + @test true + else + protocol = Arrow.Flight.Protocol + + try + FlightTestSupport.with_test_grpc_handle() do grpc + base_client = + Arrow.Flight.Client("grpc://127.0.0.1:$(server.port)"; grpc=grpc) + client = Arrow.Flight.withheaders( + base_client, + "authorization" => "Bearer token1234", + ) + + actions_req, actions_channel = Arrow.Flight.listactions(client) + actions = collect(actions_channel) + gRPCClient.grpc_async_await(actions_req) + @test actions == [ + protocol.ActionType( + "echo-authorization", + "Return the Authorization header", + ), + ] + + action_req, action_channel = Arrow.Flight.doaction( + client, + protocol.Action("echo-authorization", UInt8[]), + ) + action_results = collect(action_channel) + gRPCClient.grpc_async_await(action_req) + @test length(action_results) == 1 + @test String(action_results[1].body) == "Bearer token1234" + + call_req, call_channel = Arrow.Flight.doaction( + base_client, + protocol.Action("echo-authorization", UInt8[]); + headers=["authorization" => "Bearer call-level"], + ) + call_results = collect(call_channel) + gRPCClient.grpc_async_await(call_req) + @test length(call_results) == 1 + @test String(call_results[1].body) == "Bearer call-level" + end + finally + FlightTestSupport.stop_pyarrow_flight_server(server) + end + end +end diff --git a/test/flight/ipc_conversion.jl b/test/flight/ipc_conversion.jl new file mode 100644 index 00000000..a2d20671 --- /dev/null +++ b/test/flight/ipc_conversion.jl @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight IPC conversion helpers" begin + missing_schema_fragment = "the server may have terminated the stream before emitting the first schema-bearing FlightData message" + descriptor = Arrow.Flight.Protocol.FlightDescriptor( + Arrow.Flight.Protocol.var"FlightDescriptor.DescriptorType".PATH, + UInt8[], + ["datasets", "roundtrip"], + ) + source = Tables.partitioner(( + (id=Int64[1, 2], label=["one", "two"]), + (id=Int64[3], label=["three"]), + )) + messages = Arrow.Flight.flightdata(source; descriptor=descriptor) + + @test !isempty(messages) + @test messages[1].flight_descriptor == descriptor + @test all(isnothing(msg.flight_descriptor) for msg in messages[2:end]) + @test !isempty(messages[1].data_header) + @test isempty(messages[1].data_body) + + bytes = Arrow.Flight.streambytes(messages) + @test Arrow.readbuffer(bytes, 1, UInt32) == Arrow.CONTINUATION_INDICATOR_BYTES + @test Arrow.readbuffer(bytes, length(bytes) - 3, Int32) == 0 + + batches = collect(Arrow.Flight.stream(messages)) + @test length(batches) == 2 + @test batches[1].id == [1, 2] + @test batches[2].label == ["three"] + + tbl = Arrow.Flight.table(messages) + @test tbl.id == [1, 2, 3] + @test tbl.label == ["one", "two", "three"] + + schema_bytes = Arrow.Flight.schemaipc(first(messages)) + @test Arrow.Flight.schemaipc(Arrow.Flight.Protocol.SchemaResult(schema_bytes[5:end])) == + schema_bytes + + stream_error = try + Arrow.Flight.stream(Arrow.Flight.Protocol.FlightData[]) + nothing + catch err + err + end + @test stream_error isa ArgumentError + @test occursin(missing_schema_fragment, sprint(showerror, stream_error)) + + table_error = try + Arrow.Flight.table(Arrow.Flight.Protocol.FlightData[]) + nothing + catch err + err + end + @test table_error isa ArgumentError + @test occursin(missing_schema_fragment, sprint(showerror, table_error)) + + empty_tbl = Arrow.Flight.table( + Arrow.Flight.Protocol.FlightData[]; + schema=Arrow.Flight.Protocol.SchemaResult(schema_bytes[5:end]), + ) + @test isempty(empty_tbl.id) + @test isempty(empty_tbl.label) +end diff --git a/test/flight/ipc_schema_separation.jl b/test/flight/ipc_schema_separation.jl new file mode 100644 index 00000000..facd1e1a --- /dev/null +++ b/test/flight/ipc_schema_separation.jl @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight IPC schema separation" begin + source = Tables.partitioner(((word=["red", "blue"],), (word=["red", "green"],))) + messages = Arrow.Flight.flightdata(source; dictencode=true) + schema_bytes = Arrow.Flight.schemaipc(first(messages)) + info = Arrow.Flight.Protocol.FlightInfo( + schema_bytes[5:end], + nothing, + Arrow.Flight.Protocol.FlightEndpoint[], + Int64(-1), + Int64(-1), + false, + UInt8[], + ) + payload = messages[2:end] + + @test length(messages) >= 4 + @test Arrow.Flight.schemaipc(info) == schema_bytes + + batches = collect(Arrow.Flight.stream(payload; schema=info)) + @test length(batches) == 2 + @test isequal(batches[1].word, ["red", "blue"]) + @test isequal(batches[2].word, ["red", "green"]) + + tbl = Arrow.Flight.table(payload; schema=info) + @test isequal(tbl.word, ["red", "blue", "red", "green"]) +end diff --git a/test/flight/poll_interop.jl b/test/flight/poll_interop.jl new file mode 100644 index 00000000..b3a5cf17 --- /dev/null +++ b/test/flight/poll_interop.jl @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight poll interop" begin + server = FlightTestSupport.start_poll_flight_server() + if isnothing(server) + @test true + else + protocol = Arrow.Flight.Protocol + descriptor_type = protocol.var"FlightDescriptor.DescriptorType" + + try + FlightTestSupport.with_test_grpc_handle() do grpc + client = Arrow.Flight.Client("grpc://127.0.0.1:$(server.port)"; grpc=grpc) + initial_descriptor = protocol.FlightDescriptor( + descriptor_type.PATH, + UInt8[], + ["interop", "poll"], + ) + + first_poll = Arrow.Flight.pollflightinfo(client, initial_descriptor) + @test !isnothing(first_poll.info) + @test !isnothing(first_poll.flight_descriptor) + @test first_poll.flight_descriptor.path == ["interop", "poll", "retry"] + @test first_poll.info.total_records == 1 + @test first_poll.info.ordered + @test first_poll.progress ≈ 0.5 + @test Arrow.Flight.schemaipc(first_poll.info) == Arrow.Flight.schemaipc( + protocol.SchemaResult(first_poll.info.schema[5:end]), + ) + + second_poll = + Arrow.Flight.pollflightinfo(client, first_poll.flight_descriptor) + @test !isnothing(second_poll.info) + @test isnothing(second_poll.flight_descriptor) + @test second_poll.info.flight_descriptor.path == ["interop", "poll"] + @test second_poll.progress ≈ 1.0 + @test length(second_poll.info.endpoint) == 1 + @test second_poll.info.endpoint[1].ticket.ticket == b"poll-ticket" + end + finally + FlightTestSupport.stop_pyarrow_flight_server(server) + end + end +end diff --git a/test/flight/pyarrow_interop.jl b/test/flight/pyarrow_interop.jl new file mode 100644 index 00000000..ebf0e0f8 --- /dev/null +++ b/test/flight/pyarrow_interop.jl @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("pyarrow_interop/support.jl") +include("pyarrow_interop/discovery_tests.jl") +include("pyarrow_interop/download_tests.jl") +include("pyarrow_interop/upload_tests.jl") +include("pyarrow_interop/exchange_tests.jl") + +@testset "Flight pyarrow interop" begin + server = FlightTestSupport.start_pyarrow_flight_server() + if isnothing(server) + @test true + else + protocol = Arrow.Flight.Protocol + descriptors = pyarrow_interop_descriptors(protocol) + + try + FlightTestSupport.with_test_grpc_handle() do grpc + client = Arrow.Flight.Client("grpc://127.0.0.1:$(server.port)"; grpc=grpc) + pyarrow_interop_test_discovery(client, protocol, descriptors.download) + pyarrow_interop_test_download(client, descriptors.download) + pyarrow_interop_test_upload(client, descriptors.upload) + pyarrow_interop_test_exchange(client, descriptors.exchange) + end + finally + FlightTestSupport.stop_pyarrow_flight_server(server) + end + end +end diff --git a/test/flight/pyarrow_interop/discovery_tests.jl b/test/flight/pyarrow_interop/discovery_tests.jl new file mode 100644 index 00000000..ee70af1f --- /dev/null +++ b/test/flight/pyarrow_interop/discovery_tests.jl @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function pyarrow_interop_test_discovery(client, protocol, download_descriptor) + flights_req, flights_channel = Arrow.Flight.listflights(client) + flights = pyarrow_interop_collect(flights_req, flights_channel) + @test any( + info -> + !isnothing(info.flight_descriptor) && + info.flight_descriptor.path == download_descriptor.path, + flights, + ) + + actions_req, actions_channel = Arrow.Flight.listactions(client) + actions = pyarrow_interop_collect(actions_req, actions_channel) + @test any(action -> action.var"#type" == "ping", actions) + + action_req, action_channel = + Arrow.Flight.doaction(client, protocol.Action("ping", UInt8[])) + action_results = pyarrow_interop_collect(action_req, action_channel) + @test length(action_results) == 1 + @test String(action_results[1].body) == "pong" +end diff --git a/test/flight/pyarrow_interop/download_tests.jl b/test/flight/pyarrow_interop/download_tests.jl new file mode 100644 index 00000000..6b6b25a8 --- /dev/null +++ b/test/flight/pyarrow_interop/download_tests.jl @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function pyarrow_interop_test_download(client, download_descriptor) + download_info = Arrow.Flight.getflightinfo(client, download_descriptor) + @test download_info.total_records == 3 + @test length(download_info.endpoint) == 1 + + download_schema = Arrow.Flight.getschema(client, download_descriptor) + @test Arrow.Flight.schemaipc(download_schema) == Arrow.Flight.schemaipc(download_info) + + doget_req, doget_channel = Arrow.Flight.doget(client, download_info.endpoint[1].ticket) + download_messages = pyarrow_interop_collect(doget_req, doget_channel) + + download_table = Arrow.Flight.table(download_messages; schema=download_info) + @test download_table.id == [1, 2, 3] + @test download_table.name == ["one", "two", "three"] +end diff --git a/test/flight/pyarrow_interop/exchange_tests.jl b/test/flight/pyarrow_interop/exchange_tests.jl new file mode 100644 index 00000000..aaad35e3 --- /dev/null +++ b/test/flight/pyarrow_interop/exchange_tests.jl @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function pyarrow_interop_test_exchange(client, exchange_descriptor) + exchange_source = Tables.partitioner(( + (id=Int64[21, 22], name=["twenty-one", "twenty-two"]), + (id=Int64[23], name=["twenty-three"]), + )) + exchange_messages = + Arrow.Flight.flightdata(exchange_source; descriptor=exchange_descriptor) + + exchange_req, exchange_request, exchange_response = Arrow.Flight.doexchange(client) + exchanged_messages = pyarrow_interop_send_messages( + exchange_req, + exchange_request, + exchange_response, + exchange_messages, + ) + + exchange_table = Arrow.Flight.table(exchanged_messages) + @test exchange_table.id == [21, 22, 23] + @test exchange_table.name == ["twenty-one", "twenty-two", "twenty-three"] + @test filter(!isempty, getfield.(exchanged_messages, :app_metadata)) == + [b"exchange:0", b"exchange:1"] +end diff --git a/test/flight/pyarrow_interop/support.jl b/test/flight/pyarrow_interop/support.jl new file mode 100644 index 00000000..893f65ce --- /dev/null +++ b/test/flight/pyarrow_interop/support.jl @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function pyarrow_interop_descriptors(protocol) + descriptor_type = protocol.var"FlightDescriptor.DescriptorType" + return ( + download=protocol.FlightDescriptor( + descriptor_type.PATH, + UInt8[], + ["interop", "download"], + ), + upload=protocol.FlightDescriptor( + descriptor_type.PATH, + UInt8[], + ["interop", "upload"], + ), + exchange=protocol.FlightDescriptor( + descriptor_type.PATH, + UInt8[], + ["interop", "exchange"], + ), + ) +end + +function pyarrow_interop_collect(req, channel) + messages = collect(channel) + gRPCClient.grpc_async_await(req) + return messages +end + +function pyarrow_interop_send_messages(req, request, response, messages) + for message in messages + put!(request, message) + end + close(request) + responses = collect(response) + gRPCClient.grpc_async_await(req) + return responses +end diff --git a/test/flight/pyarrow_interop/upload_tests.jl b/test/flight/pyarrow_interop/upload_tests.jl new file mode 100644 index 00000000..1c10c9e2 --- /dev/null +++ b/test/flight/pyarrow_interop/upload_tests.jl @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function pyarrow_interop_test_upload(client, upload_descriptor) + upload_source = Tables.partitioner(( + (id=Int64[10, 11], name=["ten", "eleven"]), + (id=Int64[12], name=["twelve"]), + )) + upload_messages = Arrow.Flight.flightdata(upload_source; descriptor=upload_descriptor) + + doput_req, doput_request, doput_response = Arrow.Flight.doput(client) + put_results = pyarrow_interop_send_messages( + doput_req, + doput_request, + doput_response, + upload_messages, + ) + + @test !isempty(put_results) + @test String(put_results[end].app_metadata) == "stored" + + uploaded_info = Arrow.Flight.getflightinfo(client, upload_descriptor) + uploaded_req, uploaded_channel = + Arrow.Flight.doget(client, uploaded_info.endpoint[1].ticket) + uploaded_messages = pyarrow_interop_collect(uploaded_req, uploaded_channel) + + uploaded_table = Arrow.Flight.table(uploaded_messages; schema=uploaded_info) + @test uploaded_table.id == [10, 11, 12] + @test uploaded_table.name == ["ten", "eleven", "twelve"] +end diff --git a/test/flight/server_core.jl b/test/flight/server_core.jl new file mode 100644 index 00000000..3ca8e2da --- /dev/null +++ b/test/flight/server_core.jl @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include("server_core/support.jl") +include("server_core/metadata_tests.jl") +include("server_core/descriptor_tests.jl") +include("server_core/direct_handler_tests.jl") +include("server_core/dispatch_tests.jl") + +@testset "Flight server core surface" begin + fixture = flight_server_core_fixture() + flight_server_core_test_metadata(fixture) + flight_server_core_test_descriptors(fixture) + flight_server_core_test_direct_handlers(fixture) + flight_server_core_test_dispatch(fixture) +end diff --git a/test/flight/server_core/descriptor_tests.jl b/test/flight/server_core/descriptor_tests.jl new file mode 100644 index 00000000..1cf24550 --- /dev/null +++ b/test/flight/server_core/descriptor_tests.jl @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_server_core_test_descriptors(fixture) + handshake_descriptor = Arrow.Flight.lookupmethod(fixture.descriptor_info, "Handshake") + @test !isnothing(handshake_descriptor) + @test handshake_descriptor.path == "/arrow.flight.protocol.FlightService/Handshake" + @test handshake_descriptor.request_streaming + @test handshake_descriptor.response_streaming + @test handshake_descriptor.request_type === fixture.protocol.HandshakeRequest + @test handshake_descriptor.response_type === fixture.protocol.HandshakeResponse + + doget_descriptor = Arrow.Flight.lookupmethod( + fixture.descriptor_info, + "/arrow.flight.protocol.FlightService/DoGet", + ) + @test !isnothing(doget_descriptor) + @test !doget_descriptor.request_streaming + @test doget_descriptor.response_streaming + @test doget_descriptor.request_type === fixture.protocol.Ticket + @test doget_descriptor.response_type === fixture.protocol.FlightData + @test isnothing(Arrow.Flight.lookupmethod(fixture.descriptor_info, "MissingMethod")) +end diff --git a/test/flight/server_core/direct_handler_tests.jl b/test/flight/server_core/direct_handler_tests.jl new file mode 100644 index 00000000..e936dad2 --- /dev/null +++ b/test/flight/server_core/direct_handler_tests.jl @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_server_core_test_direct_handlers(fixture) + @test_throws gRPCClient.gRPCServiceCallException Arrow.Flight.getflightinfo( + fixture.service, + fixture.context, + fixture.descriptor, + ) + + info = + Arrow.Flight.getflightinfo(fixture.implemented, fixture.context, fixture.descriptor) + @test info.total_records == 7 + @test info.total_bytes == 42 + @test info.flight_descriptor.path == ["server", "dataset"] + + get_response = Channel{fixture.protocol.FlightData}(1) + @test Arrow.Flight.doget( + fixture.implemented, + fixture.context, + fixture.protocol.Ticket(b"ticket-1"), + get_response, + ) == :doget_ok + @test length(collect(get_response)) == 1 + + actions_response = Channel{fixture.protocol.ActionType}(1) + @test Arrow.Flight.listactions( + fixture.implemented, + fixture.context, + actions_response, + ) == :listactions_ok + actions = collect(actions_response) + @test length(actions) == 1 + @test getfield(actions[1], Symbol("#type")) == "ping" + @test actions[1].description == "Ping action" +end diff --git a/test/flight/server_core/dispatch_tests.jl b/test/flight/server_core/dispatch_tests.jl new file mode 100644 index 00000000..28bc7b42 --- /dev/null +++ b/test/flight/server_core/dispatch_tests.jl @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_server_core_test_dispatch(fixture) + dispatch_info = Arrow.Flight.dispatch( + fixture.implemented, + fixture.context, + "/arrow.flight.protocol.FlightService/GetFlightInfo", + fixture.descriptor, + ) + @test dispatch_info.total_records == 7 + @test dispatch_info.flight_descriptor.path == ["server", "dataset"] + + doget_descriptor = Arrow.Flight.lookupmethod( + fixture.descriptor_info, + "/arrow.flight.protocol.FlightService/DoGet", + ) + get_response = Channel{fixture.protocol.FlightData}(1) + @test Arrow.Flight.dispatch( + fixture.implemented, + fixture.context, + doget_descriptor, + fixture.protocol.Ticket(b"ticket-1"), + get_response, + ) == :doget_ok + @test length(collect(get_response)) == 1 + + actions_response = Channel{fixture.protocol.ActionType}(1) + @test Arrow.Flight.dispatch( + fixture.implemented, + fixture.context, + "ListActions", + actions_response, + ) == :listactions_ok + @test length(collect(actions_response)) == 1 + @test_throws ArgumentError Arrow.Flight.dispatch( + fixture.implemented, + fixture.context, + "/arrow.flight.protocol.FlightService/MissingMethod", + fixture.descriptor, + ) +end diff --git a/test/flight/server_core/metadata_tests.jl b/test/flight/server_core/metadata_tests.jl new file mode 100644 index 00000000..3393792c --- /dev/null +++ b/test/flight/server_core/metadata_tests.jl @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_server_core_test_metadata(fixture) + @test Arrow.Flight.callheader(fixture.context, "authorization") == "Bearer test" + @test Arrow.Flight.callheader(fixture.context, "Authorization") == "Bearer test" + @test Arrow.Flight.callheader(fixture.context, "auth-token-bin") == UInt8[0x01, 0x02] + @test isnothing(Arrow.Flight.callheader(fixture.context, "missing")) + @test fixture.descriptor_info.name == "arrow.flight.protocol.FlightService" + @test length(fixture.descriptor_info.methods) == 10 +end diff --git a/test/flight/server_core/support.jl b/test/flight/server_core/support.jl new file mode 100644 index 00000000..3c3397ae --- /dev/null +++ b/test/flight/server_core/support.jl @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function flight_server_core_fixture() + protocol = Arrow.Flight.Protocol + context = Arrow.Flight.ServerCallContext( + headers=["authorization" => "Bearer test", "auth-token-bin" => UInt8[0x01, 0x02]], + peer="127.0.0.1:4000", + secure=true, + ) + descriptor_info = Arrow.Flight.servicedescriptor(Arrow.Flight.Service()) + descriptor_type = protocol.var"FlightDescriptor.DescriptorType" + descriptor = + protocol.FlightDescriptor(descriptor_type.PATH, UInt8[], ["server", "dataset"]) + service = Arrow.Flight.Service() + implemented = Arrow.Flight.Service( + getflightinfo=(ctx, req) -> begin + @test ctx === context + @test req.path == descriptor.path + return protocol.FlightInfo( + UInt8[], + req, + protocol.FlightEndpoint[], + 7, + 42, + false, + UInt8[], + ) + end, + doget=(ctx, ticket, response) -> begin + @test ctx === context + @test ticket.ticket == b"ticket-1" + put!(response, protocol.FlightData(nothing, UInt8[], UInt8[], UInt8[])) + close(response) + return :doget_ok + end, + listactions=(ctx, response) -> begin + @test ctx === context + put!(response, protocol.ActionType("ping", "Ping action")) + close(response) + return :listactions_ok + end, + ) + return (; protocol, context, descriptor_info, descriptor, service, implemented) +end diff --git a/test/flight/support.jl b/test/flight/support.jl new file mode 100644 index 00000000..d1cb8d39 --- /dev/null +++ b/test/flight/support.jl @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module FlightTestSupport + +using gRPCClient + +export PyArrowFlightServer, + flight_test_roots, + pyarrow_flight_python, + start_pyarrow_flight_server, + start_headers_flight_server, + start_handshake_flight_server, + start_poll_flight_server, + start_tls_flight_server, + stop_pyarrow_flight_server, + with_test_grpc_handle, + load_grpcserver, + generate_test_tls_certificate, + next_message_factory + +include("support/types.jl") +include("support/paths.jl") +include("support/python_servers.jl") +include("support/grpc.jl") +include("support/tls.jl") +include("support/streams.jl") + +end diff --git a/test/flight/support/grpc.jl b/test/flight/support/grpc.jl new file mode 100644 index 00000000..a9eb7cd2 --- /dev/null +++ b/test/flight/support/grpc.jl @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function with_test_grpc_handle(f::F) where {F} + grpc = gRPCClient.gRPCCURL() + gRPCClient.grpc_init(grpc) + try + return f(grpc) + finally + gRPCClient.grpc_shutdown(grpc) + end +end + +function load_grpcserver() + isnothing(Base.find_package("gRPCServer")) && return nothing + return Base.require( + Base.PkgId(Base.UUID("608c6337-0d7d-447f-bb69-0f5674ee3959"), "gRPCServer"), + ) +end diff --git a/test/flight/support/paths.jl b/test/flight/support/paths.jl new file mode 100644 index 00000000..0632072d --- /dev/null +++ b/test/flight/support/paths.jl @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const TEST_ROOT = normpath(joinpath(@__DIR__, "..", "..")) + +function git_toplevel(path::AbstractString) + try + cmd = pipeline( + Cmd(["git", "-C", path, "rev-parse", "--show-toplevel"]); + stderr=devnull, + ) + return normpath(chomp(read(cmd, String))) + catch + return nothing + end +end + +function flight_test_roots() + roots = String[] + path = abspath(TEST_ROOT) + while true + top = git_toplevel(path) + !isnothing(top) && push!(roots, top) + parent = dirname(path) + parent == path && break + path = parent + end + push!(roots, TEST_ROOT) + unique!(roots) + return roots +end + +function pyarrow_flight_python() + haskey(ENV, "ARROW_FLIGHT_PYTHON") && return ENV["ARROW_FLIGHT_PYTHON"] + cache_home = get(ENV, "PRJ_CACHE_HOME", ".cache") + for root in flight_test_roots() + python = joinpath(root, cache_home, "arrow-julia-flight-pyenv", "bin", "python") + isfile(python) && return python + end + return nothing +end diff --git a/test/flight/support/python_servers.jl b/test/flight/support/python_servers.jl new file mode 100644 index 00000000..07ba6933 --- /dev/null +++ b/test/flight/support/python_servers.jl @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function start_python_flight_server( + script_name::AbstractString; + env_overrides::AbstractDict{<:AbstractString,<:AbstractString}=Dict{String,String}(), +) + python = pyarrow_flight_python() + isnothing(python) && return nothing + + stdout = Pipe() + stderr = Pipe() + env = merge( + Dict{String,String}(ENV), + Dict("PYTHONUNBUFFERED" => "1"), + Dict{String,String}(string(k) => string(v) for (k, v) in pairs(env_overrides)), + ) + cmd = setenv(Cmd([python, joinpath(TEST_ROOT, script_name)]), env) + process = run(pipeline(cmd; stdout=stdout, stderr=stderr), wait=false) + close(stdout.in) + close(stderr.in) + + line = try + readline(stdout) + catch err + errout = read(stderr, String) + wait(process) + error( + "failed to start pyarrow Flight server: $(sprint(showerror, err)); stderr=$(repr(errout))", + ) + end + port = parse(Int, chomp(line)) + return PyArrowFlightServer(process, stdout, stderr, port) +end + +start_pyarrow_flight_server() = start_python_flight_server("flight_pyarrow_server.py") +start_headers_flight_server() = start_python_flight_server("flight_headers_server.py") +start_handshake_flight_server() = start_python_flight_server("flight_handshake_server.py") +start_poll_flight_server() = start_python_flight_server("flight_poll_server.py") +function start_tls_flight_server(cert_path::AbstractString, key_path::AbstractString) + start_python_flight_server( + "flight_tls_server.py"; + env_overrides=Dict( + "ARROW_FLIGHT_TLS_CERT" => String(cert_path), + "ARROW_FLIGHT_TLS_KEY" => String(key_path), + ), + ) +end + +function stop_pyarrow_flight_server(server::PyArrowFlightServer) + try + kill(server.process) + catch + end + try + wait(server.process) + catch + end + close(server.stdout) + close(server.stderr) + return +end + +stop_pyarrow_flight_server(::Nothing) = nothing diff --git a/test/flight/support/streams.jl b/test/flight/support/streams.jl new file mode 100644 index 00000000..8b9d048c --- /dev/null +++ b/test/flight/support/streams.jl @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function next_message_factory(messages) + index = Ref(1) + return () -> begin + current = index[] + current > length(messages) && return nothing + index[] = current + 1 + return messages[current] + end +end diff --git a/test/flight/support/tls.jl b/test/flight/support/tls.jl new file mode 100644 index 00000000..ca7dfcf8 --- /dev/null +++ b/test/flight/support/tls.jl @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +function generate_test_tls_certificate(dir::AbstractString) + openssl = Sys.which("openssl") + isnothing(openssl) && return nothing + + config_path = joinpath(dir, "openssl.cnf") + cert_path = joinpath(dir, "cert.pem") + key_path = joinpath(dir, "key.pem") + write( + config_path, + """ + [req] + distinguished_name = dn + x509_extensions = v3_req + prompt = no + + [dn] + CN = localhost + + [v3_req] + subjectAltName = @alt_names + + [alt_names] + DNS.1 = localhost + IP.1 = 127.0.0.1 + """, + ) + run( + Cmd([ + openssl, + "req", + "-x509", + "-nodes", + "-newkey", + "rsa:2048", + "-keyout", + key_path, + "-out", + cert_path, + "-days", + "1", + "-config", + config_path, + "-extensions", + "v3_req", + ]), + ) + return cert_path, key_path +end diff --git a/test/flight/support/types.jl b/test/flight/support/types.jl new file mode 100644 index 00000000..5cbca294 --- /dev/null +++ b/test/flight/support/types.jl @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +struct PyArrowFlightServer + process::Base.Process + stdout::Pipe + stderr::Pipe + port::Int +end diff --git a/test/flight/tls_interop.jl b/test/flight/tls_interop.jl new file mode 100644 index 00000000..7cdc3a77 --- /dev/null +++ b/test/flight/tls_interop.jl @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@testset "Flight TLS interop" begin + mktempdir() do dir + tls_material = FlightTestSupport.generate_test_tls_certificate(dir) + if isnothing(tls_material) + @test true + return + end + cert_path, key_path = tls_material + server = FlightTestSupport.start_tls_flight_server(cert_path, key_path) + if isnothing(server) + @test true + else + protocol = Arrow.Flight.Protocol + descriptor_type = protocol.var"FlightDescriptor.DescriptorType" + descriptor = protocol.FlightDescriptor( + descriptor_type.PATH, + UInt8[], + ["interop", "tls", "download"], + ) + + try + FlightTestSupport.with_test_grpc_handle() do grpc + client = Arrow.Flight.Client( + "grpc+tls://localhost:$(server.port)"; + grpc=grpc, + tls_root_certs=cert_path, + ) + info = Arrow.Flight.getflightinfo(client, descriptor) + @test info.total_records == 3 + @test length(info.endpoint) == 1 + + schema = Arrow.Flight.getschema(client, descriptor) + @test Arrow.Flight.schemaipc(schema) == Arrow.Flight.schemaipc(info) + + req, channel = Arrow.Flight.doget(client, info.endpoint[1].ticket) + messages = collect(channel) + gRPCClient.grpc_async_await(req) + + table = Arrow.Flight.table(messages; schema=info) + @test table.id == [31, 32, 33] + @test table.name == ["thirty-one", "thirty-two", "thirty-three"] + end + + FlightTestSupport.with_test_grpc_handle() do grpc + insecure_client = Arrow.Flight.Client( + "grpc+tls://localhost:$(server.port)"; + grpc=grpc, + disable_server_verification=true, + ) + info = Arrow.Flight.getflightinfo(insecure_client, descriptor) + @test info.total_records == 3 + end + finally + FlightTestSupport.stop_pyarrow_flight_server(server) + end + end + end +end diff --git a/test/flight_grpcserver.jl b/test/flight_grpcserver.jl new file mode 100644 index 00000000..7fd8c4c3 --- /dev/null +++ b/test/flight_grpcserver.jl @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +using Pkg + +const TEST_ROOT = @__DIR__ +const ARROW_ROOT = normpath(joinpath(TEST_ROOT, "..")) +const ARROWTYPES_ROOT = joinpath(ARROW_ROOT, "src", "ArrowTypes") + +function maybe_git_root(path::AbstractString) + try + return readchomp(pipeline(`git -C $path rev-parse --show-toplevel`; stderr=devnull)) + catch + return nothing + end +end + +function flight_grpcserver_roots(path::AbstractString) + roots = String[] + current = abspath(path) + while true + root = maybe_git_root(current) + if !isnothing(root) && root ∉ roots + push!(roots, root) + end + parent = dirname(current) + parent == current && break + current = parent + end + return roots +end + +function locate_grpcserver() + if haskey(ENV, "ARROW_FLIGHT_GRPCSERVER_PATH") + candidate = abspath(ENV["ARROW_FLIGHT_GRPCSERVER_PATH"]) + isdir(candidate) || error("ARROW_FLIGHT_GRPCSERVER_PATH does not exist: $candidate") + return candidate + end + for root in flight_grpcserver_roots(TEST_ROOT) + candidate = joinpath(root, ".cache", "vendor", "gRPCServer.jl") + isdir(candidate) && return candidate + end + error( + "Could not locate vendored gRPCServer.jl. " * + "Set ARROW_FLIGHT_GRPCSERVER_PATH to an explicit checkout path.", + ) +end + +const TEMP_ENV = mktempdir() +cp(joinpath(TEST_ROOT, "Project.toml"), joinpath(TEMP_ENV, "Project.toml")) + +Pkg.activate(TEMP_ENV) +Pkg.develop(PackageSpec(path=ARROW_ROOT)) +Pkg.develop(PackageSpec(path=ARROWTYPES_ROOT)) +Pkg.develop(PackageSpec(path=locate_grpcserver())) +Pkg.instantiate() + +using Test +using Arrow + +include(joinpath(TEST_ROOT, "flight.jl")) diff --git a/test/flight_handshake_server.py b/test/flight_handshake_server.py new file mode 100644 index 00000000..ab4f855f --- /dev/null +++ b/test/flight_handshake_server.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#!/usr/bin/env python3 + +import signal + +import pyarrow.flight as fl + + +class TokenAuthHandler(fl.ServerAuthHandler): + def authenticate(self, outgoing, incoming): + username = incoming.read() + password = incoming.read() + if username == b"test" and password == b"p4ssw0rd": + outgoing.write(b"secret:test") + return + raise fl.FlightUnauthenticatedError("invalid username/password") + + def is_valid(self, token): + if token != b"secret:test": + raise fl.FlightUnauthenticatedError("invalid token") + return b"test" + + +class HandshakeFlightServer(fl.FlightServerBase): + def __init__(self): + super().__init__( + location="grpc://127.0.0.1:0", + auth_handler=TokenAuthHandler(), + ) + + def list_actions(self, context): + del context + return [fl.ActionType("authenticated", "Requires a valid auth token")] + + +def main(): + server = HandshakeFlightServer() + + def shutdown_handler(signum, frame): + del signum, frame + server.shutdown() + + signal.signal(signal.SIGTERM, shutdown_handler) + signal.signal(signal.SIGINT, shutdown_handler) + + print(server.port, flush=True) + server.serve() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/flight_headers_server.py b/test/flight_headers_server.py new file mode 100644 index 00000000..39b86ccd --- /dev/null +++ b/test/flight_headers_server.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#!/usr/bin/env python3 + +import signal + +import pyarrow.flight as fl + + +def case_insensitive_header_lookup(headers, lookup_key): + lookup_key = lookup_key.lower() + for key, value in headers.items(): + if key.lower() == lookup_key: + return value + raise fl.FlightUnauthenticatedError(f"missing required header: {lookup_key}") + + +class HeaderEchoServerMiddlewareFactory(fl.ServerMiddlewareFactory): + def start_call(self, info, headers): + del info + authorization = case_insensitive_header_lookup(headers, "authorization") + return HeaderEchoServerMiddleware(authorization[0]) + + +class HeaderEchoServerMiddleware(fl.ServerMiddleware): + def __init__(self, authorization): + self.authorization = authorization + + +class HeaderEchoFlightServer(fl.FlightServerBase): + def __init__(self): + super().__init__( + location="grpc://127.0.0.1:0", + middleware={"auth": HeaderEchoServerMiddlewareFactory()}, + ) + + def list_actions(self, context): + del context + return [("echo-authorization", "Return the Authorization header")] + + def do_action(self, context, action): + if action.type != "echo-authorization": + raise KeyError(f"unsupported action: {action.type}") + middleware = context.get_middleware("auth") + if middleware is None: + raise fl.FlightUnauthenticatedError("missing auth middleware") + return [middleware.authorization.encode("utf-8")] + + +def main(): + server = HeaderEchoFlightServer() + + def shutdown_handler(signum, frame): + del signum, frame + server.shutdown() + + signal.signal(signal.SIGTERM, shutdown_handler) + signal.signal(signal.SIGINT, shutdown_handler) + + print(server.port, flush=True) + server.serve() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/flight_poll_server.py b/test/flight_poll_server.py new file mode 100644 index 00000000..17d03923 --- /dev/null +++ b/test/flight_poll_server.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#!/usr/bin/env python3 + +import pathlib +import signal +import sys +import tempfile +from concurrent import futures + +import grpc +import pyarrow as pa +import grpc_tools +from grpc_tools import protoc + + +ROOT = pathlib.Path(__file__).resolve().parent.parent / "src" / "flight" / "proto" +PROTO = ROOT / "Flight.proto" +GRPC_TOOLS_PROTO = pathlib.Path(grpc_tools.__file__).resolve().parent / "_proto" + + +def load_proto_modules(): + out = pathlib.Path(tempfile.mkdtemp(prefix="flight_poll_proto_")) + result = protoc.main( + [ + "grpc_tools.protoc", + f"-I{ROOT}", + f"-I{GRPC_TOOLS_PROTO}", + f"--python_out={out}", + f"--grpc_python_out={out}", + str(PROTO), + ] + ) + if result != 0: + raise RuntimeError(f"protoc failed with exit code {result}") + sys.path.insert(0, str(out)) + import Flight_pb2 + import Flight_pb2_grpc + + return Flight_pb2, Flight_pb2_grpc + + +def descriptor_key(descriptor): + return tuple(descriptor.path) + + +def main(): + pb2, pb2_grpc = load_proto_modules() + + class PollFlightInfoServicer(pb2_grpc.FlightServiceServicer): + def __init__(self, port): + self.pb2 = pb2 + self.port = port + self.schema_bytes = bytes(pa.schema([("id", pa.int64())]).serialize()) + + def _descriptor(self, path): + return self.pb2.FlightDescriptor( + type=self.pb2.FlightDescriptor.PATH, + path=list(path), + ) + + def _flight_info(self, path): + endpoint = self.pb2.FlightEndpoint( + ticket=self.pb2.Ticket(ticket=b"poll-ticket"), + location=[self.pb2.Location(uri=f"grpc://127.0.0.1:{self.port}")], + ) + return self.pb2.FlightInfo( + schema=self.schema_bytes, + flight_descriptor=self._descriptor(path), + endpoint=[endpoint], + total_records=1, + total_bytes=8, + ordered=True, + ) + + def PollFlightInfo(self, request, context): + del context + key = descriptor_key(request) + if key == ("interop", "poll"): + return self.pb2.PollInfo( + info=self._flight_info(key), + flight_descriptor=self._descriptor(("interop", "poll", "retry")), + progress=0.5, + ) + if key == ("interop", "poll", "retry"): + return self.pb2.PollInfo( + info=self._flight_info(("interop", "poll")), + progress=1.0, + ) + raise KeyError(f"unsupported poll descriptor: {key}") + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) + port = server.add_insecure_port("127.0.0.1:0") + pb2_grpc.add_FlightServiceServicer_to_server(PollFlightInfoServicer(port), server) + + def shutdown_handler(signum, frame): + del signum, frame + server.stop(grace=None) + + signal.signal(signal.SIGTERM, shutdown_handler) + signal.signal(signal.SIGINT, shutdown_handler) + + server.start() + print(port, flush=True) + server.wait_for_termination() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/flight_pyarrow_server.py b/test/flight_pyarrow_server.py new file mode 100644 index 00000000..386b3bd0 --- /dev/null +++ b/test/flight_pyarrow_server.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#!/usr/bin/env python3 + +import json +import signal +import sys + +import pyarrow as pa +import pyarrow.flight as fl + + +def normalize_component(value): + if isinstance(value, bytes): + return value.decode("utf-8") + return str(value) + + +def descriptor_key(descriptor): + if descriptor.descriptor_type != fl.DescriptorType.PATH: + raise KeyError("only PATH descriptors are supported") + return tuple(normalize_component(part) for part in descriptor.path) + + +def ticket_key(ticket): + return tuple(normalize_component(part) for part in json.loads(ticket.ticket.decode("utf-8"))) + + +def key_ticket(key): + return fl.Ticket(json.dumps(list(key)).encode("utf-8")) + + +class InteropFlightServer(fl.FlightServerBase): + def __init__(self): + super().__init__(location="grpc://127.0.0.1:0") + self._datasets = { + ("interop", "download"): pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "name": pa.array(["one", "two", "three"]), + } + ) + } + + def _descriptor(self, key): + return fl.FlightDescriptor.for_path(*key) + + def _flight_info(self, key): + table = self._datasets[key] + endpoint = fl.FlightEndpoint( + key_ticket(key), + [fl.Location.for_grpc_tcp("127.0.0.1", self.port)], + ) + return fl.FlightInfo( + table.schema, + self._descriptor(key), + [endpoint], + total_records=table.num_rows, + total_bytes=table.nbytes, + ) + + def list_flights(self, context, criteria): + del context, criteria + for key in sorted(self._datasets): + yield self._flight_info(key) + + def get_flight_info(self, context, descriptor): + del context + return self._flight_info(descriptor_key(descriptor)) + + def get_schema(self, context, descriptor): + del context + return fl.SchemaResult(self._datasets[descriptor_key(descriptor)].schema) + + def do_get(self, context, ticket): + del context + table = self._datasets[ticket_key(ticket)] + return fl.GeneratorStream(table.schema, iter(table.to_batches(max_chunksize=2))) + + def do_put(self, context, descriptor, reader, writer): + del context + self._datasets[descriptor_key(descriptor)] = reader.read_all() + writer.write(b"stored") + + def do_exchange(self, context, descriptor, reader, writer): + del context + key = descriptor_key(descriptor) + if key != ("interop", "exchange"): + raise KeyError(f"unsupported exchange descriptor: {key}") + + writer.begin(reader.schema) + batch_index = 0 + while True: + try: + chunk = reader.read_chunk() + except StopIteration: + break + if chunk.data is None: + continue + metadata = pa.py_buffer(f"exchange:{batch_index}".encode("utf-8")) + writer.write_with_metadata(chunk.data, metadata) + batch_index += 1 + + def list_actions(self, context): + del context + return [("ping", "Return a fixed pong payload")] + + def do_action(self, context, action): + del context + if action.type != "ping": + raise KeyError(f"unsupported action: {action.type}") + return [b"pong"] + + +def main(): + server = InteropFlightServer() + + def shutdown_handler(signum, frame): + del signum, frame + server.shutdown() + + signal.signal(signal.SIGTERM, shutdown_handler) + signal.signal(signal.SIGINT, shutdown_handler) + + print(server.port, flush=True) + server.serve() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/flight_tls_server.py b/test/flight_tls_server.py new file mode 100644 index 00000000..c86de2b6 --- /dev/null +++ b/test/flight_tls_server.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#!/usr/bin/env python3 + +import json +import os +import signal + +import pyarrow as pa +import pyarrow.flight as fl + + +def normalize_component(value): + if isinstance(value, bytes): + return value.decode("utf-8") + return str(value) + + +def descriptor_key(descriptor): + if descriptor.descriptor_type != fl.DescriptorType.PATH: + raise KeyError("only PATH descriptors are supported") + return tuple(normalize_component(part) for part in descriptor.path) + + +def key_ticket(key): + return fl.Ticket(json.dumps(list(key)).encode("utf-8")) + + +class TLSInteropFlightServer(fl.FlightServerBase): + def __init__(self, cert_path, key_path): + cert = open(cert_path, "rb").read() + key = open(key_path, "rb").read() + super().__init__( + location="grpc+tls://127.0.0.1:0", + tls_certificates=[fl.CertKeyPair(cert=cert, key=key)], + ) + self._datasets = { + ("interop", "tls", "download"): pa.table( + { + "id": pa.array([31, 32, 33], type=pa.int64()), + "name": pa.array(["thirty-one", "thirty-two", "thirty-three"]), + } + ) + } + + def _descriptor(self, key): + return fl.FlightDescriptor.for_path(*key) + + def _flight_info(self, key): + table = self._datasets[key] + endpoint = fl.FlightEndpoint( + key_ticket(key), + [fl.Location.for_grpc_tls("localhost", self.port)], + ) + return fl.FlightInfo( + table.schema, + self._descriptor(key), + [endpoint], + total_records=table.num_rows, + total_bytes=table.nbytes, + ) + + def get_flight_info(self, context, descriptor): + del context + return self._flight_info(descriptor_key(descriptor)) + + def get_schema(self, context, descriptor): + del context + return fl.SchemaResult(self._datasets[descriptor_key(descriptor)].schema) + + def do_get(self, context, ticket): + del context + key = tuple(normalize_component(part) for part in json.loads(ticket.ticket.decode("utf-8"))) + table = self._datasets[key] + return fl.GeneratorStream(table.schema, iter(table.to_batches(max_chunksize=2))) + + +def main(): + cert_path = os.environ["ARROW_FLIGHT_TLS_CERT"] + key_path = os.environ["ARROW_FLIGHT_TLS_KEY"] + server = TLSInteropFlightServer(cert_path, key_path) + + def shutdown_handler(signum, frame): + del signum, frame + server.shutdown() + + signal.signal(signal.SIGTERM, shutdown_handler) + signal.signal(signal.SIGINT, shutdown_handler) + + print(server.port, flush=True) + server.serve() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/runtests.jl b/test/runtests.jl index 315d1b60..7b66ec93 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,6 +39,7 @@ include(joinpath(@__DIR__, "testtables.jl")) include(joinpath(@__DIR__, "testappend.jl")) include(joinpath(@__DIR__, "integrationtest.jl")) include(joinpath(@__DIR__, "dates.jl")) +include(joinpath(@__DIR__, "flight.jl")) struct CustomStruct x::Int @@ -263,6 +264,93 @@ end @test all(isequal.(values(t), values(tt))) end + @testset "single-partition tobuffer byte equivalence" begin + t = (col=OffsetArray(["a", "bc", "def"], 0:2),) + io = IOBuffer() + Arrow.write(io, t) + seekstart(io) + @test read(Arrow.tobuffer(t)) == read(io) + + tm = (col=OffsetArray(Union{Missing,String}["a", missing, "def"], 0:2),) + io = IOBuffer() + Arrow.write(io, tm) + seekstart(io) + @test read(Arrow.tobuffer(tm)) == read(io) + + bt = + (col=OffsetArray([codeunits("a"), codeunits("bc"), codeunits("def")], 0:2),) + io = IOBuffer() + Arrow.write(io, bt) + seekstart(io) + @test read(Arrow.tobuffer(bt)) == read(io) + + btm = ( + col=OffsetArray( + Union{Missing,Base.CodeUnits{UInt8,String}}[ + codeunits("a"), + missing, + codeunits("def"), + ], + 0:2, + ), + ) + io = IOBuffer() + Arrow.write(io, btm) + seekstart(io) + @test read(Arrow.tobuffer(btm)) == read(io) + + mapt = ( + col=OffsetArray([Dict("a" => 1, "b" => 2), Dict("a" => 3, "b" => 4)], 0:1), + ) + io = IOBuffer() + Arrow.write(io, mapt) + seekstart(io) + @test read(Arrow.tobuffer(mapt)) == read(io) + + pooled = (col=PooledArray(["a", "b", "a", "c"]),) + io = IOBuffer() + Arrow.write(io, pooled; dictencode=true) + seekstart(io) + @test read(Arrow.tobuffer(pooled; dictencode=true)) == read(io) + + meta = Dict("key1" => "value1") + colmeta = Dict(:col => Dict("colkey1" => "colvalue1")) + io = IOBuffer() + Arrow.write(io, t; metadata=meta, colmetadata=colmeta) + seekstart(io) + @test read(Arrow.tobuffer(t; metadata=meta, colmetadata=colmeta)) == read(io) + + parts = Tables.partitioner([t, t]) + io = IOBuffer() + Arrow.write(io, parts) + seekstart(io) + @test read(Arrow.tobuffer(parts)) == read(io) + + string_missing_parts = Tables.partitioner([tm, tm]) + io = IOBuffer() + Arrow.write(io, string_missing_parts) + seekstart(io) + @test read(Arrow.tobuffer(string_missing_parts)) == read(io) + + binary_parts = Tables.partitioner([bt, bt]) + io = IOBuffer() + Arrow.write(io, binary_parts) + seekstart(io) + @test read(Arrow.tobuffer(binary_parts)) == read(io) + + binary_missing_parts = Tables.partitioner([btm, btm]) + io = IOBuffer() + Arrow.write(io, binary_missing_parts) + seekstart(io) + @test read(Arrow.tobuffer(binary_missing_parts)) == read(io) + + map_parts = Tables.partitioner([mapt, mapt]) + io = IOBuffer() + Arrow.write(io, map_parts) + seekstart(io) + @test read(Arrow.tobuffer(map_parts)) == read(io) + end + @testset "# 53" begin s = "a"^100 t = (a=[SubString(s, 1:10), SubString(s, 11:20)],) @@ -294,6 +382,38 @@ end @test isequal(tt.a, ['a', missing]) end + @testset "# offset bool write paths" begin + t = ( + a=OffsetArray(Bool[true, false, true], -1:1), + b=OffsetArray(Union{Missing,Bool}[true, missing, false], -1:1), + c=OffsetArray(Any[true, false, true], -1:1), + d=OffsetArray(Any[true, missing, false], -1:1), + ) + tt = Arrow.Table(Arrow.tobuffer(t)) + @test eltype(tt.c) == Bool + @test eltype(tt.d) == Union{Missing,Bool} + @test tt.a == Bool[true, false, true] + @test isequal(tt.b, Union{Missing,Bool}[true, missing, false]) + @test tt.c == Bool[true, false, true] + @test isequal(tt.d, Union{Missing,Bool}[true, missing, false]) + end + + @testset "# offset primitive write paths" begin + t = ( + a=OffsetArray(Int64[1, 2, 3], -1:1), + b=OffsetArray(Union{Missing,Int64}[1, missing, 3], -1:1), + c=OffsetArray(Any[1, 2, 3], -1:1), + d=OffsetArray(Any[1, missing, 3], -1:1), + ) + tt = Arrow.Table(Arrow.tobuffer(t)) + @test eltype(tt.c) == Int64 + @test eltype(tt.d) == Union{Missing,Int64} + @test tt.a == Int64[1, 2, 3] + @test isequal(tt.b, Union{Missing,Int64}[1, missing, 3]) + @test tt.c == Int64[1, 2, 3] + @test isequal(tt.d, Union{Missing,Int64}[1, missing, 3]) + end + @testset "# automatic custom struct serialization/deserialization" begin t = (col1=[CustomStruct(1, 2.3, "hey"), CustomStruct(4, 5.6, "there")],) @@ -328,6 +448,47 @@ end @test copy(tt.a) isa Vector{Nanosecond} @test copy(tt.b) isa Vector{UUID} @test copy(tt.c) isa Vector{Union{Missing,Nanosecond}} + + toffset = ( + b=OffsetArray( + [ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + -1:0, + ), + bm=OffsetArray( + Union{Missing,UUID}[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + missing, + ], + -1:0, + ), + ba=OffsetArray( + Any[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + -1:0, + ), + bam=OffsetArray( + Any[UUID("550e8400-e29b-41d4-a716-446655440000"), missing], + -1:0, + ), + ) + ttoffset = Arrow.Table(Arrow.tobuffer(toffset)) + @test collect(toffset.b) == ttoffset.b + @test isequal(collect(toffset.bm), ttoffset.bm) + @test eltype(ttoffset.ba) == NTuple{16,UInt8} + @test eltype(ttoffset.bam) == Union{Missing,NTuple{16,UInt8}} + @test map(Arrow.ArrowTypes.toarrow, collect(toffset.ba)) == copy(ttoffset.ba) + @test isequal( + map( + x -> ismissing(x) ? missing : Arrow.ArrowTypes.toarrow(x), + collect(toffset.bam), + ), + copy(ttoffset.bam), + ) end @testset "# copy on DictEncoding w/ missing values" begin @@ -624,6 +785,19 @@ end t = (col1=[["boop", "she"], ["boop", "she"], ["boo"]],) tbl = Arrow.Table(Arrow.tobuffer(t)) @test eltype(tbl.col1) <: AbstractVector{String} + + toffset = ( + col1=OffsetArray([Int64[1, 2], Int64[3, 4], Int64[]], -1:1), + col2=OffsetArray( + Union{Missing,Vector{Int64}}[Int64[1], missing, Int64[2, 3]], + -1:1, + ), + ) + tt = Arrow.Table(Arrow.tobuffer(toffset)) + @test eltype(tt.col1) <: AbstractVector{Int64} + @test Base.nonmissingtype(eltype(tt.col2)) <: AbstractVector{Int64} + @test collect(toffset.col1) == tt.col1 + @test isequal(collect(toffset.col2), tt.col2) end @testset "# 200 VersionNumber" begin @@ -632,6 +806,27 @@ end @test eltype(tbl.col1) == VersionNumber end + @testset "offset struct string write paths" begin + rows = OffsetArray( + Union{Missing,NamedTuple{(:s,),Tuple{String}}}[ + (s="a",), + missing, + (s="bc",), + ], + -1:1, + ) + tt = Arrow.Table(Arrow.tobuffer((rows=rows,))) + @test Base.nonmissingtype(eltype(tt.rows)) == NamedTuple{(:s,),Tuple{String}} + @test isequal(collect(rows), tt.rows) + end + + @testset "Complex" begin + t = (col1=Union{ComplexF64,Missing}[1 + 2im, missing, 3 + 4im],) + tbl = Arrow.Table(Arrow.tobuffer(t)) + @test eltype(tbl.col1) == Union{ComplexF64,Missing} + @test isequal(collect(tbl.col1), t.col1) + end + @testset "`show`" begin str = nothing table = (; a=1:5, b=fill(1.0, 5)) @@ -852,6 +1047,95 @@ end @test_throws ArgumentError( "`keytype(d)` must be concrete to serialize map-like `d`, but `keytype(d) == Real`", ) Arrow.tobuffer(t) + + t = ( + x=OffsetArray([Dict("a" => 1, "b" => 2), Dict("c" => 3)], -1:0), + xm=OffsetArray( + Union{Missing,Dict{String,Int}}[Dict("a" => 1), missing], + -1:0, + ), + xe=OffsetArray( + [Dict("a" => 1, "b" => 2, "c" => 3), Dict{String,Int}()], + -1:0, + ), + xem=OffsetArray( + Union{Missing,Dict{String,Int}}[Dict{String,Int}(), missing], + -1:0, + ), + xa=OffsetArray(Any[Dict("a" => 1, "b" => 2), Dict("c" => 3)], -1:0), + xam=OffsetArray(Any[Dict("a" => 1), missing], -1:0), + xame=OffsetArray(Any[Dict{String,Int}(), missing], -1:0), + ) + tt = Arrow.Table(Arrow.tobuffer(t)) + @test eltype(tt.x) == Dict{String,Int64} + @test eltype(tt.xm) == Union{Missing,Dict{String,Int64}} + @test eltype(tt.xe) == Dict{String,Int64} + @test eltype(tt.xem) == Union{Missing,Dict{String,Int64}} + @test eltype(tt.xa) == Dict{String,Int64} + @test eltype(tt.xam) == Union{Missing,Dict{String,Int64}} + @test eltype(tt.xame) == Union{Missing,Dict{String,Int64}} + @test copy(tt.x) isa Vector{Dict{String,Int64}} + @test copy(tt.xm) isa Vector{Union{Missing,Dict{String,Int64}}} + @test copy(tt.xem) isa Vector{Union{Missing,Dict{String,Int64}}} + @test copy(tt.xa) isa Vector{Dict{String,Int64}} + @test copy(tt.xam) isa Vector{Union{Missing,Dict{String,Int64}}} + @test copy(tt.xame) isa Vector{Union{Missing,Dict{String,Int64}}} + @test collect(t.x) == tt.x + @test isequal(collect(t.xm), tt.xm) + @test collect(t.xe) == tt.xe + @test isequal(collect(t.xem), tt.xem) + @test collect(t.xa) == tt.xa + @test isequal(collect(t.xam), tt.xam) + @test isequal(collect(t.xame), tt.xame) + + mapio = IOBuffer() + Arrow.write(mapio, (x=t.xm,)) + seekstart(mapio) + @test read(Arrow.tobuffer((x=t.xm,))) == read(mapio) + + mapbuf = Arrow.tobuffer((x=t.xm,)) + seekend(mapbuf) + mappos = position(mapbuf) + Arrow.append(mapbuf, Arrow.Table(Arrow.tobuffer((x=t.xm,)))) + seekstart(mapbuf) + mapbuf1 = read(mapbuf, mappos) + mapbuf2 = read(mapbuf) + mapt1 = Arrow.Table(mapbuf1) + mapt2 = Arrow.Table(mapbuf2) + @test isequal(collect(mapt1.x), collect(mapt2.x)) + + emptymapbuf = Arrow.tobuffer((x=t.xe,)) + seekend(emptymapbuf) + emptymappos = position(emptymapbuf) + Arrow.append(emptymapbuf, Arrow.Table(Arrow.tobuffer((x=t.xe,)))) + seekstart(emptymapbuf) + emptymapbuf1 = read(emptymapbuf, emptymappos) + emptymapbuf2 = read(emptymapbuf) + emptymapt1 = Arrow.Table(emptymapbuf1) + emptymapt2 = Arrow.Table(emptymapbuf2) + @test isequal(collect(emptymapt1.x), collect(emptymapt2.x)) + + anymapbuf = Arrow.tobuffer((x=t.xam,)) + seekend(anymapbuf) + anymappos = position(anymapbuf) + Arrow.append(anymapbuf, Arrow.Table(Arrow.tobuffer((x=t.xam,)))) + seekstart(anymapbuf) + anymapbuf1 = read(anymapbuf, anymappos) + anymapbuf2 = read(anymapbuf) + anymapt1 = Arrow.Table(anymapbuf1) + anymapt2 = Arrow.Table(anymapbuf2) + @test isequal(collect(anymapt1.x), collect(anymapt2.x)) + + anyemptymapbuf = Arrow.tobuffer((x=t.xame,)) + seekend(anyemptymapbuf) + anyemptymappos = position(anyemptymapbuf) + Arrow.append(anyemptymapbuf, Arrow.Table(Arrow.tobuffer((x=t.xame,)))) + seekstart(anyemptymapbuf) + anyemptymapbuf1 = read(anyemptymapbuf, anyemptymappos) + anyemptymapbuf2 = read(anyemptymapbuf) + anyemptymapt1 = Arrow.Table(anyemptymapbuf1) + anyemptymapt2 = Arrow.Table(anyemptymapbuf2) + @test isequal(collect(anyemptymapt1.x), collect(anyemptymapt2.x)) end @testset "# 214" begin @@ -966,6 +1250,45 @@ end @test isequal(t1.bm, t2.bm) @test isequal(t1.c, t2.c) @test isequal(t1.cm, t2.cm) + + toffset = ( + b=OffsetArray([b"01", b"", b"3"], -1:1), + bm=OffsetArray( + Union{Missing,Base.CodeUnits{UInt8,String}}[b"01", b"3", missing], + -1:1, + ), + ba=OffsetArray(Any[b"01", b"", b"3"], -1:1), + bam=OffsetArray(Any[b"01", missing, b"3"], -1:1), + c=OffsetArray(["a", "b", "c"], -1:1), + cm=OffsetArray(Union{Missing,String}["a", "c", missing], -1:1), + ) + ttoffset = Arrow.Table(Arrow.tobuffer(toffset)) + @test eltype(ttoffset.b) <: Base.CodeUnits + @test Base.nonmissingtype(eltype(ttoffset.bm)) <: Base.CodeUnits + @test eltype(ttoffset.ba) <: Base.CodeUnits + @test Base.nonmissingtype(eltype(ttoffset.bam)) <: Base.CodeUnits + @test eltype(ttoffset.c) == String + @test eltype(ttoffset.cm) == Union{Missing,String} + @test collect(toffset.b) == ttoffset.b + @test isequal(collect(toffset.bm), ttoffset.bm) + @test collect(toffset.ba) == copy(ttoffset.ba) + @test isequal(collect(toffset.bam), copy(ttoffset.bam)) + @test collect(toffset.c) == ttoffset.c + @test isequal(collect(toffset.cm), ttoffset.cm) + + offsetbuf = Arrow.tobuffer(toffset) + seekend(offsetbuf) + offsetpos = position(offsetbuf) + Arrow.append(offsetbuf, ttoffset) + seekstart(offsetbuf) + offsetbuf1 = read(offsetbuf, offsetpos) + offsetbuf2 = read(offsetbuf) + offsett1 = Arrow.Table(offsetbuf1) + offsett2 = Arrow.Table(offsetbuf2) + @test collect(offsett1.b) == collect(offsett2.b) + @test isequal(collect(offsett1.bm), collect(offsett2.bm)) + @test collect(offsett1.c) == collect(offsett2.c) + @test isequal(collect(offsett1.cm), collect(offsett2.cm)) end @testset "# 435" begin From 64cc61b3c12c1fbd784f999d91546d0a502d2257 Mon Sep 17 00:00:00 2001 From: guangtao Date: Mon, 30 Mar 2026 01:41:39 -0700 Subject: [PATCH 02/16] Fix #590 view-backed variadic buffer inference Centralize the view inline/external layout boundary, infer missing external buffers from valid view elements, and ignore null slots during buffer-count recovery for Utf8View/BinaryView readers. --- README.md | 1 + src/arraytypes/views.jl | 25 +++++++++++++++---------- src/table.jl | 15 ++++++++++++++- test/runtests.jl | 24 ++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index daada49a..f0ef95b8 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ This implementation supports the 1.0 version of the specification, including sup * All nested data types * Dictionary encodings and messages * Extension types + * View-backed Utf8/Binary columns, including recovery from under-reported variadic buffer counts by inferring the required external buffers from valid view elements * Streaming, file, record batch, and replacement and isdelta dictionary messages It currently doesn't include support for: diff --git a/src/arraytypes/views.jl b/src/arraytypes/views.jl index 0a43f6fc..0d23a70d 100644 --- a/src/arraytypes/views.jl +++ b/src/arraytypes/views.jl @@ -21,6 +21,17 @@ struct ViewElement offset::Int32 end +const VIEW_ELEMENT_BYTES = sizeof(ViewElement) +const VIEW_LENGTH_BYTES = sizeof(Int32) +const VIEW_INLINE_BYTES = VIEW_ELEMENT_BYTES - VIEW_LENGTH_BYTES + +@inline _viewisinline(length::Integer) = length <= VIEW_INLINE_BYTES +@inline _viewinlinestart(i::Integer) = + ((i - 1) * VIEW_ELEMENT_BYTES) + VIEW_LENGTH_BYTES + 1 +@inline _viewinlineend(i::Integer, length::Integer) = _viewinlinestart(i) + length - 1 +@inline _viewinlineslice(inline::Vector{UInt8}, i::Integer, length::Integer) = + @view inline[_viewinlinestart(i):_viewinlineend(i, length)] + """ Arrow.View @@ -45,12 +56,8 @@ Base.size(l::View) = (l.ℓ,) if S <: Base.CodeUnits # BinaryView return !l.validity[i] ? missing : - v.length < 13 ? - Base.CodeUnits( - StringView( - @view l.inline[(((i - 1) * 16) + 5):(((i - 1) * 16) + 5 + v.length - 1)] - ), - ) : + _viewisinline(v.length) ? + Base.CodeUnits(StringView(_viewinlineslice(l.inline, i, v.length))) : Base.CodeUnits( StringView( @view l.buffers[v.bufindex + 1][(v.offset + 1):(v.offset + v.length)] @@ -59,12 +66,10 @@ Base.size(l::View) = (l.ℓ,) else # Utf8View return !l.validity[i] ? missing : - v.length < 13 ? + _viewisinline(v.length) ? ArrowTypes.fromarrow( T, - StringView( - @view l.inline[(((i - 1) * 16) + 5):(((i - 1) * 16) + 5 + v.length - 1)] - ), + StringView(_viewinlineslice(l.inline, i, v.length)), ) : ArrowTypes.fromarrow( T, diff --git a/src/table.jl b/src/table.jl index de8bfc37..ad6adfda 100644 --- a/src/table.jl +++ b/src/table.jl @@ -732,6 +732,18 @@ const ListTypes = const LargeLists = Union{Meta.LargeUtf8,Meta.LargeBinary,Meta.LargeList,Meta.LargeListView} const ViewTypes = Union{Meta.Utf8View,Meta.BinaryView,Meta.ListView,Meta.LargeListView} +@inline function _viewbuffercount(validity, views, declared::Integer) + count = Int(declared) + for i in eachindex(views) + validity[i] || continue + v = @inbounds views[i] + if !_viewisinline(v.length) + count = max(count, Int(v.bufindex) + 1) + end + end + return count +end + function build(field::Meta.Field, batch, rb, de, nodeidx, bufferidx, varbufferidx, convert) d = field.dictionary if d !== nothing @@ -910,7 +922,8 @@ function build( inline = reinterpret(UInt8, views) # reuse the (possibly realigned) memory backing `views` bufferidx += 1 buffers = Vector{UInt8}[] - for i = 1:rb.variadicBufferCounts[varbufferidx] + nvariadic = _viewbuffercount(validity, views, rb.variadicBufferCounts[varbufferidx]) + for i = 1:nvariadic buffer = rb.buffers[bufferidx] _, A = reinterp(UInt8, batch, buffer, rb.compression) push!(buffers, A) diff --git a/test/runtests.jl b/test/runtests.jl index 7b66ec93..02292515 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -264,6 +264,30 @@ end @test all(isequal.(values(t), values(tt))) end + @testset "View buffer count inference" begin + inline_len = Int32(Arrow.VIEW_INLINE_BYTES) + views = Arrow.ViewElement[ + Arrow.ViewElement(inline_len, Int32(0), Int32(0), Int32(0)), + Arrow.ViewElement(inline_len + Int32(148), Int32(0), Int32(0), Int32(0)), + Arrow.ViewElement(inline_len + Int32(207), Int32(0), Int32(1), Int32(160)), + ] + validity = Arrow.ValidityBitmap(UInt8[], 1, length(views), 0) + @test Arrow._viewisinline(inline_len) + @test !Arrow._viewisinline(inline_len + Int32(1)) + @test Arrow._viewbuffercount(validity, views, Int32(0)) == 2 + @test Arrow._viewbuffercount(validity, views, Int32(1)) == 2 + @test Arrow._viewbuffercount(validity, views, Int32(3)) == 3 + + sparse_validity = Arrow.ValidityBitmap(UInt8[0x05], 1, 3, 1) + sparse_views = Arrow.ViewElement[ + Arrow.ViewElement(inline_len + Int32(64), Int32(0), Int32(0), Int32(0)), + Arrow.ViewElement(inline_len + Int32(64), Int32(0), Int32(99), Int32(0)), + Arrow.ViewElement(inline_len, Int32(0), Int32(0), Int32(0)), + ] + @test !sparse_validity[2] + @test Arrow._viewbuffercount(sparse_validity, sparse_views, Int32(0)) == 1 + end + @testset "single-partition tobuffer byte equivalence" begin t = (col=OffsetArray(["a", "bc", "def"], 0:2),) io = IOBuffer() From 5cba03308515a27189fc3ffe89803ce2334f305e Mon Sep 17 00:00:00 2001 From: guangtao Date: Mon, 30 Mar 2026 01:57:14 -0700 Subject: [PATCH 03/16] Fix #586 dictionary-encoded CategoricalArray missing refs Normalize dictionary codes against the refpool axis for missing-aware CategoricalArray inputs so Arrow roundtrips preserve values and copy/DataFrame copycols paths no longer fail. --- README.md | 1 + src/arraytypes/dictencoding.jl | 4 +++- test/runtests.jl | 7 +++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f0ef95b8..dd3a08d3 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ This implementation supports the 1.0 version of the specification, including sup * All primitive data types * All nested data types * Dictionary encodings and messages + * Dictionary-encoded `CategoricalArray` interop, including missing-value roundtrips through `Arrow.Table`, `copy`, and `DataFrame(...; copycols=true)` * Extension types * View-backed Utf8/Binary columns, including recovery from under-reported variadic buffer counts by inferring the required external buffers from valid view elements * Streaming, file, record batch, and replacement and isdelta dictionary messages diff --git a/src/arraytypes/dictencoding.jl b/src/arraytypes/dictencoding.jl index 3e3576c5..c8582e81 100644 --- a/src/arraytypes/dictencoding.jl +++ b/src/arraytypes/dictencoding.jl @@ -119,6 +119,8 @@ signedtype(::Type{UInt32}) = Int32 signedtype(::Type{UInt64}) = Int64 signedtype(::Type{T}) where {T<:Signed} = T +@inline _dictrefshift(pool) = firstindex(pool) + indtype(d::DictEncoded{T,S,A}) where {T,S,A} = S indtype(c::Compressed{Z,A}) where {Z,A<:DictEncoded} = indtype(c.data) @@ -232,7 +234,7 @@ function arrowvector( inds = copyto!(similar(Vector{signedtype(length(pool))}, length(refa)), refa) end # adjust to "offset" instead of index - inds .-= firstindex(refa) + inds .-= _dictrefshift(pool) data = arrowvector( pool, i, diff --git a/test/runtests.jl b/test/runtests.jl index 02292515..ea3408eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -538,6 +538,13 @@ end @test isa(first(av.indices), Signed) @test length(av) == 3 @test eltype(av) == String + + x = CategoricalArray(Union{Missing,String}["a", missing, "ccc"]) + tt = Arrow.Table(Arrow.tobuffer((x=x,); dictencode=true)) + @test isequal(collect(tt.x), collect(x)) + @test isequal(collect(copy(tt.x)), collect(x)) + df = DataFrame(tt; copycols=true) + @test isequal(collect(df.x), collect(x)) end @testset "# 120" begin From 83cc25ba03fbadfd772fddd93c22aa4f2cbae693 Mon Sep 17 00:00:00 2001 From: guangtao Date: Mon, 30 Mar 2026 23:58:55 -0700 Subject: [PATCH 04/16] Add Julia Enum extension roundtrips --- README.md | 1 + docs/src/manual.md | 20 +++++++++ src/ArrowTypes/src/ArrowTypes.jl | 76 ++++++++++++++++++++++++++++++++ src/ArrowTypes/test/tests.jl | 20 +++++++++ test/runtests.jl | 38 ++++++++++++++++ 5 files changed, 155 insertions(+) diff --git a/README.md b/README.md index dd3a08d3..4e7b559a 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ This implementation supports the 1.0 version of the specification, including sup * Dictionary encodings and messages * Dictionary-encoded `CategoricalArray` interop, including missing-value roundtrips through `Arrow.Table`, `copy`, and `DataFrame(...; copycols=true)` * Extension types + * Base Julia `Enum` logical types via the `JuliaLang.Enum` extension label, with native Julia roundtrips back to the original enum type while `convert=false` and non-Julia consumers still see the primitive storage type * View-backed Utf8/Binary columns, including recovery from under-reported variadic buffer counts by inferring the required external buffers from valid view elements * Streaming, file, record batch, and replacement and isdelta dictionary messages diff --git a/docs/src/manual.md b/docs/src/manual.md index 264b8170..790eb826 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -111,6 +111,26 @@ Arrow.jl already uses this mechanism for several Base logical types, including their original Julia types instead of falling back to plain struct-shaped `NamedTuple`s. +Base Julia `@enum` types also work out of the box through the same extension +machinery. Arrow stores the enum as its primitive basetype plus a +`JuliaLang.Enum` extension label that records the qualified Julia type path and +label/value mapping. Native Julia readers reconstruct the enum type, while +`Arrow.Table(...; convert=false)` and non-Julia consumers continue to see the +primitive storage values. + +```julia +using Arrow + +@enum RankingStrategy lexical=1 semantic=2 hybrid=3 + +bytes = read(Arrow.tobuffer((strategy = [lexical, hybrid],))) +typed = Arrow.Table(IOBuffer(bytes)) +raw = Arrow.Table(IOBuffer(bytes); convert=false) + +eltype(typed.strategy) == RankingStrategy +eltype(raw.strategy) == Int32 +``` + ```julia using Arrow diff --git a/src/ArrowTypes/src/ArrowTypes.jl b/src/ArrowTypes/src/ArrowTypes.jl index 46c1c21c..4fe9c5a5 100644 --- a/src/ArrowTypes/src/ArrowTypes.jl +++ b/src/ArrowTypes/src/ArrowTypes.jl @@ -213,6 +213,81 @@ arrowname(::Type{Char}) = CHAR JuliaType(::Val{CHAR}) = Char fromarrow(::Type{Char}, x::UInt32) = Char(x) +ArrowType(::Type{T}) where {T<:Enum} = Base.Enums.basetype(T) +toarrow(x::T) where {T<:Enum} = convert(Base.Enums.basetype(T), Int(x)) +const ENUM = Symbol("JuliaLang.Enum") +arrowname(::Type{T}) where {T<:Enum} = ENUM + +function _qualifiedtypepath(::Type{T}) where {T} + module_path = join(string.(Base.fullname(parentmodule(T))), ".") + return string(module_path, ".", nameof(T)) +end + +function _enum_labels(::Type{T}) where {T<:Enum} + B = Base.Enums.basetype(T) + return join((string(instance, ":", convert(B, Int(instance))) for instance in instances(T)), ",") +end + +function arrowmetadata(::Type{T}) where {T<:Enum} + return string("type=", _qualifiedtypepath(T), ";labels=", _enum_labels(T)) +end + +function _parsemetadata(metadata::AbstractString) + parsed = Dict{String, String}() + isempty(metadata) && return parsed + for entry in split(metadata, ';') + isempty(entry) && continue + delimiter = findfirst(==('='), entry) + delimiter === nothing && continue + key = entry[1:prevind(entry, delimiter)] + value = entry[nextind(entry, delimiter):end] + parsed[key] = value + end + return parsed +end + +function _rootmodule(name::Symbol) + name === :Main && return Main + if isdefined(Main, name) + candidate = getfield(Main, name) + candidate isa Module && return candidate + end + try + return Base.root_module(Main, name) + catch + return nothing + end +end + +function _resolvequalifiedtype(path::AbstractString) + parts = split(path, '.') + length(parts) < 2 && return nothing + current = _rootmodule(Symbol(first(parts))) + current isa Module || return nothing + for part in parts[2:(end - 1)] + symbol = Symbol(part) + isdefined(current, symbol) || return nothing + current = getfield(current, symbol) + current isa Module || return nothing + end + type_symbol = Symbol(last(parts)) + isdefined(current, type_symbol) || return nothing + return getfield(current, type_symbol) +end + +function JuliaType(::Val{ENUM}, S, metadata::String) + parsed = _parsemetadata(metadata) + haskey(parsed, "type") || return nothing + T = _resolvequalifiedtype(parsed["type"]) + T isa DataType || return nothing + T <: Enum || return nothing + storage_type = Base.nonmissingtype(S) + Base.Enums.basetype(T) === storage_type || return nothing + return T +end + +fromarrow(::Type{T}, x::Integer) where {T<:Enum} = T(x) + "BoolKind data is stored with values packed down to individual bits; so instead of a traditional Bool being 1 byte/8 bits, 8 Bool values would be packed into a single byte" struct BoolKind <: ArrowKind end ArrowKind(::Type{Bool}) = BoolKind() @@ -367,6 +442,7 @@ function default end default(T) = zero(T) default(::Type{Symbol}) = Symbol() default(::Type{Char}) = '\0' +default(::Type{T}) where {T<:Enum} = first(instances(T)) default(::Type{<:AbstractString}) = "" default(::Type{Any}) = nothing default(::Type{Missing}) = missing diff --git a/src/ArrowTypes/test/tests.jl b/src/ArrowTypes/test/tests.jl index e5822f43..87b7cfa6 100644 --- a/src/ArrowTypes/test/tests.jl +++ b/src/ArrowTypes/test/tests.jl @@ -22,6 +22,15 @@ struct Person name::String end +module EnumTestModule +@enum RankingStrategy lexical=1 semantic=2 hybrid=3 +end + +const RankingStrategy = EnumTestModule.RankingStrategy +const lexical = EnumTestModule.lexical +const semantic = EnumTestModule.semantic +const hybrid = EnumTestModule.hybrid + @testset "ArrowTypes" begin @test ArrowTypes.ArrowKind(MyInt) == ArrowTypes.PrimitiveKind() @test ArrowTypes.ArrowKind(Person) == ArrowTypes.StructKind() @@ -67,6 +76,17 @@ end @test ArrowTypes.JuliaType(Val(ArrowTypes.CHAR)) == Char @test ArrowTypes.fromarrow(Char, UInt32('1')) == '1' + enum_metadata = ArrowTypes.arrowmetadata(RankingStrategy) + @test ArrowTypes.ArrowKind(RankingStrategy) == ArrowTypes.PrimitiveKind() + @test ArrowTypes.ArrowType(RankingStrategy) == Int32 + @test ArrowTypes.toarrow(hybrid) == Int32(3) + @test ArrowTypes.arrowname(RankingStrategy) == ArrowTypes.ENUM + @test occursin("type=Main.EnumTestModule.RankingStrategy", enum_metadata) + @test occursin("labels=lexical:1,semantic:2,hybrid:3", enum_metadata) + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, enum_metadata) == RankingStrategy + @test ArrowTypes.fromarrow(RankingStrategy, Int32(2)) == semantic + @test ArrowTypes.default(RankingStrategy) == lexical + @test ArrowTypes.ArrowKind(Bool) == ArrowTypes.BoolKind() @test ArrowTypes.ListKind() == ArrowTypes.ListKind{false}() diff --git a/test/runtests.jl b/test/runtests.jl index ea3408eb..1707b9c7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,6 +51,15 @@ struct CustomStruct2{sym} x::Int end +module EnumRoundtripModule +@enum RankingStrategy lexical=1 semantic=2 hybrid=3 +end + +const RankingStrategy = EnumRoundtripModule.RankingStrategy +const lexical = EnumRoundtripModule.lexical +const semantic = EnumRoundtripModule.semantic +const hybrid = EnumRoundtripModule.hybrid + @testset ExtendedTestSet "Arrow" begin @testset "table roundtrips" begin for case in testtables @@ -450,6 +459,35 @@ end @test all(isequal.(values(t), values(tt))) end + @testset "# Julia Enum extension logical type roundtrip" begin + t = ( + col1=[lexical, hybrid], + col2=Union{Missing, RankingStrategy}[missing, semantic], + ) + + bytes = read(Arrow.tobuffer(t)) + tt = Arrow.Table(IOBuffer(bytes)) + raw = Arrow.Table(IOBuffer(bytes); convert=false) + + @test length(tt) == length(t) + @test eltype(tt.col1) == RankingStrategy + @test eltype(tt.col2) == Union{Missing, RankingStrategy} + @test tt.col1 == [lexical, hybrid] + @test isequal( + tt.col2, + Union{Missing, RankingStrategy}[missing, semantic], + ) + @test eltype(raw.col1) == Int32 + @test eltype(raw.col2) == Union{Missing, Int32} + @test raw.col1 == Int32[1, 3] + @test isequal(raw.col2, Union{Missing, Int32}[missing, 2]) + @test Arrow.getmetadata(tt.col1)["ARROW:extension:name"] == "JuliaLang.Enum" + @test occursin( + "Main.EnumRoundtripModule.RankingStrategy", + Arrow.getmetadata(tt.col1)["ARROW:extension:metadata"], + ) + end + @testset "# 76" begin t = (col1=NamedTuple{(:a,),Tuple{Union{Int,String}}}[(a=1,), (a="x",)],) tt = Arrow.Table(Arrow.tobuffer(t)) From 3bc1d80aa324c49bc52e35aeb553b0d0527b7ae3 Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 13:19:03 -0700 Subject: [PATCH 05/16] Add canonical Arrow core type support --- README.md | 6 ++ docs/src/manual.md | 7 ++- src/ArrowTypes/src/ArrowTypes.jl | 4 +- src/ArrowTypes/test/tests.jl | 1 + src/eltypes.jl | 50 +++++++++++++++- src/table.jl | 20 +++++++ src/utils.jl | 5 +- test/run_end_encoded_small.arrow | Bin 0 -> 794 bytes test/runtests.jl | 94 +++++++++++++++++++++++++++++++ 9 files changed, 183 insertions(+), 4 deletions(-) create mode 100644 test/run_end_encoded_small.arrow diff --git a/README.md b/README.md index 4e7b559a..8fc21914 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ This implementation supports the 1.0 version of the specification, including sup It currently doesn't include support for: * Tensors or sparse tensors * C data interface + * Run-End Encoded arrays; Arrow.jl now rejects them explicitly during read instead of falling through partially Flight RPC status: * Experimental `Arrow.Flight` support is available in-tree @@ -97,4 +98,9 @@ Third-party data formats: * Other Tables.jl-compatible packages automatically supported ([DataFrames.jl](https://github.com/JuliaData/DataFrames.jl), [JSONTables.jl](https://github.com/JuliaData/JSONTables.jl), [JuliaDB.jl](https://github.com/JuliaData/JuliaDB.jl), [SQLite.jl](https://github.com/JuliaDatabases/SQLite.jl), [MySQL.jl](https://github.com/JuliaDatabases/MySQL.jl), [JDBC.jl](https://github.com/JuliaDatabases/JDBC.jl), [ODBC.jl](https://github.com/JuliaDatabases/ODBC.jl), [XLSX.jl](https://github.com/felipenoris/XLSX.jl), etc.) * No current Julia packages support ORC +Canonical extension highlights: + * `UUID` now writes the canonical `arrow.uuid` extension name by default while retaining reader compatibility with legacy `JuliaLang.UUID` metadata + * `Arrow.TimestampWithOffset{U}` provides a canonical `arrow.timestamp_with_offset` logical type without conflating offset-only semantics with `ZonedDateTime` + * Legacy `JuliaLang.ZonedDateTime-UTC` and `JuliaLang.ZonedDateTime` files remain readable for backward compatibility + See the [full documentation](https://arrow.apache.org/julia/) for details on reading and writing arrow data. diff --git a/docs/src/manual.md b/docs/src/manual.md index 790eb826..8162b571 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -87,7 +87,8 @@ In the arrow data format, specific logical types are supported, a list of which * `Date`, `Time`, `Timestamp`, and `Duration` all have natural Julia defintions in `Dates.Date`, `Dates.Time`, `TimeZones.ZonedDateTime`, and `Dates.Period` subtypes, respectively. * `Char` and `Symbol` Julia types are mapped to arrow string types, with additional metadata of the original Julia type; this allows deserializing directly to `Char` and `Symbol` in Julia, while other language implementations will see these columns as just strings -* Similarly to the above, the `UUID` Julia type is mapped to a 128-bit `FixedSizeBinary` arrow type. +* `UUID` is mapped to a 128-bit `FixedSizeBinary` arrow type and now writes the canonical `arrow.uuid` extension name by default while still reading older `JuliaLang.UUID` metadata +* `Arrow.TimestampWithOffset{U}` is the canonical offset-only logical type for `arrow.timestamp_with_offset`; it stores a UTC `Arrow.Timestamp{U,:UTC}` plus `offset_minutes::Int16` and does not imply a timezone-name interpretation * `Decimal128` and `Decimal256` have no corresponding builtin Julia types, so they're deserialized using a compatible type definition in Arrow.jl itself: `Arrow.Decimal` @@ -97,6 +98,10 @@ One note on performance: when writing `TimeZones.ZonedDateTime` columns to the a as the column has `ZonedDateTime` elements that all share a common timezone. This ensures the writing process can know "upfront" which timezone will be encoded and is thus much more efficient and performant. +Run-End Encoded arrays are not implemented in Arrow.jl yet. Files containing +that layout now fail explicitly during read with a clear unsupported error +instead of partially decoding. + Similarly, `ArrowTypes.ToArrow` avoids repeated type-promotion work for homogeneous custom columns even when `ArrowTypes.ArrowType(T)` is abstract, so write-time conversion does not pay unnecessary overhead once the serialized diff --git a/src/ArrowTypes/src/ArrowTypes.jl b/src/ArrowTypes/src/ArrowTypes.jl index 4fe9c5a5..ce832710 100644 --- a/src/ArrowTypes/src/ArrowTypes.jl +++ b/src/ArrowTypes/src/ArrowTypes.jl @@ -339,9 +339,11 @@ ArrowKind(::Type{NTuple{N,T}}) where {N,T} = FixedSizeListKind{N,T}() ArrowKind(::Type{UUID}) = FixedSizeListKind{16,UInt8}() ArrowType(::Type{UUID}) = NTuple{16,UInt8} toarrow(x::UUID) = _cast(NTuple{16,UInt8}, x.value) -const UUIDSYMBOL = Symbol("JuliaLang.UUID") +const UUIDSYMBOL = Symbol("arrow.uuid") +const LEGACY_UUIDSYMBOL = Symbol("JuliaLang.UUID") arrowname(::Type{UUID}) = UUIDSYMBOL JuliaType(::Val{UUIDSYMBOL}) = UUID +JuliaType(::Val{LEGACY_UUIDSYMBOL}) = UUID fromarrow(::Type{UUID}, x::NTuple{16,UInt8}) = UUID(_cast(UInt128, x)) ArrowKind(::Type{IPv4}) = PrimitiveKind() diff --git a/src/ArrowTypes/test/tests.jl b/src/ArrowTypes/test/tests.jl index 87b7cfa6..16f420e5 100644 --- a/src/ArrowTypes/test/tests.jl +++ b/src/ArrowTypes/test/tests.jl @@ -126,6 +126,7 @@ const hybrid = EnumTestModule.hybrid @test ArrowTypes.toarrow(u) == ubytes @test ArrowTypes.arrowname(UUID) == ArrowTypes.UUIDSYMBOL @test ArrowTypes.JuliaType(Val(ArrowTypes.UUIDSYMBOL)) == UUID + @test ArrowTypes.JuliaType(Val(ArrowTypes.LEGACY_UUIDSYMBOL)) == UUID @test ArrowTypes.fromarrow(UUID, ubytes) == u ip4 = IPv4(rand(UInt32)) diff --git a/src/eltypes.jl b/src/eltypes.jl index 52dbb809..d00c9e0b 100644 --- a/src/eltypes.jl +++ b/src/eltypes.jl @@ -24,6 +24,8 @@ finaljuliatype(T) = T finaljuliatype(::Type{Missing}) = Missing finaljuliatype(::Type{Union{T,Missing}}) where {T} = Union{Missing,finaljuliatype(T)} +const RUN_END_ENCODED_UNSUPPORTED = "Run-End Encoded arrays are not supported yet" + """ Given a FlatBuffers.Builder and a Julia column or column eltype, Write the field.type flatbuffer definition of the eltype @@ -46,7 +48,11 @@ function juliaeltype(f::Meta.Field, meta::AbstractDict{String,String}, convert:: if haskey(meta, "ARROW:extension:name") typename = meta["ARROW:extension:name"] metadata = get(meta, "ARROW:extension:metadata", "") - JT = ArrowTypes.JuliaType(Val(Symbol(typename)), maybemissing(TT), metadata) + typenamesym = Symbol(typename) + storageT = + typenamesym === TIMESTAMP_WITH_OFFSET_SYMBOL ? maybemissing(juliaeltype(f, false)) : + maybemissing(TT) + JT = ArrowTypes.JuliaType(Val(typenamesym), storageT, metadata) if JT !== nothing return f.nullable ? Union{JT,Missing} : JT else @@ -265,6 +271,19 @@ end Base.zero(::Type{Timestamp{U,T}}) where {U,T} = Timestamp{U,T}(Int64(0)) +struct TimestampWithOffset{U} + timestamp::Timestamp{U,:UTC} + offset_minutes::Int16 +end + +TimestampWithOffset( + timestamp::Timestamp{U,:UTC}, + offset_minutes::Integer, +) where {U} = TimestampWithOffset{U}(timestamp, Int16(offset_minutes)) + +Base.zero(::Type{TimestampWithOffset{U}}) where {U} = + TimestampWithOffset{U}(zero(Timestamp{U,:UTC}), Int16(0)) + function juliaeltype(f::Meta.Field, x::Meta.Timestamp, convert) return Timestamp{x.unit,x.timezone === nothing ? nothing : Symbol(x.timezone)} end @@ -335,6 +354,32 @@ ArrowTypes.fromarrow(::Type{ZonedDateTime}, x::Timestamp) = convert(ZonedDateTim ArrowTypes.default(::Type{TimeZones.ZonedDateTime}) = TimeZones.ZonedDateTime(1, 1, 1, 1, 1, 1, TimeZones.tz"UTC") +const TIMESTAMP_WITH_OFFSET_SYMBOL = Symbol("arrow.timestamp_with_offset") +ArrowTypes.arrowname(::Type{TimestampWithOffset{U}}) where {U} = TIMESTAMP_WITH_OFFSET_SYMBOL +ArrowTypes.JuliaType( + ::Val{TIMESTAMP_WITH_OFFSET_SYMBOL}, + ::Type{ + NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{Timestamp{U,:UTC},Int16}, + }, + }, + metadata::String, +) where {U} = TimestampWithOffset{U} +ArrowTypes.default(::Type{TimestampWithOffset{U}}) where {U} = zero(TimestampWithOffset{U}) +ArrowTypes.fromarrowstruct( + ::Type{TimestampWithOffset{U}}, + ::Val{(:timestamp, :offset_minutes)}, + timestamp::Timestamp{U,:UTC}, + offset_minutes::Int16, +) where {U} = TimestampWithOffset{U}(timestamp, offset_minutes) +ArrowTypes.fromarrowstruct( + ::Type{TimestampWithOffset{U}}, + ::Val{(:offset_minutes, :timestamp)}, + offset_minutes::Int16, + timestamp::Timestamp{U,:UTC}, +) where {U} = TimestampWithOffset{U}(timestamp, offset_minutes) + # Backwards compatibility: older versions of Arrow saved ZonedDateTime's with this metdata: const OLD_ZONEDDATETIME_SYMBOL = Symbol("JuliaLang.ZonedDateTime") # and stored the local time instead of the UTC time. @@ -390,6 +435,9 @@ function juliaeltype(f::Meta.Field, x::Meta.Interval, convert) return Interval{x.unit,bitwidth(x.unit)} end +juliaeltype(f::Meta.Field, x::Meta.RunEndEncoded, convert) = + throw(ArgumentError(RUN_END_ENCODED_UNSUPPORTED)) + function arrowtype(b, ::Type{Interval{U,T}}) where {U,T} Meta.intervalStart(b) Meta.intervalAddUnit(b, U) diff --git a/src/table.jl b/src/table.jl index ad6adfda..1852dbc4 100644 --- a/src/table.jl +++ b/src/table.jl @@ -28,6 +28,10 @@ tobytes(io::IO) = Base.read(io) tobytes(io::IOStream) = Mmap.mmap(io) tobytes(file_path) = open(tobytes, file_path, "r") +rejectunsupported(field::Meta.Field) = (rejectunsupported(field.type); foreach(rejectunsupported, field.children)) +rejectunsupported(x) = nothing +rejectunsupported(x::Meta.RunEndEncoded) = throw(ArgumentError(RUN_END_ENCODED_UNSUPPORTED)) + struct BatchIterator bytes::Vector{UInt8} startpos::Int @@ -187,6 +191,7 @@ function Base.iterate(x::Stream, (pos, id)=(1, 0)) # assert endianness? # store custom_metadata? for (i, field) in enumerate(x.schema.fields) + rejectunsupported(field) push!(x.names, Symbol(field.name)) push!( x.types, @@ -487,6 +492,7 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true) # store custom_metadata? if sch === nothing for (i, field) in enumerate(header.fields) + rejectunsupported(field) push!(names(t), Symbol(field.name)) # recursively find any dictionaries for any fields getdictionaries!(dictencoded, field) @@ -903,6 +909,20 @@ function build( varbufferidx end +function build( + f::Meta.Field, + x::Meta.RunEndEncoded, + batch, + rb, + de, + nodeidx, + bufferidx, + varbufferidx, + convert, +) + throw(ArgumentError(RUN_END_ENCODED_UNSUPPORTED)) +end + function build( f::Meta.Field, L::ViewTypes, diff --git a/src/utils.jl b/src/utils.jl index 7c042b71..05a297cb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -163,7 +163,10 @@ end @inline function _directtobuffercoleligible(col) T = Base.nonmissingtype(eltype(col)) - return !(T <: AbstractString || T <: Base.CodeUnits) + T <: AbstractString && return false + T <: Base.CodeUnits && return false + K = ArrowTypes.ArrowKind(ArrowTypes.ArrowType(T)) + return !(K isa ArrowTypes.ListKind) end @inline function _directtobufferstringonly(col) diff --git a/test/run_end_encoded_small.arrow b/test/run_end_encoded_small.arrow new file mode 100644 index 0000000000000000000000000000000000000000..17155c3533579bb291c52c6af9fe26bba93ea145 GIT binary patch literal 794 zcmds0yAA(j#=DAjBv70EJTN6Lg$&cQz*cfg^Y3 zzRui7RymzcW+UK5JOFqKP{1KahBh@KNoj*tn`atMy3GQvyy2UXNhy@`u2E*jzvDasL)vp??Uo86jh*^bgO zKg~q>n)k-LLlGRhVPa>Up_xMS3=}jJv>j_yJK55wcd{yIXRJqVQ{&#^n0ghoUqLkm z?XO0Y%J Dict("ARROW:extension:name" => "JuliaLang.UUID"), + ), + ), + ) + @test copy(legacy_tt.b) == [ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ] toffset = ( b=OffsetArray( @@ -693,6 +719,74 @@ const hybrid = EnumRoundtripModule.hybrid ) end + @testset "canonical timestamp_with_offset" begin + values = Union{Missing,Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}}[ + Arrow.TimestampWithOffset( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(1577836800000), + 330, + ), + missing, + Arrow.TimestampWithOffset( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(1577923200000), + -480, + ), + ] + tt = Arrow.Table(Arrow.tobuffer((col=values,))) + @test eltype(tt.col) == + Union{Missing,Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}} + @test isequal(copy(tt.col), values) + @test Arrow.getmetadata(tt.col)["ARROW:extension:name"] == + "arrow.timestamp_with_offset" + + raw_tt = Arrow.Table(Arrow.tobuffer((col=values,)); convert=false) + @test eltype(raw_tt.col) == + Union{ + Missing, + NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{ + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + Int16, + }, + }, + } + @test isequal( + copy(raw_tt.col), + Union{ + Missing, + NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{ + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + Int16, + }, + }, + }[ + ( + timestamp=Arrow.Timestamp{ + Arrow.Meta.TimeUnit.MILLISECOND, + :UTC, + }(1577836800000), + offset_minutes=Int16(330), + ), + missing, + ( + timestamp=Arrow.Timestamp{ + Arrow.Meta.TimeUnit.MILLISECOND, + :UTC, + }(1577923200000), + offset_minutes=Int16(-480), + ), + ], + ) + end + + @testset "Run-End Encoded rejection" begin + path = joinpath(@__DIR__, "run_end_encoded_small.arrow") + @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) Arrow.Table(path) + @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) collect(Arrow.Stream(path)) + end + @testset "# 158" begin # arrow ipc stream generated from pyarrow with no record batches bytes = UInt8[ From bd7e88153a5856bda5ede4a30502f7e2f479dbb2 Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 13:37:41 -0700 Subject: [PATCH 06/16] Add canonical low-risk Arrow extensions --- README.md | 3 +++ docs/src/manual.md | 3 +++ src/eltypes.jl | 59 ++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 43 +++++++++++++++++++++++++++++++++ 4 files changed, 108 insertions(+) diff --git a/README.md b/README.md index 8fc21914..acd3d763 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,9 @@ Third-party data formats: Canonical extension highlights: * `UUID` now writes the canonical `arrow.uuid` extension name by default while retaining reader compatibility with legacy `JuliaLang.UUID` metadata * `Arrow.TimestampWithOffset{U}` provides a canonical `arrow.timestamp_with_offset` logical type without conflating offset-only semantics with `ZonedDateTime` + * `Arrow.Bool8` provides an explicit opt-in writer/reader surface for the canonical `arrow.bool8` extension without changing the default packed-bit `Bool` path + * `Arrow.JSONText{String}` provides a text-backed logical type for the canonical `arrow.json` extension without parsing payloads during read or write + * `arrow.opaque` now reads as the underlying storage type without warning, and explicit writer metadata can be generated with `Arrow.opaquemetadata(type_name, vendor_name)` * Legacy `JuliaLang.ZonedDateTime-UTC` and `JuliaLang.ZonedDateTime` files remain readable for backward compatibility See the [full documentation](https://arrow.apache.org/julia/) for details on reading and writing arrow data. diff --git a/docs/src/manual.md b/docs/src/manual.md index 8162b571..da958808 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -89,6 +89,9 @@ In the arrow data format, specific logical types are supported, a list of which * `Char` and `Symbol` Julia types are mapped to arrow string types, with additional metadata of the original Julia type; this allows deserializing directly to `Char` and `Symbol` in Julia, while other language implementations will see these columns as just strings * `UUID` is mapped to a 128-bit `FixedSizeBinary` arrow type and now writes the canonical `arrow.uuid` extension name by default while still reading older `JuliaLang.UUID` metadata * `Arrow.TimestampWithOffset{U}` is the canonical offset-only logical type for `arrow.timestamp_with_offset`; it stores a UTC `Arrow.Timestamp{U,:UTC}` plus `offset_minutes::Int16` and does not imply a timezone-name interpretation +* `Arrow.Bool8` is an explicit opt-in logical type for the canonical `arrow.bool8` extension; it uses `Int8` storage, while plain Julia `Bool` continues to use Arrow's packed-bit boolean layout +* `Arrow.JSONText{String}` is a text-backed logical type for the canonical `arrow.json` extension; Arrow.jl preserves the payload as text and does not parse JSON automatically +* `arrow.opaque` is treated as interoperability metadata over the underlying storage type; explicit metadata can be generated with `Arrow.opaquemetadata(type_name, vendor_name)` when writing * `Decimal128` and `Decimal256` have no corresponding builtin Julia types, so they're deserialized using a compatible type definition in Arrow.jl itself: `Arrow.Decimal` diff --git a/src/eltypes.jl b/src/eltypes.jl index d00c9e0b..7790ae42 100644 --- a/src/eltypes.jl +++ b/src/eltypes.jl @@ -25,6 +25,9 @@ finaljuliatype(::Type{Missing}) = Missing finaljuliatype(::Type{Union{T,Missing}}) where {T} = Union{Missing,finaljuliatype(T)} const RUN_END_ENCODED_UNSUPPORTED = "Run-End Encoded arrays are not supported yet" +const BOOL8_SYMBOL = Symbol("arrow.bool8") +const JSON_SYMBOL = Symbol("arrow.json") +const OPAQUE_SYMBOL = Symbol("arrow.opaque") """ Given a FlatBuffers.Builder and a Julia column or column eltype, @@ -114,6 +117,62 @@ function arrowtype(b, ::Type{T}) where {T<:Integer} return Meta.Int, Meta.intEnd(b), nothing end +struct Bool8 + value::Bool +end + +Bool8(x::Integer) = Bool8(!iszero(x)) + +Base.Bool(x::Bool8) = getfield(x, :value) +Base.convert(::Type{Bool}, x::Bool8) = Bool(x) +Base.convert(::Type{Int8}, x::Bool8) = Int8(Bool(x)) +Base.zero(::Type{Bool8}) = Bool8(false) +Base.:(==)(x::Bool8, y::Bool8) = Bool(x) == Bool(y) +Base.isequal(x::Bool8, y::Bool8) = isequal(Bool(x), Bool(y)) + +ArrowTypes.ArrowType(::Type{Bool8}) = Int8 +ArrowTypes.toarrow(x::Bool8) = Int8(Bool(x)) +ArrowTypes.arrowname(::Type{Bool8}) = BOOL8_SYMBOL +ArrowTypes.JuliaType(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = Bool8 +ArrowTypes.fromarrow(::Type{Bool8}, x::Int8) = Bool8(x) +ArrowTypes.default(::Type{Bool8}) = zero(Bool8) + +function writearray(io::IO, ::Type{Int8}, col::ArrowTypes.ToArrow{Int8,A}) where {A<:AbstractVector{Bool8}} + data = ArrowTypes._sourcedata(col) + strides(data) == (1,) || return _writearrayfallback(io, Int8, col) + return Base.write(io, reinterpret(Int8, data)) +end + +struct JSONText{S<:AbstractString} + value::S +end + +Base.String(x::JSONText) = String(getfield(x, :value)) +Base.convert(::Type{String}, x::JSONText) = String(x) +Base.:(==)(x::JSONText, y::JSONText) = getfield(x, :value) == getfield(y, :value) +Base.isequal(x::JSONText, y::JSONText) = isequal(getfield(x, :value), getfield(y, :value)) + +ArrowTypes.ArrowType(::Type{JSONText{S}}) where {S<:AbstractString} = S +ArrowTypes.toarrow(x::JSONText) = getfield(x, :value) +ArrowTypes.arrowname(::Type{JSONText{S}}) where {S<:AbstractString} = JSON_SYMBOL +ArrowTypes.JuliaType(::Val{JSON_SYMBOL}, ::Type{S}, metadata::String) where {S<:AbstractString} = + JSONText{S} +ArrowTypes.fromarrow(::Type{JSONText{String}}, ptr::Ptr{UInt8}, len::Int) = + JSONText(unsafe_string(ptr, len)) +ArrowTypes.fromarrow(::Type{JSONText{S}}, x::S) where {S<:AbstractString} = JSONText{S}(x) +ArrowTypes.default(::Type{JSONText{S}}) where {S<:AbstractString} = + JSONText{S}(ArrowTypes.default(S)) + +ArrowTypes.JuliaType(::Val{OPAQUE_SYMBOL}, S, metadata::String) = S + +@inline function _jsonstringliteral(x::AbstractString) + return '"' * escape_string(x) * '"' +end + +opaquemetadata(type_name::AbstractString, vendor_name::AbstractString) = + "{\"type_name\":" * _jsonstringliteral(type_name) * + ",\"vendor_name\":" * _jsonstringliteral(vendor_name) * "}" + # primitive types function juliaeltype(f::Meta.Field, fp::Meta.FloatingPoint, convert) if fp.precision == Meta.Precision.HALF diff --git a/test/runtests.jl b/test/runtests.jl index fd4f328b..dc14dfae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -787,6 +787,49 @@ const hybrid = EnumRoundtripModule.hybrid @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) collect(Arrow.Stream(path)) end + @testset "canonical bool8/json/opaque" begin + bools = Union{Missing,Arrow.Bool8}[Arrow.Bool8(true), missing, Arrow.Bool8(false)] + tt = Arrow.Table(Arrow.tobuffer((col=bools,))) + @test eltype(tt.col) == Union{Missing,Arrow.Bool8} + @test isequal(copy(tt.col), bools) + @test Arrow.getmetadata(tt.col)["ARROW:extension:name"] == "arrow.bool8" + + raw_tt = Arrow.Table(Arrow.tobuffer((col=bools,)); convert=false) + @test eltype(raw_tt.col) == Union{Missing,Int8} + @test isequal(copy(raw_tt.col), Union{Missing,Int8}[1, missing, 0]) + + jsons = Union{Missing,Arrow.JSONText{String}}[ + Arrow.JSONText("{\"a\":1}"), + missing, + Arrow.JSONText("[1,2,3]"), + ] + json_tt = Arrow.Table(Arrow.tobuffer((col=jsons,))) + @test eltype(json_tt.col) == Union{Missing,Arrow.JSONText{String}} + @test isequal(copy(json_tt.col), jsons) + @test Arrow.getmetadata(json_tt.col)["ARROW:extension:name"] == "arrow.json" + + raw_json_tt = Arrow.Table(Arrow.tobuffer((col=jsons,)); convert=false) + @test eltype(raw_json_tt.col) == Union{Missing,String} + @test isequal(copy(raw_json_tt.col), Union{Missing,String}["{\"a\":1}", missing, "[1,2,3]"]) + + opaque_meta = Arrow.opaquemetadata("pkg.Type", "vendor.example") + opaque_tt = Arrow.Table( + Arrow.tobuffer( + (col=["a", "b"],); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.opaque", + "ARROW:extension:metadata" => opaque_meta, + ), + ), + ), + ) + @test eltype(opaque_tt.col) == String + @test copy(opaque_tt.col) == ["a", "b"] + @test Arrow.getmetadata(opaque_tt.col)["ARROW:extension:name"] == "arrow.opaque" + @test Arrow.getmetadata(opaque_tt.col)["ARROW:extension:metadata"] == opaque_meta + end + @testset "# 158" begin # arrow ipc stream generated from pyarrow with no record batches bytes = UInt8[ From d4cdd133ab6da7782275b73ea7c0f90c07dba77b Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 14:08:57 -0700 Subject: [PATCH 07/16] Add Run-End Encoded read support --- README.md | 2 +- docs/src/manual.md | 6 +- src/arraytypes/arraytypes.jl | 1 + src/arraytypes/runendencoded.jl | 107 ++++++++++++++++++++++++++++++++ src/eltypes.jl | 5 +- src/table.jl | 15 ++++- test/runtests.jl | 18 +++++- 7 files changed, 143 insertions(+), 11 deletions(-) create mode 100644 src/arraytypes/runendencoded.jl diff --git a/README.md b/README.md index acd3d763..3b459089 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ This implementation supports the 1.0 version of the specification, including sup It currently doesn't include support for: * Tensors or sparse tensors * C data interface - * Run-End Encoded arrays; Arrow.jl now rejects them explicitly during read instead of falling through partially + * Writing Run-End Encoded arrays; Arrow.jl now reads REE arrays and exposes them as read-only vectors, but still rejects REE on write paths Flight RPC status: * Experimental `Arrow.Flight` support is available in-tree diff --git a/docs/src/manual.md b/docs/src/manual.md index da958808..731828e2 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -101,9 +101,9 @@ One note on performance: when writing `TimeZones.ZonedDateTime` columns to the a as the column has `ZonedDateTime` elements that all share a common timezone. This ensures the writing process can know "upfront" which timezone will be encoded and is thus much more efficient and performant. -Run-End Encoded arrays are not implemented in Arrow.jl yet. Files containing -that layout now fail explicitly during read with a clear unsupported error -instead of partially decoding. +Run-End Encoded arrays are now supported on the read path. Arrow.jl exposes REE +columns as read-only vectors and continues to reject REE on write paths, rather +than attempting a partial or lossy re-encoding. Similarly, `ArrowTypes.ToArrow` avoids repeated type-promotion work for homogeneous custom columns even when `ArrowTypes.ArrowType(T)` is abstract, so diff --git a/src/arraytypes/arraytypes.jl b/src/arraytypes/arraytypes.jl index 20bbce2d..9f04292e 100644 --- a/src/arraytypes/arraytypes.jl +++ b/src/arraytypes/arraytypes.jl @@ -342,3 +342,4 @@ include("struct.jl") include("unions.jl") include("dictencoding.jl") include("views.jl") +include("runendencoded.jl") diff --git a/src/arraytypes/runendencoded.jl b/src/arraytypes/runendencoded.jl new file mode 100644 index 00000000..7ca318b0 --- /dev/null +++ b/src/arraytypes/runendencoded.jl @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + Arrow.RunEndEncoded + +A read-only `ArrowVector` for Arrow Run-End Encoded arrays. Logical indexing is +resolved by binary searching the physical `run_ends` child and then indexing the +corresponding `values` child. +""" +struct RunEndEncoded{T,R,A} <: ArrowVector{T} + run_ends::R + values::A + ℓ::Int + metadata::Union{Nothing,Base.ImmutableDict{String,String}} +end + +Base.size(r::RunEndEncoded) = (r.ℓ,) +Base.copy(r::RunEndEncoded) = collect(r) + +@inline _reephysicalindex(r::RunEndEncoded, i::Integer) = searchsortedfirst(r.run_ends, i) + +function _validaterunendencoded(run_ends, values, len) + nruns = length(run_ends) + nvals = length(values) + nruns == nvals || throw( + ArgumentError( + "invalid Run-End Encoded array: run_ends length $nruns does not match values length $nvals", + ), + ) + if len == 0 + nruns == 0 || throw( + ArgumentError( + "invalid Run-End Encoded array: zero logical length requires zero runs", + ), + ) + elseif nruns == 0 + throw( + ArgumentError( + "invalid Run-End Encoded array: non-zero logical length requires at least one run", + ), + ) + end + last_end = 0 + for (idx, run_end) in enumerate(run_ends) + current_end = Int(run_end) + current_end > last_end || throw( + ArgumentError( + "invalid Run-End Encoded array: run_ends must be strictly increasing positive integers (failed at run $idx)", + ), + ) + last_end = current_end + end + len == 0 || last_end == len || throw( + ArgumentError( + "invalid Run-End Encoded array: final run end $last_end does not match logical length $len", + ), + ) + return +end + +function RunEndEncoded(run_ends::R, values::A, len, meta) where {R,A} + _validaterunendencoded(run_ends, values, Int(len)) + T = eltype(values) + return RunEndEncoded{T,R,A}(run_ends, values, Int(len), meta) +end + +function _makerunendencoded(::Type{T}, run_ends::R, values::A, len, meta) where {T,R,A} + _validaterunendencoded(run_ends, values, Int(len)) + return RunEndEncoded{T,R,A}(run_ends, values, Int(len), meta) +end + +@propagate_inbounds function Base.getindex(r::RunEndEncoded{T}, i::Integer) where {T} + @boundscheck checkbounds(r, i) + physical = _reephysicalindex(r, i) + physical <= length(r.values) || throw( + ArgumentError( + "invalid Run-End Encoded array: no physical value found for logical index $i", + ), + ) + return @inbounds ArrowTypes.fromarrow(T, r.values[physical]) +end + +function toarrowvector( + x::RunEndEncoded, + i=1, + de=Dict{Int64,Any}(), + ded=DictEncoding[], + meta=getmetadata(x); + compression::Union{Nothing,Symbol,LZ4FrameCompressor,ZstdCompressor}=nothing, + kw..., +) + throw(ArgumentError(RUN_END_ENCODED_UNSUPPORTED)) +end diff --git a/src/eltypes.jl b/src/eltypes.jl index 7790ae42..5c5aa76f 100644 --- a/src/eltypes.jl +++ b/src/eltypes.jl @@ -494,8 +494,9 @@ function juliaeltype(f::Meta.Field, x::Meta.Interval, convert) return Interval{x.unit,bitwidth(x.unit)} end -juliaeltype(f::Meta.Field, x::Meta.RunEndEncoded, convert) = - throw(ArgumentError(RUN_END_ENCODED_UNSUPPORTED)) +function juliaeltype(f::Meta.Field, x::Meta.RunEndEncoded, convert) + return juliaeltype(f.children[2], buildmetadata(f.children[2]), convert) +end function arrowtype(b, ::Type{Interval{U,T}}) where {U,T} Meta.intervalStart(b) diff --git a/src/table.jl b/src/table.jl index 1852dbc4..f72d83ae 100644 --- a/src/table.jl +++ b/src/table.jl @@ -30,7 +30,6 @@ tobytes(file_path) = open(tobytes, file_path, "r") rejectunsupported(field::Meta.Field) = (rejectunsupported(field.type); foreach(rejectunsupported, field.children)) rejectunsupported(x) = nothing -rejectunsupported(x::Meta.RunEndEncoded) = throw(ArgumentError(RUN_END_ENCODED_UNSUPPORTED)) struct BatchIterator bytes::Vector{UInt8} @@ -920,7 +919,19 @@ function build( varbufferidx, convert, ) - throw(ArgumentError(RUN_END_ENCODED_UNSUPPORTED)) + @debug "building array: x = $x" + len = rb.nodes[nodeidx].length + nodeidx += 1 + meta = buildmetadata(f.custom_metadata) + T = juliaeltype(f, meta, convert) + run_ends, nodeidx, bufferidx, varbufferidx = + build(f.children[1], batch, rb, de, nodeidx, bufferidx, varbufferidx, false) + values, nodeidx, bufferidx, varbufferidx = + build(f.children[2], batch, rb, de, nodeidx, bufferidx, varbufferidx, convert) + return _makerunendencoded(T, run_ends, values, len, meta), + nodeidx, + bufferidx, + varbufferidx end function build( diff --git a/test/runtests.jl b/test/runtests.jl index dc14dfae..e909081e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -781,10 +781,22 @@ const hybrid = EnumRoundtripModule.hybrid ) end - @testset "Run-End Encoded rejection" begin + @testset "Run-End Encoded read support" begin path = joinpath(@__DIR__, "run_end_encoded_small.arrow") - @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) Arrow.Table(path) - @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) collect(Arrow.Stream(path)) + expected = ["a", "a", "b", "b", "b"] + + tt = Arrow.Table(path) + @test tt isa Arrow.Table + @test eltype(tt.x) == Union{Missing,String} + @test collect(tt.x) == expected + @test copy(tt.x) == expected + + batches = collect(Arrow.Stream(path)) + @test length(batches) == 1 + @test collect(batches[1].x) == expected + + @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) Arrow.tobuffer(tt) + @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) Arrow.tobuffer((x=tt.x,)) end @testset "canonical bool8/json/opaque" begin From 826a1b48a377d8737776bc4c0f4793c590134c09 Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 14:28:30 -0700 Subject: [PATCH 08/16] Recognize canonical advanced extension passthrough --- README.md | 1 + docs/src/manual.md | 1 + src/eltypes.jl | 6 ++++ test/runtests.jl | 81 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+) diff --git a/README.md b/README.md index 3b459089..b341da1c 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,7 @@ Canonical extension highlights: * `Arrow.Bool8` provides an explicit opt-in writer/reader surface for the canonical `arrow.bool8` extension without changing the default packed-bit `Bool` path * `Arrow.JSONText{String}` provides a text-backed logical type for the canonical `arrow.json` extension without parsing payloads during read or write * `arrow.opaque` now reads as the underlying storage type without warning, and explicit writer metadata can be generated with `Arrow.opaquemetadata(type_name, vendor_name)` + * `arrow.parquet.variant`, `arrow.fixed_shape_tensor`, and `arrow.variable_shape_tensor` are recognized on read as canonical passthrough extensions over their storage types without claiming full semantic interpretation or automatic writer support * Legacy `JuliaLang.ZonedDateTime-UTC` and `JuliaLang.ZonedDateTime` files remain readable for backward compatibility See the [full documentation](https://arrow.apache.org/julia/) for details on reading and writing arrow data. diff --git a/docs/src/manual.md b/docs/src/manual.md index 731828e2..52984c44 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -92,6 +92,7 @@ In the arrow data format, specific logical types are supported, a list of which * `Arrow.Bool8` is an explicit opt-in logical type for the canonical `arrow.bool8` extension; it uses `Int8` storage, while plain Julia `Bool` continues to use Arrow's packed-bit boolean layout * `Arrow.JSONText{String}` is a text-backed logical type for the canonical `arrow.json` extension; Arrow.jl preserves the payload as text and does not parse JSON automatically * `arrow.opaque` is treated as interoperability metadata over the underlying storage type; explicit metadata can be generated with `Arrow.opaquemetadata(type_name, vendor_name)` when writing +* `arrow.parquet.variant`, `arrow.fixed_shape_tensor`, and `arrow.variable_shape_tensor` are recognized as canonical passthrough extensions on read; Arrow.jl currently returns their underlying storage types and does not yet implement higher-level semantic interpretation or automatic writer surfaces for them * `Decimal128` and `Decimal256` have no corresponding builtin Julia types, so they're deserialized using a compatible type definition in Arrow.jl itself: `Arrow.Decimal` diff --git a/src/eltypes.jl b/src/eltypes.jl index 5c5aa76f..24bdeac8 100644 --- a/src/eltypes.jl +++ b/src/eltypes.jl @@ -28,6 +28,9 @@ const RUN_END_ENCODED_UNSUPPORTED = "Run-End Encoded arrays are not supported ye const BOOL8_SYMBOL = Symbol("arrow.bool8") const JSON_SYMBOL = Symbol("arrow.json") const OPAQUE_SYMBOL = Symbol("arrow.opaque") +const PARQUET_VARIANT_SYMBOL = Symbol("arrow.parquet.variant") +const FIXED_SHAPE_TENSOR_SYMBOL = Symbol("arrow.fixed_shape_tensor") +const VARIABLE_SHAPE_TENSOR_SYMBOL = Symbol("arrow.variable_shape_tensor") """ Given a FlatBuffers.Builder and a Julia column or column eltype, @@ -164,6 +167,9 @@ ArrowTypes.default(::Type{JSONText{S}}) where {S<:AbstractString} = JSONText{S}(ArrowTypes.default(S)) ArrowTypes.JuliaType(::Val{OPAQUE_SYMBOL}, S, metadata::String) = S +ArrowTypes.JuliaType(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = S +ArrowTypes.JuliaType(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S +ArrowTypes.JuliaType(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S @inline function _jsonstringliteral(x::AbstractString) return '"' * escape_string(x) * '"' diff --git a/test/runtests.jl b/test/runtests.jl index e909081e..395a79bf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -842,6 +842,87 @@ const hybrid = EnumRoundtripModule.hybrid @test Arrow.getmetadata(opaque_tt.col)["ARROW:extension:metadata"] == opaque_meta end + @testset "canonical advanced passthrough" begin + variant_values = Union{ + Missing, + NamedTuple{(:metadata, :value),Tuple{String,String}}, + }[ + (metadata="json", value="{\"a\":1}"), + missing, + (metadata="string", value="abc"), + ] + @test_logs min_level=Base.CoreLogging.Warn begin + variant_tt = Arrow.Table( + Arrow.tobuffer( + (col=variant_values,); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.parquet.variant", + "ARROW:extension:metadata" => "", + ), + ), + ), + ) + @test eltype(variant_tt.col) == eltype(variant_values) + @test isequal(copy(variant_tt.col), variant_values) + @test Arrow.getmetadata(variant_tt.col)["ARROW:extension:name"] == + "arrow.parquet.variant" + end + + fixed_tensor_values = Union{Missing,NTuple{4,Int32}}[ + (Int32(1), Int32(2), Int32(3), Int32(4)), + missing, + (Int32(5), Int32(6), Int32(7), Int32(8)), + ] + @test_logs min_level=Base.CoreLogging.Warn begin + fixed_tensor_tt = Arrow.Table( + Arrow.tobuffer( + (col=fixed_tensor_values,); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.fixed_shape_tensor", + "ARROW:extension:metadata" => "{\"shape\":[2,2],\"dim_names\":[\"x\",\"y\"]}", + ), + ), + ), + ) + @test eltype(fixed_tensor_tt.col) == eltype(fixed_tensor_values) + @test isequal(copy(fixed_tensor_tt.col), fixed_tensor_values) + @test Arrow.getmetadata(fixed_tensor_tt.col)["ARROW:extension:name"] == + "arrow.fixed_shape_tensor" + end + + variable_tensor_values = Union{Missing,Vector{Int32}}[ + Int32[1, 2, 3, 4], + missing, + Int32[5, 6], + ] + @test_logs min_level=Base.CoreLogging.Warn begin + variable_tensor_tt = Arrow.Table( + Arrow.tobuffer( + (col=variable_tensor_values,); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.variable_shape_tensor", + "ARROW:extension:metadata" => "{\"uniform_shape\":[2]}", + ), + ), + ), + ) + @test Base.nonmissingtype(eltype(variable_tensor_tt.col)) <: AbstractVector{Int32} + @test isequal( + map(x -> x === missing ? missing : copy(x), copy(variable_tensor_tt.col)), + Union{Missing,Vector{Int32}}[ + Int32[1, 2, 3, 4], + missing, + Int32[5, 6], + ], + ) + @test Arrow.getmetadata(variable_tensor_tt.col)["ARROW:extension:name"] == + "arrow.variable_shape_tensor" + end + end + @testset "# 158" begin # arrow ipc stream generated from pyarrow with no record batches bytes = UInt8[ From 2a39ac147f5c54b62646075636f036984c19ce87 Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 14:58:25 -0700 Subject: [PATCH 09/16] Add explicit tensor IPC boundary errors --- README.md | 2 +- docs/src/manual.md | 4 ++++ src/Arrow.jl | 2 ++ src/metadata/Message.jl | 24 ++++++++++++++++++++---- src/table.jl | 17 ++++++++++++++--- test/runtests.jl | 25 +++++++++++++++++++++++++ 6 files changed, 66 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index b341da1c..d6a5bca9 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ This implementation supports the 1.0 version of the specification, including sup * Streaming, file, record batch, and replacement and isdelta dictionary messages It currently doesn't include support for: - * Tensors or sparse tensors + * Tensor or sparse tensor IPC payload semantics; Arrow.jl now recognizes those message headers explicitly and rejects them with precise errors instead of falling through to a generic unsupported-message path * C data interface * Writing Run-End Encoded arrays; Arrow.jl now reads REE arrays and exposes them as read-only vectors, but still rejects REE on write paths diff --git a/docs/src/manual.md b/docs/src/manual.md index 52984c44..54cc9e74 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -106,6 +106,10 @@ Run-End Encoded arrays are now supported on the read path. Arrow.jl exposes REE columns as read-only vectors and continues to reject REE on write paths, rather than attempting a partial or lossy re-encoding. +Tensor and SparseTensor IPC messages are still unsupported, but Arrow.jl now +recognizes those message headers explicitly and rejects them with precise +errors instead of falling through to a generic unsupported-message failure. + Similarly, `ArrowTypes.ToArrow` avoids repeated type-promotion work for homogeneous custom columns even when `ArrowTypes.ArrowType(T)` is abstract, so write-time conversion does not pay unnecessary overhead once the serialized diff --git a/src/Arrow.jl b/src/Arrow.jl index eb9c2188..a6fa58a6 100644 --- a/src/Arrow.jl +++ b/src/Arrow.jl @@ -80,6 +80,8 @@ import Base: == const FILE_FORMAT_MAGIC_BYTES = b"ARROW1" const CONTINUATION_INDICATOR_BYTES = 0xffffffff +const TENSOR_UNSUPPORTED = "Tensor messages are not supported yet" +const SPARSE_TENSOR_UNSUPPORTED = "SparseTensor messages are not supported yet" # vendored flatbuffers code for now include("FlatBuffers/FlatBuffers.jl") diff --git a/src/metadata/Message.jl b/src/metadata/Message.jl index 0e494394..56491916 100644 --- a/src/metadata/Message.jl +++ b/src/metadata/Message.jl @@ -157,12 +157,28 @@ dictionaryBatchAddIsDelta(b::FlatBuffers.Builder, isdelta::Base.Bool) = FlatBuffers.prependslot!(b, 2, isdelta, false) dictionaryBatchEnd(b::FlatBuffers.Builder) = FlatBuffers.endobject!(b) +struct Tensor <: FlatBuffers.Table + bytes::Vector{UInt8} + pos::Base.Int +end + +Base.propertynames(x::Tensor) = () +Base.getproperty(x::Tensor, field::Symbol) = nothing + +struct SparseTensor <: FlatBuffers.Table + bytes::Vector{UInt8} + pos::Base.Int +end + +Base.propertynames(x::SparseTensor) = () +Base.getproperty(x::SparseTensor, field::Symbol) = nothing + function MessageHeader(b::UInt8) b == 1 && return Schema b == 2 && return DictionaryBatch b == 3 && return RecordBatch - # b == 4 && return Tensor - # b == 5 && return SparseTensor + b == 4 && return Tensor + b == 5 && return SparseTensor return nothing end @@ -170,8 +186,8 @@ function MessageHeader(::Base.Type{T})::Int16 where {T} T == Schema && return 1 T == DictionaryBatch && return 2 T == RecordBatch && return 3 - # T == Tensor && return 4 - # T == SparseTensor && return 5 + T == Tensor && return 4 + T == SparseTensor && return 5 return 0 end diff --git a/src/table.jl b/src/table.jl index f72d83ae..0af09dd8 100644 --- a/src/table.jl +++ b/src/table.jl @@ -181,10 +181,13 @@ function Base.iterate(x::Stream, (pos, id)=(1, 0)) end batch, (pos, id) = state header = batch.msg.header - if isnothing(x.schema) && !isa(header, Meta.Schema) + if header isa Meta.Tensor + throw(ArgumentError(TENSOR_UNSUPPORTED)) + elseif header isa Meta.SparseTensor + throw(ArgumentError(SPARSE_TENSOR_UNSUPPORTED)) + elseif isnothing(x.schema) && !isa(header, Meta.Schema) throw(ArgumentError("first arrow ipc message MUST be a schema message")) - end - if header isa Meta.Schema + elseif header isa Meta.Schema if isnothing(x.schema) x.schema = header # assert endianness? @@ -268,6 +271,10 @@ function Base.iterate(x::Stream, (pos, id)=(1, 0)) push!(columns, vec) end break + elseif header isa Meta.Tensor + throw(ArgumentError(TENSOR_UNSUPPORTED)) + elseif header isa Meta.SparseTensor + throw(ArgumentError(SPARSE_TENSOR_UNSUPPORTED)) else throw(ArgumentError("unsupported arrow message type: $(typeof(header))")) end @@ -583,6 +590,10 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true) ), ) rbi += 1 + elseif header isa Meta.Tensor + throw(ArgumentError(TENSOR_UNSUPPORTED)) + elseif header isa Meta.SparseTensor + throw(ArgumentError(SPARSE_TENSOR_UNSUPPORTED)) else throw(ArgumentError("unsupported arrow message type: $(typeof(header))")) end diff --git a/test/runtests.jl b/test/runtests.jl index 395a79bf..db5f6d85 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -923,6 +923,31 @@ const hybrid = EnumRoundtripModule.hybrid end end + @testset "tensor message boundary" begin + function patch_message_header_type(bytes, header_type::UInt8) + patched = copy(bytes) + msg = Arrow.FlatBuffers.getrootas(Arrow.Meta.Message, patched, 8) + offset = Arrow.FlatBuffers.offset(msg, 6) + @test offset != 0 + patched[Arrow.FlatBuffers.pos(msg) + offset + 1] = header_type + return patched + end + + base = take!(Arrow.tobuffer((x=[1, 2],))) + + tensor_bytes = patch_message_header_type(base, UInt8(4)) + @test_throws ArgumentError(Arrow.TENSOR_UNSUPPORTED) Arrow.Table(tensor_bytes) + @test_throws ArgumentError(Arrow.TENSOR_UNSUPPORTED) collect(Arrow.Stream(tensor_bytes)) + + sparse_tensor_bytes = patch_message_header_type(base, UInt8(5)) + @test_throws ArgumentError(Arrow.SPARSE_TENSOR_UNSUPPORTED) Arrow.Table( + sparse_tensor_bytes, + ) + @test_throws ArgumentError(Arrow.SPARSE_TENSOR_UNSUPPORTED) collect( + Arrow.Stream(sparse_tensor_bytes), + ) + end + @testset "# 158" begin # arrow ipc stream generated from pyarrow with no record batches bytes = UInt8[ From 1eb955bf9b1b3e0f7ed2a1e9806726334a8e2f6b Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 15:55:22 -0700 Subject: [PATCH 10/16] Validate advanced canonical Arrow extensions --- Project.toml | 2 + README.md | 4 +- docs/src/manual.md | 4 +- src/Arrow.jl | 1 + src/eltypes.jl | 259 ++++++++++++++++++++++++++++++++++++++++++++- test/runtests.jl | 117 +++++++++++++++++--- 6 files changed, 369 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index 5b4a3bcc..8150433b 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,7 @@ ConcurrentUtilities = "f0e56b4a-5159-44fe-b623-3e5288b988bb" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429" gRPCClient = "aaca4a50-36af-4a1d-b878-4c443f2061ad" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" @@ -57,6 +58,7 @@ CodecZstd = "0.7, 0.8" ConcurrentUtilities = "2" DataAPI = "1" EnumX = "1" +JSON3 = "1" ProtoBuf = "~1.2.1" gRPCClient = "1" gRPCServer = "0.1" diff --git a/README.md b/README.md index d6a5bca9..cf7316cb 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,9 @@ Canonical extension highlights: * `Arrow.Bool8` provides an explicit opt-in writer/reader surface for the canonical `arrow.bool8` extension without changing the default packed-bit `Bool` path * `Arrow.JSONText{String}` provides a text-backed logical type for the canonical `arrow.json` extension without parsing payloads during read or write * `arrow.opaque` now reads as the underlying storage type without warning, and explicit writer metadata can be generated with `Arrow.opaquemetadata(type_name, vendor_name)` - * `arrow.parquet.variant`, `arrow.fixed_shape_tensor`, and `arrow.variable_shape_tensor` are recognized on read as canonical passthrough extensions over their storage types without claiming full semantic interpretation or automatic writer support + * `Arrow.variantmetadata()`, `Arrow.fixedshapetensormetadata(...)`, and `Arrow.variableshapetensormetadata(...)` generate canonical metadata strings for advanced canonical extensions + * `arrow.fixed_shape_tensor` and `arrow.variable_shape_tensor` are recognized on read as canonical passthrough extensions over their storage types, and Arrow.jl now validates their canonical metadata plus top-level storage shape before accepting them + * `arrow.parquet.variant` is recognized on read as a canonical passthrough extension over its storage type; Arrow.jl currently validates that its canonical metadata is the required empty string, but does not yet implement deeper variant semantics or an automatic writer surface * Legacy `JuliaLang.ZonedDateTime-UTC` and `JuliaLang.ZonedDateTime` files remain readable for backward compatibility See the [full documentation](https://arrow.apache.org/julia/) for details on reading and writing arrow data. diff --git a/docs/src/manual.md b/docs/src/manual.md index 54cc9e74..e806f5f6 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -92,7 +92,9 @@ In the arrow data format, specific logical types are supported, a list of which * `Arrow.Bool8` is an explicit opt-in logical type for the canonical `arrow.bool8` extension; it uses `Int8` storage, while plain Julia `Bool` continues to use Arrow's packed-bit boolean layout * `Arrow.JSONText{String}` is a text-backed logical type for the canonical `arrow.json` extension; Arrow.jl preserves the payload as text and does not parse JSON automatically * `arrow.opaque` is treated as interoperability metadata over the underlying storage type; explicit metadata can be generated with `Arrow.opaquemetadata(type_name, vendor_name)` when writing -* `arrow.parquet.variant`, `arrow.fixed_shape_tensor`, and `arrow.variable_shape_tensor` are recognized as canonical passthrough extensions on read; Arrow.jl currently returns their underlying storage types and does not yet implement higher-level semantic interpretation or automatic writer surfaces for them +* `Arrow.variantmetadata()`, `Arrow.fixedshapetensormetadata(...)`, and `Arrow.variableshapetensormetadata(...)` generate canonical metadata strings for advanced canonical extensions when writing explicit storage-backed values +* `arrow.fixed_shape_tensor` and `arrow.variable_shape_tensor` are recognized as canonical passthrough extensions on read; Arrow.jl returns their underlying storage types, validates canonical metadata and top-level storage shape, and does not yet implement higher-level semantic interpretation or automatic writer surfaces for them +* `arrow.parquet.variant` is recognized as a canonical passthrough extension on read; Arrow.jl currently validates only the required empty metadata string and does not yet implement deeper variant semantics or an automatic writer surface * `Decimal128` and `Decimal256` have no corresponding builtin Julia types, so they're deserialized using a compatible type definition in Arrow.jl itself: `Arrow.Decimal` diff --git a/src/Arrow.jl b/src/Arrow.jl index a6fa58a6..b899e88f 100644 --- a/src/Arrow.jl +++ b/src/Arrow.jl @@ -66,6 +66,7 @@ using DataAPI, Tables, SentinelArrays, PooledArrays, + JSON3, CodecLz4, CodecZstd, TimeZones, diff --git a/src/eltypes.jl b/src/eltypes.jl index 24bdeac8..3ae5d56b 100644 --- a/src/eltypes.jl +++ b/src/eltypes.jl @@ -32,6 +32,197 @@ const PARQUET_VARIANT_SYMBOL = Symbol("arrow.parquet.variant") const FIXED_SHAPE_TENSOR_SYMBOL = Symbol("arrow.fixed_shape_tensor") const VARIABLE_SHAPE_TENSOR_SYMBOL = Symbol("arrow.variable_shape_tensor") +@inline _canonicalextensionerror(sym::Symbol, msg::AbstractString) = + throw(ArgumentError("invalid canonical $(String(sym)) extension: $msg")) + +@inline _fieldchildren(field::Meta.Field) = + field.children === nothing ? Meta.Field[] : field.children + +@inline _jsonhaskey(x, key::AbstractString) = haskey(x, key) +@inline _jsonget(x, key::AbstractString) = x[key] + +function _parsecanonicalmetadata(sym::Symbol, metadata::String; required::Bool=false) + isempty(metadata) && return required ? _canonicalextensionerror(sym, "metadata is required") : nothing + value = try + JSON3.read(metadata) + catch + _canonicalextensionerror(sym, "metadata must be valid JSON") + end + value isa JSON3.Object || _canonicalextensionerror(sym, "metadata must be a JSON object") + return value +end + +function _parseintvector(sym::Symbol, value, label::AbstractString; allow_null::Bool=false) + value isa AbstractVector || + _canonicalextensionerror(sym, "\"$label\" must be a JSON array") + parsed = Vector{allow_null ? Union{Nothing,Int} : Int}() + for item in value + if allow_null && isnothing(item) + push!(parsed, nothing) + elseif item isa Integer + item >= 0 || _canonicalextensionerror(sym, "\"$label\" values must be non-negative") + push!(parsed, Int(item)) + else + suffix = allow_null ? "integers or null" : "integers" + _canonicalextensionerror(sym, "\"$label\" must contain only $suffix") + end + end + return parsed +end + +function _parsestringvector(sym::Symbol, value, label::AbstractString) + value isa AbstractVector || + _canonicalextensionerror(sym, "\"$label\" must be a JSON array") + parsed = String[] + for item in value + item isa AbstractString || + _canonicalextensionerror(sym, "\"$label\" must contain only strings") + push!(parsed, String(item)) + end + return parsed +end + +function _validatepermutation(sym::Symbol, permutation::Vector{Int}, ndim::Int) + length(permutation) == ndim || + _canonicalextensionerror(sym, "\"permutation\" must have length $ndim") + length(unique(permutation)) == ndim || + _canonicalextensionerror(sym, "\"permutation\" must not contain duplicates") + return permutation +end + +function _extractdimensionalmetadata(sym::Symbol, metadata; ndim::Union{Nothing,Int}=nothing) + metadata === nothing && return (nothing, nothing, nothing) + dim_names = + _jsonhaskey(metadata, "dim_names") ? + _parsestringvector(sym, _jsonget(metadata, "dim_names"), "dim_names") : nothing + permutation = + _jsonhaskey(metadata, "permutation") ? + _parseintvector(sym, _jsonget(metadata, "permutation"), "permutation") : nothing + uniform_shape = + _jsonhaskey(metadata, "uniform_shape") ? + _parseintvector(sym, _jsonget(metadata, "uniform_shape"), "uniform_shape"; allow_null=true) : + nothing + if ndim !== nothing + dim_names !== nothing && length(dim_names) == ndim || + isnothing(dim_names) || _canonicalextensionerror(sym, "\"dim_names\" must have length $ndim") + permutation !== nothing && _validatepermutation(sym, permutation, ndim) + uniform_shape !== nothing && length(uniform_shape) == ndim || + isnothing(uniform_shape) || + _canonicalextensionerror(sym, "\"uniform_shape\" must have length $ndim") + end + return dim_names, permutation, uniform_shape +end + +@inline _isliststoragetype(x) = + x isa Union{Meta.List,Meta.LargeList,Meta.ListView,Meta.LargeListView} + +@inline _isbinarystoragetype(x) = + x isa Union{Meta.Binary,Meta.LargeBinary,Meta.BinaryView,Meta.FixedSizeBinary} + +function _validateparquetvariant(field::Meta.Field, metadata::String) + isempty(metadata) || _canonicalextensionerror(PARQUET_VARIANT_SYMBOL, "metadata must be the empty string") + field + return +end + +function _validatefixedshapetensor(field::Meta.Field, metadata::String) + meta = _parsecanonicalmetadata(FIXED_SHAPE_TENSOR_SYMBOL, metadata; required=true) + _jsonhaskey(meta, "shape") || + _canonicalextensionerror(FIXED_SHAPE_TENSOR_SYMBOL, "\"shape\" is required") + shape = _parseintvector( + FIXED_SHAPE_TENSOR_SYMBOL, + _jsonget(meta, "shape"), + "shape", + ) + dim_names, permutation, _ = _extractdimensionalmetadata( + FIXED_SHAPE_TENSOR_SYMBOL, + meta; + ndim=length(shape), + ) + field.type isa Meta.FixedSizeList || + _canonicalextensionerror( + FIXED_SHAPE_TENSOR_SYMBOL, + "storage must be a FixedSizeList", + ) + length(collect(_fieldchildren(field))) == 1 || + _canonicalextensionerror( + FIXED_SHAPE_TENSOR_SYMBOL, + "storage must contain exactly one child field", + ) + expected = isempty(shape) ? 1 : prod(shape) + Int(field.type.listSize) == expected || + _canonicalextensionerror( + FIXED_SHAPE_TENSOR_SYMBOL, + "\"shape\" product $expected does not match FixedSizeList size $(field.type.listSize)", + ) + dim_names + permutation + return +end + +function _validatevariableshapetensor(field::Meta.Field, metadata::String) + field.type isa Meta.Struct || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "storage must be a Struct", + ) + children = Dict(String(child.name) => child for child in collect(_fieldchildren(field))) + keys(children) == Set(("data", "shape")) || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "storage must contain exactly \"data\" and \"shape\" fields", + ) + data_field = children["data"] + shape_field = children["shape"] + _isliststoragetype(data_field.type) || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"data\" field must use list storage", + ) + length(collect(_fieldchildren(data_field))) == 1 || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"data\" field must contain exactly one child field", + ) + shape_field.type isa Meta.FixedSizeList || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" field must use FixedSizeList storage", + ) + shape_children = collect(_fieldchildren(shape_field)) + length(shape_children) == 1 || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" field must contain exactly one child field", + ) + shape_value = only(shape_children) + shape_value.type isa Meta.Int || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" values must use Int32 storage", + ) + (shape_value.type.bitWidth == 32 && shape_value.type.is_signed) || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" values must use signed Int32 storage", + ) + ndim = Int(shape_field.type.listSize) + meta = _parsecanonicalmetadata(VARIABLE_SHAPE_TENSOR_SYMBOL, metadata) + _extractdimensionalmetadata(VARIABLE_SHAPE_TENSOR_SYMBOL, meta; ndim=ndim) + return +end + +function _validatecanonicalpassthrough(field::Meta.Field, typenamesym::Symbol, metadata::String) + if typenamesym === PARQUET_VARIANT_SYMBOL + _validateparquetvariant(field, metadata) + elseif typenamesym === FIXED_SHAPE_TENSOR_SYMBOL + _validatefixedshapetensor(field, metadata) + elseif typenamesym === VARIABLE_SHAPE_TENSOR_SYMBOL + _validatevariableshapetensor(field, metadata) + end + return +end + """ Given a FlatBuffers.Builder and a Julia column or column eltype, Write the field.type flatbuffer definition of the eltype @@ -49,12 +240,13 @@ end function juliaeltype(f::Meta.Field, meta::AbstractDict{String,String}, convert::Bool) TT = juliaeltype(f, convert) - !convert && return TT - T = finaljuliatype(TT) if haskey(meta, "ARROW:extension:name") typename = meta["ARROW:extension:name"] metadata = get(meta, "ARROW:extension:metadata", "") typenamesym = Symbol(typename) + _validatecanonicalpassthrough(f, typenamesym, metadata) + !convert && return TT + T = finaljuliatype(TT) storageT = typenamesym === TIMESTAMP_WITH_OFFSET_SYMBOL ? maybemissing(juliaeltype(f, false)) : maybemissing(TT) @@ -66,6 +258,8 @@ function juliaeltype(f::Meta.Field, meta::AbstractDict{String,String}, convert:: 1 _id = hash((:juliaeltype, typename, TT)) end end + !convert && return TT + T = finaljuliatype(TT) return something(TT, T) end @@ -179,6 +373,67 @@ opaquemetadata(type_name::AbstractString, vendor_name::AbstractString) = "{\"type_name\":" * _jsonstringliteral(type_name) * ",\"vendor_name\":" * _jsonstringliteral(vendor_name) * "}" +variantmetadata() = "" + +function fixedshapetensormetadata( + shape::AbstractVector{<:Integer}; + dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, + permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, +) + parsed_shape = _parseintvector(FIXED_SHAPE_TENSOR_SYMBOL, collect(shape), "shape") + parsed_dim_names = dim_names === nothing ? nothing : String.(dim_names) + parsed_permutation = + permutation === nothing ? nothing : _validatepermutation( + FIXED_SHAPE_TENSOR_SYMBOL, + Int.(permutation), + length(parsed_shape), + ) + parsed_dim_names !== nothing && length(parsed_dim_names) == length(parsed_shape) || + isnothing(parsed_dim_names) || + _canonicalextensionerror( + FIXED_SHAPE_TENSOR_SYMBOL, + "\"dim_names\" must have length $(length(parsed_shape))", + ) + body = Dict{String,Any}("shape" => parsed_shape) + parsed_dim_names !== nothing && (body["dim_names"] = parsed_dim_names) + parsed_permutation !== nothing && (body["permutation"] = parsed_permutation) + return JSON3.write(body) +end + +function variableshapetensormetadata(; + uniform_shape::Union{Nothing,AbstractVector}=nothing, + dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, + permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, +) + uniform = uniform_shape === nothing ? nothing : + _parseintvector( + VARIABLE_SHAPE_TENSOR_SYMBOL, + collect(uniform_shape), + "uniform_shape"; + allow_null=true, + ) + ndim = uniform === nothing ? nothing : length(uniform) + parsed_dim_names = dim_names === nothing ? nothing : String.(dim_names) + parsed_permutation = + permutation === nothing ? nothing : + Int.(permutation) + ndim !== nothing && parsed_dim_names !== nothing && + length(parsed_dim_names) == ndim || + ndim === nothing || + isnothing(parsed_dim_names) || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"dim_names\" must have length $ndim", + ) + ndim !== nothing && parsed_permutation !== nothing && + _validatepermutation(VARIABLE_SHAPE_TENSOR_SYMBOL, parsed_permutation, ndim) + body = Dict{String,Any}() + uniform !== nothing && (body["uniform_shape"] = uniform) + parsed_dim_names !== nothing && (body["dim_names"] = parsed_dim_names) + parsed_permutation !== nothing && (body["permutation"] = parsed_permutation) + return isempty(body) ? "" : JSON3.write(body) +end + # primitive types function juliaeltype(f::Meta.Field, fp::Meta.FloatingPoint, convert) if fp.precision == Meta.Precision.HALF diff --git a/test/runtests.jl b/test/runtests.jl index db5f6d85..23f14544 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,7 @@ using CategoricalArrays using DataAPI using FilePathsBase using DataFrames +using JSON3 using OffsetArrays import Random: randstring using TestSetExtensions: ExtendedTestSet @@ -843,13 +844,55 @@ const hybrid = EnumRoundtripModule.hybrid end @testset "canonical advanced passthrough" begin + function assert_canonical_extension_error(f::Function, needle::AbstractString) + err = try + f() + nothing + catch e + e + end + @test err !== nothing + @test occursin(needle, sprint(showerror, err)) + return + end + + @test Arrow.variantmetadata() == "" + + fixed_metadata = Arrow.fixedshapetensormetadata( + [2, 2]; + dim_names=["x", "y"], + permutation=[1, 0], + ) + @test JSON3.read(fixed_metadata)["shape"] == [2, 2] + @test JSON3.read(fixed_metadata)["dim_names"] == ["x", "y"] + @test JSON3.read(fixed_metadata)["permutation"] == [1, 0] + + variable_metadata = Arrow.variableshapetensormetadata( + uniform_shape=Union{Nothing,Int}[2], + dim_names=["axis0"], + permutation=[0], + ) + @test JSON3.read(variable_metadata)["uniform_shape"] == [2] + @test JSON3.read(variable_metadata)["dim_names"] == ["axis0"] + @test JSON3.read(variable_metadata)["permutation"] == [0] + @test Arrow.variableshapetensormetadata() == "" + + @test_throws ArgumentError Arrow.fixedshapetensormetadata( + [2, 2]; + dim_names=["x"], + ) + @test_throws ArgumentError Arrow.variableshapetensormetadata( + uniform_shape=Union{Nothing,Int}[2, nothing]; + permutation=[0], + ) + variant_values = Union{ Missing, NamedTuple{(:metadata, :value),Tuple{String,String}}, }[ (metadata="json", value="{\"a\":1}"), missing, - (metadata="string", value="abc"), + (metadata="str", value="abc"), ] @test_logs min_level=Base.CoreLogging.Warn begin variant_tt = Arrow.Table( @@ -858,7 +901,7 @@ const hybrid = EnumRoundtripModule.hybrid colmetadata=Dict( :col => Dict( "ARROW:extension:name" => "arrow.parquet.variant", - "ARROW:extension:metadata" => "", + "ARROW:extension:metadata" => Arrow.variantmetadata(), ), ), ), @@ -881,7 +924,7 @@ const hybrid = EnumRoundtripModule.hybrid colmetadata=Dict( :col => Dict( "ARROW:extension:name" => "arrow.fixed_shape_tensor", - "ARROW:extension:metadata" => "{\"shape\":[2,2],\"dim_names\":[\"x\",\"y\"]}", + "ARROW:extension:metadata" => fixed_metadata, ), ), ), @@ -892,10 +935,13 @@ const hybrid = EnumRoundtripModule.hybrid "arrow.fixed_shape_tensor" end - variable_tensor_values = Union{Missing,Vector{Int32}}[ - Int32[1, 2, 3, 4], + variable_tensor_values = Union{ + Missing, + NamedTuple{(:data, :shape),Tuple{Vector{Int32},NTuple{1,Int32}}}, + }[ + (data=Int32[1, 2, 3, 4], shape=(Int32(2),)), missing, - Int32[5, 6], + (data=Int32[5, 6], shape=(Int32(1),)), ] @test_logs min_level=Base.CoreLogging.Warn begin variable_tensor_tt = Arrow.Table( @@ -904,23 +950,66 @@ const hybrid = EnumRoundtripModule.hybrid colmetadata=Dict( :col => Dict( "ARROW:extension:name" => "arrow.variable_shape_tensor", - "ARROW:extension:metadata" => "{\"uniform_shape\":[2]}", + "ARROW:extension:metadata" => variable_metadata, ), ), ), ) - @test Base.nonmissingtype(eltype(variable_tensor_tt.col)) <: AbstractVector{Int32} + @test eltype(variable_tensor_tt.col) == eltype(variable_tensor_values) @test isequal( - map(x -> x === missing ? missing : copy(x), copy(variable_tensor_tt.col)), - Union{Missing,Vector{Int32}}[ - Int32[1, 2, 3, 4], - missing, - Int32[5, 6], - ], + map( + x -> x === missing ? missing : (data=copy(x.data), shape=x.shape), + copy(variable_tensor_tt.col), + ), + variable_tensor_values, ) @test Arrow.getmetadata(variable_tensor_tt.col)["ARROW:extension:name"] == "arrow.variable_shape_tensor" end + + invalid_variant_bytes = Arrow.tobuffer( + (col=variant_values,); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.parquet.variant", + "ARROW:extension:metadata" => "{\"unexpected\":true}", + ), + ), + ) + assert_canonical_extension_error( + () -> Arrow.Table(invalid_variant_bytes), + "invalid canonical arrow.parquet.variant extension", + ) + + invalid_fixed_bytes = Arrow.tobuffer( + (col=fixed_tensor_values,); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.fixed_shape_tensor", + "ARROW:extension:metadata" => Arrow.fixedshapetensormetadata([3, 2]), + ), + ), + ) + assert_canonical_extension_error( + () -> Arrow.Table(invalid_fixed_bytes), + "invalid canonical arrow.fixed_shape_tensor extension", + ) + + invalid_variable_bytes = Arrow.tobuffer( + (col=["a", "b"],); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "arrow.variable_shape_tensor", + "ARROW:extension:metadata" => Arrow.variableshapetensormetadata( + uniform_shape=Union{Nothing,Int}[1], + ), + ), + ), + ) + assert_canonical_extension_error( + () -> Arrow.Table(invalid_variable_bytes), + "invalid canonical arrow.variable_shape_tensor extension", + ) end @testset "tensor message boundary" begin From 412b8679a186a013a299bfbab246d14c649f8be8 Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 17:53:20 -0700 Subject: [PATCH 11/16] Centralize logical extension runtime contract --- src/Arrow.jl | 1 + src/ArrowTypes/src/ArrowTypes.jl | 7 +- src/ArrowTypes/test/tests.jl | 3 +- src/arraytypes/arraytypes.jl | 19 +-- src/arraytypes/runendencoded.jl | 12 +- src/eltypes.jl | 242 ++++++++++++++++++------------- src/logicaltypes.jl | 84 +++++++++++ src/table.jl | 3 +- test/flight/ipc_conversion.jl | 20 +++ test/runtests.jl | 161 ++++++++++++-------- 10 files changed, 362 insertions(+), 190 deletions(-) create mode 100644 src/logicaltypes.jl diff --git a/src/Arrow.jl b/src/Arrow.jl index b899e88f..8223449e 100644 --- a/src/Arrow.jl +++ b/src/Arrow.jl @@ -94,6 +94,7 @@ const Meta = Flatbuf using ArrowTypes include("utils.jl") +include("logicaltypes.jl") include("arraytypes/arraytypes.jl") include("eltypes.jl") include("table.jl") diff --git a/src/ArrowTypes/src/ArrowTypes.jl b/src/ArrowTypes/src/ArrowTypes.jl index ce832710..319224ee 100644 --- a/src/ArrowTypes/src/ArrowTypes.jl +++ b/src/ArrowTypes/src/ArrowTypes.jl @@ -225,7 +225,10 @@ end function _enum_labels(::Type{T}) where {T<:Enum} B = Base.Enums.basetype(T) - return join((string(instance, ":", convert(B, Int(instance))) for instance in instances(T)), ",") + return join( + (string(instance, ":", convert(B, Int(instance))) for instance in instances(T)), + ",", + ) end function arrowmetadata(::Type{T}) where {T<:Enum} @@ -233,7 +236,7 @@ function arrowmetadata(::Type{T}) where {T<:Enum} end function _parsemetadata(metadata::AbstractString) - parsed = Dict{String, String}() + parsed = Dict{String,String}() isempty(metadata) && return parsed for entry in split(metadata, ';') isempty(entry) && continue diff --git a/src/ArrowTypes/test/tests.jl b/src/ArrowTypes/test/tests.jl index 16f420e5..3d985b22 100644 --- a/src/ArrowTypes/test/tests.jl +++ b/src/ArrowTypes/test/tests.jl @@ -83,7 +83,8 @@ const hybrid = EnumTestModule.hybrid @test ArrowTypes.arrowname(RankingStrategy) == ArrowTypes.ENUM @test occursin("type=Main.EnumTestModule.RankingStrategy", enum_metadata) @test occursin("labels=lexical:1,semantic:2,hybrid:3", enum_metadata) - @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, enum_metadata) == RankingStrategy + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, enum_metadata) == + RankingStrategy @test ArrowTypes.fromarrow(RankingStrategy, Int32(2)) == semantic @test ArrowTypes.default(RankingStrategy) == lexical diff --git a/src/arraytypes/arraytypes.jl b/src/arraytypes/arraytypes.jl index 9f04292e..281dab14 100644 --- a/src/arraytypes/arraytypes.jl +++ b/src/arraytypes/arraytypes.jl @@ -103,13 +103,7 @@ function arrowvector( x = ToArrow(x) end S = maybemissing(eltype(x)) - if ArrowTypes.hasarrowname(T) - meta = _arrowtypemeta( - _normalizemeta(meta), - String(ArrowTypes.arrowname(T)), - String(ArrowTypes.arrowmetadata(T)), - ) - end + meta = _extensionmetadatafor(T, _normalizemeta(meta)) return arrowvector( S, x, @@ -133,17 +127,6 @@ _normalizecolmeta(colmeta) = toidict( Symbol(k) => toidict(String(v1) => String(v2) for (v1, v2) in v) for (k, v) in colmeta ) -function _arrowtypemeta(::Nothing, n, m) - return toidict(("ARROW:extension:name" => n, "ARROW:extension:metadata" => m)) -end - -function _arrowtypemeta(meta, n, m) - dict = Dict(meta) - dict["ARROW:extension:name"] = n - dict["ARROW:extension:metadata"] = m - return toidict(dict) -end - @inline function _materializeconverted(x::ArrowTypes.ToArrow) data = ArrowTypes._sourcedata(x) if ArrowTypes._needsconvert(x) && !ArrowTypes.concrete_or_concreteunion(eltype(data)) diff --git a/src/arraytypes/runendencoded.jl b/src/arraytypes/runendencoded.jl index 7ca318b0..05946e38 100644 --- a/src/arraytypes/runendencoded.jl +++ b/src/arraytypes/runendencoded.jl @@ -64,11 +64,13 @@ function _validaterunendencoded(run_ends, values, len) ) last_end = current_end end - len == 0 || last_end == len || throw( - ArgumentError( - "invalid Run-End Encoded array: final run end $last_end does not match logical length $len", - ), - ) + len == 0 || + last_end == len || + throw( + ArgumentError( + "invalid Run-End Encoded array: final run end $last_end does not match logical length $len", + ), + ) return end diff --git a/src/eltypes.jl b/src/eltypes.jl index 3ae5d56b..9fbacb94 100644 --- a/src/eltypes.jl +++ b/src/eltypes.jl @@ -32,6 +32,13 @@ const PARQUET_VARIANT_SYMBOL = Symbol("arrow.parquet.variant") const FIXED_SHAPE_TENSOR_SYMBOL = Symbol("arrow.fixed_shape_tensor") const VARIABLE_SHAPE_TENSOR_SYMBOL = Symbol("arrow.variable_shape_tensor") +_builtinextensionspec(::Type{ArrowTypes.UUID}) = + ExtensionTypeSpec(ArrowTypes.UUIDSYMBOL, "") +_builtinextensionjuliatype(::Val{ArrowTypes.UUIDSYMBOL}, S, metadata::String) = + ArrowTypes.UUID +_builtinextensionjuliatype(::Val{ArrowTypes.LEGACY_UUIDSYMBOL}, S, metadata::String) = + ArrowTypes.UUID + @inline _canonicalextensionerror(sym::Symbol, msg::AbstractString) = throw(ArgumentError("invalid canonical $(String(sym)) extension: $msg")) @@ -42,13 +49,15 @@ const VARIABLE_SHAPE_TENSOR_SYMBOL = Symbol("arrow.variable_shape_tensor") @inline _jsonget(x, key::AbstractString) = x[key] function _parsecanonicalmetadata(sym::Symbol, metadata::String; required::Bool=false) - isempty(metadata) && return required ? _canonicalextensionerror(sym, "metadata is required") : nothing + isempty(metadata) && + return required ? _canonicalextensionerror(sym, "metadata is required") : nothing value = try JSON3.read(metadata) catch _canonicalextensionerror(sym, "metadata must be valid JSON") end - value isa JSON3.Object || _canonicalextensionerror(sym, "metadata must be a JSON object") + value isa JSON3.Object || + _canonicalextensionerror(sym, "metadata must be a JSON object") return value end @@ -60,7 +69,8 @@ function _parseintvector(sym::Symbol, value, label::AbstractString; allow_null:: if allow_null && isnothing(item) push!(parsed, nothing) elseif item isa Integer - item >= 0 || _canonicalextensionerror(sym, "\"$label\" values must be non-negative") + item >= 0 || + _canonicalextensionerror(sym, "\"$label\" values must be non-negative") push!(parsed, Int(item)) else suffix = allow_null ? "integers or null" : "integers" @@ -90,7 +100,11 @@ function _validatepermutation(sym::Symbol, permutation::Vector{Int}, ndim::Int) return permutation end -function _extractdimensionalmetadata(sym::Symbol, metadata; ndim::Union{Nothing,Int}=nothing) +function _extractdimensionalmetadata( + sym::Symbol, + metadata; + ndim::Union{Nothing,Int}=nothing, +) metadata === nothing && return (nothing, nothing, nothing) dim_names = _jsonhaskey(metadata, "dim_names") ? @@ -100,11 +114,16 @@ function _extractdimensionalmetadata(sym::Symbol, metadata; ndim::Union{Nothing, _parseintvector(sym, _jsonget(metadata, "permutation"), "permutation") : nothing uniform_shape = _jsonhaskey(metadata, "uniform_shape") ? - _parseintvector(sym, _jsonget(metadata, "uniform_shape"), "uniform_shape"; allow_null=true) : - nothing + _parseintvector( + sym, + _jsonget(metadata, "uniform_shape"), + "uniform_shape"; + allow_null=true, + ) : nothing if ndim !== nothing dim_names !== nothing && length(dim_names) == ndim || - isnothing(dim_names) || _canonicalextensionerror(sym, "\"dim_names\" must have length $ndim") + isnothing(dim_names) || + _canonicalextensionerror(sym, "\"dim_names\" must have length $ndim") permutation !== nothing && _validatepermutation(sym, permutation, ndim) uniform_shape !== nothing && length(uniform_shape) == ndim || isnothing(uniform_shape) || @@ -120,7 +139,10 @@ end x isa Union{Meta.Binary,Meta.LargeBinary,Meta.BinaryView,Meta.FixedSizeBinary} function _validateparquetvariant(field::Meta.Field, metadata::String) - isempty(metadata) || _canonicalextensionerror(PARQUET_VARIANT_SYMBOL, "metadata must be the empty string") + isempty(metadata) || _canonicalextensionerror( + PARQUET_VARIANT_SYMBOL, + "metadata must be the empty string", + ) field return end @@ -129,32 +151,22 @@ function _validatefixedshapetensor(field::Meta.Field, metadata::String) meta = _parsecanonicalmetadata(FIXED_SHAPE_TENSOR_SYMBOL, metadata; required=true) _jsonhaskey(meta, "shape") || _canonicalextensionerror(FIXED_SHAPE_TENSOR_SYMBOL, "\"shape\" is required") - shape = _parseintvector( + shape = _parseintvector(FIXED_SHAPE_TENSOR_SYMBOL, _jsonget(meta, "shape"), "shape") + dim_names, permutation, _ = + _extractdimensionalmetadata(FIXED_SHAPE_TENSOR_SYMBOL, meta; ndim=length(shape)) + field.type isa Meta.FixedSizeList || _canonicalextensionerror( FIXED_SHAPE_TENSOR_SYMBOL, - _jsonget(meta, "shape"), - "shape", + "storage must be a FixedSizeList", ) - dim_names, permutation, _ = _extractdimensionalmetadata( + length(collect(_fieldchildren(field))) == 1 || _canonicalextensionerror( FIXED_SHAPE_TENSOR_SYMBOL, - meta; - ndim=length(shape), + "storage must contain exactly one child field", ) - field.type isa Meta.FixedSizeList || - _canonicalextensionerror( - FIXED_SHAPE_TENSOR_SYMBOL, - "storage must be a FixedSizeList", - ) - length(collect(_fieldchildren(field))) == 1 || - _canonicalextensionerror( - FIXED_SHAPE_TENSOR_SYMBOL, - "storage must contain exactly one child field", - ) expected = isempty(shape) ? 1 : prod(shape) - Int(field.type.listSize) == expected || - _canonicalextensionerror( - FIXED_SHAPE_TENSOR_SYMBOL, - "\"shape\" product $expected does not match FixedSizeList size $(field.type.listSize)", - ) + Int(field.type.listSize) == expected || _canonicalextensionerror( + FIXED_SHAPE_TENSOR_SYMBOL, + "\"shape\" product $expected does not match FixedSizeList size $(field.type.listSize)", + ) dim_names permutation return @@ -162,45 +174,36 @@ end function _validatevariableshapetensor(field::Meta.Field, metadata::String) field.type isa Meta.Struct || - _canonicalextensionerror( - VARIABLE_SHAPE_TENSOR_SYMBOL, - "storage must be a Struct", - ) + _canonicalextensionerror(VARIABLE_SHAPE_TENSOR_SYMBOL, "storage must be a Struct") children = Dict(String(child.name) => child for child in collect(_fieldchildren(field))) - keys(children) == Set(("data", "shape")) || - _canonicalextensionerror( - VARIABLE_SHAPE_TENSOR_SYMBOL, - "storage must contain exactly \"data\" and \"shape\" fields", - ) + keys(children) == Set(("data", "shape")) || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "storage must contain exactly \"data\" and \"shape\" fields", + ) data_field = children["data"] shape_field = children["shape"] - _isliststoragetype(data_field.type) || - _canonicalextensionerror( - VARIABLE_SHAPE_TENSOR_SYMBOL, - "\"data\" field must use list storage", - ) - length(collect(_fieldchildren(data_field))) == 1 || - _canonicalextensionerror( - VARIABLE_SHAPE_TENSOR_SYMBOL, - "\"data\" field must contain exactly one child field", - ) - shape_field.type isa Meta.FixedSizeList || - _canonicalextensionerror( - VARIABLE_SHAPE_TENSOR_SYMBOL, - "\"shape\" field must use FixedSizeList storage", - ) + _isliststoragetype(data_field.type) || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"data\" field must use list storage", + ) + length(collect(_fieldchildren(data_field))) == 1 || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"data\" field must contain exactly one child field", + ) + shape_field.type isa Meta.FixedSizeList || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" field must use FixedSizeList storage", + ) shape_children = collect(_fieldchildren(shape_field)) - length(shape_children) == 1 || - _canonicalextensionerror( - VARIABLE_SHAPE_TENSOR_SYMBOL, - "\"shape\" field must contain exactly one child field", - ) + length(shape_children) == 1 || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" field must contain exactly one child field", + ) shape_value = only(shape_children) - shape_value.type isa Meta.Int || - _canonicalextensionerror( - VARIABLE_SHAPE_TENSOR_SYMBOL, - "\"shape\" values must use Int32 storage", - ) + shape_value.type isa Meta.Int || _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"shape\" values must use Int32 storage", + ) (shape_value.type.bitWidth == 32 && shape_value.type.is_signed) || _canonicalextensionerror( VARIABLE_SHAPE_TENSOR_SYMBOL, @@ -212,7 +215,11 @@ function _validatevariableshapetensor(field::Meta.Field, metadata::String) return end -function _validatecanonicalpassthrough(field::Meta.Field, typenamesym::Symbol, metadata::String) +function _validatecanonicalpassthrough( + field::Meta.Field, + typenamesym::Symbol, + metadata::String, +) if typenamesym === PARQUET_VARIANT_SYMBOL _validateparquetvariant(field, metadata) elseif typenamesym === FIXED_SHAPE_TENSOR_SYMBOL @@ -240,21 +247,20 @@ end function juliaeltype(f::Meta.Field, meta::AbstractDict{String,String}, convert::Bool) TT = juliaeltype(f, convert) - if haskey(meta, "ARROW:extension:name") - typename = meta["ARROW:extension:name"] - metadata = get(meta, "ARROW:extension:metadata", "") - typenamesym = Symbol(typename) - _validatecanonicalpassthrough(f, typenamesym, metadata) + spec = _extensionspec(meta) + if spec !== nothing + _validatecanonicalpassthrough(f, spec.name, spec.metadata) !convert && return TT T = finaljuliatype(TT) storageT = - typenamesym === TIMESTAMP_WITH_OFFSET_SYMBOL ? maybemissing(juliaeltype(f, false)) : - maybemissing(TT) - JT = ArrowTypes.JuliaType(Val(typenamesym), storageT, metadata) + spec.name === TIMESTAMP_WITH_OFFSET_SYMBOL ? + maybemissing(juliaeltype(f, false)) : maybemissing(TT) + JT = _resolveextensionjuliatype(spec, storageT) if JT !== nothing return f.nullable ? Union{JT,Missing} : JT else - @warn "unsupported ARROW:extension:name type: \"$typename\", arrow type = $TT" maxlog = + typename = _extensiontypename(spec) + @warn "unsupported $(EXTENSION_NAME_KEY) type: \"$typename\", arrow type = $TT" maxlog = 1 _id = hash((:juliaeltype, typename, TT)) end end @@ -333,8 +339,14 @@ ArrowTypes.arrowname(::Type{Bool8}) = BOOL8_SYMBOL ArrowTypes.JuliaType(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = Bool8 ArrowTypes.fromarrow(::Type{Bool8}, x::Int8) = Bool8(x) ArrowTypes.default(::Type{Bool8}) = zero(Bool8) - -function writearray(io::IO, ::Type{Int8}, col::ArrowTypes.ToArrow{Int8,A}) where {A<:AbstractVector{Bool8}} +_builtinextensionspec(::Type{Bool8}) = ExtensionTypeSpec(BOOL8_SYMBOL, "") +_builtinextensionjuliatype(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = Bool8 + +function writearray( + io::IO, + ::Type{Int8}, + col::ArrowTypes.ToArrow{Int8,A}, +) where {A<:AbstractVector{Bool8}} data = ArrowTypes._sourcedata(col) strides(data) == (1,) || return _writearrayfallback(io, Int8, col) return Base.write(io, reinterpret(Int8, data)) @@ -352,26 +364,43 @@ Base.isequal(x::JSONText, y::JSONText) = isequal(getfield(x, :value), getfield(y ArrowTypes.ArrowType(::Type{JSONText{S}}) where {S<:AbstractString} = S ArrowTypes.toarrow(x::JSONText) = getfield(x, :value) ArrowTypes.arrowname(::Type{JSONText{S}}) where {S<:AbstractString} = JSON_SYMBOL -ArrowTypes.JuliaType(::Val{JSON_SYMBOL}, ::Type{S}, metadata::String) where {S<:AbstractString} = - JSONText{S} +ArrowTypes.JuliaType( + ::Val{JSON_SYMBOL}, + ::Type{S}, + metadata::String, +) where {S<:AbstractString} = JSONText{S} ArrowTypes.fromarrow(::Type{JSONText{String}}, ptr::Ptr{UInt8}, len::Int) = JSONText(unsafe_string(ptr, len)) ArrowTypes.fromarrow(::Type{JSONText{S}}, x::S) where {S<:AbstractString} = JSONText{S}(x) ArrowTypes.default(::Type{JSONText{S}}) where {S<:AbstractString} = JSONText{S}(ArrowTypes.default(S)) +_builtinextensionspec(::Type{JSONText{S}}) where {S<:AbstractString} = + ExtensionTypeSpec(JSON_SYMBOL, "") +_builtinextensionjuliatype( + ::Val{JSON_SYMBOL}, + ::Type{S}, + metadata::String, +) where {S<:AbstractString} = JSONText{S} ArrowTypes.JuliaType(::Val{OPAQUE_SYMBOL}, S, metadata::String) = S ArrowTypes.JuliaType(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = S ArrowTypes.JuliaType(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S ArrowTypes.JuliaType(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S +_builtinextensionjuliatype(::Val{OPAQUE_SYMBOL}, S, metadata::String) = S +_builtinextensionjuliatype(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = S +_builtinextensionjuliatype(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S +_builtinextensionjuliatype(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S @inline function _jsonstringliteral(x::AbstractString) return '"' * escape_string(x) * '"' end opaquemetadata(type_name::AbstractString, vendor_name::AbstractString) = - "{\"type_name\":" * _jsonstringliteral(type_name) * - ",\"vendor_name\":" * _jsonstringliteral(vendor_name) * "}" + "{\"type_name\":" * + _jsonstringliteral(type_name) * + ",\"vendor_name\":" * + _jsonstringliteral(vendor_name) * + "}" variantmetadata() = "" @@ -383,7 +412,8 @@ function fixedshapetensormetadata( parsed_shape = _parseintvector(FIXED_SHAPE_TENSOR_SYMBOL, collect(shape), "shape") parsed_dim_names = dim_names === nothing ? nothing : String.(dim_names) parsed_permutation = - permutation === nothing ? nothing : _validatepermutation( + permutation === nothing ? nothing : + _validatepermutation( FIXED_SHAPE_TENSOR_SYMBOL, Int.(permutation), length(parsed_shape), @@ -405,27 +435,26 @@ function variableshapetensormetadata(; dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, ) - uniform = uniform_shape === nothing ? nothing : - _parseintvector( - VARIABLE_SHAPE_TENSOR_SYMBOL, - collect(uniform_shape), - "uniform_shape"; - allow_null=true, - ) + uniform = + uniform_shape === nothing ? nothing : + _parseintvector( + VARIABLE_SHAPE_TENSOR_SYMBOL, + collect(uniform_shape), + "uniform_shape"; + allow_null=true, + ) ndim = uniform === nothing ? nothing : length(uniform) parsed_dim_names = dim_names === nothing ? nothing : String.(dim_names) - parsed_permutation = - permutation === nothing ? nothing : - Int.(permutation) - ndim !== nothing && parsed_dim_names !== nothing && - length(parsed_dim_names) == ndim || + parsed_permutation = permutation === nothing ? nothing : Int.(permutation) + ndim !== nothing && parsed_dim_names !== nothing && length(parsed_dim_names) == ndim || ndim === nothing || isnothing(parsed_dim_names) || _canonicalextensionerror( VARIABLE_SHAPE_TENSOR_SYMBOL, "\"dim_names\" must have length $ndim", ) - ndim !== nothing && parsed_permutation !== nothing && + ndim !== nothing && + parsed_permutation !== nothing && _validatepermutation(VARIABLE_SHAPE_TENSOR_SYMBOL, parsed_permutation, ndim) body = Dict{String,Any}() uniform !== nothing && (body["uniform_shape"] = uniform) @@ -596,10 +625,8 @@ struct TimestampWithOffset{U} offset_minutes::Int16 end -TimestampWithOffset( - timestamp::Timestamp{U,:UTC}, - offset_minutes::Integer, -) where {U} = TimestampWithOffset{U}(timestamp, Int16(offset_minutes)) +TimestampWithOffset(timestamp::Timestamp{U,:UTC}, offset_minutes::Integer) where {U} = + TimestampWithOffset{U}(timestamp, Int16(offset_minutes)) Base.zero(::Type{TimestampWithOffset{U}}) where {U} = TimestampWithOffset{U}(zero(Timestamp{U,:UTC}), Int16(0)) @@ -673,17 +700,15 @@ ArrowTypes.JuliaType(::Val{ZONEDDATETIME_SYMBOL}, S) = ZonedDateTime ArrowTypes.fromarrow(::Type{ZonedDateTime}, x::Timestamp) = convert(ZonedDateTime, x) ArrowTypes.default(::Type{TimeZones.ZonedDateTime}) = TimeZones.ZonedDateTime(1, 1, 1, 1, 1, 1, TimeZones.tz"UTC") +_builtinextensionspec(::Type{ZonedDateTime}) = ExtensionTypeSpec(ZONEDDATETIME_SYMBOL, "") +_builtinextensionjuliatype(::Val{ZONEDDATETIME_SYMBOL}, S, metadata::String) = ZonedDateTime const TIMESTAMP_WITH_OFFSET_SYMBOL = Symbol("arrow.timestamp_with_offset") -ArrowTypes.arrowname(::Type{TimestampWithOffset{U}}) where {U} = TIMESTAMP_WITH_OFFSET_SYMBOL +ArrowTypes.arrowname(::Type{TimestampWithOffset{U}}) where {U} = + TIMESTAMP_WITH_OFFSET_SYMBOL ArrowTypes.JuliaType( ::Val{TIMESTAMP_WITH_OFFSET_SYMBOL}, - ::Type{ - NamedTuple{ - (:timestamp, :offset_minutes), - Tuple{Timestamp{U,:UTC},Int16}, - }, - }, + ::Type{NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}}}, metadata::String, ) where {U} = TimestampWithOffset{U} ArrowTypes.default(::Type{TimestampWithOffset{U}}) where {U} = zero(TimestampWithOffset{U}) @@ -699,12 +724,21 @@ ArrowTypes.fromarrowstruct( offset_minutes::Int16, timestamp::Timestamp{U,:UTC}, ) where {U} = TimestampWithOffset{U}(timestamp, offset_minutes) +_builtinextensionspec(::Type{TimestampWithOffset{U}}) where {U} = + ExtensionTypeSpec(TIMESTAMP_WITH_OFFSET_SYMBOL, "") +_builtinextensionjuliatype( + ::Val{TIMESTAMP_WITH_OFFSET_SYMBOL}, + ::Type{NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}}}, + metadata::String, +) where {U} = TimestampWithOffset{U} # Backwards compatibility: older versions of Arrow saved ZonedDateTime's with this metdata: const OLD_ZONEDDATETIME_SYMBOL = Symbol("JuliaLang.ZonedDateTime") # and stored the local time instead of the UTC time. struct LocalZonedDateTime end ArrowTypes.JuliaType(::Val{OLD_ZONEDDATETIME_SYMBOL}, S) = LocalZonedDateTime +_builtinextensionjuliatype(::Val{OLD_ZONEDDATETIME_SYMBOL}, S, metadata::String) = + LocalZonedDateTime function ArrowTypes.fromarrow(::Type{LocalZonedDateTime}, x::Timestamp{U,TZ}) where {U,TZ} (U === Meta.TimeUnit.MICROSECOND || U == Meta.TimeUnit.NANOSECOND) && warntimestamp(U, ZonedDateTime) diff --git a/src/logicaltypes.jl b/src/logicaltypes.jl new file mode 100644 index 00000000..cdd28844 --- /dev/null +++ b/src/logicaltypes.jl @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +const EXTENSION_NAME_KEY = "ARROW:extension:name" +const EXTENSION_METADATA_KEY = "ARROW:extension:metadata" + +struct ExtensionTypeSpec + name::Symbol + metadata::String +end + +@inline _extensiontypename(spec::ExtensionTypeSpec) = String(spec.name) +@inline _builtinextensionspec(::Type{T}) where {T} = nothing +@inline _builtinextensionjuliatype(::Val{name}, storageT, metadata) where {name} = nothing + +@inline function _extensionmetadatafor(::Type{T}, meta) where {T} + spec = _extensionspec(T) + spec === nothing && return meta + return _mergeextensionmeta(meta, spec) +end + +@inline function _extensionspec(::Type{T}) where {T} + spec = _builtinextensionspec(T) + spec !== nothing && return spec + ArrowTypes.hasarrowname(T) || return nothing + return ExtensionTypeSpec(ArrowTypes.arrowname(T), String(ArrowTypes.arrowmetadata(T))) +end + +@inline function _extensionspec(meta::AbstractDict{String,String}) + haskey(meta, EXTENSION_NAME_KEY) || return nothing + return ExtensionTypeSpec( + Symbol(meta[EXTENSION_NAME_KEY]), + get(meta, EXTENSION_METADATA_KEY, ""), + ) +end + +function _mergeextensionmeta(::Nothing, spec::ExtensionTypeSpec) + return toidict(( + EXTENSION_NAME_KEY => _extensiontypename(spec), + EXTENSION_METADATA_KEY => spec.metadata, + ),) +end + +function _mergeextensionmeta(::Nothing, name::Symbol, metadata::String) + return toidict((EXTENSION_NAME_KEY => String(name), EXTENSION_METADATA_KEY => metadata)) +end + +function _mergeextensionmeta(meta, spec::ExtensionTypeSpec) + dict = Dict(meta) + dict[EXTENSION_NAME_KEY] = _extensiontypename(spec) + dict[EXTENSION_METADATA_KEY] = spec.metadata + return toidict(dict) +end + +function _mergeextensionmeta(meta, name::Symbol, metadata::String) + dict = Dict(meta) + dict[EXTENSION_NAME_KEY] = String(name) + dict[EXTENSION_METADATA_KEY] = metadata + return toidict(dict) +end + +@inline function _builtinextensionjuliatype(spec::ExtensionTypeSpec, storageT) + return _builtinextensionjuliatype(Val(spec.name), storageT, spec.metadata) +end + +@inline function _resolveextensionjuliatype(spec::ExtensionTypeSpec, storageT) + builtin = _builtinextensionjuliatype(spec, storageT) + builtin !== nothing && return builtin + return ArrowTypes.JuliaType(Val(spec.name), storageT, spec.metadata) +end diff --git a/src/table.jl b/src/table.jl index 0af09dd8..991da43e 100644 --- a/src/table.jl +++ b/src/table.jl @@ -28,7 +28,8 @@ tobytes(io::IO) = Base.read(io) tobytes(io::IOStream) = Mmap.mmap(io) tobytes(file_path) = open(tobytes, file_path, "r") -rejectunsupported(field::Meta.Field) = (rejectunsupported(field.type); foreach(rejectunsupported, field.children)) +rejectunsupported(field::Meta.Field) = + (rejectunsupported(field.type); foreach(rejectunsupported, field.children)) rejectunsupported(x) = nothing struct BatchIterator diff --git a/test/flight/ipc_conversion.jl b/test/flight/ipc_conversion.jl index a2d20671..0b59b662 100644 --- a/test/flight/ipc_conversion.jl +++ b/test/flight/ipc_conversion.jl @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +using Tables +using UUIDs + @testset "Flight IPC conversion helpers" begin missing_schema_fragment = "the server may have terminated the stream before emitting the first schema-bearing FlightData message" descriptor = Arrow.Flight.Protocol.FlightDescriptor( @@ -75,4 +78,21 @@ ) @test isempty(empty_tbl.id) @test isempty(empty_tbl.label) + + extension_source = ( + uuid=[UUID(UInt128(1)), UUID(UInt128(2))], + flag=[Arrow.Bool8(true), Arrow.Bool8(false)], + ) + extension_messages = Arrow.Flight.flightdata(extension_source) + extension_batches = collect(Arrow.Flight.stream(extension_messages)) + extension_tbl = Arrow.Flight.table(extension_messages) + + @test Arrow.getmetadata(extension_batches[1].uuid)[Arrow.EXTENSION_NAME_KEY] == + "arrow.uuid" + @test Arrow.getmetadata(extension_batches[1].flag)[Arrow.EXTENSION_NAME_KEY] == + "arrow.bool8" + @test Arrow.getmetadata(extension_tbl.uuid)[Arrow.EXTENSION_NAME_KEY] == "arrow.uuid" + @test Arrow.getmetadata(extension_tbl.flag)[Arrow.EXTENSION_NAME_KEY] == "arrow.bool8" + @test copy(extension_tbl.uuid) == extension_source.uuid + @test Bool.(copy(extension_tbl.flag)) == Bool.(extension_source.flag) end diff --git a/test/runtests.jl b/test/runtests.jl index 23f14544..6d0d3e7b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,11 +56,6 @@ module EnumRoundtripModule @enum RankingStrategy lexical=1 semantic=2 hybrid=3 end -const RankingStrategy = EnumRoundtripModule.RankingStrategy -const lexical = EnumRoundtripModule.lexical -const semantic = EnumRoundtripModule.semantic -const hybrid = EnumRoundtripModule.hybrid - @testset ExtendedTestSet "Arrow" begin @testset "table roundtrips" begin for case in testtables @@ -468,8 +463,11 @@ const hybrid = EnumRoundtripModule.hybrid @testset "# Julia Enum extension logical type roundtrip" begin t = ( - col1=[lexical, hybrid], - col2=Union{Missing, RankingStrategy}[missing, semantic], + col1=[EnumRoundtripModule.lexical, EnumRoundtripModule.hybrid], + col2=Union{Missing,EnumRoundtripModule.RankingStrategy}[ + missing, + EnumRoundtripModule.semantic, + ], ) bytes = read(Arrow.tobuffer(t)) @@ -477,17 +475,20 @@ const hybrid = EnumRoundtripModule.hybrid raw = Arrow.Table(IOBuffer(bytes); convert=false) @test length(tt) == length(t) - @test eltype(tt.col1) == RankingStrategy - @test eltype(tt.col2) == Union{Missing, RankingStrategy} - @test tt.col1 == [lexical, hybrid] + @test eltype(tt.col1) == EnumRoundtripModule.RankingStrategy + @test eltype(tt.col2) == Union{Missing,EnumRoundtripModule.RankingStrategy} + @test tt.col1 == [EnumRoundtripModule.lexical, EnumRoundtripModule.hybrid] @test isequal( tt.col2, - Union{Missing, RankingStrategy}[missing, semantic], + Union{Missing,EnumRoundtripModule.RankingStrategy}[ + missing, + EnumRoundtripModule.semantic, + ], ) @test eltype(raw.col1) == Int32 - @test eltype(raw.col2) == Union{Missing, Int32} + @test eltype(raw.col2) == Union{Missing,Int32} @test raw.col1 == Int32[1, 3] - @test isequal(raw.col2, Union{Missing, Int32}[missing, 2]) + @test isequal(raw.col2, Union{Missing,Int32}[missing, 2]) @test Arrow.getmetadata(tt.col1)["ARROW:extension:name"] == "JuliaLang.Enum" @test occursin( "Main.EnumRoundtripModule.RankingStrategy", @@ -721,17 +722,22 @@ const hybrid = EnumRoundtripModule.hybrid end @testset "canonical timestamp_with_offset" begin - values = Union{Missing,Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}}[ - Arrow.TimestampWithOffset( - Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(1577836800000), - 330, - ), - missing, - Arrow.TimestampWithOffset( - Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(1577923200000), - -480, - ), - ] + values = + Union{Missing,Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}}[ + Arrow.TimestampWithOffset( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}( + 1577836800000, + ), + 330, + ), + missing, + Arrow.TimestampWithOffset( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}( + 1577923200000, + ), + -480, + ), + ] tt = Arrow.Table(Arrow.tobuffer((col=values,))) @test eltype(tt.col) == Union{Missing,Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}} @@ -740,15 +746,11 @@ const hybrid = EnumRoundtripModule.hybrid "arrow.timestamp_with_offset" raw_tt = Arrow.Table(Arrow.tobuffer((col=values,)); convert=false) - @test eltype(raw_tt.col) == - Union{ + @test eltype(raw_tt.col) == Union{ Missing, NamedTuple{ (:timestamp, :offset_minutes), - Tuple{ - Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, - Int16, - }, + Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, }, } @test isequal( @@ -757,25 +759,20 @@ const hybrid = EnumRoundtripModule.hybrid Missing, NamedTuple{ (:timestamp, :offset_minutes), - Tuple{ - Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, - Int16, - }, + Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, }, }[ ( - timestamp=Arrow.Timestamp{ - Arrow.Meta.TimeUnit.MILLISECOND, - :UTC, - }(1577836800000), + timestamp=Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}( + 1577836800000, + ), offset_minutes=Int16(330), ), missing, ( - timestamp=Arrow.Timestamp{ - Arrow.Meta.TimeUnit.MILLISECOND, - :UTC, - }(1577923200000), + timestamp=Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}( + 1577923200000, + ), offset_minutes=Int16(-480), ), ], @@ -797,11 +794,14 @@ const hybrid = EnumRoundtripModule.hybrid @test collect(batches[1].x) == expected @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) Arrow.tobuffer(tt) - @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) Arrow.tobuffer((x=tt.x,)) + @test_throws ArgumentError(Arrow.RUN_END_ENCODED_UNSUPPORTED) Arrow.tobuffer(( + x=tt.x, + )) end @testset "canonical bool8/json/opaque" begin - bools = Union{Missing,Arrow.Bool8}[Arrow.Bool8(true), missing, Arrow.Bool8(false)] + bools = + Union{Missing,Arrow.Bool8}[Arrow.Bool8(true), missing, Arrow.Bool8(false)] tt = Arrow.Table(Arrow.tobuffer((col=bools,))) @test eltype(tt.col) == Union{Missing,Arrow.Bool8} @test isequal(copy(tt.col), bools) @@ -823,7 +823,10 @@ const hybrid = EnumRoundtripModule.hybrid raw_json_tt = Arrow.Table(Arrow.tobuffer((col=jsons,)); convert=false) @test eltype(raw_json_tt.col) == Union{Missing,String} - @test isequal(copy(raw_json_tt.col), Union{Missing,String}["{\"a\":1}", missing, "[1,2,3]"]) + @test isequal( + copy(raw_json_tt.col), + Union{Missing,String}["{\"a\":1}", missing, "[1,2,3]"], + ) opaque_meta = Arrow.opaquemetadata("pkg.Type", "vendor.example") opaque_tt = Arrow.Table( @@ -840,7 +843,8 @@ const hybrid = EnumRoundtripModule.hybrid @test eltype(opaque_tt.col) == String @test copy(opaque_tt.col) == ["a", "b"] @test Arrow.getmetadata(opaque_tt.col)["ARROW:extension:name"] == "arrow.opaque" - @test Arrow.getmetadata(opaque_tt.col)["ARROW:extension:metadata"] == opaque_meta + @test Arrow.getmetadata(opaque_tt.col)["ARROW:extension:metadata"] == + opaque_meta end @testset "canonical advanced passthrough" begin @@ -886,14 +890,12 @@ const hybrid = EnumRoundtripModule.hybrid permutation=[0], ) - variant_values = Union{ - Missing, - NamedTuple{(:metadata, :value),Tuple{String,String}}, - }[ - (metadata="json", value="{\"a\":1}"), - missing, - (metadata="str", value="abc"), - ] + variant_values = + Union{Missing,NamedTuple{(:metadata, :value),Tuple{String,String}}}[ + (metadata="json", value="{\"a\":1}"), + missing, + (metadata="str", value="abc"), + ] @test_logs min_level=Base.CoreLogging.Warn begin variant_tt = Arrow.Table( Arrow.tobuffer( @@ -986,7 +988,8 @@ const hybrid = EnumRoundtripModule.hybrid colmetadata=Dict( :col => Dict( "ARROW:extension:name" => "arrow.fixed_shape_tensor", - "ARROW:extension:metadata" => Arrow.fixedshapetensormetadata([3, 2]), + "ARROW:extension:metadata" => + Arrow.fixedshapetensormetadata([3, 2]), ), ), ) @@ -1000,9 +1003,10 @@ const hybrid = EnumRoundtripModule.hybrid colmetadata=Dict( :col => Dict( "ARROW:extension:name" => "arrow.variable_shape_tensor", - "ARROW:extension:metadata" => Arrow.variableshapetensormetadata( - uniform_shape=Union{Nothing,Int}[1], - ), + "ARROW:extension:metadata" => + Arrow.variableshapetensormetadata( + uniform_shape=Union{Nothing,Int}[1], + ), ), ), ) @@ -1012,6 +1016,43 @@ const hybrid = EnumRoundtripModule.hybrid ) end + @testset "logical extension runtime contract" begin + uuid_spec = Arrow._extensionspec(UUID) + @test uuid_spec isa Arrow.ExtensionTypeSpec + @test uuid_spec.name == Arrow.ArrowTypes.UUIDSYMBOL + @test uuid_spec.metadata == "" + @test Arrow._resolveextensionjuliatype( + Arrow.ExtensionTypeSpec(Arrow.ArrowTypes.LEGACY_UUIDSYMBOL, ""), + NTuple{16,UInt8}, + ) == UUID + + bool8_spec = Arrow._extensionspec(Arrow.Bool8) + @test bool8_spec isa Arrow.ExtensionTypeSpec + @test bool8_spec.name == Symbol("arrow.bool8") + @test Arrow._resolveextensionjuliatype(bool8_spec, Int8) == Arrow.Bool8 + + timestamp_storage = NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, + } + @test Arrow._resolveextensionjuliatype( + Arrow.ExtensionTypeSpec(Symbol("arrow.timestamp_with_offset"), ""), + timestamp_storage, + ) == Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND} + + opaque_spec = Arrow.ExtensionTypeSpec( + Symbol("arrow.opaque"), + Arrow.opaquemetadata("demo.type", "demo.vendor"), + ) + @test Arrow._resolveextensionjuliatype(opaque_spec, Vector{UInt8}) == + Vector{UInt8} + + @test Arrow._resolveextensionjuliatype( + Arrow.ExtensionTypeSpec(Symbol("JuliaLang.ZonedDateTime"), ""), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + ) == Arrow.LocalZonedDateTime + end + @testset "tensor message boundary" begin function patch_message_header_type(bytes, header_type::UInt8) patched = copy(bytes) @@ -1026,7 +1067,9 @@ const hybrid = EnumRoundtripModule.hybrid tensor_bytes = patch_message_header_type(base, UInt8(4)) @test_throws ArgumentError(Arrow.TENSOR_UNSUPPORTED) Arrow.Table(tensor_bytes) - @test_throws ArgumentError(Arrow.TENSOR_UNSUPPORTED) collect(Arrow.Stream(tensor_bytes)) + @test_throws ArgumentError(Arrow.TENSOR_UNSUPPORTED) collect( + Arrow.Stream(tensor_bytes), + ) sparse_tensor_bytes = patch_message_header_type(base, UInt8(5)) @test_throws ArgumentError(Arrow.SPARSE_TENSOR_UNSUPPORTED) Arrow.Table( From 1c7c795d8cc74cc3a12e7d1bd8ae40c676a73173 Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 17:53:36 -0700 Subject: [PATCH 12/16] Harden enum logical type resolution --- src/Arrow.jl | 1 + src/ArrowTypes/src/ArrowTypes.jl | 37 ++++++++++++++++++++--- src/ArrowTypes/test/tests.jl | 33 ++++++++++++++++++++ src/eltypes.jl | 31 ------------------- src/logicaltypes_builtin.jl | 52 ++++++++++++++++++++++++++++++++ test/runtests.jl | 49 ++++++++++++++++++++++++++++++ 6 files changed, 167 insertions(+), 36 deletions(-) create mode 100644 src/logicaltypes_builtin.jl diff --git a/src/Arrow.jl b/src/Arrow.jl index 8223449e..2d3e6956 100644 --- a/src/Arrow.jl +++ b/src/Arrow.jl @@ -97,6 +97,7 @@ include("utils.jl") include("logicaltypes.jl") include("arraytypes/arraytypes.jl") include("eltypes.jl") +include("logicaltypes_builtin.jl") include("table.jl") include("write.jl") include("append.jl") diff --git a/src/ArrowTypes/src/ArrowTypes.jl b/src/ArrowTypes/src/ArrowTypes.jl index 319224ee..825c6a54 100644 --- a/src/ArrowTypes/src/ArrowTypes.jl +++ b/src/ArrowTypes/src/ArrowTypes.jl @@ -214,7 +214,7 @@ JuliaType(::Val{CHAR}) = Char fromarrow(::Type{Char}, x::UInt32) = Char(x) ArrowType(::Type{T}) where {T<:Enum} = Base.Enums.basetype(T) -toarrow(x::T) where {T<:Enum} = convert(Base.Enums.basetype(T), Int(x)) +toarrow(x::T) where {T<:Enum} = Base.Enums.basetype(T)(x) const ENUM = Symbol("JuliaLang.Enum") arrowname(::Type{T}) where {T<:Enum} = ENUM @@ -225,10 +225,35 @@ end function _enum_labels(::Type{T}) where {T<:Enum} B = Base.Enums.basetype(T) - return join( - (string(instance, ":", convert(B, Int(instance))) for instance in instances(T)), - ",", - ) + return join((string(instance, ":", B(instance)) for instance in instances(T)), ",") +end + +function _parseenumlabels(labels::AbstractString, ::Type{B}) where {B<:Integer} + pairs = Pair{String,B}[] + isempty(labels) && return pairs + for entry in split(labels, ',') + isempty(entry) && return nothing + delimiter = findfirst(==(':'), entry) + delimiter === nothing && return nothing + label = entry[1:prevind(entry, delimiter)] + value = entry[nextind(entry, delimiter):end] + isempty(label) && return nothing + parsed = tryparse(B, value) + parsed === nothing && return nothing + push!(pairs, label => parsed) + end + return pairs +end + +function _enumlabelsmatch(::Type{T}, labels::AbstractString) where {T<:Enum} + B = Base.Enums.basetype(T) + parsed = _parseenumlabels(labels, B) + parsed === nothing && return false + expected = [string(instance) => B(instance) for instance in instances(T)] + length(parsed) == length(expected) || return false + parsed_dict = Dict(parsed) + length(parsed_dict) == length(parsed) || return false + return parsed_dict == Dict(expected) end function arrowmetadata(::Type{T}) where {T<:Enum} @@ -281,11 +306,13 @@ end function JuliaType(::Val{ENUM}, S, metadata::String) parsed = _parsemetadata(metadata) haskey(parsed, "type") || return nothing + haskey(parsed, "labels") || return nothing T = _resolvequalifiedtype(parsed["type"]) T isa DataType || return nothing T <: Enum || return nothing storage_type = Base.nonmissingtype(S) Base.Enums.basetype(T) === storage_type || return nothing + _enumlabelsmatch(T, parsed["labels"]) || return nothing return T end diff --git a/src/ArrowTypes/test/tests.jl b/src/ArrowTypes/test/tests.jl index 3d985b22..3363b4b5 100644 --- a/src/ArrowTypes/test/tests.jl +++ b/src/ArrowTypes/test/tests.jl @@ -26,10 +26,17 @@ module EnumTestModule @enum RankingStrategy lexical=1 semantic=2 hybrid=3 end +module WideEnumTestModule +@enum WideRanking::UInt64 small=1 colossal=0xffffffffffffffff +end + const RankingStrategy = EnumTestModule.RankingStrategy const lexical = EnumTestModule.lexical const semantic = EnumTestModule.semantic const hybrid = EnumTestModule.hybrid +const WideRanking = WideEnumTestModule.WideRanking +const small = WideEnumTestModule.small +const colossal = WideEnumTestModule.colossal @testset "ArrowTypes" begin @test ArrowTypes.ArrowKind(MyInt) == ArrowTypes.PrimitiveKind() @@ -85,9 +92,35 @@ const hybrid = EnumTestModule.hybrid @test occursin("labels=lexical:1,semantic:2,hybrid:3", enum_metadata) @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, enum_metadata) == RankingStrategy + reordered_enum_metadata = "type=Main.EnumTestModule.RankingStrategy;labels=semantic:2,hybrid:3,lexical:1" + mismatched_enum_metadata = "type=Main.EnumTestModule.RankingStrategy;labels=lexical:1,semantic:2,hybrid:4" + malformed_enum_metadata = "type=Main.EnumTestModule.RankingStrategy;labels=lexical:1,semantic:nope" + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, reordered_enum_metadata) == + RankingStrategy + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, mismatched_enum_metadata) === + nothing + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), Int32, malformed_enum_metadata) === + nothing + @test ArrowTypes.JuliaType( + Val(ArrowTypes.ENUM), + Int32, + "type=Main.EnumTestModule.RankingStrategy", + ) === nothing @test ArrowTypes.fromarrow(RankingStrategy, Int32(2)) == semantic @test ArrowTypes.default(RankingStrategy) == lexical + wide_enum_metadata = ArrowTypes.arrowmetadata(WideRanking) + @test ArrowTypes.ArrowKind(WideRanking) == ArrowTypes.PrimitiveKind() + @test ArrowTypes.ArrowType(WideRanking) == UInt64 + @test ArrowTypes.toarrow(colossal) == typemax(UInt64) + @test ArrowTypes.arrowname(WideRanking) == ArrowTypes.ENUM + @test occursin("type=Main.WideEnumTestModule.WideRanking", wide_enum_metadata) + @test occursin("labels=small:1,colossal:18446744073709551615", wide_enum_metadata) + @test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM), UInt64, wide_enum_metadata) == + WideRanking + @test ArrowTypes.fromarrow(WideRanking, typemax(UInt64)) == colossal + @test ArrowTypes.default(WideRanking) == small + @test ArrowTypes.ArrowKind(Bool) == ArrowTypes.BoolKind() @test ArrowTypes.ListKind() == ArrowTypes.ListKind{false}() diff --git a/src/eltypes.jl b/src/eltypes.jl index 9fbacb94..fe5b49a2 100644 --- a/src/eltypes.jl +++ b/src/eltypes.jl @@ -32,13 +32,6 @@ const PARQUET_VARIANT_SYMBOL = Symbol("arrow.parquet.variant") const FIXED_SHAPE_TENSOR_SYMBOL = Symbol("arrow.fixed_shape_tensor") const VARIABLE_SHAPE_TENSOR_SYMBOL = Symbol("arrow.variable_shape_tensor") -_builtinextensionspec(::Type{ArrowTypes.UUID}) = - ExtensionTypeSpec(ArrowTypes.UUIDSYMBOL, "") -_builtinextensionjuliatype(::Val{ArrowTypes.UUIDSYMBOL}, S, metadata::String) = - ArrowTypes.UUID -_builtinextensionjuliatype(::Val{ArrowTypes.LEGACY_UUIDSYMBOL}, S, metadata::String) = - ArrowTypes.UUID - @inline _canonicalextensionerror(sym::Symbol, msg::AbstractString) = throw(ArgumentError("invalid canonical $(String(sym)) extension: $msg")) @@ -339,8 +332,6 @@ ArrowTypes.arrowname(::Type{Bool8}) = BOOL8_SYMBOL ArrowTypes.JuliaType(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = Bool8 ArrowTypes.fromarrow(::Type{Bool8}, x::Int8) = Bool8(x) ArrowTypes.default(::Type{Bool8}) = zero(Bool8) -_builtinextensionspec(::Type{Bool8}) = ExtensionTypeSpec(BOOL8_SYMBOL, "") -_builtinextensionjuliatype(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = Bool8 function writearray( io::IO, @@ -374,22 +365,11 @@ ArrowTypes.fromarrow(::Type{JSONText{String}}, ptr::Ptr{UInt8}, len::Int) = ArrowTypes.fromarrow(::Type{JSONText{S}}, x::S) where {S<:AbstractString} = JSONText{S}(x) ArrowTypes.default(::Type{JSONText{S}}) where {S<:AbstractString} = JSONText{S}(ArrowTypes.default(S)) -_builtinextensionspec(::Type{JSONText{S}}) where {S<:AbstractString} = - ExtensionTypeSpec(JSON_SYMBOL, "") -_builtinextensionjuliatype( - ::Val{JSON_SYMBOL}, - ::Type{S}, - metadata::String, -) where {S<:AbstractString} = JSONText{S} ArrowTypes.JuliaType(::Val{OPAQUE_SYMBOL}, S, metadata::String) = S ArrowTypes.JuliaType(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = S ArrowTypes.JuliaType(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S ArrowTypes.JuliaType(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S -_builtinextensionjuliatype(::Val{OPAQUE_SYMBOL}, S, metadata::String) = S -_builtinextensionjuliatype(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = S -_builtinextensionjuliatype(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S -_builtinextensionjuliatype(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S @inline function _jsonstringliteral(x::AbstractString) return '"' * escape_string(x) * '"' @@ -700,8 +680,6 @@ ArrowTypes.JuliaType(::Val{ZONEDDATETIME_SYMBOL}, S) = ZonedDateTime ArrowTypes.fromarrow(::Type{ZonedDateTime}, x::Timestamp) = convert(ZonedDateTime, x) ArrowTypes.default(::Type{TimeZones.ZonedDateTime}) = TimeZones.ZonedDateTime(1, 1, 1, 1, 1, 1, TimeZones.tz"UTC") -_builtinextensionspec(::Type{ZonedDateTime}) = ExtensionTypeSpec(ZONEDDATETIME_SYMBOL, "") -_builtinextensionjuliatype(::Val{ZONEDDATETIME_SYMBOL}, S, metadata::String) = ZonedDateTime const TIMESTAMP_WITH_OFFSET_SYMBOL = Symbol("arrow.timestamp_with_offset") ArrowTypes.arrowname(::Type{TimestampWithOffset{U}}) where {U} = @@ -724,21 +702,12 @@ ArrowTypes.fromarrowstruct( offset_minutes::Int16, timestamp::Timestamp{U,:UTC}, ) where {U} = TimestampWithOffset{U}(timestamp, offset_minutes) -_builtinextensionspec(::Type{TimestampWithOffset{U}}) where {U} = - ExtensionTypeSpec(TIMESTAMP_WITH_OFFSET_SYMBOL, "") -_builtinextensionjuliatype( - ::Val{TIMESTAMP_WITH_OFFSET_SYMBOL}, - ::Type{NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}}}, - metadata::String, -) where {U} = TimestampWithOffset{U} # Backwards compatibility: older versions of Arrow saved ZonedDateTime's with this metdata: const OLD_ZONEDDATETIME_SYMBOL = Symbol("JuliaLang.ZonedDateTime") # and stored the local time instead of the UTC time. struct LocalZonedDateTime end ArrowTypes.JuliaType(::Val{OLD_ZONEDDATETIME_SYMBOL}, S) = LocalZonedDateTime -_builtinextensionjuliatype(::Val{OLD_ZONEDDATETIME_SYMBOL}, S, metadata::String) = - LocalZonedDateTime function ArrowTypes.fromarrow(::Type{LocalZonedDateTime}, x::Timestamp{U,TZ}) where {U,TZ} (U === Meta.TimeUnit.MICROSECOND || U == Meta.TimeUnit.NANOSECOND) && warntimestamp(U, ZonedDateTime) diff --git a/src/logicaltypes_builtin.jl b/src/logicaltypes_builtin.jl new file mode 100644 index 00000000..9bd76383 --- /dev/null +++ b/src/logicaltypes_builtin.jl @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_builtinextensionspec(::Type{ArrowTypes.UUID}) = + ExtensionTypeSpec(ArrowTypes.UUIDSYMBOL, "") +_builtinextensionjuliatype(::Val{ArrowTypes.UUIDSYMBOL}, S, metadata::String) = + ArrowTypes.UUID +_builtinextensionjuliatype(::Val{ArrowTypes.LEGACY_UUIDSYMBOL}, S, metadata::String) = + ArrowTypes.UUID + +_builtinextensionspec(::Type{Bool8}) = ExtensionTypeSpec(BOOL8_SYMBOL, "") +_builtinextensionjuliatype(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = Bool8 + +_builtinextensionspec(::Type{JSONText{S}}) where {S<:AbstractString} = + ExtensionTypeSpec(JSON_SYMBOL, "") +_builtinextensionjuliatype( + ::Val{JSON_SYMBOL}, + ::Type{S}, + metadata::String, +) where {S<:AbstractString} = JSONText{S} + +_builtinextensionjuliatype(::Val{OPAQUE_SYMBOL}, S, metadata::String) = S +_builtinextensionjuliatype(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = S +_builtinextensionjuliatype(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S +_builtinextensionjuliatype(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S + +_builtinextensionspec(::Type{ZonedDateTime}) = ExtensionTypeSpec(ZONEDDATETIME_SYMBOL, "") +_builtinextensionjuliatype(::Val{ZONEDDATETIME_SYMBOL}, S, metadata::String) = ZonedDateTime + +_builtinextensionspec(::Type{TimestampWithOffset{U}}) where {U} = + ExtensionTypeSpec(TIMESTAMP_WITH_OFFSET_SYMBOL, "") +_builtinextensionjuliatype( + ::Val{TIMESTAMP_WITH_OFFSET_SYMBOL}, + ::Type{NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}}}, + metadata::String, +) where {U} = TimestampWithOffset{U} + +_builtinextensionjuliatype(::Val{OLD_ZONEDDATETIME_SYMBOL}, S, metadata::String) = + LocalZonedDateTime diff --git a/test/runtests.jl b/test/runtests.jl index 6d0d3e7b..147d8ee9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,6 +56,10 @@ module EnumRoundtripModule @enum RankingStrategy lexical=1 semantic=2 hybrid=3 end +module WideEnumRoundtripModule +@enum WideRanking::UInt64 tiny=1 colossal=0xffffffffffffffff +end + @testset ExtendedTestSet "Arrow" begin @testset "table roundtrips" begin for case in testtables @@ -496,6 +500,51 @@ end ) end + @testset "# Julia Enum extension contract edge cases" begin + t = ( + col=[WideEnumRoundtripModule.tiny, WideEnumRoundtripModule.colossal], + nullable=Union{Missing,WideEnumRoundtripModule.WideRanking}[ + missing, + WideEnumRoundtripModule.colossal, + ], + ) + bytes = read(Arrow.tobuffer(t)) + tt = Arrow.Table(IOBuffer(bytes)) + raw = Arrow.Table(IOBuffer(bytes); convert=false) + + @test eltype(tt.col) == WideEnumRoundtripModule.WideRanking + @test eltype(tt.nullable) == Union{Missing,WideEnumRoundtripModule.WideRanking} + @test tt.col == [WideEnumRoundtripModule.tiny, WideEnumRoundtripModule.colossal] + @test isequal( + tt.nullable, + Union{Missing,WideEnumRoundtripModule.WideRanking}[ + missing, + WideEnumRoundtripModule.colossal, + ], + ) + @test eltype(raw.col) == UInt64 + @test eltype(raw.nullable) == Union{Missing,UInt64} + @test raw.col == UInt64[1, typemax(UInt64)] + @test isequal(raw.nullable, Union{Missing,UInt64}[missing, typemax(UInt64)]) + + mismatch_metadata = "type=Main.WideEnumRoundtripModule.WideRanking;labels=tiny:1,colossal:2" + @test_logs (:warn, r"unsupported ARROW:extension:name type: \"JuliaLang.Enum\"") begin + mismatch_tt = Arrow.Table( + Arrow.tobuffer( + (col=UInt64[1, typemax(UInt64)],); + colmetadata=Dict( + :col => Dict( + "ARROW:extension:name" => "JuliaLang.Enum", + "ARROW:extension:metadata" => mismatch_metadata, + ), + ), + ), + ) + @test eltype(mismatch_tt.col) == UInt64 + @test copy(mismatch_tt.col) == UInt64[1, typemax(UInt64)] + end + end + @testset "# 76" begin t = (col1=NamedTuple{(:a,),Tuple{Union{Int,String}}}[(a=1,), (a="x",)],) tt = Arrow.Table(Arrow.tobuffer(t)) From 680577163bed4707f70698e66627c7aee9c1dedb Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 20:08:16 -0700 Subject: [PATCH 13/16] Complete built-in logical extension contract shims --- src/eltypes.jl | 181 ++++++++++++++-------------------- src/logicaltypes.jl | 17 ++++ src/logicaltypes_builtin.jl | 134 ++++++++++++++++++++++++- test/flight/ipc_conversion.jl | 17 ++++ test/runtests.jl | 137 +++++++++++++++++++++++++ 5 files changed, 377 insertions(+), 109 deletions(-) diff --git a/src/eltypes.jl b/src/eltypes.jl index fe5b49a2..98a2ab97 100644 --- a/src/eltypes.jl +++ b/src/eltypes.jl @@ -208,21 +208,6 @@ function _validatevariableshapetensor(field::Meta.Field, metadata::String) return end -function _validatecanonicalpassthrough( - field::Meta.Field, - typenamesym::Symbol, - metadata::String, -) - if typenamesym === PARQUET_VARIANT_SYMBOL - _validateparquetvariant(field, metadata) - elseif typenamesym === FIXED_SHAPE_TENSOR_SYMBOL - _validatefixedshapetensor(field, metadata) - elseif typenamesym === VARIABLE_SHAPE_TENSOR_SYMBOL - _validatevariableshapetensor(field, metadata) - end - return -end - """ Given a FlatBuffers.Builder and a Julia column or column eltype, Write the field.type flatbuffer definition of the eltype @@ -242,7 +227,7 @@ function juliaeltype(f::Meta.Field, meta::AbstractDict{String,String}, convert:: TT = juliaeltype(f, convert) spec = _extensionspec(meta) if spec !== nothing - _validatecanonicalpassthrough(f, spec.name, spec.metadata) + _validatebuiltinextension(spec, f) !convert && return TT T = finaljuliatype(TT) storageT = @@ -326,12 +311,13 @@ Base.zero(::Type{Bool8}) = Bool8(false) Base.:(==)(x::Bool8, y::Bool8) = Bool(x) == Bool(y) Base.isequal(x::Bool8, y::Bool8) = isequal(Bool(x), Bool(y)) -ArrowTypes.ArrowType(::Type{Bool8}) = Int8 -ArrowTypes.toarrow(x::Bool8) = Int8(Bool(x)) -ArrowTypes.arrowname(::Type{Bool8}) = BOOL8_SYMBOL -ArrowTypes.JuliaType(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = Bool8 -ArrowTypes.fromarrow(::Type{Bool8}, x::Int8) = Bool8(x) -ArrowTypes.default(::Type{Bool8}) = zero(Bool8) +ArrowTypes.ArrowType(::Type{Bool8}) = _builtinarrowtype(Bool8) +ArrowTypes.toarrow(x::Bool8) = _builtintoarrow(x) +ArrowTypes.arrowname(::Type{Bool8}) = _builtinarrowname(Bool8) +ArrowTypes.JuliaType(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = + _builtinextensionjuliatype(Val(BOOL8_SYMBOL), Int8, metadata) +ArrowTypes.fromarrow(::Type{Bool8}, x::Int8) = _builtinfromarrow(Bool8, x) +ArrowTypes.default(::Type{Bool8}) = _builtindefault(Bool8) function writearray( io::IO, @@ -352,62 +338,51 @@ Base.convert(::Type{String}, x::JSONText) = String(x) Base.:(==)(x::JSONText, y::JSONText) = getfield(x, :value) == getfield(y, :value) Base.isequal(x::JSONText, y::JSONText) = isequal(getfield(x, :value), getfield(y, :value)) -ArrowTypes.ArrowType(::Type{JSONText{S}}) where {S<:AbstractString} = S -ArrowTypes.toarrow(x::JSONText) = getfield(x, :value) -ArrowTypes.arrowname(::Type{JSONText{S}}) where {S<:AbstractString} = JSON_SYMBOL +ArrowTypes.ArrowType(::Type{JSONText{S}}) where {S<:AbstractString} = + _builtinarrowtype(JSONText{S}) +ArrowTypes.toarrow(x::JSONText) = _builtintoarrow(x) +ArrowTypes.arrowname(::Type{JSONText{S}}) where {S<:AbstractString} = + _builtinarrowname(JSONText{S}) ArrowTypes.JuliaType( ::Val{JSON_SYMBOL}, ::Type{S}, metadata::String, -) where {S<:AbstractString} = JSONText{S} +) where {S<:AbstractString} = _builtinextensionjuliatype(Val(JSON_SYMBOL), S, metadata) ArrowTypes.fromarrow(::Type{JSONText{String}}, ptr::Ptr{UInt8}, len::Int) = - JSONText(unsafe_string(ptr, len)) -ArrowTypes.fromarrow(::Type{JSONText{S}}, x::S) where {S<:AbstractString} = JSONText{S}(x) + _builtinfromarrow(JSONText{String}, ptr, len) +ArrowTypes.fromarrow(::Type{JSONText{S}}, x::S) where {S<:AbstractString} = + _builtinfromarrow(JSONText{S}, x) ArrowTypes.default(::Type{JSONText{S}}) where {S<:AbstractString} = - JSONText{S}(ArrowTypes.default(S)) + _builtindefault(JSONText{S}) -ArrowTypes.JuliaType(::Val{OPAQUE_SYMBOL}, S, metadata::String) = S -ArrowTypes.JuliaType(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = S -ArrowTypes.JuliaType(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S -ArrowTypes.JuliaType(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S +ArrowTypes.JuliaType(::Val{OPAQUE_SYMBOL}, S, metadata::String) = + _builtinextensionjuliatype(Val(OPAQUE_SYMBOL), S, metadata) +ArrowTypes.JuliaType(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = + _builtinextensionjuliatype(Val(PARQUET_VARIANT_SYMBOL), S, metadata) +ArrowTypes.JuliaType(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = + _builtinextensionjuliatype(Val(FIXED_SHAPE_TENSOR_SYMBOL), S, metadata) +ArrowTypes.JuliaType(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = + _builtinextensionjuliatype(Val(VARIABLE_SHAPE_TENSOR_SYMBOL), S, metadata) @inline function _jsonstringliteral(x::AbstractString) return '"' * escape_string(x) * '"' end opaquemetadata(type_name::AbstractString, vendor_name::AbstractString) = - "{\"type_name\":" * - _jsonstringliteral(type_name) * - ",\"vendor_name\":" * - _jsonstringliteral(vendor_name) * - "}" + _builtinopaquemetadata(type_name, vendor_name) -variantmetadata() = "" +variantmetadata() = _builtinvariantmetadata() function fixedshapetensormetadata( shape::AbstractVector{<:Integer}; dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, ) - parsed_shape = _parseintvector(FIXED_SHAPE_TENSOR_SYMBOL, collect(shape), "shape") - parsed_dim_names = dim_names === nothing ? nothing : String.(dim_names) - parsed_permutation = - permutation === nothing ? nothing : - _validatepermutation( - FIXED_SHAPE_TENSOR_SYMBOL, - Int.(permutation), - length(parsed_shape), - ) - parsed_dim_names !== nothing && length(parsed_dim_names) == length(parsed_shape) || - isnothing(parsed_dim_names) || - _canonicalextensionerror( - FIXED_SHAPE_TENSOR_SYMBOL, - "\"dim_names\" must have length $(length(parsed_shape))", - ) - body = Dict{String,Any}("shape" => parsed_shape) - parsed_dim_names !== nothing && (body["dim_names"] = parsed_dim_names) - parsed_permutation !== nothing && (body["permutation"] = parsed_permutation) - return JSON3.write(body) + return _builtinfixedshapetensormetadata( + shape; + dim_names=dim_names, + permutation=permutation, + ) end function variableshapetensormetadata(; @@ -415,32 +390,11 @@ function variableshapetensormetadata(; dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, ) - uniform = - uniform_shape === nothing ? nothing : - _parseintvector( - VARIABLE_SHAPE_TENSOR_SYMBOL, - collect(uniform_shape), - "uniform_shape"; - allow_null=true, - ) - ndim = uniform === nothing ? nothing : length(uniform) - parsed_dim_names = dim_names === nothing ? nothing : String.(dim_names) - parsed_permutation = permutation === nothing ? nothing : Int.(permutation) - ndim !== nothing && parsed_dim_names !== nothing && length(parsed_dim_names) == ndim || - ndim === nothing || - isnothing(parsed_dim_names) || - _canonicalextensionerror( - VARIABLE_SHAPE_TENSOR_SYMBOL, - "\"dim_names\" must have length $ndim", - ) - ndim !== nothing && - parsed_permutation !== nothing && - _validatepermutation(VARIABLE_SHAPE_TENSOR_SYMBOL, parsed_permutation, ndim) - body = Dict{String,Any}() - uniform !== nothing && (body["uniform_shape"] = uniform) - parsed_dim_names !== nothing && (body["dim_names"] = parsed_dim_names) - parsed_permutation !== nothing && (body["permutation"] = parsed_permutation) - return isempty(body) ? "" : JSON3.write(body) + return _builtinvariableshapetensormetadata(; + uniform_shape=uniform_shape, + dim_names=dim_names, + permutation=permutation, + ) end # primitive types @@ -671,53 +625,64 @@ ArrowTypes.fromarrow(::Type{Dates.DateTime}, x::Date{Meta.DateUnit.MILLISECOND,I convert(Dates.DateTime, x) ArrowTypes.default(::Type{Dates.DateTime}) = Dates.DateTime(1, 1, 1, 1, 1, 1) -ArrowTypes.ArrowType(::Type{ZonedDateTime}) = Timestamp -ArrowTypes.toarrow(x::ZonedDateTime) = - convert(Timestamp{Meta.TimeUnit.MILLISECOND,Symbol(x.timezone)}, x) +ArrowTypes.ArrowType(::Type{ZonedDateTime}) = _builtinarrowtype(ZonedDateTime) +ArrowTypes.toarrow(x::ZonedDateTime) = _builtintoarrow(x) const ZONEDDATETIME_SYMBOL = Symbol("JuliaLang.ZonedDateTime-UTC") -ArrowTypes.arrowname(::Type{ZonedDateTime}) = ZONEDDATETIME_SYMBOL -ArrowTypes.JuliaType(::Val{ZONEDDATETIME_SYMBOL}, S) = ZonedDateTime -ArrowTypes.fromarrow(::Type{ZonedDateTime}, x::Timestamp) = convert(ZonedDateTime, x) -ArrowTypes.default(::Type{TimeZones.ZonedDateTime}) = - TimeZones.ZonedDateTime(1, 1, 1, 1, 1, 1, TimeZones.tz"UTC") +ArrowTypes.arrowname(::Type{ZonedDateTime}) = _builtinarrowname(ZonedDateTime) +ArrowTypes.JuliaType(::Val{ZONEDDATETIME_SYMBOL}, S) = + _builtinextensionjuliatype(Val(ZONEDDATETIME_SYMBOL), S) +ArrowTypes.fromarrow(::Type{ZonedDateTime}, x::Timestamp) = + _builtinfromarrow(ZonedDateTime, x) +ArrowTypes.default(::Type{TimeZones.ZonedDateTime}) = _builtindefault(ZonedDateTime) const TIMESTAMP_WITH_OFFSET_SYMBOL = Symbol("arrow.timestamp_with_offset") +ArrowTypes.ArrowType(::Type{TimestampWithOffset{U}}) where {U} = + _builtinarrowtype(TimestampWithOffset{U}) +ArrowTypes.toarrow(x::TimestampWithOffset{U}) where {U} = _builtintoarrow(x) ArrowTypes.arrowname(::Type{TimestampWithOffset{U}}) where {U} = - TIMESTAMP_WITH_OFFSET_SYMBOL + _builtinarrowname(TimestampWithOffset{U}) ArrowTypes.JuliaType( ::Val{TIMESTAMP_WITH_OFFSET_SYMBOL}, ::Type{NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}}}, metadata::String, -) where {U} = TimestampWithOffset{U} -ArrowTypes.default(::Type{TimestampWithOffset{U}}) where {U} = zero(TimestampWithOffset{U}) +) where {U} = _builtinextensionjuliatype( + Val(TIMESTAMP_WITH_OFFSET_SYMBOL), + NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}}, + metadata, +) +ArrowTypes.default(::Type{TimestampWithOffset{U}}) where {U} = + _builtindefault(TimestampWithOffset{U}) ArrowTypes.fromarrowstruct( ::Type{TimestampWithOffset{U}}, ::Val{(:timestamp, :offset_minutes)}, timestamp::Timestamp{U,:UTC}, offset_minutes::Int16, -) where {U} = TimestampWithOffset{U}(timestamp, offset_minutes) +) where {U} = _builtinfromarrowstruct( + TimestampWithOffset{U}, + Val((:timestamp, :offset_minutes)), + timestamp, + offset_minutes, +) ArrowTypes.fromarrowstruct( ::Type{TimestampWithOffset{U}}, ::Val{(:offset_minutes, :timestamp)}, offset_minutes::Int16, timestamp::Timestamp{U,:UTC}, -) where {U} = TimestampWithOffset{U}(timestamp, offset_minutes) +) where {U} = _builtinfromarrowstruct( + TimestampWithOffset{U}, + Val((:offset_minutes, :timestamp)), + offset_minutes, + timestamp, +) # Backwards compatibility: older versions of Arrow saved ZonedDateTime's with this metdata: const OLD_ZONEDDATETIME_SYMBOL = Symbol("JuliaLang.ZonedDateTime") # and stored the local time instead of the UTC time. struct LocalZonedDateTime end -ArrowTypes.JuliaType(::Val{OLD_ZONEDDATETIME_SYMBOL}, S) = LocalZonedDateTime -function ArrowTypes.fromarrow(::Type{LocalZonedDateTime}, x::Timestamp{U,TZ}) where {U,TZ} - (U === Meta.TimeUnit.MICROSECOND || U == Meta.TimeUnit.NANOSECOND) && - warntimestamp(U, ZonedDateTime) - return ZonedDateTime( - Dates.DateTime( - Dates.UTM(Int64(Dates.toms(periodtype(U)(x.x)) + UNIX_EPOCH_DATETIME)), - ), - TimeZone(String(TZ)), - ) -end +ArrowTypes.JuliaType(::Val{OLD_ZONEDDATETIME_SYMBOL}, S) = + _builtinextensionjuliatype(Val(OLD_ZONEDDATETIME_SYMBOL), S) +ArrowTypes.fromarrow(::Type{LocalZonedDateTime}, x::Timestamp{U,TZ}) where {U,TZ} = + _builtinfromarrow(LocalZonedDateTime, x) """ Arrow.ToTimestamp(x::AbstractVector{ZonedDateTime}) diff --git a/src/logicaltypes.jl b/src/logicaltypes.jl index cdd28844..7695692d 100644 --- a/src/logicaltypes.jl +++ b/src/logicaltypes.jl @@ -25,7 +25,20 @@ end @inline _extensiontypename(spec::ExtensionTypeSpec) = String(spec.name) @inline _builtinextensionspec(::Type{T}) where {T} = nothing +@inline _builtinextensionjuliatype(::Val{name}, storageT) where {name} = + _builtinextensionjuliatype(Val(name), storageT, "") @inline _builtinextensionjuliatype(::Val{name}, storageT, metadata) where {name} = nothing +@inline _builtinarrowtype(::Type{T}) where {T} = nothing +@inline _builtintoarrow(x) = nothing +@inline _builtinarrowname(::Type{T}) where {T} = nothing +function _builtinfromarrow end +function _builtinfromarrowstruct end +function _builtindefault end +function _builtinopaquemetadata end +function _builtinvariantmetadata end +function _builtinfixedshapetensormetadata end +function _builtinvariableshapetensormetadata end +@inline _validatebuiltinextension(::Val{name}, field, metadata) where {name} = nothing @inline function _extensionmetadatafor(::Type{T}, meta) where {T} spec = _extensionspec(T) @@ -82,3 +95,7 @@ end builtin !== nothing && return builtin return ArrowTypes.JuliaType(Val(spec.name), storageT, spec.metadata) end + +@inline function _validatebuiltinextension(spec::ExtensionTypeSpec, field::Meta.Field) + return _validatebuiltinextension(Val(spec.name), field, spec.metadata) +end diff --git a/src/logicaltypes_builtin.jl b/src/logicaltypes_builtin.jl index 9bd76383..428811ec 100644 --- a/src/logicaltypes_builtin.jl +++ b/src/logicaltypes_builtin.jl @@ -14,39 +14,171 @@ # See the License for the specific language governing permissions and # limitations under the License. +_builtinarrowtype(::Type{ArrowTypes.UUID}) = NTuple{16,UInt8} +_builtintoarrow(x::ArrowTypes.UUID) = ArrowTypes._cast(NTuple{16,UInt8}, x.value) +_builtinarrowname(::Type{ArrowTypes.UUID}) = ArrowTypes.UUIDSYMBOL _builtinextensionspec(::Type{ArrowTypes.UUID}) = - ExtensionTypeSpec(ArrowTypes.UUIDSYMBOL, "") + ExtensionTypeSpec(_builtinarrowname(ArrowTypes.UUID), "") _builtinextensionjuliatype(::Val{ArrowTypes.UUIDSYMBOL}, S, metadata::String) = ArrowTypes.UUID _builtinextensionjuliatype(::Val{ArrowTypes.LEGACY_UUIDSYMBOL}, S, metadata::String) = ArrowTypes.UUID _builtinextensionspec(::Type{Bool8}) = ExtensionTypeSpec(BOOL8_SYMBOL, "") +_builtinarrowtype(::Type{Bool8}) = Int8 +_builtintoarrow(x::Bool8) = Int8(Bool(x)) +_builtinarrowname(::Type{Bool8}) = BOOL8_SYMBOL _builtinextensionjuliatype(::Val{BOOL8_SYMBOL}, ::Type{Int8}, metadata::String) = Bool8 +_builtinfromarrow(::Type{Bool8}, x::Int8) = Bool8(x) +_builtindefault(::Type{Bool8}) = zero(Bool8) _builtinextensionspec(::Type{JSONText{S}}) where {S<:AbstractString} = ExtensionTypeSpec(JSON_SYMBOL, "") +_builtinarrowtype(::Type{JSONText{S}}) where {S<:AbstractString} = S +_builtintoarrow(x::JSONText) = getfield(x, :value) +_builtinarrowname(::Type{JSONText{S}}) where {S<:AbstractString} = JSON_SYMBOL _builtinextensionjuliatype( ::Val{JSON_SYMBOL}, ::Type{S}, metadata::String, ) where {S<:AbstractString} = JSONText{S} +_builtinfromarrow(::Type{JSONText{String}}, ptr::Ptr{UInt8}, len::Int) = + JSONText(unsafe_string(ptr, len)) +_builtinfromarrow(::Type{JSONText{S}}, x::S) where {S<:AbstractString} = JSONText{S}(x) +_builtindefault(::Type{JSONText{S}}) where {S<:AbstractString} = + JSONText{S}(ArrowTypes.default(S)) _builtinextensionjuliatype(::Val{OPAQUE_SYMBOL}, S, metadata::String) = S _builtinextensionjuliatype(::Val{PARQUET_VARIANT_SYMBOL}, S, metadata::String) = S _builtinextensionjuliatype(::Val{FIXED_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S _builtinextensionjuliatype(::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, S, metadata::String) = S +_builtinopaquemetadata(type_name::AbstractString, vendor_name::AbstractString) = + "{\"type_name\":" * + _jsonstringliteral(type_name) * + ",\"vendor_name\":" * + _jsonstringliteral(vendor_name) * + "}" +_builtinvariantmetadata() = "" + +function _builtinfixedshapetensormetadata( + shape::AbstractVector{<:Integer}; + dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, + permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, +) + parsed_shape = _parseintvector(FIXED_SHAPE_TENSOR_SYMBOL, collect(shape), "shape") + parsed_dim_names = dim_names === nothing ? nothing : String.(dim_names) + parsed_permutation = + permutation === nothing ? nothing : + _validatepermutation( + FIXED_SHAPE_TENSOR_SYMBOL, + Int.(permutation), + length(parsed_shape), + ) + parsed_dim_names !== nothing && length(parsed_dim_names) == length(parsed_shape) || + isnothing(parsed_dim_names) || + _canonicalextensionerror( + FIXED_SHAPE_TENSOR_SYMBOL, + "\"dim_names\" must have length $(length(parsed_shape))", + ) + body = Dict{String,Any}("shape" => parsed_shape) + parsed_dim_names !== nothing && (body["dim_names"] = parsed_dim_names) + parsed_permutation !== nothing && (body["permutation"] = parsed_permutation) + return JSON3.write(body) +end + +function _builtinvariableshapetensormetadata(; + uniform_shape::Union{Nothing,AbstractVector}=nothing, + dim_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, + permutation::Union{Nothing,AbstractVector{<:Integer}}=nothing, +) + uniform = + uniform_shape === nothing ? nothing : + _parseintvector( + VARIABLE_SHAPE_TENSOR_SYMBOL, + collect(uniform_shape), + "uniform_shape"; + allow_null=true, + ) + ndim = uniform === nothing ? nothing : length(uniform) + parsed_dim_names = dim_names === nothing ? nothing : String.(dim_names) + parsed_permutation = permutation === nothing ? nothing : Int.(permutation) + ndim !== nothing && parsed_dim_names !== nothing && length(parsed_dim_names) == ndim || + ndim === nothing || + isnothing(parsed_dim_names) || + _canonicalextensionerror( + VARIABLE_SHAPE_TENSOR_SYMBOL, + "\"dim_names\" must have length $ndim", + ) + ndim !== nothing && + parsed_permutation !== nothing && + _validatepermutation(VARIABLE_SHAPE_TENSOR_SYMBOL, parsed_permutation, ndim) + body = Dict{String,Any}() + uniform !== nothing && (body["uniform_shape"] = uniform) + parsed_dim_names !== nothing && (body["dim_names"] = parsed_dim_names) + parsed_permutation !== nothing && (body["permutation"] = parsed_permutation) + return isempty(body) ? "" : JSON3.write(body) +end +_validatebuiltinextension( + ::Val{PARQUET_VARIANT_SYMBOL}, + field::Meta.Field, + metadata::String, +) = _validateparquetvariant(field, metadata) +_validatebuiltinextension( + ::Val{FIXED_SHAPE_TENSOR_SYMBOL}, + field::Meta.Field, + metadata::String, +) = _validatefixedshapetensor(field, metadata) +_validatebuiltinextension( + ::Val{VARIABLE_SHAPE_TENSOR_SYMBOL}, + field::Meta.Field, + metadata::String, +) = _validatevariableshapetensor(field, metadata) _builtinextensionspec(::Type{ZonedDateTime}) = ExtensionTypeSpec(ZONEDDATETIME_SYMBOL, "") +_builtinarrowtype(::Type{ZonedDateTime}) = Timestamp +_builtintoarrow(x::ZonedDateTime) = + convert(Timestamp{Meta.TimeUnit.MILLISECOND,Symbol(x.timezone)}, x) +_builtinarrowname(::Type{ZonedDateTime}) = ZONEDDATETIME_SYMBOL _builtinextensionjuliatype(::Val{ZONEDDATETIME_SYMBOL}, S, metadata::String) = ZonedDateTime +_builtinfromarrow(::Type{ZonedDateTime}, x::Timestamp) = convert(ZonedDateTime, x) +_builtindefault(::Type{TimeZones.ZonedDateTime}) = + TimeZones.ZonedDateTime(1, 1, 1, 1, 1, 1, TimeZones.tz"UTC") _builtinextensionspec(::Type{TimestampWithOffset{U}}) where {U} = ExtensionTypeSpec(TIMESTAMP_WITH_OFFSET_SYMBOL, "") +_builtinarrowtype(::Type{TimestampWithOffset{U}}) where {U} = + NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}} +_builtintoarrow(x::TimestampWithOffset{U}) where {U} = + (timestamp=getfield(x, :timestamp), offset_minutes=getfield(x, :offset_minutes)) +_builtinarrowname(::Type{TimestampWithOffset{U}}) where {U} = TIMESTAMP_WITH_OFFSET_SYMBOL _builtinextensionjuliatype( ::Val{TIMESTAMP_WITH_OFFSET_SYMBOL}, ::Type{NamedTuple{(:timestamp, :offset_minutes),Tuple{Timestamp{U,:UTC},Int16}}}, metadata::String, ) where {U} = TimestampWithOffset{U} +_builtindefault(::Type{TimestampWithOffset{U}}) where {U} = zero(TimestampWithOffset{U}) +_builtinfromarrowstruct( + ::Type{TimestampWithOffset{U}}, + ::Val{(:timestamp, :offset_minutes)}, + timestamp::Timestamp{U,:UTC}, + offset_minutes::Int16, +) where {U} = TimestampWithOffset{U}(timestamp, offset_minutes) +_builtinfromarrowstruct( + ::Type{TimestampWithOffset{U}}, + ::Val{(:offset_minutes, :timestamp)}, + offset_minutes::Int16, + timestamp::Timestamp{U,:UTC}, +) where {U} = TimestampWithOffset{U}(timestamp, offset_minutes) _builtinextensionjuliatype(::Val{OLD_ZONEDDATETIME_SYMBOL}, S, metadata::String) = LocalZonedDateTime +function _builtinfromarrow(::Type{LocalZonedDateTime}, x::Timestamp{U,TZ}) where {U,TZ} + (U === Meta.TimeUnit.MICROSECOND || U == Meta.TimeUnit.NANOSECOND) && + warntimestamp(U, ZonedDateTime) + return ZonedDateTime( + Dates.DateTime( + Dates.UTM(Int64(Dates.toms(periodtype(U)(x.x)) + UNIX_EPOCH_DATETIME)), + ), + TimeZone(String(TZ)), + ) +end diff --git a/test/flight/ipc_conversion.jl b/test/flight/ipc_conversion.jl index 0b59b662..e0f35bec 100644 --- a/test/flight/ipc_conversion.jl +++ b/test/flight/ipc_conversion.jl @@ -82,6 +82,14 @@ using UUIDs extension_source = ( uuid=[UUID(UInt128(1)), UUID(UInt128(2))], flag=[Arrow.Bool8(true), Arrow.Bool8(false)], + json=Union{Missing,Arrow.JSONText{String}}[Arrow.JSONText("{\"a\":1}"), missing], + ts=Union{Missing,Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}}[ + Arrow.TimestampWithOffset( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(123), + Int16(-480), + ), + missing, + ], ) extension_messages = Arrow.Flight.flightdata(extension_source) extension_batches = collect(Arrow.Flight.stream(extension_messages)) @@ -91,8 +99,17 @@ using UUIDs "arrow.uuid" @test Arrow.getmetadata(extension_batches[1].flag)[Arrow.EXTENSION_NAME_KEY] == "arrow.bool8" + @test Arrow.getmetadata(extension_batches[1].json)[Arrow.EXTENSION_NAME_KEY] == + "arrow.json" + @test Arrow.getmetadata(extension_batches[1].ts)[Arrow.EXTENSION_NAME_KEY] == + "arrow.timestamp_with_offset" @test Arrow.getmetadata(extension_tbl.uuid)[Arrow.EXTENSION_NAME_KEY] == "arrow.uuid" @test Arrow.getmetadata(extension_tbl.flag)[Arrow.EXTENSION_NAME_KEY] == "arrow.bool8" + @test Arrow.getmetadata(extension_tbl.json)[Arrow.EXTENSION_NAME_KEY] == "arrow.json" + @test Arrow.getmetadata(extension_tbl.ts)[Arrow.EXTENSION_NAME_KEY] == + "arrow.timestamp_with_offset" @test copy(extension_tbl.uuid) == extension_source.uuid @test Bool.(copy(extension_tbl.flag)) == Bool.(extension_source.flag) + @test isequal(copy(extension_tbl.json), extension_source.json) + @test isequal(copy(extension_tbl.ts), extension_source.ts) end diff --git a/test/runtests.jl b/test/runtests.jl index 147d8ee9..671694ec 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -787,6 +787,14 @@ end -480, ), ] + @test ArrowTypes.JuliaType( + Val(Symbol("arrow.timestamp_with_offset")), + NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, + }, + "", + ) == Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND} tt = Arrow.Table(Arrow.tobuffer((col=values,))) @test eltype(tt.col) == Union{Missing,Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}} @@ -851,6 +859,7 @@ end @testset "canonical bool8/json/opaque" begin bools = Union{Missing,Arrow.Bool8}[Arrow.Bool8(true), missing, Arrow.Bool8(false)] + @test ArrowTypes.JuliaType(Val(Symbol("arrow.bool8")), Int8, "") == Arrow.Bool8 tt = Arrow.Table(Arrow.tobuffer((col=bools,))) @test eltype(tt.col) == Union{Missing,Arrow.Bool8} @test isequal(copy(tt.col), bools) @@ -865,6 +874,8 @@ end missing, Arrow.JSONText("[1,2,3]"), ] + @test ArrowTypes.JuliaType(Val(Symbol("arrow.json")), String, "") == + Arrow.JSONText{String} json_tt = Arrow.Table(Arrow.tobuffer((col=jsons,))) @test eltype(json_tt.col) == Union{Missing,Arrow.JSONText{String}} @test isequal(copy(json_tt.col), jsons) @@ -878,6 +889,8 @@ end ) opaque_meta = Arrow.opaquemetadata("pkg.Type", "vendor.example") + @test ArrowTypes.JuliaType(Val(Symbol("arrow.opaque")), String, opaque_meta) == + String opaque_tt = Arrow.Table( Arrow.tobuffer( (col=["a", "b"],); @@ -925,6 +938,18 @@ end dim_names=["axis0"], permutation=[0], ) + @test ArrowTypes.JuliaType(Val(Symbol("arrow.parquet.variant")), String, "") == + String + @test ArrowTypes.JuliaType( + Val(Symbol("arrow.fixed_shape_tensor")), + NTuple{4,Int32}, + fixed_metadata, + ) == NTuple{4,Int32} + @test ArrowTypes.JuliaType( + Val(Symbol("arrow.variable_shape_tensor")), + NamedTuple{(:data, :shape),Tuple{Vector{Int32},NTuple{1,Int32}}}, + variable_metadata, + ) == NamedTuple{(:data, :shape),Tuple{Vector{Int32},NTuple{1,Int32}}} @test JSON3.read(variable_metadata)["uniform_shape"] == [2] @test JSON3.read(variable_metadata)["dim_names"] == ["axis0"] @test JSON3.read(variable_metadata)["permutation"] == [0] @@ -1066,6 +1091,16 @@ end end @testset "logical extension runtime contract" begin + uuid = UUID("550e8400-e29b-41d4-a716-446655440000") + @test Arrow._builtinarrowtype(UUID) == NTuple{16,UInt8} + @test Arrow._builtintoarrow(uuid) == + ArrowTypes._cast(NTuple{16,UInt8}, uuid.value) + @test Arrow._builtinarrowname(UUID) == Symbol("arrow.uuid") + @test ArrowTypes.ArrowType(UUID) == Arrow._builtinarrowtype(UUID) + @test ArrowTypes.toarrow(uuid) == Arrow._builtintoarrow(uuid) + @test ArrowTypes.arrowname(UUID) == Arrow._builtinarrowname(UUID) + @test ArrowTypes.JuliaType(Val(Symbol("arrow.uuid"))) == UUID + @test ArrowTypes.JuliaType(Val(Symbol("JuliaLang.UUID"))) == UUID uuid_spec = Arrow._extensionspec(UUID) @test uuid_spec isa Arrow.ExtensionTypeSpec @test uuid_spec.name == Arrow.ArrowTypes.UUIDSYMBOL @@ -1078,12 +1113,73 @@ end bool8_spec = Arrow._extensionspec(Arrow.Bool8) @test bool8_spec isa Arrow.ExtensionTypeSpec @test bool8_spec.name == Symbol("arrow.bool8") + @test Arrow._builtinarrowtype(Arrow.Bool8) == Int8 + @test Arrow._builtintoarrow(Arrow.Bool8(true)) == Int8(1) + @test Arrow._builtinarrowname(Arrow.Bool8) == Symbol("arrow.bool8") + @test Arrow._builtinfromarrow(Arrow.Bool8, Int8(1)) == Arrow.Bool8(true) + @test Arrow._builtindefault(Arrow.Bool8) == Arrow.Bool8(false) @test Arrow._resolveextensionjuliatype(bool8_spec, Int8) == Arrow.Bool8 + @test Arrow._builtinarrowtype(Arrow.JSONText{String}) == String + @test Arrow._builtintoarrow(Arrow.JSONText("abc")) == "abc" + @test Arrow._builtinarrowname(Arrow.JSONText{String}) == Symbol("arrow.json") + @test Arrow._builtinfromarrow(Arrow.JSONText{String}, pointer("abc"), 3) == + Arrow.JSONText("abc") + @test Arrow._builtinfromarrow(Arrow.JSONText{String}, "xyz") == + Arrow.JSONText("xyz") + @test Arrow._builtindefault(Arrow.JSONText{String}) == Arrow.JSONText("") + timestamp_storage = NamedTuple{ (:timestamp, :offset_minutes), Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, } + zdt = ZonedDateTime(Dates.DateTime(2020), tz"Europe/Paris") + @test Arrow._builtinarrowtype(ZonedDateTime) == Arrow.Timestamp + @test Arrow._builtintoarrow(zdt) == convert( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,Symbol("Europe/Paris")}, + zdt, + ) + @test Arrow._builtinarrowname(ZonedDateTime) == + Symbol("JuliaLang.ZonedDateTime-UTC") + paris_timestamp = + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,Symbol("Europe/Paris")}(0) + @test Arrow._builtinfromarrow(ZonedDateTime, paris_timestamp) == + convert(ZonedDateTime, paris_timestamp) + @test Arrow._builtindefault(ZonedDateTime) == + ZonedDateTime(1, 1, 1, 1, 1, 1, tz"UTC") + @test Arrow._builtinarrowname( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + ) == Symbol("arrow.timestamp_with_offset") + @test Arrow._builtinarrowtype( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + ) == NamedTuple{ + (:timestamp, :offset_minutes), + Tuple{Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC},Int16}, + } + ts_with_offset = Arrow.TimestampWithOffset( + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(123), + Int16(-480), + ) + @test Arrow._builtintoarrow(ts_with_offset) == ( + timestamp=Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(123), + offset_minutes=Int16(-480), + ) + @test ArrowTypes.ArrowType( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + ) == Arrow._builtinarrowtype( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + ) + @test ArrowTypes.toarrow(ts_with_offset) == + Arrow._builtintoarrow(ts_with_offset) + @test Arrow._builtindefault( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + ) == zero(Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}) + @test Arrow._builtinfromarrowstruct( + Arrow.TimestampWithOffset{Arrow.Meta.TimeUnit.MILLISECOND}, + Val((:timestamp, :offset_minutes)), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}(123), + Int16(-480), + ) == ts_with_offset @test Arrow._resolveextensionjuliatype( Arrow.ExtensionTypeSpec(Symbol("arrow.timestamp_with_offset"), ""), timestamp_storage, @@ -1093,8 +1189,49 @@ end Symbol("arrow.opaque"), Arrow.opaquemetadata("demo.type", "demo.vendor"), ) + @test Arrow.opaquemetadata("demo.type", "demo.vendor") == + Arrow._builtinopaquemetadata("demo.type", "demo.vendor") @test Arrow._resolveextensionjuliatype(opaque_spec, Vector{UInt8}) == Vector{UInt8} + @test Arrow.variantmetadata() == Arrow._builtinvariantmetadata() + @test Arrow.fixedshapetensormetadata( + [2, 2]; + dim_names=["row", "col"], + permutation=[1, 0], + ) == Arrow._builtinfixedshapetensormetadata( + [2, 2]; + dim_names=["row", "col"], + permutation=[1, 0], + ) + @test Arrow.variableshapetensormetadata( + uniform_shape=[2, nothing]; + dim_names=["row", "col"], + permutation=[1, 0], + ) == Arrow._builtinvariableshapetensormetadata( + uniform_shape=[2, nothing]; + dim_names=["row", "col"], + permutation=[1, 0], + ) + @test Arrow._builtinextensionjuliatype( + Val(Symbol("JuliaLang.ZonedDateTime-UTC")), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + ) == ZonedDateTime + @test ArrowTypes.JuliaType( + Val(Symbol("JuliaLang.ZonedDateTime-UTC")), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + ) == ZonedDateTime + @test Arrow._builtinextensionjuliatype( + Val(Symbol("JuliaLang.ZonedDateTime")), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + ) == Arrow.LocalZonedDateTime + @test ArrowTypes.JuliaType( + Val(Symbol("JuliaLang.ZonedDateTime")), + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,:UTC}, + ) == Arrow.LocalZonedDateTime + local_zdt_timestamp = + Arrow.Timestamp{Arrow.Meta.TimeUnit.MILLISECOND,Symbol("Europe/Paris")}(0) + @test Arrow._builtinfromarrow(Arrow.LocalZonedDateTime, local_zdt_timestamp) == + ArrowTypes.fromarrow(Arrow.LocalZonedDateTime, local_zdt_timestamp) @test Arrow._resolveextensionjuliatype( Arrow.ExtensionTypeSpec(Symbol("JuliaLang.ZonedDateTime"), ""), From 23abf82a2812c9bad8ae76117c6cb332d8dd9865 Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 21:28:48 -0700 Subject: [PATCH 14/16] Unify Flight native runtime decode and emit --- .github/workflows/ci.yml | 2 + .github/workflows/ci_nightly.yml | 4 - src/flight/convert/flightdata.jl | 114 +++++- src/flight/convert/streaming.jl | 345 +++++++++++++++++- src/flight/exports.jl | 3 +- src/table.jl | 33 +- .../bidi_streaming_tests.jl | 5 + .../server_streaming_tests.jl | 6 +- .../grpcserver_extension/support/fixture.jl | 12 + .../grpcserver_extension/support/service.jl | 27 +- test/flight/ipc_conversion.jl | 45 +++ test/flight/pyarrow_interop/exchange_tests.jl | 23 +- 12 files changed, 580 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 31f5b03a..f327c33b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -127,7 +127,9 @@ jobs: with: project: ${{ matrix.pkg.dir }} - uses: julia-actions/julia-processcoverage@v1 + if: matrix.pkg.name == 'Arrow.jl' && matrix.version == '1' && matrix.os == 'macos-latest' && matrix.nthreads == 1 - uses: codecov/codecov-action@v5 + if: matrix.pkg.name == 'Arrow.jl' && matrix.version == '1' && matrix.os == 'macos-latest' && matrix.nthreads == 1 with: files: lcov.info test_monorepo: diff --git a/.github/workflows/ci_nightly.yml b/.github/workflows/ci_nightly.yml index 9d7d6ce6..dd869da2 100644 --- a/.github/workflows/ci_nightly.yml +++ b/.github/workflows/ci_nightly.yml @@ -64,10 +64,6 @@ jobs: JULIA_NUM_THREADS: ${{ matrix.nthreads }} with: project: ${{ matrix.pkg.dir }} - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v5 - with: - files: lcov.info test_monorepo: name: Monorepo dev - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} runs-on: ${{ matrix.os }} diff --git a/src/flight/convert/flightdata.jl b/src/flight/convert/flightdata.jl index d78d83a9..c515a50a 100644 --- a/src/flight/convert/flightdata.jl +++ b/src/flight/convert/flightdata.jl @@ -15,7 +15,20 @@ # specific language governing permissions and limitations # under the License. -function flightdata( +function _sourcedefaultcolmetadata(cols) + sch = Tables.schema(cols) + isnothing(sch) && return nothing + colmeta = Dict{Symbol,Any}() + Tables.eachcolumn(sch, cols) do col, _, nm + meta = ArrowParent.getmetadata(col) + isnothing(meta) || (colmeta[nm] = meta) + end + isempty(colmeta) && return nothing + return ArrowParent._normalizecolmeta(colmeta) +end + +function _emitflightdata!( + emit, source; descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, compress=nothing, @@ -29,13 +42,25 @@ function flightdata( colmetadata::Union{Nothing,Any}=nothing, ) dictencodings = Dict{Int64,Any}() - messages = Protocol.FlightData[] schema = Ref{Tables.Schema}() normalized_colmetadata = ArrowParent._normalizecolmeta(colmetadata) - meta = isnothing(metadata) ? ArrowParent.getmetadata(source) : metadata + source_meta = isnothing(metadata) ? ArrowParent.getmetadata(source) : metadata + source_colmetadata = isnothing(colmetadata) ? nothing : normalized_colmetadata for tbl in Tables.partitions(source) tblcols = Tables.columns(tbl) + if isnothing(metadata) + tblmeta = ArrowParent.getmetadata(tbl) + isnothing(tblmeta) && (tblmeta = source_meta) + else + tblmeta = metadata + end + if isnothing(colmetadata) + tblcolmetadata = _sourcedefaultcolmetadata(tblcols) + isnothing(tblcolmetadata) && (tblcolmetadata = source_colmetadata) + else + tblcolmetadata = normalized_colmetadata + end cols = ArrowParent.toarrowtable( tblcols, dictencodings, @@ -45,13 +70,12 @@ function flightdata( dictencode, dictencodenested, maxdepth, - meta, - normalized_colmetadata, + tblmeta, + tblcolmetadata, ) if !isassigned(schema) schema[] = Tables.schema(cols) - push!( - messages, + emit( _flightdata_message( ArrowParent.makeschemamsg(schema[], cols); descriptor=descriptor, @@ -62,8 +86,7 @@ function flightdata( for (id, delock) in sort!(collect(dictencodings); by=x -> x.first, rev=true) de = delock.value dictsch = Tables.Schema((:col,), (eltype(de.data),)) - push!( - messages, + emit( _flightdata_message( ArrowParent.makedictionarybatchmsg( dictsch, @@ -80,8 +103,7 @@ function flightdata( elseif !isempty(cols.dictencodingdeltas) for de in cols.dictencodingdeltas dictsch = Tables.Schema((:col,), (eltype(de.data),)) - push!( - messages, + emit( _flightdata_message( ArrowParent.makedictionarybatchmsg( dictsch, @@ -95,8 +117,7 @@ function flightdata( ) end end - push!( - messages, + emit( _flightdata_message( ArrowParent.makerecordbatchmsg(schema[], cols, alignment); alignment=alignment, @@ -104,5 +125,72 @@ function flightdata( ) descriptor = nothing end + return nothing +end + +function flightdata( + source; + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, +) + messages = Protocol.FlightData[] + _emitflightdata!( + message -> push!(messages, message), + source; + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + ) return messages end + +function putflightdata!( + sink, + source; + close::Bool=false, + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, +) + try + _emitflightdata!( + message -> put!(sink, message), + source; + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + ) + finally + close && Base.close(sink) + end + return sink +end diff --git a/src/flight/convert/streaming.jl b/src/flight/convert/streaming.jl index b18396b6..458fee57 100644 --- a/src/flight/convert/streaming.jl +++ b/src/flight/convert/streaming.jl @@ -18,6 +18,76 @@ streambytes(message::Protocol.FlightData; kwargs...) = streambytes(Protocol.FlightData[message]; kwargs...) +mutable struct FlightStream{M} + messages::M + state::Any + started::Bool + exhausted::Bool + nextid::Int + names::Vector{Symbol} + types::Vector{Type} + schema::Union{Nothing,ArrowParent.Meta.Schema} + dictencodings::ArrowParent.Lockable{Dict{Int64,ArrowParent.DictEncoding}} + dictencoded::Dict{Int64,ArrowParent.Meta.Field} + convert::Bool +end + +struct FlightMetadataVector{T,A<:AbstractVector{T}} <: AbstractVector{T} + data::A + metadata::Union{Nothing,Base.ImmutableDict{String,String}} +end + +Base.IndexStyle(::Type{<:FlightMetadataVector}) = Base.IndexLinear() +Base.size(x::FlightMetadataVector) = size(x.data) +Base.axes(x::FlightMetadataVector) = axes(x.data) +Base.length(x::FlightMetadataVector) = length(x.data) +Base.getindex(x::FlightMetadataVector, i::Int) = getindex(x.data, i) +Base.iterate(x::FlightMetadataVector) = iterate(x.data) +Base.iterate(x::FlightMetadataVector, state) = iterate(x.data, state) +ArrowParent.getmetadata(x::FlightMetadataVector) = x.metadata + +function FlightStream(messages; schema=nothing, convert::Bool=true) + x = FlightStream( + messages, + nothing, + false, + false, + 0, + Symbol[], + Type[], + nothing, + ArrowParent.Lockable(Dict{Int64,ArrowParent.DictEncoding}()), + Dict{Int64,ArrowParent.Meta.Field}(), + convert, + ) + schema === nothing || _register_schema!(x, _flight_schema(schema)) + return x +end + +Base.IteratorSize(::Type{<:FlightStream}) = Base.SizeUnknown() +Base.eltype(::Type{<:FlightStream}) = ArrowParent.Table +Base.isdone(x::FlightStream) = x.exhausted + +Tables.partitions(x::FlightStream) = x + +function Tables.columnnames(x::FlightStream) + _ensure_schema!(x) + return getfield(x, :names) +end + +function Tables.schema(x::FlightStream) + _ensure_schema!(x) + return Tables.Schema(Tables.columnnames(x), getfield(x, :types)) +end + +function Base.iterate(x::FlightStream) + return _iterate_flight_stream!(x) +end + +function Base.iterate(x::FlightStream, ::Nothing) + return _iterate_flight_stream!(x) +end + function _missing_schema_message() return join( [ @@ -35,6 +105,273 @@ function _require_schema_messages(messages::AbstractVector{<:Protocol.FlightData throw(ArgumentError(_missing_schema_message())) end +function _flight_schema(schema) + schema isa ArrowParent.Meta.Schema && return schema + bytes = schemaipc(schema) + message = ArrowParent.FlatBuffers.getrootas(ArrowParent.Meta.Message, bytes, 8) + header = message.header + header isa ArrowParent.Meta.Schema || + throw(ArgumentError("Flight schema payload did not decode to an Arrow IPC schema")) + return header +end + +function _register_schema!(x::FlightStream, schema::ArrowParent.Meta.Schema) + if isnothing(getfield(x, :schema)) + setfield!(x, :schema, schema) + for field in schema.fields + ArrowParent.rejectunsupported(field) + push!(getfield(x, :names), Symbol(field.name)) + push!( + getfield(x, :types), + ArrowParent.juliaeltype( + field, + ArrowParent.buildmetadata(field.custom_metadata), + getfield(x, :convert), + ), + ) + ArrowParent.getdictionaries!(getfield(x, :dictencoded), field) + end + return x + end + schema == getfield(x, :schema) || throw( + ArgumentError( + "mismatched schemas between different arrow batches: $(getfield(x, :schema)) != $schema", + ), + ) + return x +end + +function _next_flight_message!(x::FlightStream) + getfield(x, :exhausted) && return nothing + state = + getfield(x, :started) ? iterate(getfield(x, :messages), getfield(x, :state)) : + iterate(getfield(x, :messages)) + setfield!(x, :started, true) + state === nothing && return (setfield!(x, :exhausted, true); nothing) + message, next_state = state + setfield!(x, :state, next_state) + setfield!(x, :nextid, getfield(x, :nextid) + 1) + return message +end + +function _flight_batch(message::Protocol.FlightData, id::Integer) + isempty(message.data_header) && + throw(ArgumentError("FlightData message is missing the Arrow IPC header")) + msg = + ArrowParent.FlatBuffers.getrootas(ArrowParent.Meta.Message, message.data_header, 0) + return ArrowParent.Batch(msg, message.data_body, 1, Int(id)) +end + +function _ensure_schema!(x::FlightStream) + isnothing(getfield(x, :schema)) || return x + while true + message = _next_flight_message!(x) + message === nothing && throw(ArgumentError(_missing_schema_message())) + if isempty(message.data_header) + isempty(message.data_body) || throw( + ArgumentError("FlightData message has a body but no Arrow IPC header"), + ) + continue + end + batch = _flight_batch(message, getfield(x, :nextid)) + header = batch.msg.header + if header isa ArrowParent.Meta.Schema + _register_schema!(x, header) + return x + elseif header isa ArrowParent.Meta.Tensor + throw(ArgumentError(ArrowParent.TENSOR_UNSUPPORTED)) + elseif header isa ArrowParent.Meta.SparseTensor + throw(ArgumentError(ArrowParent.SPARSE_TENSOR_UNSUPPORTED)) + end + throw(ArgumentError(_missing_schema_message())) + end +end + +function _store_dictionary_batch!( + x::FlightStream, + batch, + header::ArrowParent.Meta.DictionaryBatch, +) + id = header.id + recordbatch = header.data + @lock getfield(x, :dictencodings) begin + dictencodings = getfield(x, :dictencodings)[] + if haskey(dictencodings, id) && header.isDelta + field = getfield(x, :dictencoded)[id] + values, _, _, _ = ArrowParent.build( + field, + field.type, + batch, + recordbatch, + getfield(x, :dictencodings), + Int64(1), + Int64(1), + Int64(1), + getfield(x, :convert), + ) + dictencoding = dictencodings[id] + append!(dictencoding.data, values) + return + end + field = getfield(x, :dictencoded)[id] + values, _, _, _ = ArrowParent.build( + field, + field.type, + batch, + recordbatch, + getfield(x, :dictencodings), + Int64(1), + Int64(1), + Int64(1), + getfield(x, :convert), + ) + A = ArrowParent.ChainedVector([values]) + S = + field.dictionary.indexType === nothing ? Int32 : + ArrowParent.juliaeltype(field, field.dictionary.indexType, false) + dictencodings[id] = ArrowParent.DictEncoding{eltype(A),S,typeof(A)}( + id, + A, + field.dictionary.isOrdered, + values.metadata, + ) + end + return nothing +end + +function _flight_table(x::FlightStream, columns) + schema = getfield(x, :schema) + schema === nothing && throw(ArgumentError(_missing_schema_message())) + lookup = Dict{Symbol,AbstractVector}() + types = Type[] + for (nm, col) in zip(getfield(x, :names), columns) + lookup[nm] = col + push!(types, eltype(col)) + end + return ArrowParent.Table(getfield(x, :names), types, columns, lookup, Ref(schema)) +end + +function _empty_flight_table(x::FlightStream) + schema = getfield(x, :schema) + schema === nothing && throw(ArgumentError(_missing_schema_message())) + names = copy(getfield(x, :names)) + types = copy(getfield(x, :types)) + columns = AbstractVector[] + for field in schema.fields + T = ArrowParent.juliaeltype( + field, + ArrowParent.buildmetadata(field.custom_metadata), + getfield(x, :convert), + ) + push!(columns, T[]) + end + lookup = Dict{Symbol,AbstractVector}(names[i] => columns[i] for i in eachindex(names)) + return ArrowParent.Table(names, types, columns, lookup, Ref(schema)) +end + +function _copy_flight_table(batch::ArrowParent.Table) + names = copy(ArrowParent.names(batch)) + types = copy(ArrowParent.types(batch)) + columns = copy(ArrowParent.columns(batch)) + schema = ArrowParent.schema(batch)[] + lookup = Dict{Symbol,AbstractVector}(names[i] => columns[i] for i in eachindex(names)) + return ArrowParent.Table(names, types, columns, lookup, Ref(schema)) +end + +_flightcolumndata(col::FlightMetadataVector) = col.data +_flightcolumndata(col) = col + +function _chain_flight_column(col, batch_col) + metadata = ArrowParent.getmetadata(col) + chained = + ArrowParent.ChainedVector([_flightcolumndata(col), _flightcolumndata(batch_col)]) + return metadata === nothing ? chained : FlightMetadataVector(chained, metadata) +end + +function _append_flight_column!(col, batch_col) + append!(_flightcolumndata(col), _flightcolumndata(batch_col)) + return col +end + +function _append_flight_batch!( + table::ArrowParent.Table, + batch::ArrowParent.Table, + batchindex::Int, +) + columns = ArrowParent.columns(table) + batch_columns = ArrowParent.columns(batch) + if batchindex == 2 + for i in eachindex(columns) + columns[i] = _chain_flight_column(columns[i], batch_columns[i]) + end + else + for i in eachindex(columns) + _append_flight_column!(columns[i], batch_columns[i]) + end + end + lookup = getfield(table, :lookup) + for (nm, col) in zip(ArrowParent.names(table), columns) + lookup[nm] = col + end + return table +end + +function _materialize_flight_table(messages; schema=nothing, convert::Bool=true) + stream_state = FlightStream(messages; schema=schema, convert=convert) + state = iterate(stream_state) + state === nothing && return _empty_flight_table(stream_state) + table, next_state = state + next = iterate(stream_state, next_state) + next === nothing && return table + out = _copy_flight_table(table) + batchindex = 2 + while next !== nothing + batch, next_state = next + _append_flight_batch!(out, batch, batchindex) + batchindex += 1 + next = iterate(stream_state, next_state) + end + return out +end + +function _iterate_flight_stream!(x::FlightStream) + _ensure_schema!(x) + while true + message = _next_flight_message!(x) + message === nothing && return nothing + if isempty(message.data_header) + isempty(message.data_body) || throw( + ArgumentError("FlightData message has a body but no Arrow IPC header"), + ) + continue + end + batch = _flight_batch(message, getfield(x, :nextid)) + header = batch.msg.header + if header isa ArrowParent.Meta.Schema + _register_schema!(x, header) + continue + elseif header isa ArrowParent.Meta.DictionaryBatch + _store_dictionary_batch!(x, batch, header) + continue + elseif header isa ArrowParent.Meta.RecordBatch + columns = collect( + ArrowParent.VectorIterator( + getfield(x, :schema), + batch, + getfield(x, :dictencodings), + getfield(x, :convert), + ), + ) + return _flight_table(x, columns), nothing + elseif header isa ArrowParent.Meta.Tensor + throw(ArgumentError(ArrowParent.TENSOR_UNSUPPORTED)) + elseif header isa ArrowParent.Meta.SparseTensor + throw(ArgumentError(ArrowParent.SPARSE_TENSOR_UNSUPPORTED)) + end + throw(ArgumentError("unsupported arrow message type: $(typeof(header))")) + end +end + function streambytes( messages; schema=nothing, @@ -64,8 +401,9 @@ function stream( alignment::Integer=DEFAULT_IPC_ALIGNMENT, end_marker::Bool=true, ) - bytes = streambytes(messages; schema=schema, alignment=alignment, end_marker=end_marker) - return ArrowParent.Stream(bytes; convert=convert) + messages isa AbstractVector{<:Protocol.FlightData} && + _require_schema_messages(messages, schema) + return FlightStream(messages; schema=schema, convert=convert) end function table( @@ -75,6 +413,5 @@ function table( alignment::Integer=DEFAULT_IPC_ALIGNMENT, end_marker::Bool=true, ) - bytes = streambytes(messages; schema=schema, alignment=alignment, end_marker=end_marker) - return ArrowParent.Table(bytes; convert=convert) + return _materialize_flight_table(messages; schema=schema, convert=convert) end diff --git a/src/flight/exports.jl b/src/flight/exports.jl index 2809ce22..b595e47c 100644 --- a/src/flight/exports.jl +++ b/src/flight/exports.jl @@ -43,4 +43,5 @@ export Client, streambytes, stream, table, - flightdata + flightdata, + putflightdata! diff --git a/src/table.jl b/src/table.jl index 991da43e..556f1786 100644 --- a/src/table.jl +++ b/src/table.jl @@ -438,12 +438,39 @@ struct TablePartitions npartitions::Int end +Base.IteratorSize(::Type{TablePartitions}) = Base.HasLength() +Base.length(tp::TablePartitions) = tp.npartitions + +_partitionarrays(col) = col isa ChainedVector ? col.arrays : _wrappedpartitionarrays(col) + +function _wrappedpartitionarrays(col) + if hasfield(typeof(col), :data) + data = getfield(col, :data) + data isa ChainedVector && return data.arrays + end + return nothing +end + +_partitioncolumn(col, i::Int) = + col isa ChainedVector ? col.arrays[i] : _wrappedpartitioncolumn(col, i) + +function _wrappedpartitioncolumn(col, i::Int) + if hasfield(typeof(col), :data) && hasfield(typeof(col), :metadata) + data = getfield(col, :data) + if data isa ChainedVector + wrapper = getfield(parentmodule(typeof(col)), nameof(typeof(col))) + return wrapper(data.arrays[i], getfield(col, :metadata)) + end + end + return col +end + function TablePartitions(table::Table) cols = columns(table) npartitions = if length(cols) == 0 0 - elseif cols[1] isa ChainedVector - length(cols[1].arrays) + elseif (arrays = _partitionarrays(cols[1])) !== nothing + length(arrays) else 1 end @@ -454,7 +481,7 @@ function Base.iterate(tp::TablePartitions, i=1) i > tp.npartitions && return nothing tp.npartitions == 1 && return tp.table, i + 1 cols = columns(tp.table) - newcols = AbstractVector[cols[j].arrays[i] for j = 1:length(cols)] + newcols = AbstractVector[_partitioncolumn(cols[j], i) for j = 1:length(cols)] nms = names(tp.table) tbl = Table( nms, diff --git a/test/flight/grpcserver_extension/bidi_streaming_tests.jl b/test/flight/grpcserver_extension/bidi_streaming_tests.jl index bd5bc4ca..e2f1aff2 100644 --- a/test/flight/grpcserver_extension/bidi_streaming_tests.jl +++ b/test/flight/grpcserver_extension/bidi_streaming_tests.jl @@ -72,6 +72,11 @@ function grpcserver_extension_test_bidi_streaming(grpcserver, service, fixture, ) @test doexchange_closed[] @test length(doexchange_messages) == length(fixture.exchange_messages) + doexchange_table = Arrow.Flight.table(doexchange_messages) + @test doexchange_table.id == [10] + @test doexchange_table.name == ["ten"] + @test Arrow.getmetadata(doexchange_table)["dataset"] == "exchange" + @test Arrow.getmetadata(doexchange_table.name)["lang"] == "exchange" failing_service = Arrow.Flight.Service( doexchange=(ctx, request, response) -> diff --git a/test/flight/grpcserver_extension/server_streaming_tests.jl b/test/flight/grpcserver_extension/server_streaming_tests.jl index 80bfcf24..7eb8b45e 100644 --- a/test/flight/grpcserver_extension/server_streaming_tests.jl +++ b/test/flight/grpcserver_extension/server_streaming_tests.jl @@ -32,8 +32,10 @@ function grpcserver_extension_test_server_streaming(grpcserver, service, fixture ) @test doget_closed[] @test length(doget_messages) == length(fixture.messages) - @test Arrow.Flight.table(doget_messages; schema=fixture.info).name == - ["one", "two", "three"] + doget_table = Arrow.Flight.table(doget_messages; schema=fixture.info) + @test doget_table.name == ["one", "two", "three"] + @test Arrow.getmetadata(doget_table)["dataset"] == "native" + @test Arrow.getmetadata(doget_table.name)["lang"] == "en" doget_any_messages = Any[] doget_any_closed = Ref(false) diff --git a/test/flight/grpcserver_extension/support/fixture.jl b/test/flight/grpcserver_extension/support/fixture.jl index 23ff5a94..0bd00666 100644 --- a/test/flight/grpcserver_extension/support/fixture.jl +++ b/test/flight/grpcserver_extension/support/fixture.jl @@ -20,12 +20,16 @@ function grpcserver_extension_fixture(protocol) descriptor = protocol.FlightDescriptor(descriptor_type.PATH, UInt8[], ["native", "dataset"]) ticket = protocol.Ticket(b"native-ticket") + dataset_metadata = Dict("dataset" => "native") + dataset_colmetadata = Dict(:name => Dict("lang" => "en")) messages = Arrow.Flight.flightdata( Tables.partitioner(( (id=Int64[1, 2], name=["one", "two"]), (id=Int64[3], name=["three"]), )); descriptor=descriptor, + metadata=dataset_metadata, + colmetadata=dataset_colmetadata, ) schema_bytes = Arrow.Flight.schemaipc(first(messages)) info = protocol.FlightInfo( @@ -38,9 +42,13 @@ function grpcserver_extension_fixture(protocol) UInt8[], ) handshake_requests = [protocol.HandshakeRequest(UInt64(0), b"native-token")] + exchange_metadata = Dict("dataset" => "exchange") + exchange_colmetadata = Dict(:name => Dict("lang" => "exchange")) exchange_messages = Arrow.Flight.flightdata( Tables.partitioner(((id=Int64[10], name=["ten"]),)); descriptor=descriptor, + metadata=exchange_metadata, + colmetadata=exchange_colmetadata, ) return ( descriptor=descriptor, @@ -49,6 +57,10 @@ function grpcserver_extension_fixture(protocol) schema_bytes=schema_bytes, info=info, handshake_requests=handshake_requests, + dataset_metadata=dataset_metadata, + dataset_colmetadata=dataset_colmetadata, exchange_messages=exchange_messages, + exchange_metadata=exchange_metadata, + exchange_colmetadata=exchange_colmetadata, ) end diff --git a/test/flight/grpcserver_extension/support/service.jl b/test/flight/grpcserver_extension/support/service.jl index 8b0ef260..fea4cc58 100644 --- a/test/flight/grpcserver_extension/support/service.jl +++ b/test/flight/grpcserver_extension/support/service.jl @@ -37,8 +37,17 @@ function grpcserver_extension_service(protocol, fixture) doget=(ctx, req, response) -> begin @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" @test req.ticket == fixture.ticket.ticket - foreach(message -> put!(response, message), fixture.messages) - close(response) + Arrow.Flight.putflightdata!( + response, + Tables.partitioner(( + (id=Int64[1, 2], name=["one", "two"]), + (id=Int64[3], name=["three"]), + )); + descriptor=fixture.descriptor, + metadata=fixture.dataset_metadata, + colmetadata=fixture.dataset_colmetadata, + close=true, + ) return :doget_ok end, listactions=(ctx, response) -> begin @@ -56,17 +65,21 @@ function grpcserver_extension_service(protocol, fixture) end, doput=(ctx, request, response) -> begin @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" - incoming = collect(request) - @test length(incoming) == length(fixture.messages) + incoming = collect(Arrow.Flight.stream(request)) + @test length(incoming) == 2 + @test incoming[1].id == [1, 2] + @test incoming[1].name == ["one", "two"] + @test Arrow.getmetadata(incoming[1])["dataset"] == "native" + @test Arrow.getmetadata(incoming[1].name)["lang"] == "en" + @test incoming[2].id == [3] + @test incoming[2].name == ["three"] put!(response, protocol.PutResult(b"stored")) close(response) return :doput_ok end, doexchange=(ctx, request, response) -> begin @test Arrow.Flight.callheader(ctx, "authorization") == "Bearer native" - incoming = collect(request) - foreach(message -> put!(response, message), incoming) - close(response) + Arrow.Flight.putflightdata!(response, Arrow.Flight.stream(request); close=true) return :doexchange_ok end, ) diff --git a/test/flight/ipc_conversion.jl b/test/flight/ipc_conversion.jl index e0f35bec..8bfb3238 100644 --- a/test/flight/ipc_conversion.jl +++ b/test/flight/ipc_conversion.jl @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +using DataAPI using Tables using UUIDs @@ -79,6 +80,50 @@ using UUIDs @test isempty(empty_tbl.id) @test isempty(empty_tbl.label) + metadata_source = Tables.partitioner(((title=["red", "blue"],), (title=["green"],))) + metadata_messages = Arrow.Flight.flightdata( + metadata_source; + metadata=Dict("dataset" => "flight"), + colmetadata=Dict(:title => Dict("lang" => "en")), + ) + metadata_schema_bytes = Arrow.Flight.schemaipc(first(metadata_messages)) + metadata_info = Arrow.Flight.Protocol.FlightInfo( + metadata_schema_bytes[5:end], + nothing, + Arrow.Flight.Protocol.FlightEndpoint[], + Int64(-1), + Int64(-1), + false, + UInt8[], + ) + metadata_batches = + collect(Arrow.Flight.stream(metadata_messages[2:end]; schema=metadata_info)) + metadata_table = Arrow.Flight.table(metadata_messages[2:end]; schema=metadata_info) + + @test length(metadata_batches) == 2 + @test DataAPI.metadata(metadata_batches[1], "dataset") == "flight" + @test DataAPI.colmetadata(metadata_batches[1], :title, "lang") == "en" + @test DataAPI.metadata(metadata_batches[2], "dataset") == "flight" + @test DataAPI.colmetadata(metadata_batches[2], :title, "lang") == "en" + @test metadata_table.title == ["red", "blue", "green"] + @test DataAPI.metadata(metadata_table, "dataset") == "flight" + @test DataAPI.colmetadata(metadata_table, :title, "lang") == "en" + metadata_parts = collect(Tables.partitions(metadata_table)) + @test length(metadata_parts) == 2 + @test metadata_parts[1].title == ["red", "blue"] + @test metadata_parts[2].title == ["green"] + @test DataAPI.colmetadata(metadata_parts[1], :title, "lang") == "en" + + reemitted_channel = Channel{Arrow.Flight.Protocol.FlightData}(8) + reemit_task = + @async Arrow.Flight.putflightdata!(reemitted_channel, metadata_table; close=true) + reemitted_messages = collect(reemitted_channel) + wait(reemit_task) + reemitted_table = Arrow.Flight.table(reemitted_messages) + @test reemitted_table.title == metadata_table.title + @test DataAPI.metadata(reemitted_table, "dataset") == "flight" + @test DataAPI.colmetadata(reemitted_table, :title, "lang") == "en" + extension_source = ( uuid=[UUID(UInt128(1)), UUID(UInt128(2))], flag=[Arrow.Bool8(true), Arrow.Bool8(false)], diff --git a/test/flight/pyarrow_interop/exchange_tests.jl b/test/flight/pyarrow_interop/exchange_tests.jl index aaad35e3..47631f8a 100644 --- a/test/flight/pyarrow_interop/exchange_tests.jl +++ b/test/flight/pyarrow_interop/exchange_tests.jl @@ -24,13 +24,26 @@ function pyarrow_interop_test_exchange(client, exchange_descriptor) Arrow.Flight.flightdata(exchange_source; descriptor=exchange_descriptor) exchange_req, exchange_request, exchange_response = Arrow.Flight.doexchange(client) - exchanged_messages = pyarrow_interop_send_messages( - exchange_req, - exchange_request, - exchange_response, - exchange_messages, + sender = @async begin + for message in exchange_messages + put!(exchange_request, message) + end + close(exchange_request) + end + exchanged_messages = Arrow.Flight.Protocol.FlightData[] + exchange_batches = collect( + Arrow.Flight.stream(( + (push!(exchanged_messages, message); message) for message in exchange_response + ),), ) + wait(sender) + gRPCClient.grpc_async_await(exchange_req) + @test length(exchange_batches) == 2 + @test exchange_batches[1].id == [21, 22] + @test exchange_batches[1].name == ["twenty-one", "twenty-two"] + @test exchange_batches[2].id == [23] + @test exchange_batches[2].name == ["twenty-three"] exchange_table = Arrow.Flight.table(exchanged_messages) @test exchange_table.id == [21, 22, 23] @test exchange_table.name == ["twenty-one", "twenty-two", "twenty-three"] From 275eed24b21794619c546deb5f53b82d7878bbce Mon Sep 17 00:00:00 2001 From: guangtao Date: Tue, 31 Mar 2026 22:49:00 -0700 Subject: [PATCH 15/16] Finalize metadata overlays and Flight runtime contract --- README.md | 3 +- docs/src/manual.md | 6 + src/Arrow.jl | 1 + src/flight/client/methods/data.jl | 160 ++++++++++++++++++ src/flight/client/transport.jl | 26 +++ src/flight/convert/streaming.jl | 100 +++++++---- src/metadata/overlay.jl | 155 +++++++++++++++++ src/table.jl | 26 +++ test/flight/ipc_conversion.jl | 24 +++ test/flight/pyarrow_interop/exchange_tests.jl | 39 +++-- test/flight/pyarrow_interop/upload_tests.jl | 20 ++- test/runtests.jl | 20 +++ 12 files changed, 530 insertions(+), 50 deletions(-) create mode 100644 src/metadata/overlay.jl diff --git a/README.md b/README.md index cf7316cb..658c2f83 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ This implementation supports the 1.0 version of the specification, including sup * Dictionary encodings and messages * Dictionary-encoded `CategoricalArray` interop, including missing-value roundtrips through `Arrow.Table`, `copy`, and `DataFrame(...; copycols=true)` * Extension types + * Lightweight schema/field metadata overlays via `Arrow.withmetadata(...)` for Tables.jl-compatible sources before serialization * Base Julia `Enum` logical types via the `JuliaLang.Enum` extension label, with native Julia roundtrips back to the original enum type while `convert=false` and non-Julia consumers still see the primitive storage type * View-backed Utf8/Binary columns, including recovery from under-reported variadic buffer counts by inferring the required external buffers from valid view elements * Streaming, file, record batch, and replacement and isdelta dictionary messages @@ -79,7 +80,7 @@ Flight RPC status: * Requires Julia `1.12+` * Includes generated protocol bindings and complete client constructors for the `FlightService` RPC surface * Keeps the top-level Flight module shell thin, with exports and generated-protocol setup split out of `src/flight/Flight.jl` - * Includes high-level `FlightData <-> Arrow IPC` helpers for `Arrow.Table`, `Arrow.Stream`, and DoPut payload generation + * Includes high-level `FlightData <-> Arrow IPC` helpers for `Arrow.Table`, `Arrow.Stream`, and DoPut payload generation, plus opt-in `app_metadata` surfacing through `include_app_metadata=true` on `Arrow.Flight.stream(...)` / `Arrow.Flight.table(...)` * Keeps the Flight IPC conversion layer modular under `src/flight/convert/`, with `src/flight/convert.jl` retained as a thin entrypoint * Includes client helpers for request headers, binary metadata, handshake token reuse, and TLS configuration via `withheaders`, `withtoken`, and `authenticate` * Keeps the Flight client implementation modular under `src/flight/client/`, with thin entrypoints at `src/flight/client.jl` and `src/flight/client/rpc_methods.jl` diff --git a/docs/src/manual.md b/docs/src/manual.md index e806f5f6..42a97828 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -245,6 +245,12 @@ Arrow.jl provides a convenient accessor for this metadata via [`Arrow.getmetadat To attach custom schema/column metadata to Arrow tables at serialization time, see the `metadata` and `colmetadata` keyword arguments to [`Arrow.write`](@ref). +For lightweight overlays on existing Tables.jl sources, Arrow.jl also provides +`Arrow.withmetadata(table_like; metadata=..., colmetadata=...)`. This keeps any +existing schema/field metadata already exposed by the source, overlays new +entries on top, and returns a wrapper that can be passed directly to +[`Arrow.write`](@ref), `Arrow.tobuffer`, or the Flight IPC helpers. + ## Writing arrow data Ok, so that's a pretty good rundown of *reading* arrow data, but how do you *produce* arrow data? Enter `Arrow.write`. diff --git a/src/Arrow.jl b/src/Arrow.jl index 2d3e6956..68ddda8d 100644 --- a/src/Arrow.jl +++ b/src/Arrow.jl @@ -99,6 +99,7 @@ include("arraytypes/arraytypes.jl") include("eltypes.jl") include("logicaltypes_builtin.jl") include("table.jl") +include("metadata/overlay.jl") include("write.jl") include("append.jl") include("show.jl") diff --git a/src/flight/client/methods/data.jl b/src/flight/client/methods/data.jl index 10a2f54a..62da0ec0 100644 --- a/src/flight/client/methods/data.jl +++ b/src/flight/client/methods/data.jl @@ -68,6 +68,86 @@ function doput( return req, request, response end +function doput( + client::Client, + source, + response::Channel{Protocol.PutResult}; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + kwargs..., +) + request = Channel{Protocol.FlightData}(request_capacity) + grpc_request = doput(client, request, response; headers=headers, kwargs...) + producer = errormonitor( + Threads.@spawn putflightdata!( + request, + source; + close=true, + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + ) + ) + return FlightAsyncRequest(grpc_request, producer) +end + +function doput( + client::Client, + source; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + kwargs..., +) + response = Channel{Protocol.PutResult}(response_capacity) + req = doput( + client, + source, + response; + request_capacity=request_capacity, + headers=headers, + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + kwargs..., + ) + return req, response +end + doexchange( client::Client, request::Channel{Protocol.FlightData}, @@ -94,3 +174,83 @@ function doexchange( req = doexchange(client, request, response; headers=headers, kwargs...) return req, request, response end + +function doexchange( + client::Client, + source, + response::Channel{Protocol.FlightData}; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + kwargs..., +) + request = Channel{Protocol.FlightData}(request_capacity) + grpc_request = doexchange(client, request, response; headers=headers, kwargs...) + producer = errormonitor( + Threads.@spawn putflightdata!( + request, + source; + close=true, + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + ) + ) + return FlightAsyncRequest(grpc_request, producer) +end + +function doexchange( + client::Client, + source; + request_capacity::Integer=DEFAULT_STREAM_BUFFER, + response_capacity::Integer=DEFAULT_STREAM_BUFFER, + headers::AbstractVector{<:Pair}=HeaderPair[], + descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + compress=nothing, + largelists::Bool=false, + denseunions::Bool=true, + dictencode::Bool=false, + dictencodenested::Bool=false, + alignment::Integer=DEFAULT_IPC_ALIGNMENT, + maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, + metadata::Union{Nothing,Any}=nothing, + colmetadata::Union{Nothing,Any}=nothing, + kwargs..., +) + response = Channel{Protocol.FlightData}(response_capacity) + req = doexchange( + client, + source, + response; + request_capacity=request_capacity, + headers=headers, + descriptor=descriptor, + compress=compress, + largelists=largelists, + denseunions=denseunions, + dictencode=dictencode, + dictencodenested=dictencodenested, + alignment=alignment, + maxdepth=maxdepth, + metadata=metadata, + colmetadata=colmetadata, + kwargs..., + ) + return req, response +end diff --git a/src/flight/client/transport.jl b/src/flight/client/transport.jl index 4d6fef80..5b6ede79 100644 --- a/src/flight/client/transport.jl +++ b/src/flight/client/transport.jl @@ -119,6 +119,32 @@ function _grpc_async_request( end end +struct FlightAsyncRequest{R} + request::R + producer::Union{Nothing,Task} +end + +function Base.wait(req::FlightAsyncRequest) + producer = getfield(req, :producer) + isnothing(producer) || wait(producer) + return wait(getfield(req, :request)) +end + +function gRPCClient.grpc_async_await(req::FlightAsyncRequest) + producer = getfield(req, :producer) + isnothing(producer) || wait(producer) + return gRPCClient.grpc_async_await(getfield(req, :request)) +end + +function gRPCClient.grpc_async_await( + client::gRPCClient.gRPCServiceClient{TRequest,true,TResponse,false}, + req::FlightAsyncRequest, +) where {TRequest<:Any,TResponse<:Any} + producer = getfield(req, :producer) + isnothing(producer) || wait(producer) + return gRPCClient.grpc_async_await(client, getfield(req, :request)) +end + _default_rpc_options(client::Client) = ( secure=client.secure, grpc=client.grpc, diff --git a/src/flight/convert/streaming.jl b/src/flight/convert/streaming.jl index 458fee57..06b22749 100644 --- a/src/flight/convert/streaming.jl +++ b/src/flight/convert/streaming.jl @@ -32,20 +32,10 @@ mutable struct FlightStream{M} convert::Bool end -struct FlightMetadataVector{T,A<:AbstractVector{T}} <: AbstractVector{T} - data::A - metadata::Union{Nothing,Base.ImmutableDict{String,String}} +struct FlightStreamWithAppMetadata{S} + stream::S end -Base.IndexStyle(::Type{<:FlightMetadataVector}) = Base.IndexLinear() -Base.size(x::FlightMetadataVector) = size(x.data) -Base.axes(x::FlightMetadataVector) = axes(x.data) -Base.length(x::FlightMetadataVector) = length(x.data) -Base.getindex(x::FlightMetadataVector, i::Int) = getindex(x.data, i) -Base.iterate(x::FlightMetadataVector) = iterate(x.data) -Base.iterate(x::FlightMetadataVector, state) = iterate(x.data, state) -ArrowParent.getmetadata(x::FlightMetadataVector) = x.metadata - function FlightStream(messages; schema=nothing, convert::Bool=true) x = FlightStream( messages, @@ -68,7 +58,12 @@ Base.IteratorSize(::Type{<:FlightStream}) = Base.SizeUnknown() Base.eltype(::Type{<:FlightStream}) = ArrowParent.Table Base.isdone(x::FlightStream) = x.exhausted +Base.IteratorSize(::Type{<:FlightStreamWithAppMetadata}) = Base.SizeUnknown() +Base.eltype(::Type{<:FlightStreamWithAppMetadata}) = + NamedTuple{(:table, :app_metadata),Tuple{ArrowParent.Table,Vector{UInt8}}} + Tables.partitions(x::FlightStream) = x +Tables.partitions(x::FlightStreamWithAppMetadata) = x function Tables.columnnames(x::FlightStream) _ensure_schema!(x) @@ -88,6 +83,14 @@ function Base.iterate(x::FlightStream, ::Nothing) return _iterate_flight_stream!(x) end +function Base.iterate(x::FlightStreamWithAppMetadata) + return _iterate_flight_stream!(x.stream; include_app_metadata=true) +end + +function Base.iterate(x::FlightStreamWithAppMetadata, ::Nothing) + return _iterate_flight_stream!(x.stream; include_app_metadata=true) +end + function _missing_schema_message() return join( [ @@ -278,14 +281,24 @@ function _copy_flight_table(batch::ArrowParent.Table) return ArrowParent.Table(names, types, columns, lookup, Ref(schema)) end -_flightcolumndata(col::FlightMetadataVector) = col.data -_flightcolumndata(col) = col +_copy_app_metadata(message::Protocol.FlightData) = copy(message.app_metadata) + +function _flight_batch_result( + table::ArrowParent.Table, + message::Protocol.FlightData; + include_app_metadata::Bool, +) + include_app_metadata || return table + return (table=table, app_metadata=_copy_app_metadata(message)) +end + +_flightcolumndata(col) = ArrowParent._metadatavectordata(col) function _chain_flight_column(col, batch_col) metadata = ArrowParent.getmetadata(col) chained = ArrowParent.ChainedVector([_flightcolumndata(col), _flightcolumndata(batch_col)]) - return metadata === nothing ? chained : FlightMetadataVector(chained, metadata) + return ArrowParent._wrapmetadata(chained, metadata) end function _append_flight_column!(col, batch_col) @@ -316,25 +329,39 @@ function _append_flight_batch!( return table end -function _materialize_flight_table(messages; schema=nothing, convert::Bool=true) +function _materialize_flight_table( + messages; + schema=nothing, + convert::Bool=true, + include_app_metadata::Bool=false, +) stream_state = FlightStream(messages; schema=schema, convert=convert) - state = iterate(stream_state) - state === nothing && return _empty_flight_table(stream_state) - table, next_state = state - next = iterate(stream_state, next_state) - next === nothing && return table - out = _copy_flight_table(table) + state = _iterate_flight_stream!(stream_state; include_app_metadata=include_app_metadata) + if state === nothing + empty_table = _empty_flight_table(stream_state) + return include_app_metadata ? (table=empty_table, app_metadata=Vector{UInt8}[]) : + empty_table + end + first_value, _ = state + first_table = include_app_metadata ? first_value.table : first_value + out = _copy_flight_table(first_table) + batch_app_metadata = + include_app_metadata ? Vector{Vector{UInt8}}([first_value.app_metadata]) : nothing batchindex = 2 - while next !== nothing - batch, next_state = next + while true + next = + _iterate_flight_stream!(stream_state; include_app_metadata=include_app_metadata) + next === nothing && break + batch_value, _ = next + batch = include_app_metadata ? batch_value.table : batch_value _append_flight_batch!(out, batch, batchindex) + include_app_metadata && push!(batch_app_metadata, batch_value.app_metadata) batchindex += 1 - next = iterate(stream_state, next_state) end - return out + return include_app_metadata ? (table=out, app_metadata=batch_app_metadata) : out end -function _iterate_flight_stream!(x::FlightStream) +function _iterate_flight_stream!(x::FlightStream; include_app_metadata::Bool=false) _ensure_schema!(x) while true message = _next_flight_message!(x) @@ -362,7 +389,12 @@ function _iterate_flight_stream!(x::FlightStream) getfield(x, :convert), ), ) - return _flight_table(x, columns), nothing + return _flight_batch_result( + _flight_table(x, columns), + message; + include_app_metadata=include_app_metadata, + ), + nothing elseif header isa ArrowParent.Meta.Tensor throw(ArgumentError(ArrowParent.TENSOR_UNSUPPORTED)) elseif header isa ArrowParent.Meta.SparseTensor @@ -398,20 +430,28 @@ function stream( messages; schema=nothing, convert::Bool=true, + include_app_metadata::Bool=false, alignment::Integer=DEFAULT_IPC_ALIGNMENT, end_marker::Bool=true, ) messages isa AbstractVector{<:Protocol.FlightData} && _require_schema_messages(messages, schema) - return FlightStream(messages; schema=schema, convert=convert) + flight_stream = FlightStream(messages; schema=schema, convert=convert) + return include_app_metadata ? FlightStreamWithAppMetadata(flight_stream) : flight_stream end function table( messages; schema=nothing, convert::Bool=true, + include_app_metadata::Bool=false, alignment::Integer=DEFAULT_IPC_ALIGNMENT, end_marker::Bool=true, ) - return _materialize_flight_table(messages; schema=schema, convert=convert) + return _materialize_flight_table( + messages; + schema=schema, + convert=convert, + include_app_metadata=include_app_metadata, + ) end diff --git a/src/metadata/overlay.jl b/src/metadata/overlay.jl new file mode 100644 index 00000000..6c00ea05 --- /dev/null +++ b/src/metadata/overlay.jl @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +_metadata_entries(metadata) = metadata isa AbstractVector ? metadata : pairs(metadata) + +function _normalize_metadata_overlay(metadata) + metadata === nothing && return nothing + return toidict( + String(first(entry)) => String(last(entry)) for entry in _metadata_entries(metadata) + ) +end + +function _merge_metadata_overlays(metadata_sources...) + merged = Dict{String,String}() + for metadata in metadata_sources + metadata === nothing && continue + for entry in _metadata_entries(metadata) + merged[String(first(entry))] = String(last(entry)) + end + end + return isempty(merged) ? nothing : toidict(pairs(merged)) +end + +struct MetadataOverlayVector{T,V<:AbstractVector{T},M} <: AbstractVector{T} + data::V + metadata::M +end + +Base.IndexStyle(::Type{<:MetadataOverlayVector{T,V}}) where {T,V} = Base.IndexStyle(V) +Base.size(x::MetadataOverlayVector) = size(x.data) +Base.axes(x::MetadataOverlayVector) = axes(x.data) +Base.length(x::MetadataOverlayVector) = length(x.data) +Base.getindex(x::MetadataOverlayVector, i::Int) = x.data[i] +Base.iterate(x::MetadataOverlayVector, state...) = iterate(x.data, state...) +getmetadata(x::MetadataOverlayVector) = x.metadata + +struct MetadataOverlayTable{N,C,M} + columns::NamedTuple{N,C} + metadata::M +end + +function Base.getproperty(x::MetadataOverlayTable, name::Symbol) + if name === :columns || name === :metadata + return getfield(x, name) + end + columns = getfield(x, :columns) + if hasproperty(columns, name) + return getproperty(columns, name) + end + return getfield(x, name) +end + +function Base.propertynames(x::MetadataOverlayTable, private::Bool=false) + column_names = propertynames(getfield(x, :columns)) + return private ? (:columns, :metadata, column_names...) : column_names +end + +Tables.istable(::Type{<:MetadataOverlayTable}) = true +Tables.columnaccess(::Type{<:MetadataOverlayTable}) = true +Tables.columns(x::MetadataOverlayTable) = getfield(x, :columns) +Tables.schema(x::MetadataOverlayTable) = Tables.schema(getfield(x, :columns)) +getmetadata(x::MetadataOverlayTable) = getfield(x, :metadata) + +function _column_metadata_overlay(table_like) + merged = Dict{Symbol,Any}() + for name in Tables.schema(table_like).names + metadata = + _normalize_metadata_overlay(getmetadata(Tables.getcolumn(table_like, name))) + metadata === nothing || (merged[name] = metadata) + end + return merged +end + +function _merge_column_metadata_overlays(table_like, colmetadata) + merged = _column_metadata_overlay(table_like) + colmetadata === nothing && return merged + for (name, metadata) in pairs(colmetadata) + symbol_name = Symbol(name) + merged_metadata = + _merge_metadata_overlays(get(merged, symbol_name, nothing), metadata) + merged_metadata === nothing || (merged[symbol_name] = merged_metadata) + end + return merged +end + +function _metadata_overlay_table(columns::NamedTuple; metadata=nothing, colmetadata=nothing) + wrapped_columns = Pair{Symbol,Any}[] + for name in keys(columns) + column_metadata = isnothing(colmetadata) ? nothing : get(colmetadata, name, nothing) + push!( + wrapped_columns, + name => MetadataOverlayVector(columns[name], column_metadata), + ) + end + return MetadataOverlayTable((; wrapped_columns...), metadata) +end + +""" + Arrow.withmetadata(table_like; metadata=nothing, colmetadata=nothing) + +Return a lightweight Tables.jl-compatible wrapper around `table_like` that +preserves any existing Arrow schema/field metadata and overlays additional +schema `metadata` and column `colmetadata` for subsequent Arrow serialization. + +Both `metadata` and `colmetadata` follow the same shape accepted by +[`Arrow.write`](@ref): schema metadata must be an iterable of string-like pairs, +while `colmetadata` must map column names to iterables of string-like pairs. +When the source already carries metadata, overlay entries win on key conflicts. +""" +function withmetadata(columns::NamedTuple; metadata=nothing, colmetadata=nothing) + normalized_metadata = _normalize_metadata_overlay(metadata) + normalized_colmetadata = if isnothing(colmetadata) + nothing + else + Dict( + Symbol(name) => _normalize_metadata_overlay(column_metadata) for + (name, column_metadata) in pairs(colmetadata) + ) + end + if normalized_metadata === nothing && isnothing(normalized_colmetadata) + return columns + end + return _metadata_overlay_table( + columns; + metadata=normalized_metadata, + colmetadata=normalized_colmetadata, + ) +end + +function withmetadata(table_like; metadata=nothing, colmetadata=nothing) + merged_metadata = _merge_metadata_overlays(getmetadata(table_like), metadata) + merged_colmetadata = _merge_column_metadata_overlays(table_like, colmetadata) + if merged_metadata === nothing && isempty(merged_colmetadata) + return table_like + end + return _metadata_overlay_table( + Tables.columntable(table_like); + metadata=merged_metadata, + colmetadata=isempty(merged_colmetadata) ? nothing : merged_colmetadata, + ) +end diff --git a/src/table.jl b/src/table.jl index 556f1786..e2df7285 100644 --- a/src/table.jl +++ b/src/table.jl @@ -433,6 +433,24 @@ Tables.columnnames(t::Table) = names(t) Tables.getcolumn(t::Table, i::Int) = columns(t)[i] Tables.getcolumn(t::Table, nm::Symbol) = lookup(t)[nm] +struct MetadataVector{T,A<:AbstractVector{T},M} <: AbstractVector{T} + data::A + metadata::M +end + +Base.IndexStyle(::Type{<:MetadataVector}) = Base.IndexLinear() +Base.size(x::MetadataVector) = size(x.data) +Base.axes(x::MetadataVector) = axes(x.data) +Base.length(x::MetadataVector) = length(x.data) +Base.getindex(x::MetadataVector, i::Int) = getindex(x.data, i) +Base.iterate(x::MetadataVector) = iterate(x.data) +Base.iterate(x::MetadataVector, state) = iterate(x.data, state) +getmetadata(x::MetadataVector) = x.metadata + +_metadatavectordata(x::MetadataVector) = x.data +_metadatavectordata(x) = x +_wrapmetadata(data, metadata) = metadata === nothing ? data : MetadataVector(data, metadata) + struct TablePartitions table::Table npartitions::Int @@ -441,6 +459,11 @@ end Base.IteratorSize(::Type{TablePartitions}) = Base.HasLength() Base.length(tp::TablePartitions) = tp.npartitions +function _partitionarrays(col::MetadataVector) + data = getfield(col, :data) + return data isa ChainedVector ? data.arrays : nothing +end + _partitionarrays(col) = col isa ChainedVector ? col.arrays : _wrappedpartitionarrays(col) function _wrappedpartitionarrays(col) @@ -451,6 +474,9 @@ function _wrappedpartitionarrays(col) return nothing end +_partitioncolumn(col::MetadataVector, i::Int) = + MetadataVector(getfield(col, :data).arrays[i], getfield(col, :metadata)) + _partitioncolumn(col, i::Int) = col isa ChainedVector ? col.arrays[i] : _wrappedpartitioncolumn(col, i) diff --git a/test/flight/ipc_conversion.jl b/test/flight/ipc_conversion.jl index 8bfb3238..b63b58a3 100644 --- a/test/flight/ipc_conversion.jl +++ b/test/flight/ipc_conversion.jl @@ -112,7 +112,31 @@ using UUIDs @test length(metadata_parts) == 2 @test metadata_parts[1].title == ["red", "blue"] @test metadata_parts[2].title == ["green"] + @test DataAPI.metadata(metadata_parts[1], "dataset") == "flight" @test DataAPI.colmetadata(metadata_parts[1], :title, "lang") == "en" + @test DataAPI.metadata(metadata_parts[2], "dataset") == "flight" + @test DataAPI.colmetadata(metadata_parts[2], :title, "lang") == "en" + + app_metadata_messages = [ + index == 1 ? message : + Arrow.Flight.Protocol.FlightData( + message.flight_descriptor, + message.data_header, + Vector{UInt8}(codeunits("batch:$(index - 2)")), + message.data_body, + ) for (index, message) in enumerate(metadata_messages) + ] + metadata_batches_with_app = + collect(Arrow.Flight.stream(app_metadata_messages; include_app_metadata=true)) + metadata_table_with_app = + Arrow.Flight.table(app_metadata_messages; include_app_metadata=true) + @test length(metadata_batches_with_app) == 2 + @test metadata_batches_with_app[1].table.title == ["red", "blue"] + @test metadata_batches_with_app[2].table.title == ["green"] + @test String(metadata_batches_with_app[1].app_metadata) == "batch:0" + @test String(metadata_batches_with_app[2].app_metadata) == "batch:1" + @test metadata_table_with_app.table.title == ["red", "blue", "green"] + @test String.(metadata_table_with_app.app_metadata) == ["batch:0", "batch:1"] reemitted_channel = Channel{Arrow.Flight.Protocol.FlightData}(8) reemit_task = diff --git a/test/flight/pyarrow_interop/exchange_tests.jl b/test/flight/pyarrow_interop/exchange_tests.jl index 47631f8a..572dcc3c 100644 --- a/test/flight/pyarrow_interop/exchange_tests.jl +++ b/test/flight/pyarrow_interop/exchange_tests.jl @@ -20,33 +20,50 @@ function pyarrow_interop_test_exchange(client, exchange_descriptor) (id=Int64[21, 22], name=["twenty-one", "twenty-two"]), (id=Int64[23], name=["twenty-three"]), )) - exchange_messages = - Arrow.Flight.flightdata(exchange_source; descriptor=exchange_descriptor) - - exchange_req, exchange_request, exchange_response = Arrow.Flight.doexchange(client) - sender = @async begin - for message in exchange_messages - put!(exchange_request, message) - end - close(exchange_request) - end + exchange_metadata = Dict("dataset" => "interop-exchange") + exchange_colmetadata = Dict(:name => Dict("lang" => "en")) + exchange_req, exchange_response = Arrow.Flight.doexchange( + client, + exchange_source; + descriptor=exchange_descriptor, + metadata=exchange_metadata, + colmetadata=exchange_colmetadata, + ) exchanged_messages = Arrow.Flight.Protocol.FlightData[] exchange_batches = collect( Arrow.Flight.stream(( (push!(exchanged_messages, message); message) for message in exchange_response ),), ) - wait(sender) gRPCClient.grpc_async_await(exchange_req) @test length(exchange_batches) == 2 @test exchange_batches[1].id == [21, 22] @test exchange_batches[1].name == ["twenty-one", "twenty-two"] + @test DataAPI.metadata(exchange_batches[1], "dataset") == "interop-exchange" + @test DataAPI.colmetadata(exchange_batches[1], :name, "lang") == "en" @test exchange_batches[2].id == [23] @test exchange_batches[2].name == ["twenty-three"] + @test DataAPI.metadata(exchange_batches[2], "dataset") == "interop-exchange" + @test DataAPI.colmetadata(exchange_batches[2], :name, "lang") == "en" exchange_table = Arrow.Flight.table(exchanged_messages) @test exchange_table.id == [21, 22, 23] @test exchange_table.name == ["twenty-one", "twenty-two", "twenty-three"] + @test DataAPI.metadata(exchange_table, "dataset") == "interop-exchange" + @test DataAPI.colmetadata(exchange_table, :name, "lang") == "en" @test filter(!isempty, getfield.(exchanged_messages, :app_metadata)) == [b"exchange:0", b"exchange:1"] + + exchange_batches_with_app = + collect(Arrow.Flight.stream(exchanged_messages; include_app_metadata=true)) + @test exchange_batches_with_app[1].table.id == [21, 22] + @test exchange_batches_with_app[2].table.id == [23] + @test String.(getproperty.(exchange_batches_with_app, :app_metadata)) == + ["exchange:0", "exchange:1"] + + exchange_table_with_app = + Arrow.Flight.table(exchanged_messages; include_app_metadata=true) + @test exchange_table_with_app.table.id == [21, 22, 23] + @test exchange_table_with_app.table.name == ["twenty-one", "twenty-two", "twenty-three"] + @test String.(exchange_table_with_app.app_metadata) == ["exchange:0", "exchange:1"] end diff --git a/test/flight/pyarrow_interop/upload_tests.jl b/test/flight/pyarrow_interop/upload_tests.jl index 1c10c9e2..e07b8de3 100644 --- a/test/flight/pyarrow_interop/upload_tests.jl +++ b/test/flight/pyarrow_interop/upload_tests.jl @@ -20,15 +20,17 @@ function pyarrow_interop_test_upload(client, upload_descriptor) (id=Int64[10, 11], name=["ten", "eleven"]), (id=Int64[12], name=["twelve"]), )) - upload_messages = Arrow.Flight.flightdata(upload_source; descriptor=upload_descriptor) - - doput_req, doput_request, doput_response = Arrow.Flight.doput(client) - put_results = pyarrow_interop_send_messages( - doput_req, - doput_request, - doput_response, - upload_messages, + upload_metadata = Dict("dataset" => "interop-upload") + upload_colmetadata = Dict(:name => Dict("lang" => "en")) + doput_req, doput_response = Arrow.Flight.doput( + client, + upload_source; + descriptor=upload_descriptor, + metadata=upload_metadata, + colmetadata=upload_colmetadata, ) + put_results = collect(doput_response) + gRPCClient.grpc_async_await(doput_req) @test !isempty(put_results) @test String(put_results[end].app_metadata) == "stored" @@ -41,4 +43,6 @@ function pyarrow_interop_test_upload(client, upload_descriptor) uploaded_table = Arrow.Flight.table(uploaded_messages; schema=uploaded_info) @test uploaded_table.id == [10, 11, 12] @test uploaded_table.name == ["ten", "eleven", "twelve"] + @test DataAPI.metadata(uploaded_table, "dataset") == "interop-upload" + @test DataAPI.colmetadata(uploaded_table, :name, "lang") == "en" end diff --git a/test/runtests.jl b/test/runtests.jl index 671694ec..166e4297 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -248,6 +248,26 @@ end @test Arrow.getmetadata(tt.col2)["colkey1"] == "colvalue1" @test Arrow.getmetadata(tt.col2)["colkey2"] == "colvalue2" @test Arrow.getmetadata(tt.col3)["colkey3"] == "colvalue3" + + source = Arrow.withmetadata( + (col1=collect(1:3), col2=["a", "b", "c"]); + metadata=["source" => "base"], + colmetadata=Dict(:col1 => ["semantic.role" => "left"]), + ) + overlay = Arrow.withmetadata( + source; + metadata=["overlay" => "yes"], + colmetadata=Dict( + :col1 => ["unit" => "count"], + :col2 => ["semantic.role" => "right"], + ), + ) + overlay_tt = Arrow.Table(Arrow.tobuffer(overlay)) + @test Arrow.getmetadata(overlay_tt)["source"] == "base" + @test Arrow.getmetadata(overlay_tt)["overlay"] == "yes" + @test Arrow.getmetadata(overlay_tt.col1)["semantic.role"] == "left" + @test Arrow.getmetadata(overlay_tt.col1)["unit"] == "count" + @test Arrow.getmetadata(overlay_tt.col2)["semantic.role"] == "right" end @testset "# custom compressors" begin From de4b6f7457eb91f42dc10be5bb2b8aae3a7420be Mon Sep 17 00:00:00 2001 From: guangtao Date: Wed, 1 Apr 2026 01:02:21 -0700 Subject: [PATCH 16/16] Enhance Arrow.Flight.withappmetadata --- README.md | 2 +- docs/src/manual.md | 11 +++ src/flight/client/methods/data.jl | 8 ++ src/flight/convert/flightdata.jl | 89 ++++++++++++++++++- src/flight/convert/framing.jl | 8 +- src/flight/exports.jl | 1 + test/flight/ipc_conversion.jl | 68 +++++++++++--- test/flight/pyarrow_interop/exchange_tests.jl | 9 +- test/flight_pyarrow_server.py | 4 +- 9 files changed, 182 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 658c2f83..7b8f4e68 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ Flight RPC status: * Requires Julia `1.12+` * Includes generated protocol bindings and complete client constructors for the `FlightService` RPC surface * Keeps the top-level Flight module shell thin, with exports and generated-protocol setup split out of `src/flight/Flight.jl` - * Includes high-level `FlightData <-> Arrow IPC` helpers for `Arrow.Table`, `Arrow.Stream`, and DoPut payload generation, plus opt-in `app_metadata` surfacing through `include_app_metadata=true` on `Arrow.Flight.stream(...)` / `Arrow.Flight.table(...)` + * Includes high-level `FlightData <-> Arrow IPC` helpers for `Arrow.Table`, `Arrow.Stream`, and DoPut/DoExchange payload generation, plus opt-in `app_metadata` surfacing through `include_app_metadata=true` on `Arrow.Flight.stream(...)` / `Arrow.Flight.table(...)`, explicit batch-wise `app_metadata=...` emission on `Arrow.Flight.flightdata(...)`, `Arrow.Flight.putflightdata!(...)`, and source-based `Arrow.Flight.doexchange(...)`, and a reusable `Arrow.Flight.withappmetadata(...)` wrapper so source-level batch metadata can stay attached without manual keyword threading * Keeps the Flight IPC conversion layer modular under `src/flight/convert/`, with `src/flight/convert.jl` retained as a thin entrypoint * Includes client helpers for request headers, binary metadata, handshake token reuse, and TLS configuration via `withheaders`, `withtoken`, and `authenticate` * Keeps the Flight client implementation modular under `src/flight/client/`, with thin entrypoints at `src/flight/client.jl` and `src/flight/client/rpc_methods.jl` diff --git a/docs/src/manual.md b/docs/src/manual.md index 42a97828..62aef961 100644 --- a/docs/src/manual.md +++ b/docs/src/manual.md @@ -251,6 +251,17 @@ existing schema/field metadata already exposed by the source, overlays new entries on top, and returns a wrapper that can be passed directly to [`Arrow.write`](@ref), `Arrow.tobuffer`, or the Flight IPC helpers. +The Flight IPC helpers also expose batch-wise Flight `app_metadata`. +[`Arrow.Flight.stream`](@ref) and [`Arrow.Flight.table`](@ref) can surface it +with `include_app_metadata=true`, while [`Arrow.Flight.flightdata`](@ref), +[`Arrow.Flight.putflightdata!`](@ref), and source-based +[`Arrow.Flight.doexchange`](@ref) accept `app_metadata=...` to emit one payload +per record batch without dropping down to raw protocol messages. +[`Arrow.Flight.withappmetadata`](@ref) provides the same payload metadata as a +lightweight wrapper around a table or partitioned source, so the metadata can +ride with the source itself instead of being re-specified at every emit call +site. + ## Writing arrow data Ok, so that's a pretty good rundown of *reading* arrow data, but how do you *produce* arrow data? Enter `Arrow.write`. diff --git a/src/flight/client/methods/data.jl b/src/flight/client/methods/data.jl index 62da0ec0..30159773 100644 --- a/src/flight/client/methods/data.jl +++ b/src/flight/client/methods/data.jl @@ -84,6 +84,7 @@ function doput( maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, kwargs..., ) request = Channel{Protocol.FlightData}(request_capacity) @@ -103,6 +104,7 @@ function doput( maxdepth=maxdepth, metadata=metadata, colmetadata=colmetadata, + app_metadata=app_metadata, ) ) return FlightAsyncRequest(grpc_request, producer) @@ -124,6 +126,7 @@ function doput( maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, kwargs..., ) response = Channel{Protocol.PutResult}(response_capacity) @@ -143,6 +146,7 @@ function doput( maxdepth=maxdepth, metadata=metadata, colmetadata=colmetadata, + app_metadata=app_metadata, kwargs..., ) return req, response @@ -191,6 +195,7 @@ function doexchange( maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, kwargs..., ) request = Channel{Protocol.FlightData}(request_capacity) @@ -210,6 +215,7 @@ function doexchange( maxdepth=maxdepth, metadata=metadata, colmetadata=colmetadata, + app_metadata=app_metadata, ) ) return FlightAsyncRequest(grpc_request, producer) @@ -231,6 +237,7 @@ function doexchange( maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, kwargs..., ) response = Channel{Protocol.FlightData}(response_capacity) @@ -250,6 +257,7 @@ function doexchange( maxdepth=maxdepth, metadata=metadata, colmetadata=colmetadata, + app_metadata=app_metadata, kwargs..., ) return req, response diff --git a/src/flight/convert/flightdata.jl b/src/flight/convert/flightdata.jl index c515a50a..2faf2993 100644 --- a/src/flight/convert/flightdata.jl +++ b/src/flight/convert/flightdata.jl @@ -27,6 +27,82 @@ function _sourcedefaultcolmetadata(cols) return ArrowParent._normalizecolmeta(colmeta) end +struct FlightAppMetadataSource{T,M} + source::T + app_metadata::M +end + +ArrowParent.getmetadata(x::FlightAppMetadataSource) = ArrowParent.getmetadata(x.source) + +""" + Arrow.Flight.withappmetadata(source; app_metadata) + +Return a lightweight wrapper around `source` that carries batch-wise Flight +`app_metadata` alongside the Arrow payload. The wrapper can be passed directly +to [`Arrow.Flight.flightdata`](@ref), [`Arrow.Flight.putflightdata!`](@ref), +or source-based [`Arrow.Flight.doexchange`](@ref) without manually threading +`app_metadata=...` through each call site. +""" +withappmetadata(source; app_metadata) = + isnothing(app_metadata) ? source : FlightAppMetadataSource(source, app_metadata) + +function _unwrap_app_metadata_source(source, app_metadata) + source isa FlightAppMetadataSource || return source, app_metadata + isnothing(app_metadata) || throw( + ArgumentError( + "app_metadata cannot be provided both via Arrow.Flight.withappmetadata(...) and the app_metadata keyword", + ), + ) + return source.source, source.app_metadata +end + +_is_app_metadata_value(x) = x isa AbstractString || x isa AbstractVector{UInt8} + +function _normalize_app_metadata_value(value) + value === nothing && return UInt8[] + value isa AbstractString && return Vector{UInt8}(codeunits(value)) + value isa AbstractVector{UInt8} && return Vector{UInt8}(value) + throw( + ArgumentError( + "app_metadata entries must be AbstractString, AbstractVector{UInt8}, or nothing", + ), + ) +end + +function _normalize_app_metadata_source(app_metadata) + isnothing(app_metadata) && return nothing + return _is_app_metadata_value(app_metadata) ? (app_metadata,) : app_metadata +end + +_app_metadata_cursor(app_metadata) = + let metadata_iter = _normalize_app_metadata_source(app_metadata) + isnothing(metadata_iter) ? nothing : + (iter=metadata_iter, state=nothing, started=false) + end + +function _next_app_metadata(cursor) + isnothing(cursor) && return UInt8[], cursor + iter = cursor.iter + next = cursor.started ? iterate(iter, cursor.state) : iterate(iter) + isnothing(next) && throw( + ArgumentError("app_metadata was exhausted before all record batches were emitted"), + ) + value, state = next + return _normalize_app_metadata_value(value), (iter=iter, state=state, started=true) +end + +function _ensure_app_metadata_consumed(cursor) + isnothing(cursor) && return nothing + next = cursor.started ? iterate(cursor.iter, cursor.state) : iterate(cursor.iter) + isnothing(next) && return nothing + throw(ArgumentError("app_metadata contains more entries than source partitions")) +end + +function _partition_with_app_metadata(tbl, cursor) + app_metadata, cursor = _next_app_metadata(cursor) + return tbl, app_metadata, cursor +end + function _emitflightdata!( emit, source; @@ -40,14 +116,19 @@ function _emitflightdata!( maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, ) + source, app_metadata = _unwrap_app_metadata_source(source, app_metadata) dictencodings = Dict{Int64,Any}() schema = Ref{Tables.Schema}() normalized_colmetadata = ArrowParent._normalizecolmeta(colmetadata) source_meta = isnothing(metadata) ? ArrowParent.getmetadata(source) : metadata source_colmetadata = isnothing(colmetadata) ? nothing : normalized_colmetadata + app_metadata_cursor = _app_metadata_cursor(app_metadata) - for tbl in Tables.partitions(source) + for partition in Tables.partitions(source) + tbl, record_app_metadata, app_metadata_cursor = + _partition_with_app_metadata(partition, app_metadata_cursor) tblcols = Tables.columns(tbl) if isnothing(metadata) tblmeta = ArrowParent.getmetadata(tbl) @@ -120,11 +201,13 @@ function _emitflightdata!( emit( _flightdata_message( ArrowParent.makerecordbatchmsg(schema[], cols, alignment); + app_metadata=record_app_metadata, alignment=alignment, ), ) descriptor = nothing end + _ensure_app_metadata_consumed(app_metadata_cursor) return nothing end @@ -140,6 +223,7 @@ function flightdata( maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, ) messages = Protocol.FlightData[] _emitflightdata!( @@ -155,6 +239,7 @@ function flightdata( maxdepth=maxdepth, metadata=metadata, colmetadata=colmetadata, + app_metadata=app_metadata, ) return messages end @@ -173,6 +258,7 @@ function putflightdata!( maxdepth::Integer=ArrowParent.DEFAULT_MAX_DEPTH, metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, + app_metadata=nothing, ) try _emitflightdata!( @@ -188,6 +274,7 @@ function putflightdata!( maxdepth=maxdepth, metadata=metadata, colmetadata=colmetadata, + app_metadata=app_metadata, ) finally close && Base.close(sink) diff --git a/src/flight/convert/framing.jl b/src/flight/convert/framing.jl index 8b205d15..555219d9 100644 --- a/src/flight/convert/framing.jl +++ b/src/flight/convert/framing.jl @@ -27,12 +27,18 @@ end function _flightdata_message( msg::ArrowParent.Message; descriptor::Union{Nothing,Protocol.FlightDescriptor}=nothing, + app_metadata::AbstractVector{UInt8}=UInt8[], alignment::Integer=DEFAULT_IPC_ALIGNMENT, ) body = _message_body(msg, alignment) length(body) == msg.bodylen || throw(ArgumentError("FlightData body length mismatch while encoding Arrow IPC")) - return Protocol.FlightData(descriptor, Vector{UInt8}(msg.msgflatbuf), UInt8[], body) + return Protocol.FlightData( + descriptor, + Vector{UInt8}(msg.msgflatbuf), + Vector{UInt8}(app_metadata), + body, + ) end function _write_framed_message( diff --git a/src/flight/exports.jl b/src/flight/exports.jl index b595e47c..89dc88fc 100644 --- a/src/flight/exports.jl +++ b/src/flight/exports.jl @@ -43,5 +43,6 @@ export Client, streambytes, stream, table, + withappmetadata, flightdata, putflightdata! diff --git a/test/flight/ipc_conversion.jl b/test/flight/ipc_conversion.jl index b63b58a3..285e2744 100644 --- a/test/flight/ipc_conversion.jl +++ b/test/flight/ipc_conversion.jl @@ -117,15 +117,12 @@ using UUIDs @test DataAPI.metadata(metadata_parts[2], "dataset") == "flight" @test DataAPI.colmetadata(metadata_parts[2], :title, "lang") == "en" - app_metadata_messages = [ - index == 1 ? message : - Arrow.Flight.Protocol.FlightData( - message.flight_descriptor, - message.data_header, - Vector{UInt8}(codeunits("batch:$(index - 2)")), - message.data_body, - ) for (index, message) in enumerate(metadata_messages) - ] + app_metadata_messages = Arrow.Flight.flightdata( + metadata_source; + metadata=Dict("dataset" => "flight"), + colmetadata=Dict(:title => Dict("lang" => "en")), + app_metadata=("batch:0", "batch:1"), + ) metadata_batches_with_app = collect(Arrow.Flight.stream(app_metadata_messages; include_app_metadata=true)) metadata_table_with_app = @@ -138,16 +135,65 @@ using UUIDs @test metadata_table_with_app.table.title == ["red", "blue", "green"] @test String.(metadata_table_with_app.app_metadata) == ["batch:0", "batch:1"] + wrapped_metadata_source = Arrow.Flight.withappmetadata( + metadata_source; + app_metadata=("wrapped:0", "wrapped:1"), + ) + wrapped_metadata_messages = Arrow.Flight.flightdata( + wrapped_metadata_source; + metadata=Dict("dataset" => "flight"), + colmetadata=Dict(:title => Dict("lang" => "en")), + ) + wrapped_metadata_table = + Arrow.Flight.table(wrapped_metadata_messages; include_app_metadata=true) + @test wrapped_metadata_table.table.title == ["red", "blue", "green"] + @test String.(wrapped_metadata_table.app_metadata) == ["wrapped:0", "wrapped:1"] + reemitted_channel = Channel{Arrow.Flight.Protocol.FlightData}(8) - reemit_task = - @async Arrow.Flight.putflightdata!(reemitted_channel, metadata_table; close=true) + reemit_task = @async Arrow.Flight.putflightdata!( + reemitted_channel, + Arrow.Flight.withappmetadata( + metadata_table_with_app.table; + app_metadata=("batch:0", "batch:1"), + ); + close=true, + ) reemitted_messages = collect(reemitted_channel) wait(reemit_task) + @test String.(getfield.(reemitted_messages[2:end], :app_metadata)) == + ["batch:0", "batch:1"] reemitted_table = Arrow.Flight.table(reemitted_messages) @test reemitted_table.title == metadata_table.title @test DataAPI.metadata(reemitted_table, "dataset") == "flight" @test DataAPI.colmetadata(reemitted_table, :title, "lang") == "en" + app_metadata_error = try + Arrow.Flight.flightdata(metadata_source; app_metadata=("only-one",)) + nothing + catch err + err + end + @test app_metadata_error isa ArgumentError + @test occursin("app_metadata was exhausted", sprint(showerror, app_metadata_error)) + + duplicate_app_metadata_error = try + Arrow.Flight.flightdata( + Arrow.Flight.withappmetadata( + metadata_source; + app_metadata=("wrapped:0", "wrapped:1"), + ); + app_metadata=("extra:0", "extra:1"), + ) + nothing + catch err + err + end + @test duplicate_app_metadata_error isa ArgumentError + @test occursin( + "Arrow.Flight.withappmetadata", + sprint(showerror, duplicate_app_metadata_error), + ) + extension_source = ( uuid=[UUID(UInt128(1)), UUID(UInt128(2))], flag=[Arrow.Bool8(true), Arrow.Bool8(false)], diff --git a/test/flight/pyarrow_interop/exchange_tests.jl b/test/flight/pyarrow_interop/exchange_tests.jl index 572dcc3c..7d212e1c 100644 --- a/test/flight/pyarrow_interop/exchange_tests.jl +++ b/test/flight/pyarrow_interop/exchange_tests.jl @@ -22,6 +22,9 @@ function pyarrow_interop_test_exchange(client, exchange_descriptor) )) exchange_metadata = Dict("dataset" => "interop-exchange") exchange_colmetadata = Dict(:name => Dict("lang" => "en")) + exchange_app_metadata = ["client:0", "client:1"] + exchange_source = + Arrow.Flight.withappmetadata(exchange_source; app_metadata=exchange_app_metadata) exchange_req, exchange_response = Arrow.Flight.doexchange( client, exchange_source; @@ -52,18 +55,18 @@ function pyarrow_interop_test_exchange(client, exchange_descriptor) @test DataAPI.metadata(exchange_table, "dataset") == "interop-exchange" @test DataAPI.colmetadata(exchange_table, :name, "lang") == "en" @test filter(!isempty, getfield.(exchanged_messages, :app_metadata)) == - [b"exchange:0", b"exchange:1"] + [b"client:0", b"client:1"] exchange_batches_with_app = collect(Arrow.Flight.stream(exchanged_messages; include_app_metadata=true)) @test exchange_batches_with_app[1].table.id == [21, 22] @test exchange_batches_with_app[2].table.id == [23] @test String.(getproperty.(exchange_batches_with_app, :app_metadata)) == - ["exchange:0", "exchange:1"] + exchange_app_metadata exchange_table_with_app = Arrow.Flight.table(exchanged_messages; include_app_metadata=true) @test exchange_table_with_app.table.id == [21, 22, 23] @test exchange_table_with_app.table.name == ["twenty-one", "twenty-two", "twenty-three"] - @test String.(exchange_table_with_app.app_metadata) == ["exchange:0", "exchange:1"] + @test String.(exchange_table_with_app.app_metadata) == exchange_app_metadata end diff --git a/test/flight_pyarrow_server.py b/test/flight_pyarrow_server.py index 386b3bd0..b100469f 100644 --- a/test/flight_pyarrow_server.py +++ b/test/flight_pyarrow_server.py @@ -112,7 +112,9 @@ def do_exchange(self, context, descriptor, reader, writer): break if chunk.data is None: continue - metadata = pa.py_buffer(f"exchange:{batch_index}".encode("utf-8")) + metadata = chunk.app_metadata + if metadata is None: + metadata = pa.py_buffer(f"exchange:{batch_index}".encode("utf-8")) writer.write_with_metadata(chunk.data, metadata) batch_index += 1