Add unidirectional sequence rnn support to the NN API delegate
PiperOrigin-RevId: 253204784
This commit is contained in:
parent
8167da2aa5
commit
26945dc15e
@ -135,6 +135,13 @@ bool IsHybridOperator(const TfLiteContext* context, int builtin_code,
|
|||||||
const TfLiteType weights_type = context->tensors[weights_id].type;
|
const TfLiteType weights_type = context->tensors[weights_id].type;
|
||||||
return IsFloat(input_type) && IsQuantized(weights_type);
|
return IsFloat(input_type) && IsQuantized(weights_type);
|
||||||
}
|
}
|
||||||
|
case kTfLiteBuiltinUnidirectionalSequenceRnn: {
|
||||||
|
const int input_id = node->inputs->data[0];
|
||||||
|
const int weights_id = node->inputs->data[1];
|
||||||
|
const TfLiteType input_type = context->tensors[input_id].type;
|
||||||
|
const TfLiteType weights_type = context->tensors[weights_id].type;
|
||||||
|
return IsFloat(input_type) && IsQuantized(weights_type);
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -1191,6 +1198,22 @@ class NNAPIDelegateKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case kTfLiteBuiltinUnidirectionalSequenceRnn:
|
||||||
|
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
|
||||||
|
if (IsHybridOperator(context, builtin_code, node)) {
|
||||||
|
// Hybrid version of this op is not supported by NN API.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return [](const NNAPIOpMappingArgs& mapping_args)
|
||||||
|
-> ANeuralNetworksOperationType {
|
||||||
|
auto builtin = reinterpret_cast<TfLiteSequenceRNNParams*>(
|
||||||
|
mapping_args.node->builtin_data);
|
||||||
|
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
|
||||||
|
mapping_args.builder->AddScalarInt32Operand(builtin->time_major);
|
||||||
|
return ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
break;
|
||||||
case kTfLiteBuiltinSpaceToBatchNd:
|
case kTfLiteBuiltinSpaceToBatchNd:
|
||||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
|
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) {
|
||||||
return BasicMappingFn<ANEURALNETWORKS_SPACE_TO_BATCH_ND>;
|
return BasicMappingFn<ANEURALNETWORKS_SPACE_TO_BATCH_ND>;
|
||||||
|
@ -852,6 +852,7 @@ cc_test(
|
|||||||
name = "unidirectional_sequence_rnn_test",
|
name = "unidirectional_sequence_rnn_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["unidirectional_sequence_rnn_test.cc"],
|
srcs = ["unidirectional_sequence_rnn_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
":test_main",
|
":test_main",
|
||||||
|
@ -109,6 +109,7 @@ enum {
|
|||||||
ANEURALNETWORKS_SQRT = 88,
|
ANEURALNETWORKS_SQRT = 88,
|
||||||
ANEURALNETWORKS_TILE = 89,
|
ANEURALNETWORKS_TILE = 89,
|
||||||
ANEURALNETWORKS_TOPK_V2 = 90,
|
ANEURALNETWORKS_TOPK_V2 = 90,
|
||||||
|
ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN = 93,
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
Loading…
x
Reference in New Issue
Block a user