From 0a4263c9ab3213017d8b4798874958bc8050eb37 Mon Sep 17 00:00:00 2001 From: LucaCappelletti94 Date: Thu, 26 Feb 2026 23:26:22 +0100 Subject: [PATCH 1/2] feat(postgres): implement ALTER FUNCTION/AGGREGATE parsing parity Add ALTER FUNCTION/ALTER AGGREGATE support through AST and parser dispatch, including action parsing and aggregate signature handling. Tighten PostgreSQL parity semantics in parser branches (function-only DEPENDS ON EXTENSION, stricter aggregate signature argument rules) and add VARIADIC function argument mode support. --- src/ast/ddl.rs | 216 +++++++++++++++++++++++++++++++- src/ast/mod.rs | 25 +++- src/ast/spans.rs | 1 + src/keywords.rs | 1 + src/parser/mod.rs | 312 +++++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 545 insertions(+), 10 deletions(-) diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 0c4f93e647..9987bfbfe6 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -47,10 +47,10 @@ use crate::ast::{ FunctionDeterminismSpecifier, FunctionParallel, FunctionSecurity, HiveDistributionStyle, HiveFormat, HiveIOFormat, HiveRowFormat, HiveSetLocation, Ident, InitializeKind, MySQLColumnPosition, ObjectName, OnCommit, OneOrManyWithParens, OperateFunctionArg, - OrderByExpr, ProjectionSelect, Query, RefreshModeKind, RowAccessPolicy, SequenceOptions, - Spanned, SqlOption, StorageSerializationPolicy, TableVersion, Tag, TriggerEvent, - TriggerExecBody, TriggerObject, TriggerPeriod, TriggerReferencing, Value, ValueWithSpan, - WrappedCollection, + OrderByExpr, ProjectionSelect, Query, RefreshModeKind, ResetConfig, RowAccessPolicy, + SequenceOptions, Spanned, SqlOption, StorageSerializationPolicy, TableVersion, Tag, + TriggerEvent, TriggerExecBody, TriggerObject, TriggerPeriod, TriggerReferencing, Value, + ValueWithSpan, WrappedCollection, }; use crate::display_utils::{DisplayCommaSeparated, Indent, NewLine, SpaceOrNewline}; use crate::keywords::Keyword; @@ -5121,6 +5121,214 @@ impl Spanned for AlterOperatorClass { } } +/// `ALTER FUNCTION` / `ALTER AGGREGATE` statement. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct AlterFunction { + /// Object type being altered. + pub kind: AlterFunctionKind, + /// Function or aggregate signature. + pub function: FunctionDesc, + /// `ORDER BY` argument list for aggregate signatures. + /// + /// This is only used for `ALTER AGGREGATE`. + pub aggregate_order_by: Option>, + /// Whether the aggregate signature uses `*`. + /// + /// This is only used for `ALTER AGGREGATE`. + pub aggregate_star: bool, + /// Operation applied to the object. + pub operation: AlterFunctionOperation, +} + +/// Function-like object type used by [`AlterFunction`]. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum AlterFunctionKind { + /// `FUNCTION` + Function, + /// `AGGREGATE` + Aggregate, +} + +impl fmt::Display for AlterFunctionKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Function => write!(f, "FUNCTION"), + Self::Aggregate => write!(f, "AGGREGATE"), + } + } +} + +/// Operation for `ALTER FUNCTION` / `ALTER AGGREGATE`. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum AlterFunctionOperation { + /// `RENAME TO new_name` + RenameTo { + /// New unqualified function or aggregate name. + new_name: Ident, + }, + /// `OWNER TO { new_owner | CURRENT_ROLE | CURRENT_USER | SESSION_USER }` + OwnerTo(Owner), + /// `SET SCHEMA schema_name` + SetSchema { + /// The target schema name. + schema_name: ObjectName, + }, + /// `[ NO ] DEPENDS ON EXTENSION extension_name` + DependsOnExtension { + /// `true` when `NO DEPENDS ON EXTENSION`. + no: bool, + /// Extension name. + extension_name: ObjectName, + }, + /// `action [ ... ] [ RESTRICT ]` (function only). + Actions { + /// One or more function actions. + actions: Vec, + /// Whether `RESTRICT` is present. + restrict: bool, + }, +} + +/// Function action in `ALTER FUNCTION ... action [ ... ] [ RESTRICT ]`. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum AlterFunctionAction { + /// `CALLED ON NULL INPUT` / `RETURNS NULL ON NULL INPUT` / `STRICT` + CalledOnNull(FunctionCalledOnNull), + /// `IMMUTABLE` / `STABLE` / `VOLATILE` + Behavior(FunctionBehavior), + /// `[ NOT ] LEAKPROOF` + Leakproof(bool), + /// `[ EXTERNAL ] SECURITY { DEFINER | INVOKER }` + Security { + /// Whether the optional `EXTERNAL` keyword was present. + external: bool, + /// Security mode. + security: FunctionSecurity, + }, + /// `PARALLEL { UNSAFE | RESTRICTED | SAFE }` + Parallel(FunctionParallel), + /// `COST execution_cost` + Cost(Expr), + /// `ROWS result_rows` + Rows(Expr), + /// `SUPPORT support_function` + Support(ObjectName), + /// `SET configuration_parameter { TO | = } { value | DEFAULT }` + /// or `SET configuration_parameter FROM CURRENT` + Set(FunctionDefinitionSetParam), + /// `RESET configuration_parameter` or `RESET ALL` + Reset(ResetConfig), +} + +impl fmt::Display for AlterFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "ALTER {} ", self.kind)?; + match self.kind { + AlterFunctionKind::Function => { + write!(f, "{} ", self.function)?; + } + AlterFunctionKind::Aggregate => { + write!(f, "{}(", self.function.name)?; + if self.aggregate_star { + write!(f, "*")?; + } else { + if let Some(args) = &self.function.args { + write!(f, "{}", display_comma_separated(args))?; + } + if let Some(order_by_args) = &self.aggregate_order_by { + if self + .function + .args + .as_ref() + .is_some_and(|args| !args.is_empty()) + { + write!(f, " ")?; + } + write!(f, "ORDER BY {}", display_comma_separated(order_by_args))?; + } + } + write!(f, ") ")?; + } + } + write!(f, "{}", self.operation) + } +} + +impl fmt::Display for AlterFunctionOperation { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AlterFunctionOperation::RenameTo { new_name } => { + write!(f, "RENAME TO {new_name}") + } + AlterFunctionOperation::OwnerTo(owner) => write!(f, "OWNER TO {owner}"), + AlterFunctionOperation::SetSchema { schema_name } => { + write!(f, "SET SCHEMA {schema_name}") + } + AlterFunctionOperation::DependsOnExtension { no, extension_name } => { + if *no { + write!(f, "NO DEPENDS ON EXTENSION {extension_name}") + } else { + write!(f, "DEPENDS ON EXTENSION {extension_name}") + } + } + AlterFunctionOperation::Actions { actions, restrict } => { + write!(f, "{}", display_separated(actions, " "))?; + if *restrict { + write!(f, " RESTRICT")?; + } + Ok(()) + } + } + } +} + +impl fmt::Display for AlterFunctionAction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AlterFunctionAction::CalledOnNull(called_on_null) => write!(f, "{called_on_null}"), + AlterFunctionAction::Behavior(behavior) => write!(f, "{behavior}"), + AlterFunctionAction::Leakproof(leakproof) => { + if *leakproof { + write!(f, "LEAKPROOF") + } else { + write!(f, "NOT LEAKPROOF") + } + } + AlterFunctionAction::Security { external, security } => { + if *external { + write!(f, "EXTERNAL ")?; + } + write!(f, "{security}") + } + AlterFunctionAction::Parallel(parallel) => write!(f, "{parallel}"), + AlterFunctionAction::Cost(execution_cost) => write!(f, "COST {execution_cost}"), + AlterFunctionAction::Rows(result_rows) => write!(f, "ROWS {result_rows}"), + AlterFunctionAction::Support(support_function) => { + write!(f, "SUPPORT {support_function}") + } + AlterFunctionAction::Set(set_param) => write!(f, "{set_param}"), + AlterFunctionAction::Reset(reset_config) => match reset_config { + ResetConfig::ALL => write!(f, "RESET ALL"), + ResetConfig::ConfigName(name) => write!(f, "RESET {name}"), + }, + } + } +} + +impl Spanned for AlterFunction { + fn span(&self) -> Span { + Span::empty() + } +} + /// CREATE POLICY statement. /// /// See [PostgreSQL](https://www.postgresql.org/docs/current/sql-createpolicy.html) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 1e430171ee..cea133e4f2 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -60,7 +60,8 @@ pub use self::dcl::{ SetConfigValue, Use, }; pub use self::ddl::{ - Alignment, AlterColumnOperation, AlterConnectorOwner, AlterIndexOperation, AlterOperator, + Alignment, AlterColumnOperation, AlterConnectorOwner, AlterFunction, AlterFunctionAction, + AlterFunctionKind, AlterFunctionOperation, AlterIndexOperation, AlterOperator, AlterOperatorClass, AlterOperatorClassOperation, AlterOperatorFamily, AlterOperatorFamilyOperation, AlterOperatorOperation, AlterPolicy, AlterPolicyOperation, AlterSchema, AlterSchemaOperation, AlterTable, AlterTableAlgorithm, AlterTableLock, @@ -3739,6 +3740,13 @@ pub enum Statement { with_options: Vec, }, /// ```sql + /// ALTER FUNCTION + /// ALTER AGGREGATE + /// ``` + /// See [PostgreSQL](https://www.postgresql.org/docs/current/sql-alterfunction.html) + /// and [PostgreSQL](https://www.postgresql.org/docs/current/sql-alteraggregate.html) + AlterFunction(AlterFunction), + /// ```sql /// ALTER TYPE /// See [PostgreSQL](https://www.postgresql.org/docs/current/sql-altertype.html) /// ``` @@ -5462,6 +5470,7 @@ impl fmt::Display for Statement { } write!(f, " AS {query}") } + Statement::AlterFunction(alter_function) => write!(f, "{alter_function}"), Statement::AlterType(AlterType { name, operation }) => { write!(f, "ALTER TYPE {name} {operation}") } @@ -9687,6 +9696,8 @@ pub enum ArgMode { Out, /// `INOUT` mode. InOut, + /// `VARIADIC` mode. + Variadic, } impl fmt::Display for ArgMode { @@ -9695,6 +9706,7 @@ impl fmt::Display for ArgMode { ArgMode::In => write!(f, "IN"), ArgMode::Out => write!(f, "OUT"), ArgMode::InOut => write!(f, "INOUT"), + ArgMode::Variadic => write!(f, "VARIADIC"), } } } @@ -9751,6 +9763,8 @@ impl fmt::Display for FunctionSecurity { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub enum FunctionSetValue { + /// SET param = DEFAULT / SET param TO DEFAULT + Default, /// SET param = value1, value2, ... Values(Vec), /// SET param FROM CURRENT @@ -9765,7 +9779,7 @@ pub enum FunctionSetValue { #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct FunctionDefinitionSetParam { /// The name of the configuration parameter. - pub name: Ident, + pub name: ObjectName, /// The value to set for the parameter. pub value: FunctionSetValue, } @@ -9774,6 +9788,7 @@ impl fmt::Display for FunctionDefinitionSetParam { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "SET {} ", self.name)?; match &self.value { + FunctionSetValue::Default => write!(f, "= DEFAULT"), FunctionSetValue::Values(values) => { write!(f, "= {}", display_comma_separated(values)) } @@ -11918,6 +11933,12 @@ impl From for Statement { } } +impl From for Statement { + fn from(a: AlterFunction) -> Self { + Self::AlterFunction(a) + } +} + impl From for Statement { fn from(a: AlterType) -> Self { Self::AlterType(a) diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 0b95c3ed70..6b103d7969 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -402,6 +402,7 @@ impl Spanned for Statement { .chain(with_options.iter().map(|i| i.span())), ), // These statements need to be implemented + Statement::AlterFunction { .. } => Span::empty(), Statement::AlterType { .. } => Span::empty(), Statement::AlterOperator { .. } => Span::empty(), Statement::AlterOperatorFamily { .. } => Span::empty(), diff --git a/src/keywords.rs b/src/keywords.rs index cc2b9e9dd0..df5084df25 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -1127,6 +1127,7 @@ define_keywords!( VARCHAR2, VARIABLE, VARIABLES, + VARIADIC, VARYING, VAR_POP, VAR_SAMP, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index bea566bbe8..5a3d6f4c3d 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -5673,15 +5673,19 @@ impl<'a> Parser<'a> { return self.expected_ref("DEFINER or INVOKER", self.peek_token_ref()); } } else if self.parse_keyword(Keyword::SET) { - let name = self.parse_identifier()?; + let name = self.parse_object_name(false)?; let value = if self.parse_keywords(&[Keyword::FROM, Keyword::CURRENT]) { FunctionSetValue::FromCurrent } else { if !self.consume_token(&Token::Eq) && !self.parse_keyword(Keyword::TO) { return self.expected_ref("= or TO", self.peek_token_ref()); } - let values = self.parse_comma_separated(Parser::parse_expr)?; - FunctionSetValue::Values(values) + if self.parse_keyword(Keyword::DEFAULT) { + FunctionSetValue::Default + } else { + let values = self.parse_comma_separated(Parser::parse_expr)?; + FunctionSetValue::Values(values) + } }; set_params.push(FunctionDefinitionSetParam { name, value }); } else if self.parse_keyword(Keyword::RETURN) { @@ -5955,6 +5959,8 @@ impl<'a> Parser<'a> { Some(ArgMode::Out) } else if self.parse_keyword(Keyword::INOUT) { Some(ArgMode::InOut) + } else if self.parse_keyword(Keyword::VARIADIC) { + Some(ArgMode::Variadic) } else { None }; @@ -6007,6 +6013,69 @@ impl<'a> Parser<'a> { }) } + fn parse_aggregate_function_arg(&mut self) -> Result { + let mode = if self.parse_keyword(Keyword::IN) { + Some(ArgMode::In) + } else { + if self + .peek_one_of_keywords(&[Keyword::OUT, Keyword::INOUT, Keyword::VARIADIC]) + .is_some() + { + return self.expected_ref( + "IN or argument type in aggregate signature", + self.peek_token_ref(), + ); + } + None + }; + + // Parse: [ argname ] argtype, but do not consume ORDER from + // `... argtype ORDER BY ...` as a type-name disambiguator. + let mut name = None; + let mut data_type = self.parse_data_type()?; + let data_type_idx = self.get_current_index(); + + fn parse_data_type_for_aggregate_arg(parser: &mut Parser) -> Result { + if parser.peek_keyword(Keyword::DEFAULT) + || parser.peek_keyword(Keyword::ORDER) + || parser.peek_token_ref().token == Token::Comma + || parser.peek_token_ref().token == Token::RParen + { + // Dummy error ignored by maybe_parse + parser_err!( + "The current token cannot start an aggregate argument type", + parser.peek_token_ref().span.start + ) + } else { + parser.parse_data_type() + } + } + + if let Some(next_data_type) = self.maybe_parse(parse_data_type_for_aggregate_arg)? { + let token = self.token_at(data_type_idx); + if !matches!(token.token, Token::Word(_)) { + return self.expected("a name or type", token.clone()); + } + + name = Some(Ident::new(token.to_string())); + data_type = next_data_type; + } + + if self.peek_keyword(Keyword::DEFAULT) || self.peek_token_ref().token == Token::Eq { + return self.expected_ref( + "',' or ')' or ORDER BY after aggregate argument type", + self.peek_token_ref(), + ); + } + + Ok(OperateFunctionArg { + mode, + name, + data_type, + default_expr: None, + }) + } + /// Parse statements of the DropTrigger type such as: /// /// ```sql @@ -10433,6 +10502,8 @@ impl<'a> Parser<'a> { Keyword::TYPE, Keyword::TABLE, Keyword::INDEX, + Keyword::FUNCTION, + Keyword::AGGREGATE, Keyword::ROLE, Keyword::POLICY, Keyword::CONNECTOR, @@ -10472,6 +10543,8 @@ impl<'a> Parser<'a> { operation, }) } + Keyword::FUNCTION => self.parse_alter_function(AlterFunctionKind::Function), + Keyword::AGGREGATE => self.parse_alter_function(AlterFunctionKind::Aggregate), Keyword::OPERATOR => { if self.parse_keyword(Keyword::FAMILY) { self.parse_alter_operator_family().map(Into::into) @@ -10487,11 +10560,242 @@ impl<'a> Parser<'a> { Keyword::USER => self.parse_alter_user().map(Into::into), // unreachable because expect_one_of_keywords used above unexpected_keyword => Err(ParserError::ParserError( - format!("Internal parser error: expected any of {{VIEW, TYPE, TABLE, INDEX, ROLE, POLICY, CONNECTOR, ICEBERG, SCHEMA, USER, OPERATOR}}, got {unexpected_keyword:?}"), + format!("Internal parser error: expected any of {{VIEW, TYPE, TABLE, INDEX, FUNCTION, AGGREGATE, ROLE, POLICY, CONNECTOR, ICEBERG, SCHEMA, USER, OPERATOR}}, got {unexpected_keyword:?}"), )), } } + fn parse_unreserved_keyword(&mut self, expected: &str) -> bool { + match &self.peek_token_ref().token { + Token::Word(w) if w.quote_style.is_none() && w.value.eq_ignore_ascii_case(expected) => { + self.advance_token(); + true + } + _ => false, + } + } + + fn parse_alter_aggregate_signature( + &mut self, + ) -> Result<(FunctionDesc, bool, Option>), ParserError> { + let name = self.parse_object_name(false)?; + self.expect_token(&Token::LParen)?; + + if self.consume_token(&Token::Mul) { + self.expect_token(&Token::RParen)?; + return Ok(( + FunctionDesc { + name, + args: Some(vec![]), + }, + true, + None, + )); + } + + let args = + if self.peek_keyword(Keyword::ORDER) || self.peek_token_ref().token == Token::RParen { + vec![] + } else { + self.parse_comma_separated(Parser::parse_aggregate_function_arg)? + }; + + let aggregate_order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) { + Some(self.parse_comma_separated(Parser::parse_aggregate_function_arg)?) + } else { + None + }; + + self.expect_token(&Token::RParen)?; + Ok(( + FunctionDesc { + name, + args: Some(args), + }, + false, + aggregate_order_by, + )) + } + + fn parse_alter_function_action(&mut self) -> Result, ParserError> { + let action = if self.parse_keywords(&[ + Keyword::CALLED, + Keyword::ON, + Keyword::NULL, + Keyword::INPUT, + ]) { + Some(AlterFunctionAction::CalledOnNull( + FunctionCalledOnNull::CalledOnNullInput, + )) + } else if self.parse_keywords(&[ + Keyword::RETURNS, + Keyword::NULL, + Keyword::ON, + Keyword::NULL, + Keyword::INPUT, + ]) { + Some(AlterFunctionAction::CalledOnNull( + FunctionCalledOnNull::ReturnsNullOnNullInput, + )) + } else if self.parse_keyword(Keyword::STRICT) { + Some(AlterFunctionAction::CalledOnNull( + FunctionCalledOnNull::Strict, + )) + } else if self.parse_keyword(Keyword::IMMUTABLE) { + Some(AlterFunctionAction::Behavior(FunctionBehavior::Immutable)) + } else if self.parse_keyword(Keyword::STABLE) { + Some(AlterFunctionAction::Behavior(FunctionBehavior::Stable)) + } else if self.parse_keyword(Keyword::VOLATILE) { + Some(AlterFunctionAction::Behavior(FunctionBehavior::Volatile)) + } else if self.parse_keyword(Keyword::NOT) { + self.expect_keyword(Keyword::LEAKPROOF)?; + Some(AlterFunctionAction::Leakproof(false)) + } else if self.parse_keyword(Keyword::LEAKPROOF) { + Some(AlterFunctionAction::Leakproof(true)) + } else if self.parse_keyword(Keyword::EXTERNAL) { + self.expect_keyword(Keyword::SECURITY)?; + let security = if self.parse_keyword(Keyword::DEFINER) { + FunctionSecurity::Definer + } else if self.parse_keyword(Keyword::INVOKER) { + FunctionSecurity::Invoker + } else { + return self.expected_ref("DEFINER or INVOKER", self.peek_token_ref()); + }; + Some(AlterFunctionAction::Security { + external: true, + security, + }) + } else if self.parse_keyword(Keyword::SECURITY) { + let security = if self.parse_keyword(Keyword::DEFINER) { + FunctionSecurity::Definer + } else if self.parse_keyword(Keyword::INVOKER) { + FunctionSecurity::Invoker + } else { + return self.expected_ref("DEFINER or INVOKER", self.peek_token_ref()); + }; + Some(AlterFunctionAction::Security { + external: false, + security, + }) + } else if self.parse_keyword(Keyword::PARALLEL) { + let parallel = if self.parse_keyword(Keyword::UNSAFE) { + FunctionParallel::Unsafe + } else if self.parse_keyword(Keyword::RESTRICTED) { + FunctionParallel::Restricted + } else if self.parse_keyword(Keyword::SAFE) { + FunctionParallel::Safe + } else { + return self + .expected_ref("one of UNSAFE | RESTRICTED | SAFE", self.peek_token_ref()); + }; + Some(AlterFunctionAction::Parallel(parallel)) + } else if self.parse_unreserved_keyword("COST") { + Some(AlterFunctionAction::Cost(self.parse_number()?)) + } else if self.parse_keyword(Keyword::ROWS) { + Some(AlterFunctionAction::Rows(self.parse_number()?)) + } else if self.parse_keyword(Keyword::SUPPORT) { + Some(AlterFunctionAction::Support(self.parse_object_name(false)?)) + } else if self.parse_keyword(Keyword::SET) { + let name = self.parse_object_name(false)?; + let value = if self.parse_keywords(&[Keyword::FROM, Keyword::CURRENT]) { + FunctionSetValue::FromCurrent + } else { + if !self.consume_token(&Token::Eq) && !self.parse_keyword(Keyword::TO) { + return self.expected_ref("= or TO", self.peek_token_ref()); + } + if self.parse_keyword(Keyword::DEFAULT) { + FunctionSetValue::Default + } else { + FunctionSetValue::Values(self.parse_comma_separated(Parser::parse_expr)?) + } + }; + Some(AlterFunctionAction::Set(FunctionDefinitionSetParam { + name, + value, + })) + } else if self.parse_keyword(Keyword::RESET) { + let reset_config = if self.parse_keyword(Keyword::ALL) { + ResetConfig::ALL + } else { + ResetConfig::ConfigName(self.parse_object_name(false)?) + }; + Some(AlterFunctionAction::Reset(reset_config)) + } else { + None + }; + + Ok(action) + } + + fn parse_alter_function_actions( + &mut self, + ) -> Result<(Vec, bool), ParserError> { + let mut actions = vec![]; + while let Some(action) = self.parse_alter_function_action()? { + actions.push(action); + } + if actions.is_empty() { + return self.expected_ref("at least one ALTER FUNCTION action", self.peek_token_ref()); + } + let restrict = self.parse_keyword(Keyword::RESTRICT); + Ok((actions, restrict)) + } + + /// Parse an `ALTER FUNCTION` or `ALTER AGGREGATE` statement. + pub fn parse_alter_function( + &mut self, + kind: AlterFunctionKind, + ) -> Result { + let (function, aggregate_star, aggregate_order_by) = match kind { + AlterFunctionKind::Function => (self.parse_function_desc()?, false, None), + AlterFunctionKind::Aggregate => self.parse_alter_aggregate_signature()?, + }; + + let operation = if self.parse_keywords(&[Keyword::RENAME, Keyword::TO]) { + let new_name = self.parse_identifier()?; + AlterFunctionOperation::RenameTo { new_name } + } else if self.parse_keywords(&[Keyword::OWNER, Keyword::TO]) { + AlterFunctionOperation::OwnerTo(self.parse_owner()?) + } else if self.parse_keywords(&[Keyword::SET, Keyword::SCHEMA]) { + AlterFunctionOperation::SetSchema { + schema_name: self.parse_object_name(false)?, + } + } else if matches!(kind, AlterFunctionKind::Function) && self.parse_keyword(Keyword::NO) { + if !self.parse_unreserved_keyword("DEPENDS") { + return self.expected_ref("DEPENDS after NO", self.peek_token_ref()); + } + self.expect_keywords(&[Keyword::ON, Keyword::EXTENSION])?; + AlterFunctionOperation::DependsOnExtension { + no: true, + extension_name: self.parse_object_name(false)?, + } + } else if matches!(kind, AlterFunctionKind::Function) + && self.parse_unreserved_keyword("DEPENDS") + { + self.expect_keywords(&[Keyword::ON, Keyword::EXTENSION])?; + AlterFunctionOperation::DependsOnExtension { + no: false, + extension_name: self.parse_object_name(false)?, + } + } else if matches!(kind, AlterFunctionKind::Function) { + let (actions, restrict) = self.parse_alter_function_actions()?; + AlterFunctionOperation::Actions { actions, restrict } + } else { + return self.expected_ref( + "RENAME TO, OWNER TO, or SET SCHEMA after ALTER AGGREGATE", + self.peek_token_ref(), + ); + }; + + Ok(Statement::AlterFunction(AlterFunction { + kind, + function, + aggregate_order_by, + aggregate_star, + operation, + })) + } + /// Parse a [Statement::AlterTable] pub fn parse_alter_table(&mut self, iceberg: bool) -> Result { let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); From 25e267c6dca7c6261a9c2d06e993bfbce3b83a19 Mon Sep 17 00:00:00 2001 From: LucaCappelletti94 Date: Thu, 26 Feb 2026 23:26:26 +0100 Subject: [PATCH 2/2] test(postgres): cover ALTER FUNCTION/AGGREGATE variants Add a focused PostgreSQL test matrix for ALTER FUNCTION and ALTER AGGREGATE covering valid forms, canonical output, and strict-parity rejection cases. --- tests/sqlparser_postgres.rs | 219 ++++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 7c19f51e5e..2816834aa0 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -7932,6 +7932,225 @@ fn parse_alter_operator_class() { .is_err()); } +#[test] +fn parse_alter_function_and_aggregate() { + for (sql, expected) in [ + ( + "ALTER AGGREGATE alt_func1(int) RENAME TO alt_func3", + "ALTER AGGREGATE alt_func1(INT) RENAME TO alt_func3", + ), + ( + "ALTER AGGREGATE alt_func1(int) OWNER TO regress_alter_generic_user3", + "ALTER AGGREGATE alt_func1(INT) OWNER TO regress_alter_generic_user3", + ), + ( + "ALTER AGGREGATE alt_func1(int) SET SCHEMA alt_nsp2", + "ALTER AGGREGATE alt_func1(INT) SET SCHEMA alt_nsp2", + ), + ( + "ALTER AGGREGATE alt_agg1(int) RENAME TO alt_agg2", + "ALTER AGGREGATE alt_agg1(INT) RENAME TO alt_agg2", + ), + ( + "ALTER AGGREGATE alt_agg1(int) RENAME TO alt_agg3", + "ALTER AGGREGATE alt_agg1(INT) RENAME TO alt_agg3", + ), + ( + "ALTER AGGREGATE alt_agg2(int) OWNER TO regress_alter_generic_user2", + "ALTER AGGREGATE alt_agg2(INT) OWNER TO regress_alter_generic_user2", + ), + ( + "ALTER AGGREGATE alt_agg2(int) OWNER TO regress_alter_generic_user3", + "ALTER AGGREGATE alt_agg2(INT) OWNER TO regress_alter_generic_user3", + ), + ( + "ALTER AGGREGATE alt_agg2(int) SET SCHEMA alt_nsp2", + "ALTER AGGREGATE alt_agg2(INT) SET SCHEMA alt_nsp2", + ), + ( + "ALTER AGGREGATE alt_order(int ORDER BY text) RENAME TO alt_order2", + "ALTER AGGREGATE alt_order(INT ORDER BY TEXT) RENAME TO alt_order2", + ), + ( + "ALTER AGGREGATE alt_order_only(ORDER BY int) SET SCHEMA alt_nsp2", + "ALTER AGGREGATE alt_order_only(ORDER BY INT) SET SCHEMA alt_nsp2", + ), + ( + "ALTER AGGREGATE alt_star(*) OWNER TO regress_alter_generic_user2", + "ALTER AGGREGATE alt_star(*) OWNER TO regress_alter_generic_user2", + ), + ] { + let statement = pg_and_generic().one_statement_parses_to(sql, expected); + assert!(matches!( + statement, + Statement::AlterFunction(AlterFunction { + kind: AlterFunctionKind::Aggregate, + .. + }) + )); + } + + for (sql, expected) in [ + ( + "ALTER FUNCTION alt_func1(int) RENAME TO alt_func2", + "ALTER FUNCTION alt_func1(INT) RENAME TO alt_func2", + ), + ( + "ALTER FUNCTION alt_func1(int) RENAME TO alt_func3", + "ALTER FUNCTION alt_func1(INT) RENAME TO alt_func3", + ), + ( + "ALTER FUNCTION alt_func2(int) OWNER TO regress_alter_generic_user2", + "ALTER FUNCTION alt_func2(INT) OWNER TO regress_alter_generic_user2", + ), + ( + "ALTER FUNCTION alt_func2(int) OWNER TO regress_alter_generic_user3", + "ALTER FUNCTION alt_func2(INT) OWNER TO regress_alter_generic_user3", + ), + ( + "ALTER FUNCTION alt_func2(int) SET SCHEMA alt_nsp1", + "ALTER FUNCTION alt_func2(INT) SET SCHEMA alt_nsp1", + ), + ( + "ALTER FUNCTION alt_func2(int) SET SCHEMA alt_nsp2", + "ALTER FUNCTION alt_func2(INT) SET SCHEMA alt_nsp2", + ), + ( + "ALTER FUNCTION alt_func2(int) DEPENDS ON EXTENSION ext1", + "ALTER FUNCTION alt_func2(INT) DEPENDS ON EXTENSION ext1", + ), + ( + "ALTER FUNCTION alt_func2(int) NO DEPENDS ON EXTENSION ext1", + "ALTER FUNCTION alt_func2(INT) NO DEPENDS ON EXTENSION ext1", + ), + ( + "ALTER FUNCTION alt_func2 IMMUTABLE", + "ALTER FUNCTION alt_func2 IMMUTABLE", + ), + ( + "ALTER FUNCTION alt_func2(int) IMMUTABLE", + "ALTER FUNCTION alt_func2(INT) IMMUTABLE", + ), + ( + "ALTER FUNCTION alt_func2(int) STABLE", + "ALTER FUNCTION alt_func2(INT) STABLE", + ), + ( + "ALTER FUNCTION alt_func2(int) VOLATILE", + "ALTER FUNCTION alt_func2(INT) VOLATILE", + ), + ( + "ALTER FUNCTION alt_func2(int) CALLED ON NULL INPUT", + "ALTER FUNCTION alt_func2(INT) CALLED ON NULL INPUT", + ), + ( + "ALTER FUNCTION alt_func2(int) RETURNS NULL ON NULL INPUT", + "ALTER FUNCTION alt_func2(INT) RETURNS NULL ON NULL INPUT", + ), + ( + "ALTER FUNCTION alt_func2(int) STRICT", + "ALTER FUNCTION alt_func2(INT) STRICT", + ), + ( + "ALTER FUNCTION alt_func2(int) LEAKPROOF", + "ALTER FUNCTION alt_func2(INT) LEAKPROOF", + ), + ( + "ALTER FUNCTION alt_func2(int) NOT LEAKPROOF", + "ALTER FUNCTION alt_func2(INT) NOT LEAKPROOF", + ), + ( + "ALTER FUNCTION alt_func2(int) SECURITY DEFINER", + "ALTER FUNCTION alt_func2(INT) SECURITY DEFINER", + ), + ( + "ALTER FUNCTION alt_func2(int) EXTERNAL SECURITY INVOKER", + "ALTER FUNCTION alt_func2(INT) EXTERNAL SECURITY INVOKER", + ), + ( + "ALTER FUNCTION alt_func2(int) PARALLEL SAFE", + "ALTER FUNCTION alt_func2(INT) PARALLEL SAFE", + ), + ( + "ALTER FUNCTION alt_func2(int) PARALLEL RESTRICTED", + "ALTER FUNCTION alt_func2(INT) PARALLEL RESTRICTED", + ), + ( + "ALTER FUNCTION alt_func2(int) PARALLEL UNSAFE", + "ALTER FUNCTION alt_func2(INT) PARALLEL UNSAFE", + ), + ( + "ALTER FUNCTION alt_func2(int) COST 3.5", + "ALTER FUNCTION alt_func2(INT) COST 3.5", + ), + ( + "ALTER FUNCTION alt_func2(int) ROWS 42", + "ALTER FUNCTION alt_func2(INT) ROWS 42", + ), + ( + "ALTER FUNCTION alt_func2(int) SUPPORT pg_catalog.alt_support", + "ALTER FUNCTION alt_func2(INT) SUPPORT pg_catalog.alt_support", + ), + ( + "ALTER FUNCTION alt_func2(int) SET work_mem TO DEFAULT", + "ALTER FUNCTION alt_func2(INT) SET work_mem = DEFAULT", + ), + ( + "ALTER FUNCTION alt_func2(int) SET work_mem FROM CURRENT", + "ALTER FUNCTION alt_func2(INT) SET work_mem FROM CURRENT", + ), + ( + "ALTER FUNCTION alt_func2(int) SET search_path = pg_catalog, public", + "ALTER FUNCTION alt_func2(INT) SET search_path = pg_catalog, public", + ), + ( + "ALTER FUNCTION alt_func2(int) RESET work_mem", + "ALTER FUNCTION alt_func2(INT) RESET work_mem", + ), + ( + "ALTER FUNCTION alt_func2(int) RESET ALL", + "ALTER FUNCTION alt_func2(INT) RESET ALL", + ), + ( + "ALTER FUNCTION alt_func2(int) IMMUTABLE STRICT PARALLEL SAFE RESTRICT", + "ALTER FUNCTION alt_func2(INT) IMMUTABLE STRICT PARALLEL SAFE RESTRICT", + ), + ( + "ALTER FUNCTION alt_variadic(VARIADIC int[]) STABLE", + "ALTER FUNCTION alt_variadic(VARIADIC INT[]) STABLE", + ), + ] { + let statement = pg_and_generic().one_statement_parses_to(sql, expected); + assert!(matches!( + statement, + Statement::AlterFunction(AlterFunction { + kind: AlterFunctionKind::Function, + .. + }) + )); + } + + assert!(pg() + .parse_sql_statements("ALTER AGGREGATE alt_func1(INT) DEPENDS ON EXTENSION ext1") + .is_err()); + assert!(pg() + .parse_sql_statements("ALTER AGGREGATE alt_func1(INT) NO DEPENDS ON EXTENSION ext1") + .is_err()); + assert!(pg() + .parse_sql_statements("ALTER AGGREGATE alt_func1(OUT INT) OWNER TO joe") + .is_err()); + assert!(pg() + .parse_sql_statements("ALTER AGGREGATE alt_func1(INOUT INT) OWNER TO joe") + .is_err()); + assert!(pg() + .parse_sql_statements("ALTER AGGREGATE alt_func1(INT = 1) OWNER TO joe") + .is_err()); + + assert!(pg() + .parse_sql_statements("ALTER AGGREGATE alt_func1(INT) IMMUTABLE") + .is_err()); +} + #[test] fn parse_drop_operator_family() { for if_exists in [true, false] {