diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index ea651ed7382..953a49c66fd 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -135,6 +135,13 @@ bool IsHybridOperator(const TfLiteContext* context, int builtin_code, const TfLiteType weights_type = context->tensors[weights_id].type; return IsFloat(input_type) && IsQuantized(weights_type); } + case kTfLiteBuiltinUnidirectionalSequenceRnn: { + const int input_id = node->inputs->data[0]; + const int weights_id = node->inputs->data[1]; + const TfLiteType input_type = context->tensors[input_id].type; + const TfLiteType weights_type = context->tensors[weights_id].type; + return IsFloat(input_type) && IsQuantized(weights_type); + } default: return false; } @@ -1191,6 +1198,22 @@ class NNAPIDelegateKernel { } } } break; + case kTfLiteBuiltinUnidirectionalSequenceRnn: + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) { + if (IsHybridOperator(context, builtin_code, node)) { + // Hybrid version of this op is not supported by NN API. + return nullptr; + } + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + mapping_args.builder->AddScalarInt32Operand(builtin->time_major); + return ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN; + }; + } + break; case kTfLiteBuiltinSpaceToBatchNd: if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) { return BasicMappingFn; diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 763ba0ae9b8..02422a09730 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -852,6 +852,7 @@ cc_test( name = "unidirectional_sequence_rnn_test", size = "small", srcs = ["unidirectional_sequence_rnn_test.cc"], + tags = ["tflite_nnapi"], deps = [ ":builtin_ops", ":test_main", diff --git a/tensorflow/lite/nnapi/NeuralNetworksTypes.h b/tensorflow/lite/nnapi/NeuralNetworksTypes.h index d0b26e1bfb6..164d161db5c 100644 --- a/tensorflow/lite/nnapi/NeuralNetworksTypes.h +++ b/tensorflow/lite/nnapi/NeuralNetworksTypes.h @@ -109,6 +109,7 @@ enum { ANEURALNETWORKS_SQRT = 88, ANEURALNETWORKS_TILE = 89, ANEURALNETWORKS_TOPK_V2 = 90, + ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN = 93, }; /**