Skip to content

Commit d3ad9d2

Browse files
refactor based on PR comments
1 parent 1d3505e commit d3ad9d2

File tree

7 files changed

+62
-62
lines changed

7 files changed

+62
-62
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,6 @@ jobs:
6565
path: testdata/tpcds/data/**
6666
retention-days: 7
6767
if-no-files-found: ignore
68-
- name: Clean up test data
69-
run: |
70-
rm -rf testdata/tpcds/data/*
71-
rm -f $HOME/.local/bin/duckdb
72-
rm -rf /home/runner/.duckdb
73-
df -h
7468

7569
format-check:
7670
runs-on: ubuntu-latest

src/test_utils/property_based.rs

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
use arrow::{
2-
array::{ArrayRef, UInt32Array},
2+
array::{ArrayRef, Float16Array, Float32Array, Float64Array, UInt32Array},
33
compute::{SortColumn, concat_batches, lexsort_to_indices},
44
record_batch::RecordBatch,
55
};
66
use datafusion::{
77
common::{internal_datafusion_err, internal_err},
88
error::{DataFusionError, Result},
9-
execution::context::SessionContext,
109
physical_expr::LexOrdering,
1110
physical_plan::ExecutionPlan,
1211
};
1312
use std::sync::Arc;
1413

15-
/// compares the set of record batches for equality
14+
/// compares the set of record batches for equality
1615
pub async fn compare_result_set(
1716
actual_result: &Result<Vec<RecordBatch>>,
1817
expected_result: &Result<Vec<RecordBatch>>,
@@ -21,10 +20,7 @@ pub async fn compare_result_set(
2120
Ok(batches) => batches,
2221
Err(e) => {
2322
if expected_result.is_ok() {
24-
return internal_err!(
25-
"expected no error but got: {}",
26-
e
27-
);
23+
return internal_err!("expected no error but got: {}", e);
2824
}
2925
return Ok(()); // Both errored, so the query is valid
3026
}
@@ -34,10 +30,7 @@ pub async fn compare_result_set(
3430
Ok(batches) => batches,
3531
Err(e) => {
3632
if actual_result.is_ok() {
37-
return internal_err!(
38-
"expected error but got none, error: {}",
39-
e
40-
);
33+
return internal_err!("expected error but got none, error: {}", e);
4134
}
4235
return Ok(()); // Both errored, so the query is valid
4336
}
@@ -47,7 +40,7 @@ pub async fn compare_result_set(
4740
.map_err(|e| internal_datafusion_err!("result sets were not equal: {}", e))
4841
}
4942

50-
// Ensures that the plans have the same ordering properties and that the actual result is sorted
43+
// Ensures that the plans have the same ordering properties and that the actual result is sorted
5144
// correctly.
5245
pub async fn compare_ordering(
5346
actual_physical_plan: Arc<dyn ExecutionPlan>,
@@ -203,6 +196,12 @@ fn batch_rows_to_strings(batches: &[RecordBatch]) -> Vec<String> {
203196

204197
if array.is_null(row_idx) {
205198
row_values.push("NULL".to_string());
199+
} else if let Some(arr) = array.as_any().downcast_ref::<Float16Array>() {
200+
row_values.push(format!("{:.1$}", arr.value(row_idx), 2));
201+
} else if let Some(arr) = array.as_any().downcast_ref::<Float32Array>() {
202+
row_values.push(format!("{:.1$}", arr.value(row_idx), 2));
203+
} else if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
204+
row_values.push(format!("{:.1$}", arr.value(row_idx), 2));
206205
} else {
207206
// Use Arrow's deterministic string representation
208207
let value_str = array_value_to_string(array, row_idx)
@@ -282,6 +281,8 @@ mod tests {
282281

283282
use arrow::array::{Int32Array, StringArray};
284283
use arrow::datatypes::{DataType, Field, Schema};
284+
use datafusion::physical_plan::collect;
285+
use datafusion::prelude::SessionContext;
285286

286287
use std::sync::Arc;
287288

@@ -438,25 +439,27 @@ mod tests {
438439

439440
// Query which sorted by id should pass.
440441
let ordered_query = "SELECT * FROM test_table ORDER BY id";
442+
441443
let df = actual_ctx.sql(ordered_query).await.unwrap();
442-
let result = df.collect().await;
444+
let task_ctx = actual_ctx.task_ctx();
445+
let actual_plan = df.create_physical_plan().await.unwrap();
446+
let results = collect(actual_plan.clone(), task_ctx).await;
447+
448+
let df = expected_ctx.sql(ordered_query).await.unwrap();
449+
let expected_plan = df.create_physical_plan().await.unwrap();
450+
443451
assert!(
444-
compare_ordering(&actual_ctx, &expected_ctx, ordered_query, &result)
452+
compare_ordering(actual_plan.clone(), expected_plan.clone(), &results)
445453
.await
446454
.is_ok()
447455
);
448456

449457
// This should fail because the batch is not sorted by value
450458
let result = Ok(vec![batch]);
451459
assert!(
452-
compare_ordering(
453-
&actual_ctx,
454-
&expected_ctx,
455-
"SELECT * FROM test_table ORDER BY value",
456-
&result
457-
)
458-
.await
459-
.is_err()
460+
compare_ordering(actual_plan.clone(), expected_plan.clone(), &result)
461+
.await
462+
.is_err()
460463
);
461464
}
462465
}

src/test_utils/tpcds.rs

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,8 @@ pub fn queries() -> Result<Vec<(String, String)>> {
7070

7171
/// Load a single TPC-DS query by ID (1-99).
7272
pub fn get_test_tpcds_query(id: usize) -> Result<String> {
73-
if id < 1 || id > 99 {
74-
return internal_err!("Query ID must be between 1 and 99, got {}", id);
75-
}
76-
7773
let queries_dir = get_queries_dir();
78-
74+
7975
if !queries_dir.exists() {
8076
return internal_err!(
8177
"TPC-DS queries directory not found: {}",
@@ -84,21 +80,14 @@ pub fn get_test_tpcds_query(id: usize) -> Result<String> {
8480
}
8581

8682
let query_file = queries_dir.join(format!("q{}.sql", id));
87-
83+
8884
if !query_file.exists() {
89-
return internal_err!(
90-
"Query file not found: {}",
91-
query_file.display()
92-
);
85+
return internal_err!("Query file not found: {}", query_file.display());
9386
}
9487

9588
let query_sql = fs::read_to_string(&query_file)
9689
.map_err(|e| {
97-
internal_datafusion_err!(
98-
"Failed to read query file {}: {}",
99-
query_file.display(),
100-
e
101-
)
90+
internal_datafusion_err!("Failed to read query file {}: {}", query_file.display(), e)
10291
})?
10392
.trim()
10493
.to_string();

testdata/tpcds/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ This directory contains 99 TPC-DS queries from https://github.com/duckdb/duckdb
22

33
## Modifications for DataFusion Compatibility
44

5-
- Query 57 was modified to add explicit ORDER BY d_moy to avg() window function. DataFusion requires explicit ordering with PARTITION BY.
6-
- Query 72 was modified to support data functions in datafusion
5+
- Queries 47 and 57 were modified to add explicit ORDER BY d_moy to avg() window function. DataFusion requires explicit ordering in window functions with PARTITION BY for deterministic results.
6+
- Query 72 was modified to support date functions in datafusion
77

88
`generate.sh {SCALE_FACTOR}` is a script which can generate TPC-DS parquet data. Requires the duckdb CLI: https://duckdb.org/install/

testdata/tpcds/queries/q47.sql

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
-- TPC-DS Query 47
2+
-- Modified: Added ORDER BY d_moy to avg() window function for DataFusion compatibility
13
WITH v1 AS
24
(SELECT i_category,
35
i_brand,
@@ -10,7 +12,8 @@ WITH v1 AS
1012
i_brand,
1113
s_store_name,
1214
s_company_name,
13-
d_year) avg_monthly_sales,
15+
d_year
16+
ORDER BY d_moy) avg_monthly_sales,
1417
rank() OVER (PARTITION BY i_category,
1518
i_brand,
1619
s_store_name,

testdata/tpcds/queries/q72.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ LEFT OUTER JOIN catalog_returns ON (cr_item_sk = cs_item_sk
2424
AND cr_order_number = cs_order_number)
2525
WHERE d1.d_week_seq = d2.d_week_seq
2626
AND inv_quantity_on_hand < cs_quantity
27-
AND d3.d_date > d1.d_date + INTERVAL '5' DAY -- DuckDB: day + 5
27+
AND d3.d_date > d1.d_date + INTERVAL '5' DAY -- Modified - Original duckdb syntax is: d1.d_date + 5
2828
AND hd_buy_potential = '>10000'
2929
AND d1.d_year = 1999
3030
AND cd_marital_status = 'D'

tests/tpcds_test.rs

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
mod tests {
33
use datafusion::common::runtime::JoinSet;
44
use datafusion::error::Result;
5+
use datafusion::physical_plan::{ExecutionPlan, collect};
56
use datafusion::prelude::SessionContext;
6-
use datafusion::physical_plan::{collect, ExecutionPlan};
77
use datafusion_distributed::test_utils::{
88
localhost::start_localhost_context,
99
property_based::{compare_ordering, compare_result_set},
10-
tpcds::{generate_tpcds_data, get_test_tpcds_query, register_tables, get_data_dir},
10+
tpcds::{generate_tpcds_data, get_data_dir, get_test_tpcds_query, register_tables},
1111
};
1212

1313
use datafusion::arrow::array::RecordBatch;
1414

15-
use datafusion_distributed::{DefaultSessionBuilder, DistributedExt, display_plan_ascii};
15+
use datafusion_distributed::{DefaultSessionBuilder, DistributedExt};
1616
use std::env;
1717
use std::fs;
1818
use std::sync::Arc;
@@ -25,7 +25,8 @@ mod tests {
2525
INIT_TEST_TPCDS_TABLES
2626
.get_or_init(|| async {
2727
if !fs::exists(get_data_dir()).unwrap_or(false) {
28-
let scale_factor = env::var("SCALE_FACTOR").unwrap_or_else(|_| "0.01".to_string());
28+
let scale_factor =
29+
env::var("SCALE_FACTOR").unwrap_or_else(|_| "0.01".to_string());
2930
generate_tpcds_data(scale_factor.as_str()).unwrap();
3031
}
3132
})
@@ -41,7 +42,8 @@ mod tests {
4142
let (mut distributed_ctx, worker_tasks) =
4243
start_localhost_context(NUM_WORKERS, DefaultSessionBuilder).await;
4344
distributed_ctx.set_distributed_files_per_task(FILES_PER_TASK)?;
44-
distributed_ctx.set_distributed_cardinality_effect_task_scale_factor(CARDINALITY_TASK_COUNT_FACTOR)?;
45+
distributed_ctx
46+
.set_distributed_cardinality_effect_task_scale_factor(CARDINALITY_TASK_COUNT_FACTOR)?;
4547
register_tables(&distributed_ctx).await?;
4648

4749
// Create single node context to compare results to.
@@ -51,11 +53,13 @@ mod tests {
5153
Ok((distributed_ctx, single_node_ctx, worker_tasks))
5254
}
5355

54-
async fn run(ctx: &SessionContext, query_sql: &str) -> (Arc<dyn ExecutionPlan>, Result<Vec<RecordBatch>>) {
56+
async fn run(
57+
ctx: &SessionContext,
58+
query_sql: &str,
59+
) -> (Arc<dyn ExecutionPlan>, Result<Vec<RecordBatch>>) {
5560
let df = ctx.sql(&query_sql).await.unwrap();
5661
let task_ctx = ctx.task_ctx();
57-
let plan = df.create_physical_plan().await.unwrap();
58-
println!("{}", display_plan_ascii(plan.as_ref(), false));
62+
let plan = df.create_physical_plan().await.unwrap();
5963
(plan.clone(), collect(plan, task_ctx).await) // Collect execution errors, do not unwrap.
6064
}
6165

@@ -65,17 +69,24 @@ mod tests {
6569
let query_sql = get_test_tpcds_query(query_id)?;
6670
let (distributed_ctx, single_node_ctx, _handles) = setup().await?;
6771

68-
let (distributed_physical_plan, distributed_results) = run(&distributed_ctx, &query_sql).await;
69-
println!("execution complete");
70-
let (single_node_physical_plan, single_node_results) = run(&single_node_ctx, &query_sql).await;
71-
72-
// println!(display(&distributed_physical_plan));
72+
let (single_node_physical_plan, single_node_results) =
73+
run(&single_node_ctx, &query_sql).await;
74+
let (distributed_physical_plan, distributed_results) =
75+
run(&distributed_ctx, &query_sql).await;
7376

7477
let compare_result = tokio::try_join!(
7578
compare_result_set(&distributed_results, &single_node_results),
76-
compare_ordering(distributed_physical_plan, single_node_physical_plan, &distributed_results),
79+
compare_ordering(
80+
distributed_physical_plan,
81+
single_node_physical_plan,
82+
&distributed_results
83+
),
84+
);
85+
assert!(
86+
compare_result.is_ok(),
87+
"Query {query_id} failed: {}",
88+
compare_result.unwrap_err()
7789
);
78-
assert!(compare_result.is_ok(), "Query {query_id} failed: {}", compare_result.unwrap_err());
7990
Ok(())
8091
}
8192

0 commit comments

Comments
 (0)