From d43453baa525c04d401fc88a57602030fa16a7b8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Mar 2019 14:07:23 -0700 Subject: [PATCH] Add negative axis support to Pack. PiperOrigin-RevId: 239265388 --- tensorflow/lite/kernels/pack.cc | 36 ++++++++++++-------- tensorflow/lite/kernels/pack_test.cc | 51 ++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 15 deletions(-) diff --git a/tensorflow/lite/kernels/pack.cc b/tensorflow/lite/kernels/pack.cc index e26abaaff1e..3515c3e873f 100644 --- a/tensorflow/lite/kernels/pack.cc +++ b/tensorflow/lite/kernels/pack.cc @@ -28,16 +28,20 @@ namespace { constexpr int kOutputTensor = 0; TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - const TfLitePackParams* data = + TfLitePackParams* data = reinterpret_cast(node->builtin_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input0 = GetInput(context, node, 0); + const int dimension_size = NumDimensions(input0) + 1; + if (data->axis < 0) { + data->axis += dimension_size; + } TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis); - // TODO(renjieliu): Support negative axis. TF_LITE_ENSURE(context, data->axis >= 0); + if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 && input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 && input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) { @@ -53,7 +57,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // Resize output. rank R will become rank R + 1 - const int dimension_size = NumDimensions(input0) + 1; const TfLiteIntArray* input_shape = input0->dims; TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size); int i = 0; @@ -81,8 +84,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } template -void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output, - int values_count, int axis) { +TfLiteStatus PackImpl(TfLiteContext* context, TfLiteNode* node, + TfLiteTensor* output, int values_count, int axis) { + TF_LITE_ENSURE(context, axis >= 0); + VectorOfTensors all_inputs(*context, *node->inputs); tflite::PackParams op_params; op_params.axis = axis; @@ -90,6 +95,7 @@ void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output, reference_ops::Pack(op_params, all_inputs.shapes(), all_inputs.data(), GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { @@ -99,24 +105,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); switch (output->type) { case kTfLiteFloat32: { - PackImpl(context, node, output, data->values_count, data->axis); - break; + return PackImpl(context, node, output, data->values_count, + data->axis); } case kTfLiteUInt8: { - PackImpl(context, node, output, data->values_count, data->axis); - break; + return PackImpl(context, node, output, data->values_count, + data->axis); } case kTfLiteInt8: { - PackImpl(context, node, output, data->values_count, data->axis); - break; + return PackImpl(context, node, output, data->values_count, + data->axis); } case kTfLiteInt32: { - PackImpl(context, node, output, data->values_count, data->axis); - break; + return PackImpl(context, node, output, data->values_count, + data->axis); } case kTfLiteInt64: { - PackImpl(context, node, output, data->values_count, data->axis); - break; + return PackImpl(context, node, output, data->values_count, + data->axis); } default: { context->ReportError(context, "Type '%s' is not supported by pack.", diff --git a/tensorflow/lite/kernels/pack_test.cc b/tensorflow/lite/kernels/pack_test.cc index f44111567fc..b40bb45b76a 100644 --- a/tensorflow/lite/kernels/pack_test.cc +++ b/tensorflow/lite/kernels/pack_test.cc @@ -72,6 +72,16 @@ TEST(PackOpTest, FloatThreeInputsDifferentAxis) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } +TEST(PackOpTest, FloatThreeInputsNegativeAxis) { + PackOpModel model({TensorType_FLOAT32, {2}}, -1, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + TEST(PackOpTest, FloatMultilDimensions) { PackOpModel model({TensorType_FLOAT32, {2, 3}}, 1, 2); model.SetInput(0, {1, 2, 3, 4, 5, 6}); @@ -116,6 +126,16 @@ TEST(PackOpTest, Int32ThreeInputsDifferentAxis) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } +TEST(PackOpTest, Int32ThreeInputsNegativeAxis) { + PackOpModel model({TensorType_INT32, {2}}, -1, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + TEST(PackOpTest, Int32MultilDimensions) { PackOpModel model({TensorType_INT32, {2, 3}}, 1, 2); model.SetInput(0, {1, 2, 3, 4, 5, 6}); @@ -149,6 +169,17 @@ TEST(PackOpTest, Int64ThreeInputsDifferentAxis) { ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)})); } +TEST(PackOpTest, Int64ThreeInputsNegativeAxis) { + PackOpModel model({TensorType_INT64, {2}}, -1, 3); + model.SetInput(0, {1LL << 33, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, -(1LL << 34)}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)})); +} + TEST(PackOpTest, Int64MultilDimensions) { PackOpModel model({TensorType_INT64, {2, 3}}, 1, 2); model.SetInput(0, {1LL << 33, 2, 3, 4, 5, 6}); @@ -181,6 +212,16 @@ TEST(PackOpTest, Uint8ThreeInputsDifferentAxis) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } +TEST(PackOpTest, Uint8ThreeInputsNegativeAxis) { + PackOpModel model({TensorType_UINT8, {2}}, -1, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + TEST(PackOpTest, Uint8MultilDimensions) { PackOpModel model({TensorType_UINT8, {2, 3}}, 1, 2); model.SetInput(0, {1, 2, 3, 4, 5, 6}); @@ -212,6 +253,16 @@ TEST(PackOpTest, Int8ThreeInputsDifferentAxis) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } +TEST(PackOpTest, Int8ThreeInputsNegativeAxis) { + PackOpModel model({TensorType_INT8, {2}}, -1, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + TEST(PackOpTest, Int8MultilDimensions) { PackOpModel model({TensorType_INT8, {2, 3}}, 1, 2); model.SetInput(0, {1, 2, 3, 4, 5, 6});