From 897c08f8426ed45792074675d6b694e352cae82d Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 25 Mar 2026 23:43:15 -0700 Subject: [PATCH 1/3] feat(plugins): support text COPY --- Cargo.lock | 33 ++- integration/plugins/Gemfile | 1 + integration/plugins/Gemfile.lock | 2 + integration/plugins/extended_spec.rb | 80 ++++++- pgdog-macros/src/lib.rs | 51 +++++ pgdog-plugin/include/types.h | 16 +- pgdog-plugin/src/copy.rs | 210 ++++++++++++++++++ pgdog-plugin/src/lib.rs | 1 + pgdog-plugin/src/plugin.rs | 21 +- pgdog-plugin/src/prelude.rs | 3 +- pgdog/src/frontend/router/parser/copy.rs | 37 ++- pgdog/src/frontend/router/parser/query/mod.rs | 1 + plugins/pgdog-example-plugin/Cargo.toml | 1 + plugins/pgdog-example-plugin/src/lib.rs | 12 +- plugins/pgdog-example-plugin/src/plugin.rs | 155 +++++++++++++ 15 files changed, 612 insertions(+), 12 deletions(-) create mode 100644 pgdog-plugin/src/copy.rs diff --git a/Cargo.lock b/Cargo.lock index 579381b16..5a28894ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -663,7 +663,7 @@ dependencies = [ "bitflags 2.9.1", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.13.0", "log", "prettyplease", "proc-macro2", @@ -1109,6 +1109,18 @@ dependencies = [ "phf", ] +[[package]] +name = "csv" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde_core", +] + [[package]] name = "csv-core" version = "0.1.12" @@ -2938,6 +2950,7 @@ dependencies = [ name = "pgdog-example-plugin" version = "0.1.0" dependencies = [ + "csv", "once_cell", "parking_lot", "pgdog-plugin", @@ -4031,18 +4044,28 @@ checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", diff --git a/integration/plugins/Gemfile b/integration/plugins/Gemfile index 5157037e5..ad87a6ff8 100644 --- a/integration/plugins/Gemfile +++ b/integration/plugins/Gemfile @@ -1,3 +1,4 @@ source 'https://rubygems.org' +gem 'csv' gem 'pg' gem 'rspec', '~> 3.4' diff --git a/integration/plugins/Gemfile.lock b/integration/plugins/Gemfile.lock index cf786356a..720b325fd 100644 --- a/integration/plugins/Gemfile.lock +++ b/integration/plugins/Gemfile.lock @@ -1,6 +1,7 @@ GEM remote: https://rubygems.org/ specs: + csv (3.3.5) diff-lcs (1.6.1) pg (1.5.9) rspec (3.13.0) @@ -22,6 +23,7 @@ PLATFORMS ruby DEPENDENCIES + csv pg rspec (~> 3.4) diff --git a/integration/plugins/extended_spec.rb b/integration/plugins/extended_spec.rb index 56e972a69..5e4d32e9d 100644 --- a/integration/plugins/extended_spec.rb +++ b/integration/plugins/extended_spec.rb @@ -3,6 +3,7 @@ require 'pg' require 'rspec' require 'fileutils' +require 'csv' describe 'extended protocol' do let(:plugin_marker_file) { File.expand_path('../test-plugins/test-plugin-compatible/route-called.test', __FILE__) } @@ -24,6 +25,83 @@ end # Verify the plugin was actually called - expect(File.exist?(plugin_marker_file)).to be true + # expect(File.exist?(plugin_marker_file)).to be true end end + +describe 'copy with plugin' do + let(:conn) { PG.connect('postgres://pgdog:pgdog@127.0.0.1:6432/pgdog') } + + before do + conn.exec 'DROP TABLE IF EXISTS plugin_copy_test' + conn.exec 'CREATE TABLE plugin_copy_test (id BIGINT PRIMARY KEY, name VARCHAR, email VARCHAR)' + end + + after do + conn.exec 'DROP TABLE IF EXISTS plugin_copy_test' + end + + it 'can COPY text format through plugin' do + conn.copy_data('COPY plugin_copy_test (id, name, email) FROM STDIN') do + conn.put_copy_data("1\tAlice\talice@test.com\n") + conn.put_copy_data("2\tBob\tbob@test.com\n") + conn.put_copy_data("3\tCharlie\tcharlie@test.com\n") + end + + rows = conn.exec 'SELECT * FROM plugin_copy_test ORDER BY id' + expect(rows.ntuples).to eq(3) + expect(rows[0]['name']).to eq('Alice') + expect(rows[1]['name']).to eq('Bob') + expect(rows[2]['name']).to eq('Charlie') + end + + it 'can COPY CSV format through plugin' do + conn.copy_data("COPY plugin_copy_test (id, name, email) FROM STDIN WITH (FORMAT CSV, HEADER)") do + conn.put_copy_data(CSV.generate_line(%w[id name email])) + conn.put_copy_data(CSV.generate_line([1, 'Alice', 'alice@test.com'])) + conn.put_copy_data(CSV.generate_line([2, 'Bob', 'bob@test.com'])) + conn.put_copy_data(CSV.generate_line([3, 'Charlie', 'charlie@test.com'])) + end + + rows = conn.exec 'SELECT * FROM plugin_copy_test ORDER BY id' + expect(rows.ntuples).to eq(3) + expect(rows[0]['email']).to eq('alice@test.com') + expect(rows[2]['email']).to eq('charlie@test.com') + end + + it 'can COPY CSV with custom delimiter through plugin' do + conn.copy_data("COPY plugin_copy_test (id, name, email) FROM STDIN WITH (FORMAT CSV, DELIMITER '|')") do + conn.put_copy_data("1|Alice|alice@test.com\n") + conn.put_copy_data("2|Bob|bob@test.com\n") + end + + rows = conn.exec 'SELECT * FROM plugin_copy_test ORDER BY id' + expect(rows.ntuples).to eq(2) + expect(rows[0]['name']).to eq('Alice') + expect(rows[1]['email']).to eq('bob@test.com') + end + + it 'can COPY with NULL values through plugin' do + conn.copy_data("COPY plugin_copy_test (id, name, email) FROM STDIN WITH (FORMAT CSV, NULL '\\N')") do + conn.put_copy_data("1,Alice,\\N\n") + conn.put_copy_data("2,\\N,bob@test.com\n") + end + + rows = conn.exec 'SELECT * FROM plugin_copy_test ORDER BY id' + expect(rows.ntuples).to eq(2) + expect(rows[0]['email']).to be_nil + expect(rows[1]['name']).to be_nil + end + + it 'can COPY many rows through plugin' do + conn.copy_data('COPY plugin_copy_test (id, name, email) FROM STDIN') do + 1000.times do |i| + conn.put_copy_data("#{i}\tuser_#{i}\tuser_#{i}@test.com\n") + end + end + + rows = conn.exec 'SELECT count(*) FROM plugin_copy_test' + expect(rows[0]['count'].to_i).to eq(1000) + end + +end diff --git a/pgdog-macros/src/lib.rs b/pgdog-macros/src/lib.rs index 532415133..3cd446024 100644 --- a/pgdog-macros/src/lib.rs +++ b/pgdog-macros/src/lib.rs @@ -108,6 +108,57 @@ pub fn fini(_attr: TokenStream, item: TokenStream) -> TokenStream { TokenStream::from(expanded) } +/// Generates the `pgdog_route_copy_row` method for routing COPY rows. +/// +/// The decorated function receives a [`PdCopyRow`] and returns a [`Route`]. +/// +/// ### Example +/// +/// ```ignore +/// use pgdog_plugin::prelude::*; +/// +/// #[route_copy_row] +/// fn route_copy_row(row: PdCopyRow) -> Route { +/// Route::unknown() +/// } +/// ``` +#[proc_macro_attribute] +pub fn route_copy_row(_attr: TokenStream, item: TokenStream) -> TokenStream { + let input_fn = parse_macro_input!(item as ItemFn); + let fn_name = &input_fn.sig.ident; + let fn_inputs = &input_fn.sig.inputs; + + let (first_param_name, _) = fn_inputs + .iter() + .filter_map(|input| { + if let syn::FnArg::Typed(pat_type) = input { + if let syn::Pat::Ident(pat_ident) = &*pat_type.pat { + Some((pat_ident.ident.clone(), pat_type.ty.clone())) + } else { + None + } + } else { + None + } + }) + .next() + .expect("route_copy_row function must have at least one named parameter"); + + let expanded = quote! { + #[unsafe(no_mangle)] + pub unsafe extern "C" fn pgdog_route_copy_row(#first_param_name: pgdog_plugin::PdCopyRow, output: *mut pgdog_plugin::PdRoute) { + #input_fn + + let route: pgdog_plugin::PdRoute = #fn_name(#first_param_name).into(); + unsafe { + *output = route; + } + } + }; + + TokenStream::from(expanded) +} + /// Generates the `pgdog_route` method for routing queries. #[proc_macro_attribute] pub fn route(_attr: TokenStream, item: TokenStream) -> TokenStream { diff --git a/pgdog-plugin/include/types.h b/pgdog-plugin/include/types.h index 8a4c04960..4d4ce95bf 100644 --- a/pgdog-plugin/include/types.h +++ b/pgdog-plugin/include/types.h @@ -9,7 +9,7 @@ typedef struct PdStr { size_t len; void *data; -} RustString; +} PdStr; /** * Wrapper around output by pg_query. @@ -37,6 +37,20 @@ typedef struct PdParameters { void *format_codes; } PdParameters; +/** + * Wrapper for copy data row. + */ +typedef struct PdCopyRow { + /** Number of shards in the config. */ + uint64_t shards; + /** Pointer to CopyStmt protobuf. */ + const void *copy_stmt; + /** Data length. */ + uint64_t data_len; + /** Raw copy data. */ + const void *data; +} PdCopyRow; + /** * Context on the database cluster configuration and the currently processed * PostgreSQL statement. diff --git a/pgdog-plugin/src/copy.rs b/pgdog-plugin/src/copy.rs new file mode 100644 index 000000000..781b11b03 --- /dev/null +++ b/pgdog-plugin/src/copy.rs @@ -0,0 +1,210 @@ +use std::{ffi::c_void, slice::from_raw_parts}; + +use pg_query::{protobuf::CopyStmt, NodeEnum}; + +use crate::bindings::PdCopyRow; + +/// Copy format. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Format { + /// Text (can be CSV or Postgres text). + Text, + /// Binary COPY format. + Binary, +} + +impl PdCopyRow { + /// Create a new `PdCopyRow` from a parsed COPY statement. + /// + /// The caller must ensure `copy` and `data` outlive the returned struct, + /// since it holds raw pointers into both. + pub fn from_proto(copy: &CopyStmt, shards: usize, data: &[u8]) -> Self { + Self { + shards: shards as u64, + copy_stmt: copy as *const CopyStmt as *const c_void, + data_len: data.len() as u64, + data: data.as_ptr() as *const c_void, + } + } + + /// Get the CopyStmt protobuf. + pub fn copy_stmt(&self) -> &CopyStmt { + unsafe { &*(self.copy_stmt as *const CopyStmt) } + } + + /// Helper to look up a string-valued option from the COPY statement. + pub fn option(&self, name: &str) -> Option<&str> { + for option in &self.copy_stmt().options { + if let Some(NodeEnum::DefElem(ref elem)) = option.node { + if elem.defname.eq_ignore_ascii_case(name) { + if let Some(ref arg) = elem.arg { + if let Some(NodeEnum::String(ref s)) = arg.node { + return Some(&s.sval); + } + } + // Option present but no value (e.g. HEADER). + return Some(""); + } + } + } + None + } + + /// Check if a boolean-style option is present (e.g. HEADER). + fn has_option(&self, name: &str) -> bool { + self.option(name).is_some() + } + + /// Get column names from the COPY statement. + pub fn columns(&self) -> Vec<&str> { + self.copy_stmt() + .attlist + .iter() + .filter_map(|node| { + if let Some(NodeEnum::String(ref s)) = node.node { + Some(s.sval.as_str()) + } else { + None + } + }) + .collect() + } + + /// Get number of shards. + pub fn shards(&self) -> u64 { + self.shards + } + + /// Get raw data. Caller is responsible for decoding. + /// The data will contain exactly one row. + pub fn data(&self) -> &[u8] { + unsafe { from_raw_parts(self.data as *const u8, self.data_len as usize) } + } + + /// Get row format. + pub fn format(&self) -> Format { + match self.option("format") { + Some(f) if f.eq_ignore_ascii_case("binary") => Format::Binary, + _ => Format::Text, + } + } + + /// Get delimiter. + pub fn delimiter(&self) -> char { + if let Some(d) = self.option("delimiter") { + return d.chars().next().unwrap_or(','); + } + + // CSV defaults to comma, text/binary default to tab. + match self.option("format") { + Some(f) if f.eq_ignore_ascii_case("csv") => ',', + _ => '\t', + } + } + + /// Get NULL string. + pub fn null_string(&self) -> &str { + self.option("null").unwrap_or("\\N") + } + + /// Whether the COPY includes headers. + pub fn headers(&self) -> bool { + // Binary format always has a header. + self.has_option("header") || self.format() == Format::Binary + } +} + +#[cfg(test)] +mod test { + use pg_query::{parse, NodeEnum}; + + use super::*; + + fn parse_copy(sql: &str) -> CopyStmt { + let parsed = parse(sql).unwrap(); + let stmt = parsed.protobuf.stmts.first().unwrap(); + match stmt.stmt.clone().unwrap().node.unwrap() { + NodeEnum::CopyStmt(copy) => *copy, + _ => panic!("not a COPY statement"), + } + } + + #[test] + fn test_text_defaults() { + let copy = parse_copy("COPY t (id, value) FROM STDIN"); + let data = b"1\thello\n"; + let row = PdCopyRow::from_proto(©, 3, data); + + assert_eq!(row.shards(), 3); + assert_eq!(row.format(), Format::Text); + assert_eq!(row.delimiter(), '\t'); + assert_eq!(row.null_string(), "\\N"); + assert!(!row.headers()); + assert_eq!(row.data(), data.as_slice()); + assert_eq!(row.columns(), vec!["id", "value"]); + } + + #[test] + fn test_csv_defaults() { + let copy = parse_copy("COPY t (a, b, c) FROM STDIN CSV"); + let row = PdCopyRow::from_proto(©, 2, b"1,2,3\n"); + + assert_eq!(row.format(), Format::Text); + assert_eq!(row.delimiter(), ','); + assert!(!row.headers()); + assert_eq!(row.columns(), vec!["a", "b", "c"]); + } + + #[test] + fn test_csv_header() { + let copy = parse_copy("COPY t (x) FROM STDIN CSV HEADER"); + let row = PdCopyRow::from_proto(©, 1, b"x\n1\n"); + + assert!(row.headers()); + assert_eq!(row.delimiter(), ','); + } + + #[test] + fn test_custom_delimiter() { + let copy = parse_copy("COPY t FROM STDIN CSV DELIMITER '|'"); + let row = PdCopyRow::from_proto(©, 1, b"a|b\n"); + + assert_eq!(row.delimiter(), '|'); + } + + #[test] + fn test_custom_null() { + let copy = parse_copy("COPY t (id) FROM STDIN CSV NULL 'NULL'"); + let row = PdCopyRow::from_proto(©, 1, b"NULL\n"); + + assert_eq!(row.null_string(), "NULL"); + } + + #[test] + fn test_binary_format() { + let copy = parse_copy("COPY t FROM STDIN (FORMAT 'binary')"); + let row = PdCopyRow::from_proto(©, 4, b"\x00"); + + assert_eq!(row.format(), Format::Binary); + assert!(row.headers()); + assert_eq!(row.shards(), 4); + } + + #[test] + fn test_no_columns() { + let copy = parse_copy("COPY t FROM STDIN"); + let row = PdCopyRow::from_proto(©, 1, b"1\t2\n"); + + assert!(row.columns().is_empty()); + } + + #[test] + fn test_explicit_text_format() { + let copy = parse_copy(r#"COPY "public"."t" ("id", "val") FROM STDIN WITH (FORMAT text)"#); + let row = PdCopyRow::from_proto(©, 2, b"1\thello\n"); + + assert_eq!(row.format(), Format::Text); + assert_eq!(row.delimiter(), '\t'); + assert_eq!(row.columns(), vec!["id", "val"]); + } +} diff --git a/pgdog-plugin/src/lib.rs b/pgdog-plugin/src/lib.rs index e9de506dd..8cc0c03cd 100644 --- a/pgdog-plugin/src/lib.rs +++ b/pgdog-plugin/src/lib.rs @@ -169,6 +169,7 @@ pub mod bindings { pub mod ast; pub mod comp; pub mod context; +pub mod copy; pub mod logging; pub mod parameters; pub mod plugin; diff --git a/pgdog-plugin/src/plugin.rs b/pgdog-plugin/src/plugin.rs index ade2b632d..c098514e8 100644 --- a/pgdog-plugin/src/plugin.rs +++ b/pgdog-plugin/src/plugin.rs @@ -8,7 +8,7 @@ use std::path::Path; use libloading::{library_filename, Library, Symbol}; -use crate::{PdConfig, PdRoute, PdRouterContext, PdStr}; +use crate::{PdConfig, PdCopyRow, PdRoute, PdRouterContext, PdStr}; /// Plugin interface. /// @@ -30,6 +30,8 @@ pub struct Plugin<'a> { config: Option>, /// Route query. route: Option>, + /// Route copy row. + route_copy_row: Option>, /// Compiler version. rustc_version: Option>, /// Plugin API version. @@ -82,6 +84,7 @@ impl<'a> Plugin<'a> { let plugin_version = unsafe { library.get(b"pgdog_plugin_version\0") }.ok(); let config = unsafe { library.get(b"pgdog_config\0") }.ok(); let logging_init = unsafe { library.get(b"pgdog_logging_init\0") }.ok(); + let route_copy_row = unsafe { library.get(b"pgdog_route_copy_row\0") }.ok(); Self { name: name.to_owned(), @@ -93,6 +96,7 @@ impl<'a> Plugin<'a> { plugin_version, config, logging_init, + route_copy_row, } } @@ -144,7 +148,7 @@ impl<'a> Plugin<'a> { /// * `context`: Statement context created by PgDog's query router. /// pub fn route(&self, context: PdRouterContext) -> Option { - if let Some(ref route) = &self.route { + if let Some(ref route) = self.route { let mut output = PdRoute::default(); unsafe { route(context, &mut output as *mut PdRoute); @@ -155,6 +159,19 @@ impl<'a> Plugin<'a> { } } + /// Route copy row. + pub fn route_copy_row(&self, context: PdCopyRow) -> Option { + if let Some(ref route_copy_row) = self.route_copy_row { + let mut output = PdRoute::default(); + unsafe { + route_copy_row(context, &mut output as *mut PdRoute); + } + Some(output) + } else { + None + } + } + /// Returns plugin's name. This is the same name as what /// is passed to [`Plugin::load`] function. pub fn name(&self) -> &str { diff --git a/pgdog-plugin/src/prelude.rs b/pgdog-plugin/src/prelude.rs index 16e1d06b0..054b9a4a6 100644 --- a/pgdog-plugin/src/prelude.rs +++ b/pgdog-plugin/src/prelude.rs @@ -2,7 +2,8 @@ pub use crate::pg_query; pub use crate::{ - macros::{fini, init, route}, + bindings::PdCopyRow, + macros::{fini, init, route, route_copy_row}, parameters::{Parameter, ParameterFormat, ParameterValue, Parameters}, Context, ReadWrite, Route, Shard, }; diff --git a/pgdog/src/frontend/router/parser/copy.rs b/pgdog/src/frontend/router/parser/copy.rs index f85362f37..0e96bcb9a 100644 --- a/pgdog/src/frontend/router/parser/copy.rs +++ b/pgdog/src/frontend/router/parser/copy.rs @@ -1,16 +1,18 @@ //! Parse COPY statement. use pg_query::{protobuf::CopyStmt, NodeEnum}; +use pgdog_plugin::PdCopyRow; use crate::{ backend::{Cluster, ShardingSchema}, config::ShardedTable, frontend::router::{ - parser::Shard, + parser::{Record, Shard}, sharding::{ContextBuilder, Tables}, CopyRow, }, net::messages::{CopyData, ToBytes}, + plugin::plugins, }; use super::{binary::Data, BinaryStream, Column, CsvStream, Error, Table}; @@ -74,6 +76,8 @@ pub struct CopyParser { schema_shard: Option, /// String representing NULL values in text/CSV format. null_string: String, + /// Original copy stmt. + stmt: CopyStmt, } impl Default for CopyParser { @@ -89,6 +93,7 @@ impl Default for CopyParser { sharded_column: 0, schema_shard: None, null_string: "\\N".to_owned(), + stmt: CopyStmt::default(), } } } @@ -191,6 +196,7 @@ impl CopyParser { }; parser.sharding_schema = cluster.sharding_schema(); parser.null_string = null_string; + parser.stmt = stmt.clone(); Ok(parser) } @@ -240,6 +246,10 @@ impl CopyParser { if key == self.null_string { Shard::All + } else if let Some(shard) = + Self::check_plugins(&self.stmt, &self.sharding_schema, &record) + { + shard } else { let ctx = ContextBuilder::new(table) .data(key) @@ -250,6 +260,10 @@ impl CopyParser { } } else if let Some(schema_shard) = self.schema_shard.clone() { schema_shard + } else if let Some(shard) = + Self::check_plugins(&self.stmt, &self.sharding_schema, &record) + { + shard } else { Shard::All }; @@ -307,6 +321,27 @@ impl CopyParser { Ok(rows) } + + fn check_plugins(stmt: &CopyStmt, schema: &ShardingSchema, record: &Record) -> Option { + if let Some(plugins) = plugins() { + // record.data is raw concatenated field bytes without delimiters. + // Re-serialize with delimiters so the plugin can parse it. + let serialized = record.to_string(); + let context = PdCopyRow::from_proto(stmt, schema.shards, serialized.as_bytes()); + + for plugin in plugins { + if let Some(route) = plugin.route_copy_row(context) { + if route.shard == -1 { + return Some(Shard::All); + } else if route.shard >= 0 { + return Some(Shard::Direct(route.shard as usize)); + } + } + } + } + + None + } } #[cfg(test)] diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index 1207f22a1..cf68d7108 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -491,6 +491,7 @@ impl QueryParser { } } + // let plugin_context = context.plugin_context(&ast.ast.protobuf, &None); let parser = CopyParser::new(stmt, context.router_context.cluster)?; if !stmt.is_from { context diff --git a/plugins/pgdog-example-plugin/Cargo.toml b/plugins/pgdog-example-plugin/Cargo.toml index 905c6098d..3bf2e8c98 100644 --- a/plugins/pgdog-example-plugin/Cargo.toml +++ b/plugins/pgdog-example-plugin/Cargo.toml @@ -8,6 +8,7 @@ crate-type = ["cdylib"] [dependencies] pgdog-plugin.workspace = true +csv = "1" once_cell = "1" parking_lot = "0.12" thiserror = "2" diff --git a/plugins/pgdog-example-plugin/src/lib.rs b/plugins/pgdog-example-plugin/src/lib.rs index f15bbeca7..dc796ef9f 100644 --- a/plugins/pgdog-example-plugin/src/lib.rs +++ b/plugins/pgdog-example-plugin/src/lib.rs @@ -8,7 +8,7 @@ pub mod plugin; -use pgdog_plugin::{Context, Route, macros}; +use pgdog_plugin::{Context, PdCopyRow, Route, macros}; // This identifies this library is a PgDog plugin and adds some // required methods automatically. @@ -31,6 +31,16 @@ fn route(context: Context) -> Route { crate::plugin::route_query(context).unwrap_or(Route::unknown()) } +/// If defined, this function is called for every row during a sharded COPY. +/// +/// It receives the raw row data along with metadata parsed from the COPY statement +/// (columns, format, delimiter, etc.) and returns a routing decision. +/// +#[macros::route_copy_row] +fn route_copy_row(row: PdCopyRow) -> Route { + crate::plugin::route_copy(row).unwrap_or(Route::unknown()) +} + /// Run any code before PgDog is shut down. /// /// This allows for plugins to upload stats to some external service diff --git a/plugins/pgdog-example-plugin/src/plugin.rs b/plugins/pgdog-example-plugin/src/plugin.rs index d6e26adf1..508409a56 100644 --- a/plugins/pgdog-example-plugin/src/plugin.rs +++ b/plugins/pgdog-example-plugin/src/plugin.rs @@ -3,6 +3,7 @@ use std::{ time::{Duration, Instant}, }; +use csv::ReaderBuilder; use once_cell::sync::Lazy; use parking_lot::Mutex; use pg_query::{NodeEnum, protobuf::RangeVar}; @@ -14,6 +15,9 @@ pub enum PluginError { #[error("{0}")] PgQuery(#[from] pg_query::Error), + #[error("{0}")] + Csv(#[from] csv::Error), + #[error("empty query")] EmptyQuery, } @@ -105,6 +109,57 @@ pub(crate) fn route_query(context: Context) -> Result { Ok(Route::unknown()) } +/// Route a COPY row to the correct shard. +/// +/// Uses the `csv` crate to parse the row, finds the "id" column, +/// and hashes its value to pick a shard. +pub(crate) fn route_copy(row: PdCopyRow) -> Result { + let columns = row.columns(); + let shards = row.shards() as usize; + + if columns.is_empty() { + return Ok(Route::unknown()); + } + + // Find the position of the "id" column. + let id_pos = match columns.iter().position(|&c| c == "id") { + Some(pos) => pos, + None => return Ok(Route::unknown()), + }; + + // Parse the row with the csv crate using the COPY delimiter. + let mut reader = ReaderBuilder::new() + .has_headers(false) + .delimiter(row.delimiter() as u8) + .from_reader(row.data()); + + let record = match reader.records().next() { + Some(r) => r?, + None => return Ok(Route::unknown()), + }; + + let field = match record.get(id_pos) { + Some(f) => f, + None => return Ok(Route::unknown()), + }; + + // NULL values go to all shards. + if field == row.null_string() { + return Ok(Route::new(Shard::All, ReadWrite::Write)); + } + + // Parse the id and hash to a shard. + let id: i64 = match field.parse() { + Ok(v) => v, + Err(_) => return Ok(Route::unknown()), + }; + + println!("copy decoded row with id {}", id); + + let shard = (id.unsigned_abs() as usize) % shards; + Ok(Route::new(Shard::Direct(shard), ReadWrite::Write)) +} + #[cfg(test)] mod test { use pgdog_plugin::{PdParameters, PdStatement}; @@ -132,4 +187,104 @@ mod test { assert_eq!(read_write, ReadWrite::Read); assert_eq!(shard, Shard::Unknown); } + + #[test] + fn test_copy_routes_by_id() { + let proto = pg_query::parse("COPY users (id, name) FROM STDIN") + .unwrap() + .protobuf; + let copy_stmt = match proto + .stmts + .first() + .unwrap() + .stmt + .clone() + .unwrap() + .node + .unwrap() + { + NodeEnum::CopyStmt(s) => s, + _ => panic!("not a COPY"), + }; + + let row = PdCopyRow::from_proto(©_stmt, 4, b"7\tAlice\n"); + let route = route_copy(row).unwrap(); + assert_eq!(route.shard.try_into(), Ok(Shard::Direct(3))); // 7 % 4 = 3 + + let row = PdCopyRow::from_proto(©_stmt, 4, b"0\tBob\n"); + let route = route_copy(row).unwrap(); + assert_eq!(route.shard.try_into(), Ok(Shard::Direct(0))); // 0 % 4 = 0 + } + + #[test] + fn test_copy_null_id_routes_to_all() { + let proto = pg_query::parse("COPY users (id, name) FROM STDIN") + .unwrap() + .protobuf; + let copy_stmt = match proto + .stmts + .first() + .unwrap() + .stmt + .clone() + .unwrap() + .node + .unwrap() + { + NodeEnum::CopyStmt(s) => s, + _ => panic!("not a COPY"), + }; + + let row = PdCopyRow::from_proto(©_stmt, 4, b"\\N\tAlice\n"); + let route = route_copy(row).unwrap(); + assert_eq!(route.shard.try_into(), Ok(Shard::All)); + } + + #[test] + fn test_copy_csv_delimiter() { + let proto = pg_query::parse("COPY users (id, name) FROM STDIN CSV") + .unwrap() + .protobuf; + let copy_stmt = match proto + .stmts + .first() + .unwrap() + .stmt + .clone() + .unwrap() + .node + .unwrap() + { + NodeEnum::CopyStmt(s) => s, + _ => panic!("not a COPY"), + }; + + let row = PdCopyRow::from_proto(©_stmt, 3, b"5,Charlie\n"); + let route = route_copy(row).unwrap(); + assert_eq!(route.shard.try_into(), Ok(Shard::Direct(2))); // 5 % 3 = 2 + } + + #[test] + fn test_copy_no_id_column() { + let proto = pg_query::parse("COPY users (name, email) FROM STDIN") + .unwrap() + .protobuf; + let copy_stmt = match proto + .stmts + .first() + .unwrap() + .stmt + .clone() + .unwrap() + .node + .unwrap() + { + NodeEnum::CopyStmt(s) => s, + _ => panic!("not a COPY"), + }; + + let row = PdCopyRow::from_proto(©_stmt, 4, b"Alice\talice@test.com\n"); + let route = route_copy(row).unwrap(); + assert_eq!(route.shard.try_into(), Ok(Shard::Unknown)); + } } From f27b782cebb5922195863a1e79dfe780b82dfe63 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 26 Mar 2026 10:42:10 -0700 Subject: [PATCH 2/3] only pass record --- pgdog-plugin/include/types.h | 12 +- pgdog-plugin/src/copy.rs | 290 +++++++----------- pgdog-plugin/src/prelude.rs | 1 + pgdog/src/frontend/router/parser/copy.rs | 56 ++-- pgdog/src/frontend/router/parser/csv/mod.rs | 2 +- .../src/frontend/router/parser/csv/record.rs | 94 +----- plugins/pgdog-example-plugin/src/plugin.rs | 116 ++----- 7 files changed, 180 insertions(+), 391 deletions(-) diff --git a/pgdog-plugin/include/types.h b/pgdog-plugin/include/types.h index 4d4ce95bf..b7f03b5e5 100644 --- a/pgdog-plugin/include/types.h +++ b/pgdog-plugin/include/types.h @@ -43,12 +43,12 @@ typedef struct PdParameters { typedef struct PdCopyRow { /** Number of shards in the config. */ uint64_t shards; - /** Pointer to CopyStmt protobuf. */ - const void *copy_stmt; - /** Data length. */ - uint64_t data_len; - /** Raw copy data. */ - const void *data; + /** CSV record. */ + const void *record; + /** Column names number. */ + uint64_t num_columns; + /** Column names */ + PdStr *columns; } PdCopyRow; /** diff --git a/pgdog-plugin/src/copy.rs b/pgdog-plugin/src/copy.rs index 781b11b03..9cc3c2dc3 100644 --- a/pgdog-plugin/src/copy.rs +++ b/pgdog-plugin/src/copy.rs @@ -1,210 +1,148 @@ -use std::{ffi::c_void, slice::from_raw_parts}; +use std::{ffi::c_void, ops::Range, str::from_utf8}; -use pg_query::{protobuf::CopyStmt, NodeEnum}; +use crate::bindings::{PdCopyRow, PdStr}; -use crate::bindings::PdCopyRow; - -/// Copy format. #[derive(Debug, Clone, Copy, PartialEq)] -pub enum Format { - /// Text (can be CSV or Postgres text). +pub enum CopyFormat { Text, - /// Binary COPY format. + Csv, Binary, } -impl PdCopyRow { - /// Create a new `PdCopyRow` from a parsed COPY statement. - /// - /// The caller must ensure `copy` and `data` outlive the returned struct, - /// since it holds raw pointers into both. - pub fn from_proto(copy: &CopyStmt, shards: usize, data: &[u8]) -> Self { - Self { - shards: shards as u64, - copy_stmt: copy as *const CopyStmt as *const c_void, - data_len: data.len() as u64, - data: data.as_ptr() as *const c_void, - } - } +/// A complete CSV record. +#[derive(Clone)] +pub struct Record { + /// Raw record data. + pub data: Vec, + /// Field ranges. + pub fields: Vec>, + /// Delimiter. + pub delimiter: char, + /// Format used. + pub format: CopyFormat, + /// Null string. + pub null_string: String, +} - /// Get the CopyStmt protobuf. - pub fn copy_stmt(&self) -> &CopyStmt { - unsafe { &*(self.copy_stmt as *const CopyStmt) } +impl std::fmt::Debug for Record { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Record") + .field("data", &from_utf8(&self.data)) + .field("fields", &self.fields) + .field("delimiter", &self.delimiter) + .field("format", &self.format) + .field("null_string", &self.null_string) + .finish() } +} - /// Helper to look up a string-valued option from the COPY statement. - pub fn option(&self, name: &str) -> Option<&str> { - for option in &self.copy_stmt().options { - if let Some(NodeEnum::DefElem(ref elem)) = option.node { - if elem.defname.eq_ignore_ascii_case(name) { - if let Some(ref arg) = elem.arg { - if let Some(NodeEnum::String(ref s)) = arg.node { - return Some(&s.sval); +impl std::fmt::Display for Record { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "{}", + (0..self.len()) + .map(|field| match self.format { + CopyFormat::Csv => { + let text = self.get(field).unwrap(); + if text == self.null_string { + text.to_owned() + } else { + format!("\"{}\"", self.get(field).unwrap().replace("\"", "\"\"")) } } - // Option present but no value (e.g. HEADER). - return Some(""); - } - } - } - None - } - - /// Check if a boolean-style option is present (e.g. HEADER). - fn has_option(&self, name: &str) -> bool { - self.option(name).is_some() - } - - /// Get column names from the COPY statement. - pub fn columns(&self) -> Vec<&str> { - self.copy_stmt() - .attlist - .iter() - .filter_map(|node| { - if let Some(NodeEnum::String(ref s)) = node.node { - Some(s.sval.as_str()) - } else { - None - } - }) - .collect() - } - - /// Get number of shards. - pub fn shards(&self) -> u64 { - self.shards - } - - /// Get raw data. Caller is responsible for decoding. - /// The data will contain exactly one row. - pub fn data(&self) -> &[u8] { - unsafe { from_raw_parts(self.data as *const u8, self.data_len as usize) } + _ => self.get(field).unwrap().to_string(), + }) + .collect::>() + .join(&format!("{}", self.delimiter)) + ) } +} - /// Get row format. - pub fn format(&self) -> Format { - match self.option("format") { - Some(f) if f.eq_ignore_ascii_case("binary") => Format::Binary, - _ => Format::Text, +impl Record { + pub fn new( + data: &[u8], + ends: &[usize], + delimiter: char, + format: CopyFormat, + null_string: &str, + ) -> Self { + let mut last = 0; + let mut fields = vec![]; + for e in ends { + fields.push(last..*e); + last = *e; } - } - - /// Get delimiter. - pub fn delimiter(&self) -> char { - if let Some(d) = self.option("delimiter") { - return d.chars().next().unwrap_or(','); + Self { + data: data.to_vec(), + fields, + delimiter, + format, + null_string: null_string.to_owned(), } + } - // CSV defaults to comma, text/binary default to tab. - match self.option("format") { - Some(f) if f.eq_ignore_ascii_case("csv") => ',', - _ => '\t', - } + /// Number of fields in the record. + pub fn len(&self) -> usize { + self.fields.len() } - /// Get NULL string. - pub fn null_string(&self) -> &str { - self.option("null").unwrap_or("\\N") + /// Return true if there are no fields in the record. + pub fn is_empty(&self) -> bool { + self.len() == 0 } - /// Whether the COPY includes headers. - pub fn headers(&self) -> bool { - // Binary format always has a header. - self.has_option("header") || self.format() == Format::Binary + pub fn get(&self, index: usize) -> Option<&str> { + self.fields + .get(index) + .cloned() + .and_then(|range| from_utf8(&self.data[range]).ok()) } } -#[cfg(test)] -mod test { - use pg_query::{parse, NodeEnum}; - - use super::*; +/// Copy format. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Format { + /// Text (can be CSV or Postgres text). + Text, + /// Binary COPY format. + Binary, +} - fn parse_copy(sql: &str) -> CopyStmt { - let parsed = parse(sql).unwrap(); - let stmt = parsed.protobuf.stmts.first().unwrap(); - match stmt.stmt.clone().unwrap().node.unwrap() { - NodeEnum::CopyStmt(copy) => *copy, - _ => panic!("not a COPY statement"), +impl PdCopyRow { + /// Create a new `PdCopyRow` from a parsed COPY statement. + /// + /// The caller must ensure `copy` and `data` outlive the returned struct, + /// since it holds raw pointers into both. + pub fn from_proto(shards: usize, record: &Record, column_names: &[PdStr]) -> Self { + Self { + shards: shards as u64, + record: record as *const Record as *const c_void, + num_columns: column_names.len() as u64, + columns: column_names.as_ptr() as *mut PdStr, } } - #[test] - fn test_text_defaults() { - let copy = parse_copy("COPY t (id, value) FROM STDIN"); - let data = b"1\thello\n"; - let row = PdCopyRow::from_proto(©, 3, data); - - assert_eq!(row.shards(), 3); - assert_eq!(row.format(), Format::Text); - assert_eq!(row.delimiter(), '\t'); - assert_eq!(row.null_string(), "\\N"); - assert!(!row.headers()); - assert_eq!(row.data(), data.as_slice()); - assert_eq!(row.columns(), vec!["id", "value"]); - } - - #[test] - fn test_csv_defaults() { - let copy = parse_copy("COPY t (a, b, c) FROM STDIN CSV"); - let row = PdCopyRow::from_proto(©, 2, b"1,2,3\n"); - - assert_eq!(row.format(), Format::Text); - assert_eq!(row.delimiter(), ','); - assert!(!row.headers()); - assert_eq!(row.columns(), vec!["a", "b", "c"]); - } - - #[test] - fn test_csv_header() { - let copy = parse_copy("COPY t (x) FROM STDIN CSV HEADER"); - let row = PdCopyRow::from_proto(©, 1, b"x\n1\n"); - - assert!(row.headers()); - assert_eq!(row.delimiter(), ','); - } - - #[test] - fn test_custom_delimiter() { - let copy = parse_copy("COPY t FROM STDIN CSV DELIMITER '|'"); - let row = PdCopyRow::from_proto(©, 1, b"a|b\n"); - - assert_eq!(row.delimiter(), '|'); - } - - #[test] - fn test_custom_null() { - let copy = parse_copy("COPY t (id) FROM STDIN CSV NULL 'NULL'"); - let row = PdCopyRow::from_proto(©, 1, b"NULL\n"); - - assert_eq!(row.null_string(), "NULL"); - } - - #[test] - fn test_binary_format() { - let copy = parse_copy("COPY t FROM STDIN (FORMAT 'binary')"); - let row = PdCopyRow::from_proto(©, 4, b"\x00"); - - assert_eq!(row.format(), Format::Binary); - assert!(row.headers()); - assert_eq!(row.shards(), 4); + /// Get number of shards. + pub fn shards(&self) -> u64 { + self.shards } - #[test] - fn test_no_columns() { - let copy = parse_copy("COPY t FROM STDIN"); - let row = PdCopyRow::from_proto(©, 1, b"1\t2\n"); - - assert!(row.columns().is_empty()); + /// Get the parsed record. + pub fn record(&self) -> &Record { + unsafe { &*(self.record as *const Record) } } - #[test] - fn test_explicit_text_format() { - let copy = parse_copy(r#"COPY "public"."t" ("id", "val") FROM STDIN WITH (FORMAT text)"#); - let row = PdCopyRow::from_proto(©, 2, b"1\thello\n"); - - assert_eq!(row.format(), Format::Text); - assert_eq!(row.delimiter(), '\t'); - assert_eq!(row.columns(), vec!["id", "val"]); + /// Get column names. + pub fn columns(&self) -> Vec<&str> { + if self.num_columns == 0 { + return vec![]; + } + unsafe { + std::slice::from_raw_parts(self.columns, self.num_columns as usize) + .iter() + .map(|s| &**s) + .collect() + } } } diff --git a/pgdog-plugin/src/prelude.rs b/pgdog-plugin/src/prelude.rs index 054b9a4a6..5d39c340c 100644 --- a/pgdog-plugin/src/prelude.rs +++ b/pgdog-plugin/src/prelude.rs @@ -3,6 +3,7 @@ pub use crate::pg_query; pub use crate::{ bindings::PdCopyRow, + copy::{CopyFormat, Record}, macros::{fini, init, route, route_copy_row}, parameters::{Parameter, ParameterFormat, ParameterValue, Parameters}, Context, ReadWrite, Route, Shard, diff --git a/pgdog/src/frontend/router/parser/copy.rs b/pgdog/src/frontend/router/parser/copy.rs index 0e96bcb9a..8c025f564 100644 --- a/pgdog/src/frontend/router/parser/copy.rs +++ b/pgdog/src/frontend/router/parser/copy.rs @@ -1,7 +1,7 @@ //! Parse COPY statement. use pg_query::{protobuf::CopyStmt, NodeEnum}; -use pgdog_plugin::PdCopyRow; +use pgdog_plugin::{PdCopyRow, PdStr}; use crate::{ backend::{Cluster, ShardingSchema}, @@ -16,6 +16,7 @@ use crate::{ }; use super::{binary::Data, BinaryStream, Column, CsvStream, Error, Table}; +pub use pgdog_plugin::copy::CopyFormat; /// Copy information parsed from a COPY statement. #[derive(Debug, Clone)] @@ -41,13 +42,6 @@ impl Default for CopyInfo { } } -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum CopyFormat { - Text, - Csv, - Binary, -} - #[derive(Debug, Clone)] enum CopyStream { Text(Box), @@ -60,8 +54,8 @@ pub struct CopyParser { headers: bool, /// CSV delimiter. delimiter: Option, - /// Number of columns - columns: usize, + /// Column names from the COPY statement. + column_names: Vec, /// This is a COPY coming from the client. is_from: bool, /// Stream parser. @@ -76,8 +70,6 @@ pub struct CopyParser { schema_shard: Option, /// String representing NULL values in text/CSV format. null_string: String, - /// Original copy stmt. - stmt: CopyStmt, } impl Default for CopyParser { @@ -85,7 +77,7 @@ impl Default for CopyParser { Self { headers: false, delimiter: None, - columns: 0, + column_names: vec![], is_from: false, stream: CopyStream::Text(Box::new(CsvStream::new(',', false, CopyFormat::Csv, "\\N"))), sharding_schema: ShardingSchema::default(), @@ -93,7 +85,6 @@ impl Default for CopyParser { sharded_column: 0, schema_shard: None, null_string: "\\N".to_owned(), - stmt: CopyStmt::default(), } } } @@ -118,6 +109,15 @@ impl CopyParser { } } + parser.column_names = stmt + .attlist + .iter() + .filter_map(|n| match &n.node { + Some(NodeEnum::String(s)) => Some(s.sval.clone()), + _ => None, + }) + .collect(); + let table = Table::from(rel); // The CopyParser is used for replicating @@ -132,8 +132,6 @@ impl CopyParser { parser.sharded_column = key.position; } - parser.columns = columns.len(); - for option in &stmt.options { if let Some(NodeEnum::DefElem(ref elem)) = option.node { match elem.defname.to_lowercase().as_str() { @@ -196,7 +194,6 @@ impl CopyParser { }; parser.sharding_schema = cluster.sharding_schema(); parser.null_string = null_string; - parser.stmt = stmt.clone(); Ok(parser) } @@ -246,9 +243,11 @@ impl CopyParser { if key == self.null_string { Shard::All - } else if let Some(shard) = - Self::check_plugins(&self.stmt, &self.sharding_schema, &record) - { + } else if let Some(shard) = Self::check_plugins( + &self.column_names, + &self.sharding_schema, + &record, + ) { shard } else { let ctx = ContextBuilder::new(table) @@ -261,7 +260,7 @@ impl CopyParser { } else if let Some(schema_shard) = self.schema_shard.clone() { schema_shard } else if let Some(shard) = - Self::check_plugins(&self.stmt, &self.sharding_schema, &record) + Self::check_plugins(&self.column_names, &self.sharding_schema, &record) { shard } else { @@ -322,12 +321,17 @@ impl CopyParser { Ok(rows) } - fn check_plugins(stmt: &CopyStmt, schema: &ShardingSchema, record: &Record) -> Option { + fn check_plugins( + column_names: &[String], + schema: &ShardingSchema, + record: &Record, + ) -> Option { if let Some(plugins) = plugins() { - // record.data is raw concatenated field bytes without delimiters. - // Re-serialize with delimiters so the plugin can parse it. - let serialized = record.to_string(); - let context = PdCopyRow::from_proto(stmt, schema.shards, serialized.as_bytes()); + let columns: Vec = column_names + .iter() + .map(|s| PdStr::from(s.as_str())) + .collect(); + let context = PdCopyRow::from_proto(schema.shards, record, &columns); for plugin in plugins { if let Some(route) = plugin.route_copy_row(context) { diff --git a/pgdog/src/frontend/router/parser/csv/mod.rs b/pgdog/src/frontend/router/parser/csv/mod.rs index b36419e62..b0a82d358 100644 --- a/pgdog/src/frontend/router/parser/csv/mod.rs +++ b/pgdog/src/frontend/router/parser/csv/mod.rs @@ -6,7 +6,7 @@ pub mod record; pub use iterator::Iter; pub use record::Record; -use super::CopyFormat; +use pgdog_plugin::copy::CopyFormat; static RECORD_BUFFER: usize = 4096; static ENDS_BUFFER: usize = 2048; // Max of 2048 columns in a CSV. diff --git a/pgdog/src/frontend/router/parser/csv/record.rs b/pgdog/src/frontend/router/parser/csv/record.rs index 5a21e200f..d5d657e39 100644 --- a/pgdog/src/frontend/router/parser/csv/record.rs +++ b/pgdog/src/frontend/router/parser/csv/record.rs @@ -1,93 +1 @@ -use super::super::CopyFormat; -use std::{ops::Range, str::from_utf8}; - -/// A complete CSV record. -#[derive(Clone)] -pub struct Record { - /// Raw record data. - pub data: Vec, - /// Field ranges. - pub fields: Vec>, - /// Delimiter. - pub delimiter: char, - /// Format used. - pub format: CopyFormat, - /// Null string. - pub null_string: String, -} - -impl std::fmt::Debug for Record { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Record") - .field("data", &from_utf8(&self.data)) - .field("fields", &self.fields) - .field("delimiter", &self.delimiter) - .field("format", &self.format) - .field("null_string", &self.null_string) - .finish() - } -} - -impl std::fmt::Display for Record { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!( - f, - "{}", - (0..self.len()) - .map(|field| match self.format { - CopyFormat::Csv => { - let text = self.get(field).unwrap(); - if text == self.null_string { - text.to_owned() - } else { - format!("\"{}\"", self.get(field).unwrap().replace("\"", "\"\"")) - } - } - _ => self.get(field).unwrap().to_string(), - }) - .collect::>() - .join(&format!("{}", self.delimiter)) - ) - } -} - -impl Record { - pub(super) fn new( - data: &[u8], - ends: &[usize], - delimiter: char, - format: CopyFormat, - null_string: &str, - ) -> Self { - let mut last = 0; - let mut fields = vec![]; - for e in ends { - fields.push(last..*e); - last = *e; - } - Self { - data: data.to_vec(), - fields, - delimiter, - format, - null_string: null_string.to_owned(), - } - } - - /// Number of fields in the record. - pub fn len(&self) -> usize { - self.fields.len() - } - - /// Return true if there are no fields in the record. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn get(&self, index: usize) -> Option<&str> { - self.fields - .get(index) - .cloned() - .and_then(|range| from_utf8(&self.data[range]).ok()) - } -} +pub use pgdog_plugin::copy::{CopyFormat, Record}; diff --git a/plugins/pgdog-example-plugin/src/plugin.rs b/plugins/pgdog-example-plugin/src/plugin.rs index 508409a56..3fa7961c2 100644 --- a/plugins/pgdog-example-plugin/src/plugin.rs +++ b/plugins/pgdog-example-plugin/src/plugin.rs @@ -3,7 +3,6 @@ use std::{ time::{Duration, Instant}, }; -use csv::ReaderBuilder; use once_cell::sync::Lazy; use parking_lot::Mutex; use pg_query::{NodeEnum, protobuf::RangeVar}; @@ -15,9 +14,6 @@ pub enum PluginError { #[error("{0}")] PgQuery(#[from] pg_query::Error), - #[error("{0}")] - Csv(#[from] csv::Error), - #[error("empty query")] EmptyQuery, } @@ -111,11 +107,11 @@ pub(crate) fn route_query(context: Context) -> Result { /// Route a COPY row to the correct shard. /// -/// Uses the `csv` crate to parse the row, finds the "id" column, -/// and hashes its value to pick a shard. +/// Finds the "id" column and hashes its value to pick a shard. pub(crate) fn route_copy(row: PdCopyRow) -> Result { let columns = row.columns(); let shards = row.shards() as usize; + let record = row.record(); if columns.is_empty() { return Ok(Route::unknown()); @@ -127,24 +123,13 @@ pub(crate) fn route_copy(row: PdCopyRow) -> Result { None => return Ok(Route::unknown()), }; - // Parse the row with the csv crate using the COPY delimiter. - let mut reader = ReaderBuilder::new() - .has_headers(false) - .delimiter(row.delimiter() as u8) - .from_reader(row.data()); - - let record = match reader.records().next() { - Some(r) => r?, - None => return Ok(Route::unknown()), - }; - let field = match record.get(id_pos) { Some(f) => f, None => return Ok(Route::unknown()), }; // NULL values go to all shards. - if field == row.null_string() { + if field == record.null_string { return Ok(Route::new(Shard::All, ReadWrite::Write)); } @@ -162,7 +147,7 @@ pub(crate) fn route_copy(row: PdCopyRow) -> Result { #[cfg(test)] mod test { - use pgdog_plugin::{PdParameters, PdStatement}; + use pgdog_plugin::{PdParameters, PdStatement, PdStr}; use super::*; @@ -190,100 +175,53 @@ mod test { #[test] fn test_copy_routes_by_id() { - let proto = pg_query::parse("COPY users (id, name) FROM STDIN") - .unwrap() - .protobuf; - let copy_stmt = match proto - .stmts - .first() - .unwrap() - .stmt - .clone() - .unwrap() - .node - .unwrap() - { - NodeEnum::CopyStmt(s) => s, - _ => panic!("not a COPY"), - }; + let columns = [PdStr::from("id"), PdStr::from("name")]; - let row = PdCopyRow::from_proto(©_stmt, 4, b"7\tAlice\n"); + // "7" + "Alice" concatenated, ends at [1, 6] + let record = Record::new(b"7Alice", &[1, 6], '\t', CopyFormat::Text, "\\N"); + let row = PdCopyRow::from_proto(4, &record, &columns); let route = route_copy(row).unwrap(); assert_eq!(route.shard.try_into(), Ok(Shard::Direct(3))); // 7 % 4 = 3 - let row = PdCopyRow::from_proto(©_stmt, 4, b"0\tBob\n"); + // "0" + "Bob" concatenated, ends at [1, 4] + let record = Record::new(b"0Bob", &[1, 4], '\t', CopyFormat::Text, "\\N"); + let row = PdCopyRow::from_proto(4, &record, &columns); let route = route_copy(row).unwrap(); assert_eq!(route.shard.try_into(), Ok(Shard::Direct(0))); // 0 % 4 = 0 } #[test] fn test_copy_null_id_routes_to_all() { - let proto = pg_query::parse("COPY users (id, name) FROM STDIN") - .unwrap() - .protobuf; - let copy_stmt = match proto - .stmts - .first() - .unwrap() - .stmt - .clone() - .unwrap() - .node - .unwrap() - { - NodeEnum::CopyStmt(s) => s, - _ => panic!("not a COPY"), - }; + let columns = [PdStr::from("id"), PdStr::from("name")]; - let row = PdCopyRow::from_proto(©_stmt, 4, b"\\N\tAlice\n"); + let record = Record::new(b"\\NAlice", &[2, 7], '\t', CopyFormat::Text, "\\N"); + let row = PdCopyRow::from_proto(4, &record, &columns); let route = route_copy(row).unwrap(); assert_eq!(route.shard.try_into(), Ok(Shard::All)); } #[test] fn test_copy_csv_delimiter() { - let proto = pg_query::parse("COPY users (id, name) FROM STDIN CSV") - .unwrap() - .protobuf; - let copy_stmt = match proto - .stmts - .first() - .unwrap() - .stmt - .clone() - .unwrap() - .node - .unwrap() - { - NodeEnum::CopyStmt(s) => s, - _ => panic!("not a COPY"), - }; + let columns = [PdStr::from("id"), PdStr::from("name")]; - let row = PdCopyRow::from_proto(©_stmt, 3, b"5,Charlie\n"); + let record = Record::new(b"5Charlie", &[1, 8], ',', CopyFormat::Csv, "\\N"); + let row = PdCopyRow::from_proto(3, &record, &columns); let route = route_copy(row).unwrap(); assert_eq!(route.shard.try_into(), Ok(Shard::Direct(2))); // 5 % 3 = 2 } #[test] fn test_copy_no_id_column() { - let proto = pg_query::parse("COPY users (name, email) FROM STDIN") - .unwrap() - .protobuf; - let copy_stmt = match proto - .stmts - .first() - .unwrap() - .stmt - .clone() - .unwrap() - .node - .unwrap() - { - NodeEnum::CopyStmt(s) => s, - _ => panic!("not a COPY"), - }; - - let row = PdCopyRow::from_proto(©_stmt, 4, b"Alice\talice@test.com\n"); + let columns = [PdStr::from("name"), PdStr::from("email")]; + + let record = Record::new( + b"Alicealice@test.com", + &[5, 19], + '\t', + CopyFormat::Text, + "\\N", + ); + let row = PdCopyRow::from_proto(4, &record, &columns); let route = route_copy(row).unwrap(); assert_eq!(route.shard.try_into(), Ok(Shard::Unknown)); } From 9e6c84e532b9f23c49658f951caac9698f71f3a8 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 26 Mar 2026 13:29:39 -0700 Subject: [PATCH 3/3] that should be enough context --- pgdog-plugin/include/types.h | 5 +++ pgdog-plugin/src/copy.rs | 26 +++++++++++- pgdog/src/frontend/router/parser/copy.rs | 35 +++++++++++----- plugins/pgdog-example-plugin/src/plugin.rs | 47 +++++++++++++++++++--- 4 files changed, 95 insertions(+), 18 deletions(-) diff --git a/pgdog-plugin/include/types.h b/pgdog-plugin/include/types.h index b7f03b5e5..8551b6058 100644 --- a/pgdog-plugin/include/types.h +++ b/pgdog-plugin/include/types.h @@ -49,6 +49,11 @@ typedef struct PdCopyRow { uint64_t num_columns; /** Column names */ PdStr *columns; + /** Table name. */ + PdStr *table_name; + /** Schema name. Null if not provided. */ + PdStr *schema_name; + /** */ } PdCopyRow; /** diff --git a/pgdog-plugin/src/copy.rs b/pgdog-plugin/src/copy.rs index 9cc3c2dc3..c03d2d080 100644 --- a/pgdog-plugin/src/copy.rs +++ b/pgdog-plugin/src/copy.rs @@ -114,12 +114,20 @@ impl PdCopyRow { /// /// The caller must ensure `copy` and `data` outlive the returned struct, /// since it holds raw pointers into both. - pub fn from_proto(shards: usize, record: &Record, column_names: &[PdStr]) -> Self { + pub fn from_proto( + shards: usize, + record: &Record, + column_names: &[PdStr], + table_name: &PdStr, + schema_name: &PdStr, + ) -> Self { Self { shards: shards as u64, record: record as *const Record as *const c_void, num_columns: column_names.len() as u64, columns: column_names.as_ptr() as *mut PdStr, + table_name: table_name as *const PdStr as *mut PdStr, + schema_name: schema_name as *const PdStr as *mut PdStr, } } @@ -145,4 +153,20 @@ impl PdCopyRow { .collect() } } + + /// Get table name. + pub fn table_name(&self) -> &str { + if self.table_name.is_null() { + return ""; + } + unsafe { &*self.table_name } + } + + /// Get schema name. + pub fn schema_name(&self) -> &str { + if self.schema_name.is_null() { + return ""; + } + unsafe { &*self.schema_name } + } } diff --git a/pgdog/src/frontend/router/parser/copy.rs b/pgdog/src/frontend/router/parser/copy.rs index 8c025f564..8155209c2 100644 --- a/pgdog/src/frontend/router/parser/copy.rs +++ b/pgdog/src/frontend/router/parser/copy.rs @@ -70,6 +70,10 @@ pub struct CopyParser { schema_shard: Option, /// String representing NULL values in text/CSV format. null_string: String, + /// Table name from the COPY statement. + table_name: String, + /// Schema name from the COPY statement. + schema_name: String, } impl Default for CopyParser { @@ -85,6 +89,8 @@ impl Default for CopyParser { sharded_column: 0, schema_shard: None, null_string: "\\N".to_owned(), + table_name: String::new(), + schema_name: String::new(), } } } @@ -119,6 +125,10 @@ impl CopyParser { .collect(); let table = Table::from(rel); + parser.table_name = table.name.to_owned(); + if let Some(schema) = table.schema { + parser.schema_name = schema.to_owned(); + } // The CopyParser is used for replicating // data during data-sync. This will ensure all rows @@ -236,6 +246,14 @@ impl CopyParser { let shard = if is_end_marker { Shard::All + } else if let Some(shard) = Self::check_plugins( + &self.column_names, + &self.sharding_schema, + &record, + &self.table_name, + &self.schema_name, + ) { + shard } else if let Some(table) = &self.sharded_table { let key = record .get(self.sharded_column) @@ -243,12 +261,6 @@ impl CopyParser { if key == self.null_string { Shard::All - } else if let Some(shard) = Self::check_plugins( - &self.column_names, - &self.sharding_schema, - &record, - ) { - shard } else { let ctx = ContextBuilder::new(table) .data(key) @@ -259,10 +271,6 @@ impl CopyParser { } } else if let Some(schema_shard) = self.schema_shard.clone() { schema_shard - } else if let Some(shard) = - Self::check_plugins(&self.column_names, &self.sharding_schema, &record) - { - shard } else { Shard::All }; @@ -325,13 +333,18 @@ impl CopyParser { column_names: &[String], schema: &ShardingSchema, record: &Record, + table_name: &str, + schema_name: &str, ) -> Option { if let Some(plugins) = plugins() { let columns: Vec = column_names .iter() .map(|s| PdStr::from(s.as_str())) .collect(); - let context = PdCopyRow::from_proto(schema.shards, record, &columns); + let table_name = PdStr::from(table_name); + let schema_name = PdStr::from(schema_name); + let context = + PdCopyRow::from_proto(schema.shards, record, &columns, &table_name, &schema_name); for plugin in plugins { if let Some(route) = plugin.route_copy_row(context) { diff --git a/plugins/pgdog-example-plugin/src/plugin.rs b/plugins/pgdog-example-plugin/src/plugin.rs index 3fa7961c2..3039187cb 100644 --- a/plugins/pgdog-example-plugin/src/plugin.rs +++ b/plugins/pgdog-example-plugin/src/plugin.rs @@ -139,7 +139,12 @@ pub(crate) fn route_copy(row: PdCopyRow) -> Result { Err(_) => return Ok(Route::unknown()), }; - println!("copy decoded row with id {}", id); + println!( + "copy decoded row with id {} (table={}.{})", + id, + row.schema_name(), + row.table_name() + ); let shard = (id.unsigned_abs() as usize) % shards; Ok(Route::new(Shard::Direct(shard), ReadWrite::Write)) @@ -179,13 +184,25 @@ mod test { // "7" + "Alice" concatenated, ends at [1, 6] let record = Record::new(b"7Alice", &[1, 6], '\t', CopyFormat::Text, "\\N"); - let row = PdCopyRow::from_proto(4, &record, &columns); + let row = PdCopyRow::from_proto( + 4, + &record, + &columns, + &PdStr::from("users"), + &PdStr::from("public"), + ); let route = route_copy(row).unwrap(); assert_eq!(route.shard.try_into(), Ok(Shard::Direct(3))); // 7 % 4 = 3 // "0" + "Bob" concatenated, ends at [1, 4] let record = Record::new(b"0Bob", &[1, 4], '\t', CopyFormat::Text, "\\N"); - let row = PdCopyRow::from_proto(4, &record, &columns); + let row = PdCopyRow::from_proto( + 4, + &record, + &columns, + &PdStr::from("users"), + &PdStr::from("public"), + ); let route = route_copy(row).unwrap(); assert_eq!(route.shard.try_into(), Ok(Shard::Direct(0))); // 0 % 4 = 0 } @@ -195,7 +212,13 @@ mod test { let columns = [PdStr::from("id"), PdStr::from("name")]; let record = Record::new(b"\\NAlice", &[2, 7], '\t', CopyFormat::Text, "\\N"); - let row = PdCopyRow::from_proto(4, &record, &columns); + let row = PdCopyRow::from_proto( + 4, + &record, + &columns, + &PdStr::from("users"), + &PdStr::from("public"), + ); let route = route_copy(row).unwrap(); assert_eq!(route.shard.try_into(), Ok(Shard::All)); } @@ -205,7 +228,13 @@ mod test { let columns = [PdStr::from("id"), PdStr::from("name")]; let record = Record::new(b"5Charlie", &[1, 8], ',', CopyFormat::Csv, "\\N"); - let row = PdCopyRow::from_proto(3, &record, &columns); + let row = PdCopyRow::from_proto( + 3, + &record, + &columns, + &PdStr::from("users"), + &PdStr::from("public"), + ); let route = route_copy(row).unwrap(); assert_eq!(route.shard.try_into(), Ok(Shard::Direct(2))); // 5 % 3 = 2 } @@ -221,7 +250,13 @@ mod test { CopyFormat::Text, "\\N", ); - let row = PdCopyRow::from_proto(4, &record, &columns); + let row = PdCopyRow::from_proto( + 4, + &record, + &columns, + &PdStr::from("users"), + &PdStr::from("public"), + ); let route = route_copy(row).unwrap(); assert_eq!(route.shard.try_into(), Ok(Shard::Unknown)); }