diff --git a/tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h b/tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h index 4abf5c7725c..cda46a2673c 100644 --- a/tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h +++ b/tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h @@ -18,12 +18,23 @@ limitations under the License. #include #include "ruy/profiler/instrumentation.h" // from @ruy -#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { - namespace reference_ops { +// TODO(b/135760455): Move this method anonymous namespace in a cc file. +inline RuntimeShape ExtendShapeBatchToSpace(const RuntimeShape& shape) { + if (shape.DimensionsCount() == 4) { + return shape; + } + RuntimeShape new_shape(4, 1); + new_shape.SetDim(0, shape.Dims(0)); + new_shape.SetDim(1, shape.Dims(1)); + new_shape.SetDim(3, shape.Dims(2)); + return new_shape; +} + template inline void BatchToSpaceND(const RuntimeShape& unextended_input1_shape, const T* input1_data, @@ -40,13 +51,9 @@ inline void BatchToSpaceND(const RuntimeShape& unextended_input1_shape, unextended_output_shape.DimensionsCount()); const RuntimeShape input1_shape = - (unextended_input1_shape.DimensionsCount() == 4) - ? unextended_input1_shape - : RuntimeShape::ExtendedShape(4, unextended_input1_shape); + ExtendShapeBatchToSpace(unextended_input1_shape); const RuntimeShape output_shape = - (unextended_output_shape.DimensionsCount() == 4) - ? unextended_output_shape - : RuntimeShape::ExtendedShape(4, unextended_output_shape); + ExtendShapeBatchToSpace(unextended_output_shape); const int output_width = output_shape.Dims(2); const int output_height = output_shape.Dims(1); diff --git a/tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h b/tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h index f58ba372e0e..7f844152c83 100644 --- a/tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h +++ b/tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h @@ -18,12 +18,14 @@ limitations under the License. #include #include "ruy/profiler/instrumentation.h" // from @ruy +#include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace reference_ops { -inline RuntimeShape ExtendShape(const RuntimeShape& shape) { +// TODO(b/135760455): Move this method anonymous namespace in a cc file. +inline RuntimeShape ExtendShapeSpaceToBatch(const RuntimeShape& shape) { if (shape.DimensionsCount() == 4) { return shape; } @@ -51,8 +53,10 @@ inline void SpaceToBatchND(const SpaceToBatchParams& params, unextended_output_shape.DimensionsCount()); // Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C. - const RuntimeShape input1_shape = ExtendShape(unextended_input1_shape); - const RuntimeShape output_shape = ExtendShape(unextended_output_shape); + const RuntimeShape input1_shape = + ExtendShapeSpaceToBatch(unextended_input1_shape); + const RuntimeShape output_shape = + ExtendShapeSpaceToBatch(unextended_output_shape); const int depth = input1_shape.Dims(3); const int input_width = input1_shape.Dims(2); diff --git a/tensorflow/lite/micro/kernels/batch_to_space_nd.cc b/tensorflow/lite/micro/kernels/batch_to_space_nd.cc index 8bb9b68d39c..35aeb922d8a 100644 --- a/tensorflow/lite/micro/kernels/batch_to_space_nd.cc +++ b/tensorflow/lite/micro/kernels/batch_to_space_nd.cc @@ -32,8 +32,6 @@ constexpr int kOutputTensor = 0; constexpr int kInputDims = 4; constexpr int kOutputDims = 4; -} // namespace. - TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -95,6 +93,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +} // namespace. + TfLiteRegistration Register_BATCH_TO_SPACE_ND() { return {/*init=*/nullptr, /*free=*/nullptr, diff --git a/tensorflow/lite/micro/kernels/space_to_batch_nd.cc b/tensorflow/lite/micro/kernels/space_to_batch_nd.cc index c4c47205e6c..d2f8e138b96 100644 --- a/tensorflow/lite/micro/kernels/space_to_batch_nd.cc +++ b/tensorflow/lite/micro/kernels/space_to_batch_nd.cc @@ -33,8 +33,6 @@ constexpr int kOutputTensor = 0; constexpr int kInputDims = 4; constexpr int kOutputDims = 4; -} // namespace. - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(SpaceToBatchParams)); @@ -109,6 +107,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +} // namespace. + TfLiteRegistration Register_SPACE_TO_BATCH_ND() { return {/*init=*/Init, /*free=*/nullptr,