Add unidirectional sequence rnn support to the NN API delegate

PiperOrigin-RevId: 253204784
This commit is contained in:
A. Unique TensorFlower 2019-06-14 04:20:37 -07:00 committed by TensorFlower Gardener
parent 8167da2aa5
commit 26945dc15e
3 changed files with 25 additions and 0 deletions

View File

@ -135,6 +135,13 @@ bool IsHybridOperator(const TfLiteContext* context, int builtin_code,
const TfLiteType weights_type = context->tensors[weights_id].type;
return IsFloat(input_type) && IsQuantized(weights_type);
}
case kTfLiteBuiltinUnidirectionalSequenceRnn: {
const int input_id = node->inputs->data[0];
const int weights_id = node->inputs->data[1];
const TfLiteType input_type = context->tensors[input_id].type;
const TfLiteType weights_type = context->tensors[weights_id].type;
return IsFloat(input_type) && IsQuantized(weights_type);
}
default:
return false;
}
@ -1191,6 +1198,22 @@ class NNAPIDelegateKernel {
}
}
} break;
case kTfLiteBuiltinUnidirectionalSequenceRnn:
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
if (IsHybridOperator(context, builtin_code, node)) {
// Hybrid version of this op is not supported by NN API.
return nullptr;
}
return [](const NNAPIOpMappingArgs& mapping_args)
-> ANeuralNetworksOperationType {
auto builtin = reinterpret_cast<TfLiteSequenceRNNParams*>(
mapping_args.node->builtin_data);
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
mapping_args.builder->AddScalarInt32Operand(builtin->time_major);
return ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN;
};
}
break;
case kTfLiteBuiltinSpaceToBatchNd:
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
return BasicMappingFn<ANEURALNETWORKS_SPACE_TO_BATCH_ND>;

View File

@ -852,6 +852,7 @@ cc_test(
name = "unidirectional_sequence_rnn_test",
size = "small",
srcs = ["unidirectional_sequence_rnn_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",

View File

@ -109,6 +109,7 @@ enum {
ANEURALNETWORKS_SQRT = 88,
ANEURALNETWORKS_TILE = 89,
ANEURALNETWORKS_TOPK_V2 = 90,
ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN = 93,
};
/**