Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 36 additions & 46 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,40 +101,40 @@ __host__ void
assert((int)regions.size() == fused->numInputs + fused->numWeights +
fused->numOutputs +
softmax_grad_additional_region);
GenericTensorAccessorR input_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorR weight_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorW output_accessor[MAX_NUM_OUTPUTS];
std::vector<GenericTensorAccessorR> input_accessor;
std::vector<GenericTensorAccessorR> weight_accessor;
std::vector<GenericTensorAccessorW> output_accessor;
assert(fused->numInputs <= MAX_NUM_INPUTS);
for (int i = 0; i < fused->numInputs; i++) {
input_accessor[i] =
input_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->input_data_types[i],
regions[i],
task->regions[i],
FID_DATA,
ctx,
runtime);
runtime));
}
int roff = fused->numInputs;
assert(fused->numWeights <= MAX_NUM_WEIGHTS);
for (int i = 0; i < fused->numWeights; i++) {
weight_accessor[i] =
weight_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->weight_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
}
roff += fused->numWeights;
assert(fused->numOutputs <= MAX_NUM_OUTPUTS);
for (int i = 0; i < fused->numOutputs; i++) {
output_accessor[i] =
output_accessor.push_back(
helperGetGenericTensorAccessorWO(fused->output_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
}
roff += fused->numOutputs;
// Assert that all meta share the same dnn/blas handler
Expand All @@ -153,39 +153,28 @@ __host__ void

int ioff = 0, woff = 0, ooff = 0;
for (int op = 0; op < fused->numOperators; op++) {
#if 0
std::cout << get_operator_type_name(fused->op_op_type[op]) << std::endl;
#endif
GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorR my_weight_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorW my_output_accessor[MAX_NUM_OUTPUTS];
std::vector<GenericTensorAccessorR> my_input_accessor;
std::vector<GenericTensorAccessorR> my_weight_accessor;
std::vector<GenericTensorAccessorW> my_output_accessor;
for (int i = 0; i < fused->op_num_inputs[op]; i++) {
int my_off = fused->op_input_idx[i + ioff];
if (fused->op_input_source[i + ioff] == SOURCE_INPUT) {
my_input_accessor[i] = input_accessor[my_off];
#if 0
printf("\tmy_input_accessor[%i] = input_accessor[%i]\n", i, my_off);
#endif
my_input_accessor.push_back(input_accessor[my_off]);
} else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) {
my_input_accessor[i] = output_accessor[my_off];
#if 0
printf("\tmy_input_accessor[%i] = output_accessor[%i]\n", i, my_off);
#endif
my_input_accessor.push_back(output_accessor[my_off]);
} else {
assert(false);
}
}
for (int i = 0; i < fused->op_num_weights[op]; i++) {
assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT);
my_weight_accessor[i] = weight_accessor[fused->op_weight_idx[i + woff]];
my_weight_accessor.push_back(
weight_accessor[fused->op_weight_idx[i + woff]]);
}
for (int i = 0; i < fused->op_num_outputs[op]; i++) {
int my_off = fused->op_output_idx[i + ooff];
assert(fused->op_output_source[i + ooff] == SOURCE_OUTPUT);
my_output_accessor[i] = output_accessor[my_off];
#if 0
printf("\tmy_output_accessor[%i] = output_accessor[%i]\n", i, my_off);
#endif
my_output_accessor.push_back(output_accessor[my_off]);
}
switch (fused->op_op_type[op]) {
case OP_CONCAT: {
Expand All @@ -195,7 +184,7 @@ __host__ void
int num_inputs = fused->op_num_inputs[op];
Kernels::Concat::forward_kernel_wrapper(m,
my_output_accessor[0],
my_input_accessor,
my_input_accessor.data(),
num_inputs,
m->legion_axis);
break;
Expand Down Expand Up @@ -1242,40 +1231,40 @@ __host__ void FusedOp::forward_task(Task const *task,
assert(regions.size() == task->regions.size());
assert((int)regions.size() ==
fused->numInputs + fused->numWeights + fused->numOutputs);
GenericTensorAccessorR input_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorR weight_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorW output_accessor[MAX_NUM_OUTPUTS];
std::vector<GenericTensorAccessorR> input_accessor;
std::vector<GenericTensorAccessorR> weight_accessor;
std::vector<GenericTensorAccessorW> output_accessor;
assert(fused->numInputs <= MAX_NUM_INPUTS);
for (int i = 0; i < fused->numInputs; i++) {
input_accessor[i] =
input_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->input_data_types[i],
regions[i],
task->regions[i],
FID_DATA,
ctx,
runtime);
runtime));
}
int roff = fused->numInputs;
assert(fused->numWeights <= MAX_NUM_WEIGHTS);
for (int i = 0; i < fused->numWeights; i++) {
weight_accessor[i] =
weight_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->weight_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
}
roff += fused->numWeights;
assert(fused->numOutputs <= MAX_NUM_OUTPUTS);
for (int i = 0; i < fused->numOutputs; i++) {
output_accessor[i] =
output_accessor.push_back(
helperGetGenericTensorAccessorWO(fused->output_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
}
// Assert that all meta share the same dnn/blas handler
int start = 0;
Expand All @@ -1293,31 +1282,32 @@ __host__ void FusedOp::forward_task(Task const *task,

int ioff = 0, woff = 0, ooff = 0;
for (int op = 0; op < fused->numOperators; op++) {
GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorR my_weight_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorW my_output_accessor[MAX_NUM_OUTPUTS];
std::vector<GenericTensorAccessorR> my_input_accessor;
std::vector<GenericTensorAccessorR> my_weight_accessor;
std::vector<GenericTensorAccessorW> my_output_accessor;
for (int i = 0; i < fused->op_num_inputs[op]; i++) {
int my_off = fused->op_input_idx[i + ioff];
if (fused->op_input_source[i + ioff] == SOURCE_INPUT) {
assert(my_off < fused->numInputs);
my_input_accessor[i] = input_accessor[my_off];
my_input_accessor.push_back(input_accessor[my_off]);
} else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) {
assert(my_off < fused->numOutputs);
my_input_accessor[i] = output_accessor[my_off];
my_input_accessor.push_back(output_accessor[my_off]);
} else {
assert(false);
}
}
for (int i = 0; i < fused->op_num_weights[op]; i++) {
assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT);
assert(fused->op_weight_idx[i + woff] < fused->numWeights);
my_weight_accessor[i] = weight_accessor[fused->op_weight_idx[i + woff]];
my_weight_accessor.push_back(
weight_accessor[fused->op_weight_idx[i + woff]]);
}
for (int i = 0; i < fused->op_num_outputs[op]; i++) {
int my_off = fused->op_output_idx[i + ooff];
assert(fused->op_output_source[i + ooff] == SOURCE_OUTPUT);
assert(my_off < fused->numOutputs);
my_output_accessor[i] = output_accessor[my_off];
my_output_accessor.push_back(output_accessor[my_off]);
}
switch (fused->op_op_type[op]) {
case OP_CONCAT: {
Expand All @@ -1327,7 +1317,7 @@ __host__ void FusedOp::forward_task(Task const *task,
int num_inputs = fused->op_num_inputs[op];
Kernels::Concat::forward_kernel_wrapper(m,
my_output_accessor[0],
my_input_accessor,
my_input_accessor.data(),
num_inputs,
m->legion_axis);
break;
Expand Down
Loading