Add POW op support to the NNAPI delegate
PiperOrigin-RevId: 256183546
This commit is contained in:
parent
ae8d42bbb4
commit
1979135fb7
@ -1424,6 +1424,13 @@ class NNAPIDelegateKernel {
|
|||||||
return BasicMappingFn<ANEURALNETWORKS_RSQRT>;
|
return BasicMappingFn<ANEURALNETWORKS_RSQRT>;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteBuiltinPow:
|
||||||
|
// NN API only supports float inputs to this op.
|
||||||
|
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||||
|
context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) {
|
||||||
|
return BasicMappingFn<ANEURALNETWORKS_POW>;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case kTfLiteBuiltinSin:
|
case kTfLiteBuiltinSin:
|
||||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
|
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
|
||||||
return BasicMappingFn<ANEURALNETWORKS_SIN>;
|
return BasicMappingFn<ANEURALNETWORKS_SIN>;
|
||||||
|
@ -1429,6 +1429,7 @@ cc_test(
|
|||||||
name = "pow_test",
|
name = "pow_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["pow_test.cc"],
|
srcs = ["pow_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
":test_main",
|
":test_main",
|
||||||
|
@ -108,6 +108,16 @@ TEST(PowOpModel, BroadcastTest) {
|
|||||||
EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096));
|
EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PowOpModel, BroadcastFloatTest) {
|
||||||
|
PowOpModel<float> model({TensorType_FLOAT32, {1, 2, 2, 1}},
|
||||||
|
{TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {}});
|
||||||
|
model.PopulateTensor<float>(model.input1(), {12, 2, 7, 8});
|
||||||
|
model.PopulateTensor<float>(model.input2(), {4});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalculateTrueResults(const std::vector<T>& input_data, T exponent,
|
void CalculateTrueResults(const std::vector<T>& input_data, T exponent,
|
||||||
int flat_size, std::vector<T>* output_data) {
|
int flat_size, std::vector<T>* output_data) {
|
||||||
|
@ -108,6 +108,7 @@ enum {
|
|||||||
ANEURALNETWORKS_NEG = 67,
|
ANEURALNETWORKS_NEG = 67,
|
||||||
ANEURALNETWORKS_NOT_EQUAL = 68,
|
ANEURALNETWORKS_NOT_EQUAL = 68,
|
||||||
ANEURALNETWORKS_PAD_V2 = 69,
|
ANEURALNETWORKS_PAD_V2 = 69,
|
||||||
|
ANEURALNETWORKS_POW = 70,
|
||||||
ANEURALNETWORKS_PRELU = 71,
|
ANEURALNETWORKS_PRELU = 71,
|
||||||
ANEURALNETWORKS_RSQRT = 83,
|
ANEURALNETWORKS_RSQRT = 83,
|
||||||
ANEURALNETWORKS_SELECT = 84,
|
ANEURALNETWORKS_SELECT = 84,
|
||||||
|
Loading…
Reference in New Issue
Block a user