Skip to content

Commit b4ee442

Browse files
committed
fix: udtf single-row expectation
1 parent e74c0c1 commit b4ee442

File tree

7 files changed

+53
-36
lines changed

7 files changed

+53
-36
lines changed

src/query/expression/src/utils/udf_client.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -302,15 +302,14 @@ impl UDFFlightClient {
302302
&mut self,
303303
name: &str,
304304
func_name: &str,
305-
num_rows: usize,
305+
num_rows: Option<usize>,
306306
args: Vec<BlockEntry>,
307307
return_type: &DataType,
308308
) -> Result<DataBlock> {
309309
let instant = Instant::now();
310310

311311
Profile::record_usize_profile(ProfileStatisticsName::ExternalServerRequestCount, 1);
312312
record_running_requests_external_start(name, 1);
313-
record_request_external_batch_rows(func_name, num_rows);
314313

315314
let args = args
316315
.into_iter()
@@ -333,7 +332,9 @@ impl UDFFlightClient {
333332
.collect::<Vec<_>>();
334333
let data_schema = DataSchema::new(fields);
335334

336-
let input_batch = DataBlock::new(args, num_rows)
335+
// at least 1 for `UDFFlightClient::batch_rows`
336+
let input_num_rows = args.first().map(|entry| entry.len()).unwrap_or(1);
337+
let input_batch = DataBlock::new(args, input_num_rows)
337338
.to_record_batch_with_dataschema(&data_schema)
338339
.map_err(|err| ErrorCode::from_string(format!("{err}")))?;
339340

@@ -359,13 +360,17 @@ impl UDFFlightClient {
359360
));
360361
}
361362

362-
if result_block.num_rows() != num_rows {
363-
return Err(ErrorCode::UDFDataError(format!(
364-
"UDF server should return {} rows, but it returned {} rows",
365-
num_rows,
366-
result_block.num_rows()
367-
)));
363+
if let Some(expected_rows) = num_rows {
364+
if result_block.num_rows() != expected_rows {
365+
return Err(ErrorCode::UDFDataError(format!(
366+
"UDF server should return {} rows, but it returned {} rows",
367+
expected_rows,
368+
result_block.num_rows()
369+
)));
370+
}
368371
}
372+
record_request_external_batch_rows(func_name, result_block.num_rows());
373+
369374
if return_type.remove_nullable().is_tuple() && result_fields.len() > 1 {
370375
if let DataType::Tuple(tys) = return_type.remove_nullable() {
371376
if tys

src/query/service/src/pipelines/builders/builder_udtf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ impl UdtfServerSource {
123123

124124
debug_assert!(func.return_ty.is_tuple());
125125
let result = client
126-
.do_exchange(&func.name, &func.func_name, 1, args, &func.return_ty)
126+
.do_exchange(&func.name, &func.func_name, None, args, &func.return_ty)
127127
.await?;
128128

129129
drop(permit);

src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ impl TransformUdfServer {
140140
.do_exchange(
141141
&func.name,
142142
&func.func_name,
143-
num_rows,
143+
Some(num_rows),
144144
block_entries,
145145
&func.data_type,
146146
)

src/query/service/src/table_functions/others/udf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ impl Table for UdfEchoTable {
163163
let return_type = DataType::Nullable(Box::new(DataType::String));
164164

165165
let result = client
166-
.do_exchange(name, name, num_rows, block_entries, &return_type)
166+
.do_exchange(name, name, Some(num_rows), block_entries, &return_type)
167167
.await?;
168168

169169
let scalar = unsafe { result.get_by_offset(0).index_unchecked(0) };

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5347,7 +5347,7 @@ impl<'a> TypeChecker<'a> {
53475347
.do_exchange(
53485348
name,
53495349
&udf_definition.handler,
5350-
num_rows,
5350+
Some(num_rows),
53515351
block_entries,
53525352
&udf_definition.return_type,
53535353
)

tests/sqllogictests/suites/udf_server/udf_server_test.test

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,8 +717,10 @@ query I
717717
SELECT * from stage_summary_udtf(@s3_stage/output/2024, 21)
718718
----
719719
s3_stage External test output/2024 21 s3_stage:test:output/2024:21
720+
s3_stage External test output/2024 22 s3_stage:test:output/2024:22
720721

721722
query I
722723
SELECT * from multi_stage_process_udtf(@s3_stage/input/2024/, @s3_stage/output/2024, 21)
723724
----
724725
s3_stage s3_stage test test input/2024/ output/2024 29
726+
s3_stage s3_stage test test input/2024/ output/2024 30

tests/udf/udf_server.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -454,17 +454,23 @@ def stage_summary_udtf(data_stage: StageLocation, value: int):
454454
assert data_stage.stage_type.lower() == "external"
455455
assert data_stage.storage
456456
bucket = _stage_bucket(data_stage)
457-
summary = f"{data_stage.stage_name}:{bucket}:{data_stage.relative_path}:{value}"
458-
return [
459-
{
460-
"stage_name": data_stage.stage_name or "",
461-
"stage_type": data_stage.stage_type or "",
462-
"bucket": bucket,
463-
"relative_path": data_stage.relative_path or "",
464-
"value": value,
465-
"summary": summary,
466-
}
467-
]
457+
rows = []
458+
for offset in (0, 1):
459+
current_value = value + offset
460+
summary = (
461+
f"{data_stage.stage_name}:{bucket}:{data_stage.relative_path}:{current_value}"
462+
)
463+
rows.append(
464+
{
465+
"stage_name": data_stage.stage_name or "",
466+
"stage_type": data_stage.stage_type or "",
467+
"bucket": bucket,
468+
"relative_path": data_stage.relative_path or "",
469+
"value": current_value,
470+
"summary": summary,
471+
}
472+
)
473+
return rows
468474

469475

470476
@udf(
@@ -507,18 +513,22 @@ def multi_stage_process_udtf(
507513
assert output_stage.stage_type.lower() == "external"
508514
input_bucket = _stage_bucket(input_stage)
509515
output_bucket = _stage_bucket(output_stage)
510-
result = value + len(input_bucket) + len(output_bucket)
511-
return [
512-
{
513-
"input_stage": input_stage.stage_name or "",
514-
"output_stage": output_stage.stage_name or "",
515-
"input_bucket": input_bucket,
516-
"output_bucket": output_bucket,
517-
"input_relative_path": input_stage.relative_path or "",
518-
"output_relative_path": output_stage.relative_path or "",
519-
"result": result,
520-
}
521-
]
516+
rows = []
517+
for offset in (0, 1):
518+
current_value = value + offset
519+
result = current_value + len(input_bucket) + len(output_bucket)
520+
rows.append(
521+
{
522+
"input_stage": input_stage.stage_name or "",
523+
"output_stage": output_stage.stage_name or "",
524+
"input_bucket": input_bucket,
525+
"output_bucket": output_bucket,
526+
"input_relative_path": input_stage.relative_path or "",
527+
"output_relative_path": output_stage.relative_path or "",
528+
"result": result,
529+
}
530+
)
531+
return rows
522532

523533

524534
@udf(

0 commit comments

Comments
 (0)