Skip to content

Commit eb85c46

Browse files
pramods-cadveblush
andauthored
Optimization in LSTM for batch > 1 cases on HiFi. (#3244)
* Optimization in LSTM for batch > 1 cases on HiFi. * Addressed review comments. * Fixed code style errors. * Removing unnecessary comment. --------- Co-authored-by: Esun Kim <veblush@google.com>
1 parent 6c1c1a8 commit eb85c46

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,8 @@ void LstmStepManager::UpdateBatch() {
473473
// Multi-batch for time_major input
474474
RuntimeShape LstmStepManager::InputShape() const {
475475
int batch_size = 1;
476-
if (size_info_.time_major) {
476+
if (size_info_.time_major ||
477+
(size_info_.batch_size > 1 && size_info_.time_steps == 1)) {
477478
batch_size = size_info_.batch_size;
478479
}
479480
const int dims[2] = {batch_size, size_info_.input_dimension};
@@ -485,7 +486,8 @@ RuntimeShape LstmStepManager::InputShape() const {
485486
// Multi-batch for time_major input
486487
RuntimeShape LstmStepManager::StateShape() const {
487488
int batch_size = 1;
488-
if (size_info_.time_major) {
489+
if (size_info_.time_major ||
490+
(size_info_.batch_size > 1 && size_info_.time_steps == 1)) {
489491
batch_size = size_info_.batch_size;
490492
}
491493
const int dims[2] = {batch_size, size_info_.state_dimension};

tensorflow/lite/micro/kernels/xtensa/lstm_eval.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,14 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
661661
kernel_content.GetInternalTensor(tflite::kLstmInputTensor);
662662
TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor();
663663

664-
int time_major = step_info.time_major();
665-
int num_batches = time_major == 0 ? 1 : step_info.batch_size();
666-
int input_dimension = step_info.input_dimension();
667-
int state_dimension = step_info.state_dimension();
664+
const auto& size_info = op_data.size_info;
665+
const int time_major = step_info.time_major();
666+
const int batch_size = size_info.batch_size;
667+
const int time_steps = size_info.time_steps;
668+
const int num_batches = time_major == 0 ? (time_steps == 1 ? batch_size : 1)
669+
: step_info.batch_size();
670+
const int input_dimension = step_info.input_dimension();
671+
const int state_dimension = step_info.state_dimension();
668672

669673
// Check offset validity to avoid memory overflow
670674
TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension,
@@ -803,8 +807,10 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
803807
// prepare for the next time step
804808
step_info.UpdateTime();
805809
}
810+
} else if (size_info.batch_size > 1 && size_info.time_steps == 1) {
811+
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
812+
step_info, op_data, kernel_content, buffers);
806813
} else {
807-
// batch first, unable to size the input data. single batch inference
808814
for (int b = 0; b < size_info.batch_size; b++) {
809815
for (int t = 0; t < size_info.time_steps; t++) {
810816
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(

0 commit comments

Comments
 (0)