diff --git a/.sqlxrc.sample.json b/.sqlxrc.sample.json index ee7909cf..4e6d1baf 100644 --- a/.sqlxrc.sample.json +++ b/.sqlxrc.sample.json @@ -12,7 +12,11 @@ "DB_USER": "postgres", "DB_PASS": "postgres", "DB_NAME": "postgres", - "PG_SEARCH_PATH": "public,myschema" + "PG_SEARCH_PATH": "public,myschema", + "type_mapping": { + "date": "string", + "timestamp": { "type": "DateTime", "import": "import type { DateTime } from \"luxon\"" } + } }, "db_mysql": { "DB_TYPE": "mysql", diff --git a/src/common/config.rs b/src/common/config.rs index 0bbad139..5b85da37 100644 --- a/src/common/config.rs +++ b/src/common/config.rs @@ -7,6 +7,7 @@ use regex::Regex; use serde; use serde::{Deserialize, Serialize}; use serde_json; +use serde_json::Value as JsonValue; use std::collections::HashMap; use std::fs; use std::path::PathBuf; @@ -20,6 +21,50 @@ pub struct SqlxConfig { pub connections: HashMap, } +#[derive(Clone, Debug)] +pub enum CustomTypeMapping { + Simple(String), + WithImport { type_name: String, import: String }, +} + +impl<'de> Deserialize<'de> for CustomTypeMapping { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let value = JsonValue::deserialize(deserializer)?; + match value { + JsonValue::String(s) => Ok(CustomTypeMapping::Simple(s)), + JsonValue::Object(map) => { + let type_name = map.get("type") + .and_then(|v| v.as_str()).ok_or_else(|| serde::de::Error::missing_field("type"))? + .to_string(); + let import = map.get("import") + .and_then(|v| v.as_str()).ok_or_else(|| serde::de::Error::missing_field("import"))? + .to_string(); + Ok(CustomTypeMapping::WithImport { type_name, import }) + } + _ => Err(serde::de::Error::custom("Expected a string or an object for CustomTypeMapping")), + } + } +} + +impl Serialize for CustomTypeMapping { + fn serialize(&self, serializer: S) -> Result + where S: serde::Serializer { + match self { + CustomTypeMapping::Simple(s) => serializer.serialize_str(s), + CustomTypeMapping::WithImport { type_name, import } => { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("type", type_name)?; + map.serialize_entry("import", import)?; + map.end() + } + } + } +} + pub const fn default_bool() -> bool { V } @@ -36,6 +81,8 @@ pub struct GenerateTypesConfig { pub generate_path: Option, } + + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct DbConnectionConfig { #[serde(rename = "DB_TYPE")] @@ -58,6 +105,7 @@ pub struct DbConnectionConfig { pub pool_size: u32, #[serde(rename = "CONNECTION_TIMEOUT", default = "default_connection_timeout")] pub connection_timeout: u64, + pub type_mapping: Option>, } fn default_pool_size() -> u32 { @@ -318,6 +366,11 @@ impl Config { .or_else(|| Some(default_connection_timeout())) .unwrap(); + let type_mapping = default_config + .and_then(|x| x.type_mapping.clone()); + + println!("checking {:#?}", type_mapping); + DbConnectionConfig { db_type: db_type.to_owned(), db_host, @@ -329,6 +382,7 @@ impl Config { pg_search_path: pg_search_path.to_owned(), pool_size, connection_timeout, + type_mapping, } } diff --git a/tests/custom_type_mapping.rs b/tests/custom_type_mapping.rs new file mode 100644 index 00000000..a5e9cf0b --- /dev/null +++ b/tests/custom_type_mapping.rs @@ -0,0 +1,214 @@ +#[cfg(test)] +mod custom_type_mapping_tests { + use std::fs; + use std::io::Write; + use tempfile::tempdir; + + use assert_cmd::cargo::cargo_bin_cmd; + use pretty_assertions::assert_eq; + use test_utils::test_utils::TSString; + + /// Helper: creates a temporary SQLite database, writes a .sqlxrc.json with type_mapping, + /// runs sqlx-ts, and returns the generated types. + fn run_type_mapping_test( + schema_sql: &str, + ts_content: &str, + type_mapping_json: &str, + ) -> Result<(String, String), Box> { + let dir = tempdir()?; + let parent_path = dir.path(); + + // Create the SQLite database + let db_path = parent_path.join("test.db"); + let conn = rusqlite::Connection::open(&db_path)?; + conn.execute_batch(schema_sql)?; + drop(conn); + + // Write the .sqlxrc.json config with type_mapping + let config = format!( + r#"{{ + "generate_types": {{ + "enabled": true + }}, + "connections": {{ + "default": {{ + "DB_TYPE": "sqlite", + "DB_NAME": "{}", + "type_mapping": {} + }} + }} +}}"#, + db_path.display(), + type_mapping_json + ); + let config_path = parent_path.join(".sqlxrc.json"); + let mut config_file = fs::File::create(&config_path)?; + write!(config_file, "{}", config)?; + + // Write the TS file + let file_path = parent_path.join("index.ts"); + let mut temp_file = fs::File::create(&file_path)?; + writeln!(temp_file, "{}", ts_content)?; + + // Run sqlx-ts with CLI args for DB connection + config file for type_mapping + let mut cmd = cargo_bin_cmd!("sqlx-ts"); + cmd + .arg(parent_path.to_str().unwrap()) + .arg("--ext=ts") + .arg("--db-type=sqlite") + .arg(format!("--db-name={}", db_path.display())) + .arg(format!("--config={}", config_path.display())) + .arg("-g"); + + let output = cmd.output()?; + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + + assert!( + output.status.success(), + "sqlx-ts failed!\nstdout: {stdout}\nstderr: {stderr}" + ); + + // Read generated types + let type_file_path = parent_path.join("index.queries.ts"); + let type_file = if type_file_path.exists() { + fs::read_to_string(type_file_path)? + } else { + String::new() + }; + + Ok((stdout, type_file)) + } + + #[test] + fn should_override_integer_to_string() -> Result<(), Box> { + let schema = "CREATE TABLE test_custom_types (id INTEGER PRIMARY KEY NOT NULL, count BIGINT NOT NULL);"; + + let ts_content = r#" +import { sql } from 'sqlx-ts' +const someQuery = sql`SELECT * FROM test_custom_types` +"#; + + let type_mapping = r#"{ "bigint": "string" }"#; + + let (_, type_file) = run_type_mapping_test(schema, ts_content, type_mapping)?; + + let expected = r#" +export type SomeQueryParams = []; + +export interface ISomeQueryResult { + count: string; + id: number; +} + +export interface ISomeQueryQuery { + params: SomeQueryParams; + result: ISomeQueryResult; +} +"#; + + assert_eq!( + expected.trim().to_string().flatten(), + type_file.trim().to_string().flatten() + ); + Ok(()) + } + + #[test] + fn should_override_with_union_type() -> Result<(), Box> { + let schema = "CREATE TABLE test_custom_types (id INTEGER PRIMARY KEY NOT NULL, count BIGINT NOT NULL);"; + + let ts_content = r#" +import { sql } from 'sqlx-ts' +const someQuery = sql`SELECT * FROM test_custom_types` +"#; + + let type_mapping = r#"{ "bigint": "string | number" }"#; + + let (_, type_file) = run_type_mapping_test(schema, ts_content, type_mapping)?; + + let expected = r#" +export type SomeQueryParams = []; + +export interface ISomeQueryResult { + count: string | number; + id: number; +} + +export interface ISomeQueryQuery { + params: SomeQueryParams; + result: ISomeQueryResult; +} +"#; + + assert_eq!( + expected.trim().to_string().flatten(), + type_file.trim().to_string().flatten() + ); + Ok(()) + } + + #[test] + fn should_override_with_import() -> Result<(), Box> { + let schema = "CREATE TABLE events (id INTEGER PRIMARY KEY NOT NULL, created_at DATETIME NOT NULL);"; + + let ts_content = r#" +import { sql } from 'sqlx-ts' +const someQuery = sql`SELECT * FROM events` +"#; + + let type_mapping = r#"{ "datetime": { "type": "DateTime", "import": "import type { DateTime } from \"luxon\"" } }"#; + + let (_, type_file) = run_type_mapping_test(schema, ts_content, type_mapping)?; + + // Should contain the import at the top + assert!( + type_file.contains("import type { DateTime } from \"luxon\""), + "Expected import statement in generated file, got:\n{type_file}" + ); + + // Should use the custom type + assert!( + type_file.contains("created_at: DateTime;"), + "Expected DateTime type for created_at, got:\n{type_file}" + ); + + Ok(()) + } + + #[test] + fn should_not_override_unmapped_types() -> Result<(), Box> { + let schema = "CREATE TABLE test_custom_types (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, count BIGINT NOT NULL);"; + + let ts_content = r#" +import { sql } from 'sqlx-ts' +const someQuery = sql`SELECT * FROM test_custom_types` +"#; + + // Only override bigint, text should remain string + let type_mapping = r#"{ "bigint": "string" }"#; + + let (_, type_file) = run_type_mapping_test(schema, ts_content, type_mapping)?; + + let expected = r#" +export type SomeQueryParams = []; + +export interface ISomeQueryResult { + count: string; + id: number; + name: string; +} + +export interface ISomeQueryQuery { + params: SomeQueryParams; + result: ISomeQueryResult; +} +"#; + + assert_eq!( + expected.trim().to_string().flatten(), + type_file.trim().to_string().flatten() + ); + Ok(()) + } +}