Added support for ArgMin and ArgMax.
PiperOrigin-RevId: 256416647
This commit is contained in:
parent
1ad6aae48d
commit
82b4538721
@ -871,6 +871,50 @@ class NNAPIDelegateKernel {
|
||||
};
|
||||
}
|
||||
break;
|
||||
case kTfLiteBuiltinArgMax:
|
||||
case kTfLiteBuiltinArgMin:
|
||||
if (version == 1) {
|
||||
// Those operators were introduced in NNAPI 1.2.
|
||||
if (android_sdk_version < kMinSdkVersionForNNAPI12) {
|
||||
return nullptr;
|
||||
}
|
||||
// Only certain input types are supported.
|
||||
auto input_type = context->tensors[node->inputs->data[0]].type;
|
||||
if (input_type != kTfLiteFloat16 && input_type != kTfLiteFloat32 &&
|
||||
input_type != kTfLiteInt32 && input_type != kTfLiteUInt8) {
|
||||
return nullptr;
|
||||
}
|
||||
// NNAPI only supports axis as int32. If the axis type is int64 and
|
||||
// constant we can convert it to int32 if the value isn't too large.
|
||||
const auto& axis_tensor = context->tensors[node->inputs->data[1]];
|
||||
if (axis_tensor.type == kTfLiteInt64) {
|
||||
if (axis_tensor.allocation_type != kTfLiteMmapRo ||
|
||||
*axis_tensor.data.i64 > std::numeric_limits<int32_t>::max() ||
|
||||
*axis_tensor.data.i64 < std::numeric_limits<int32_t>::min()) {
|
||||
return nullptr;
|
||||
}
|
||||
} else if (axis_tensor.type != kTfLiteInt32) {
|
||||
return nullptr;
|
||||
}
|
||||
if (builtin_code == kTfLiteBuiltinArgMax) {
|
||||
// NNAPI only supports int32 output.
|
||||
auto builtin =
|
||||
reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
|
||||
if (builtin->output_type != kTfLiteInt32) {
|
||||
return nullptr;
|
||||
}
|
||||
return BasicMappingFn<ANEURALNETWORKS_ARGMAX>;
|
||||
} else {
|
||||
// NNAPI only supports int32 output.
|
||||
auto builtin =
|
||||
reinterpret_cast<TfLiteArgMinParams*>(node->builtin_data);
|
||||
if (builtin->output_type != kTfLiteInt32) {
|
||||
return nullptr;
|
||||
}
|
||||
return BasicMappingFn<ANEURALNETWORKS_ARGMIN>;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case kTfLiteBuiltinMul:
|
||||
if (version == 1) {
|
||||
if (!IsFloatOrUint8Operator(context, node)) {
|
||||
@ -2360,6 +2404,37 @@ class NNAPIDelegateKernel {
|
||||
input_pos == 1) {
|
||||
// The axis param is added during Map
|
||||
continue;
|
||||
} else if (reg->builtin_code == kTfLiteBuiltinArgMin ||
|
||||
reg->builtin_code == kTfLiteBuiltinArgMax) {
|
||||
// The first input tensor is added as is. The second one, specifying
|
||||
// the axis, needs to be converted to a scalar since TFLite uses a
|
||||
// tensor but NNAPI uses a scalar as the axis.
|
||||
if (input_pos == 0) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
builder.AddTensorInput(input_index, hybrid_op));
|
||||
} else {
|
||||
const int axis_id = node->inputs->data[1];
|
||||
const TfLiteTensor& axis_tensor = context->tensors[axis_id];
|
||||
switch (axis_tensor.type) {
|
||||
case kTfLiteInt32:
|
||||
if (axis_tensor.allocation_type == kTfLiteMmapRo) {
|
||||
TF_LITE_ENSURE_STATUS(builder.AddScalarInt32Operand(
|
||||
static_cast<int32_t>(*axis_tensor.data.i32)));
|
||||
} else {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
builder.AddSingleValueTensorAsScalarOperand(
|
||||
axis_id, ANEURALNETWORKS_INT32));
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
// Map() function already makes sure int64 input is constant.
|
||||
TF_LITE_ENSURE_STATUS(builder.AddScalarInt32Operand(
|
||||
static_cast<int32_t>(*axis_tensor.data.i64)));
|
||||
break;
|
||||
default:
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op,
|
||||
input_tensor_flags));
|
||||
|
@ -585,6 +585,7 @@ cc_test(
|
||||
name = "arg_min_max_test",
|
||||
size = "small",
|
||||
srcs = ["arg_min_max_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
|
@ -23,249 +23,236 @@ namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <typename T>
|
||||
class ArgBaseOpModel : public SingleOpModel {
|
||||
public:
|
||||
ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
||||
TensorType axis_type, TensorType output_type) {
|
||||
ArgBaseOpModel(TensorType input_type, int axis_value, TensorType axis_type,
|
||||
bool constant_axis, TensorType output_type)
|
||||
: axis_value_(axis_value),
|
||||
axis_type_(axis_type),
|
||||
constant_axis_(constant_axis) {
|
||||
input_ = AddInput(input_type);
|
||||
axis_ = AddInput(axis_type);
|
||||
if (constant_axis) {
|
||||
if (axis_type == TensorType_INT64) {
|
||||
axis_ =
|
||||
AddConstInput(axis_type, {static_cast<int64_t>(axis_value)}, {1});
|
||||
} else {
|
||||
axis_ = AddConstInput(axis_type, {axis_value}, {1});
|
||||
}
|
||||
} else {
|
||||
axis_ = AddInput(axis_type);
|
||||
}
|
||||
output_ = AddOutput(output_type);
|
||||
}
|
||||
|
||||
int input() { return input_; }
|
||||
int axis() { return axis_; }
|
||||
int input() const { return input_; }
|
||||
int axis() const { return axis_; }
|
||||
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||
std::vector<int32_t> GetInt32Output() const {
|
||||
return ExtractVector<int32_t>(output_);
|
||||
}
|
||||
std::vector<int64_t> GetInt64Output() const {
|
||||
return ExtractVector<int64_t>(output_);
|
||||
}
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
protected:
|
||||
void PopulateAxisIfNeeded() {
|
||||
if (constant_axis_) return;
|
||||
if (axis_type_ == TensorType_INT32) {
|
||||
PopulateTensor<int32_t>(axis(), {axis_value_});
|
||||
} else {
|
||||
PopulateTensor<int64_t>(axis(), {axis_value_});
|
||||
}
|
||||
}
|
||||
|
||||
const int axis_value_;
|
||||
const TensorType axis_type_;
|
||||
const bool constant_axis_;
|
||||
|
||||
int input_;
|
||||
int axis_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ArgMaxOpModel : public ArgBaseOpModel<T> {
|
||||
class ArgMaxOpModel : public ArgBaseOpModel {
|
||||
public:
|
||||
ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
||||
TensorType axis_type, TensorType output_type)
|
||||
: ArgBaseOpModel<T>(input_shape, input_type, axis_type, output_type) {
|
||||
ArgBaseOpModel<T>::SetBuiltinOp(
|
||||
int axis_value, TensorType axis_type, bool constant_axis,
|
||||
TensorType output_type)
|
||||
: ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
|
||||
output_type) {
|
||||
ArgBaseOpModel::SetBuiltinOp(
|
||||
BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
|
||||
CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, output_type).Union());
|
||||
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
|
||||
CreateArgMaxOptions(ArgBaseOpModel::builder_, output_type).Union());
|
||||
ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
|
||||
PopulateAxisIfNeeded();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ArgMinOpModel : public ArgBaseOpModel<T> {
|
||||
class ArgMinOpModel : public ArgBaseOpModel {
|
||||
public:
|
||||
ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
|
||||
TensorType axis_type, TensorType output_type)
|
||||
: ArgBaseOpModel<T>(input_shape, input_type, axis_type, output_type) {
|
||||
ArgBaseOpModel<T>::SetBuiltinOp(
|
||||
int axis_value, TensorType axis_type, bool constant_axis,
|
||||
TensorType output_type)
|
||||
: ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis,
|
||||
output_type) {
|
||||
ArgBaseOpModel::SetBuiltinOp(
|
||||
BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
|
||||
CreateArgMinOptions(ArgBaseOpModel<T>::builder_, output_type).Union());
|
||||
ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
|
||||
CreateArgMinOptions(ArgBaseOpModel::builder_, output_type).Union());
|
||||
ArgBaseOpModel::BuildInterpreter({input_shape, {1}});
|
||||
PopulateAxisIfNeeded();
|
||||
}
|
||||
};
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgFloat) {
|
||||
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
|
||||
TensorType_INT32, TensorType_INT32);
|
||||
// Declare ArgMinMaxOpTest as a parameterized test, where the parameter is a
|
||||
// tuple with:
|
||||
// - boolean indicating whether to use a constant axis or not.
|
||||
// - axis type (TensorType_INT32 or TensorType_INT64)
|
||||
// - output type (TensorType_INT32 or TensorType_INT64)
|
||||
class ArgMinMaxOpTest : public ::testing::TestWithParam<
|
||||
std::tuple<bool, TensorType, TensorType>> {
|
||||
public:
|
||||
bool ConstantAxis() const { return std::get<0>(GetParam()); }
|
||||
|
||||
TensorType AxisType() const { return std::get<1>(GetParam()); }
|
||||
|
||||
TensorType OutputType() const { return std::get<2>(GetParam()); }
|
||||
|
||||
void ValidateOutput(const ArgBaseOpModel& model,
|
||||
const std::vector<int>& expected_output) {
|
||||
if (OutputType() == TensorType_INT32) {
|
||||
EXPECT_THAT(model.GetInt32Output(), ElementsAreArray(expected_output));
|
||||
} else {
|
||||
EXPECT_THAT(model.GetInt64Output(), ElementsAreArray(expected_output));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ArgMinMaxOpTest, ArgMinMaxOpTest,
|
||||
::testing::Combine(::testing::Bool(),
|
||||
::testing::Values(TensorType_INT32, TensorType_INT64),
|
||||
::testing::Values(TensorType_INT32, TensorType_INT64)));
|
||||
|
||||
TEST_P(ArgMinMaxOpTest, GetMaxArgFloat) {
|
||||
ArgMaxOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
|
||||
ValidateOutput(model, {1});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgUInt8) {
|
||||
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_UINT8, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(ArgMinMaxOpTest, GetMaxArgUInt8) {
|
||||
ArgMaxOpModel model({1, 1, 1, 4}, TensorType_UINT8, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<uint8_t>(model.input(), {1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
|
||||
ValidateOutput(model, {1});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgInt8) {
|
||||
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT8, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(ArgMinMaxOpTest, GetMaxArgInt8) {
|
||||
ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT8, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<int8_t>(model.input(), {-1, -9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
|
||||
ValidateOutput(model, {2});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgInt) {
|
||||
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(ArgMinMaxOpTest, GetMaxArgInt) {
|
||||
ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
|
||||
ValidateOutput(model, {1});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgMulDimensions) {
|
||||
ArgMaxOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(ArgMinMaxOpTest, GetMaxArgMulDimensions) {
|
||||
ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 1}));
|
||||
ValidateOutput(model, {3, 1});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgNegativeAxis) {
|
||||
ArgMaxOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(ArgMinMaxOpTest, GetMaxArgNegativeAxis) {
|
||||
ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {-2});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 1, 0, 0}));
|
||||
ValidateOutput(model, {0, 1, 0, 0});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgOutput64) {
|
||||
ArgMaxOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT64);
|
||||
TEST_P(ArgMinMaxOpTest, GetMaxArgOutput64) {
|
||||
ArgMaxOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 1}));
|
||||
ValidateOutput(model, {0, 1});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(ArgMaxOpTest, GetMaxArgAxis64) {
|
||||
// Input Int32, Axis Int64, Output Int64
|
||||
ArgMaxOpModel<int64_t> model1({1, 1, 2, 4}, TensorType_INT32,
|
||||
TensorType_INT64, TensorType_INT64);
|
||||
model1.PopulateTensor<int>(model1.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model1.PopulateTensor<int64_t>(model1.axis(), {3});
|
||||
model1.Invoke();
|
||||
|
||||
EXPECT_THAT(model1.GetOutput(), ElementsAreArray({0, 1}));
|
||||
EXPECT_THAT(model1.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
|
||||
// Input Int8, Axis Int64, Output Int32
|
||||
ArgMaxOpModel<int32_t> model2({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||
TensorType_INT32);
|
||||
model2.PopulateTensor<int8_t>(model2.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model2.PopulateTensor<int64_t>(model2.axis(), {3});
|
||||
model2.Invoke();
|
||||
|
||||
EXPECT_THAT(model2.GetOutput(), ElementsAreArray({0, 1}));
|
||||
EXPECT_THAT(model2.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
|
||||
// Input Int8, Axis Int64, Output Int64
|
||||
ArgMaxOpModel<int64_t> model3({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||
TensorType_INT64);
|
||||
model3.PopulateTensor<int8_t>(model3.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model3.PopulateTensor<int64_t>(model3.axis(), {3});
|
||||
model3.Invoke();
|
||||
|
||||
EXPECT_THAT(model3.GetOutput(), ElementsAreArray({0, 1}));
|
||||
EXPECT_THAT(model3.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(ArgMinOpTest, GetMinArgFloat) {
|
||||
ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
|
||||
TensorType_INT32, TensorType_INT32);
|
||||
TEST_P(ArgMinMaxOpTest, GetMinArgFloat) {
|
||||
ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
|
||||
ValidateOutput(model, {0});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(ArgMinOpTest, GetMinArgInt) {
|
||||
ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(ArgMinMaxOpTest, GetMinArgInt) {
|
||||
ArgMinOpModel model({1, 1, 1, 4}, TensorType_INT32, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
|
||||
ValidateOutput(model, {0});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(ArgMinOpTest, GetMinArgMulDimensions) {
|
||||
ArgMinOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(ArgMinMaxOpTest, GetMinArgMulDimensions) {
|
||||
ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0}));
|
||||
ValidateOutput(model, {0, 0});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(ArgMinOpTest, GetMinArgNegativeAxis) {
|
||||
ArgMinOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(ArgMinMaxOpTest, GetMinArgNegativeAxis) {
|
||||
ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, -2, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {-2});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0, 0, 1}));
|
||||
ValidateOutput(model, {0, 0, 0, 1});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(ArgMinOpTest, GetMinArgOutput64) {
|
||||
ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT64);
|
||||
TEST_P(ArgMinMaxOpTest, GetMinArgOutput64) {
|
||||
ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model.PopulateTensor<int>(model.axis(), {3});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
|
||||
ValidateOutput(model, {1, 0});
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(ArgMinOpTest, GetMinArgAxis64) {
|
||||
// Input Int32, Axis Int64, Output Int64
|
||||
ArgMinOpModel<int64_t> model1({1, 1, 2, 4}, TensorType_INT32,
|
||||
TensorType_INT64, TensorType_INT64);
|
||||
model1.PopulateTensor<int>(model1.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model1.PopulateTensor<int64_t>(model1.axis(), {3});
|
||||
model1.Invoke();
|
||||
|
||||
EXPECT_THAT(model1.GetOutput(), ElementsAreArray({1, 0}));
|
||||
EXPECT_THAT(model1.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
|
||||
// Input Int8, Axis Int64, Output Int32
|
||||
ArgMinOpModel<int32_t> model2({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||
TensorType_INT32);
|
||||
model2.PopulateTensor<int8_t>(model2.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model2.PopulateTensor<int64_t>(model2.axis(), {3});
|
||||
model2.Invoke();
|
||||
|
||||
EXPECT_THAT(model2.GetOutput(), ElementsAreArray({1, 0}));
|
||||
EXPECT_THAT(model2.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
|
||||
// Input Int8, Axis Int64, Output Int64
|
||||
ArgMinOpModel<int64_t> model3({1, 1, 2, 4}, TensorType_INT8, TensorType_INT64,
|
||||
TensorType_INT64);
|
||||
model3.PopulateTensor<int8_t>(model3.input(), {10, 2, 7, 8, 1, 9, 7, 3});
|
||||
model3.PopulateTensor<int64_t>(model3.axis(), {3});
|
||||
model3.Invoke();
|
||||
|
||||
EXPECT_THAT(model3.GetOutput(), ElementsAreArray({1, 0}));
|
||||
EXPECT_THAT(model3.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -209,7 +209,7 @@ int32_t SingleOpModel::GetTensorSize(int index) const {
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<string> SingleOpModel::ExtractVector(int index) {
|
||||
std::vector<string> SingleOpModel::ExtractVector(int index) const {
|
||||
TfLiteTensor* tensor_ptr = interpreter_->tensor(index);
|
||||
CHECK(tensor_ptr != nullptr);
|
||||
const int num_strings = GetStringCount(tensor_ptr);
|
||||
|
@ -323,8 +323,8 @@ class SingleOpModel {
|
||||
|
||||
// Return a vector with the flattened contents of a tensor.
|
||||
template <typename T>
|
||||
std::vector<T> ExtractVector(int index) {
|
||||
T* v = interpreter_->typed_tensor<T>(index);
|
||||
std::vector<T> ExtractVector(int index) const {
|
||||
const T* v = interpreter_->typed_tensor<T>(index);
|
||||
CHECK(v);
|
||||
return std::vector<T>(v, v + GetTensorSize(index));
|
||||
}
|
||||
@ -594,7 +594,7 @@ TensorType GetTensorType() {
|
||||
|
||||
// Strings have a special implementation that is in test_util.cc
|
||||
template <>
|
||||
std::vector<string> SingleOpModel::ExtractVector(int index);
|
||||
std::vector<string> SingleOpModel::ExtractVector(int index) const;
|
||||
|
||||
// The TypeUnion struct specializations hold a collection of related types.
|
||||
// Each struct holds: 1. a primitive type (e.g. float), 2. a TensorType (e.g.
|
||||
|
@ -90,6 +90,8 @@ enum {
|
||||
ANEURALNETWORKS_SUB = 36,
|
||||
ANEURALNETWORKS_TRANSPOSE = 37,
|
||||
ANEURALNETWORKS_ABS = 38,
|
||||
ANEURALNETWORKS_ARGMAX = 39,
|
||||
ANEURALNETWORKS_ARGMIN = 40,
|
||||
ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM = 42,
|
||||
ANEURALNETWORKS_EQUAL = 48,
|
||||
ANEURALNETWORKS_EXP = 49,
|
||||
|
Loading…
Reference in New Issue
Block a user