Added support for ArgMin and ArgMax.

PiperOrigin-RevId: 256416647
This commit is contained in:
A. Unique TensorFlower 2019-07-03 12:24:10 -07:00 committed by TensorFlower Gardener
parent 1ad6aae48d
commit 82b4538721
6 changed files with 214 additions and 149 deletions

View File

@ -871,6 +871,50 @@ class NNAPIDelegateKernel {
};
}
break;
case kTfLiteBuiltinArgMax:
case kTfLiteBuiltinArgMin:
if (version == 1) {
// Those operators were introduced in NNAPI 1.2.
if (android_sdk_version < kMinSdkVersionForNNAPI12) {
return nullptr;
}
// Only certain input types are supported.
auto input_type = context->tensors[node->inputs->data[0]].type;
if (input_type != kTfLiteFloat16 && input_type != kTfLiteFloat32 &&
input_type != kTfLiteInt32 && input_type != kTfLiteUInt8) {
return nullptr;
}
// NNAPI only supports axis as int32. If the axis type is int64 and
// constant we can convert it to int32 if the value isn't too large.
const auto& axis_tensor = context->tensors[node->inputs->data[1]];
if (axis_tensor.type == kTfLiteInt64) {
if (axis_tensor.allocation_type != kTfLiteMmapRo ||
*axis_tensor.data.i64 > std::numeric_limits<int32_t>::max() ||
*axis_tensor.data.i64 < std::numeric_limits<int32_t>::min()) {
return nullptr;
}
} else if (axis_tensor.type != kTfLiteInt32) {
return nullptr;
}
if (builtin_code == kTfLiteBuiltinArgMax) {
// NNAPI only supports int32 output.
auto builtin =
reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
if (builtin->output_type != kTfLiteInt32) {
return nullptr;
}
return BasicMappingFn<ANEURALNETWORKS_ARGMAX>;
} else {
// NNAPI only supports int32 output.
auto builtin =
reinterpret_cast<TfLiteArgMinParams*>(node->builtin_data);
if (builtin->output_type != kTfLiteInt32) {
return nullptr;
}
return BasicMappingFn<ANEURALNETWORKS_ARGMIN>;
}
}
break;
case kTfLiteBuiltinMul:
if (version == 1) {
if (!IsFloatOrUint8Operator(context, node)) {
@ -2360,6 +2404,37 @@ class NNAPIDelegateKernel {
input_pos == 1) {
// The axis param is added during Map
continue;
} else if (reg->builtin_code == kTfLiteBuiltinArgMin ||
reg->builtin_code == kTfLiteBuiltinArgMax) {
// The first input tensor is added as is. The second one, specifying
// the axis, needs to be converted to a scalar since TFLite uses a
// tensor but NNAPI uses a scalar as the axis.
if (input_pos == 0) {
TF_LITE_ENSURE_STATUS(
builder.AddTensorInput(input_index, hybrid_op));
} else {
const int axis_id = node->inputs->data[1];
const TfLiteTensor& axis_tensor = context->tensors[axis_id];
switch (axis_tensor.type) {
case kTfLiteInt32:
if (axis_tensor.allocation_type == kTfLiteMmapRo) {
TF_LITE_ENSURE_STATUS(builder.AddScalarInt32Operand(
static_cast<int32_t>(*axis_tensor.data.i32)));
} else {
TF_LITE_ENSURE_STATUS(
builder.AddSingleValueTensorAsScalarOperand(
axis_id, ANEURALNETWORKS_INT32));
}
break;
case kTfLiteInt64:
// Map() function already makes sure int64 input is constant.
TF_LITE_ENSURE_STATUS(builder.AddScalarInt32Operand(
static_cast<int32_t>(*axis_tensor.data.i64)));
break;
default:
return kTfLiteError;
}
}
} else {
TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op,
input_tensor_flags));

View File

@ -585,6 +585,7 @@ cc_test(
name = "arg_min_max_test",
size = "small",
srcs = ["arg_min_max_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",

View File

@ -23,249 +23,236 @@ namespace {
using ::testing::ElementsAreArray;
template <typename T>
class ArgBaseOpModel : public SingleOpModel {
public:
ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
TensorType axis_type, TensorType output_type) {
ArgBaseOpModel(TensorType input_type, int axis_value, TensorType axis_type,
bool constant_axis, TensorType output_type)
: axis_value_(axis_value),
axis_type_(axis_type),
constant_axis_(constant_axis) {
input_ = AddInput(input_type);
axis_ = AddInput(axis_type);
if (constant_axis) {
if (axis_type == TensorType_INT64) {
axis_ =
AddConstInput(axis_type, {static_cast<int64_t>(axis_value)}, {1});
} else {
axis_ = AddConstInput(axis_type, {axis_value}, {1});
}
} else {
axis_ = AddInput(axis_type);
}
output_ = AddOutput(output_type);
}
int input() { return input_; }
int axis() { return axis_; }
int input() const { return input_; }
int axis() const { return axis_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int32_t> GetInt32Output() const {
return ExtractVector<int32_t>(output_);
}
std::vector<int64_t> GetInt64Output() const {
return ExtractVector<int64_t>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected:
void PopulateAxisIfNeeded() {
if (constant_axis_) return;
if (axis_type_ == TensorType_INT32) {
PopulateTensor<int32_t>(axis(), {axis_value_});
} else {
PopulateTensor<int64_t>(axis(), {axis_value_});
}
}
const int axis_value_;
const TensorType axis_type_;
const bool constant_axis_;
int input_;
int axis_;
int output_;
};
template <typename T>
class ArgMaxOpModel : public ArgBaseOpModel<T> {
class ArgMaxOpModel : public ArgBaseOpModel {
public:
ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
TensorType axis_type, TensorType output_type)
: ArgBaseOpModel<T>(input_shape, input_type, axis_type, output_type) {
ArgBaseOpModel<T>::SetBuiltinOp(
int axis_value, TensorType axis_type, bool constant_axis,
TensorType output_type)
: ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
output_type) {
ArgBaseOpModel::SetBuiltinOp(
BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, output_type).Union());
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
CreateArgMaxOptions(ArgBaseOpModel::builder_, output_type).Union());
ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
PopulateAxisIfNeeded();
}
};
template <typename T>
class ArgMinOpModel : public ArgBaseOpModel<T> {
class ArgMinOpModel : public ArgBaseOpModel {
public:
ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
TensorType axis_type, TensorType output_type)
: ArgBaseOpModel<T>(input_shape, input_type, axis_type, output_type) {
ArgBaseOpModel<T>::SetBuiltinOp(
int axis_value, TensorType axis_type, bool constant_axis,
TensorType output_type)
: ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
output_type) {
ArgBaseOpModel::SetBuiltinOp(
BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
CreateArgMinOptions(ArgBaseOpModel<T>::builder_, output_type).Union());
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
CreateArgMinOptions(ArgBaseOpModel::builder_, output_type).Union());
ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
PopulateAxisIfNeeded();
}
};
TEST(ArgMaxOpTest, GetMaxArgFloat) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
TensorType_INT32, TensorType_INT32);
// Declare ArgMinMaxOpTest as a parameterized test, where the parameter is a
// tuple with:
// - boolean indicating whether to use a constant axis or not.
// - axis type (TensorType_INT32 or TensorType_INT64)
// - output type (TensorType_INT32 or TensorType_INT64)
class ArgMinMaxOpTest : public ::testing::TestWithParam<
std::tuple<bool, TensorType, TensorType>> {
public:
bool ConstantAxis() const { return std::get<0>(GetParam()); }
TensorType AxisType() const { return std::get<1>(GetParam()); }
TensorType OutputType() const { return std::get<2>(GetParam()); }
void ValidateOutput(const ArgBaseOpModel& model,
const std::vector<int>& expected_output) {
if (OutputType() == TensorType_INT32) {
EXPECT_THAT(model.GetInt32Output(), ElementsAreArray(expected_output));
} else {
EXPECT_THAT(model.GetInt64Output(), ElementsAreArray(expected_output));
}
}
};
INSTANTIATE_TEST_SUITE_P(
ArgMinMaxOpTest, ArgMinMaxOpTest,
::testing::Combine(::testing::Bool(),
::testing::Values(TensorType_INT32, TensorType_INT64),
::testing::Values(TensorType_INT32, TensorType_INT64)));
TEST_P(ArgMinMaxOpTest, GetMaxArgFloat) {
ArgMaxOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
ValidateOutput(model, {1});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
}
TEST(ArgMaxOpTest, GetMaxArgUInt8) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_UINT8, TensorType_INT32,
TensorType_INT32);
TEST_P(ArgMinMaxOpTest, GetMaxArgUInt8) {
ArgMaxOpModel model({1, 1, 1, 4}, TensorType_UINT8, 3, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<uint8_t>(model.input(), {1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
ValidateOutput(model, {1});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
}
TEST(ArgMaxOpTest, GetMaxArgInt8) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT8, TensorType_INT32,
TensorType_INT32);
TEST_P(ArgMinMaxOpTest, GetMaxArgInt8) {
ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT8, 3, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<int8_t>(model.input(), {-1, -9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
ValidateOutput(model, {2});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
}
TEST(ArgMaxOpTest, GetMaxArgInt) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT32);
TEST_P(ArgMinMaxOpTest, GetMaxArgInt) {
ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
ValidateOutput(model, {1});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
}
TEST(ArgMaxOpTest, GetMaxArgMulDimensions) {
ArgMaxOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT32);
TEST_P(ArgMinMaxOpTest, GetMaxArgMulDimensions) {
ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 1}));
ValidateOutput(model, {3, 1});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
TEST(ArgMaxOpTest, GetMaxArgNegativeAxis) {
ArgMaxOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT32);
TEST_P(ArgMinMaxOpTest, GetMaxArgNegativeAxis) {
ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {-2});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 1, 0, 0}));
ValidateOutput(model, {0, 1, 0, 0});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
}
TEST(ArgMaxOpTest, GetMaxArgOutput64) {
ArgMaxOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT64);
TEST_P(ArgMinMaxOpTest, GetMaxArgOutput64) {
ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 1}));
ValidateOutput(model, {0, 1});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
TEST(ArgMaxOpTest, GetMaxArgAxis64) {
// Input Int32, Axis Int64, Output Int64
ArgMaxOpModel<int64_t> model1({1, 1, 2, 4}, TensorType_INT32,
TensorType_INT64, TensorType_INT64);
model1.PopulateTensor<int>(model1.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model1.PopulateTensor<int64_t>(model1.axis(), {3});
model1.Invoke();
EXPECT_THAT(model1.GetOutput(), ElementsAreArray({0, 1}));
EXPECT_THAT(model1.GetOutputShape(), ElementsAreArray({1, 1, 2}));
// Input Int8, Axis Int64, Output Int32
ArgMaxOpModel<int32_t> model2({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
TensorType_INT32);
model2.PopulateTensor<int8_t>(model2.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model2.PopulateTensor<int64_t>(model2.axis(), {3});
model2.Invoke();
EXPECT_THAT(model2.GetOutput(), ElementsAreArray({0, 1}));
EXPECT_THAT(model2.GetOutputShape(), ElementsAreArray({1, 1, 2}));
// Input Int8, Axis Int64, Output Int64
ArgMaxOpModel<int64_t> model3({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
TensorType_INT64);
model3.PopulateTensor<int8_t>(model3.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model3.PopulateTensor<int64_t>(model3.axis(), {3});
model3.Invoke();
EXPECT_THAT(model3.GetOutput(), ElementsAreArray({0, 1}));
EXPECT_THAT(model3.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
TEST(ArgMinOpTest, GetMinArgFloat) {
ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
TensorType_INT32, TensorType_INT32);
TEST_P(ArgMinMaxOpTest, GetMinArgFloat) {
ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
ValidateOutput(model, {0});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
}
TEST(ArgMinOpTest, GetMinArgInt) {
ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT32);
TEST_P(ArgMinMaxOpTest, GetMinArgInt) {
ArgMinOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
ValidateOutput(model, {0});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
}
TEST(ArgMinOpTest, GetMinArgMulDimensions) {
ArgMinOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT32);
TEST_P(ArgMinMaxOpTest, GetMinArgMulDimensions) {
ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0}));
ValidateOutput(model, {0, 0});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
TEST(ArgMinOpTest, GetMinArgNegativeAxis) {
ArgMinOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT32);
TEST_P(ArgMinMaxOpTest, GetMinArgNegativeAxis) {
ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {-2});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0, 0, 1}));
ValidateOutput(model, {0, 0, 0, 1});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
}
TEST(ArgMinOpTest, GetMinArgOutput64) {
ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT64);
TEST_P(ArgMinMaxOpTest, GetMinArgOutput64) {
ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
ConstantAxis(), OutputType());
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model.PopulateTensor<int>(model.axis(), {3});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
ValidateOutput(model, {1, 0});
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
TEST(ArgMinOpTest, GetMinArgAxis64) {
// Input Int32, Axis Int64, Output Int64
ArgMinOpModel<int64_t> model1({1, 1, 2, 4}, TensorType_INT32,
TensorType_INT64, TensorType_INT64);
model1.PopulateTensor<int>(model1.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model1.PopulateTensor<int64_t>(model1.axis(), {3});
model1.Invoke();
EXPECT_THAT(model1.GetOutput(), ElementsAreArray({1, 0}));
EXPECT_THAT(model1.GetOutputShape(), ElementsAreArray({1, 1, 2}));
// Input Int8, Axis Int64, Output Int32
ArgMinOpModel<int32_t> model2({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
TensorType_INT32);
model2.PopulateTensor<int8_t>(model2.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model2.PopulateTensor<int64_t>(model2.axis(), {3});
model2.Invoke();
EXPECT_THAT(model2.GetOutput(), ElementsAreArray({1, 0}));
EXPECT_THAT(model2.GetOutputShape(), ElementsAreArray({1, 1, 2}));
// Input Int8, Axis Int64, Output Int64
ArgMinOpModel<int64_t> model3({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
TensorType_INT64);
model3.PopulateTensor<int8_t>(model3.input(), {10, 2, 7, 8, 1, 9, 7, 3});
model3.PopulateTensor<int64_t>(model3.axis(), {3});
model3.Invoke();
EXPECT_THAT(model3.GetOutput(), ElementsAreArray({1, 0}));
EXPECT_THAT(model3.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
} // namespace
} // namespace tflite

View File

@ -209,7 +209,7 @@ int32_t SingleOpModel::GetTensorSize(int index) const {
}
template <>
std::vector<string> SingleOpModel::ExtractVector(int index) {
std::vector<string> SingleOpModel::ExtractVector(int index) const {
TfLiteTensor* tensor_ptr = interpreter_->tensor(index);
CHECK(tensor_ptr != nullptr);
const int num_strings = GetStringCount(tensor_ptr);

View File

@ -323,8 +323,8 @@ class SingleOpModel {
// Return a vector with the flattened contents of a tensor.
template <typename T>
std::vector<T> ExtractVector(int index) {
T* v = interpreter_->typed_tensor<T>(index);
std::vector<T> ExtractVector(int index) const {
const T* v = interpreter_->typed_tensor<T>(index);
CHECK(v);
return std::vector<T>(v, v + GetTensorSize(index));
}
@ -594,7 +594,7 @@ TensorType GetTensorType() {
// Strings have a special implementation that is in test_util.cc
template <>
std::vector<string> SingleOpModel::ExtractVector(int index);
std::vector<string> SingleOpModel::ExtractVector(int index) const;
// The TypeUnion struct specializations hold a collection of related types.
// Each struct holds: 1. a primitive type (e.g. float), 2. a TensorType (e.g.

View File

@ -90,6 +90,8 @@ enum {
ANEURALNETWORKS_SUB = 36,
ANEURALNETWORKS_TRANSPOSE = 37,
ANEURALNETWORKS_ABS = 38,
ANEURALNETWORKS_ARGMAX = 39,
ANEURALNETWORKS_ARGMIN = 40,
ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM = 42,
ANEURALNETWORKS_EQUAL = 48,
ANEURALNETWORKS_EXP = 49,