@@ -661,10 +661,14 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
661661 kernel_content.GetInternalTensor (tflite::kLstmInputTensor );
662662 TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor ();
663663
664- int time_major = step_info.time_major ();
665- int num_batches = time_major == 0 ? 1 : step_info.batch_size ();
666- int input_dimension = step_info.input_dimension ();
667- int state_dimension = step_info.state_dimension ();
664+ const auto & size_info = op_data.size_info ;
665+ const int time_major = step_info.time_major ();
666+ const int batch_size = size_info.batch_size ;
667+ const int time_steps = size_info.time_steps ;
668+ const int num_batches = time_major == 0 ? (time_steps == 1 ? batch_size : 1 )
669+ : step_info.batch_size ();
670+ const int input_dimension = step_info.input_dimension ();
671+ const int state_dimension = step_info.state_dimension ();
668672
669673 // Check offset validity to avoid memory overflow
670674 TFLITE_DCHECK_LE (step_info.InputOffset () + num_batches * input_dimension,
@@ -803,8 +807,10 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
803807 // prepare for the next time step
804808 step_info.UpdateTime ();
805809 }
810+ } else if (size_info.batch_size > 1 && size_info.time_steps == 1 ) {
811+ lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
812+ step_info, op_data, kernel_content, buffers);
806813 } else {
807- // batch first, unable to size the input data. single batch inference
808814 for (int b = 0 ; b < size_info.batch_size ; b++) {
809815 for (int t = 0 ; t < size_info.time_steps ; t++) {
810816 lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
0 commit comments