Add unidirectional sequence rnn support to the NN API delegate
PiperOrigin-RevId: 253204784
This commit is contained in:
parent
8167da2aa5
commit
26945dc15e
@ -135,6 +135,13 @@ bool IsHybridOperator(const TfLiteContext* context, int builtin_code,
|
||||
const TfLiteType weights_type = context->tensors[weights_id].type;
|
||||
return IsFloat(input_type) && IsQuantized(weights_type);
|
||||
}
|
||||
case kTfLiteBuiltinUnidirectionalSequenceRnn: {
|
||||
const int input_id = node->inputs->data[0];
|
||||
const int weights_id = node->inputs->data[1];
|
||||
const TfLiteType input_type = context->tensors[input_id].type;
|
||||
const TfLiteType weights_type = context->tensors[weights_id].type;
|
||||
return IsFloat(input_type) && IsQuantized(weights_type);
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -1191,6 +1198,22 @@ class NNAPIDelegateKernel {
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinUnidirectionalSequenceRnn:
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
|
||||
if (IsHybridOperator(context, builtin_code, node)) {
|
||||
// Hybrid version of this op is not supported by NN API.
|
||||
return nullptr;
|
||||
}
|
||||
return [](const NNAPIOpMappingArgs& mapping_args)
|
||||
-> ANeuralNetworksOperationType {
|
||||
auto builtin = reinterpret_cast<TfLiteSequenceRNNParams*>(
|
||||
mapping_args.node->builtin_data);
|
||||
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
|
||||
mapping_args.builder->AddScalarInt32Operand(builtin->time_major);
|
||||
return ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN;
|
||||
};
|
||||
}
|
||||
break;
|
||||
case kTfLiteBuiltinSpaceToBatchNd:
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
|
||||
return BasicMappingFn<ANEURALNETWORKS_SPACE_TO_BATCH_ND>;
|
||||
|
@ -852,6 +852,7 @@ cc_test(
|
||||
name = "unidirectional_sequence_rnn_test",
|
||||
size = "small",
|
||||
srcs = ["unidirectional_sequence_rnn_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
|
@ -109,6 +109,7 @@ enum {
|
||||
ANEURALNETWORKS_SQRT = 88,
|
||||
ANEURALNETWORKS_TILE = 89,
|
||||
ANEURALNETWORKS_TOPK_V2 = 90,
|
||||
ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN = 93,
|
||||
};
|
||||
|
||||
/**
|
||||
|
Loading…
x
Reference in New Issue
Block a user