Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions src/adapter/src/optimize/dataflows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use mz_repr::adt::array::ArrayDimension;
use mz_repr::explain::trace_plan;
use mz_repr::optimize::OptimizerFeatures;
use mz_repr::role_id::RoleId;
use mz_repr::{Datum, GlobalId, ReprRelationType, ReprScalarType, Row};
use mz_repr::{Datum, GlobalId, ReprRelationType, Row};
use mz_sql::catalog::CatalogRole;
use mz_sql::rbac;
use mz_sql::session::metadata::SessionMetadata;
Expand Down Expand Up @@ -253,10 +253,7 @@ impl ExprPrep for ExprPrepWebhookValidation {
e
{
let now: Datum = now.try_into()?;
let const_expr = MirScalarExpr::literal_ok(
now,
ReprScalarType::from(&f.output_type().scalar_type),
);
let const_expr = MirScalarExpr::literal_ok(now, f.output_type().scalar_type);
*e = const_expr;
}
Ok(())
Expand Down Expand Up @@ -591,7 +588,7 @@ fn eval_unmaterializable_func(
.expect("known to be a valid array");
Ok(MirScalarExpr::literal_from_single_element_row(
row,
ReprScalarType::from(&f.output_type().scalar_type),
f.output_type().scalar_type,
))
};
let pack_dict = |mut datums: Vec<(String, String)>| {
Expand All @@ -604,13 +601,13 @@ fn eval_unmaterializable_func(
);
Ok(MirScalarExpr::literal_from_single_element_row(
row,
ReprScalarType::from(&f.output_type().scalar_type),
f.output_type().scalar_type,
))
};
let pack = |datum| {
Ok(MirScalarExpr::literal_ok(
datum,
ReprScalarType::from(&f.output_type().scalar_type),
f.output_type().scalar_type,
))
};

Expand Down Expand Up @@ -708,7 +705,7 @@ fn eval_unmaterializable_func(
});
Ok(MirScalarExpr::literal_from_single_element_row(
row,
ReprScalarType::from(&f.output_type().scalar_type),
f.output_type().scalar_type,
))
}
UnmaterializableFunc::MzSessionId => pack(Datum::from(state.config().session_id)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for AddInt16 {
) -> Self::Output<'a> {
add_int16(a, b)
}
fn output_type(
fn output_sql_type(
&self,
input_types: &[mz_repr::SqlColumnType],
) -> mz_repr::SqlColumnType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for UnaryFn {
) -> Self::Output<'a> {
unary_fn(a, b, temp_storage)
}
fn output_type(
fn output_sql_type(
&self,
input_types: &[mz_repr::SqlColumnType],
) -> mz_repr::SqlColumnType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl crate::func::binary::EagerBinaryFunc for ComplexOutputTypeFn {
) -> Self::Output<'a> {
complex_output_type_fn(a, b)
}
fn output_type(
fn output_sql_type(
&self,
input_types: &[mz_repr::SqlColumnType],
) -> mz_repr::SqlColumnType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ impl crate::func::EagerUnaryFunc for UnaryFn {
fn call<'a>(&self, a: Self::Input<'a>) -> Self::Output<'a> {
unary_fn(a)
}
fn output_type(&self, input_type: mz_repr::SqlColumnType) -> mz_repr::SqlColumnType {
fn output_sql_type(
&self,
input_type: mz_repr::SqlColumnType,
) -> mz_repr::SqlColumnType {
use mz_repr::AsColumnType;
let output = Self::Output::as_column_type();
let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ impl crate::func::EagerUnaryFunc for UnaryFn {
fn call<'a>(&self, a: Self::Input<'a>) -> Self::Output<'a> {
unary_fn(a)
}
fn output_type(&self, input_type: mz_repr::SqlColumnType) -> mz_repr::SqlColumnType {
fn output_sql_type(
&self,
input_type: mz_repr::SqlColumnType,
) -> mz_repr::SqlColumnType {
use mz_repr::AsColumnType;
let output = Self::Output::as_column_type();
let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
Expand Down
7 changes: 5 additions & 2 deletions src/expr-derive-impl/src/sqlfunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,10 @@ fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<Token
#fn_name(a)
}

fn output_type(&self, input_type: mz_repr::SqlColumnType) -> mz_repr::SqlColumnType {
fn output_sql_type(
&self,
input_type: mz_repr::SqlColumnType
) -> mz_repr::SqlColumnType {
use mz_repr::AsColumnType;
let output = #output_type;
let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
Expand Down Expand Up @@ -510,7 +513,7 @@ fn binary_func(
#fn_name(a, b #arena)
}

fn output_type(
fn output_sql_type(
&self,
input_types: &[mz_repr::SqlColumnType],
) -> mz_repr::SqlColumnType {
Expand Down
18 changes: 9 additions & 9 deletions src/expr/src/interpret.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ impl<'a> ColumnSpecs<'a> {
let range = self
.unmaterializables
.entry(func.clone())
.or_insert_with(|| ResultSpec::has_type(&func.output_type(), true));
.or_insert_with(|| ResultSpec::has_type(&func.output_sql_type(), true));
*range = range.clone().intersect(update);
}

Expand Down Expand Up @@ -847,12 +847,12 @@ impl<'a> Interpreter for ColumnSpecs<'a> {
}

fn unmaterializable(&self, func: &UnmaterializableFunc) -> Self::Summary {
let col_type = func.output_type();
let col_type = func.output_sql_type();
let range = self
.unmaterializables
.get(func)
.cloned()
.unwrap_or_else(|| ResultSpec::has_type(&func.output_type(), true));
.unwrap_or_else(|| ResultSpec::has_type(&func.output_sql_type(), true));
ColumnSpec { col_type, range }
}

Expand All @@ -872,7 +872,7 @@ impl<'a> Interpreter for ColumnSpecs<'a> {
})
};

let col_type = func.output_type(summary.col_type);
let col_type = func.output_sql_type(summary.col_type);

let range = mapped_spec.intersect(ResultSpec::has_type(&col_type, fallible));
ColumnSpec { col_type, range }
Expand Down Expand Up @@ -904,7 +904,7 @@ impl<'a> Interpreter for ColumnSpecs<'a> {
})
};

let col_type = func.output_type(&[left.col_type, right.col_type]);
let col_type = func.output_sql_type(&[left.col_type, right.col_type]);

let range = mapped_spec.intersect(ResultSpec::has_type(&col_type, fallible));
ColumnSpec { col_type, range }
Expand Down Expand Up @@ -954,7 +954,7 @@ impl<'a> Interpreter for ColumnSpecs<'a> {
};

let col_types = args.into_iter().map(|spec| spec.col_type).collect();
let col_type = func.output_type(col_types);
let col_type = func.output_sql_type(col_types);

let range = mapped_spec.intersect(ResultSpec::has_type(&col_type, fallible));

Expand Down Expand Up @@ -1318,7 +1318,7 @@ mod tests {
if !unary_typecheck(&func, &type_in) {
return None;
}
let type_out = func.output_type(type_in);
let type_out = func.output_sql_type(type_in);
let expr_out = MirScalarExpr::CallUnary {
func,
expr: Box::new(expr_in),
Expand All @@ -1337,7 +1337,7 @@ mod tests {
if !binary_typecheck(&func, &type_left, &type_right) {
return None;
}
let type_out = func.output_type(&[type_left, type_right]);
let type_out = func.output_sql_type(&[type_left, type_right]);
let expr_out = MirScalarExpr::CallBinary {
func,
expr1: Box::new(expr_left),
Expand All @@ -1356,7 +1356,7 @@ mod tests {
if !variadic_typecheck(&func, &type_in) {
return None;
}
let type_out = func.output_type(type_in);
let type_out = func.output_sql_type(type_in);
let expr_out = MirScalarExpr::CallVariadic {
func,
exprs: exprs_in,
Expand Down
24 changes: 10 additions & 14 deletions src/expr/src/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ impl MirRelationExpr {
FlatMap { func, .. } => {
let mut result = input_types.next().unwrap().clone();
result.extend(
func.output_type()
func.output_sql_type()
.column_types
.iter()
.map(ReprColumnType::from),
Expand Down Expand Up @@ -1256,13 +1256,13 @@ impl MirRelationExpr {
/// # Example
///
/// ```rust
/// use mz_repr::{Datum, SqlColumnType, SqlRelationType, SqlScalarType};
/// use mz_repr::{Datum, SqlColumnType, ReprRelationType, ReprScalarType};
/// use mz_expr::MirRelationExpr;
///
/// // A common schema for each input.
/// let schema = SqlRelationType::new(vec![
/// SqlScalarType::Int32.nullable(false),
/// SqlScalarType::Int32.nullable(false),
/// let schema = ReprRelationType::new(vec![
/// ReprScalarType::Int32.nullable(false),
/// ReprScalarType::Int32.nullable(false),
/// ]);
///
/// // the specific data are not important here.
Expand Down Expand Up @@ -1519,16 +1519,16 @@ impl MirRelationExpr {
/// the correct type.
pub fn take_safely(&mut self, typ: Option<ReprRelationType>) -> MirRelationExpr {
if let Some(typ) = &typ {
let self_typ = self.sql_typ();
let self_typ = self.typ();
soft_assert_no_log!(
self_typ
.column_types
.iter()
.zip_eq(typ.column_types.iter())
.all(|(t1, t2)| ReprScalarType::from(&t1.scalar_type) == t2.scalar_type)
.all(|(t1, t2)| t1.scalar_type == t2.scalar_type)
);
}
let mut typ = typ.unwrap_or_else(|| ReprRelationType::from(&self.sql_typ()));
let mut typ = typ.unwrap_or_else(|| self.typ());
typ.keys = vec![vec![]];
for ct in typ.column_types.iter_mut() {
ct.nullable = false;
Expand Down Expand Up @@ -2478,16 +2478,12 @@ pub struct AggregateExpr {
impl AggregateExpr {
/// Computes the type of this `AggregateExpr`.
pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
self.func.output_type(self.expr.sql_typ(column_types))
self.func.output_sql_type(self.expr.sql_typ(column_types))
}

/// Computes the type of this `AggregateExpr`.
pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
ReprColumnType::from(
&self
.func
.output_type(SqlColumnType::from_repr(&self.expr.typ(column_types))),
)
self.func.output_type(self.expr.typ(column_types))
}

/// Returns whether the expression has a constant result.
Expand Down
28 changes: 21 additions & 7 deletions src/expr/src/relation/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
use mz_repr::adt::regex::{Regex as ReprRegex, RegexCompilationError};
use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampLike};
use mz_repr::{
ColumnName, Datum, Diff, Row, RowArena, RowPacker, SharedRow, SqlColumnType, SqlRelationType,
SqlScalarType, datum_size,
ColumnName, Datum, Diff, ReprColumnType, ReprRelationType, Row, RowArena, RowPacker, SharedRow,
SqlColumnType, SqlRelationType, SqlScalarType, datum_size,
};
use num::{CheckedAdd, Integer, Signed, ToPrimitive};
use ordered_float::OrderedFloat;
Expand Down Expand Up @@ -2324,7 +2324,7 @@ impl AggregateFunc {
/// The output column type also contains nullability information, which
/// is (without further information) true for aggregations that are not
/// counts.
pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
let scalar_type = match self {
AggregateFunc::Count => SqlScalarType::Int64,
AggregateFunc::Any => SqlScalarType::Bool,
Expand Down Expand Up @@ -2437,7 +2437,7 @@ impl AggregateFunc {
let arg_type = fields[0].unwrap_record_element_type()[1]
.clone()
.nullable(true);
let wrapped_aggr_out_type = wrapped_aggregate.output_type(arg_type);
let wrapped_aggr_out_type = wrapped_aggregate.output_sql_type(arg_type);

SqlScalarType::List {
element_type: Box::new(SqlScalarType::Record {
Expand Down Expand Up @@ -2465,7 +2465,7 @@ impl AggregateFunc {
|(arg_type, wrapped_agg)| {
(
ColumnName::from(wrapped_agg.name()),
wrapped_agg.output_type((**arg_type).clone().nullable(true)),
wrapped_agg.output_sql_type((**arg_type).clone().nullable(true)),
)
}).collect_vec();

Expand Down Expand Up @@ -2602,6 +2602,13 @@ impl AggregateFunc {
scalar_type.nullable(nullable)
}

/// Computes the representation type of this aggregate function.
///
/// This is a wrapper around [`Self::output_sql_type`] that converts the result to a representation type.
pub fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type)))
}

/// Compute output type for ROW_NUMBER, RANK, DENSE_RANK
fn output_type_ranking_window_funcs(
input_type: &SqlColumnType,
Expand Down Expand Up @@ -3555,7 +3562,7 @@ impl TableFunc {
}
}

pub fn output_type(&self) -> SqlRelationType {
pub fn output_sql_type(&self) -> SqlRelationType {
let (column_types, keys) = match self {
TableFunc::AclExplode => {
let column_types = vec![
Expand Down Expand Up @@ -3694,7 +3701,7 @@ impl TableFunc {
(column_types, keys)
}
TableFunc::WithOrdinality(WithOrdinality { inner }) => {
let mut typ = inner.output_type();
let mut typ = inner.output_sql_type();
// Add the ordinality column.
typ.column_types.push(SqlScalarType::Int64.nullable(false));
// The ordinality column is always a key.
Expand All @@ -3712,6 +3719,13 @@ impl TableFunc {
}
}

/// Computes the representation type of this table function.
///
/// This is a wrapper around [`Self::output_sql_type`] that converts the result to a representation type.
pub fn output_type(&self) -> ReprRelationType {
ReprRelationType::from(&self.output_sql_type())
}

pub fn output_arity(&self) -> usize {
match self {
TableFunc::AclExplode => 4,
Expand Down
8 changes: 4 additions & 4 deletions src/expr/src/relation/join_input_mapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,13 @@ impl JoinInputMapper {
/// # Examples
///
/// ```
/// use mz_repr::{Datum, SqlColumnType, SqlRelationType, SqlScalarType};
/// use mz_repr::{Datum, ReprColumnType, ReprRelationType, ReprScalarType};
/// use mz_expr::{JoinInputMapper, MirRelationExpr, MirScalarExpr};
///
/// // A two-column schema common to each of the three inputs
/// let schema = SqlRelationType::new(vec![
/// SqlScalarType::Int32.nullable(false),
/// SqlScalarType::Int32.nullable(false),
/// let schema = ReprRelationType::new(vec![
/// ReprScalarType::Int32.nullable(false),
/// ReprScalarType::Int32.nullable(false),
/// ]);
///
/// // the specific data are not important here.
Expand Down
Loading