diff --git a/python/benchmarks/bench_eval_type.py b/python/benchmarks/bench_eval_type.py index 6dd352cf5d7e2..1674a7e660ebf 100644 --- a/python/benchmarks/bench_eval_type.py +++ b/python/benchmarks/bench_eval_type.py @@ -558,6 +558,103 @@ class CogroupedMapArrowUDFPeakmemBench(_CogroupedMapArrowBenchMixin, _PeakmemBen pass +# -- SQL_COGROUPED_MAP_PANDAS_UDF ---------------------------------------------- +# UDF receives two ``pandas.DataFrame`` (left, right) per co-group, returns +# ``pandas.DataFrame``. Optional 3-arg variant ``(key, left, right)``. + + +class _CogroupedMapPandasBenchMixin: + """Provides _write_scenario for SQL_COGROUPED_MAP_PANDAS_UDF.""" + + def _cogrouped_map_pandas_identity(left, right): + """Identity cogroup UDF: returns left DataFrame as-is.""" + return left + + def _cogrouped_map_pandas_concat(left, right): + """Concat cogroup UDF: vertically concatenates left and right DataFrames.""" + import pandas as pd + + return pd.concat([left, right], ignore_index=True) + + def _cogrouped_map_pandas_left_semi(left, right): + """Left-semi cogroup UDF: filters left rows whose key exists in right.""" + key_col = left.columns[0] + return left[left[key_col].isin(right[key_col])] + + def _cogrouped_map_pandas_key_identity(key, left, right): + """3-arg cogroup UDF that consumes ``key`` and returns left DataFrame.""" + return left + + # Scaled down vs SQL_COGROUPED_MAP_ARROW_UDF: pandas conversion adds + # per-group Arrow<->Pandas overhead on both left and right sides. + _scenario_configs = { + "few_groups_sm": (50, 1_000, 1, 4), + "few_groups_lg": (50, 10_000, 1, 4), + "many_groups_sm": (500, 200, 1, 4), + "many_groups_lg": (200, 2_000, 1, 4), + "wide_values": (100, 1_000, 1, 20), + "multi_key": (100, 1_000, 3, 5), + } + + @staticmethod + def _build_scenario(name): + """Build a cogroup scenario: two DataFrames with the same grouping structure. + + Like cogrouped arrow, batches have flat columns: + [key_col_0, ..., key_col_k, val_col_0, ..., val_col_v]. + """ + np.random.seed(42) + num_groups, rows_per_group, num_key_cols, num_value_cols = ( + _CogroupedMapPandasBenchMixin._scenario_configs[name] + ) + n_cols = num_key_cols + num_value_cols + type_pool = MockDataFactory.MIXED_TYPES[:n_cols] + while len(type_pool) < n_cols: + type_pool = type_pool + MockDataFactory.MIXED_TYPES[: n_cols - len(type_pool)] + + cogroups, schema = MockDataFactory.make_cogrouped_batches( + num_groups=num_groups, + num_rows=rows_per_group, + num_cols=n_cols, + spark_type_pool=type_pool, + batch_size=rows_per_group, + ) + return_type = StructType(schema.fields[num_key_cols:]) + return (cogroups, return_type, num_key_cols, num_value_cols) + + # Each UDF entry: (func, n_args). n_args=2 -> func(left, right); + # n_args=3 -> func(key, left, right). + _udfs = { + "identity_udf": (_cogrouped_map_pandas_identity, 2), + "concat_udf": (_cogrouped_map_pandas_concat, 2), + "left_semi_udf": (_cogrouped_map_pandas_left_semi, 2), + "key_identity_udf": (_cogrouped_map_pandas_key_identity, 3), + } + params = [list(_scenario_configs), list(_udfs)] + param_names = ["scenario", "udf"] + + def _write_scenario(self, scenario, udf_name, buf): + groups, schema, num_key_cols, num_value_cols = self._build_scenario(scenario) + udf_func, _ = self._udfs[udf_name] + left_offsets = MockUDFFactory.make_grouped_arg_offsets(num_key_cols, num_value_cols) + right_offsets = MockUDFFactory.make_grouped_arg_offsets(num_key_cols, num_value_cols) + arg_offsets = left_offsets + right_offsets + MockProtocolWriter.write_worker_input( + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + lambda b: MockProtocolWriter.write_udf_payload(udf_func, schema, arg_offsets, b), + lambda b: MockProtocolWriter.write_grouped_data_payload(groups, buf=b), + buf, + ) + + +class CogroupedMapPandasUDFTimeBench(_CogroupedMapPandasBenchMixin, _TimeBenchBase): + pass + + +class CogroupedMapPandasUDFPeakmemBench(_CogroupedMapPandasBenchMixin, _PeakmemBenchBase): + pass + + # -- SQL_GROUPED_AGG_ARROW_UDF ------------------------------------------------ # UDF receives ``pa.Array`` columns per group, returns scalar.