diff --git a/src/adapter/src/optimize/dataflows.rs b/src/adapter/src/optimize/dataflows.rs index c812db9f7aafc..413b69308d33e 100644 --- a/src/adapter/src/optimize/dataflows.rs +++ b/src/adapter/src/optimize/dataflows.rs @@ -36,7 +36,7 @@ use mz_repr::adt::array::ArrayDimension; use mz_repr::explain::trace_plan; use mz_repr::optimize::OptimizerFeatures; use mz_repr::role_id::RoleId; -use mz_repr::{Datum, GlobalId, ReprRelationType, ReprScalarType, Row}; +use mz_repr::{Datum, GlobalId, ReprRelationType, Row}; use mz_sql::catalog::CatalogRole; use mz_sql::rbac; use mz_sql::session::metadata::SessionMetadata; @@ -253,10 +253,7 @@ impl ExprPrep for ExprPrepWebhookValidation { e { let now: Datum = now.try_into()?; - let const_expr = MirScalarExpr::literal_ok( - now, - ReprScalarType::from(&f.output_type().scalar_type), - ); + let const_expr = MirScalarExpr::literal_ok(now, f.output_type().scalar_type); *e = const_expr; } Ok(()) @@ -591,7 +588,7 @@ fn eval_unmaterializable_func( .expect("known to be a valid array"); Ok(MirScalarExpr::literal_from_single_element_row( row, - ReprScalarType::from(&f.output_type().scalar_type), + f.output_type().scalar_type, )) }; let pack_dict = |mut datums: Vec<(String, String)>| { @@ -604,13 +601,13 @@ fn eval_unmaterializable_func( ); Ok(MirScalarExpr::literal_from_single_element_row( row, - ReprScalarType::from(&f.output_type().scalar_type), + f.output_type().scalar_type, )) }; let pack = |datum| { Ok(MirScalarExpr::literal_ok( datum, - ReprScalarType::from(&f.output_type().scalar_type), + f.output_type().scalar_type, )) }; @@ -708,7 +705,7 @@ fn eval_unmaterializable_func( }); Ok(MirScalarExpr::literal_from_single_element_row( row, - ReprScalarType::from(&f.output_type().scalar_type), + f.output_type().scalar_type, )) } UnmaterializableFunc::MzSessionId => pack(Datum::from(state.config().session_id)), diff --git a/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__add_int16.snap b/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__add_int16.snap index a27ac8876b7fc..da4ed84613705 100644 --- a/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__add_int16.snap +++ b/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__add_int16.snap @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for AddInt16 { ) -> Self::Output<'a> { add_int16(a, b) } - fn output_type( + fn output_sql_type( &self, input_types: &[mz_repr::SqlColumnType], ) -> mz_repr::SqlColumnType { diff --git a/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__binary_arena_fn.snap b/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__binary_arena_fn.snap index 15b55e2e0f9b7..58a4883e15656 100644 --- a/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__binary_arena_fn.snap +++ b/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__binary_arena_fn.snap @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for UnaryFn { ) -> Self::Output<'a> { unary_fn(a, b, temp_storage) } - fn output_type( + fn output_sql_type( &self, input_types: &[mz_repr::SqlColumnType], ) -> mz_repr::SqlColumnType { diff --git a/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__complex_type.snap b/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__complex_type.snap index 5256c4bc5d6e3..d6355419549e1 100644 --- a/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__complex_type.snap +++ b/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__complex_type.snap @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for ComplexOutputTypeFn { ) -> Self::Output<'a> { complex_output_type_fn(a, b) } - fn output_type( + fn output_sql_type( &self, input_types: &[mz_repr::SqlColumnType], ) -> mz_repr::SqlColumnType { diff --git a/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__unary_fn.snap b/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__unary_fn.snap index 3c5330a9234fc..712c18562e8a0 100644 --- a/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__unary_fn.snap +++ b/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__unary_fn.snap @@ -22,7 +22,10 @@ impl crate::func::EagerUnaryFunc for UnaryFn { fn call<'a>(&self, a: Self::Input<'a>) -> Self::Output<'a> { unary_fn(a) } - fn output_type(&self, input_type: mz_repr::SqlColumnType) -> mz_repr::SqlColumnType { + fn output_sql_type( + &self, + input_type: mz_repr::SqlColumnType, + ) -> mz_repr::SqlColumnType { use mz_repr::AsColumnType; let output = Self::Output::as_column_type(); let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self); diff --git a/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__unary_ref.snap b/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__unary_ref.snap index f3f9c855fdf3d..d63e7719e115d 100644 --- a/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__unary_ref.snap +++ b/src/expr-derive-impl/src/snapshots/mz_expr_derive_impl__test__unary_ref.snap @@ -22,7 +22,10 @@ impl crate::func::EagerUnaryFunc for UnaryFn { fn call<'a>(&self, a: Self::Input<'a>) -> Self::Output<'a> { unary_fn(a) } - fn output_type(&self, input_type: mz_repr::SqlColumnType) -> mz_repr::SqlColumnType { + fn output_sql_type( + &self, + input_type: mz_repr::SqlColumnType, + ) -> mz_repr::SqlColumnType { use mz_repr::AsColumnType; let output = Self::Output::as_column_type(); let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self); diff --git a/src/expr-derive-impl/src/sqlfunc.rs b/src/expr-derive-impl/src/sqlfunc.rs index 17005c0e451d9..b01fb07441f93 100644 --- a/src/expr-derive-impl/src/sqlfunc.rs +++ b/src/expr-derive-impl/src/sqlfunc.rs @@ -340,7 +340,10 @@ fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result mz_repr::SqlColumnType { + fn output_sql_type( + &self, + input_type: mz_repr::SqlColumnType + ) -> mz_repr::SqlColumnType { use mz_repr::AsColumnType; let output = #output_type; let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self); @@ -510,7 +513,7 @@ fn binary_func( #fn_name(a, b #arena) } - fn output_type( + fn output_sql_type( &self, input_types: &[mz_repr::SqlColumnType], ) -> mz_repr::SqlColumnType { diff --git a/src/expr/src/interpret.rs b/src/expr/src/interpret.rs index bdecaa1e6a83d..553e47d4c82d8 100644 --- a/src/expr/src/interpret.rs +++ b/src/expr/src/interpret.rs @@ -764,7 +764,7 @@ impl<'a> ColumnSpecs<'a> { let range = self .unmaterializables .entry(func.clone()) - .or_insert_with(|| ResultSpec::has_type(&func.output_type(), true)); + .or_insert_with(|| ResultSpec::has_type(&func.output_sql_type(), true)); *range = range.clone().intersect(update); } @@ -847,12 +847,12 @@ impl<'a> Interpreter for ColumnSpecs<'a> { } fn unmaterializable(&self, func: &UnmaterializableFunc) -> Self::Summary { - let col_type = func.output_type(); + let col_type = func.output_sql_type(); let range = self .unmaterializables .get(func) .cloned() - .unwrap_or_else(|| ResultSpec::has_type(&func.output_type(), true)); + .unwrap_or_else(|| ResultSpec::has_type(&func.output_sql_type(), true)); ColumnSpec { col_type, range } } @@ -872,7 +872,7 @@ impl<'a> Interpreter for ColumnSpecs<'a> { }) }; - let col_type = func.output_type(summary.col_type); + let col_type = func.output_sql_type(summary.col_type); let range = mapped_spec.intersect(ResultSpec::has_type(&col_type, fallible)); ColumnSpec { col_type, range } @@ -904,7 +904,7 @@ impl<'a> Interpreter for ColumnSpecs<'a> { }) }; - let col_type = func.output_type(&[left.col_type, right.col_type]); + let col_type = func.output_sql_type(&[left.col_type, right.col_type]); let range = mapped_spec.intersect(ResultSpec::has_type(&col_type, fallible)); ColumnSpec { col_type, range } @@ -954,7 +954,7 @@ impl<'a> Interpreter for ColumnSpecs<'a> { }; let col_types = args.into_iter().map(|spec| spec.col_type).collect(); - let col_type = func.output_type(col_types); + let col_type = func.output_sql_type(col_types); let range = mapped_spec.intersect(ResultSpec::has_type(&col_type, fallible)); @@ -1318,7 +1318,7 @@ mod tests { if !unary_typecheck(&func, &type_in) { return None; } - let type_out = func.output_type(type_in); + let type_out = func.output_sql_type(type_in); let expr_out = MirScalarExpr::CallUnary { func, expr: Box::new(expr_in), @@ -1337,7 +1337,7 @@ mod tests { if !binary_typecheck(&func, &type_left, &type_right) { return None; } - let type_out = func.output_type(&[type_left, type_right]); + let type_out = func.output_sql_type(&[type_left, type_right]); let expr_out = MirScalarExpr::CallBinary { func, expr1: Box::new(expr_left), @@ -1356,7 +1356,7 @@ mod tests { if !variadic_typecheck(&func, &type_in) { return None; } - let type_out = func.output_type(type_in); + let type_out = func.output_sql_type(type_in); let expr_out = MirScalarExpr::CallVariadic { func, exprs: exprs_in, diff --git a/src/expr/src/relation.rs b/src/expr/src/relation.rs index 828739e0d020c..08d4dcf14410d 100644 --- a/src/expr/src/relation.rs +++ b/src/expr/src/relation.rs @@ -467,7 +467,7 @@ impl MirRelationExpr { FlatMap { func, .. } => { let mut result = input_types.next().unwrap().clone(); result.extend( - func.output_type() + func.output_sql_type() .column_types .iter() .map(ReprColumnType::from), @@ -1256,13 +1256,13 @@ impl MirRelationExpr { /// # Example /// /// ```rust - /// use mz_repr::{Datum, SqlColumnType, SqlRelationType, SqlScalarType}; + /// use mz_repr::{Datum, SqlColumnType, ReprRelationType, ReprScalarType}; /// use mz_expr::MirRelationExpr; /// /// // A common schema for each input. - /// let schema = SqlRelationType::new(vec![ - /// SqlScalarType::Int32.nullable(false), - /// SqlScalarType::Int32.nullable(false), + /// let schema = ReprRelationType::new(vec![ + /// ReprScalarType::Int32.nullable(false), + /// ReprScalarType::Int32.nullable(false), /// ]); /// /// // the specific data are not important here. @@ -1519,16 +1519,16 @@ impl MirRelationExpr { /// the correct type. pub fn take_safely(&mut self, typ: Option) -> MirRelationExpr { if let Some(typ) = &typ { - let self_typ = self.sql_typ(); + let self_typ = self.typ(); soft_assert_no_log!( self_typ .column_types .iter() .zip_eq(typ.column_types.iter()) - .all(|(t1, t2)| ReprScalarType::from(&t1.scalar_type) == t2.scalar_type) + .all(|(t1, t2)| t1.scalar_type == t2.scalar_type) ); } - let mut typ = typ.unwrap_or_else(|| ReprRelationType::from(&self.sql_typ())); + let mut typ = typ.unwrap_or_else(|| self.typ()); typ.keys = vec![vec![]]; for ct in typ.column_types.iter_mut() { ct.nullable = false; @@ -2478,16 +2478,12 @@ pub struct AggregateExpr { impl AggregateExpr { /// Computes the type of this `AggregateExpr`. pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType { - self.func.output_type(self.expr.sql_typ(column_types)) + self.func.output_sql_type(self.expr.sql_typ(column_types)) } /// Computes the type of this `AggregateExpr`. pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType { - ReprColumnType::from( - &self - .func - .output_type(SqlColumnType::from_repr(&self.expr.typ(column_types))), - ) + self.func.output_type(self.expr.typ(column_types)) } /// Returns whether the expression has a constant result. diff --git a/src/expr/src/relation/func.rs b/src/expr/src/relation/func.rs index 6c7c288c519f7..7f282f21258c2 100644 --- a/src/expr/src/relation/func.rs +++ b/src/expr/src/relation/func.rs @@ -30,8 +30,8 @@ use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale}; use mz_repr::adt::regex::{Regex as ReprRegex, RegexCompilationError}; use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampLike}; use mz_repr::{ - ColumnName, Datum, Diff, Row, RowArena, RowPacker, SharedRow, SqlColumnType, SqlRelationType, - SqlScalarType, datum_size, + ColumnName, Datum, Diff, ReprColumnType, ReprRelationType, Row, RowArena, RowPacker, SharedRow, + SqlColumnType, SqlRelationType, SqlScalarType, datum_size, }; use num::{CheckedAdd, Integer, Signed, ToPrimitive}; use ordered_float::OrderedFloat; @@ -2324,7 +2324,7 @@ impl AggregateFunc { /// The output column type also contains nullability information, which /// is (without further information) true for aggregations that are not /// counts. - pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { let scalar_type = match self { AggregateFunc::Count => SqlScalarType::Int64, AggregateFunc::Any => SqlScalarType::Bool, @@ -2437,7 +2437,7 @@ impl AggregateFunc { let arg_type = fields[0].unwrap_record_element_type()[1] .clone() .nullable(true); - let wrapped_aggr_out_type = wrapped_aggregate.output_type(arg_type); + let wrapped_aggr_out_type = wrapped_aggregate.output_sql_type(arg_type); SqlScalarType::List { element_type: Box::new(SqlScalarType::Record { @@ -2465,7 +2465,7 @@ impl AggregateFunc { |(arg_type, wrapped_agg)| { ( ColumnName::from(wrapped_agg.name()), - wrapped_agg.output_type((**arg_type).clone().nullable(true)), + wrapped_agg.output_sql_type((**arg_type).clone().nullable(true)), ) }).collect_vec(); @@ -2602,6 +2602,13 @@ impl AggregateFunc { scalar_type.nullable(nullable) } + /// Computes the representation type of this aggregate function. + /// + /// This is a wrapper around [`Self::output_sql_type`] that converts the result to a representation type. + pub fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType { + ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type))) + } + /// Compute output type for ROW_NUMBER, RANK, DENSE_RANK fn output_type_ranking_window_funcs( input_type: &SqlColumnType, @@ -3555,7 +3562,7 @@ impl TableFunc { } } - pub fn output_type(&self) -> SqlRelationType { + pub fn output_sql_type(&self) -> SqlRelationType { let (column_types, keys) = match self { TableFunc::AclExplode => { let column_types = vec![ @@ -3694,7 +3701,7 @@ impl TableFunc { (column_types, keys) } TableFunc::WithOrdinality(WithOrdinality { inner }) => { - let mut typ = inner.output_type(); + let mut typ = inner.output_sql_type(); // Add the ordinality column. typ.column_types.push(SqlScalarType::Int64.nullable(false)); // The ordinality column is always a key. @@ -3712,6 +3719,13 @@ impl TableFunc { } } + /// Computes the representation type of this table function. + /// + /// This is a wrapper around [`Self::output_sql_type`] that converts the result to a representation type. + pub fn output_type(&self) -> ReprRelationType { + ReprRelationType::from(&self.output_sql_type()) + } + pub fn output_arity(&self) -> usize { match self { TableFunc::AclExplode => 4, diff --git a/src/expr/src/relation/join_input_mapper.rs b/src/expr/src/relation/join_input_mapper.rs index 4e990cdc40058..e0de5419ee5f9 100644 --- a/src/expr/src/relation/join_input_mapper.rs +++ b/src/expr/src/relation/join_input_mapper.rs @@ -265,13 +265,13 @@ impl JoinInputMapper { /// # Examples /// /// ``` - /// use mz_repr::{Datum, SqlColumnType, SqlRelationType, SqlScalarType}; + /// use mz_repr::{Datum, ReprColumnType, ReprRelationType, ReprScalarType}; /// use mz_expr::{JoinInputMapper, MirRelationExpr, MirScalarExpr}; /// /// // A two-column schema common to each of the three inputs - /// let schema = SqlRelationType::new(vec![ - /// SqlScalarType::Int32.nullable(false), - /// SqlScalarType::Int32.nullable(false), + /// let schema = ReprRelationType::new(vec![ + /// ReprScalarType::Int32.nullable(false), + /// ReprScalarType::Int32.nullable(false), /// ]); /// /// // the specific data are not important here. diff --git a/src/expr/src/scalar.rs b/src/expr/src/scalar.rs index 0105e89dc92f4..4c1c9dc3f19a3 100644 --- a/src/expr/src/scalar.rs +++ b/src/expr/src/scalar.rs @@ -2012,24 +2012,14 @@ impl MirScalarExpr { match self { MirScalarExpr::Column(i, _name) => column_types[*i].clone(), MirScalarExpr::Literal(_, typ) => typ.clone(), - MirScalarExpr::CallUnmaterializable(func) => ReprColumnType::from(&func.output_type()), - MirScalarExpr::CallUnary { expr, func } => ReprColumnType::from( - &func.output_type(SqlColumnType::from_repr(&expr.typ(column_types))), - ), + MirScalarExpr::CallUnmaterializable(func) => func.output_type(), + MirScalarExpr::CallUnary { expr, func } => func.output_type(expr.typ(column_types)), MirScalarExpr::CallBinary { expr1, expr2, func } => { - ReprColumnType::from(&func.output_type(&[ - SqlColumnType::from_repr(&expr1.typ(column_types)), - SqlColumnType::from_repr(&expr2.typ(column_types)), - ])) - } - MirScalarExpr::CallVariadic { exprs, func } => ReprColumnType::from( - &func.output_type( - exprs - .iter() - .map(|e| SqlColumnType::from_repr(&e.typ(column_types))) - .collect(), - ), - ), + func.output_type(&[expr1.typ(column_types), expr2.typ(column_types)]) + } + MirScalarExpr::CallVariadic { exprs, func } => { + func.output_type(exprs.iter().map(|e| e.typ(column_types)).collect()) + } MirScalarExpr::If { cond: _, then, els } => { let then_type = then.typ(column_types); let else_type = els.typ(column_types); diff --git a/src/expr/src/scalar/func/binary.rs b/src/expr/src/scalar/func/binary.rs index b8da434d4bb08..6de0472112bc0 100644 --- a/src/expr/src/scalar/func/binary.rs +++ b/src/expr/src/scalar/func/binary.rs @@ -10,7 +10,7 @@ //! Utilities for binary functions. use mz_ore::assert_none; -use mz_repr::{Datum, InputDatumType, OutputDatumType, RowArena, SqlColumnType}; +use mz_repr::{Datum, InputDatumType, OutputDatumType, ReprColumnType, RowArena, SqlColumnType}; use crate::{EvalError, MirScalarExpr}; @@ -25,7 +25,19 @@ pub(crate) trait LazyBinaryFunc { ) -> Result, EvalError>; /// The output SqlColumnType of this function. - fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType; + fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType; + + /// A wrapper around [`Self::output_sql_type`] that works with representation types. + fn output_type(&self, input_types: &[ReprColumnType]) -> ReprColumnType { + ReprColumnType::from( + &self.output_sql_type( + &input_types + .iter() + .map(SqlColumnType::from_repr) + .collect::>(), + ), + ) + } /// Whether this function will produce NULL on NULL input. fn propagates_nulls(&self) -> bool; @@ -67,7 +79,20 @@ pub(crate) trait EagerBinaryFunc { fn call<'a>(&self, input: Self::Input<'a>, temp_storage: &'a RowArena) -> Self::Output<'a>; /// The output SqlColumnType of this function - fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType; + fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType; + + /// The output of this function as a representation type. + #[allow(dead_code)] + fn output_type(&self, input_types: &[ReprColumnType]) -> ReprColumnType { + ReprColumnType::from( + &self.output_sql_type( + &input_types + .iter() + .map(SqlColumnType::from_repr) + .collect::>(), + ), + ) + } /// Whether this function will produce NULL on NULL input fn propagates_nulls(&self) -> bool { @@ -128,8 +153,8 @@ impl LazyBinaryFunc for T { self.call(input, temp_storage).into_result(temp_storage) } - fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType { - self.output_type(input_types) + fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType { + self.output_sql_type(input_types) } fn propagates_nulls(&self) -> bool { @@ -162,7 +187,7 @@ pub use derive::BinaryFunc; mod derive { use std::fmt; - use mz_repr::{Datum, RowArena, SqlColumnType}; + use mz_repr::{Datum, ReprColumnType, RowArena, SqlColumnType}; use crate::scalar::func::binary::LazyBinaryFunc; use crate::scalar::func::*; @@ -426,28 +451,28 @@ mod test { #[mz_ore::test] fn output_types_infallible() { assert_eq!( - Infallible1.output_type(&[ + Infallible1.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Infallible1.output_type(&[ + Infallible1.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(false) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Infallible1.output_type(&[ + Infallible1.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Infallible1.output_type(&[ + Infallible1.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(false) ]), @@ -455,28 +480,28 @@ mod test { ); assert_eq!( - Infallible2.output_type(&[ + Infallible2.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Infallible2.output_type(&[ + Infallible2.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(false) ]), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Infallible2.output_type(&[ + Infallible2.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Infallible2.output_type(&[ + Infallible2.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(false) ]), @@ -484,28 +509,28 @@ mod test { ); assert_eq!( - Infallible3.output_type(&[ + Infallible3.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Infallible3.output_type(&[ + Infallible3.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(false) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Infallible3.output_type(&[ + Infallible3.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Infallible3.output_type(&[ + Infallible3.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(false) ]), @@ -546,28 +571,28 @@ mod test { #[mz_ore::test] fn output_types_fallible() { assert_eq!( - Fallible1.output_type(&[ + Fallible1.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Fallible1.output_type(&[ + Fallible1.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(false) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Fallible1.output_type(&[ + Fallible1.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Fallible1.output_type(&[ + Fallible1.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(false) ]), @@ -575,28 +600,28 @@ mod test { ); assert_eq!( - Fallible2.output_type(&[ + Fallible2.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Fallible2.output_type(&[ + Fallible2.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(false) ]), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Fallible2.output_type(&[ + Fallible2.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Fallible2.output_type(&[ + Fallible2.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(false) ]), @@ -604,28 +629,28 @@ mod test { ); assert_eq!( - Fallible3.output_type(&[ + Fallible3.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Fallible3.output_type(&[ + Fallible3.output_sql_type(&[ SqlScalarType::Float32.nullable(true), SqlScalarType::Float32.nullable(false) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Fallible3.output_type(&[ + Fallible3.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(true) ]), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Fallible3.output_type(&[ + Fallible3.output_sql_type(&[ SqlScalarType::Float32.nullable(false), SqlScalarType::Float32.nullable(false) ]), diff --git a/src/expr/src/scalar/func/impls/array.rs b/src/expr/src/scalar/func/impls/array.rs index e0b451d1b0c02..edf934b111c89 100644 --- a/src/expr/src/scalar/func/impls/array.rs +++ b/src/expr/src/scalar/func/impls/array.rs @@ -74,7 +74,7 @@ impl LazyUnaryFunc for CastArrayToString { Ok(Datum::String(temp_storage.push_string(buf))) } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::String.nullable(input_type.nullable) } @@ -177,7 +177,7 @@ impl LazyUnaryFunc for CastArrayToJsonb { Ok(temp_storage.push_unary_row(row)) } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::Jsonb.nullable(input_type.nullable) } @@ -254,7 +254,7 @@ impl LazyUnaryFunc for CastArrayToArray { Ok(temp_storage.try_make_datum(|packer| packer.try_push_array(&dims, casted_datums))?) } - fn output_type(&self, _input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, _input_type: SqlColumnType) -> SqlColumnType { self.return_ty.clone().nullable(true) } diff --git a/src/expr/src/scalar/func/impls/char.rs b/src/expr/src/scalar/func/impls/char.rs index 59e1c57a6c044..ec65b18e961a0 100644 --- a/src/expr/src/scalar/func/impls/char.rs +++ b/src/expr/src/scalar/func/impls/char.rs @@ -44,7 +44,7 @@ impl EagerUnaryFunc for PadChar { Char(format_str_pad(a, self.length)) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Char { length: self.length, } diff --git a/src/expr/src/scalar/func/impls/date.rs b/src/expr/src/scalar/func/impls/date.rs index bfd0ffd0f65fd..f36c887df72e1 100644 --- a/src/expr/src/scalar/func/impls/date.rs +++ b/src/expr/src/scalar/func/impls/date.rs @@ -59,7 +59,7 @@ impl EagerUnaryFunc for CastDateToTimestamp { Ok(updated) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Timestamp { precision: self.0 }.nullable(input.nullable) } @@ -110,7 +110,7 @@ impl EagerUnaryFunc for CastDateToTimestampTz { Ok(updated) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::TimestampTz { precision: self.0 }.nullable(input.nullable) } @@ -187,7 +187,7 @@ impl EagerUnaryFunc for ExtractDate { extract_date_inner(self.0, a.into()) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: None }.nullable(input.nullable) } diff --git a/src/expr/src/scalar/func/impls/float32.rs b/src/expr/src/scalar/func/impls/float32.rs index 50e7ad3fc3b81..ffc7a6ffaf03f 100644 --- a/src/expr/src/scalar/func/impls/float32.rs +++ b/src/expr/src/scalar/func/impls/float32.rs @@ -218,7 +218,7 @@ impl EagerUnaryFunc for CastFloat32ToNumeric { Ok(a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable) } diff --git a/src/expr/src/scalar/func/impls/float64.rs b/src/expr/src/scalar/func/impls/float64.rs index 115204a505025..f25a6741f5984 100644 --- a/src/expr/src/scalar/func/impls/float64.rs +++ b/src/expr/src/scalar/func/impls/float64.rs @@ -233,7 +233,7 @@ impl EagerUnaryFunc for CastFloat64ToNumeric { } } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable) } diff --git a/src/expr/src/scalar/func/impls/int16.rs b/src/expr/src/scalar/func/impls/int16.rs index 70d744ada29dd..cc928db2c3eb2 100644 --- a/src/expr/src/scalar/func/impls/int16.rs +++ b/src/expr/src/scalar/func/impls/int16.rs @@ -153,7 +153,7 @@ impl EagerUnaryFunc for CastInt16ToNumeric { Ok(a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable) } diff --git a/src/expr/src/scalar/func/impls/int32.rs b/src/expr/src/scalar/func/impls/int32.rs index b0f3c94296868..f2c6346b2908f 100644 --- a/src/expr/src/scalar/func/impls/int32.rs +++ b/src/expr/src/scalar/func/impls/int32.rs @@ -169,7 +169,7 @@ impl EagerUnaryFunc for CastInt32ToNumeric { Ok(a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable) } diff --git a/src/expr/src/scalar/func/impls/int64.rs b/src/expr/src/scalar/func/impls/int64.rs index e56d3ae546b3f..a355b9de36b76 100644 --- a/src/expr/src/scalar/func/impls/int64.rs +++ b/src/expr/src/scalar/func/impls/int64.rs @@ -146,7 +146,7 @@ impl EagerUnaryFunc for CastInt64ToNumeric { Ok(a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable) } diff --git a/src/expr/src/scalar/func/impls/jsonb.rs b/src/expr/src/scalar/func/impls/jsonb.rs index e69407b2c040c..b4280df534063 100644 --- a/src/expr/src/scalar/func/impls/jsonb.rs +++ b/src/expr/src/scalar/func/impls/jsonb.rs @@ -122,7 +122,7 @@ impl EagerUnaryFunc for CastJsonbToNumeric { } } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable) } diff --git a/src/expr/src/scalar/func/impls/list.rs b/src/expr/src/scalar/func/impls/list.rs index c7026e6056c43..2e7107533a442 100644 --- a/src/expr/src/scalar/func/impls/list.rs +++ b/src/expr/src/scalar/func/impls/list.rs @@ -50,7 +50,7 @@ impl LazyUnaryFunc for CastListToString { Ok(Datum::String(temp_storage.push_string(buf))) } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::String.nullable(input_type.nullable) } @@ -123,7 +123,7 @@ impl LazyUnaryFunc for CastListToJsonb { Ok(temp_storage.push_unary_row(row)) } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::Jsonb.nullable(input_type.nullable) } @@ -199,7 +199,7 @@ impl LazyUnaryFunc for CastList1ToList2 { Ok(temp_storage.make_datum(|packer| packer.push_list(cast_datums))) } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { self.return_ty .without_modifiers() .nullable(input_type.nullable) @@ -293,7 +293,7 @@ impl EagerBinaryFunc for ListLengthMax { } } } - fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType { + fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType { let output = Self::Output::as_column_type(); let propagates_nulls = self.propagates_nulls(); let nullable = output.nullable; diff --git a/src/expr/src/scalar/func/impls/map.rs b/src/expr/src/scalar/func/impls/map.rs index 59afc397fea91..78080a030cd7c 100644 --- a/src/expr/src/scalar/func/impls/map.rs +++ b/src/expr/src/scalar/func/impls/map.rs @@ -50,7 +50,7 @@ impl LazyUnaryFunc for CastMapToString { Ok(Datum::String(temp_storage.push_string(buf))) } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::String.nullable(input_type.nullable) } @@ -137,7 +137,7 @@ impl LazyUnaryFunc for MapBuildFromRecordList { Ok(map) } - fn output_type(&self, _input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, _input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::Map { value_type: Box::new(self.value_type.clone()), custom_id: None, diff --git a/src/expr/src/scalar/func/impls/numeric.rs b/src/expr/src/scalar/func/impls/numeric.rs index c6a225e6c9927..76fd18124ae61 100644 --- a/src/expr/src/scalar/func/impls/numeric.rs +++ b/src/expr/src/scalar/func/impls/numeric.rs @@ -330,7 +330,7 @@ impl EagerUnaryFunc for AdjustNumericScale { Ok(d) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: Some(self.0), } diff --git a/src/expr/src/scalar/func/impls/range.rs b/src/expr/src/scalar/func/impls/range.rs index 71b3bfd2683d7..920ad52e96ae4 100644 --- a/src/expr/src/scalar/func/impls/range.rs +++ b/src/expr/src/scalar/func/impls/range.rs @@ -50,7 +50,7 @@ impl LazyUnaryFunc for CastRangeToString { Ok(Datum::String(temp_storage.push_string(buf))) } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::String.nullable(input_type.nullable) } diff --git a/src/expr/src/scalar/func/impls/record.rs b/src/expr/src/scalar/func/impls/record.rs index a8a349ec2ed52..a59df5b29f615 100644 --- a/src/expr/src/scalar/func/impls/record.rs +++ b/src/expr/src/scalar/func/impls/record.rs @@ -49,7 +49,7 @@ impl LazyUnaryFunc for CastRecordToString { Ok(Datum::String(temp_storage.push_string(buf))) } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::String.nullable(input_type.nullable) } @@ -118,7 +118,7 @@ impl LazyUnaryFunc for CastRecord1ToRecord2 { Ok(temp_storage.make_datum(|packer| packer.push_list(cast_datums))) } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { self.return_ty .without_modifiers() .nullable(input_type.nullable) @@ -183,7 +183,7 @@ impl LazyUnaryFunc for RecordGet { Ok(a.unwrap_list().iter().nth(self.0).unwrap()) } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { match input_type.scalar_type { SqlScalarType::Record { fields, .. } => { let (_name, ty) = &fields[self.0]; diff --git a/src/expr/src/scalar/func/impls/string.rs b/src/expr/src/scalar/func/impls/string.rs index e518aa5402341..6c255f4b7b950 100644 --- a/src/expr/src/scalar/func/impls/string.rs +++ b/src/expr/src/scalar/func/impls/string.rs @@ -183,7 +183,7 @@ impl EagerUnaryFunc for CastStringToNumeric { Ok(d.into_inner()) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable) } @@ -240,7 +240,7 @@ impl EagerUnaryFunc for CastStringToTimestamp { Ok(updated) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Timestamp { precision: self.0 }.nullable(input.nullable) } @@ -294,7 +294,7 @@ impl EagerUnaryFunc for CastStringToTimestampTz { Ok(updated) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::TimestampTz { precision: self.0 }.nullable(input.nullable) } @@ -375,7 +375,7 @@ impl LazyUnaryFunc for CastStringToArray { } /// The output SqlColumnType of this function - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { self.return_ty.clone().nullable(input_type.nullable) } @@ -463,7 +463,7 @@ impl LazyUnaryFunc for CastStringToList { } /// The output SqlColumnType of this function - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { self.return_ty .without_modifiers() .nullable(input_type.nullable) @@ -561,7 +561,7 @@ impl LazyUnaryFunc for CastStringToMap { } /// The output SqlColumnType of this function - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { self.return_ty.clone().nullable(input_type.nullable) } @@ -630,7 +630,7 @@ impl EagerUnaryFunc for CastStringToChar { Ok(Char(s)) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Char { length: self.length, } @@ -712,7 +712,7 @@ impl LazyUnaryFunc for CastStringToRange { } /// The output SqlColumnType of this function - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { self.return_ty .without_modifiers() .nullable(input_type.nullable) @@ -784,7 +784,7 @@ impl EagerUnaryFunc for CastStringToVarChar { Ok(VarChar(s)) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::VarChar { max_length: self.length, } @@ -864,7 +864,7 @@ impl LazyUnaryFunc for CastStringToInt2Vector { } /// The output SqlColumnType of this function - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::Int2Vector.nullable(input_type.nullable) } @@ -1011,7 +1011,7 @@ impl EagerUnaryFunc for IsLikeMatch { self.0.is_match(haystack) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Bool.nullable(input.nullable) } } @@ -1049,7 +1049,7 @@ impl EagerUnaryFunc for IsRegexpMatch { self.0.is_match(haystack) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Bool.nullable(input.nullable) } } @@ -1094,7 +1094,7 @@ impl LazyUnaryFunc for RegexpMatch { } /// The output SqlColumnType of this function - fn output_type(&self, _input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, _input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(true) } @@ -1163,7 +1163,7 @@ impl LazyUnaryFunc for RegexpSplitToArray { } /// The output SqlColumnType of this function - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(input_type.nullable) } @@ -1249,7 +1249,7 @@ impl binary::EagerBinaryFunc for RegexpReplace { self.regex.replacen(source, self.limit, replacement) } - fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType { + fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType { use mz_repr::AsColumnType; let output = as AsColumnType>::as_column_type(); let propagates_nulls = binary::EagerBinaryFunc::propagates_nulls(self); diff --git a/src/expr/src/scalar/func/impls/time.rs b/src/expr/src/scalar/func/impls/time.rs index d595025b90a9f..d3985320d4821 100644 --- a/src/expr/src/scalar/func/impls/time.rs +++ b/src/expr/src/scalar/func/impls/time.rs @@ -109,7 +109,7 @@ impl EagerUnaryFunc for ExtractTime { date_part_time_inner(self.0, a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: None }.nullable(input.nullable) } } @@ -142,7 +142,7 @@ impl EagerUnaryFunc for DatePartTime { date_part_time_inner(self.0, a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Float64.nullable(input.nullable) } } @@ -188,7 +188,7 @@ impl EagerUnaryFunc for TimezoneTime { timezone_time(self.tz, a, &self.wall_time) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Time.nullable(input.nullable) } } diff --git a/src/expr/src/scalar/func/impls/timestamp.rs b/src/expr/src/scalar/func/impls/timestamp.rs index 1de81d77e15e0..945aa2a1180ba 100644 --- a/src/expr/src/scalar/func/impls/timestamp.rs +++ b/src/expr/src/scalar/func/impls/timestamp.rs @@ -99,7 +99,7 @@ impl EagerUnaryFunc for CastTimestampToTimestampTz { Ok(updated) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::TimestampTz { precision: self.to }.nullable(input.nullable) } @@ -158,7 +158,7 @@ impl EagerUnaryFunc for AdjustTimestampPrecision { Ok(updated) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Timestamp { precision: self.to }.nullable(input.nullable) } @@ -211,7 +211,7 @@ impl EagerUnaryFunc for CastTimestampTzToTimestamp { Ok(updated) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Timestamp { precision: self.to }.nullable(input.nullable) } @@ -270,7 +270,7 @@ impl EagerUnaryFunc for AdjustTimestampTzPrecision { Ok(updated) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::TimestampTz { precision: self.to }.nullable(input.nullable) } @@ -363,7 +363,7 @@ impl EagerUnaryFunc for ExtractInterval { date_part_interval_inner(self.0, a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: None }.nullable(input.nullable) } } @@ -396,7 +396,7 @@ impl EagerUnaryFunc for DatePartInterval { date_part_interval_inner(self.0, a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Float64.nullable(input.nullable) } } @@ -475,7 +475,7 @@ impl EagerUnaryFunc for ExtractTimestamp { date_part_timestamp_inner(self.0, &*a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: None }.nullable(input.nullable) } @@ -512,7 +512,7 @@ impl EagerUnaryFunc for ExtractTimestampTz { date_part_timestamp_inner(self.0, &*a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: None }.nullable(input.nullable) } @@ -552,7 +552,7 @@ impl EagerUnaryFunc for DatePartTimestamp { date_part_timestamp_inner(self.0, &*a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Float64.nullable(input.nullable) } } @@ -585,7 +585,7 @@ impl EagerUnaryFunc for DatePartTimestampTz { date_part_timestamp_inner(self.0, &*a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Float64.nullable(input.nullable) } } @@ -647,7 +647,7 @@ impl EagerUnaryFunc for DateTruncTimestamp { date_trunc_inner(self.0, &*a)?.try_into().err_into() } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Timestamp { precision: None }.nullable(input.nullable) } @@ -684,7 +684,7 @@ impl EagerUnaryFunc for DateTruncTimestampTz { date_trunc_inner(self.0, &*a)?.try_into().err_into() } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::TimestampTz { precision: None }.nullable(input.nullable) } @@ -777,7 +777,7 @@ impl EagerUnaryFunc for TimezoneTimestamp { timezone_timestamp(self.0, a.to_naive()) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::TimestampTz { precision: None }.nullable(input.nullable) } } @@ -812,7 +812,7 @@ impl EagerUnaryFunc for TimezoneTimestampTz { .err_into() } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Timestamp { precision: None }.nullable(input.nullable) } } @@ -848,7 +848,7 @@ impl EagerUnaryFunc for ToCharTimestamp { self.format.render(&*input) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::String.nullable(input.nullable) } } @@ -884,7 +884,7 @@ impl EagerUnaryFunc for ToCharTimestampTz { self.format.render(&*input) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::String.nullable(input.nullable) } } diff --git a/src/expr/src/scalar/func/impls/uint16.rs b/src/expr/src/scalar/func/impls/uint16.rs index e9c5b87838e38..426eb6bc13582 100644 --- a/src/expr/src/scalar/func/impls/uint16.rs +++ b/src/expr/src/scalar/func/impls/uint16.rs @@ -135,7 +135,7 @@ impl EagerUnaryFunc for CastUint16ToNumeric { Ok(a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable) } diff --git a/src/expr/src/scalar/func/impls/uint32.rs b/src/expr/src/scalar/func/impls/uint32.rs index 43fe7f3e8bd19..3990524bfa181 100644 --- a/src/expr/src/scalar/func/impls/uint32.rs +++ b/src/expr/src/scalar/func/impls/uint32.rs @@ -141,7 +141,7 @@ impl EagerUnaryFunc for CastUint32ToNumeric { Ok(a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable) } diff --git a/src/expr/src/scalar/func/impls/uint64.rs b/src/expr/src/scalar/func/impls/uint64.rs index 09e7ed1e50f12..ad3d128fc70cd 100644 --- a/src/expr/src/scalar/func/impls/uint64.rs +++ b/src/expr/src/scalar/func/impls/uint64.rs @@ -145,7 +145,7 @@ impl EagerUnaryFunc for CastUint64ToNumeric { Ok(a) } - fn output_type(&self, input: SqlColumnType) -> SqlColumnType { + fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType { SqlScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable) } diff --git a/src/expr/src/scalar/func/macros.rs b/src/expr/src/scalar/func/macros.rs index 56806901fea26..413b52c16237a 100644 --- a/src/expr/src/scalar/func/macros.rs +++ b/src/expr/src/scalar/func/macros.rs @@ -53,29 +53,29 @@ mod test { #[mz_ore::test] fn output_types_infallible() { assert_eq!( - Infallible1.output_type(SqlScalarType::Float32.nullable(true)), + Infallible1.output_sql_type(SqlScalarType::Float32.nullable(true)), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Infallible1.output_type(SqlScalarType::Float32.nullable(false)), + Infallible1.output_sql_type(SqlScalarType::Float32.nullable(false)), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Infallible2.output_type(SqlScalarType::Float32.nullable(true)), + Infallible2.output_sql_type(SqlScalarType::Float32.nullable(true)), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Infallible2.output_type(SqlScalarType::Float32.nullable(false)), + Infallible2.output_sql_type(SqlScalarType::Float32.nullable(false)), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Infallible3.output_type(SqlScalarType::Float32.nullable(true)), + Infallible3.output_sql_type(SqlScalarType::Float32.nullable(true)), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Infallible3.output_type(SqlScalarType::Float32.nullable(false)), + Infallible3.output_sql_type(SqlScalarType::Float32.nullable(false)), SqlScalarType::Float32.nullable(true) ); } @@ -110,29 +110,29 @@ mod test { #[mz_ore::test] fn output_types_fallible() { assert_eq!( - Fallible1.output_type(SqlScalarType::Float32.nullable(true)), + Fallible1.output_sql_type(SqlScalarType::Float32.nullable(true)), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Fallible1.output_type(SqlScalarType::Float32.nullable(false)), + Fallible1.output_sql_type(SqlScalarType::Float32.nullable(false)), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Fallible2.output_type(SqlScalarType::Float32.nullable(true)), + Fallible2.output_sql_type(SqlScalarType::Float32.nullable(true)), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Fallible2.output_type(SqlScalarType::Float32.nullable(false)), + Fallible2.output_sql_type(SqlScalarType::Float32.nullable(false)), SqlScalarType::Float32.nullable(false) ); assert_eq!( - Fallible3.output_type(SqlScalarType::Float32.nullable(true)), + Fallible3.output_sql_type(SqlScalarType::Float32.nullable(true)), SqlScalarType::Float32.nullable(true) ); assert_eq!( - Fallible3.output_type(SqlScalarType::Float32.nullable(false)), + Fallible3.output_sql_type(SqlScalarType::Float32.nullable(false)), SqlScalarType::Float32.nullable(true) ); } @@ -166,7 +166,12 @@ macro_rules! derive_unary { } } - pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { + match self { + $(Self::$name(f) => LazyUnaryFunc::output_sql_type(f, input_type),)* + } + } + pub fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType { match self { $(Self::$name(f) => LazyUnaryFunc::output_type(f, input_type),)* } @@ -283,14 +288,20 @@ macro_rules! derive_binary { } } - pub fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType { + pub fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType { match self { $(Self::$name(f) => { - LazyBinaryFunc::output_type(f, input_types) + LazyBinaryFunc::output_sql_type(f, input_types) },)* } } + pub fn output_type(&self, input_types: &[ReprColumnType]) -> ReprColumnType { + match self { + $(Self::$name(f) => LazyBinaryFunc::output_type(f, input_types),)* + } + } + pub fn propagates_nulls(&self) -> bool { match self { $(Self::$name(f) => LazyBinaryFunc::propagates_nulls(f),)* diff --git a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible1.snap b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible1.snap index 408c014544ffc..b13869f7feb4b 100644 --- a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible1.snap +++ b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible1.snap @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for Fallible1 { ) -> Self::Output<'a> { fallible1(a, b) } - fn output_type( + fn output_sql_type( &self, input_types: &[mz_repr::SqlColumnType], ) -> mz_repr::SqlColumnType { diff --git a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible2.snap b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible2.snap index b3ce37938169e..5e494268506e7 100644 --- a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible2.snap +++ b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible2.snap @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for Fallible2 { ) -> Self::Output<'a> { fallible2(a, b) } - fn output_type( + fn output_sql_type( &self, input_types: &[mz_repr::SqlColumnType], ) -> mz_repr::SqlColumnType { diff --git a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible3.snap b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible3.snap index 6081be7da001e..75d76d7c12edf 100644 --- a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible3.snap +++ b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__fallible3.snap @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for Fallible3 { ) -> Self::Output<'a> { fallible3(a, b) } - fn output_type( + fn output_sql_type( &self, input_types: &[mz_repr::SqlColumnType], ) -> mz_repr::SqlColumnType { diff --git a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible1.snap b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible1.snap index 35869c950acd8..4ec6d2231049c 100644 --- a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible1.snap +++ b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible1.snap @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for Infallible1 { ) -> Self::Output<'a> { infallible1(a, b) } - fn output_type( + fn output_sql_type( &self, input_types: &[mz_repr::SqlColumnType], ) -> mz_repr::SqlColumnType { diff --git a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible2.snap b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible2.snap index d2fd1b23a1687..99b72a6b293b8 100644 --- a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible2.snap +++ b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible2.snap @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for Infallible2 { ) -> Self::Output<'a> { infallible2(a, b) } - fn output_type( + fn output_sql_type( &self, input_types: &[mz_repr::SqlColumnType], ) -> mz_repr::SqlColumnType { diff --git a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible3.snap b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible3.snap index b58a1e35e27d1..d5d4fc6aea168 100644 --- a/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible3.snap +++ b/src/expr/src/scalar/func/snapshots/mz_expr__scalar__func__binary__test__infallible3.snap @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for Infallible3 { ) -> Self::Output<'a> { infallible3(a, b) } - fn output_type( + fn output_sql_type( &self, input_types: &[mz_repr::SqlColumnType], ) -> mz_repr::SqlColumnType { diff --git a/src/expr/src/scalar/func/unary.rs b/src/expr/src/scalar/func/unary.rs index 879bbd5f8a163..69ee975dd238d 100644 --- a/src/expr/src/scalar/func/unary.rs +++ b/src/expr/src/scalar/func/unary.rs @@ -15,7 +15,7 @@ use std::{fmt, str}; -use mz_repr::{Datum, InputDatumType, OutputDatumType, RowArena, SqlColumnType}; +use mz_repr::{Datum, InputDatumType, OutputDatumType, ReprColumnType, RowArena, SqlColumnType}; use crate::scalar::func::impls::*; use crate::{EvalError, MirScalarExpr}; @@ -31,7 +31,11 @@ pub trait LazyUnaryFunc { ) -> Result, EvalError>; /// The output SqlColumnType of this function. - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType; + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType; + + fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType { + ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type))) + } /// Whether this function will produce NULL on NULL input. fn propagates_nulls(&self) -> bool; @@ -105,7 +109,12 @@ pub trait EagerUnaryFunc { fn call<'a>(&self, input: Self::Input<'a>) -> Self::Output<'a>; /// The output SqlColumnType of this function - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType; + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType; + + /// The output of this function as a representation type. + fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType { + ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type))) + } /// Whether this function will produce NULL on NULL input fn propagates_nulls(&self) -> bool { @@ -157,8 +166,8 @@ impl LazyUnaryFunc for T { } } - fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { - self.output_type(input_type) + fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { + self.output_sql_type(input_type) } fn propagates_nulls(&self) -> bool { diff --git a/src/expr/src/scalar/func/unmaterializable.rs b/src/expr/src/scalar/func/unmaterializable.rs index 6b2532525834c..92045b53abe55 100644 --- a/src/expr/src/scalar/func/unmaterializable.rs +++ b/src/expr/src/scalar/func/unmaterializable.rs @@ -19,7 +19,7 @@ use std::fmt; use mz_lowertest::MzReflect; -use mz_repr::{SqlColumnType, SqlScalarType}; +use mz_repr::{ReprColumnType, SqlColumnType, SqlScalarType}; use serde::{Deserialize, Serialize}; #[derive( @@ -58,7 +58,7 @@ pub enum UnmaterializableFunc { } impl UnmaterializableFunc { - pub fn output_type(&self) -> SqlColumnType { + pub fn output_sql_type(&self) -> SqlColumnType { match self { UnmaterializableFunc::CurrentDatabase => SqlScalarType::String.nullable(false), // TODO: The `CurrentSchema` function should return `name`. This is @@ -105,6 +105,13 @@ impl UnmaterializableFunc { .nullable(false), } } + + /// Computes the representation type of this unmaterializable function. + /// + /// This is a wrapper around [`Self::output_sql_type`] that converts the result to a representation type. + pub fn output_type(&self) -> ReprColumnType { + ReprColumnType::from(&self.output_sql_type()) + } } impl fmt::Display for UnmaterializableFunc { diff --git a/src/expr/src/scalar/func/variadic.rs b/src/expr/src/scalar/func/variadic.rs index 053cc3db9941b..69d6c535b9b62 100644 --- a/src/expr/src/scalar/func/variadic.rs +++ b/src/expr/src/scalar/func/variadic.rs @@ -26,6 +26,7 @@ use mz_lowertest::MzReflect; use mz_ore::cast::{CastFrom, ReinterpretCast}; use mz_ore::soft_assert_or_log; use mz_pgtz::timezone::TimezoneSpec; +use mz_repr::ReprColumnType; use mz_repr::adt::array::{ArrayDimension, ArrayDimensions, InvalidArrayError}; use mz_repr::adt::mz_acl_item::{AclItem, AclMode, MzAclItem}; use mz_repr::adt::range::{InvalidRangeError, Range, RangeBound, parse_range_bound_flags}; @@ -2412,7 +2413,7 @@ impl VariadicFunc { } } - pub fn output_type(&self, input_types: Vec) -> SqlColumnType { + pub fn output_sql_type(&self, input_types: Vec) -> SqlColumnType { let in_nullable = input_types.iter().any(|t| t.nullable); match self { Self::And(s) => s.output_type(&input_types), @@ -2523,6 +2524,15 @@ impl VariadicFunc { } } + /// Computes the representation type of this variadic function. + /// + /// This is a wrapper around [`Self::output_sql_type`] that converts the result to a representation type. + pub fn output_type(&self, input_types: Vec) -> ReprColumnType { + ReprColumnType::from( + &self.output_sql_type(input_types.iter().map(SqlColumnType::from_repr).collect()), + ) + } + /// Whether the function output is NULL if any of its inputs are NULL. /// /// NB: if any input is NULL the output will be returned as NULL without diff --git a/src/sql/src/plan/hir.rs b/src/sql/src/plan/hir.rs index 681cd9c849987..37a73f4e2c9c4 100644 --- a/src/sql/src/plan/hir.rs +++ b/src/sql/src/plan/hir.rs @@ -552,7 +552,7 @@ impl ScalarWindowExpr { _inner: &SqlRelationType, _params: &BTreeMap, ) -> SqlColumnType { - self.func.output_type() + self.func.output_sql_type() } pub fn into_expr(self) -> mz_expr::AggregateFunc { @@ -599,7 +599,7 @@ impl Display for ScalarWindowFunc { } impl ScalarWindowFunc { - pub fn output_type(&self) -> SqlColumnType { + pub fn output_sql_type(&self) -> SqlColumnType { match self { ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false), ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false), @@ -668,7 +668,8 @@ impl ValueWindowExpr { inner: &SqlRelationType, params: &BTreeMap, ) -> SqlColumnType { - self.func.output_type(self.args.typ(outers, inner, params)) + self.func + .output_sql_type(self.args.typ(outers, inner, params)) } /// Converts into `mz_expr::AggregateFunc`. @@ -734,7 +735,7 @@ pub enum ValueWindowFunc { } impl ValueWindowFunc { - pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { match self { ValueWindowFunc::Lag | ValueWindowFunc::Lead => { // The input is a (value, offset, default) record, so extract the type of the first arg @@ -751,7 +752,7 @@ impl ValueWindowFunc { fields: funcs .iter() .zip_eq(input_types) - .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone()))) + .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone()))) .collect(), custom_id: None, } @@ -841,7 +842,7 @@ impl AggregateWindowExpr { ) -> SqlColumnType { self.aggregate_expr .func - .output_type(self.aggregate_expr.expr.typ(outers, inner, params)) + .output_sql_type(self.aggregate_expr.expr.typ(outers, inner, params)) } pub fn into_expr(self) -> (Box, mz_expr::AggregateFunc) { @@ -1499,7 +1500,7 @@ impl AggregateFunc { /// The output column type also contains nullability information, which /// is (without further information) true for aggregations that are not /// counts. - pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType { + pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType { let scalar_type = match self { AggregateFunc::Count => SqlScalarType::Int64, AggregateFunc::Any => SqlScalarType::Bool, @@ -1570,7 +1571,7 @@ impl AggregateFunc { fields: funcs .iter() .zip_eq(input_types) - .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone()))) + .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone()))) .collect(), custom_id: None, } @@ -1631,7 +1632,7 @@ impl HirRelationExpr { } typ } - HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(), + HirRelationExpr::CallTable { func, exprs: _ } => func.output_sql_type(), HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => { input.typ(outers, params) } @@ -3938,18 +3939,18 @@ impl AbstractExpr for HirScalarExpr { } HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true), HirScalarExpr::Literal(_, typ, _name) => typ.clone(), - HirScalarExpr::CallUnmaterializable(func, _name) => func.output_type(), + HirScalarExpr::CallUnmaterializable(func, _name) => func.output_sql_type(), HirScalarExpr::CallUnary { expr, func, name: _, - } => func.output_type(expr.typ(outers, inner, params)), + } => func.output_sql_type(expr.typ(outers, inner, params)), HirScalarExpr::CallBinary { expr1, expr2, func, name: _, - } => func.output_type(&[ + } => func.output_sql_type(&[ expr1.typ(outers, inner, params), expr2.typ(outers, inner, params), ]), @@ -3957,7 +3958,7 @@ impl AbstractExpr for HirScalarExpr { exprs, func, name: _, - } => func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()), + } => func.output_sql_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()), HirScalarExpr::If { cond: _, then, @@ -3989,7 +3990,8 @@ impl AggregateExpr { inner: &SqlRelationType, params: &BTreeMap, ) -> SqlColumnType { - self.func.output_type(self.expr.typ(outers, inner, params)) + self.func + .output_sql_type(self.expr.typ(outers, inner, params)) } /// Returns whether the expression is COUNT(*) or not. Note that diff --git a/src/sql/src/plan/lowering.rs b/src/sql/src/plan/lowering.rs index e255d773cf58e..95043e860dfb5 100644 --- a/src/sql/src/plan/lowering.rs +++ b/src/sql/src/plan/lowering.rs @@ -1534,7 +1534,7 @@ impl HirScalarExpr { mz_expr::TableFunc::UnnestList { el_typ: aggregate .func - .output_type(agg_input_type) + .output_sql_type(agg_input_type) .scalar_type .unwrap_list_element_type() .clone(), diff --git a/src/transform/src/canonicalization/flat_map_elimination.rs b/src/transform/src/canonicalization/flat_map_elimination.rs index 3eb81506fcf56..483303d4c26a8 100644 --- a/src/transform/src/canonicalization/flat_map_elimination.rs +++ b/src/transform/src/canonicalization/flat_map_elimination.rs @@ -18,7 +18,7 @@ use itertools::Itertools; use mz_expr::visit::Visit; use mz_expr::{MirRelationExpr, MirScalarExpr, TableFunc}; -use mz_repr::{Diff, ReprScalarType, RowArena}; +use mz_repr::{Diff, RowArena}; use crate::TransformCtx; @@ -90,9 +90,7 @@ impl FlatMapElimination { let map_exprs = first_row .into_iter() .zip_eq(types) - .map(|(d, typ)| { - MirScalarExpr::literal_ok(d, ReprScalarType::from(&typ.scalar_type)) - }) + .map(|(d, typ)| MirScalarExpr::literal_ok(d, typ.scalar_type)) .collect(); *relation = input.take_dangerous().map(map_exprs); } diff --git a/src/transform/src/column_knowledge.rs b/src/transform/src/column_knowledge.rs index 67a09b3bff4f1..64525a09d2782 100644 --- a/src/transform/src/column_knowledge.rs +++ b/src/transform/src/column_knowledge.rs @@ -20,7 +20,7 @@ use mz_expr::{ use mz_ore::cast::CastFrom; use mz_ore::stack::{CheckedRecursion, RecursionGuard}; use mz_ore::{assert_none, soft_panic_or_log}; -use mz_repr::{Datum, ReprColumnType, ReprScalarType, Row, SqlColumnType}; +use mz_repr::{Datum, ReprColumnType, ReprScalarType, Row}; use crate::{TransformCtx, TransformError}; @@ -519,13 +519,6 @@ impl From<&MirScalarExpr> for DatumKnowledge { } } -impl From<(Datum<'_>, &SqlColumnType)> for DatumKnowledge { - fn from((d, t): (Datum<'_>, &SqlColumnType)) -> Self { - let value = Ok(Row::pack_slice(std::slice::from_ref(&d))); - let typ = ReprScalarType::from(&t.scalar_type); - Self::Lit { value, typ } - } -} impl From<(Datum<'_>, &ReprColumnType)> for DatumKnowledge { fn from((d, t): (Datum<'_>, &ReprColumnType)) -> Self { let value = Ok(Row::pack_slice(std::slice::from_ref(&d))); @@ -533,12 +526,6 @@ impl From<(Datum<'_>, &ReprColumnType)> for DatumKnowledge { Self::Lit { value, typ } } } -impl From<&SqlColumnType> for DatumKnowledge { - fn from(typ: &SqlColumnType) -> Self { - let nullable = typ.nullable; - Self::Any { nullable } - } -} impl From<&ReprColumnType> for DatumKnowledge { fn from(typ: &ReprColumnType) -> Self { let nullable = typ.nullable; diff --git a/src/transform/src/literal_lifting.rs b/src/transform/src/literal_lifting.rs index 4f76ee89af31e..f4aa0bd807b9e 100644 --- a/src/transform/src/literal_lifting.rs +++ b/src/transform/src/literal_lifting.rs @@ -27,7 +27,7 @@ use mz_expr::JoinImplementation::IndexedFilter; use mz_expr::visit::Visit; use mz_expr::{Id, JoinInputMapper, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT}; use mz_ore::stack::{CheckedRecursion, RecursionGuard}; -use mz_repr::{ReprScalarType, Row, RowPacker}; +use mz_repr::{Row, RowPacker}; use crate::TransformCtx; @@ -602,9 +602,7 @@ impl LiteralLifting { eval, // This type information should be available in the `a.expr` literal, // but extracting it with pattern matching seems awkward. - ReprScalarType::from( - &aggr.func.output_type(aggr.expr.sql_typ(&[])).scalar_type, - ), + aggr.func.output_type(aggr.expr.typ(&[])).scalar_type, ) }; diff --git a/src/transform/src/typecheck.rs b/src/transform/src/typecheck.rs index caa39a38b8f26..142c2732e972a 100644 --- a/src/transform/src/typecheck.rs +++ b/src/transform/src/typecheck.rs @@ -25,7 +25,6 @@ use mz_repr::adt::range::Range; use mz_repr::explain::{DummyHumanizer, ExprHumanizer}; use mz_repr::{ ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType, - SqlColumnType, }; /// Typechecking contexts as shared by various typechecking passes. @@ -996,11 +995,7 @@ impl Typecheck { // TODO(mgree) check t_exprs agrees with `func`'s input type let t_out: Vec = func - .output_type() - .column_types - .iter() - .map(ReprColumnType::from) - .collect_vec(); + .output_type().column_types; // FlatMap extends the existing columns t_in.extend(t_out); @@ -1508,31 +1503,23 @@ impl Typecheck { Ok(typ) } - CallUnmaterializable(func) => Ok(ReprColumnType::from(&func.output_type())), + CallUnmaterializable(func) => Ok(func.output_type()), CallUnary { expr, func } => { let typ_in = tc.typecheck_scalar(expr, source, column_types)?; - let typ_out = func.output_type(SqlColumnType::from_repr(&typ_in)); - Ok(ReprColumnType::from(&typ_out)) + let typ_out = func.output_type(typ_in); + Ok(typ_out) } CallBinary { expr1, expr2, func } => { let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?; let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?; - let typ_out = func.output_type(&[ - SqlColumnType::from_repr(&typ_in1), - SqlColumnType::from_repr(&typ_in2), - ]); - Ok(ReprColumnType::from(&typ_out)) + let typ_out = func.output_type(&[typ_in1, typ_in2]); + Ok(typ_out) } - CallVariadic { exprs, func } => Ok(ReprColumnType::from( - &func.output_type( - exprs - .iter() - .map(|e| { - tc.typecheck_scalar(e, source, column_types) - .map(|typ| SqlColumnType::from_repr(&typ)) - }) - .collect::, TypeError>>()?, - ), + CallVariadic { exprs, func } => Ok(func.output_type( + exprs + .iter() + .map(|e| tc.typecheck_scalar(e, source, column_types)) + .collect::, TypeError>>()?, )), If { cond, then, els } => { let cond_type = tc.typecheck_scalar(cond, source, column_types)?; @@ -1589,9 +1576,7 @@ impl Typecheck { // TODO check that t_in is actually acceptable for `func` - Ok(ReprColumnType::from( - &expr.func.output_type(SqlColumnType::from_repr(&t_in)), - )) + Ok(expr.func.output_type(t_in)) }) } } @@ -2069,7 +2054,7 @@ impl<'a> TypeError<'a> { #[cfg(test)] mod tests { use mz_ore::{assert_err, assert_ok}; - use mz_repr::{arb_datum, arb_datum_for_column}; + use mz_repr::{SqlColumnType, arb_datum, arb_datum_for_column}; use proptest::prelude::*; use super::*;