diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 286b16a4e..3a38fbaae 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -19,7 +19,13 @@ //! (commonly referred to as Data Definition Language, or DDL) #[cfg(not(feature = "std"))] -use alloc::{boxed::Box, format, string::String, vec, vec::Vec}; +use alloc::{ + boxed::Box, + format, + string::{String, ToString}, + vec, + vec::Vec, +}; use core::fmt::{self, Display, Write}; #[cfg(feature = "serde")] @@ -3952,3 +3958,233 @@ impl Spanned for DropFunction { Span::empty() } } + +/// CREATE OPERATOR statement +/// See +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct CreateOperator { + /// Operator name (can be schema-qualified) + pub name: ObjectName, + /// FUNCTION or PROCEDURE parameter (function name) + pub function: ObjectName, + /// Whether PROCEDURE keyword was used (vs FUNCTION) + pub is_procedure: bool, + /// LEFTARG parameter (left operand type) + pub left_arg: Option, + /// RIGHTARG parameter (right operand type) + pub right_arg: Option, + /// COMMUTATOR parameter (commutator operator) + pub commutator: Option, + /// NEGATOR parameter (negator operator) + pub negator: Option, + /// RESTRICT parameter (restriction selectivity function) + pub restrict: Option, + /// JOIN parameter (join selectivity function) + pub join: Option, + /// HASHES flag + pub hashes: bool, + /// MERGES flag + pub merges: bool, +} + +/// CREATE OPERATOR FAMILY statement +/// See +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct CreateOperatorFamily { + /// Operator family name (can be schema-qualified) + pub name: ObjectName, + /// Index method (btree, hash, gist, gin, etc.) + pub using: Ident, +} + +/// CREATE OPERATOR CLASS statement +/// See +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct CreateOperatorClass { + /// Operator class name (can be schema-qualified) + pub name: ObjectName, + /// Whether this is the default operator class for the type + pub default: bool, + /// The data type + pub for_type: DataType, + /// Index method (btree, hash, gist, gin, etc.) + pub using: Ident, + /// Optional operator family name + pub family: Option, + /// List of operator class items (operators, functions, storage) + pub items: Vec, +} + +impl fmt::Display for CreateOperator { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "CREATE OPERATOR {} (", self.name)?; + + let function_keyword = if self.is_procedure { + "PROCEDURE" + } else { + "FUNCTION" + }; + let mut params = vec![format!("{} = {}", function_keyword, self.function)]; + + if let Some(left_arg) = &self.left_arg { + params.push(format!("LEFTARG = {}", left_arg)); + } + if let Some(right_arg) = &self.right_arg { + params.push(format!("RIGHTARG = {}", right_arg)); + } + if let Some(commutator) = &self.commutator { + params.push(format!("COMMUTATOR = {}", commutator)); + } + if let Some(negator) = &self.negator { + params.push(format!("NEGATOR = {}", negator)); + } + if let Some(restrict) = &self.restrict { + params.push(format!("RESTRICT = {}", restrict)); + } + if let Some(join) = &self.join { + params.push(format!("JOIN = {}", join)); + } + if self.hashes { + params.push("HASHES".to_string()); + } + if self.merges { + params.push("MERGES".to_string()); + } + + write!(f, "{}", params.join(", "))?; + write!(f, ")") + } +} + +impl fmt::Display for CreateOperatorFamily { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "CREATE OPERATOR FAMILY {} USING {}", + self.name, self.using + ) + } +} + +impl fmt::Display for CreateOperatorClass { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "CREATE OPERATOR CLASS {}", self.name)?; + if self.default { + write!(f, " DEFAULT")?; + } + write!(f, " FOR TYPE {} USING {}", self.for_type, self.using)?; + if let Some(family) = &self.family { + write!(f, " FAMILY {}", family)?; + } + write!(f, " AS {}", display_comma_separated(&self.items)) + } +} + +/// Operator argument types for CREATE OPERATOR CLASS +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct OperatorArgTypes { + pub left: DataType, + pub right: DataType, +} + +impl fmt::Display for OperatorArgTypes { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}, {}", self.left, self.right) + } +} + +/// An item in a CREATE OPERATOR CLASS statement +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum OperatorClassItem { + /// OPERATOR clause + Operator { + strategy_number: u32, + operator_name: ObjectName, + /// Optional operator argument types + op_types: Option, + /// FOR SEARCH or FOR ORDER BY + purpose: Option, + }, + /// FUNCTION clause + Function { + support_number: u32, + /// Optional function argument types for the operator class + op_types: Option>, + function_name: ObjectName, + /// Function argument types + argument_types: Vec, + }, + /// STORAGE clause + Storage { storage_type: DataType }, +} + +/// Purpose of an operator in an operator class +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum OperatorPurpose { + ForSearch, + ForOrderBy { sort_family: ObjectName }, +} + +impl fmt::Display for OperatorClassItem { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + OperatorClassItem::Operator { + strategy_number, + operator_name, + op_types, + purpose, + } => { + write!(f, "OPERATOR {strategy_number} {operator_name}")?; + if let Some(types) = op_types { + write!(f, " ({types})")?; + } + if let Some(purpose) = purpose { + write!(f, " {purpose}")?; + } + Ok(()) + } + OperatorClassItem::Function { + support_number, + op_types, + function_name, + argument_types, + } => { + write!(f, "FUNCTION {support_number}")?; + if let Some(types) = op_types { + write!(f, " ({})", display_comma_separated(types))?; + } + write!(f, " {function_name}")?; + if !argument_types.is_empty() { + write!(f, "({})", display_comma_separated(argument_types))?; + } + Ok(()) + } + OperatorClassItem::Storage { storage_type } => { + write!(f, "STORAGE {storage_type}") + } + } + } +} + +impl fmt::Display for OperatorPurpose { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + OperatorPurpose::ForSearch => write!(f, "FOR SEARCH"), + OperatorPurpose::ForOrderBy { sort_family } => { + write!(f, "FOR ORDER BY {sort_family}") + } + } + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index aa3fb0820..482c38132 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -65,11 +65,12 @@ pub use self::ddl::{ AlterTypeAddValuePosition, AlterTypeOperation, AlterTypeRename, AlterTypeRenameValue, ClusteredBy, ColumnDef, ColumnOption, ColumnOptionDef, ColumnOptions, ColumnPolicy, ColumnPolicyProperty, ConstraintCharacteristics, CreateConnector, CreateDomain, - CreateExtension, CreateFunction, CreateIndex, CreateTable, CreateTrigger, CreateView, - Deduplicate, DeferrableInitial, DropBehavior, DropExtension, DropFunction, DropTrigger, - GeneratedAs, GeneratedExpressionMode, IdentityParameters, IdentityProperty, - IdentityPropertyFormatKind, IdentityPropertyKind, IdentityPropertyOrder, IndexColumn, - IndexOption, IndexType, KeyOrIndexDisplay, Msck, NullsDistinctOption, Owner, Partition, + CreateExtension, CreateFunction, CreateIndex, CreateOperator, CreateOperatorClass, + CreateOperatorFamily, CreateTable, CreateTrigger, CreateView, Deduplicate, DeferrableInitial, + DropBehavior, DropExtension, DropFunction, DropTrigger, GeneratedAs, GeneratedExpressionMode, + IdentityParameters, IdentityProperty, IdentityPropertyFormatKind, IdentityPropertyKind, + IdentityPropertyOrder, IndexColumn, IndexOption, IndexType, KeyOrIndexDisplay, Msck, + NullsDistinctOption, OperatorArgTypes, OperatorClassItem, OperatorPurpose, Owner, Partition, ProcedureParam, ReferentialAction, RenameTableNameKind, ReplicaIdentity, TagsColumnOption, TriggerObjectKind, Truncate, UserDefinedTypeCompositeAttributeDef, UserDefinedTypeInternalLength, UserDefinedTypeRangeOption, UserDefinedTypeRepresentation, @@ -3347,6 +3348,21 @@ pub enum Statement { /// See [Hive](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362034#LanguageManualDDL-CreateDataConnectorCreateConnector) CreateConnector(CreateConnector), /// ```sql + /// CREATE OPERATOR + /// ``` + /// See [PostgreSQL](https://www.postgresql.org/docs/current/sql-createoperator.html) + CreateOperator(CreateOperator), + /// ```sql + /// CREATE OPERATOR FAMILY + /// ``` + /// See [PostgreSQL](https://www.postgresql.org/docs/current/sql-createopfamily.html) + CreateOperatorFamily(CreateOperatorFamily), + /// ```sql + /// CREATE OPERATOR CLASS + /// ``` + /// See [PostgreSQL](https://www.postgresql.org/docs/current/sql-createopclass.html) + CreateOperatorClass(CreateOperatorClass), + /// ```sql /// ALTER TABLE /// ``` AlterTable(AlterTable), @@ -4901,6 +4917,11 @@ impl fmt::Display for Statement { Ok(()) } Statement::CreateConnector(create_connector) => create_connector.fmt(f), + Statement::CreateOperator(create_operator) => create_operator.fmt(f), + Statement::CreateOperatorFamily(create_operator_family) => { + create_operator_family.fmt(f) + } + Statement::CreateOperatorClass(create_operator_class) => create_operator_class.fmt(f), Statement::AlterTable(alter_table) => write!(f, "{alter_table}"), Statement::AlterIndex { name, operation } => { write!(f, "ALTER INDEX {name} {operation}") diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 3a4f1d028..cfaaf8f09 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -17,7 +17,8 @@ use crate::ast::{ ddl::AlterSchema, query::SelectItemQualifiedWildcardKind, AlterSchemaOperation, AlterTable, - ColumnOptions, CreateView, ExportData, Owner, TypedString, + ColumnOptions, CreateOperator, CreateOperatorClass, CreateOperatorFamily, CreateView, + ExportData, Owner, TypedString, }; use core::iter; @@ -368,6 +369,11 @@ impl Spanned for Statement { Statement::CreateSecret { .. } => Span::empty(), Statement::CreateServer { .. } => Span::empty(), Statement::CreateConnector { .. } => Span::empty(), + Statement::CreateOperator(create_operator) => create_operator.span(), + Statement::CreateOperatorFamily(create_operator_family) => { + create_operator_family.span() + } + Statement::CreateOperatorClass(create_operator_class) => create_operator_class.span(), Statement::AlterTable(alter_table) => alter_table.span(), Statement::AlterIndex { name, operation } => name.span().union(&operation.span()), Statement::AlterView { @@ -2357,6 +2363,24 @@ impl Spanned for AlterTable { } } +impl Spanned for CreateOperator { + fn span(&self) -> Span { + Span::empty() + } +} + +impl Spanned for CreateOperatorFamily { + fn span(&self) -> Span { + Span::empty() + } +} + +impl Spanned for CreateOperatorClass { + fn span(&self) -> Span { + Span::empty() + } +} + #[cfg(test)] pub mod tests { use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect}; diff --git a/src/keywords.rs b/src/keywords.rs index 7ff42b412..834d34955 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -197,6 +197,7 @@ define_keywords!( CHECK, CHECKSUM, CIRCLE, + CLASS, CLEANPATH, CLEAR, CLOB, @@ -217,6 +218,7 @@ define_keywords!( COMMENT, COMMIT, COMMITTED, + COMMUTATOR, COMPATIBLE, COMPRESSION, COMPUPDATE, @@ -385,6 +387,7 @@ define_keywords!( FAIL, FAILOVER, FALSE, + FAMILY, FETCH, FIELDS, FILE, @@ -446,6 +449,7 @@ define_keywords!( GROUPS, GZIP, HASH, + HASHES, HAVING, HEADER, HEAP, @@ -539,7 +543,10 @@ define_keywords!( LATERAL, LEAD, LEADING, + LEAKPROOF, + LEAST, LEFT, + LEFTARG, LEVEL, LIKE, LIKE_REGEX, @@ -594,6 +601,7 @@ define_keywords!( MEDIUMTEXT, MEMBER, MERGE, + MERGES, MESSAGE, METADATA, METHOD, @@ -632,6 +640,7 @@ define_keywords!( NATURAL, NCHAR, NCLOB, + NEGATOR, NEST, NESTED, NETWORK, @@ -844,6 +853,7 @@ define_keywords!( RETURNS, REVOKE, RIGHT, + RIGHTARG, RLIKE, RM, ROLE, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 1ab4626f6..f835f5417 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -4792,6 +4792,15 @@ impl<'a> Parser<'a> { self.parse_create_procedure(or_alter) } else if self.parse_keyword(Keyword::CONNECTOR) { self.parse_create_connector() + } else if self.parse_keyword(Keyword::OPERATOR) { + // Check if this is CREATE OPERATOR FAMILY or CREATE OPERATOR CLASS + if self.parse_keyword(Keyword::FAMILY) { + self.parse_create_operator_family() + } else if self.parse_keyword(Keyword::CLASS) { + self.parse_create_operator_class() + } else { + self.parse_create_operator() + } } else if self.parse_keyword(Keyword::SERVER) { self.parse_pg_create_server() } else { @@ -6436,6 +6445,281 @@ impl<'a> Parser<'a> { })) } + /// Parse an operator name, which can contain special characters like +, -, <, >, = + /// that are tokenized as operator tokens rather than identifiers. + /// This is used for PostgreSQL CREATE OPERATOR statements. + /// + /// Examples: `+`, `myschema.+`, `pg_catalog.<=` + fn parse_operator_name(&mut self) -> Result { + let mut parts = vec![]; + loop { + parts.push(ObjectNamePart::Identifier(Ident::new( + self.next_token().to_string(), + ))); + if !self.consume_token(&Token::Period) { + break; + } + } + Ok(ObjectName(parts)) + } + + /// Parse a [Statement::CreateOperator] + /// + /// [PostgreSQL Documentation](https://www.postgresql.org/docs/current/sql-createoperator.html) + pub fn parse_create_operator(&mut self) -> Result { + let name = self.parse_operator_name()?; + self.expect_token(&Token::LParen)?; + + let mut function: Option = None; + let mut is_procedure = false; + let mut left_arg: Option = None; + let mut right_arg: Option = None; + let mut commutator: Option = None; + let mut negator: Option = None; + let mut restrict: Option = None; + let mut join: Option = None; + let mut hashes = false; + let mut merges = false; + + loop { + let keyword = self.expect_one_of_keywords(&[ + Keyword::FUNCTION, + Keyword::PROCEDURE, + Keyword::LEFTARG, + Keyword::RIGHTARG, + Keyword::COMMUTATOR, + Keyword::NEGATOR, + Keyword::RESTRICT, + Keyword::JOIN, + Keyword::HASHES, + Keyword::MERGES, + ])?; + + match keyword { + Keyword::HASHES if !hashes => { + hashes = true; + } + Keyword::MERGES if !merges => { + merges = true; + } + Keyword::FUNCTION | Keyword::PROCEDURE if function.is_none() => { + self.expect_token(&Token::Eq)?; + function = Some(self.parse_object_name(false)?); + is_procedure = keyword == Keyword::PROCEDURE; + } + Keyword::LEFTARG if left_arg.is_none() => { + self.expect_token(&Token::Eq)?; + left_arg = Some(self.parse_data_type()?); + } + Keyword::RIGHTARG if right_arg.is_none() => { + self.expect_token(&Token::Eq)?; + right_arg = Some(self.parse_data_type()?); + } + Keyword::COMMUTATOR if commutator.is_none() => { + self.expect_token(&Token::Eq)?; + if self.parse_keyword(Keyword::OPERATOR) { + self.expect_token(&Token::LParen)?; + commutator = Some(self.parse_operator_name()?); + self.expect_token(&Token::RParen)?; + } else { + commutator = Some(self.parse_operator_name()?); + } + } + Keyword::NEGATOR if negator.is_none() => { + self.expect_token(&Token::Eq)?; + if self.parse_keyword(Keyword::OPERATOR) { + self.expect_token(&Token::LParen)?; + negator = Some(self.parse_operator_name()?); + self.expect_token(&Token::RParen)?; + } else { + negator = Some(self.parse_operator_name()?); + } + } + Keyword::RESTRICT if restrict.is_none() => { + self.expect_token(&Token::Eq)?; + restrict = Some(self.parse_object_name(false)?); + } + Keyword::JOIN if join.is_none() => { + self.expect_token(&Token::Eq)?; + join = Some(self.parse_object_name(false)?); + } + _ => { + return Err(ParserError::ParserError(format!( + "Duplicate or unexpected keyword {:?} in CREATE OPERATOR", + keyword + ))) + } + } + + if !self.consume_token(&Token::Comma) { + break; + } + } + + // Expect closing parenthesis + self.expect_token(&Token::RParen)?; + + // FUNCTION is required + let function = function.ok_or_else(|| { + ParserError::ParserError("CREATE OPERATOR requires FUNCTION parameter".to_string()) + })?; + + Ok(Statement::CreateOperator(CreateOperator { + name, + function, + is_procedure, + left_arg, + right_arg, + commutator, + negator, + restrict, + join, + hashes, + merges, + })) + } + + /// Parse a [Statement::CreateOperatorFamily] + /// + /// [PostgreSQL Documentation](https://www.postgresql.org/docs/current/sql-createopfamily.html) + pub fn parse_create_operator_family(&mut self) -> Result { + let name = self.parse_object_name(false)?; + self.expect_keyword(Keyword::USING)?; + let using = self.parse_identifier()?; + + Ok(Statement::CreateOperatorFamily(CreateOperatorFamily { + name, + using, + })) + } + + /// Parse a [Statement::CreateOperatorClass] + /// + /// [PostgreSQL Documentation](https://www.postgresql.org/docs/current/sql-createopclass.html) + pub fn parse_create_operator_class(&mut self) -> Result { + let name = self.parse_object_name(false)?; + let default = self.parse_keyword(Keyword::DEFAULT); + self.expect_keywords(&[Keyword::FOR, Keyword::TYPE])?; + let for_type = self.parse_data_type()?; + self.expect_keyword(Keyword::USING)?; + let using = self.parse_identifier()?; + + let family = if self.parse_keyword(Keyword::FAMILY) { + Some(self.parse_object_name(false)?) + } else { + None + }; + + self.expect_keyword(Keyword::AS)?; + + let mut items = vec![]; + loop { + if self.parse_keyword(Keyword::OPERATOR) { + let strategy_number = self.parse_literal_uint()? as u32; + let operator_name = self.parse_operator_name()?; + + // Optional operator argument types + let op_types = if self.consume_token(&Token::LParen) { + let left = self.parse_data_type()?; + self.expect_token(&Token::Comma)?; + let right = self.parse_data_type()?; + self.expect_token(&Token::RParen)?; + Some(OperatorArgTypes { left, right }) + } else { + None + }; + + // Optional purpose + let purpose = if self.parse_keyword(Keyword::FOR) { + if self.parse_keyword(Keyword::SEARCH) { + Some(OperatorPurpose::ForSearch) + } else if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) { + let sort_family = self.parse_object_name(false)?; + Some(OperatorPurpose::ForOrderBy { sort_family }) + } else { + return self.expected("SEARCH or ORDER BY after FOR", self.peek_token()); + } + } else { + None + }; + + items.push(OperatorClassItem::Operator { + strategy_number, + operator_name, + op_types, + purpose, + }); + } else if self.parse_keyword(Keyword::FUNCTION) { + let support_number = self.parse_literal_uint()? as u32; + + // Optional operator types + let op_types = + if self.consume_token(&Token::LParen) && self.peek_token() != Token::RParen { + let mut types = vec![]; + loop { + types.push(self.parse_data_type()?); + if !self.consume_token(&Token::Comma) { + break; + } + } + self.expect_token(&Token::RParen)?; + Some(types) + } else if self.consume_token(&Token::LParen) { + self.expect_token(&Token::RParen)?; + Some(vec![]) + } else { + None + }; + + let function_name = self.parse_object_name(false)?; + + // Function argument types + let argument_types = if self.consume_token(&Token::LParen) { + let mut types = vec![]; + loop { + if self.peek_token() == Token::RParen { + break; + } + types.push(self.parse_data_type()?); + if !self.consume_token(&Token::Comma) { + break; + } + } + self.expect_token(&Token::RParen)?; + types + } else { + vec![] + }; + + items.push(OperatorClassItem::Function { + support_number, + op_types, + function_name, + argument_types, + }); + } else if self.parse_keyword(Keyword::STORAGE) { + let storage_type = self.parse_data_type()?; + items.push(OperatorClassItem::Storage { storage_type }); + } else { + break; + } + + // Check for comma separator + if !self.consume_token(&Token::Comma) { + break; + } + } + + Ok(Statement::CreateOperatorClass(CreateOperatorClass { + name, + default, + for_type, + using, + family, + items, + })) + } + pub fn parse_drop(&mut self) -> Result { // MySQL dialect supports `TEMPORARY` let temporary = dialect_of!(self is MySqlDialect | GenericDialect | DuckDbDialect) diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 3bdf6d189..fbfa66588 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -6553,7 +6553,9 @@ fn parse_create_server() { #[test] fn parse_alter_schema() { - match pg_and_generic().verified_stmt("ALTER SCHEMA foo RENAME TO bar") { + // Test RENAME operation + let stmt = pg_and_generic().verified_stmt("ALTER SCHEMA foo RENAME TO bar"); + match stmt { Statement::AlterSchema(AlterSchema { operations, .. }) => { assert_eq!( operations, @@ -6565,52 +6567,26 @@ fn parse_alter_schema() { _ => unreachable!(), } - match pg_and_generic().verified_stmt("ALTER SCHEMA foo OWNER TO bar") { - Statement::AlterSchema(AlterSchema { operations, .. }) => { - assert_eq!( - operations, - vec![AlterSchemaOperation::OwnerTo { - owner: Owner::Ident("bar".into()) - }] - ); - } - _ => unreachable!(), - } - - match pg_and_generic().verified_stmt("ALTER SCHEMA foo OWNER TO CURRENT_ROLE") { - Statement::AlterSchema(AlterSchema { operations, .. }) => { - assert_eq!( - operations, - vec![AlterSchemaOperation::OwnerTo { - owner: Owner::CurrentRole - }] - ); - } - _ => unreachable!(), - } - - match pg_and_generic().verified_stmt("ALTER SCHEMA foo OWNER TO CURRENT_USER") { - Statement::AlterSchema(AlterSchema { operations, .. }) => { - assert_eq!( - operations, - vec![AlterSchemaOperation::OwnerTo { - owner: Owner::CurrentUser - }] - ); - } - _ => unreachable!(), - } - - match pg_and_generic().verified_stmt("ALTER SCHEMA foo OWNER TO SESSION_USER") { - Statement::AlterSchema(AlterSchema { operations, .. }) => { - assert_eq!( - operations, - vec![AlterSchemaOperation::OwnerTo { - owner: Owner::SessionUser - }] - ); + // Test OWNER TO operations with different owner types + for (owner_clause, expected_owner) in &[ + ("bar", Owner::Ident("bar".into())), + ("CURRENT_ROLE", Owner::CurrentRole), + ("CURRENT_USER", Owner::CurrentUser), + ("SESSION_USER", Owner::SessionUser), + ] { + let sql = format!("ALTER SCHEMA foo OWNER TO {}", owner_clause); + let stmt = pg_and_generic().verified_stmt(&sql); + match stmt { + Statement::AlterSchema(AlterSchema { operations, .. }) => { + assert_eq!( + operations, + vec![AlterSchemaOperation::OwnerTo { + owner: expected_owner.clone() + }] + ); + } + _ => unreachable!(), } - _ => unreachable!(), } } @@ -6661,3 +6637,386 @@ fn parse_foreign_key_match_with_actions() { pg_and_generic().verified_stmt(sql); } + +#[test] +fn parse_create_operator() { + let sql = "CREATE OPERATOR myschema.@@ (PROCEDURE = myschema.my_proc, LEFTARG = TIMESTAMP WITH TIME ZONE, RIGHTARG = VARCHAR(255), COMMUTATOR = schema.>, NEGATOR = schema.<=, RESTRICT = myschema.sel_func, JOIN = myschema.join_func, HASHES, MERGES)"; + assert_eq!( + pg().verified_stmt(sql), + Statement::CreateOperator(CreateOperator { + name: ObjectName::from(vec![Ident::new("myschema"), Ident::new("@@")]), + function: ObjectName::from(vec![Ident::new("myschema"), Ident::new("my_proc")]), + is_procedure: true, + left_arg: Some(DataType::Timestamp(None, TimezoneInfo::WithTimeZone)), + right_arg: Some(DataType::Varchar(Some(CharacterLength::IntegerLength { + length: 255, + unit: None + }))), + commutator: Some(ObjectName::from(vec![ + Ident::new("schema"), + Ident::new(">") + ])), + negator: Some(ObjectName::from(vec![ + Ident::new("schema"), + Ident::new("<=") + ])), + restrict: Some(ObjectName::from(vec![ + Ident::new("myschema"), + Ident::new("sel_func") + ])), + join: Some(ObjectName::from(vec![ + Ident::new("myschema"), + Ident::new("join_func") + ])), + hashes: true, + merges: true, + }) + ); + + for op_symbol in &[ + "-", "*", "/", "<", ">", "=", "<=", ">=", "<>", "~", "!", "@", "#", "%", "^", "&", "|", + "<<", ">>", "&&", + ] { + assert_eq!( + pg().verified_stmt(&format!("CREATE OPERATOR {op_symbol} (FUNCTION = f)")), + Statement::CreateOperator(CreateOperator { + name: ObjectName::from(vec![Ident::new(*op_symbol)]), + function: ObjectName::from(vec![Ident::new("f")]), + is_procedure: false, + left_arg: None, + right_arg: None, + commutator: None, + negator: None, + restrict: None, + join: None, + hashes: false, + merges: false, + }) + ); + } + + pg().one_statement_parses_to( + "CREATE OPERATOR != (FUNCTION = func)", + "CREATE OPERATOR <> (FUNCTION = func)", + ); + + for (name, expected_name) in [ + ( + "s1.+", + ObjectName::from(vec![Ident::new("s1"), Ident::new("+")]), + ), + ( + "s2.-", + ObjectName::from(vec![Ident::new("s2"), Ident::new("-")]), + ), + ( + "s1.s3.*", + ObjectName::from(vec![Ident::new("s1"), Ident::new("s3"), Ident::new("*")]), + ), + ] { + match pg().verified_stmt(&format!("CREATE OPERATOR {name} (FUNCTION = f)")) { + Statement::CreateOperator(CreateOperator { + name, + hashes: false, + merges: false, + .. + }) => { + assert_eq!(name, expected_name); + } + _ => unreachable!(), + } + } + + pg().one_statement_parses_to( + "CREATE OPERATOR + (FUNCTION = f, COMMUTATOR = OPERATOR(>), NEGATOR = OPERATOR(>=))", + "CREATE OPERATOR + (FUNCTION = f, COMMUTATOR = >, NEGATOR = >=)", + ); + + // Test all duplicate clause errors + for field in &[ + "FUNCTION = f2", + "PROCEDURE = p", + "LEFTARG = INT4, LEFTARG = INT4", + "RIGHTARG = INT4, RIGHTARG = INT4", + "COMMUTATOR = -, COMMUTATOR = *", + "NEGATOR = -, NEGATOR = *", + "RESTRICT = f1, RESTRICT = f2", + "JOIN = f1, JOIN = f2", + "HASHES, HASHES", + "MERGES, MERGES", + ] { + assert!(pg() + .parse_sql_statements(&format!("CREATE OPERATOR + (FUNCTION = f, {field})")) + .is_err()); + } + + // Test missing FUNCTION/PROCEDURE error + assert!(pg() + .parse_sql_statements("CREATE OPERATOR + (LEFTARG = INT4)") + .is_err()); + + // Test empty parameter list error + assert!(pg().parse_sql_statements("CREATE OPERATOR + ()").is_err()); + + // Test nested empty parentheses error + assert!(pg().parse_sql_statements("CREATE OPERATOR > (()").is_err()); + assert!(pg().parse_sql_statements("CREATE OPERATOR > ())").is_err()); +} + +#[test] +fn parse_create_operator_family() { + for index_method in &["btree", "hash", "gist", "gin", "spgist", "brin"] { + assert_eq!( + pg().verified_stmt(&format!( + "CREATE OPERATOR FAMILY my_family USING {index_method}" + )), + Statement::CreateOperatorFamily(CreateOperatorFamily { + name: ObjectName::from(vec![Ident::new("my_family")]), + using: Ident::new(*index_method), + }) + ); + assert_eq!( + pg().verified_stmt(&format!( + "CREATE OPERATOR FAMILY myschema.test_family USING {index_method}" + )), + Statement::CreateOperatorFamily(CreateOperatorFamily { + name: ObjectName::from(vec![Ident::new("myschema"), Ident::new("test_family")]), + using: Ident::new(*index_method), + }) + ); + } +} + +#[test] +fn parse_create_operator_class() { + // Test all combinations of DEFAULT flag and FAMILY clause with different name qualifications + for (is_default, default_clause) in [(false, ""), (true, "DEFAULT ")] { + for (has_family, family_clause) in [(false, ""), (true, " FAMILY int4_family")] { + for (class_name, expected_name) in [ + ("int4_ops", ObjectName::from(vec![Ident::new("int4_ops")])), + ( + "myschema.test_ops", + ObjectName::from(vec![Ident::new("myschema"), Ident::new("test_ops")]), + ), + ] { + let sql = format!( + "CREATE OPERATOR CLASS {class_name} {default_clause}FOR TYPE INT4 USING btree{family_clause} AS OPERATOR 1 <" + ); + match pg().verified_stmt(&sql) { + Statement::CreateOperatorClass(CreateOperatorClass { + name, + default, + ref for_type, + ref using, + ref family, + ref items, + }) => { + assert_eq!(name, expected_name); + assert_eq!(default, is_default); + assert_eq!(for_type, &DataType::Int4(None)); + assert_eq!(using, &Ident::new("btree")); + assert_eq!( + family, + &if has_family { + Some(ObjectName::from(vec![Ident::new("int4_family")])) + } else { + None + } + ); + assert_eq!(items.len(), 1); + } + _ => panic!("Expected CreateOperatorClass statement"), + } + } + } + } + + // Test comprehensive operator class with all fields + match pg().verified_stmt("CREATE OPERATOR CLASS CAS_btree_ops DEFAULT FOR TYPE CAS USING btree FAMILY CAS_btree_ops AS OPERATOR 1 <, OPERATOR 2 <=, OPERATOR 3 =, OPERATOR 4 >=, OPERATOR 5 >, FUNCTION 1 cas_cmp(CAS, CAS)") { + Statement::CreateOperatorClass(CreateOperatorClass { + name, + default: true, + ref for_type, + ref using, + ref family, + ref items, + }) => { + assert_eq!(name, ObjectName::from(vec![Ident::new("CAS_btree_ops")])); + assert_eq!(for_type, &DataType::Custom(ObjectName::from(vec![Ident::new("CAS")]), vec![])); + assert_eq!(using, &Ident::new("btree")); + assert_eq!(family, &Some(ObjectName::from(vec![Ident::new("CAS_btree_ops")]))); + assert_eq!(items.len(), 6); + } + _ => panic!("Expected CreateOperatorClass statement"), + } + + // Test operator with argument types + match pg().verified_stmt( + "CREATE OPERATOR CLASS test_ops FOR TYPE INT4 USING gist AS OPERATOR 1 < (INT4, INT4)", + ) { + Statement::CreateOperatorClass(CreateOperatorClass { ref items, .. }) => { + assert_eq!(items.len(), 1); + match &items[0] { + OperatorClassItem::Operator { + strategy_number: 1, + ref operator_name, + op_types: + Some(OperatorArgTypes { + left: DataType::Int4(None), + right: DataType::Int4(None), + }), + purpose: None, + } => { + assert_eq!(operator_name, &ObjectName::from(vec![Ident::new("<")])); + } + _ => panic!("Expected Operator item with arg types"), + } + } + _ => panic!("Expected CreateOperatorClass statement"), + } + + // Test operator FOR SEARCH + match pg().verified_stmt( + "CREATE OPERATOR CLASS test_ops FOR TYPE INT4 USING gist AS OPERATOR 1 < FOR SEARCH", + ) { + Statement::CreateOperatorClass(CreateOperatorClass { ref items, .. }) => { + assert_eq!(items.len(), 1); + match &items[0] { + OperatorClassItem::Operator { + strategy_number: 1, + ref operator_name, + op_types: None, + purpose: Some(OperatorPurpose::ForSearch), + } => { + assert_eq!(operator_name, &ObjectName::from(vec![Ident::new("<")])); + } + _ => panic!("Expected Operator item FOR SEARCH"), + } + } + _ => panic!("Expected CreateOperatorClass statement"), + } + + // Test operator FOR ORDER BY + match pg().verified_stmt("CREATE OPERATOR CLASS test_ops FOR TYPE INT4 USING gist AS OPERATOR 2 <<-> FOR ORDER BY float_ops") { + Statement::CreateOperatorClass(CreateOperatorClass { + ref items, + .. + }) => { + assert_eq!(items.len(), 1); + match &items[0] { + OperatorClassItem::Operator { + strategy_number: 2, + ref operator_name, + op_types: None, + purpose: Some(OperatorPurpose::ForOrderBy { ref sort_family }), + } => { + assert_eq!(operator_name, &ObjectName::from(vec![Ident::new("<<->")])); + assert_eq!(sort_family, &ObjectName::from(vec![Ident::new("float_ops")])); + } + _ => panic!("Expected Operator item FOR ORDER BY"), + } + } + _ => panic!("Expected CreateOperatorClass statement"), + } + + // Test function with operator class arg types + match pg().verified_stmt("CREATE OPERATOR CLASS test_ops FOR TYPE INT4 USING btree AS FUNCTION 1 (INT4, INT4) btcmp(INT4, INT4)") { + Statement::CreateOperatorClass(CreateOperatorClass { + ref items, + .. + }) => { + assert_eq!(items.len(), 1); + match &items[0] { + OperatorClassItem::Function { + support_number: 1, + op_types: Some(_), + ref function_name, + ref argument_types, + } => { + assert_eq!(function_name, &ObjectName::from(vec![Ident::new("btcmp")])); + assert_eq!(argument_types.len(), 2); + } + _ => panic!("Expected Function item with op_types"), + } + } + _ => panic!("Expected CreateOperatorClass statement"), + } + + // Test function with no arguments (empty parentheses normalizes to no parentheses) + pg().one_statement_parses_to( + "CREATE OPERATOR CLASS test_ops FOR TYPE INT4 USING btree AS FUNCTION 1 my_func()", + "CREATE OPERATOR CLASS test_ops FOR TYPE INT4 USING btree AS FUNCTION 1 my_func", + ); + match pg().verified_stmt( + "CREATE OPERATOR CLASS test_ops FOR TYPE INT4 USING btree AS FUNCTION 1 my_func", + ) { + Statement::CreateOperatorClass(CreateOperatorClass { ref items, .. }) => { + assert_eq!(items.len(), 1); + match &items[0] { + OperatorClassItem::Function { + support_number: 1, + op_types: None, + ref function_name, + ref argument_types, + } => { + assert_eq!( + function_name, + &ObjectName::from(vec![Ident::new("my_func")]) + ); + assert_eq!(argument_types.len(), 0); + } + _ => panic!("Expected Function item without op_types and no arguments"), + } + } + _ => panic!("Expected CreateOperatorClass statement"), + } + + // Test multiple items including STORAGE + match pg().verified_stmt("CREATE OPERATOR CLASS gist_ops FOR TYPE geometry USING gist AS OPERATOR 1 <<, FUNCTION 1 gist_consistent(internal, geometry, INT4), STORAGE box") { + Statement::CreateOperatorClass(CreateOperatorClass { + ref items, + .. + }) => { + assert_eq!(items.len(), 3); + // Check operator item + match &items[0] { + OperatorClassItem::Operator { + strategy_number: 1, + ref operator_name, + .. + } => { + assert_eq!(operator_name, &ObjectName::from(vec![Ident::new("<<")])); + } + _ => panic!("Expected Operator item"), + } + // Check function item + match &items[1] { + OperatorClassItem::Function { + support_number: 1, + ref function_name, + ref argument_types, + .. + } => { + assert_eq!(function_name, &ObjectName::from(vec![Ident::new("gist_consistent")])); + assert_eq!(argument_types.len(), 3); + } + _ => panic!("Expected Function item"), + } + // Check storage item + match &items[2] { + OperatorClassItem::Storage { ref storage_type } => { + assert_eq!(storage_type, &DataType::Custom(ObjectName::from(vec![Ident::new("box")]), vec![])); + } + _ => panic!("Expected Storage item"), + } + } + _ => panic!("Expected CreateOperatorClass statement"), + } + + // Test nested empty parentheses error in function arguments + assert!(pg() + .parse_sql_statements( + "CREATE OPERATOR CLASS test_ops FOR TYPE INT4 USING btree AS FUNCTION 1 cas_cmp(()" + ) + .is_err()); +}