Skip to content

Commit 62cbc8e

Browse files
Porting Dynamic_Update_Slice operator from TFLite (#3246)
* Sync files related to Reverse_V2 from TFLite #3110 * PRelu Int16x8 support in RefC * Fix code style in prelu_test.cc * 1. Reverted the copyright year * Resolved compilation error for Int8x8 test case * Add Dynamic_Update_Slice support to TFLM * Code style error correction * Code style correction * Replaced hard coded MaxDimensions to RuntimeShape::kMaxSmallSize * 1. Added more test cases \n2.Removed unused code * Updates for test failure on ARM * Code style updates * Updates on test case failure for ARM * Updates on test case failure for ARM * Code style updates --------- Co-authored-by: Esun Kim <veblush@google.com>
1 parent 9b9b1e3 commit 62cbc8e

File tree

10 files changed

+505
-1
lines changed

10 files changed

+505
-1
lines changed

python/tflite_micro/python_ops_resolver.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ PythonOpsResolver::PythonOpsResolver() {
4747
AddDequantize();
4848
AddDetectionPostprocess();
4949
AddDiv();
50+
AddDynamicUpdateSlice();
5051
AddElu();
5152
AddEmbeddingLookup();
5253
AddEnergy();

tensorflow/lite/micro/kernels/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ tflm_kernel_cc_library(
262262
"dequantize_common.cc",
263263
"detection_postprocess.cc",
264264
"div.cc",
265+
"dynamic_update_slice.cc",
265266
"elementwise.cc",
266267
"elu.cc",
267268
"embedding_lookup.cc",
@@ -352,6 +353,7 @@ tflm_kernel_cc_library(
352353
"decode_state_prune.h",
353354
"depthwise_conv.h",
354355
"dequantize.h",
356+
"dynamic_update_slice.h",
355357
"ethosu.h",
356358
"fully_connected.h",
357359
"hard_swish.h",
@@ -824,6 +826,21 @@ tflm_cc_test(
824826
],
825827
)
826828

829+
tflm_cc_test(
830+
name = "dynamic_update_slice_test",
831+
srcs = [
832+
"dynamic_update_slice_test.cc",
833+
],
834+
deps = [
835+
":kernel_runner",
836+
"//tensorflow/lite/c:common",
837+
"//tensorflow/lite/micro:debug_log",
838+
"//tensorflow/lite/micro:op_resolvers",
839+
"//tensorflow/lite/micro:test_helpers",
840+
"//tensorflow/lite/micro/testing:micro_test",
841+
],
842+
)
843+
827844
tflm_cc_test(
828845
name = "elementwise_test",
829846
srcs = ["elementwise_test.cc"],

tensorflow/lite/micro/kernels/Makefile.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \
131131
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \
132132
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \
133133
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/div_test.cc \
134+
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc \
134135
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elementwise_test.cc \
135136
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elu_test.cc \
136137
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/embedding_lookup_test.cc \
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/lite/micro/kernels/dynamic_update_slice.h"
16+
17+
#include "tensorflow/lite/c/builtin_op_data.h"
18+
#include "tensorflow/lite/c/common.h"
19+
#include "tensorflow/lite/kernels/internal/common.h"
20+
#include "tensorflow/lite/kernels/internal/quantization_util.h"
21+
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22+
#include "tensorflow/lite/kernels/kernel_util.h"
23+
#include "tensorflow/lite/kernels/op_macros.h"
24+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
25+
#include "tensorflow/lite/micro/micro_log.h"
26+
#include "tensorflow/lite/micro/micro_utils.h"
27+
28+
namespace tflite {
29+
30+
constexpr int kMaxDimensions = RuntimeShape::kMaxSmallSize;
31+
32+
namespace {
33+
34+
void CalculateClampedStartIndices(int num_dims, const int64_t* raw_indices_data,
35+
const int32_t* input_dims_data,
36+
const int32_t* update_dims_data,
37+
int32_t* clamped_start_indices_output) {
38+
for (int i = 0; i < num_dims; ++i) {
39+
clamped_start_indices_output[i] = static_cast<int32_t>(
40+
std::min<int64_t>(std::max<int64_t>(0, raw_indices_data[i]),
41+
input_dims_data[i] - update_dims_data[i]));
42+
}
43+
return;
44+
}
45+
46+
// Recursive helper for N-dimensional slice update.
47+
template <typename T>
48+
void UpdateSliceRecursive(int current_dim, int max_dims,
49+
const int32_t* output_strides,
50+
const int32_t* update_strides,
51+
const int32_t* update_dims_data,
52+
const T* update_tensor_data,
53+
const int32_t* clamped_start_indices,
54+
T* output_tensor_data) {
55+
if (current_dim == max_dims) return;
56+
output_tensor_data +=
57+
clamped_start_indices[current_dim] * output_strides[current_dim];
58+
if (current_dim == max_dims - 1) {
59+
std::memcpy(output_tensor_data, update_tensor_data,
60+
update_dims_data[max_dims - 1] * sizeof(T));
61+
} else {
62+
for (int i = 0; i < update_dims_data[current_dim]; ++i) {
63+
UpdateSliceRecursive<T>(current_dim + 1, max_dims, output_strides,
64+
update_strides, update_dims_data,
65+
update_tensor_data, clamped_start_indices,
66+
output_tensor_data);
67+
output_tensor_data += output_strides[current_dim];
68+
update_tensor_data += update_strides[current_dim];
69+
}
70+
}
71+
}
72+
73+
// Main dispatch function for Eval, templated on data type.
74+
template <typename T>
75+
void EvalImpl(const TfLiteEvalTensor* operand_eval,
76+
const TfLiteEvalTensor* update_eval, const int64_t* indices_eval,
77+
TfLiteEvalTensor* output_eval) {
78+
const RuntimeShape operand_shape =
79+
tflite::micro::GetTensorShape(operand_eval);
80+
const RuntimeShape update_shape = tflite::micro::GetTensorShape(update_eval);
81+
const T* update_tensor_data = tflite::micro::GetTensorData<T>(update_eval);
82+
T* output_tensor_data = tflite::micro::GetTensorData<T>(output_eval);
83+
84+
const int num_dims = operand_shape.DimensionsCount();
85+
if (operand_shape.FlatSize() == update_shape.FlatSize()) {
86+
std::memcpy(output_tensor_data, update_tensor_data,
87+
ElementCount(*operand_eval->dims) * sizeof(T));
88+
return;
89+
}
90+
91+
// If the operation is not done in-place, copy the input data to the output.
92+
if (operand_eval->data.data != output_eval->data.data) {
93+
std::memcpy(output_eval->data.data, operand_eval->data.data,
94+
ElementCount(*operand_eval->dims) * sizeof(T));
95+
}
96+
97+
// If update tensor is empty, no actual update is needed after operand copy.
98+
if (ElementCount(*update_eval->dims) == 0) {
99+
return;
100+
}
101+
102+
// Calculate clamped start indices (stack-allocated)
103+
int32_t clamped_start_indices[kMaxDimensions];
104+
CalculateClampedStartIndices(num_dims, indices_eval, operand_shape.DimsData(),
105+
update_shape.DimsData(), clamped_start_indices);
106+
107+
// Calculate strides (stack-allocated)
108+
int32_t output_stride[kMaxDimensions];
109+
int32_t update_stride[kMaxDimensions];
110+
output_stride[num_dims - 1] = 1;
111+
update_stride[num_dims - 1] = 1;
112+
for (int i = num_dims - 2; i >= 0; --i) {
113+
output_stride[i] = output_stride[i + 1] * operand_shape.Dims(i + 1);
114+
update_stride[i] = update_stride[i + 1] * update_shape.Dims(i + 1);
115+
}
116+
117+
// Perform the N-dimensional update
118+
// The recursive function needs base pointers and initial offsets.
119+
UpdateSliceRecursive<T>(
120+
/*current_dim=*/0, num_dims, output_stride, update_stride,
121+
update_shape.DimsData(), update_tensor_data, clamped_start_indices,
122+
output_tensor_data);
123+
}
124+
125+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
126+
MicroContext* micro_context = GetMicroContext(context);
127+
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
128+
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
129+
130+
// Use MicroContext to allocate temporary tensors for inspection
131+
// This is a robust pattern shown in EMBEDDING_LOOKUP.
132+
TfLiteTensor* operand = micro_context->AllocateTempInputTensor(
133+
node, kDynamicUpdateSliceOperandTensor);
134+
TF_LITE_ENSURE(context, operand != nullptr);
135+
136+
TfLiteTensor* update = micro_context->AllocateTempInputTensor(
137+
node, kDynamicUpdateSliceUpdateTensor);
138+
TF_LITE_ENSURE(context, update != nullptr);
139+
140+
TfLiteTensor* start_indices = micro_context->AllocateTempInputTensor(
141+
node, kDynamicUpdateSliceStartIndicesTensor);
142+
TF_LITE_ENSURE(context, start_indices != nullptr);
143+
144+
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
145+
node, kDynamicUpdateSliceOutputTensor);
146+
TF_LITE_ENSURE(context, output != nullptr);
147+
148+
// Type checks
149+
TF_LITE_ENSURE_TYPES_EQ(context, operand->type, update->type);
150+
TF_LITE_ENSURE(context, start_indices->type == kTfLiteInt32 ||
151+
start_indices->type == kTfLiteInt64);
152+
153+
TF_LITE_ENSURE_EQ(context, NumDimensions(start_indices), 1);
154+
TF_LITE_ENSURE_EQ(context, SizeOfDimension(start_indices, 0),
155+
NumDimensions(operand));
156+
157+
TF_LITE_ENSURE_EQ(context, NumDimensions(update), NumDimensions(operand));
158+
// Check that update dimensions are not larger than operand dimensions
159+
for (int i = 0; i < NumDimensions(operand); ++i) {
160+
TF_LITE_ENSURE(context,
161+
SizeOfDimension(update, i) <= SizeOfDimension(operand, i));
162+
}
163+
164+
// Deallocate temporary tensors
165+
micro_context->DeallocateTempTfLiteTensor(operand);
166+
micro_context->DeallocateTempTfLiteTensor(update);
167+
micro_context->DeallocateTempTfLiteTensor(start_indices);
168+
micro_context->DeallocateTempTfLiteTensor(output);
169+
170+
return kTfLiteOk;
171+
}
172+
173+
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
174+
const TfLiteEvalTensor* operand_eval = tflite::micro::GetEvalInput(
175+
context, node, kDynamicUpdateSliceOperandTensor);
176+
const TfLiteEvalTensor* update_eval = tflite::micro::GetEvalInput(
177+
context, node, kDynamicUpdateSliceUpdateTensor);
178+
const TfLiteEvalTensor* indices_eval = tflite::micro::GetEvalInput(
179+
context, node, kDynamicUpdateSliceStartIndicesTensor);
180+
TfLiteEvalTensor* output_eval = tflite::micro::GetEvalOutput(
181+
context, node, kDynamicUpdateSliceOutputTensor);
182+
183+
const auto& input_shape = tflite::micro::GetTensorShape(operand_eval);
184+
const int input_dims = input_shape.DimensionsCount();
185+
int64_t indices_data_i64[kMaxDimensions];
186+
if (indices_eval->type == kTfLiteInt32) {
187+
for (int i = 0; i < input_dims; i++)
188+
indices_data_i64[i] = static_cast<int64_t>(indices_eval->data.i32[i]);
189+
} else if (indices_eval->type == kTfLiteInt64) {
190+
for (int i = 0; i < input_dims; i++)
191+
indices_data_i64[i] = indices_eval->data.i64[i];
192+
} else {
193+
TF_LITE_KERNEL_LOG(context,
194+
"DynamicUpdateSlice only currently supports "
195+
"int32 or int64 indices type, got %d.",
196+
indices_eval->type);
197+
return kTfLiteError;
198+
}
199+
// Dispatch based on tensor type
200+
switch (operand_eval->type) {
201+
case kTfLiteFloat32:
202+
EvalImpl<float>(operand_eval, update_eval, indices_data_i64, output_eval);
203+
break;
204+
case kTfLiteInt8:
205+
EvalImpl<int8_t>(operand_eval, update_eval, indices_data_i64,
206+
output_eval);
207+
break;
208+
case kTfLiteInt16:
209+
EvalImpl<int16_t>(operand_eval, update_eval, indices_data_i64,
210+
output_eval);
211+
break;
212+
case kTfLiteInt32:
213+
EvalImpl<int32_t>(operand_eval, update_eval, indices_data_i64,
214+
output_eval);
215+
break;
216+
default:
217+
MicroPrintf("DYNAMIC_UPDATE_SLICE: Operand type %s not supported.",
218+
TfLiteTypeGetName(operand_eval->type));
219+
return kTfLiteError;
220+
}
221+
return kTfLiteOk;
222+
}
223+
224+
} // namespace
225+
226+
TFLMRegistration Register_DYNAMIC_UPDATE_SLICE() {
227+
return tflite::micro::RegisterOp(/*init=*/nullptr, /*prepare=*/Prepare,
228+
/*invoke=*/Eval);
229+
}
230+
231+
} // namespace tflite
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_DYNAMIC_UPDATE_SLICE_H_
16+
#define TENSORFLOW_LITE_MICRO_KERNELS_DYNAMIC_UPDATE_SLICE_H_
17+
18+
#include "tensorflow/lite/c/builtin_op_data.h"
19+
#include "tensorflow/lite/kernels/internal/types.h"
20+
#include "tensorflow/lite/micro/micro_common.h"
21+
22+
namespace tflite {
23+
24+
constexpr int kDynamicUpdateSliceOperandTensor = 0;
25+
constexpr int kDynamicUpdateSliceUpdateTensor = 1;
26+
constexpr int kDynamicUpdateSliceStartIndicesTensor = 2;
27+
constexpr int kDynamicUpdateSliceOutputTensor = 0;
28+
29+
TfLiteStatus PrepareDynamicUpdateSlice(TfLiteContext* context,
30+
TfLiteNode* node);
31+
32+
TFLMRegistration Register_DYNAMIC_UPDATE_SLICE();
33+
34+
} // namespace tflite
35+
36+
#endif // TENSORFLOW_LITE_MICRO_KERNELS_DYNAMIC_UPDATE_SLICE_H_

0 commit comments

Comments
 (0)