Add bidirectional sequence LSTM support to the NNAPI delegate
PiperOrigin-RevId: 254756123
This commit is contained in:
parent
7fd0dcd8bc
commit
9a62d2e09e
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/builtin_op_data.h"
|
#include "tensorflow/lite/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/builtin_ops.h"
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/context_util.h"
|
#include "tensorflow/lite/context_util.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
@ -135,6 +136,14 @@ bool IsHybridOperator(const TfLiteContext* context, int builtin_code,
|
|||||||
const TfLiteType weights_type = context->tensors[weights_id].type;
|
const TfLiteType weights_type = context->tensors[weights_id].type;
|
||||||
return IsFloat(input_type) && IsQuantized(weights_type);
|
return IsFloat(input_type) && IsQuantized(weights_type);
|
||||||
}
|
}
|
||||||
|
case kTfLiteBuiltinBidirectionalSequenceLstm: {
|
||||||
|
const int input_id = node->inputs->data[0];
|
||||||
|
// Input #1 is optional so use #2 to determine if hybrid.
|
||||||
|
const int weights_id = node->inputs->data[2];
|
||||||
|
const TfLiteType input_type = context->tensors[input_id].type;
|
||||||
|
const TfLiteType weights_type = context->tensors[weights_id].type;
|
||||||
|
return IsFloat(input_type) && IsQuantized(weights_type);
|
||||||
|
}
|
||||||
case kTfLiteBuiltinUnidirectionalSequenceRnn: {
|
case kTfLiteBuiltinUnidirectionalSequenceRnn: {
|
||||||
const int input_id = node->inputs->data[0];
|
const int input_id = node->inputs->data[0];
|
||||||
const int weights_id = node->inputs->data[1];
|
const int weights_id = node->inputs->data[1];
|
||||||
@ -1633,6 +1642,31 @@ class NNAPIDelegateKernel {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case kTfLiteBuiltinBidirectionalSequenceLstm:
|
||||||
|
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
|
||||||
|
if (IsHybridOperator(context, builtin_code, node)) {
|
||||||
|
// Hybrid version of this op is not supported by NN API.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return [](const NNAPIOpMappingArgs& mapping_args)
|
||||||
|
-> ANeuralNetworksOperationType {
|
||||||
|
auto builtin =
|
||||||
|
reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
|
||||||
|
mapping_args.node->builtin_data);
|
||||||
|
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
|
||||||
|
mapping_args.builder->AddScalarFloat32Operand(builtin->cell_clip);
|
||||||
|
mapping_args.builder->AddScalarFloat32Operand(builtin->proj_clip);
|
||||||
|
mapping_args.builder->AddScalarBoolOperand(builtin->merge_outputs);
|
||||||
|
mapping_args.builder->AddScalarBoolOperand(builtin->time_major);
|
||||||
|
// TF Lite doesn't support layer normalization in bidirectional
|
||||||
|
// sequence LSTM, so we insert optional tensors for NNAPI
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
mapping_args.builder->AddVectorFloat32Operand(nullptr, 0);
|
||||||
|
}
|
||||||
|
return ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
// All other operators are not mapped.
|
// All other operators are not mapped.
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -2037,7 +2071,8 @@ class NNAPIDelegateKernel {
|
|||||||
|
|
||||||
if (input_index == kOptionalTensor &&
|
if (input_index == kOptionalTensor &&
|
||||||
(reg->builtin_code == kTfLiteBuiltinLstm ||
|
(reg->builtin_code == kTfLiteBuiltinLstm ||
|
||||||
reg->builtin_code == kTfLiteBuiltinSvdf)) {
|
reg->builtin_code == kTfLiteBuiltinSvdf ||
|
||||||
|
reg->builtin_code == kTfLiteBuiltinBidirectionalSequenceLstm)) {
|
||||||
// properly handle the optional tensor for LSTM and SVDF.
|
// properly handle the optional tensor for LSTM and SVDF.
|
||||||
// currently only support float32.
|
// currently only support float32.
|
||||||
// TODO(miaowang): make sure this is also able to handle quantized
|
// TODO(miaowang): make sure this is also able to handle quantized
|
||||||
|
@ -754,6 +754,7 @@ cc_test(
|
|||||||
name = "bidirectional_sequence_lstm_test",
|
name = "bidirectional_sequence_lstm_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["bidirectional_sequence_lstm_test.cc"],
|
srcs = ["bidirectional_sequence_lstm_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
":test_main",
|
":test_main",
|
||||||
|
@ -89,6 +89,7 @@ enum {
|
|||||||
ANEURALNETWORKS_SUB = 36,
|
ANEURALNETWORKS_SUB = 36,
|
||||||
ANEURALNETWORKS_TRANSPOSE = 37,
|
ANEURALNETWORKS_TRANSPOSE = 37,
|
||||||
ANEURALNETWORKS_ABS = 38,
|
ANEURALNETWORKS_ABS = 38,
|
||||||
|
ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM = 42,
|
||||||
ANEURALNETWORKS_EQUAL = 48,
|
ANEURALNETWORKS_EQUAL = 48,
|
||||||
ANEURALNETWORKS_EXP = 49,
|
ANEURALNETWORKS_EXP = 49,
|
||||||
ANEURALNETWORKS_GATHER = 51,
|
ANEURALNETWORKS_GATHER = 51,
|
||||||
|
Loading…
Reference in New Issue
Block a user