Extract and rename ExtendShape within headers

This commit is contained in:
Nat Jeffries 2021-02-01 16:05:51 -08:00
parent 9472504604
commit 9315fc6ab8
4 changed files with 26 additions and 15 deletions

View File

@ -18,12 +18,23 @@ limitations under the License.
#include <cmath>
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
// TODO(b/135760455): Move this method anonymous namespace in a cc file.
inline RuntimeShape ExtendShapeBatchToSpace(const RuntimeShape& shape) {
if (shape.DimensionsCount() == 4) {
return shape;
}
RuntimeShape new_shape(4, 1);
new_shape.SetDim(0, shape.Dims(0));
new_shape.SetDim(1, shape.Dims(1));
new_shape.SetDim(3, shape.Dims(2));
return new_shape;
}
template <typename T>
inline void BatchToSpaceND(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
@ -40,13 +51,9 @@ inline void BatchToSpaceND(const RuntimeShape& unextended_input1_shape,
unextended_output_shape.DimensionsCount());
const RuntimeShape input1_shape =
(unextended_input1_shape.DimensionsCount() == 4)
? unextended_input1_shape
: RuntimeShape::ExtendedShape(4, unextended_input1_shape);
ExtendShapeBatchToSpace(unextended_input1_shape);
const RuntimeShape output_shape =
(unextended_output_shape.DimensionsCount() == 4)
? unextended_output_shape
: RuntimeShape::ExtendedShape(4, unextended_output_shape);
ExtendShapeBatchToSpace(unextended_output_shape);
const int output_width = output_shape.Dims(2);
const int output_height = output_shape.Dims(1);

View File

@ -18,12 +18,14 @@ limitations under the License.
#include <cmath>
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
inline RuntimeShape ExtendShape(const RuntimeShape& shape) {
// TODO(b/135760455): Move this method anonymous namespace in a cc file.
inline RuntimeShape ExtendShapeSpaceToBatch(const RuntimeShape& shape) {
if (shape.DimensionsCount() == 4) {
return shape;
}
@ -51,8 +53,10 @@ inline void SpaceToBatchND(const SpaceToBatchParams& params,
unextended_output_shape.DimensionsCount());
// Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
const RuntimeShape input1_shape = ExtendShape(unextended_input1_shape);
const RuntimeShape output_shape = ExtendShape(unextended_output_shape);
const RuntimeShape input1_shape =
ExtendShapeSpaceToBatch(unextended_input1_shape);
const RuntimeShape output_shape =
ExtendShapeSpaceToBatch(unextended_output_shape);
const int depth = input1_shape.Dims(3);
const int input_width = input1_shape.Dims(2);

View File

@ -32,8 +32,6 @@ constexpr int kOutputTensor = 0;
constexpr int kInputDims = 4;
constexpr int kOutputDims = 4;
} // namespace.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@ -95,6 +93,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
} // namespace.
TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
return {/*init=*/nullptr,
/*free=*/nullptr,

View File

@ -33,8 +33,6 @@ constexpr int kOutputTensor = 0;
constexpr int kInputDims = 4;
constexpr int kOutputDims = 4;
} // namespace.
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(SpaceToBatchParams));
@ -109,6 +107,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
} // namespace.
TfLiteRegistration Register_SPACE_TO_BATCH_ND() {
return {/*init=*/Init,
/*free=*/nullptr,