Skip to content

Commit 681014d

Browse files
working
1 parent 604980a commit 681014d

File tree

3 files changed

+78
-84
lines changed

3 files changed

+78
-84
lines changed

src/bin/fuzz.rs

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@ use async_trait::async_trait;
22
use clap::{Parser, Subcommand};
33
use datafusion::error::Result;
44
use datafusion_distributed::test_utils::{
5-
fuzz::FuzzDB,
6-
tpcds::{discover_tpcds_queries, register_available_tpcds_tables},
5+
fuzz::{FuzzDB, FuzzConfig},
6+
tpcds::{discover_tpcds_queries, register_available_tpcds_tables, generate_tpcds_data},
77
};
88
use datafusion::prelude::SessionContext;
99
use log::{debug, info, warn, error};
1010
use std::process;
1111
use std::time::Instant;
12+
use rand::Rng;
1213

1314
#[derive(Parser)]
1415
#[command(author, version, about, long_about = None)]
@@ -53,18 +54,21 @@ enum Commands {
5354
/// TPC-DS workload implementation
5455
#[derive(Clone)]
5556
struct TpcdsWorkload {
56-
scale_factor: String,
57-
force_regenerate: bool,
5857
query_filter: Option<String>,
5958
}
6059

6160
impl TpcdsWorkload {
62-
fn new(scale_factor: String, force_regenerate: bool, query_filter: Option<String>) -> Self {
63-
Self {
64-
scale_factor,
65-
force_regenerate,
66-
query_filter,
61+
fn try_new(scale_factor: String, force_regenerate: bool, query_filter: Option<String>) -> Result<Self> {
62+
// Generate data if force regenerate is enabled
63+
if force_regenerate {
64+
info!("Generating TPC-DS data with scale factor {}...", scale_factor);
65+
generate_tpcds_data(&scale_factor)?;
66+
info!("✅ TPC-DS data generation completed");
6767
}
68+
69+
Ok(Self {
70+
query_filter,
71+
})
6872
}
6973
}
7074

@@ -114,7 +118,7 @@ async fn main() -> Result<()> {
114118
process::exit(1);
115119
}
116120

117-
let workload = TpcdsWorkload::new(scale_factor.clone(), force_regenerate, queries);
121+
let workload = TpcdsWorkload::try_new(scale_factor.clone(), force_regenerate, queries)?;
118122

119123
info!("🚀 Starting DataFusion distributed fuzz testing");
120124
info!(" Workload: {}", workload.name());
@@ -142,6 +146,16 @@ fn validate_tpcds_args(scale_factor: &str) -> std::result::Result<(), String> {
142146
Ok(())
143147
}
144148

149+
fn randomized_fuzz_config() -> FuzzConfig {
150+
let mut rng = rand::thread_rng();
151+
let config = FuzzConfig {
152+
num_workers: rng.gen_range(2..=8),
153+
files_per_task: rng.gen_range(1..=8),
154+
cardinality_task_count_factor: rng.gen_range(1.0..=3.0),
155+
};
156+
config
157+
}
158+
145159
/// Run a workload using the generic workload trait
146160
async fn run_workload<W>(workload: W) -> Result<()>
147161
where W: Workload {
@@ -150,7 +164,7 @@ where W: Workload {
150164
// Create FuzzDB with randomized session config
151165
info!("⚙️ Setting up distributed session with randomized configuration...");
152166

153-
let fuzz_db = match FuzzDB::new(W::setup).await {
167+
let fuzz_db = match FuzzDB::new(randomized_fuzz_config(), W::setup).await {
154168
Ok(db) => db,
155169
Err(e) => {
156170
return Err(datafusion::error::DataFusionError::Execution(format!(

src/test_utils/fuzz.rs

Lines changed: 17 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::test_utils::localhost::start_localhost_context;
1+
use crate::{DistributedExt, test_utils::localhost::start_localhost_context};
22
use crate::DefaultSessionBuilder;
33
use arrow::record_batch::RecordBatch;
44
use async_trait::async_trait;
@@ -8,7 +8,6 @@ use datafusion::{
88
execution::context::SessionContext,
99
logical_expr::LogicalPlan,
1010
};
11-
use rand::Rng;
1211

1312
/// Fuzzing database with distributed session context and helper functions
1413
pub struct FuzzDB {
@@ -22,98 +21,58 @@ pub struct FuzzDB {
2221

2322
/// Configuration parameters for randomized session setup
2423
#[derive(Debug, Clone)]
25-
pub struct SessionConfig {
24+
pub struct FuzzConfig {
2625
pub num_workers: usize,
27-
pub tasks_per_file: usize,
28-
pub cardinality_task_count_factor: usize,
29-
pub target_partitions: usize,
30-
}
31-
32-
impl Default for SessionConfig {
33-
fn default() -> Self {
34-
Self {
35-
num_workers: 4,
36-
tasks_per_file: 4,
37-
cardinality_task_count_factor: 4,
38-
target_partitions: 8,
39-
}
40-
}
26+
pub files_per_task: usize,
27+
pub cardinality_task_count_factor: f64,
4128
}
4229

4330
impl FuzzDB {
4431
/// Create a new FuzzDB with randomized session parameters and setup function
45-
pub async fn new<F, Fut>(setup: F) -> Result<Self>
32+
pub async fn new<F, Fut>(cfg: FuzzConfig, setup: F) -> Result<Self>
4633
where
4734
F: Fn(SessionContext) -> Fut + Send + Sync,
4835
Fut: std::future::Future<Output = Result<()>> + Send
4936
{
50-
let config = randomize_session_config();
51-
create_db(config, setup).await
37+
create_db(cfg, setup).await
5238
}
5339

5440
/// Execute a query and validate results using all oracles
5541
pub async fn run(&self, query: &str) -> Result<Vec<RecordBatch>> {
5642

5743
// Execute on distributed context
5844
let df = self.distributed_ctx.sql(query).await?;
59-
let _logical_plan = df.logical_plan().clone();
45+
let logical_plan = df.logical_plan().clone();
6046
let results = df.collect().await?;
6147

6248
// Run oracles
63-
let _single_node_oracle = SingleNodeOracle::new(&self.single_node_ctx);
64-
let _ordering_oracle = OrderingOracle::new();
49+
let single_node_oracle = SingleNodeOracle::new(&self.single_node_ctx);
50+
let ordering_oracle = OrderingOracle::new();
6551

6652
// Validate with SingleNodeOracle
67-
// single_node_oracle.validate(&self.distributed_ctx, query, &results).await?;
53+
single_node_oracle.validate(&self.distributed_ctx, query, &results).await?;
6854

6955
// Validate with OrderingOracle
70-
// ordering_oracle.validate_with_plan(&logical_plan, &results).await?;
56+
ordering_oracle.validate_with_plan(&logical_plan, &results).await?;
7157

7258
Ok(results)
7359
}
7460
}
7561

76-
/// Randomize session configuration parameters
77-
fn randomize_session_config() -> SessionConfig {
78-
let mut rng = rand::thread_rng();
79-
80-
let config = SessionConfig {
81-
num_workers: rng.gen_range(2..=8),
82-
tasks_per_file: rng.gen_range(1..=8),
83-
cardinality_task_count_factor: rng.gen_range(1..=8),
84-
target_partitions: rng.gen_range(4..=16),
85-
};
86-
87-
println!("Generated random session config: {:?}", config);
88-
config
89-
}
62+
9063

9164
/// Create distributed session context with specified configuration
92-
async fn create_db<F, Fut>(config: SessionConfig, setup: F) -> Result<FuzzDB>
65+
async fn create_db<F, Fut>(cfg: FuzzConfig, setup: F) -> Result<FuzzDB>
9366
where
9467
F: Fn(SessionContext) -> Fut + Send + Sync,
9568
Fut: std::future::Future<Output = Result<()>> + Send
9669
{
97-
println!("Creating FuzzDB with {} workers", config.num_workers);
70+
println!("Creating FuzzDB with {} workers", cfg.num_workers);
9871

9972
// Start localhost context with workers using DefaultSessionBuilder
100-
let (distributed_ctx, worker_tasks) = start_localhost_context(config.num_workers, DefaultSessionBuilder).await;
101-
102-
// Configure session parameters
103-
{
104-
let mut session_config = distributed_ctx.state().config().clone();
105-
session_config.options_mut().execution.target_partitions = config.target_partitions;
106-
session_config.options_mut().optimizer.enable_round_robin_repartition = true;
107-
108-
// Set additional distributed-specific parameters
109-
if let Ok(tasks_per_file) = std::env::var("DATAFUSION_EXECUTION_TASKS_PER_FILE") {
110-
println!("Using DATAFUSION_EXECUTION_TASKS_PER_FILE from environment: {}", tasks_per_file);
111-
} else {
112-
unsafe {
113-
std::env::set_var("DATAFUSION_EXECUTION_TASKS_PER_FILE", config.tasks_per_file.to_string());
114-
}
115-
}
116-
}
73+
let (mut distributed_ctx, worker_tasks) = start_localhost_context(cfg.num_workers, DefaultSessionBuilder).await;
74+
distributed_ctx.set_distributed_files_per_task(cfg.files_per_task)?;
75+
distributed_ctx.set_distributed_cardinality_effect_task_scale_factor(cfg.cardinality_task_count_factor)?;
11776

11877
// Create single node context for oracle comparison
11978
let single_node_ctx = SessionContext::new();
@@ -122,19 +81,6 @@ where
12281
setup(distributed_ctx.clone()).await?;
12382
setup(single_node_ctx.clone()).await?;
12483

125-
// Log worker configuration
126-
println!("Session configuration:");
127-
println!(" Number of workers: {}", config.num_workers);
128-
println!(" Tasks per file: {}", config.tasks_per_file);
129-
println!(" Cardinality task count factor: {}", config.cardinality_task_count_factor);
130-
println!(" Target partitions: {}", config.target_partitions);
131-
132-
// Get worker ports from distributed context (this is a simplified approach)
133-
// In a real implementation, you might want to extract actual port information
134-
for i in 0..config.num_workers {
135-
println!(" Worker {}: localhost:random_port", i);
136-
}
137-
13884
Ok(FuzzDB {
13985
distributed_ctx,
14086
single_node_ctx,
@@ -163,8 +109,6 @@ impl<'a> SingleNodeOracle<'a> {
163109
#[async_trait]
164110
impl<'a> Oracle for SingleNodeOracle<'a> {
165111
async fn validate(&self, _distributed_ctx: &SessionContext, query: &str, distributed_results: &[RecordBatch]) -> Result<()> {
166-
println!("SingleNodeOracle: Validating query against single-node execution");
167-
168112
// Execute the same query on single node context
169113
let single_node_df = self.single_node_ctx.sql(query).await?;
170114
let single_node_results = single_node_df.collect().await?;

src/test_utils/tpcds.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
use std::collections::HashSet;
33
use std::fs;
44
use std::path::Path;
5+
use std::process::Command;
56
use datafusion::{
67
execution::context::SessionContext,
78
prelude::ParquetReadOptions,
@@ -16,6 +17,10 @@ pub fn get_queries_dir() -> std::path::PathBuf {
1617
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("testdata/tpcds/queries")
1718
}
1819

20+
pub fn get_tpcds_dir() -> std::path::PathBuf {
21+
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("testdata/tpcds")
22+
}
23+
1924
pub fn tpcds_query_from_dir(queries_dir: &std::path::Path, name: &str) -> String {
2025
let query_path = queries_dir.join(format!("{}.sql", name));
2126
fs::read_to_string(query_path)
@@ -133,4 +138,35 @@ pub async fn register_available_tpcds_tables(
133138
Ok((registered_tables, missing_tables))
134139
}
135140

141+
/// Generate TPC-DS data using the generation script
142+
pub fn generate_tpcds_data(scale_factor: &str) -> Result<()> {
143+
let tpcds_dir = get_tpcds_dir();
144+
let generate_script = tpcds_dir.join("generate.sh");
145+
146+
if !generate_script.exists() {
147+
return Err(DataFusionError::Execution(format!(
148+
"TPC-DS generation script not found: {}", generate_script.display()
149+
)));
150+
}
151+
152+
let output = Command::new("bash")
153+
.arg(&generate_script)
154+
.arg(scale_factor)
155+
.current_dir(&tpcds_dir)
156+
.output()
157+
.map_err(|e| DataFusionError::Execution(format!(
158+
"Failed to execute TPC-DS generation script: {}", e
159+
)))?;
160+
161+
if !output.status.success() {
162+
let stderr = String::from_utf8_lossy(&output.stderr);
163+
let stdout = String::from_utf8_lossy(&output.stdout);
164+
return Err(DataFusionError::Execution(format!(
165+
"TPC-DS generation failed:\nstdout: {}\nstderr: {}", stdout, stderr
166+
)));
167+
}
168+
169+
Ok(())
170+
}
171+
136172

0 commit comments

Comments
 (0)