Add negative axis support to Pack.
PiperOrigin-RevId: 239265388
This commit is contained in:
parent
e52328ebc5
commit
d43453baa5
@ -28,16 +28,20 @@ namespace {
|
|||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const TfLitePackParams* data =
|
TfLitePackParams* data =
|
||||||
reinterpret_cast<TfLitePackParams*>(node->builtin_data);
|
reinterpret_cast<TfLitePackParams*>(node->builtin_data);
|
||||||
|
|
||||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
|
||||||
const TfLiteTensor* input0 = GetInput(context, node, 0);
|
const TfLiteTensor* input0 = GetInput(context, node, 0);
|
||||||
|
const int dimension_size = NumDimensions(input0) + 1;
|
||||||
|
if (data->axis < 0) {
|
||||||
|
data->axis += dimension_size;
|
||||||
|
}
|
||||||
TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
|
TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
|
||||||
// TODO(renjieliu): Support negative axis.
|
|
||||||
TF_LITE_ENSURE(context, data->axis >= 0);
|
TF_LITE_ENSURE(context, data->axis >= 0);
|
||||||
|
|
||||||
if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
|
if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
|
||||||
input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 &&
|
input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 &&
|
||||||
input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) {
|
input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) {
|
||||||
@ -53,7 +57,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resize output. rank R will become rank R + 1
|
// Resize output. rank R will become rank R + 1
|
||||||
const int dimension_size = NumDimensions(input0) + 1;
|
|
||||||
const TfLiteIntArray* input_shape = input0->dims;
|
const TfLiteIntArray* input_shape = input0->dims;
|
||||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size);
|
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size);
|
||||||
int i = 0;
|
int i = 0;
|
||||||
@ -81,8 +84,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
|
TfLiteStatus PackImpl(TfLiteContext* context, TfLiteNode* node,
|
||||||
int values_count, int axis) {
|
TfLiteTensor* output, int values_count, int axis) {
|
||||||
|
TF_LITE_ENSURE(context, axis >= 0);
|
||||||
|
|
||||||
VectorOfTensors<T> all_inputs(*context, *node->inputs);
|
VectorOfTensors<T> all_inputs(*context, *node->inputs);
|
||||||
tflite::PackParams op_params;
|
tflite::PackParams op_params;
|
||||||
op_params.axis = axis;
|
op_params.axis = axis;
|
||||||
@ -90,6 +95,7 @@ void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
|
|||||||
|
|
||||||
reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
|
reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
|
||||||
GetTensorShape(output), GetTensorData<T>(output));
|
GetTensorShape(output), GetTensorData<T>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@ -99,24 +105,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
switch (output->type) {
|
switch (output->type) {
|
||||||
case kTfLiteFloat32: {
|
case kTfLiteFloat32: {
|
||||||
PackImpl<float>(context, node, output, data->values_count, data->axis);
|
return PackImpl<float>(context, node, output, data->values_count,
|
||||||
break;
|
data->axis);
|
||||||
}
|
}
|
||||||
case kTfLiteUInt8: {
|
case kTfLiteUInt8: {
|
||||||
PackImpl<uint8_t>(context, node, output, data->values_count, data->axis);
|
return PackImpl<uint8_t>(context, node, output, data->values_count,
|
||||||
break;
|
data->axis);
|
||||||
}
|
}
|
||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
PackImpl<int8_t>(context, node, output, data->values_count, data->axis);
|
return PackImpl<int8_t>(context, node, output, data->values_count,
|
||||||
break;
|
data->axis);
|
||||||
}
|
}
|
||||||
case kTfLiteInt32: {
|
case kTfLiteInt32: {
|
||||||
PackImpl<int32_t>(context, node, output, data->values_count, data->axis);
|
return PackImpl<int32_t>(context, node, output, data->values_count,
|
||||||
break;
|
data->axis);
|
||||||
}
|
}
|
||||||
case kTfLiteInt64: {
|
case kTfLiteInt64: {
|
||||||
PackImpl<int64_t>(context, node, output, data->values_count, data->axis);
|
return PackImpl<int64_t>(context, node, output, data->values_count,
|
||||||
break;
|
data->axis);
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
context->ReportError(context, "Type '%s' is not supported by pack.",
|
context->ReportError(context, "Type '%s' is not supported by pack.",
|
||||||
|
@ -72,6 +72,16 @@ TEST(PackOpTest, FloatThreeInputsDifferentAxis) {
|
|||||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PackOpTest, FloatThreeInputsNegativeAxis) {
|
||||||
|
PackOpModel<float> model({TensorType_FLOAT32, {2}}, -1, 3);
|
||||||
|
model.SetInput(0, {1, 4});
|
||||||
|
model.SetInput(1, {2, 5});
|
||||||
|
model.SetInput(2, {3, 6});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(PackOpTest, FloatMultilDimensions) {
|
TEST(PackOpTest, FloatMultilDimensions) {
|
||||||
PackOpModel<float> model({TensorType_FLOAT32, {2, 3}}, 1, 2);
|
PackOpModel<float> model({TensorType_FLOAT32, {2, 3}}, 1, 2);
|
||||||
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
||||||
@ -116,6 +126,16 @@ TEST(PackOpTest, Int32ThreeInputsDifferentAxis) {
|
|||||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PackOpTest, Int32ThreeInputsNegativeAxis) {
|
||||||
|
PackOpModel<int32_t> model({TensorType_INT32, {2}}, -1, 3);
|
||||||
|
model.SetInput(0, {1, 4});
|
||||||
|
model.SetInput(1, {2, 5});
|
||||||
|
model.SetInput(2, {3, 6});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(PackOpTest, Int32MultilDimensions) {
|
TEST(PackOpTest, Int32MultilDimensions) {
|
||||||
PackOpModel<int32_t> model({TensorType_INT32, {2, 3}}, 1, 2);
|
PackOpModel<int32_t> model({TensorType_INT32, {2, 3}}, 1, 2);
|
||||||
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
||||||
@ -149,6 +169,17 @@ TEST(PackOpTest, Int64ThreeInputsDifferentAxis) {
|
|||||||
ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)}));
|
ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PackOpTest, Int64ThreeInputsNegativeAxis) {
|
||||||
|
PackOpModel<int64_t> model({TensorType_INT64, {2}}, -1, 3);
|
||||||
|
model.SetInput(0, {1LL << 33, 4});
|
||||||
|
model.SetInput(1, {2, 5});
|
||||||
|
model.SetInput(2, {3, -(1LL << 34)});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||||
|
EXPECT_THAT(model.GetOutput(),
|
||||||
|
ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(PackOpTest, Int64MultilDimensions) {
|
TEST(PackOpTest, Int64MultilDimensions) {
|
||||||
PackOpModel<int64_t> model({TensorType_INT64, {2, 3}}, 1, 2);
|
PackOpModel<int64_t> model({TensorType_INT64, {2, 3}}, 1, 2);
|
||||||
model.SetInput(0, {1LL << 33, 2, 3, 4, 5, 6});
|
model.SetInput(0, {1LL << 33, 2, 3, 4, 5, 6});
|
||||||
@ -181,6 +212,16 @@ TEST(PackOpTest, Uint8ThreeInputsDifferentAxis) {
|
|||||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PackOpTest, Uint8ThreeInputsNegativeAxis) {
|
||||||
|
PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, -1, 3);
|
||||||
|
model.SetInput(0, {1, 4});
|
||||||
|
model.SetInput(1, {2, 5});
|
||||||
|
model.SetInput(2, {3, 6});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(PackOpTest, Uint8MultilDimensions) {
|
TEST(PackOpTest, Uint8MultilDimensions) {
|
||||||
PackOpModel<uint8_t> model({TensorType_UINT8, {2, 3}}, 1, 2);
|
PackOpModel<uint8_t> model({TensorType_UINT8, {2, 3}}, 1, 2);
|
||||||
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
||||||
@ -212,6 +253,16 @@ TEST(PackOpTest, Int8ThreeInputsDifferentAxis) {
|
|||||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PackOpTest, Int8ThreeInputsNegativeAxis) {
|
||||||
|
PackOpModel<int8_t> model({TensorType_INT8, {2}}, -1, 3);
|
||||||
|
model.SetInput(0, {1, 4});
|
||||||
|
model.SetInput(1, {2, 5});
|
||||||
|
model.SetInput(2, {3, 6});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(PackOpTest, Int8MultilDimensions) {
|
TEST(PackOpTest, Int8MultilDimensions) {
|
||||||
PackOpModel<int8_t> model({TensorType_INT8, {2, 3}}, 1, 2);
|
PackOpModel<int8_t> model({TensorType_INT8, {2, 3}}, 1, 2);
|
||||||
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
||||||
|
Loading…
Reference in New Issue
Block a user