Add EXPAND_DIMS support to NN API delegate

PiperOrigin-RevId: 255578517
This commit is contained in:
A. Unique TensorFlower 2019-06-28 04:25:31 -07:00 committed by TensorFlower Gardener
parent 2fd2eff6cb
commit d97ccf68c9
4 changed files with 76 additions and 26 deletions

View File

@ -1760,6 +1760,25 @@ class NNAPIDelegateKernel {
}; };
} }
break; break;
case kTfLiteBuiltinExpandDims: {
const auto input_type = context->tensors[node->inputs->data[0]].type;
const auto axis = context->tensors[node->inputs->data[1]];
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
(input_type == kTfLiteFloat16 || input_type == kTfLiteFloat32 ||
input_type == kTfLiteInt32 || input_type == kTfLiteUInt8) &&
// TFLite supports axis also as int64 but NNAPI only int32
(axis.type == kTfLiteInt32 &&
axis.allocation_type == kTfLiteMmapRo)) {
return [](const NNAPIOpMappingArgs& mapping_args)
-> ANeuralNetworksOperationType {
const TfLiteTensor& axis_param =
mapping_args.context
->tensors[mapping_args.node->inputs->data[1]];
mapping_args.builder->AddScalarInt32Operand(*axis_param.data.i32);
return ANEURALNETWORKS_EXPAND_DIMS;
};
}
} break;
default: default:
// All other operators are not mapped. // All other operators are not mapped.
return nullptr; return nullptr;
@ -2224,6 +2243,10 @@ class NNAPIDelegateKernel {
// Everything is added during Map since input tensors // Everything is added during Map since input tensors
// have different order. // have different order.
continue; continue;
} else if (reg->builtin_code == kTfLiteBuiltinExpandDims &&
input_pos == 1) {
// The axis param is added during Map
continue;
} else { } else {
TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op, TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op,
input_tensor_flags)); input_tensor_flags));

View File

@ -1370,6 +1370,7 @@ cc_test(
name = "expand_dims_test", name = "expand_dims_test",
size = "small", size = "small",
srcs = ["expand_dims_test.cc"], srcs = ["expand_dims_test.cc"],
tags = ["tflite_nnapi"],
deps = [ deps = [
":builtin_ops", ":builtin_ops",
":test_main", ":test_main",

View File

@ -25,22 +25,39 @@ namespace {
using ::testing::ElementsAreArray; using ::testing::ElementsAreArray;
enum class TestType {
CONST = 0,
DYNAMIC = 1,
};
template <typename InputType>
class ExpandDimsOpModel : public SingleOpModel { class ExpandDimsOpModel : public SingleOpModel {
public: public:
ExpandDimsOpModel(std::initializer_list<int> input_shape, ExpandDimsOpModel(int axis, std::initializer_list<int> input_shape,
TensorType input_type) { std::initializer_list<InputType> input_data,
input_ = AddInput(input_type); TestType input_tensor_types) {
axis_ = AddInput(TensorType_INT32); if (input_tensor_types == TestType::DYNAMIC) {
output_ = AddOutput(input_type); input_ = AddInput(GetTensorType<InputType>());
axis_ = AddInput(TensorType_INT32);
} else {
input_ =
AddConstInput(GetTensorType<InputType>(), input_data, input_shape);
axis_ = AddConstInput(TensorType_INT32, {axis}, {1});
}
output_ = AddOutput(GetTensorType<InputType>());
SetBuiltinOp(BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions, SetBuiltinOp(BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions,
0); 0);
BuildInterpreter({input_shape, {1}}); BuildInterpreter({input_shape, {1}});
if (input_tensor_types == TestType::DYNAMIC) {
PopulateTensor<InputType>(input_, input_data);
PopulateTensor<int32_t>(axis_, {axis});
}
} }
void SetInputFloat(std::initializer_list<float> data) { std::vector<InputType> GetValues() {
PopulateTensor<float>(input_, data); return ExtractVector<InputType>(output_);
} }
void SetAxis(int axis) { PopulateTensor<int32_t>(axis_, {axis}); }
std::vector<float> GetValuesFloat() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); } std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected: protected:
@ -49,29 +66,37 @@ class ExpandDimsOpModel : public SingleOpModel {
int output_; int output_;
}; };
TEST(ExpandDimsOpTest, DifferentAxis) { class ExpandDimsOpTest : public ::testing::TestWithParam<TestType> {};
ExpandDimsOpModel m({2, 2}, TensorType_FLOAT32);
TEST_P(ExpandDimsOpTest, PositiveAxis) {
std::initializer_list<float> values = {-1.f, 1.f, -2.f, 2.f}; std::initializer_list<float> values = {-1.f, 1.f, -2.f, 2.f};
m.SetInputFloat(values);
m.SetAxis(0);
m.Invoke();
EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2}));
m.SetAxis(1); ExpandDimsOpModel<float> axis_0(0, {2, 2}, values, GetParam());
m.Invoke(); axis_0.Invoke();
EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); EXPECT_THAT(axis_0.GetValues(), ElementsAreArray(values));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2})); EXPECT_THAT(axis_0.GetOutputShape(), ElementsAreArray({1, 2, 2}));
m.SetAxis(2); ExpandDimsOpModel<float> axis_1(1, {2, 2}, values, GetParam());
m.Invoke(); axis_1.Invoke();
EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); EXPECT_THAT(axis_1.GetValues(), ElementsAreArray(values));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1})); EXPECT_THAT(axis_1.GetOutputShape(), ElementsAreArray({2, 1, 2}));
m.SetAxis(-1); ExpandDimsOpModel<float> axis_2(2, {2, 2}, values, GetParam());
axis_2.Invoke();
EXPECT_THAT(axis_2.GetValues(), ElementsAreArray(values));
EXPECT_THAT(axis_2.GetOutputShape(), ElementsAreArray({2, 2, 1}));
}
TEST_P(ExpandDimsOpTest, NegativeAxis) {
std::initializer_list<float> values = {-1.f, 1.f, -2.f, 2.f};
ExpandDimsOpModel<float> m(-1, {2, 2}, values, GetParam());
m.Invoke(); m.Invoke();
EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); EXPECT_THAT(m.GetValues(), ElementsAreArray(values));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1})); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1}));
} }
INSTANTIATE_TEST_SUITE_P(ExpandDimsOpTest, ExpandDimsOpTest,
::testing::Values(TestType::DYNAMIC, TestType::CONST));
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -93,6 +93,7 @@ enum {
ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM = 42, ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM = 42,
ANEURALNETWORKS_EQUAL = 48, ANEURALNETWORKS_EQUAL = 48,
ANEURALNETWORKS_EXP = 49, ANEURALNETWORKS_EXP = 49,
ANEURALNETWORKS_EXPAND_DIMS = 50,
ANEURALNETWORKS_GATHER = 51, ANEURALNETWORKS_GATHER = 51,
ANEURALNETWORKS_GREATER = 53, ANEURALNETWORKS_GREATER = 53,
ANEURALNETWORKS_GREATER_EQUAL = 54, ANEURALNETWORKS_GREATER_EQUAL = 54,