Skip to content

Commit 731ad2a

Browse files
authored
Rework task assignation mechanism (#216)
* Introduce with_distributed_execution method and remove DistributedExt from anything that's not a SessionStateBuilder * Rework how tasks are assigned to stages * Use files_per_task and cardinality_effect_factor for choosing the right tasks for the different stages * Update all tests to use the new task estimation mechanism * Fix when worker urls is zero or one and add tests * Move back to `with_distributed_channel_resolver` and let users inject either `DistributedPhysicalOptimizerRule` or their own rules * Add files_per_task and cardinality_task_sf to examples * Improve docs * add impl TaskEstimator for Arc<&dyn TaskEstimator> * Add task estimator tests * Explain why we need to count distinct files * Return an error if task_count == 0 * Factor out apply_scale_factor * Add task estimator doc comments * Add more docs to apply_network_boundaries * Extend files_per_task docs * Rollback DistributedExt de-implementation from SessionConfig, SessionState and SessionContext
1 parent dd853a7 commit 731ad2a

23 files changed

+1476
-693
lines changed

benchmarks/src/tpch/run.rs

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -101,28 +101,6 @@ pub struct RunOpt {
101101
#[structopt(short = "t", long = "sorted")]
102102
sorted: bool,
103103

104-
/// Upon shuffling data, this defines how many tasks are employed into performing the shuffling.
105-
/// ```text
106-
/// ( task 1 ) ( task 2 ) ( task 3 )
107-
/// ▲ ▲ ▲
108-
/// └────┬──────┴─────┬────┘
109-
/// ( task 1 ) ( task 2 ) N tasks
110-
/// ```
111-
/// This parameter defines N
112-
#[structopt(long)]
113-
shuffle_tasks: Option<usize>,
114-
115-
/// Upon merging multiple tasks into one, this defines how many tasks are merged.
116-
/// ```text
117-
/// ( task 1 )
118-
/// ▲
119-
/// ┌───────────┴──────────┐
120-
/// ( task 1 ) ( task 2 ) ( task 3 ) N tasks
121-
/// ```
122-
/// This parameter defines N
123-
#[structopt(long)]
124-
coalesce_tasks: Option<usize>,
125-
126104
/// Spawns a worker in the specified port.
127105
#[structopt(long)]
128106
spawn: Option<u16>,
@@ -134,6 +112,14 @@ pub struct RunOpt {
134112
/// Number of physical threads per worker.
135113
#[structopt(long)]
136114
threads: Option<usize>,
115+
116+
/// Number of files per each distributed task.
117+
#[structopt(long)]
118+
files_per_task: Option<usize>,
119+
120+
/// Task count scale factor for when nodes in stages change the cardinality of the data
121+
#[structopt(long)]
122+
cardinality_task_sf: Option<f64>,
137123
}
138124

139125
#[async_trait]
@@ -142,37 +128,32 @@ impl DistributedSessionBuilder for RunOpt {
142128
&self,
143129
ctx: DistributedSessionBuilderContext,
144130
) -> Result<SessionState, DataFusionError> {
145-
let mut builder = SessionStateBuilder::new().with_default_features();
146-
131+
let rt_builder = self.common.runtime_env_builder()?;
147132
let config = self
148133
.common
149134
.config()?
150-
.with_collect_statistics(!self.disable_statistics)
135+
.with_target_partitions(self.partitions())
136+
.with_collect_statistics(!self.disable_statistics);
137+
let mut builder = SessionStateBuilder::new()
138+
.with_runtime_env(rt_builder.build_arc()?)
139+
.with_default_features()
140+
.with_config(config)
151141
.with_distributed_user_codec(InMemoryCacheExecCodec)
152142
.with_distributed_channel_resolver(LocalHostChannelResolver::new(self.workers.clone()))
143+
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
153144
.with_distributed_option_extension_from_headers::<WarmingUpMarker>(&ctx.headers)?
154-
.with_target_partitions(self.partitions());
155-
156-
let rt_builder = self.common.runtime_env_builder()?;
145+
.with_distributed_files_per_task(
146+
self.files_per_task.unwrap_or(get_available_parallelism()),
147+
)?
148+
.with_distributed_cardinality_effect_task_scale_factor(
149+
self.cardinality_task_sf.unwrap_or(1.0),
150+
)?;
157151

158152
if self.mem_table {
159153
builder = builder.with_physical_optimizer_rule(Arc::new(InMemoryDataSourceRule));
160154
}
161-
if !self.workers.is_empty() {
162-
builder = builder
163-
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
164-
.with_distributed_network_coalesce_tasks(
165-
self.coalesce_tasks.unwrap_or(self.workers.len()),
166-
)
167-
.with_distributed_network_shuffle_tasks(
168-
self.shuffle_tasks.unwrap_or(self.workers.len()),
169-
);
170-
}
171155

172-
Ok(builder
173-
.with_config(config)
174-
.with_runtime_env(rt_builder.build_arc()?)
175-
.build())
156+
Ok(builder.build())
176157
}
177158
}
178159

@@ -196,7 +177,12 @@ impl RunOpt {
196177
}
197178

198179
async fn run_local(mut self) -> Result<()> {
199-
let state = self.build_session_state(Default::default()).await?;
180+
let mut state = self.build_session_state(Default::default()).await?;
181+
if self.mem_table {
182+
state = SessionStateBuilder::from(state)
183+
.with_distributed_option_extension(WarmingUpMarker::warming_up())?
184+
.build();
185+
}
200186
let ctx = SessionContext::new_with_state(state);
201187
self.register_tables(&ctx).await?;
202188

@@ -218,9 +204,6 @@ impl RunOpt {
218204
for query_id in query_range.clone() {
219205
// put the WarmingUpMarker in the context, otherwise, queries will fail as the
220206
// InMemoryCacheExec node will think they should already be warmed up.
221-
let ctx = ctx
222-
.clone()
223-
.with_distributed_option_extension(WarmingUpMarker::warming_up())?;
224207
for query in get_query_sql(query_id)? {
225208
self.execute_query(&ctx, &query).await?;
226209
}

examples/in_memory_cluster.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use arrow::util::pretty::pretty_format_batches;
22
use arrow_flight::flight_service_client::FlightServiceClient;
33
use async_trait::async_trait;
44
use datafusion::common::DataFusionError;
5+
use datafusion::common::utils::get_available_parallelism;
56
use datafusion::execution::SessionStateBuilder;
67
use datafusion::physical_plan::displayable;
78
use datafusion::prelude::{ParquetReadOptions, SessionContext};
@@ -28,11 +29,11 @@ struct Args {
2829
#[structopt(long)]
2930
explain: bool,
3031

31-
#[structopt(long, default_value = "3")]
32-
network_shuffle_tasks: usize,
32+
#[structopt(long)]
33+
files_per_task: Option<usize>,
3334

34-
#[structopt(long, default_value = "3")]
35-
network_coalesce_tasks: usize,
35+
#[structopt(long)]
36+
cardinality_task_sf: Option<f64>,
3637
}
3738

3839
#[tokio::main]
@@ -43,8 +44,12 @@ async fn main() -> Result<(), Box<dyn Error>> {
4344
.with_default_features()
4445
.with_distributed_channel_resolver(InMemoryChannelResolver::new())
4546
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
46-
.with_distributed_network_coalesce_tasks(args.network_shuffle_tasks)
47-
.with_distributed_network_shuffle_tasks(args.network_coalesce_tasks)
47+
.with_distributed_files_per_task(
48+
args.files_per_task.unwrap_or(get_available_parallelism()),
49+
)?
50+
.with_distributed_cardinality_effect_task_scale_factor(
51+
args.cardinality_task_sf.unwrap_or(1.),
52+
)?
4853
.build();
4954

5055
let ctx = SessionContext::from(state);

examples/localhost_run.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use arrow_flight::flight_service_client::FlightServiceClient;
33
use async_trait::async_trait;
44
use dashmap::{DashMap, Entry};
55
use datafusion::common::DataFusionError;
6+
use datafusion::common::utils::get_available_parallelism;
67
use datafusion::execution::SessionStateBuilder;
78
use datafusion::physical_plan::displayable;
89
use datafusion::prelude::{ParquetReadOptions, SessionContext};
@@ -29,11 +30,11 @@ struct Args {
2930
#[structopt(long)]
3031
explain: bool,
3132

32-
#[structopt(long, default_value = "3")]
33-
network_shuffle_tasks: usize,
33+
#[structopt(long)]
34+
files_per_task: Option<usize>,
3435

35-
#[structopt(long, default_value = "3")]
36-
network_coalesce_tasks: usize,
36+
#[structopt(long)]
37+
cardinality_task_sf: Option<f64>,
3738
}
3839

3940
#[tokio::main]
@@ -49,8 +50,12 @@ async fn main() -> Result<(), Box<dyn Error>> {
4950
.with_default_features()
5051
.with_distributed_channel_resolver(localhost_resolver)
5152
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
52-
.with_distributed_network_coalesce_tasks(args.network_coalesce_tasks)
53-
.with_distributed_network_shuffle_tasks(args.network_shuffle_tasks)
53+
.with_distributed_files_per_task(
54+
args.files_per_task.unwrap_or(get_available_parallelism()),
55+
)?
56+
.with_distributed_cardinality_effect_task_scale_factor(
57+
args.cardinality_task_sf.unwrap_or(1.),
58+
)?
5459
.build();
5560

5661
let ctx = SessionContext::from(state);

src/channel_resolver_ext.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
use crate::DistributedConfig;
12
use arrow_flight::flight_service_client::FlightServiceClient;
23
use async_trait::async_trait;
3-
use datafusion::common::exec_datafusion_err;
4+
use datafusion::common::exec_err;
45
use datafusion::error::DataFusionError;
56
use datafusion::prelude::SessionConfig;
67
use std::sync::Arc;
@@ -11,21 +12,30 @@ pub(crate) fn set_distributed_channel_resolver(
1112
cfg: &mut SessionConfig,
1213
channel_resolver: impl ChannelResolver + Send + Sync + 'static,
1314
) {
14-
cfg.set_extension(Arc::new(ChannelResolverExtension(Arc::new(
15-
channel_resolver,
16-
))));
15+
let opts = cfg.options_mut();
16+
let channel_resolver_ext = ChannelResolverExtension(Arc::new(channel_resolver));
17+
if let Some(distributed_cfg) = opts.extensions.get_mut::<DistributedConfig>() {
18+
distributed_cfg.__private_channel_resolver = channel_resolver_ext;
19+
} else {
20+
opts.extensions.insert(DistributedConfig {
21+
__private_channel_resolver: channel_resolver_ext,
22+
..Default::default()
23+
});
24+
}
1725
}
1826

1927
pub(crate) fn get_distributed_channel_resolver(
2028
cfg: &SessionConfig,
2129
) -> Result<Arc<dyn ChannelResolver + Send + Sync>, DataFusionError> {
22-
cfg.get_extension::<ChannelResolverExtension>()
23-
.map(|cm| cm.0.clone())
24-
.ok_or_else(|| exec_datafusion_err!("ChannelResolver not present in the session config"))
30+
let opts = cfg.options();
31+
let Some(distributed_cfg) = opts.extensions.get::<DistributedConfig>() else {
32+
return exec_err!("ChannelResolver not present in the session config");
33+
};
34+
Ok(Arc::clone(&distributed_cfg.__private_channel_resolver.0))
2535
}
2636

2737
#[derive(Clone)]
28-
struct ChannelResolverExtension(Arc<dyn ChannelResolver + Send + Sync>);
38+
pub(crate) struct ChannelResolverExtension(pub(crate) Arc<dyn ChannelResolver + Send + Sync>);
2939

3040
pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<
3141
http::Request<Body>,

0 commit comments

Comments
 (0)