diff --git a/examples/python/pytorch/mt5/mt5_ff.py b/examples/python/pytorch/mt5/mt5_ff.py index c2868e9d1e..5dff7415d3 100644 --- a/examples/python/pytorch/mt5/mt5_ff.py +++ b/examples/python/pytorch/mt5/mt5_ff.py @@ -122,6 +122,7 @@ def top_level_task(): input_names = ["input_ids", "attention_mask"] print("Tracing the model...") + print(batch_size) hf_model = PyTorchModel( model, is_hf_model=True, input_names=input_names, batch_size=batch_size, seq_length=seq_length, diff --git a/include/flexflow/config.h b/include/flexflow/config.h index d1fe6231da..b6a27a4f2a 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -124,6 +124,7 @@ class FFConfig { size_t workSpaceSize; Legion::Context lg_ctx; Legion::Runtime *lg_hlr; + Legion::IndexSpaceT<1> all_gpu_task_is; Legion::FieldSpace field_space; bool syntheticInput, profiling, perform_fusion; size_t simulator_work_space_size; @@ -137,6 +138,8 @@ class FFConfig { bool enable_parameter_parallel; bool enable_attribute_parallel; bool enable_inplace_optimizations; + int data_parallelism_degree; + int tensor_parallelism_degree; // Control Tensor Op Math Conversion bool allow_tensor_op_math_conversion; std::string dataset_path; diff --git a/include/flexflow/ffconst.h b/include/flexflow/ffconst.h index 5658e2923d..060983b020 100644 --- a/include/flexflow/ffconst.h +++ b/include/flexflow/ffconst.h @@ -157,6 +157,7 @@ enum OperatorType { OP_REPLICATE, OP_REDUCTION, OP_PIPELINE, + OP_ALLREDUCE, OP_FUSED_PARALLEL, OP_INVALID, }; @@ -189,6 +190,7 @@ enum PMParameter { PM_COMBINE_DEGREE, // Combine PM_REDUCTION_DIM, // Reduction PM_REDUCTION_DEGREE, // Reduction + PM_ALLREDUCE_DIM, // AllReduce PM_SOFTMAX_DIM, // Softmax PM_NUM_HEADS, // MultiHeadAttention PM_INVALID, diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index bde40de31b..2ddc8549fa 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -95,6 +95,8 @@ void flexflow_model_compute_metrics(flexflow_model_t handle); void flexflow_model_update(flexflow_model_t handle); +void flexflow_model_unified_update(flexflow_model_t handle); + void flexflow_model_compile(flexflow_model_t handle, enum LossType loss_type, int *metrics, diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 5e7ad8025c..fe73e6a0e3 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -151,6 +151,7 @@ enum TaskIDs { // Optimizer with NCCL SGD_UPD_NCCL_TASK_ID, ADAM_UPD_NCCL_TASK_ID, + ADAM_UNIFY_UPD_NCCL_TASK_ID, // Initializer GLOROT_INIT_TASK_ID, ZERO_INIT_TASK_ID, @@ -190,6 +191,10 @@ enum TaskIDs { PIPELINE_INIT_TASK_ID, PIPELINE_FWD_TASK_ID, PIPELINE_BWD_TASK_ID, + ALLREDUCE_INIT_TASK_ID, + ALLREDUCE_INF_TASK_ID, + ALLREDUCE_FWD_TASK_ID, + ALLREDUCE_BWD_TASK_ID, FUSED_PARALLELOP_INIT_TASK_ID, FUSED_PARALLELOP_FWD_TASK_ID, FUSED_PARALLELOP_BWD_TASK_ID, @@ -273,6 +278,7 @@ class Split; class TopK; class Transpose; class Combine; +class AllReduce; class Repartition; class Reduction; class Replicate; @@ -777,6 +783,7 @@ class FFModel { void get_metrics(); void backward(int seq_length = -1); void update(); + void unified_update(); bool apply_fusion(std::vector const &operators, std::vector &new_operators); Op *get_final_operator() const; @@ -828,6 +835,8 @@ class FFModel { Legion::IndexSpace get_task_is(Legion::Domain const &domain) const; Legion::IndexSpace get_task_is(ParallelConfig const &pc) const; Legion::IndexSpace get_task_is(MachineView const &view) const; + bool is_transformer_block(int layer_idx) const; + bool is_mlp_block(int layer_idx) const; void create_operators_from_layers(); Op *create_operator_from_layer(Layer *layer, std::vector const &inputs); @@ -854,6 +863,7 @@ class FFModel { int metrics_input; ParallelTensor parallel_label_tensor; Tensor label_tensor; + int num_inputs = 0; std::vector layers; std::vector operators; @@ -923,6 +933,8 @@ class FFModel { Replicate *>, std::unordered_map, Reduction *>, + std::unordered_map, + AllReduce *>, std::unordered_map, Combine *>, std::unordered_map, diff --git a/include/flexflow/operator_params.h b/include/flexflow/operator_params.h index 24c84a85ed..84653ac9ca 100644 --- a/include/flexflow/operator_params.h +++ b/include/flexflow/operator_params.h @@ -7,6 +7,7 @@ #include "flexflow/ops/batch_matmul_params.h" #include "flexflow/ops/cast_params.h" #include "flexflow/ops/concat_params.h" +#include "flexflow/parallel_ops/allreduce_params.h" #include "flexflow/ops/conv_2d_params.h" #include "flexflow/ops/dropout_params.h" #include "flexflow/ops/element_binary_params.h" @@ -62,6 +63,7 @@ using OperatorParameters = mp::variant; tl::optional get_op_parameters(Op const *op); diff --git a/include/flexflow/ops/dropout.h b/include/flexflow/ops/dropout.h index 37304bdada..b8033c98ba 100644 --- a/include/flexflow/ops/dropout.h +++ b/include/flexflow/ops/dropout.h @@ -5,6 +5,13 @@ #include "flexflow/node.h" #include "flexflow/operator.h" #include "flexflow/ops/dropout_params.h" +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +#include +#include +#elif defined(FF_USE_HIP_ROCM) +#include +#include +#endif namespace FlexFlow { diff --git a/include/flexflow/ops/element_binary.h b/include/flexflow/ops/element_binary.h index f83ce20609..677ff23ce2 100644 --- a/include/flexflow/ops/element_binary.h +++ b/include/flexflow/ops/element_binary.h @@ -53,6 +53,11 @@ class ElementBinary : public Op { bool measure_operator_cost(Simulator *sim, MachineView const &pc, CostMetrics &cost_metrics) const override; + void serialize(Legion::Serializer &) const override; + static PCG::Node deserialize(FFModel &ff, + Legion::Deserializer &d, + ParallelTensor inputs[], + int num_inputs); Params get_params() const; public: diff --git a/include/flexflow/ops/element_binary_params.h b/include/flexflow/ops/element_binary_params.h index 5aa20e25a5..c70e1b597a 100644 --- a/include/flexflow/ops/element_binary_params.h +++ b/include/flexflow/ops/element_binary_params.h @@ -8,6 +8,7 @@ namespace FlexFlow { struct ElementBinaryParams { OperatorType type; + bool inplace_a; bool is_valid( std::pair const &) const; diff --git a/include/flexflow/ops/kernels/dropout_kernels.h b/include/flexflow/ops/kernels/dropout_kernels.h index 421974fbaa..b2201dd34e 100644 --- a/include/flexflow/ops/kernels/dropout_kernels.h +++ b/include/flexflow/ops/kernels/dropout_kernels.h @@ -5,6 +5,7 @@ #include "flexflow/fftype.h" #include "flexflow/op_meta.h" #include "flexflow/ops/dropout.h" +#include "flexflow/accessor.h" namespace FlexFlow { @@ -17,33 +18,40 @@ class DropoutMeta : public OpMeta { ~DropoutMeta(void); Realm::RegionInstance reserveInst; #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) + curandState *state; cudnnTensorDescriptor_t inputTensor, outputTensor; cudnnDropoutDescriptor_t dropoutDesc; #else miopenTensorDescriptor_t inputTensor, outputTensor; miopenDropoutDescriptor_t dropoutDesc; + hiprandState *state; #endif void *reserveSpace, *dropoutStates; size_t reserveSpaceSize, dropoutStateSize; + size_t num_elements; + long long seed; + float rate; }; namespace Kernels { namespace Dropout { void forward_kernel_wrapper(DropoutMeta *m, - float const *input_ptr, - float *output_ptr); + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output); void backward_kernel_wrapper(DropoutMeta *m, - float const *output_grad_ptr, - float *input_grad_ptr); + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorW const &input_grad); namespace Internal { void forward_kernel(DropoutMeta *m, float const *input_ptr, float *output_ptr, + size_t num_elements, ffStream_t stream); void backward_kernel(DropoutMeta *m, float const *output_grad_ptr, float *input_grad_ptr, + size_t num_elements, ffStream_t stream); } // namespace Internal } // namespace Dropout diff --git a/include/flexflow/optimizer.h b/include/flexflow/optimizer.h index bab7e6e4ed..401fffb351 100644 --- a/include/flexflow/optimizer.h +++ b/include/flexflow/optimizer.h @@ -18,6 +18,7 @@ #include "flexflow/parallel_tensor.h" #include "legion.h" +#include "accessor.h" namespace FlexFlow { @@ -30,6 +31,7 @@ class Optimizer { virtual void init(void) = 0; virtual void next(void) = 0; virtual void update(const ParallelTensor p) = 0; + virtual void unified_update(std::vector const parameters) = 0; FFModel const *model; }; @@ -43,6 +45,7 @@ class SGDOptimizer : public Optimizer { void init(void); void next(void); void update(const ParallelTensor p); + void unified_update(std::vector const parameters); void set_weight_decay(double _weight_decay); static void ps_update_task(Legion::Task const *task, std::vector const ®ions, @@ -60,6 +63,11 @@ class SGDOptimizer : public Optimizer { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + static void + nccl_unified_update_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); static void nccl_update_task_gpu(SGDOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, @@ -85,6 +93,7 @@ class AdamOptimizer : public Optimizer { void init(void); void next(void); void update(const ParallelTensor p); + void unified_update(std::vector const parameters); void set_weight_decay(double _weight_decay); static void ps_update_task(Legion::Task const *task, std::vector const ®ions, @@ -103,6 +112,11 @@ class AdamOptimizer : public Optimizer { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + static void + nccl_unified_update_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); static void nccl_update_task_gpu(AdamOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, @@ -110,10 +124,19 @@ class AdamOptimizer : public Optimizer { float *w_ptr, float *v_ptr, float *m_ptr); + static void nccl_unified_update_task_gpu(AdamOptimizer const *op, + OpMeta const *meta, + GenericTensorAccessorR *accWGrads, + size_t *size, + GenericTensorAccessorW *accWs, + GenericTensorAccessorW *accVs, + GenericTensorAccessorW *accMs); #endif double alpha, beta1, beta2, weight_decay, epsilon; double alpha_t, beta1_t, beta2_t; std::map v_values, m_values; + size_t reservedWorkSpaceSize = 0; + int parameters_num = 0; }; }; // namespace FlexFlow diff --git a/include/flexflow/parallel_ops/allreduce.h b/include/flexflow/parallel_ops/allreduce.h new file mode 100644 index 0000000000..a28d4cef9e --- /dev/null +++ b/include/flexflow/parallel_ops/allreduce.h @@ -0,0 +1,57 @@ +#ifndef _FLEXFLOW_ALLREDUCE_H +#define _FLEXFLOW_ALLREDUCE_H + +#include "flexflow/layer.h" +#include "flexflow/node.h" +#include "flexflow/op_meta.h" +#include "flexflow/operator.h" +#include "flexflow/parallel_ops/allreduce_params.h" +#include "parallel_op.h" + +namespace FlexFlow { + +class AllReduce : public ParallelOp { +public: + using Params = AllReduceParams; + using Input = ParallelTensor; + + AllReduce(FFModel &model, + const ParallelTensor input, + int allreduce_legion_dim, + char const *name = NULL); + AllReduce(FFModel &model, + Params const ¶ms, + Input const input, + char const *name = nullptr); + void create_input_partition(FFModel &model) override; + void init(FFModel const &) override; + void forward(FFModel const &) override; + void backward(FFModel const &) override; + bool get_int_parameter(PMParameter, int *) const override; + bool append_parallel_op_info( + std::vector ¶llel_ops) const override; + static OpMeta *init_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static void forward_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static void backward_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + bool measure_operator_cost(Simulator *sim, + MachineView const &pc, + CostMetrics &cost_metrics) const override; + + Params get_params() const; + +public: + int allreduce_dim; +}; + +}; // namespace FlexFlow + +#endif // _FLEXFLOW_ALLREDUCE_H diff --git a/include/flexflow/parallel_ops/allreduce_params.h b/include/flexflow/parallel_ops/allreduce_params.h new file mode 100644 index 0000000000..a0daac8f9a --- /dev/null +++ b/include/flexflow/parallel_ops/allreduce_params.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_ALLREDUCE_PARAMS_H +#define _FLEXFLOW_ALLREDUCE_PARAMS_H + +namespace FlexFlow { + +struct AllReduceParams { + int allreduce_legion_dim; + char name[MAX_OPNAME]; + bool is_valid(ParallelTensorShape const &) const; +}; +bool operator==(AllReduceParams const &, AllReduceParams const &); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AllReduceParams const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_ALLREDUCE_PARAMS_H diff --git a/include/flexflow/parallel_ops/kernels/allreduce_kernels.h b/include/flexflow/parallel_ops/kernels/allreduce_kernels.h new file mode 100644 index 0000000000..02a5026fcf --- /dev/null +++ b/include/flexflow/parallel_ops/kernels/allreduce_kernels.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H +#define _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H + +#include "flexflow/device.h" +#include "flexflow/fftype.h" +#include "flexflow/op_meta.h" +#include "flexflow/parallel_ops/allreduce.h" + +namespace FlexFlow { + +class AllReduceMeta : public OpMeta { +public: + AllReduceMeta(FFHandler handle, AllReduce const *reduct); +}; + +namespace Kernels { +namespace AllReduce { + +void forward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output); + +void backward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad); + +} // namespace AllReduce +} // namespace Kernels +} // namespace FlexFlow + +#endif // _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H diff --git a/include/flexflow/utils/cuda_helper.h b/include/flexflow/utils/cuda_helper.h index a4b2be0a66..d077995884 100644 --- a/include/flexflow/utils/cuda_helper.h +++ b/include/flexflow/utils/cuda_helper.h @@ -82,6 +82,12 @@ __global__ void assign_kernel(DT *ptr, Legion::coord_t size, DT value); template __global__ void copy_kernel(DT *dst, const DT *src, Legion::coord_t size); +template +__global__ void copy_kernel_with_replicate(DT *dst, + const DT *src, + Legion::coord_t origin_size, + Legion::coord_t size); + template __global__ void add_kernel(T *data_ptr, T const *grad_ptr, size_t size); diff --git a/include/flexflow/utils/hip_helper.h b/include/flexflow/utils/hip_helper.h index 709e78f517..8c589305c2 100644 --- a/include/flexflow/utils/hip_helper.h +++ b/include/flexflow/utils/hip_helper.h @@ -82,6 +82,12 @@ __global__ void assign_kernel(DT *ptr, Legion::coord_t size, DT value); template __global__ void copy_kernel(DT *dst, const DT *src, Legion::coord_t size); +template +__global__ void copy_kernel_with_replicate(DT *dst, + const DT *src, + Legion::coord_t origin_size, + Legion::coord_t size); + template __global__ void add_kernel(T *data_ptr, T const *grad_ptr, size_t size); diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index c05fb96661..750838d829 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -2018,6 +2018,13 @@ def update(self): :returns: None -- no returns. """ ffc.flexflow_model_update(self.handle) + + def unified_update(self): + """Update weights and biases of all layers. + + :returns: None -- no returns. + """ + ffc.flexflow_model_unified_update(self.handle) def compile(self, optimizer=None, loss_type=None, metrics=None, comp_mode=None): """Configure the model for trainting. FlexFlow uses lazy initialization, @@ -2072,11 +2079,11 @@ def load_bert_pretrained(self, checkpoint=None): layer = self._layers[i] if (layer.name + "_weight") in weights_dict: print('weight: ' + layer.name) - weight = layer.get_parameter_by_id(0); + weight = layer.get_parameter_by_id(0) weight.set_tensor(self, weights_dict[layer.name + "_weight"]) if (layer.name + "_bias") in weights_dict: print('bias: ' + layer.name) - bias = layer.get_parameter_by_id(1); + bias = layer.get_parameter_by_id(1) bias.set_tensor(self, weights_dict[layer.name + "_bias"]) def fit(self, x=None, y=None, batch_size=None, epochs=1): """Trains the model for a fixed number of epochs (iterations on a dataset). @@ -2117,7 +2124,7 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1): for d in dataloaders: d.next_batch(self) self.forward() - self.zero_gradients() + # self.zero_gradients() self.backward() self.update() self._ffconfig.end_trace(self._tracing_id) diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index e3c430fd92..1d8634f224 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -178,6 +178,11 @@ void flexflow_model_update(flexflow_model_t handle_) { handle->update(); } +void flexflow_model_unified_update(flexflow_model_t handle_) { + FFModel *handle = FFCObjectWrapper::unwrap(handle_); + handle->unified_update(); +} + void flexflow_model_compile(flexflow_model_t handle_, enum LossType loss_type, int *metrics, diff --git a/src/dataloader/dataloader.cc b/src/dataloader/dataloader.cc index 441a088194..614482e8b1 100644 --- a/src/dataloader/dataloader.cc +++ b/src/dataloader/dataloader.cc @@ -97,7 +97,7 @@ SingleDataLoader::SingleDataLoader(FFModel &ff, datatype = datatype_; // Currently assume that the leading dim of input is a replica dim of degree 1 assert(input->dims[input->num_dims - 1].is_replica_dim); - assert(input->dims[input->num_dims - 1].size == 1); + // assert(input->dims[input->num_dims - 1].size == 1); batch_input = input; ParallelDim dims[MAX_TENSOR_DIM]; diff --git a/src/dataloader/dataloader.cpp b/src/dataloader/dataloader.cpp index 7d9ffc02b1..97668d705d 100644 --- a/src/dataloader/dataloader.cpp +++ b/src/dataloader/dataloader.cpp @@ -41,10 +41,12 @@ void SingleDataLoader::load_input(Task const *task, int num_dims = full_input_domain.get_dim(); assert(num_dims + 1 == batch_input_domain.get_dim()); // assert the leading replica dim has a degree of one - assert(batch_input_domain.hi()[num_dims] == - batch_input_domain.lo()[num_dims]); + // assert(batch_input_domain.hi()[num_dims] == + // batch_input_domain.lo()[num_dims]); coord_t batch_size = batch_input_domain.hi()[num_dims - 1] - batch_input_domain.lo()[num_dims - 1] + 1; + coord_t replicate_num = + batch_input_domain.hi()[num_dims] - batch_input_domain.lo()[num_dims] + 1; coord_t num_elements_per_batch = batch_input_domain.get_volume() / batch_size; // FIXME: currently assume continous indices assert(batch_size == meta->num_samples); @@ -61,13 +63,15 @@ void SingleDataLoader::load_input(Task const *task, // printf("ptr(%p, %p), idx0 %d nb_elements_per_batch %d, batch_size %d, // %d\n", acc_full_input.ptr, input_zc, start_idx, num_elements_per_batch, // batch_size, start_idx * num_elements_per_batch); - hipLaunchKernelGGL(HIP_KERNEL_NAME(copy_kernel
), + assert(batch_input_domain.get_volume() % replicate_num == 0); + hipLaunchKernelGGL(HIP_KERNEL_NAME(copy_kernel_with_replicate
), GET_BLOCKS(batch_input_domain.get_volume()), CUDA_NUM_THREADS, 0, stream, batch_input_ptr, input_zc, + batch_input_domain.get_volume() / replicate_num, batch_input_domain.get_volume()); checkCUDA(hipDeviceSynchronize()); } diff --git a/src/dataloader/dataloader.cu b/src/dataloader/dataloader.cu index c2994d00a2..5462532d76 100644 --- a/src/dataloader/dataloader.cu +++ b/src/dataloader/dataloader.cu @@ -40,10 +40,13 @@ void SingleDataLoader::load_input(Task const *task, int num_dims = full_input_domain.get_dim(); assert(num_dims + 1 == batch_input_domain.get_dim()); // assert the leading replica dim has a degree of one - assert(batch_input_domain.hi()[num_dims] == - batch_input_domain.lo()[num_dims]); + // assert(batch_input_domain.hi()[num_dims] == + // batch_input_domain.lo()[num_dims]); coord_t batch_size = batch_input_domain.hi()[num_dims - 1] - batch_input_domain.lo()[num_dims - 1] + 1; + + coord_t replicate_num = + batch_input_domain.hi()[num_dims] - batch_input_domain.lo()[num_dims] + 1; coord_t num_elements_per_batch = batch_input_domain.get_volume() / batch_size; // FIXME: currently assume continous indices assert(batch_size == meta->num_samples); @@ -60,11 +63,15 @@ void SingleDataLoader::load_input(Task const *task, // printf("ptr(%p, %p), idx0 %d nb_elements_per_batch %d, batch_size %d, // %d\n", acc_full_input.ptr, input_zc, start_idx, num_elements_per_batch, // batch_size, start_idx * num_elements_per_batch); - copy_kernel
+ assert(batch_input_domain.get_volume() % replicate_num == 0); + copy_kernel_with_replicate
<<>>(batch_input_ptr, input_zc, batch_input_domain.get_volume()); + stream>>>(batch_input_ptr, + input_zc, + batch_input_domain.get_volume() / replicate_num, + batch_input_domain.get_volume()); checkCUDA(cudaDeviceSynchronize()); } diff --git a/src/ops/dropout.cc b/src/ops/dropout.cc index 55f6730827..2ebfaff539 100644 --- a/src/ops/dropout.cc +++ b/src/ops/dropout.cc @@ -28,7 +28,7 @@ using PCG::Node; using namespace FlexFlow::Kernels::Dropout; -Tensor FFModel::dropout(const Tensor input, +Tensor FFModel::dropout(Tensor const input, float rate, unsigned long long seed, char const *name) { @@ -86,7 +86,7 @@ bool operator==(DropoutParams const &lhs, DropoutParams const &rhs) { } Dropout::Dropout(FFModel &model, - const ParallelTensor _input, + ParallelTensor const _input, float _rate, unsigned long long _seed, char const *name) @@ -111,12 +111,12 @@ Dropout::Dropout(FFModel &model, Dropout::Dropout(FFModel &model, Dropout const &other, - const ParallelTensor input) + ParallelTensor const input) : Dropout(model, input, other.rate, other.seed, other.name) {} Dropout::Dropout(FFModel &model, DropoutParams const ¶ms, - const ParallelTensor input, + ParallelTensor const input, char const *name) : Dropout(model, input, params.rate, params.seed, name) {} @@ -210,12 +210,12 @@ void Dropout::forward_task(Task const *task, assert(task->regions.size() == 2); // const Dropout* dropout = (const Dropout*) task->args; DropoutMeta *m = *((DropoutMeta **)task->local_args); - float const *input_ptr = helperGetTensorPointerRO( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - float *output_ptr = helperGetTensorPointerWO( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - - forward_kernel_wrapper(m, input_ptr, output_ptr); + + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + forward_kernel_wrapper(m, input, output); } void Dropout::backward(FFModel const &ff) { @@ -264,7 +264,13 @@ void Dropout::backward_task(Task const *task, float const *output_grad_ptr = helperGetTensorPointerRO( regions[1], task->regions[1], FID_DATA, ctx, runtime); - backward_kernel_wrapper(m, output_grad_ptr, input_grad_ptr); + + GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW( + m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( + m->input_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + + backward_kernel_wrapper(m, output_grad, input_grad); } void Dropout::serialize(Legion::Serializer &sez) const { @@ -304,30 +310,36 @@ bool Dropout::measure_operator_cost(Simulator *sim, sim->free_all(); float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); assert(input_ptr != NULL); + + GenericTensorAccessorR input_acc(m->input_type[0], sub_input.get_domain(), input_ptr); cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); assert(output_ptr != NULL); + + GenericTensorAccessorW output_acc(m->output_type[0], sub_input.get_domain(), output_ptr); cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); assert(m->profiling == false); std::function forward, backward; - forward = [&] { forward_kernel_wrapper(m, input_ptr, output_ptr); }; + forward = [&] { forward_kernel_wrapper(m, input_acc, output_acc); }; if (sim->computationMode == COMP_MODE_TRAINING) { float *input_grad_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT); assert(input_grad_ptr != NULL); + GenericTensorAccessorW input_grad_acc(m->output_type[0], sub_input.get_domain(), input_grad_ptr); cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); float *output_grad_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); assert(output_grad_ptr != NULL); + GenericTensorAccessorR output_grad_acc(m->output_type[0], sub_input.get_domain(), output_grad_ptr); cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); backward = [&] { - backward_kernel_wrapper(m, output_grad_ptr, input_grad_ptr); + backward_kernel_wrapper(m, output_grad_acc, input_grad_acc); }; } diff --git a/src/ops/element_binary.cc b/src/ops/element_binary.cc index 12895cfd98..84c3f8ba93 100644 --- a/src/ops/element_binary.cc +++ b/src/ops/element_binary.cc @@ -802,10 +802,32 @@ bool ElementBinary::measure_operator_cost(Simulator *sim, return true; } +void ElementBinary::serialize(Legion::Serializer &sez) const { + sez.serialize(this->op_type); + sez.serialize(this->inplace_a); +} + +using PCG::Node; +/*static*/ +Node ElementBinary::deserialize(FFModel &ff, + Legion::Deserializer &dez, + ParallelTensor inputs[], + int num_inputs) { + assert(num_inputs == 2); + OperatorType op_type; + bool inplace_a; + dez.deserialize(op_type); + dez.deserialize(inplace_a); + ElementBinaryParams params; + params.type = op_type; + params.inplace_a = inplace_a; + return ff.get_or_create_node({inputs[0], inputs[1]}, params); +} ElementBinaryParams ElementBinary::get_params() const { ElementBinaryParams params; params.type = this->op_type; + params.inplace_a = this->inplace_a; return params; } diff --git a/src/ops/embedding.cc b/src/ops/embedding.cc index 18bf7e324e..8df6324460 100644 --- a/src/ops/embedding.cc +++ b/src/ops/embedding.cc @@ -148,23 +148,27 @@ int Embedding::output_size(ParallelDim output_dims[MAX_TENSOR_DIM]) { int const OUT_CHANNELS = Output::OUT_CHANNELS; if (aggr == AGGR_MODE_NONE) { int num_dims = input->num_dims + 1; - for (int i = 1; i < num_dims; i++) { + for (int i = 1; i < num_dims - 1; i++) { output_dims[i] = input->dims[i - 1]; } assert(OUT_CHANNELS == 0); output_dims[OUT_CHANNELS].size = this->out_channels; output_dims[OUT_CHANNELS].degree = 1; output_dims[OUT_CHANNELS].parallel_idx = -1; + // Copy replica dim + output_dims[num_dims - 1] = input->dims[input->num_dims - 1]; return num_dims; } else { int num_dims = input->num_dims; - for (int i = 1; i < num_dims; i++) { + for (int i = 1; i < num_dims - 1; i++) { output_dims[i] = input->dims[i]; } assert(OUT_CHANNELS == 0); output_dims[OUT_CHANNELS].size = this->out_channels; output_dims[OUT_CHANNELS].degree = 1; output_dims[OUT_CHANNELS].parallel_idx = -1; + // Copy replica dim + output_dims[num_dims - 1] = input->dims[input->num_dims - 1]; return num_dims; } // const int REPLICA = this->output_vocab_size_replica_dim(); @@ -179,13 +183,13 @@ int Embedding::weight_size(ParallelDim weight_dims[MAX_TENSOR_DIM]) { weight_dims[Weight::VOCAB_SIZE].size = this->num_entries; weight_dims[Weight::VOCAB_SIZE].degree = 1; weight_dims[Weight::VOCAB_SIZE].parallel_idx = -1; - for (int i = 2; i < input->num_dims; i++) { + for (int i = 2; i < input->num_dims + 1; i++) { weight_dims[i].size = input->dims[i - 1].degree; weight_dims[i].degree = weight_dims[i].size; weight_dims[i].parallel_idx = input->dims[i - 1].parallel_idx; weight_dims[i].is_replica_dim = true; } - return input->num_dims; + return input->num_dims + 1; } void Embedding::register_output_mappings() { diff --git a/src/ops/fused.cc b/src/ops/fused.cc index d0c895c53d..b241ff1587 100644 --- a/src/ops/fused.cc +++ b/src/ops/fused.cc @@ -129,7 +129,7 @@ bool FusedOp::add_operator(FFModel &model, Op *op) { // op->name, op_config)); // Cannot fuse parallel operators since they have different paralel_is // in forward and backward - assert(!op->is_parallel_op()); + assert(!op->is_parallel_op() || op->op_type == OP_ALLREDUCE); // Currently don't consider nested fusion if (op->op_type == OP_FUSED) { return false; diff --git a/src/ops/fused.cpp b/src/ops/fused.cpp index cb2b81fbcf..9da93f0c65 100644 --- a/src/ops/fused.cpp +++ b/src/ops/fused.cpp @@ -31,6 +31,7 @@ #include "flexflow/ops/kernels/reshape_kernels.h" #include "flexflow/ops/kernels/softmax_kernels.h" #include "flexflow/ops/kernels/transpose_kernels.h" +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" #include "flexflow/ops/layer_norm.h" #include "flexflow/ops/linear.h" #include "flexflow/utils/hip_helper.h" @@ -205,9 +206,7 @@ __host__ void FusedOp::forward_task(Task const *task, assert(fused->op_num_outputs[op] == 1); DropoutMeta *m = (DropoutMeta *)metas->meta[op]; Kernels::Dropout::forward_kernel_wrapper( - m, - my_input_accessor[0].get_float_ptr(), - my_output_accessor[0].get_float_ptr()); + m, my_input_accessor[0], my_output_accessor[0]); break; } case OP_LINEAR: { @@ -422,6 +421,14 @@ __host__ void FusedOp::forward_task(Task const *task, } break; } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::forward_kernel_wrapper( + m, my_input_accessor[0], my_output_accessor[0]); + break; + } case OP_RESHAPE: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); @@ -959,6 +966,14 @@ __host__ void FusedOp::backward_task(Task const *task, } break; } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::backward_kernel_wrapper( + m, my_input_grad_accessor[0], my_output_grad_accessor[0]); + break; + } case OP_TRANSPOSE: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 62262c89af..b78447ba41 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -33,6 +33,7 @@ #include "flexflow/ops/kernels/reshape_kernels.h" #include "flexflow/ops/kernels/softmax_kernels.h" #include "flexflow/ops/kernels/transpose_kernels.h" +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" #include "flexflow/ops/layer_norm.h" #include "flexflow/utils/cuda_helper.h" @@ -216,9 +217,7 @@ __host__ void FusedOp::forward_task(Task const *task, assert(fused->op_num_outputs[op] == 1); DropoutMeta *m = (DropoutMeta *)metas->meta[op]; Kernels::Dropout::forward_kernel_wrapper( - m, - my_input_accessor[0].get_float_ptr(), - my_output_accessor[0].get_float_ptr()); + m, my_input_accessor[0], my_output_accessor[0]); break; } case OP_LINEAR: { @@ -462,6 +461,14 @@ __host__ void FusedOp::forward_task(Task const *task, } break; } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::forward_kernel_wrapper( + m, my_input_accessor[0], my_output_accessor[0]); + break; + } case OP_TRANSPOSE: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); @@ -819,9 +826,7 @@ __host__ void FusedOp::backward_task(Task const *task, assert(fused->op_num_outputs[op] == 1); DropoutMeta *m = (DropoutMeta *)metas->meta[op]; Kernels::Dropout::backward_kernel_wrapper( - m, - my_output_grad_accessor[0].get_float_ptr(), - my_input_grad_accessor[0].get_float_ptr()); + m, my_output_grad_accessor[0], my_input_grad_accessor[0]); break; } case OP_EW_ADD: @@ -1006,6 +1011,14 @@ __host__ void FusedOp::backward_task(Task const *task, } break; } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::backward_kernel_wrapper( + m, my_input_grad_accessor[0], my_output_grad_accessor[0]); + break; + } case OP_TRANSPOSE: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); diff --git a/src/ops/kernels/dropout_kernels.cpp b/src/ops/kernels/dropout_kernels.cpp index b0dd4c644e..c0d5748464 100644 --- a/src/ops/kernels/dropout_kernels.cpp +++ b/src/ops/kernels/dropout_kernels.cpp @@ -30,6 +30,11 @@ DropoutMeta::DropoutMeta(FFHandler handler, Domain const &output_domain) : OpMeta(handler) { profiling = dropout->profiling; + rate = dropout->rate; + seed = dropout->seed; + input_type[0] = dropout->data_type; + output_type[0] = dropout->data_type; + checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); checkCUDNN(miopenCreateDropoutDescriptor(&dropoutDesc)); @@ -78,20 +83,68 @@ DropoutMeta::~DropoutMeta(void) { namespace Kernels { namespace Dropout { +__global__ void dropout_forward_kernel(float p, + long long seed, + size_t num_elements, + float const *input_ptr, + float *output_ptr) { + CUDA_KERNEL_LOOP(i, num_elements) { + float scale = 1.0 / p; + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, i, 0, &state); + float rand = hiprand_uniform(&state); + if (input_ptr[i] < p) { + output_ptr[i] = 0; + } else { + output_ptr[i] = input_ptr[i] * scale; + } + } +} + +__global__ void dropout_backward_kernel(float p, + long long seed, + size_t num_elements, + float const *input_ptr, + float *output_ptr) { + CUDA_KERNEL_LOOP(i, num_elements) { + float scale = 1.0 / p; + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, i, 0, &state); + float rand = hiprand_uniform(&state); + if (input_ptr[i] < p) { + output_ptr[i] = 0; + } else { + output_ptr[i] = input_ptr[i] * scale; + } + } +} + void forward_kernel_wrapper(DropoutMeta *m, - float const *input_ptr, - float *output_ptr) { + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); - Internal::forward_kernel(m, input_ptr, output_ptr, stream); + + Internal::forward_kernel(m, + input.get_float_ptr(), + output.get_float_ptr(), + input.domain.get_volume(), + stream); + + // printf("dropout %d\n", input.domain.get_volume()); + // assert(false); } void backward_kernel_wrapper(DropoutMeta *m, - float const *output_grad_ptr, - float *input_grad_ptr) { + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorW const &input_grad) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); - Internal::backward_kernel(m, output_grad_ptr, input_grad_ptr, stream); + Internal::backward_kernel(m, + output_grad.get_float_ptr(), + input_grad.get_float_ptr(), + output_grad.domain.get_volume(), + stream); } namespace Internal { @@ -99,35 +152,58 @@ namespace Internal { void forward_kernel(DropoutMeta *m, float const *input_ptr, float *output_ptr, + size_t num_elements, hipStream_t stream) { checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + int parallelism = num_elements; + hipLaunchKernelGGL(HIP_KERNEL_NAME(dropout_forward_kernel), + GET_BLOCKS(parallelism), + min(CUDA_NUM_THREADS, parallelism), + 0, + stream, + m->seed, + m->rate, + num_elements, + input_ptr, + output_ptr); - checkCUDNN(miopenDropoutForward(m->handle.dnn, - m->dropoutDesc, - m->inputTensor /* not used */, - m->inputTensor, - input_ptr, - m->outputTensor, - output_ptr, - m->reserveSpace, - m->reserveSpaceSize)); + // checkCUDNN(miopenDropoutForward(m->handle.dnn, + // m->dropoutDesc, + // m->inputTensor /* not used */, + // m->inputTensor, + // input_ptr, + // m->outputTensor, + // output_ptr, + // m->reserveSpace, + // m->reserveSpaceSize)); } void backward_kernel(DropoutMeta *m, float const *output_grad_ptr, float *input_grad_ptr, + size_t num_elements, hipStream_t stream) { checkCUDNN(miopenSetStream(m->handle.dnn, stream)); - - checkCUDNN(miopenDropoutBackward(m->handle.dnn, - m->dropoutDesc, - m->inputTensor /* not used */, - m->outputTensor, - output_grad_ptr, - m->inputTensor, - input_grad_ptr, - m->reserveSpace, - m->reserveSpaceSize)); + int parallelism = num_elements; + hipLaunchKernelGGL(HIP_KERNEL_NAME(dropout_backward_kernel), + GET_BLOCKS(parallelism), + min(CUDA_NUM_THREADS, parallelism), + 0, + stream, + m->seed, + m->rate, + num_elements, + output_grad_ptr, + input_grad_ptr); + // checkCUDNN(miopenDropoutBackward(m->handle.dnn, + // m->dropoutDesc, + // m->inputTensor /* not used */, + // m->outputTensor, + // output_grad_ptr, + // m->inputTensor, + // input_grad_ptr, + // m->reserveSpace, + // m->reserveSpaceSize)); } } // namespace Internal diff --git a/src/ops/kernels/dropout_kernels.cu b/src/ops/kernels/dropout_kernels.cu index 4a76301fd6..c5b1a384df 100644 --- a/src/ops/kernels/dropout_kernels.cu +++ b/src/ops/kernels/dropout_kernels.cu @@ -29,6 +29,10 @@ DropoutMeta::DropoutMeta(FFHandler handler, Domain const &output_domain) : OpMeta(handler) { profiling = dropout->profiling; + rate = dropout->rate; + seed = dropout->seed; + input_type[0] = dropout->data_type; + output_type[0] = dropout->data_type; checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); checkCUDNN(cudnnCreateDropoutDescriptor(&dropoutDesc)); @@ -74,20 +78,97 @@ DropoutMeta::~DropoutMeta(void) { namespace Kernels { namespace Dropout { +__global__ void dropout_forward_kernel(float p, + long long seed, + size_t num_elements, + float const *input_ptr, + float *output_ptr) { + CUDA_KERNEL_LOOP(i, num_elements) { + float scale = 1.0 / p; + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + float rand = curand_uniform(&state); + if (input_ptr[i] < p) { + output_ptr[i] = 0; + } else { + output_ptr[i] = input_ptr[i] * scale; + } + } +} + +__global__ void dropout_backward_kernel(float p, + long long seed, + size_t num_elements, + float const *input_ptr, + float *output_ptr) { + CUDA_KERNEL_LOOP(i, num_elements) { + float scale = 1.0 / p; + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + float rand = curand_uniform(&state); + if (input_ptr[i] < p) { + output_ptr[i] = 0; + } else { + output_ptr[i] = input_ptr[i] * scale; + } + } +} + void forward_kernel_wrapper(DropoutMeta *m, - float const *input_ptr, - float *output_ptr) { + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - Internal::forward_kernel(m, input_ptr, output_ptr, stream); + + cudaEvent_t t_start, t_end; + if (m->profiling) { + cudaEventCreate(&t_start); + cudaEventCreate(&t_end); + cudaEventRecord(t_start, stream); + } + + Internal::forward_kernel(m, + input.get_float_ptr(), + output.get_float_ptr(), + input.domain.get_volume(), + stream); + if (m->profiling) { + cudaEventRecord(t_end, stream); + checkCUDA(cudaEventSynchronize(t_end)); + float elapsed = 0; + checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); + cudaEventDestroy(t_start); + cudaEventDestroy(t_end); + printf(" [dropout] forward time = %.2lfms\n", elapsed); + } } void backward_kernel_wrapper(DropoutMeta *m, - float const *output_grad_ptr, - float *input_grad_ptr) { + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorW const &input_grad) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - Internal::backward_kernel(m, output_grad_ptr, input_grad_ptr, stream); + + cudaEvent_t t_start, t_end; + if (m->profiling) { + cudaEventCreate(&t_start); + cudaEventCreate(&t_end); + cudaEventRecord(t_start, stream); + } + Internal::backward_kernel(m, + output_grad.get_float_ptr(), + input_grad.get_float_ptr(), + output_grad.domain.get_volume(), + stream); + if (m->profiling) { + cudaEventRecord(t_end, stream); + checkCUDA(cudaEventSynchronize(t_end)); + float elapsed = 0; + checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); + cudaEventDestroy(t_start); + cudaEventDestroy(t_end); + printf(" [dropout] backward time = %.2lfms\n", elapsed); + } } namespace Internal { @@ -95,33 +176,48 @@ namespace Internal { void forward_kernel(DropoutMeta *m, float const *input_ptr, float *output_ptr, + size_t num_elements, cudaStream_t stream) { checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - checkCUDNN(cudnnDropoutForward(m->handle.dnn, - m->dropoutDesc, - m->inputTensor, - input_ptr, - m->outputTensor, - output_ptr, - m->reserveSpace, - m->reserveSpaceSize)); + int parallelism = num_elements; + dropout_forward_kernel<<>>( + m->seed, m->rate, num_elements, input_ptr, output_ptr); + + // checkCUDNN(cudnnDropoutForward(m->handle.dnn, + // m->dropoutDesc, + // m->inputTensor, + // input_ptr, + // m->outputTensor, + // output_ptr, + // m->reserveSpace, + // m->reserveSpaceSize)); } void backward_kernel(DropoutMeta *m, float const *output_grad_ptr, float *input_grad_ptr, + size_t num_elements, cudaStream_t stream) { checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + int parallelism = num_elements; + dropout_backward_kernel<<>>( + m->seed, m->rate, num_elements, output_grad_ptr, input_grad_ptr); - checkCUDNN(cudnnDropoutBackward(m->handle.dnn, - m->dropoutDesc, - m->outputTensor, - output_grad_ptr, - m->inputTensor, - input_grad_ptr, - m->reserveSpace, - m->reserveSpaceSize)); + // checkCUDNN(cudnnDropoutBackward(m->handle.dnn, + // m->dropoutDesc, + // m->outputTensor, + // output_grad_ptr, + // m->inputTensor, + // input_grad_ptr, + // m->reserveSpace, + // m->reserveSpaceSize)); } } // namespace Internal diff --git a/src/ops/layer_norm.cc b/src/ops/layer_norm.cc index a6130bb425..ccbd7c2dd6 100644 --- a/src/ops/layer_norm.cc +++ b/src/ops/layer_norm.cc @@ -226,10 +226,10 @@ LayerNorm::LayerNorm(FFModel &model, dims[i] = inputs[0]->dims[i]; } assert(numInputs == 1); - dims[num_dims].degree = inputs[0]->dims[inputs[0]->num_dims - 2].degree; + dims[num_dims].degree = inputs[0]->dims[inputs[0]->num_dims - 1].degree; dims[num_dims].size = dims[num_dims].degree; dims[num_dims].parallel_idx = - inputs[0]->dims[inputs[0]->num_dims - 2].parallel_idx; + inputs[0]->dims[inputs[0]->num_dims - 1].parallel_idx; dims[num_dims].is_replica_dim = true; num_dims += 1; diff --git a/src/ops/linear.cc b/src/ops/linear.cc index d5257b9c3e..20cc193107 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -190,6 +190,23 @@ Linear::Linear(FFModel &model, params.construct_mappings(*this->parallel_dims_mapping, input_shape); params.solve_dims(input_shape, output_shape, kernel_shape, bias_shape); + kernel_shape.dims[0].size = this->in_channels; + bias_shape.dims[0].degree = _input->dims[_input->num_dims - 1].degree; + bias_shape.dims[0].parallel_idx = + _input->dims[_input->num_dims - 1].parallel_idx; + bias_shape.dims[1].size = bias_shape.dims[1].degree = 1; + bias_shape.dims[1].parallel_idx = -1; + bias_shape.dims[bias_shape.num_dims - 1].size = + bias_shape.dims[bias_shape.num_dims - 1].degree = 1; + for (int i = 0; i < input_shape.num_dims - 1; i++) { + if (_input->dims[i].degree > 1) { + bias_shape.dims[bias_shape.num_dims - 1].size *= _input->dims[i].degree; + bias_shape.dims[bias_shape.num_dims - 1].degree *= _input->dims[i].degree; + bias_shape.dims[bias_shape.num_dims - 1].parallel_idx = + _input->dims[i].parallel_idx; + } + } + if (allocate_weights) { Initializer *kernel_initializer = new GlorotUniform(std::rand() /*seed*/); @@ -220,7 +237,6 @@ Linear::Linear(FFModel &model, outputs[0] = model.create_parallel_tensor_legion_ordering( output_shape.num_dims, output_shape.dims, _data_type, this); - assert(check_output_input_weight_parallel_dims(allocate_weights)); } void Linear::init(FFModel const &ff) { @@ -433,7 +449,7 @@ void Linear::forward_task_with_dim(Task const *task, int out_dim = acc_output.rect.hi[0] - acc_output.rect.lo[0] + 1; int batch_size = acc_output.rect.volume() / out_dim; assert(acc_output.rect.volume() == static_cast(out_dim * batch_size)); - assert(acc_input.rect.volume() == static_cast(in_dim * batch_size)); + // assert(acc_input.rect.volume() == static_cast(in_dim * batch_size)); assert(acc_kernel.rect.volume() == static_cast(in_dim * out_dim)); float const *acc_bias_ptr = NULL; if (m->use_bias) { diff --git a/src/ops/reshape.cc b/src/ops/reshape.cc index 2b8a60bf21..07797bd223 100644 --- a/src/ops/reshape.cc +++ b/src/ops/reshape.cc @@ -80,9 +80,14 @@ Op *Reshape::create_operator_from_layer( return new Reshape(model, layer->layer_guid, inputs[0], shape, layer->name); } +bool match_pattern(std::vector const &_shape) { + return (_shape.size() == 4 && _shape[1] == 1 && _shape[2] == 1 && + _shape[3] == 512); +} + Reshape::Reshape(FFModel &model, LayerID const &_layer_guid, - const ParallelTensor input, + ParallelTensor const input, std::vector const &_shape, char const *name) : Op(model, @@ -106,19 +111,64 @@ Reshape::Reshape(FFModel &model, if (input->dims[i].is_replica_dim) { num_replica_dims++; } + // std::cout << "reshape input size: " << input->dims[i].size + // << ", parallelidx: " << input->dims[i].parallel_idx << ". degree: " << input->dims[i].degree + // << "is replicate dim: " << input->dims[i].is_replica_dim << + // "\n"; } + + // assert(false); // assert that all replica dims are leading dims for (int i = 0; i < num_replica_dims; i++) { assert(input->dims[input->num_dims - 1 - i].is_replica_dim); } int numdim = (int)_shape.size(); ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i].size = _shape[numdim - 1 - i]; - dims[i].degree = 1; - dims[i].parallel_idx = -1; - dims[i].is_replica_dim = false; - } + + + bool expanded = numdim >= input->num_dims; + bool aggregation = numdim < input->num_dims - 1; + + for (int i = 0; i < numdim; i++) { + if (expanded && i < numdim - 1 && + _shape[i] * _shape[i + 1] == input->dims[numdim - i - 2].size) { + dims[numdim - i - 1].size = _shape[i]; + dims[numdim - i - 1].degree = input->dims[numdim - i - 2].degree; + dims[numdim - i - 1].parallel_idx = + input->dims[numdim - i - 2].parallel_idx; + dims[numdim - i - 1].is_replica_dim = + input->dims[numdim - i - 2].is_replica_dim; + std::cout << "expand dim i:" << i << ", " << dims[numdim - i - 1].degree + << ", " << dims[numdim - i - 1].size << "\n"; + } else if (aggregation && + (_shape[i] == input->dims[input->num_dims - 2 - i].size * + input->dims[input->num_dims - 3 - i].size)) { + // inherit + dims[numdim - i - 1].size = _shape[i]; + dims[numdim - i - 1].degree = + input->dims[input->num_dims - 2 - i].degree; + dims[numdim - i - 1].parallel_idx = + input->dims[input->num_dims - 2 - i].parallel_idx; + dims[numdim - i - 1].is_replica_dim = + input->dims[input->num_dims - 2 - i].is_replica_dim; + // std::cout << "agree i: " << i <<", " << _shape[i] << "\n"; + } else { + dims[numdim - i - 1].size = _shape[i]; + dims[numdim - i - 1].degree = 1; + dims[numdim - i - 1].parallel_idx = -1; + dims[numdim - i - 1].is_replica_dim = false; + } + } + + + + + // for (int i = 0; i < numdim; i++) { + // dims[i].size = _shape[numdim - 1 - i]; + // dims[i].degree = 1; + // dims[i].parallel_idx = -1; + // dims[i].is_replica_dim = false; + // } // copy all replica dims for (int i = 0; i < num_replica_dims; i++) { dims[i + numdim] = input->dims[input->num_dims - 1 - i]; @@ -131,6 +181,24 @@ Reshape::Reshape(FFModel &model, } dims[numdim - 1 - i] = input->dims[input->num_dims - 1 - i]; } + + //TODO temporary fix for input to attention QK, fix it after fuse the attention block + if(match_pattern(_shape) && model.config.tensor_parallelism_degree > 1){ + //number of heads + + dims[2].size = 12; + dims[2].degree = model.config.tensor_parallelism_degree; + dims[2].parallel_idx = 0; + dims[2].is_replica_dim = true; + + dims[4].size = 1; + dims[4].degree = 1; + dims[4].parallel_idx = -1; + dims[4].is_replica_dim = false; + + } + + outputs[0] = model.create_parallel_tensor_legion_ordering( numdim, dims, input->data_type, this); assert(outputs[0]->get_volume() == inputs[0]->get_volume()); diff --git a/src/parallel_ops/allreduce.cc b/src/parallel_ops/allreduce.cc new file mode 100644 index 0000000000..7052bb3ed5 --- /dev/null +++ b/src/parallel_ops/allreduce.cc @@ -0,0 +1,280 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/parallel_ops/allreduce.h" +#include "flexflow/ffconst_utils.h" +#include "flexflow/model.h" +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" +#include "flexflow/utils/hash_utils.h" + +namespace FlexFlow { +// declare Legion names +using Legion::ArgumentMap; +using Legion::Context; +using Legion::coord_t; +using Legion::Domain; +using Legion::Future; +using Legion::FutureMap; +using Legion::IndexLauncher; +using Legion::LogicalPartition; +using Legion::LogicalRegion; +using Legion::Machine; +using Legion::Memory; +using Legion::PhysicalRegion; +using Legion::Predicate; +using Legion::Rect; +using Legion::RegionRequirement; +using Legion::Runtime; +using Legion::Task; +using Legion::TaskArgument; +using Legion::TaskLauncher; + +using namespace FlexFlow::Kernels::AllReduce; + +/* Params */ +bool operator==(AllReduceParams const &lhs, AllReduceParams const &rhs) { + return lhs.allreduce_legion_dim == rhs.allreduce_legion_dim; +} + +bool AllReduceParams::is_valid(ParallelTensorShape const &input) const { + return input.is_valid(); +} + +AllReduceParams AllReduce::get_params() const { + AllReduceParams params; + params.allreduce_legion_dim = this->allreduce_dim; + if (this->name != nullptr) { + strcpy(params.name, this->name); + } + return params; +} + +AllReduce::AllReduce(FFModel &model, + const ParallelTensor _input, + int _allreduce_legion_dim, + char const *name) + : ParallelOp(model, OP_ALLREDUCE, name, _input), + allreduce_dim(_allreduce_legion_dim) { + int numdim = _input->num_dims; + ParallelDim dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdim; i++) { + dims[i] = _input->dims[i]; + } + assert(dims[allreduce_dim].degree > 1); + // ParallelTensorBase::update_parallel_ids(numdim, dims); + outputs[0] = model.create_parallel_tensor_legion_ordering( + numdim, dims, _input->data_type, this); +} + +AllReduce::AllReduce(FFModel &model, + AllReduceParams const ¶ms, + ParallelTensor const input, + char const *name) + : AllReduce(model, input, params.allreduce_legion_dim, params.name) {} + +void AllReduce::create_input_partition(FFModel &ff) { + // Do nothing + return; +} + + +OpMeta *AllReduce::init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + AllReduce *ar = (AllReduce *)task->args; + FFHandler handle = *((FFHandler const *)task->local_args); + AllReduceMeta *meta = new AllReduceMeta(handle, ar); + meta->input_type[0] = ar->inputs[0]->data_type; + meta->output_type[0] = ar->outputs[0]->data_type; + assert(meta->input_type[0] == meta->output_type[0]); + return meta; +} + +void AllReduce::init(FFModel const &ff) { + ArgumentMap argmap; + parallel_is = outputs[0]->parallel_is; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + assert(numOutputs == 1); + assert(numInputs == 1); + set_argumentmap_for_init(ff, argmap); + IndexLauncher launcher(ALLREDUCE_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(AllReduce)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + outputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + outputs[0]->region)); + launcher.add_field(1, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap(ff, fm); +} + +void AllReduce::forward(FFModel const &ff) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + parallel_is = outputs[0]->parallel_is; + assert(numOutputs == 1); + assert(numInputs == 1); + set_argumentmap_for_forward(ff, argmap); + IndexLauncher launcher(ALLREDUCE_FWD_TASK_ID, + outputs[0]->parallel_is, + TaskArgument(NULL, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + outputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + outputs[0]->region)); + launcher.add_field(1, FID_DATA); + runtime->execute_index_space(ctx, launcher); +} + +void AllReduce::backward(FFModel const &ff) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + assert(numOutputs == 1); + assert(numInputs == 1); + set_argumentmap_for_backward(ff, argmap); + IndexLauncher launcher(ALLREDUCE_BWD_TASK_ID, + inputs[0]->parallel_is, + TaskArgument(NULL, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + inputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + inputs[0]->region_grad)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + outputs[0]->region_grad)); + launcher.add_field(1, FID_DATA); + runtime->execute_index_space(ctx, launcher); +} + +bool AllReduce::measure_operator_cost(Simulator *sim, + MachineView const &pc, + CostMetrics &cost_metrics) const { + cost_metrics = CostMetrics(); + cost_metrics.forward_time = 0.0f; + cost_metrics.backward_time = 0.0f; + + cost_metrics.sync_time = 0; + cost_metrics.inputs_memory = 0; + cost_metrics.outputs_memory = 0; + cost_metrics.weights_memory = 0; + return true; +} + +bool AllReduce::get_int_parameter(PMParameter para, int *value) const { + switch (para) { + case PM_ALLREDUCE_DIM: + *value = allreduce_dim; + return true; + default: + return Op::get_int_parameter(para, value); + } +} + +bool AllReduce::append_parallel_op_info( + std::vector ¶llel_ops) const { + ParallelOpInfo ret; + ret.op_type = op_type; + ret.parallel_dim = allreduce_dim; + ret.parallel_degree = -1; // AllReduce does not affect parallel degree + parallel_ops.push_back(ret); + return true; +} + +/*static*/ +void AllReduce::forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 2); + assert(task->regions.size() == 2); + + AllReduceMeta const *m = *((AllReduceMeta **)task->local_args); + + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + + assert(input.data_type == output.data_type); + forward_kernel_wrapper(m, input, output); +} + +void AllReduce::backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 2); + assert(task->regions.size() == 2); + AllReduceMeta const *m = *((AllReduceMeta **)task->local_args); + + GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + + assert(input_grad.data_type == output_grad.data_type); + backward_kernel_wrapper(m, input_grad, output_grad); +} + +}; // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::AllReduceParams const ¶ms) const { + size_t key = 0; + hash_combine(key, params.allreduce_legion_dim); + return key; +} + +} // namespace std diff --git a/src/parallel_ops/kernels/allreduce_kernels.cpp b/src/parallel_ops/kernels/allreduce_kernels.cpp new file mode 100644 index 0000000000..0aea27107d --- /dev/null +++ b/src/parallel_ops/kernels/allreduce_kernels.cpp @@ -0,0 +1,58 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" +#include "flexflow/utils/hip_helper.h" +#include + +namespace FlexFlow { + +AllReduceMeta::AllReduceMeta(FFHandler handle, AllReduce const *reduct) + : OpMeta(handle) {} + +namespace Kernels { +namespace AllReduce { + +void forward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + assert(input.data_type == output.data_type); + assert(input.domain == output.domain); + size_t hidden_dim_size = input.domain.hi()[0] - input.domain.lo()[0] + 1; +#ifdef FF_USE_NCCL + // ncclDataType_t nccl_data_type = ff_to_nccl_datatype(input.data_type); + checkNCCL(ncclAllReduce(input.ptr, + output.ptr, + input.domain.get_volume(), + ncclFloat, + ncclSum, + m->handle.ncclComm, + stream)); +#else + assert(false && "Must enable FF_USE_NCCL to use AllReduce operators"); +#endif +} + +void backward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad) { + assert(false && "To be implemented"); +} + +} // namespace AllReduce +} // namespace Kernels +} // namespace FlexFlow diff --git a/src/parallel_ops/kernels/allreduce_kernels.cu b/src/parallel_ops/kernels/allreduce_kernels.cu new file mode 100644 index 0000000000..1e932d2b12 --- /dev/null +++ b/src/parallel_ops/kernels/allreduce_kernels.cu @@ -0,0 +1,73 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" +#include "flexflow/utils/cuda_helper.h" + +namespace FlexFlow { + +AllReduceMeta::AllReduceMeta(FFHandler handle, AllReduce const *reduct) + : OpMeta(handle) {} + +namespace Kernels { +namespace AllReduce { + +void forward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + assert(input.data_type == output.data_type); + assert(input.domain == output.domain); +#ifdef FF_USE_NCCL + // ncclDataType_t nccl_data_type = ff_to_nccl_datatype(input.data_type); + checkNCCL(ncclAllReduce(input.ptr, + output.ptr, + input.domain.get_volume(), + ncclFloat, + ncclSum, + m->handle.ncclComm, + stream)); +#else + assert(false && "Must enable FF_USE_NCCL to use AllReduce operators"); +#endif +} + +void backward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + assert(input_grad.data_type == output_grad.data_type); + assert(input_grad.domain == output_grad.domain); +#ifdef FF_USE_NCCL + // ncclDataType_t nccl_data_type = ff_to_nccl_datatype(input.data_type); + // std::cout <<"input volume: " << input.domain.get_volume() << "\n"; + // print_tensor((float*)input.ptr, 32, "input ptr"); + checkNCCL(ncclAllReduce(output_grad.ptr, + input_grad.ptr, + output_grad.domain.get_volume(), + ncclFloat, + ncclSum, + m->handle.ncclComm, + stream)); +#else + assert(false && "Must enable FF_USE_NCCL to use AllReduce operators"); +#endif +} + +} // namespace AllReduce +} // namespace Kernels +} // namespace FlexFlow diff --git a/src/runtime/cuda_helper.cu b/src/runtime/cuda_helper.cu index 3236545cb9..a4a58e60fc 100644 --- a/src/runtime/cuda_helper.cu +++ b/src/runtime/cuda_helper.cu @@ -61,6 +61,15 @@ __global__ void copy_kernel(DT *dst, const DT *src, coord_t size) { dst[i] = src[i]; } } +template +__global__ void copy_kernel_with_replicate(DT *dst, + const DT *src, + coord_t origin_size, + coord_t size) { + CUDA_KERNEL_LOOP(i, size) { + dst[i] = src[i % origin_size]; + } +} template __global__ void reluBackward(DT *grad_ptr, const DT *output, size_t n) { @@ -410,6 +419,14 @@ template __global__ void template __global__ void copy_kernel(float *dst, float const *src, coord_t size); +template __global__ void copy_kernel_with_replicate(float *dst, + float const *src, + coord_t origin_size, + coord_t size); +template __global__ void copy_kernel_with_replicate( + int32_t *dst, int32_t const *src, coord_t origin_size, coord_t size); +template __global__ void copy_kernel_with_replicate( + int64_t *dst, int64_t const *src, coord_t origin_size, coord_t size); template __global__ void copy_kernel(int32_t *dst, int32_t const *src, coord_t size); template __global__ void diff --git a/src/runtime/ffconst_utils.cc b/src/runtime/ffconst_utils.cc index 7ab9201113..e2debfa2d5 100644 --- a/src/runtime/ffconst_utils.cc +++ b/src/runtime/ffconst_utils.cc @@ -168,6 +168,8 @@ std::string get_operator_type_name(OperatorType type) { return "Replicate"; case OP_REDUCTION: return "Reduction"; + case OP_ALLREDUCE: + return "AllReduce"; case OP_PIPELINE: return "Pipeline"; case OP_FUSED_PARALLEL: diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index 7447125197..762c5911d6 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -39,6 +39,7 @@ #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" #include "flexflow/parallel_ops/combine.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/fused_parallel_op.h" #include "flexflow/parallel_ops/partition.h" #include "flexflow/parallel_ops/reduction.h" @@ -1927,9 +1928,37 @@ std::pair, std::unordered_map> model->config.numNodes * model->config.workersPerNode; data_parallel_view.stride[0] = 1; data_parallel_view.start_device_id = 0; + // Currently assume a 1D machine view is needed + assert(model->config.data_parallelism_degree == 1 || + model->config.tensor_parallelism_degree == 1); + int degree = model->config.data_parallelism_degree * + model->config.tensor_parallelism_degree; for (auto const &node : curr_best_graph->inEdges) { - curr_optimal_views[node.first] = data_parallel_view; - } + Op const *op = node.first.ptr; + MachineView mv; + mv.device_type = MachineView::GPU; + mv.ndims = 1; + int total_parallel_degree = 1; + for (int i = 0; i < op->outputs[0]->num_dims; i++) { + total_parallel_degree *= op->outputs[0]->dims[i].degree; + } + mv.dim[0] = total_parallel_degree; + mv.stride[0] = 1; + mv.start_device_id = 0; + // std::cout << mv.start_device_id + degree - 1 << "\n"; + // std::cout << model->config.numNodes << "\n"; + // std::cout << model->config.workersPerNode << "\n"; + // assert(false); + assert(mv.start_device_id + degree - 1 < + model->config.numNodes * model->config.workersPerNode); + curr_optimal_views[node.first] = mv; + for (int i = 0; i < node.first.ptr->numOutputs; i++) { + assert(node.first.ptr->outputs[i]->is_valid_machine_view(mv)); + } + } + // for (auto const &node : curr_best_graph->inEdges) { + // curr_optimal_views[node.first] = data_parallel_view; + // } return std::make_pair(std::move(curr_best_graph), curr_optimal_views); } @@ -2295,6 +2324,13 @@ GraphOptimalViewSerialized sez.serialize(reduction->reduction_dim); sez.serialize(reduction->reduction_degree); break; + } + case OP_ALLREDUCE: { + AllReduce *allreduce = (AllReduce *)op; + sez.serialize(allreduce->allreduce_dim); + sez.serialize(strlen(allreduce->name)); + sez.serialize(allreduce->name, strlen(allreduce->name)); + break; } case OP_COMBINE: { Combine *combine = (Combine *)op; @@ -2704,6 +2740,17 @@ void FFModel::deserialize_graph_optimal_view( {reduction_dim, reduction_degree}); break; } + case OP_ALLREDUCE: { + assert(num_inputs == 1); + int allreduce_dim; + dez.deserialize(allreduce_dim); + size_t name_len; + char name[MAX_OPNAME] = {0}; + dez.deserialize(name_len); + dez.deserialize(name, name_len); + node = get_or_create_node(inputs[0], {allreduce_dim}); + break; + } case OP_FUSED_PARALLEL: { assert(num_inputs == 1); std::vector parallel_ops; diff --git a/src/runtime/hip_helper.cpp b/src/runtime/hip_helper.cpp index ffdcf0dac1..8617cb2ef3 100644 --- a/src/runtime/hip_helper.cpp +++ b/src/runtime/hip_helper.cpp @@ -55,6 +55,16 @@ __global__ void copy_kernel(DT *dst, const DT *src, coord_t size) { } } +template +__global__ void copy_kernel_with_replicate(DT *dst, + const DT *src, + coord_t origin_size, + coord_t size) { + CUDA_KERNEL_LOOP(i, size) { + dst[i] = src[i % origin_size]; + } +} + template __global__ void reluBackward(DT *grad_ptr, const DT *output, size_t n) { CUDA_KERNEL_LOOP(i, n) { @@ -404,6 +414,16 @@ template __global__ void template __global__ void copy_kernel(float *dst, float const *src, coord_t size); + +template __global__ void copy_kernel_with_replicate(float *dst, + float const *src, + coord_t origin_size, + coord_t size); +template __global__ void copy_kernel_with_replicate( + int32_t *dst, int32_t const *src, coord_t origin_size, coord_t size); +template __global__ void copy_kernel_with_replicate( + int64_t *dst, int64_t const *src, coord_t origin_size, coord_t size); + template __global__ void copy_kernel(int32_t *dst, int32_t const *src, coord_t size); template __global__ void diff --git a/src/runtime/model.cc b/src/runtime/model.cc index c52d5e7f0c..6feddcd03c 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -50,6 +50,7 @@ #include "flexflow/ops/split.h" #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/combine.h" #include "flexflow/parallel_ops/fused_parallel_op.h" #include "flexflow/parallel_ops/partition.h" @@ -78,10 +79,10 @@ Op::Op(FFModel &model, int numWeights, bool allocate_weights, int numOutputs, - const ParallelTensor input1, - const ParallelTensor input2, - const ParallelTensor input3, - const ParallelTensor input4) + ParallelTensor const input1, + ParallelTensor const input2, + ParallelTensor const input3, + ParallelTensor const input4) : Op(model, otype, dtype, @@ -101,10 +102,10 @@ Op::Op(FFModel &model, int _numInputs, int _numWeights, int _numOutputs, - const ParallelTensor _input1, - const ParallelTensor _input2, - const ParallelTensor _input3, - const ParallelTensor _input4) + ParallelTensor const _input1, + ParallelTensor const _input2, + ParallelTensor const _input3, + ParallelTensor const _input4) : op_type(_otype), data_type(_dtype), op_guid(model.op_global_guid++), numInputs(_numInputs), numWeights(_numWeights), numOutputs(_numOutputs), profiling(model.config.profiling) { @@ -586,8 +587,13 @@ ncclComm_t Op::init_nccl_comms_task(Task const *task, ncclComm_t ncclComm; fprintf(stderr, "Before ncclCommInitRank\n"); checkNCCL(ncclCommInitRank(&ncclComm, allRanks, ncclId, myRank)); - fprintf(stderr, "After ncclCommInitRank ncclComm(%p) allRanks(%d) myRank(%d) ncclId(%p)\n", - ncclComm, allRanks, myRank, ncclId); + fprintf(stderr, + "After ncclCommInitRank ncclComm(%p) allRanks(%d) myRank(%d) " + "ncclId(%p)\n", + ncclComm, + allRanks, + myRank, + ncclId); return ncclComm; } #endif @@ -782,9 +788,9 @@ void Op::register_output_parallel_dims( operation); } -int Op::get_output_to_input_dim_mapping(const ParallelTensor output, +int Op::get_output_to_input_dim_mapping(ParallelTensor const output, int output_dim, - const ParallelTensor input) { + ParallelTensor const input) { int output_idx = -1, input_idx = -1; for (int i = 0; i < numOutputs; i++) { if (output == outputs[i]) { @@ -817,9 +823,9 @@ int Op::get_output_to_input_dim_mapping(const ParallelTensor output, return -1; } -int Op::get_output_to_weight_dim_mapping(const ParallelTensor output, +int Op::get_output_to_weight_dim_mapping(ParallelTensor const output, int output_dim, - const ParallelTensor weight) { + ParallelTensor const weight) { int output_idx = -1, weight_idx = -1; for (int i = 0; i < numOutputs; i++) { if (output == outputs[i]) { @@ -885,6 +891,9 @@ bool Op::check_output_input_weight_parallel_dims(bool allocate_weights) const { break; } + printf("other dim degree: %d, input dim degree %d\n", + other_dim.degree, + input_dim.degree); assert(other_dim.degree == input_dim.degree); assert(other_dim.parallel_idx == input_dim.parallel_idx); } @@ -894,18 +903,25 @@ bool Op::check_output_input_weight_parallel_dims(bool allocate_weights) const { bool Op::check_output_input_weight_same_parallel_is() const { assert(numOutputs > 0); IndexSpace parallel_is = outputs[0]->parallel_is; + std::cout << "output space: " + << ", " << parallel_is << "\n"; for (int i = 0; i < numOutputs; i++) { if (outputs[i]->parallel_is != parallel_is) { + std::cout << "output mismatch" + << "\n"; return false; } } for (int i = 0; i < numInputs; i++) { + std::cout << "input space: " << i << ", " << inputs[i]->parallel_is << "\n"; if (inputs[i]->parallel_is != parallel_is) { return false; } } for (int i = 0; i < numWeights; i++) { if (weights[i]->parallel_is != parallel_is) { + std::cout << "weight mismatch" + << "\n"; return false; } } @@ -947,7 +963,7 @@ void Op::set_argumentmap_for_init(FFModel const &ff, ArgumentMap &argmap) { for (PointInRectIterator it(rect); it(); it++) { \ FFHandler handle = ff.handlers[view.get_device_id(*it)]; \ if (ff.config.computationMode == COMP_MODE_TRAINING && \ - op_type == OP_WEIGHT) { \ + (op_type == OP_WEIGHT || op_type == OP_ALLREDUCE)) { \ ncclComm_t *nccl_comms = ff.find_nccl_comms(view); \ handle.ncclComm = nccl_comms[idx++]; \ } \ @@ -1191,9 +1207,11 @@ FFModel::FFModel(FFConfig &_config) //} ArgumentMap argmap; - Rect<1> task_rect(Point<1>(0), - Point<1>(config.workersPerNode * config.numNodes - 1)); - IndexSpaceT<1> task_is = runtime->create_index_space(ctx, task_rect); + // Rect<1> task_rect(Point<1>(0), + // Point<1>(config.workersPerNode * config.numNodes - 1)); + // IndexSpaceT<1> task_is = runtime->create_index_space(ctx, task_rect); + Domain domain = runtime->get_index_space_domain(ctx, config.all_gpu_task_is); + Rect<1> task_rect = domain; // int rank = 0; for (PointInRectIterator<1> it(task_rect); it(); it++) { @@ -1207,7 +1225,7 @@ FFModel::FFModel(FFConfig &_config) // Init CUDA library on each worker IndexLauncher initLauncher(FF_INIT_TASK_ID, - task_is, + config.all_gpu_task_is, TaskArgument(NULL, 0), argmap, Predicate::TRUE_PRED, @@ -1300,7 +1318,7 @@ Tensor FFModel::create_tensor(int numdim, } ParallelTensor FFModel::create_parallel_tensor(int numdim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *op, int idx, @@ -1333,7 +1351,7 @@ Tensor FFModel::create_tensor_legion_ordering(int numdim, ParallelTensor FFModel::create_parallel_tensor_legion_ordering(int numdim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *op, int idx, @@ -1383,7 +1401,7 @@ Tensor FFModel::create_tensor(int const dims[], } template -ParallelTensor FFModel::create_parallel_tensor(const ParallelDim dims[], +ParallelTensor FFModel::create_parallel_tensor(ParallelDim const dims[], DataType data_type, Op const *owner_op, int owner_idx, @@ -1464,7 +1482,7 @@ Parameter FFModel::create_weight(int numdim, } template -ParallelParameter FFModel::create_parallel_weight(const ParallelDim dims[], +ParallelParameter FFModel::create_parallel_weight(ParallelDim const dims[], DataType data_type, Op const *owner_op, bool create_grad, @@ -1494,7 +1512,7 @@ ParallelParameter FFModel::create_parallel_weight(const ParallelDim dims[], } ParallelParameter FFModel::create_parallel_weight(int numdim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *owner_op, bool create_grad, @@ -1514,7 +1532,7 @@ ParallelParameter FFModel::create_parallel_weight(int numdim, ParallelParameter FFModel::create_parallel_weight_legion_ordering( int numdim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *owner_op, bool create_grad, @@ -1719,7 +1737,7 @@ void FFModel::map_weight_with_dim(ParallelTensor weight, } bool FFModel::get_parallel_tensor_from_tensor( - const Tensor tensor, ParallelTensor ¶llel_tensor) const { + Tensor const tensor, ParallelTensor ¶llel_tensor) const { // check if tensor->parallel_tensor is already set if (tensor->parallel_tensor != nullptr) { parallel_tensor = tensor->parallel_tensor; @@ -1756,7 +1774,7 @@ bool FFModel::get_parallel_tensor_from_tensor( } void FFModel::create_disjoint_partition(int num_dims, - const ParallelDim dims[], + ParallelDim const dims[], IndexSpace const &part_is, LogicalRegion const ®ion, LogicalPartition &part) { @@ -1779,7 +1797,7 @@ void FFModel::create_disjoint_partition(int num_dims, template void FFModel::create_disjoint_partition_with_dim2( - const ParallelDim dims[], + ParallelDim const dims[], IndexSpaceT const &part_is, LogicalRegion const ®ion, LogicalPartition &part) { @@ -1812,7 +1830,7 @@ void FFModel::create_disjoint_partition_with_dim2( } void FFModel::create_aliased_partition(int num_dims, - const ParallelDim dims[], + ParallelDim const dims[], int aliased_dim, IndexSpace const &part_is, LogicalRegion const ®ion, @@ -1836,7 +1854,7 @@ void FFModel::create_aliased_partition(int num_dims, template void FFModel::create_aliased_partition_with_dim2( - const ParallelDim dims[], + ParallelDim const dims[], int aliased_dim, IndexSpaceT const &part_is, LogicalRegion const ®ion, @@ -1873,7 +1891,7 @@ void FFModel::create_aliased_partition_with_dim2( } template -void FFModel::create_disjoint_partition(const ParallelTensor tensor, +void FFModel::create_disjoint_partition(ParallelTensor const tensor, IndexSpaceT const &part_is, LogicalPartition &part_fwd, LogicalPartition &part_bwd) { @@ -1921,7 +1939,7 @@ void FFModel::create_disjoint_partition(const ParallelTensor tensor, template void FFModel::create_data_parallel_partition_with_diff_dims( - const ParallelTensor tensor, + ParallelTensor const tensor, IndexSpaceT const &part_is, LogicalPartition &part_fwd, LogicalPartition &part_bwd) { @@ -2303,7 +2321,7 @@ IndexSpace FFModel::get_task_is(ParallelConfig const &pc) const { return get_task_is(view); } -IndexSpace FFModel::get_or_create_task_is(const ParallelTensor tensor) { +IndexSpace FFModel::get_or_create_task_is(ParallelTensor const tensor) { MachineView view; view.ndims = 0; for (int i = 0; i < tensor->num_dims; i++) { @@ -2474,6 +2492,10 @@ void FFModel::update() { } } +void FFModel::unified_update() { + optimizer->unified_update(parameters); +} + Op *FFModel::get_final_operator() const { int idx = operators.size() - 1; while (operators[idx]->op_type == OP_INPUT || @@ -2504,9 +2526,10 @@ bool FFModel::apply_fusion(std::vector const &operators, operators[l]->op_type == OP_WEIGHT) { continue; } - // don't fuse parallel op since they have different parallel_is in - // forward/backward - if (operators[l]->is_parallel_op()) { + // don't fuse parallel op except allReduce since they have different + // parallel_is in forward/backward + if (operators[l]->is_parallel_op() && + operators[l]->op_type != OP_ALLREDUCE) { continue; } size_t start = 0; @@ -2549,9 +2572,10 @@ bool FFModel::apply_fusion(std::vector const &operators, operators[i]->op_type == OP_WEIGHT) { continue; } - // don't fuse parallel op since they have different parallel_is in - // forward/backward - if (operators[i]->is_parallel_op()) { + // don't fuse parallel op except allReduce since they have different + // parallel_is in forward/backward + if (operators[i]->is_parallel_op() && + operators[i]->op_type != OP_ALLREDUCE) { continue; } fused_op = new FusedOp(*this, operators[i]); @@ -2623,6 +2647,18 @@ Op *FFModel::create_operator_from_layer( dims[num_dims].degree = 1; dims[num_dims].parallel_idx = -1; dims[num_dims].is_replica_dim = true; + if (config.tensor_parallelism_degree > 1 && num_inputs != 1) { + dims[num_dims].size *= config.tensor_parallelism_degree; + dims[num_dims].degree *= config.tensor_parallelism_degree; + dims[num_dims].parallel_idx = 0; + } + //TODO temporary fix for input to attention QK, fix it after fuse the attention block + else if(config.tensor_parallelism_degree > 1){ + //n heads + dims[num_dims].size *= 12; + dims[num_dims].degree *= config.tensor_parallelism_degree; + dims[num_dims].parallel_idx = 0; + } // create_parallel_tensor adds an NoOp into operators ParallelTensor pt = create_parallel_tensor_legion_ordering(num_dims + 1, @@ -2636,19 +2672,14 @@ Op *FFModel::create_operator_from_layer( assert(tensor->parallel_tensor == nullptr); tensor->parallel_tensor = pt; // start from data parllel tensor - if (config.only_data_parallel && - config.numNodes * config.workersPerNode > 1) { - if (pt->dims[num_dims - 1].size == 1) { - Replicate *repl = new Replicate( - *this, pt, num_dims, config.numNodes * config.workersPerNode); - repl->outputs[0]->dims[num_dims].is_replica_dim = true; - operators.push_back(repl); - } else { - Repartition *part = new Repartition( - *this, pt, num_dims - 1, config.numNodes * config.workersPerNode); - operators.push_back(part); - } - } + // if (config.only_data_parallel && + // config.computationMode == COMP_MODE_TRAINING) { + // Repartition *part = new Repartition( + // *this, pt, num_dims - 1, config.numNodes * + // config.workersPerNode); + // operators.push_back(part); + // } + num_inputs++; return operators[operators.size() - 1]; } case OP_MULTIHEAD_ATTENTION: { @@ -2791,9 +2822,42 @@ Op *FFModel::create_operator_from_layer( } } +bool FFModel::is_transformer_block(int layer_idx) const { + auto const &l = layers[layer_idx]; + if (l->op_type == OP_DROPOUT && layer_idx >= 4 && + layers[layer_idx - 1]->op_type == OP_LINEAR && + layers[layer_idx - 2]->op_type == OP_RESHAPE && + layers[layer_idx - 3]->op_type == OP_TRANSPOSE && + layers[layer_idx - 4]->op_type == OP_BATCHMATMUL) { + return true; + } + return false; +} +bool FFModel::is_mlp_block(int layer_idx) const { + auto const &l = layers[layer_idx]; + // standard opt relu + if (l->op_type == OP_LINEAR && layer_idx >= 2 && + layers[layer_idx - 1]->op_type == OP_RELU && + layers[layer_idx - 2]->op_type == OP_LINEAR) { + return true; + } + // mlp layer with relu embedded in first dense layer + if (l->op_type == OP_LINEAR && layer_idx >= 1 && + layers[layer_idx - 1]->op_type == OP_LINEAR) { + long long value; + layers[layer_idx - 1]->get_int_property("activation", value); + ActiMode activation = (ActiMode)value; + if (activation == AC_MODE_RELU) { + return true; + } + } + return false; +} + void FFModel::create_operators_from_layers() { - std::map tensors_to_parallel_tensors; - for (auto const &l : layers) { + std::map tensors_to_parallel_tensors; + for (int layer_idx = 0; layer_idx < layers.size(); layer_idx++) { + auto const &l = layers[layer_idx]; std::vector inputs; for (int i = 0; i < l->numInputs; i++) { // create new input tensors @@ -2801,7 +2865,50 @@ void FFModel::create_operators_from_layers() { tensors_to_parallel_tensors.end()); inputs.push_back(tensors_to_parallel_tensors[l->inputs[i]]); } - Op *op = create_operator_from_layer(l, inputs); + // Op *op = create_operator_from_layer(l, inputs); + Op *op = nullptr; + if (config.tensor_parallelism_degree > 1 && l->op_type == OP_LAYERNORM && + layer_idx == layers.size() - 6) { + std::vector partitioned_inputs; + Combine *comb = new Combine(*this, + inputs[0], + 3 /*inner most dim*/, + config.tensor_parallelism_degree); + partitioned_inputs.push_back(comb->outputs[0]); + operators.push_back(comb); + op = create_operator_from_layer(l, partitioned_inputs); + } else { + op = create_operator_from_layer(l, inputs); + } + + // add replicate operators after op if needed + if (config.tensor_parallelism_degree > 1 && l->op_type == OP_EMBEDDING) { + // assert(op->numOutputs == 1); + // Replicate *repl = new Replicate(*this, + // op->outputs[0], + // op->outputs[0]->num_dims - 1, + // config.tensor_parallelism_degree); + // operators.push_back(repl); + // op = repl; + } else if (config.tensor_parallelism_degree > 1 && + (is_transformer_block(layer_idx) || is_mlp_block(layer_idx) || + // llama mlp layer + (l->op_type == OP_LINEAR && layer_idx >= 2 && + layers[layer_idx - 1]->op_type == OP_GELU && + layers[layer_idx - 2]->op_type == OP_LINEAR) || + // LLAMA without element-wise operator fusion + (l->op_type == OP_LINEAR && layer_idx >= 5 && + layers[layer_idx - 1]->op_type == OP_EW_MUL && + layers[layer_idx - 2]->op_type == OP_EW_MUL && + layers[layer_idx - 3]->op_type == OP_SIGMOID && + layers[layer_idx - 4]->op_type == OP_LINEAR && + layers[layer_idx - 5]->op_type == OP_LINEAR))) { + assert(op->numOutputs == 1); + AllReduce *allreduce = + new AllReduce(*this, op->outputs[0], op->outputs[0]->num_dims - 1); + operators.push_back(allreduce); + op = allreduce; + } assert(op->numOutputs == l->numOutputs); for (int i = 0; i < op->numOutputs; i++) { tensors_to_parallel_tensors[l->outputs[i]] = op->outputs[i]; @@ -2837,8 +2944,7 @@ void FFModel::compile(LossType loss_type, TaskLauncher launcher(GRAPH_OPTIMIZE_TASK_ID, TaskArgument(&model, sizeof(FFModel *))); Future future = runtime->execute_task(ctx, launcher); - ret = - future.get_result(); + ret = future.get_result(); } else { ret = PCG::Graph::graph_optimize_wrapper(this); } @@ -2974,6 +3080,60 @@ void FFModel::compile(LossType loss_type, } } + int degree = + config.data_parallelism_degree * config.tensor_parallelism_degree; + + for (int op_idx = 0; op_idx < operators.size(); op_idx++) { + Op const *op = operators[op_idx]; + // Skip weight operators + if (op->op_type == OP_WEIGHT) { + continue; + } + // Get machine views + std::vector machine_views; + for (int j = 0; j < config.data_parallelism_degree; j++) { + MachineView mv; + mv.device_type = MachineView::GPU; + mv.ndims = 1; + // mv.start_device_id = 0; + mv.stride[0] = 1; + int parallel_degree = 1; + for (int k = 0; k < op->outputs[0]->num_dims; k++) { + parallel_degree *= op->outputs[0]->dims[k].degree; + } + mv.dim[0] = parallel_degree; + mv.start_device_id = 0; + assert(mv == op->outputs[0]->machine_view); + machine_views.push_back(mv); + } + for (int i = 0; i < op->numOutputs; i++) { + ParallelTensor pt_base = op->outputs[i]; + + if (op->op_type == OP_REPLICATE) { + assert(op->numInputs == 1 && op->numOutputs == 1); + } + std::vector list; + bool found_parallel_tensor = false; + if (!found_parallel_tensor) { + for (int j = 0; j < config.data_parallelism_degree; j++) { + // Copy the metadata from pt_base to pt + ParallelTensor pt = new ParallelTensorBase(*pt_base); + pt->region = + runtime->create_logical_region(ctx, + pt_base->region.get_index_space(), + pt_base->region.get_field_space()); + pt->part = runtime->get_logical_partition( + ctx, pt->region, pt_base->part.get_index_partition()); + pt->machine_view = machine_views[j]; + Domain part_domain = + runtime->get_index_space_domain(ctx, pt_base->parallel_is); + assert(pt->machine_view.get_domain() == part_domain); + list.push_back(pt); + } + } + } + } + // Perform fusion optimizations if (config.perform_fusion) { fprintf(stderr, "Applying fusion optimizations during compilation...\n"); @@ -3481,34 +3641,34 @@ void FFIterationConfig::reset() { // Default Config Parameters struct DefaultConfig { - const static int epochs = 1; + static int const epochs = 1; // const static int iterations = 1; - const static int batchSize = 64; - const static bool profiling = false; + static int const batchSize = 64; + static bool const profiling = false; constexpr static float learningRate = 0.01f; constexpr static float weightDecay = 0.0001f; - const static size_t workSpaceSize = (size_t)1 * 1024 * 1024 * 1024; // 2GB - const static int numNodes = 1; - const static int workersPerNode = 0; - const static int cpusPerNode = 0; - const static size_t searchBudget = -1; - const static size_t simulatorWorkSpaceSize = + static size_t const workSpaceSize = (size_t)2 * 1024 * 1024 * 1024; // 2GB + static int const numNodes = 1; + static int const workersPerNode = 0; + static int const cpusPerNode = 0; + static size_t const searchBudget = -1; + static size_t const simulatorWorkSpaceSize = (size_t)2 * 1024 * 1024 * 1024; // 2GB constexpr static float searchAlpha = 1.2f; - const static bool searchOverlapBackwardUpdate = false; - const static bool onlyDataParallel = false; - const static bool enableSampleParallel = true; - const static bool enableParameterParallel = false; - const static bool enableAttributeParallel = false; - const static bool enableInplaceOptimizations = false; - const static bool allowTensorOpMathConversion = false; - const static int machine_model_version = 0; - const static int simulator_segment_size = 16777216; // 16 MB - const static int simulator_max_num_segments = 1; - const static int base_optimize_threshold = 10; - const static bool enable_control_replication = true; + static bool const searchOverlapBackwardUpdate = false; + static bool const onlyDataParallel = false; + static bool const enableSampleParallel = true; + static bool const enableParameterParallel = false; + static bool const enableAttributeParallel = false; + static bool const enableInplaceOptimizations = false; + static bool const allowTensorOpMathConversion = false; + static int const machine_model_version = 0; + static int const simulator_segment_size = 16777216; // 16 MB + static int const simulator_max_num_segments = 1; + static int const base_optimize_threshold = 10; + static bool const enable_control_replication = true; // The default python data loader type is 2 to enable control replication - const static int python_data_loader_type = 2; + static int const python_data_loader_type = 2; }; FFConfig::FFConfig() { @@ -3528,6 +3688,8 @@ FFConfig::FFConfig() { search_overlap_backward_update = DefaultConfig::searchOverlapBackwardUpdate; computationMode = COMP_MODE_TRAINING; only_data_parallel = DefaultConfig::onlyDataParallel; + data_parallelism_degree = 1; + tensor_parallelism_degree = 1; enable_sample_parallel = DefaultConfig::enableSampleParallel; enable_parameter_parallel = DefaultConfig::enableParameterParallel; enable_attribute_parallel = DefaultConfig::enableAttributeParallel; @@ -3573,6 +3735,9 @@ FFConfig::FFConfig() { Runtime *runtime = Runtime::get_runtime(); lg_hlr = runtime; lg_ctx = Runtime::get_context(); + Rect<1> task_rect(Point<1>(0), Point<1>(workersPerNode * numNodes - 1)); + // Create an index space for tasks running on all GPUs + all_gpu_task_is = runtime->create_index_space(lg_ctx, task_rect); field_space = runtime->create_field_space(lg_ctx); } @@ -3633,6 +3798,16 @@ void FFConfig::parse_args(char **argv, int argc) { only_data_parallel = true; continue; } + // data parallelism degree + if (!strcmp(argv[i], "-data-parallelism-degree")) { + data_parallelism_degree = std::stoi(argv[++i]); + continue; + } + // tensor parallelism degree + if (!strcmp(argv[i], "-tensor-parallelism-degree")) { + tensor_parallelism_degree = std::stoi(argv[++i]); + continue; + } if ((!strcmp(argv[i], "--enable-parameter-parallel"))) { enable_parameter_parallel = true; continue; @@ -5210,6 +5385,49 @@ void register_flexflow_internal_tasks(Runtime *runtime, runtime->register_task_variant(registrar); } } + // AllReduce + { + TaskVariantRegistrar registrar(ALLREDUCE_INIT_TASK_ID, "AllReduce Init"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "AllReduce init Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } + { + TaskVariantRegistrar registrar(ALLREDUCE_FWD_TASK_ID, "AllReduce Forward"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "AllReduce Forward Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } + { + TaskVariantRegistrar registrar(ALLREDUCE_BWD_TASK_ID, "AllReduce Backward"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "AllReduce Backward Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } // FusedParallelOp { TaskVariantRegistrar registrar(FUSED_PARALLELOP_FWD_TASK_ID, @@ -5302,6 +5520,23 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar); } } + { + TaskVariantRegistrar registrar(ADAM_UNIFY_UPD_NCCL_TASK_ID, + "Adam unified NCCL Update"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant< + AdamOptimizer::nccl_unified_update_task>( + registrar, "Adam unified NCCL Update Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant( + registrar); + } + } #endif // Initializer { diff --git a/src/runtime/operator_params.cc b/src/runtime/operator_params.cc index 41dd37dec7..322d7840fb 100644 --- a/src/runtime/operator_params.cc +++ b/src/runtime/operator_params.cc @@ -28,6 +28,7 @@ #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" #include "flexflow/parallel_ops/combine.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/fused_parallel_op.h" #include "flexflow/parallel_ops/partition.h" #include "flexflow/parallel_ops/reduction.h" @@ -94,6 +95,8 @@ tl::optional get_op_parameters(Op const *op) { return ((Reduction *)op)->get_params(); case OP_COMBINE: return ((Combine *)op)->get_params(); + case OP_ALLREDUCE: + return ((AllReduce *)op)->get_params(); case OP_FUSED_PARALLEL: return ((FusedParallelOp *)op)->get_params(); case OP_TRANSPOSE: diff --git a/src/runtime/optimizer.cc b/src/runtime/optimizer.cc index c42a0c9aa6..91a16e8db7 100644 --- a/src/runtime/optimizer.cc +++ b/src/runtime/optimizer.cc @@ -333,6 +333,7 @@ void AdamOptimizer::init(void) { Context ctx = model->config.lg_ctx; Runtime *runtime = model->config.lg_hlr; Initializer *initializer = new ZeroInitializer(); + reservedWorkSpaceSize = 0; for (size_t i = 0; i < model->parameters.size(); i++) { ParallelTensor p = model->parameters[i]; Domain domain = @@ -381,6 +382,7 @@ void AdamOptimizer::update(const ParallelTensor p) { assert(v_values.find(p->region) != v_values.end()); assert(m_values.find(p->region) != m_values.end()); assert(p->owner_op != NULL); + reservedWorkSpaceSize += p->get_volume() * sizeof(float); if (p->sync_type == ParameterSyncType::PS) { TaskLauncher launcher(ADAM_UPD_PS_TASK_ID, TaskArgument(this, sizeof(AdamOptimizer)), @@ -492,6 +494,119 @@ void AdamOptimizer::update(const ParallelTensor p) { } } +void SGDOptimizer::unified_update(std::vector const parameters) { + //todo +} + +void AdamOptimizer::unified_update(std::vector const parameters) { + Context ctx = model->config.lg_ctx; + Runtime *runtime = model->config.lg_hlr; + const ParallelTensor p0 = parameters.at(0); + ArgumentMap argmap; + Domain domain = runtime->get_index_space_domain(ctx, p0->parallel_is); + switch (domain.get_dim()) { +#define DIMFUNC(DIM) \ + case DIM: { \ + Rect rect = domain; \ + int idx = 0; \ + for (PointInRectIterator it(rect); it(); it++) { \ + OpMeta *mp = p0->owner_op->meta[idx++]; \ + argmap.set_point(*it, TaskArgument(&mp, sizeof(OpMeta *))); \ + } \ + break; \ + } + LEGION_FOREACH_N(DIMFUNC) +#undef DIMFUNC + default: + assert(false); + } + + int offset = 0; + int processed_parameters_num = 0; + // printf("param size: %d\n", parameters.size()); + + size_t workSpaceSize = model->handlers->workSpaceSize * + model->config.workersPerNode * model->config.numNodes; + + while (processed_parameters_num < parameters.size()) { + parameters_num = 0; + + for(int i = processed_parameters_num; i < parameters.size(); i++){ + const ParallelTensor p = parameters.at(i); + assert(v_values.find(p->region) != v_values.end()); + assert(m_values.find(p->region) != m_values.end()); + assert(p->owner_op != NULL); + if (reservedWorkSpaceSize + p->get_volume() * sizeof(float) >= workSpaceSize) { + break; + } + reservedWorkSpaceSize += p->get_volume() * sizeof(float); + parameters_num += 1; + assert(p->sync_type == ParameterSyncType::NCCL); + assert(p->parallel_is != IndexSpace::NO_SPACE); + } + + // printf("parameters_num: %d %zu, %zu, %d\n", parameters_num, + // reservedWorkSpaceSize, model->handlers->workSpaceSize, + // parameters.size()); + assert(processed_parameters_num <= parameters.size()); + + IndexLauncher launcher(ADAM_UNIFY_UPD_NCCL_TASK_ID, + p0->parallel_is, + TaskArgument(this, sizeof(AdamOptimizer)), + argmap, + Predicate::TRUE_PRED, + false /*must_epoch*/, + 0 /*mapper_id*/, + p0->machine_view.hash()); + // launch a unified task + for (int j = 0; j < parameters_num; j++) { + const ParallelTensor p = parameters.at(processed_parameters_num + j); + + // regions[0]: region_grad + launcher.add_region_requirement(RegionRequirement(p->part_grad, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + p->region_grad)); + launcher.add_field(offset, FID_DATA); + // regions[1]: region + launcher.add_region_requirement(RegionRequirement( + p->part, 0 /*projection id*/, READ_WRITE, EXCLUSIVE, p->region)); + launcher.add_field(offset + 1, FID_DATA); + // regions[2]: w_region + launcher.add_region_requirement( + RegionRequirement(v_values[p->region]->part, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + v_values[p->region]->region)); + launcher.add_field(offset + 2, FID_DATA); + // regions[3]: m_region + launcher.add_region_requirement( + RegionRequirement(m_values[p->region]->part, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + m_values[p->region]->region)); + launcher.add_field(offset + 3, FID_DATA); + offset += 4; + } + + // update alpha, beta + for (int i = 0; i < parameters_num; i++) { + this->next(); + } + launcher.concurrent = true; + FutureMap fm = runtime->execute_index_space(ctx, launcher); + // runtime->execute_must_epoch(ctx, must_epoch_launcher); + runtime->issue_execution_fence(ctx); + reservedWorkSpaceSize = 0; + offset = 0; + processed_parameters_num += parameters_num; + } + parameters_num = 0; +} + void AdamOptimizer::ps_update_task(Task const *task, std::vector const ®ions, Context ctx, @@ -605,6 +720,72 @@ void AdamOptimizer::nccl_update_task(Task const *task, nccl_update_task_gpu(op, meta, w_grad_ptr, size, w_ptr, v_ptr, m_ptr); } + +void AdamOptimizer::nccl_unified_update_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + // assert(regions.size() == 4); + // assert(task->regions.size() == 4); + AdamOptimizer const *op = (AdamOptimizer *)task->args; + OpMeta const *meta = *((OpMeta **)task->local_args); + // FFHandler handler = *((FFHandler*) task->local_args); + Domain domain = runtime->get_index_space_domain( + ctx, task->regions[1].region.get_index_space()); + + // float const *w_grad_ptr[op->parameters_num]; + // float *w_ptr[op->parameters_num], *v_ptr[op->parameters_num], + // *m_ptr[op->parameters_num]; + + // hipMalloc(w_grad_ptr, sizeof(float*) * op->parameters_num); + // hipMalloc(w_ptr, sizeof(float*) * op->parameters_num); + // hipMalloc(v_ptr, sizeof(float*) * op->parameters_num); + // hipMalloc(m_ptr, sizeof(float*) * op->parameters_num); + GenericTensorAccessorR accWGrads[op->parameters_num]; + GenericTensorAccessorW accWs[op->parameters_num]; + GenericTensorAccessorW accVs[op->parameters_num]; + GenericTensorAccessorW accMs[op->parameters_num]; + size_t *size = new size_t[op->parameters_num]; + int offset = 0; + + // printf("parameters_num: %d\n", op->parameters_num); + + for (int i = 0; i < op->parameters_num; i++) { + accWGrads[i] = helperGetGenericTensorAccessorRO(DataType::DT_FLOAT, + regions[offset], + task->regions[offset], + FID_DATA, + ctx, + runtime); + accWs[i] = helperGetGenericTensorAccessorWO(DataType::DT_FLOAT, + regions[offset + 1], + task->regions[offset + 1], + FID_DATA, + ctx, + runtime); + accVs[i] = helperGetGenericTensorAccessorWO(DataType::DT_FLOAT, + regions[offset + 2], + task->regions[offset + 2], + FID_DATA, + ctx, + runtime); + accMs[i] = helperGetGenericTensorAccessorWO(DataType::DT_FLOAT, + regions[offset + 3], + task->regions[offset + 3], + FID_DATA, + ctx, + runtime); + offset += 4; + + size[i] = accWGrads[i].domain.get_volume(); + // w_grad_ptr[i] = accWGrad.get_float_ptr(); + // w_ptr[i] = accW.get_float_ptr(); + // v_ptr[i] = accV.get_float_ptr(); + // m_ptr[i] = accM.get_float_ptr(); + } + nccl_unified_update_task_gpu(op, meta, accWGrads, size, accWs, accVs, accMs); +} #endif }; // namespace FlexFlow diff --git a/src/runtime/optimizer_kernel.cpp b/src/runtime/optimizer_kernel.cpp index e71adc87a8..373eb3fe7a 100644 --- a/src/runtime/optimizer_kernel.cpp +++ b/src/runtime/optimizer_kernel.cpp @@ -204,6 +204,7 @@ __host__ void AdamOptimizer::ps_update_task_gpu(AdamOptimizer const *op, m_ptr, v_ptr, w_ptr); + // checkCUDA(hipDeviceSynchronize()); } @@ -245,6 +246,74 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, w_ptr); // checkCUDA(hipDeviceSynchronize()); } + +__host__ void AdamOptimizer::nccl_unified_update_task_gpu( + AdamOptimizer const *op, + OpMeta const *meta, + GenericTensorAccessorR *accWGrads, + size_t *size, + GenericTensorAccessorW *accWs, + GenericTensorAccessorW *accVs, + GenericTensorAccessorW *accMs) { + + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + // assert(op->reservedWorkSpaceSize < meta->handle.workSpaceSize); + + void *workSpace_ptr = meta->handle.workSpace; + + for (int i = 0; i < op->parameters_num; i++) { + hipMemcpyAsync(workSpace_ptr, + accWGrads[i].get_float_ptr(), + size[i] * sizeof(float), + hipMemcpyDeviceToDevice, + stream); + workSpace_ptr = + static_cast(workSpace_ptr) + size[i] * sizeof(float); + } + + // do allreduce once + checkNCCL(ncclAllReduce(meta->handle.workSpace, + (float *)meta->handle.workSpace, + meta->handle.workSpaceSize, + ncclFloat, + ncclSum, + meta->handle.ncclComm, + stream)); + + workSpace_ptr = static_cast(meta->handle.workSpace); + float alpha_t = op->alpha_t; + float beta1_t = op->beta1_t; + float beta2_t = op->beta2_t; + for (int i = 0; i < op->parameters_num; i++) { + // update + // printf("update %d\n", i); + hipLaunchKernelGGL(HIP_KERNEL_NAME(adam_update), + GET_BLOCKS(size[i]), + CUDA_NUM_THREADS, + 0, + stream, + size[i], + alpha_t, + op->beta1, + op->beta2, + op->weight_decay, + op->epsilon, + static_cast(workSpace_ptr), + accMs[i].get_float_ptr(), + accVs[i].get_float_ptr(), + accWs[i].get_float_ptr()); + workSpace_ptr = + static_cast(workSpace_ptr) + size[i] * sizeof(float); + + // update + beta1_t *= op->beta1; + beta2_t *= op->beta2; + alpha_t = op->alpha * sqrt(1 - beta2_t) / (1 - beta1_t); + } + + // checkCUDA(hipDeviceSynchronize()); +} #endif }; // namespace FlexFlow \ No newline at end of file diff --git a/src/runtime/optimizer_kernel.cu b/src/runtime/optimizer_kernel.cu index 5f654fbb5b..17adce94b8 100644 --- a/src/runtime/optimizer_kernel.cu +++ b/src/runtime/optimizer_kernel.cu @@ -216,6 +216,103 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, w_ptr); // checkCUDA(cudaDeviceSynchronize()); } + +__host__ void AdamOptimizer::nccl_unified_update_task_gpu( + AdamOptimizer const *op, + OpMeta const *meta, + GenericTensorAccessorR *accWGrads, + size_t *size, + GenericTensorAccessorW *accWs, + GenericTensorAccessorW *accVs, + GenericTensorAccessorW *accMs) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + // assert(op->reservedWorkSpaceSize < meta->handle.workSpaceSize); + + cudaEvent_t t_start, t_start1, t_start2, t_end; + cudaEventCreate(&t_start); + cudaEventCreate(&t_start1); + cudaEventCreate(&t_start2); + cudaEventCreate(&t_end); + cudaEventRecord(t_start, stream); + cudaEventRecord(t_start1, stream); + cudaEventRecord(t_start2, stream); + + void *allocate_ptr; + // = meta->handle.workSpace; + checkCUDA( + cudaMalloc(&allocate_ptr,meta->handle.workSpaceSize)); + + void *workSpace_ptr = allocate_ptr; + + for (int i = 0; i < op->parameters_num; i++) { + cudaMemcpyAsync(workSpace_ptr, + accWGrads[i].get_float_ptr(), + size[i] * sizeof(float), + cudaMemcpyDeviceToDevice, + stream); + workSpace_ptr = + static_cast(workSpace_ptr) + size[i] * sizeof(float); + } + + cudaEventRecord(t_end, stream); + checkCUDA(cudaEventSynchronize(t_end)); + float elapsed = 0; + checkCUDA(cudaEventElapsedTime(&elapsed, t_start1, t_end)); + cudaEventDestroy(t_start1); + printf("[optimizer] data copy time = %.2lfms\n", elapsed); + + // do allreduce once + checkNCCL(ncclAllReduce(meta->handle.workSpace, + (float *)meta->handle.workSpace, + meta->handle.workSpaceSize, + ncclFloat, + ncclSum, + meta->handle.ncclComm, + stream)); + cudaEventRecord(t_end, stream); + checkCUDA(cudaEventSynchronize(t_end)); + elapsed = 0; + checkCUDA(cudaEventElapsedTime(&elapsed, t_start2, t_end)); + cudaEventDestroy(t_start2); + printf("[optimizer] allreduce time = %.2lfms\n", elapsed); + + // workSpace_ptr = static_cast(meta->handle.workSpace); + workSpace_ptr = static_cast(allocate_ptr); + float alpha_t = op->alpha_t; + float beta1_t = op->beta1_t; + float beta2_t = op->beta2_t; + for (int i = 0; i < op->parameters_num; i++) { + // update + // printf("update %d\n", i); + adam_update<<>>( + size[i], + alpha_t, + op->beta1, + op->beta2, + op->weight_decay, + op->epsilon, + static_cast(workSpace_ptr), + accMs[i].get_float_ptr(), + accVs[i].get_float_ptr(), + accWs[i].get_float_ptr()); + workSpace_ptr = + static_cast(workSpace_ptr) + size[i] * sizeof(float); + + // update + beta1_t *= op->beta1; + beta2_t *= op->beta2; + alpha_t = op->alpha * sqrt(1 - beta2_t) / (1 - beta1_t); + } + cudaEventRecord(t_end, stream); + checkCUDA(cudaEventSynchronize(t_end)); + elapsed = 0; + checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); + cudaEventDestroy(t_start); + cudaEventDestroy(t_end); + checkCUDA(cudaFree(allocate_ptr)); + printf("[optimizer] total time = %.2lfms\n", elapsed); +} #endif }; // namespace FlexFlow diff --git a/src/runtime/parallel_tensor.cc b/src/runtime/parallel_tensor.cc index 18318db3ce..b9f3dc89f7 100644 --- a/src/runtime/parallel_tensor.cc +++ b/src/runtime/parallel_tensor.cc @@ -660,14 +660,24 @@ bool ParallelTensorBase::set_tensor(FFModel const *ff, // TODO: check data type matches // TODO: Currently we use a task launch, change to index launch for NCCL // parameter - size_t volume = 1, num_replicas = 0; + size_t volume = 1, num_replicas = 1; if (sync_type == ParameterSyncType::NCCL) { - Domain domain = runtime->get_index_space_domain(ctx, parallel_is); - num_replicas = domain.get_volume(); + // Domain domain = runtime->get_index_space_domain(ctx, parallel_is); + // num_replicas = domain.get_volume(); + for (int i = 0; i < this->num_dims; i++) { + if (this->dims[i].is_replica_dim) { + num_replicas *= this->dims[i].size; + } + } } else if (sync_type == ParameterSyncType::PS) { num_replicas = 1; } else { - num_replicas = 1; + for (int i = 0; i < this->num_dims; i++) { + if (this->dims[i].is_replica_dim) { + num_replicas *= this->dims[i].size; + } + } + // num_replicas = 1; } for (size_t i = 0; i < dim_sizes.size(); i++) { volume = volume * dim_sizes[i]; diff --git a/src/runtime/substitution.cc b/src/runtime/substitution.cc index e3adfec5b7..4f44a3a574 100644 --- a/src/runtime/substitution.cc +++ b/src/runtime/substitution.cc @@ -34,6 +34,7 @@ #include "flexflow/ops/split.h" #include "flexflow/parallel_ops/combine.h" #include "flexflow/parallel_ops/fused_parallel_op.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/partition.h" #include "flexflow/parallel_ops/reduction.h" #include "flexflow/parallel_ops/replicate.h" @@ -3743,6 +3744,12 @@ bool FFModel::convert_graph_to_operators( reduction->reduction_degree); break; } + case OP_ALLREDUCE: { + assert(inList.size() == 1); + AllReduce *allreduce = (AllReduce *)node.ptr; + new_op = new AllReduce(*this, inputs[0], allreduce->allreduce_dim); + break; + } case OP_FUSED_PARALLEL: { assert(inList.size() == 1); FusedParallelOp *fused = (FusedParallelOp *)node.ptr;