Add negative axis support to Pack.

PiperOrigin-RevId: 239265388
This commit is contained in:
A. Unique TensorFlower 2019-03-19 14:07:23 -07:00 committed by TensorFlower Gardener
parent e52328ebc5
commit d43453baa5
2 changed files with 72 additions and 15 deletions

View File

@ -28,16 +28,20 @@ namespace {
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLitePackParams* data =
TfLitePackParams* data =
reinterpret_cast<TfLitePackParams*>(node->builtin_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input0 = GetInput(context, node, 0);
const int dimension_size = NumDimensions(input0) + 1;
if (data->axis < 0) {
data->axis += dimension_size;
}
TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
// TODO(renjieliu): Support negative axis.
TF_LITE_ENSURE(context, data->axis >= 0);
if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 &&
input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) {
@ -53,7 +57,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Resize output. rank R will become rank R + 1
const int dimension_size = NumDimensions(input0) + 1;
const TfLiteIntArray* input_shape = input0->dims;
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size);
int i = 0;
@ -81,8 +84,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <typename T>
void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
int values_count, int axis) {
TfLiteStatus PackImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* output, int values_count, int axis) {
TF_LITE_ENSURE(context, axis >= 0);
VectorOfTensors<T> all_inputs(*context, *node->inputs);
tflite::PackParams op_params;
op_params.axis = axis;
@ -90,6 +95,7 @@ void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
GetTensorShape(output), GetTensorData<T>(output));
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@ -99,24 +105,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (output->type) {
case kTfLiteFloat32: {
PackImpl<float>(context, node, output, data->values_count, data->axis);
break;
return PackImpl<float>(context, node, output, data->values_count,
data->axis);
}
case kTfLiteUInt8: {
PackImpl<uint8_t>(context, node, output, data->values_count, data->axis);
break;
return PackImpl<uint8_t>(context, node, output, data->values_count,
data->axis);
}
case kTfLiteInt8: {
PackImpl<int8_t>(context, node, output, data->values_count, data->axis);
break;
return PackImpl<int8_t>(context, node, output, data->values_count,
data->axis);
}
case kTfLiteInt32: {
PackImpl<int32_t>(context, node, output, data->values_count, data->axis);
break;
return PackImpl<int32_t>(context, node, output, data->values_count,
data->axis);
}
case kTfLiteInt64: {
PackImpl<int64_t>(context, node, output, data->values_count, data->axis);
break;
return PackImpl<int64_t>(context, node, output, data->values_count,
data->axis);
}
default: {
context->ReportError(context, "Type '%s' is not supported by pack.",

View File

@ -72,6 +72,16 @@ TEST(PackOpTest, FloatThreeInputsDifferentAxis) {
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TEST(PackOpTest, FloatThreeInputsNegativeAxis) {
PackOpModel<float> model({TensorType_FLOAT32, {2}}, -1, 3);
model.SetInput(0, {1, 4});
model.SetInput(1, {2, 5});
model.SetInput(2, {3, 6});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TEST(PackOpTest, FloatMultilDimensions) {
PackOpModel<float> model({TensorType_FLOAT32, {2, 3}}, 1, 2);
model.SetInput(0, {1, 2, 3, 4, 5, 6});
@ -116,6 +126,16 @@ TEST(PackOpTest, Int32ThreeInputsDifferentAxis) {
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TEST(PackOpTest, Int32ThreeInputsNegativeAxis) {
PackOpModel<int32_t> model({TensorType_INT32, {2}}, -1, 3);
model.SetInput(0, {1, 4});
model.SetInput(1, {2, 5});
model.SetInput(2, {3, 6});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TEST(PackOpTest, Int32MultilDimensions) {
PackOpModel<int32_t> model({TensorType_INT32, {2, 3}}, 1, 2);
model.SetInput(0, {1, 2, 3, 4, 5, 6});
@ -149,6 +169,17 @@ TEST(PackOpTest, Int64ThreeInputsDifferentAxis) {
ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)}));
}
TEST(PackOpTest, Int64ThreeInputsNegativeAxis) {
PackOpModel<int64_t> model({TensorType_INT64, {2}}, -1, 3);
model.SetInput(0, {1LL << 33, 4});
model.SetInput(1, {2, 5});
model.SetInput(2, {3, -(1LL << 34)});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)}));
}
TEST(PackOpTest, Int64MultilDimensions) {
PackOpModel<int64_t> model({TensorType_INT64, {2, 3}}, 1, 2);
model.SetInput(0, {1LL << 33, 2, 3, 4, 5, 6});
@ -181,6 +212,16 @@ TEST(PackOpTest, Uint8ThreeInputsDifferentAxis) {
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TEST(PackOpTest, Uint8ThreeInputsNegativeAxis) {
PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, -1, 3);
model.SetInput(0, {1, 4});
model.SetInput(1, {2, 5});
model.SetInput(2, {3, 6});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TEST(PackOpTest, Uint8MultilDimensions) {
PackOpModel<uint8_t> model({TensorType_UINT8, {2, 3}}, 1, 2);
model.SetInput(0, {1, 2, 3, 4, 5, 6});
@ -212,6 +253,16 @@ TEST(PackOpTest, Int8ThreeInputsDifferentAxis) {
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TEST(PackOpTest, Int8ThreeInputsNegativeAxis) {
PackOpModel<int8_t> model({TensorType_INT8, {2}}, -1, 3);
model.SetInput(0, {1, 4});
model.SetInput(1, {2, 5});
model.SetInput(2, {3, 6});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TEST(PackOpTest, Int8MultilDimensions) {
PackOpModel<int8_t> model({TensorType_INT8, {2, 3}}, 1, 2);
model.SetInput(0, {1, 2, 3, 4, 5, 6});