Extract and rename ExtendShape within headers
This commit is contained in:
parent
9472504604
commit
9315fc6ab8
@ -18,12 +18,23 @@ limitations under the License.
|
||||
#include <cmath>
|
||||
|
||||
#include "ruy/profiler/instrumentation.h" // from @ruy
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace reference_ops {
|
||||
|
||||
// TODO(b/135760455): Move this method anonymous namespace in a cc file.
|
||||
inline RuntimeShape ExtendShapeBatchToSpace(const RuntimeShape& shape) {
|
||||
if (shape.DimensionsCount() == 4) {
|
||||
return shape;
|
||||
}
|
||||
RuntimeShape new_shape(4, 1);
|
||||
new_shape.SetDim(0, shape.Dims(0));
|
||||
new_shape.SetDim(1, shape.Dims(1));
|
||||
new_shape.SetDim(3, shape.Dims(2));
|
||||
return new_shape;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void BatchToSpaceND(const RuntimeShape& unextended_input1_shape,
|
||||
const T* input1_data,
|
||||
@ -40,13 +51,9 @@ inline void BatchToSpaceND(const RuntimeShape& unextended_input1_shape,
|
||||
unextended_output_shape.DimensionsCount());
|
||||
|
||||
const RuntimeShape input1_shape =
|
||||
(unextended_input1_shape.DimensionsCount() == 4)
|
||||
? unextended_input1_shape
|
||||
: RuntimeShape::ExtendedShape(4, unextended_input1_shape);
|
||||
ExtendShapeBatchToSpace(unextended_input1_shape);
|
||||
const RuntimeShape output_shape =
|
||||
(unextended_output_shape.DimensionsCount() == 4)
|
||||
? unextended_output_shape
|
||||
: RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
ExtendShapeBatchToSpace(unextended_output_shape);
|
||||
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
|
@ -18,12 +18,14 @@ limitations under the License.
|
||||
#include <cmath>
|
||||
|
||||
#include "ruy/profiler/instrumentation.h" // from @ruy
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_ops {
|
||||
|
||||
inline RuntimeShape ExtendShape(const RuntimeShape& shape) {
|
||||
// TODO(b/135760455): Move this method anonymous namespace in a cc file.
|
||||
inline RuntimeShape ExtendShapeSpaceToBatch(const RuntimeShape& shape) {
|
||||
if (shape.DimensionsCount() == 4) {
|
||||
return shape;
|
||||
}
|
||||
@ -51,8 +53,10 @@ inline void SpaceToBatchND(const SpaceToBatchParams& params,
|
||||
unextended_output_shape.DimensionsCount());
|
||||
|
||||
// Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
|
||||
const RuntimeShape input1_shape = ExtendShape(unextended_input1_shape);
|
||||
const RuntimeShape output_shape = ExtendShape(unextended_output_shape);
|
||||
const RuntimeShape input1_shape =
|
||||
ExtendShapeSpaceToBatch(unextended_input1_shape);
|
||||
const RuntimeShape output_shape =
|
||||
ExtendShapeSpaceToBatch(unextended_output_shape);
|
||||
|
||||
const int depth = input1_shape.Dims(3);
|
||||
const int input_width = input1_shape.Dims(2);
|
||||
|
@ -32,8 +32,6 @@ constexpr int kOutputTensor = 0;
|
||||
constexpr int kInputDims = 4;
|
||||
constexpr int kOutputDims = 4;
|
||||
|
||||
} // namespace.
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
@ -95,6 +93,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
|
@ -33,8 +33,6 @@ constexpr int kOutputTensor = 0;
|
||||
constexpr int kInputDims = 4;
|
||||
constexpr int kOutputDims = 4;
|
||||
|
||||
} // namespace.
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(SpaceToBatchParams));
|
||||
@ -109,6 +107,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace.
|
||||
|
||||
TfLiteRegistration Register_SPACE_TO_BATCH_ND() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
|
Loading…
x
Reference in New Issue
Block a user