Add negative axis support to Pack.
PiperOrigin-RevId: 239265388
This commit is contained in:
parent
e52328ebc5
commit
d43453baa5
@ -28,16 +28,20 @@ namespace {
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLitePackParams* data =
|
||||
TfLitePackParams* data =
|
||||
reinterpret_cast<TfLitePackParams*>(node->builtin_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input0 = GetInput(context, node, 0);
|
||||
const int dimension_size = NumDimensions(input0) + 1;
|
||||
if (data->axis < 0) {
|
||||
data->axis += dimension_size;
|
||||
}
|
||||
TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
|
||||
// TODO(renjieliu): Support negative axis.
|
||||
TF_LITE_ENSURE(context, data->axis >= 0);
|
||||
|
||||
if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
|
||||
input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 &&
|
||||
input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) {
|
||||
@ -53,7 +57,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
// Resize output. rank R will become rank R + 1
|
||||
const int dimension_size = NumDimensions(input0) + 1;
|
||||
const TfLiteIntArray* input_shape = input0->dims;
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size);
|
||||
int i = 0;
|
||||
@ -81,8 +84,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
|
||||
int values_count, int axis) {
|
||||
TfLiteStatus PackImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteTensor* output, int values_count, int axis) {
|
||||
TF_LITE_ENSURE(context, axis >= 0);
|
||||
|
||||
VectorOfTensors<T> all_inputs(*context, *node->inputs);
|
||||
tflite::PackParams op_params;
|
||||
op_params.axis = axis;
|
||||
@ -90,6 +95,7 @@ void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
|
||||
|
||||
reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
|
||||
GetTensorShape(output), GetTensorData<T>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
@ -99,24 +105,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
switch (output->type) {
|
||||
case kTfLiteFloat32: {
|
||||
PackImpl<float>(context, node, output, data->values_count, data->axis);
|
||||
break;
|
||||
return PackImpl<float>(context, node, output, data->values_count,
|
||||
data->axis);
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
PackImpl<uint8_t>(context, node, output, data->values_count, data->axis);
|
||||
break;
|
||||
return PackImpl<uint8_t>(context, node, output, data->values_count,
|
||||
data->axis);
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
PackImpl<int8_t>(context, node, output, data->values_count, data->axis);
|
||||
break;
|
||||
return PackImpl<int8_t>(context, node, output, data->values_count,
|
||||
data->axis);
|
||||
}
|
||||
case kTfLiteInt32: {
|
||||
PackImpl<int32_t>(context, node, output, data->values_count, data->axis);
|
||||
break;
|
||||
return PackImpl<int32_t>(context, node, output, data->values_count,
|
||||
data->axis);
|
||||
}
|
||||
case kTfLiteInt64: {
|
||||
PackImpl<int64_t>(context, node, output, data->values_count, data->axis);
|
||||
break;
|
||||
return PackImpl<int64_t>(context, node, output, data->values_count,
|
||||
data->axis);
|
||||
}
|
||||
default: {
|
||||
context->ReportError(context, "Type '%s' is not supported by pack.",
|
||||
|
@ -72,6 +72,16 @@ TEST(PackOpTest, FloatThreeInputsDifferentAxis) {
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(PackOpTest, FloatThreeInputsNegativeAxis) {
|
||||
PackOpModel<float> model({TensorType_FLOAT32, {2}}, -1, 3);
|
||||
model.SetInput(0, {1, 4});
|
||||
model.SetInput(1, {2, 5});
|
||||
model.SetInput(2, {3, 6});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(PackOpTest, FloatMultilDimensions) {
|
||||
PackOpModel<float> model({TensorType_FLOAT32, {2, 3}}, 1, 2);
|
||||
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
||||
@ -116,6 +126,16 @@ TEST(PackOpTest, Int32ThreeInputsDifferentAxis) {
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(PackOpTest, Int32ThreeInputsNegativeAxis) {
|
||||
PackOpModel<int32_t> model({TensorType_INT32, {2}}, -1, 3);
|
||||
model.SetInput(0, {1, 4});
|
||||
model.SetInput(1, {2, 5});
|
||||
model.SetInput(2, {3, 6});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(PackOpTest, Int32MultilDimensions) {
|
||||
PackOpModel<int32_t> model({TensorType_INT32, {2, 3}}, 1, 2);
|
||||
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
||||
@ -149,6 +169,17 @@ TEST(PackOpTest, Int64ThreeInputsDifferentAxis) {
|
||||
ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)}));
|
||||
}
|
||||
|
||||
TEST(PackOpTest, Int64ThreeInputsNegativeAxis) {
|
||||
PackOpModel<int64_t> model({TensorType_INT64, {2}}, -1, 3);
|
||||
model.SetInput(0, {1LL << 33, 4});
|
||||
model.SetInput(1, {2, 5});
|
||||
model.SetInput(2, {3, -(1LL << 34)});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({1LL << 33, 2LL, 3LL, 4LL, 5LL, -(1LL << 34)}));
|
||||
}
|
||||
|
||||
TEST(PackOpTest, Int64MultilDimensions) {
|
||||
PackOpModel<int64_t> model({TensorType_INT64, {2, 3}}, 1, 2);
|
||||
model.SetInput(0, {1LL << 33, 2, 3, 4, 5, 6});
|
||||
@ -181,6 +212,16 @@ TEST(PackOpTest, Uint8ThreeInputsDifferentAxis) {
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(PackOpTest, Uint8ThreeInputsNegativeAxis) {
|
||||
PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, -1, 3);
|
||||
model.SetInput(0, {1, 4});
|
||||
model.SetInput(1, {2, 5});
|
||||
model.SetInput(2, {3, 6});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(PackOpTest, Uint8MultilDimensions) {
|
||||
PackOpModel<uint8_t> model({TensorType_UINT8, {2, 3}}, 1, 2);
|
||||
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
||||
@ -212,6 +253,16 @@ TEST(PackOpTest, Int8ThreeInputsDifferentAxis) {
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(PackOpTest, Int8ThreeInputsNegativeAxis) {
|
||||
PackOpModel<int8_t> model({TensorType_INT8, {2}}, -1, 3);
|
||||
model.SetInput(0, {1, 4});
|
||||
model.SetInput(1, {2, 5});
|
||||
model.SetInput(2, {3, 6});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(PackOpTest, Int8MultilDimensions) {
|
||||
PackOpModel<int8_t> model({TensorType_INT8, {2, 3}}, 1, 2);
|
||||
model.SetInput(0, {1, 2, 3, 4, 5, 6});
|
||||
|
Loading…
Reference in New Issue
Block a user