diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD index 419a3b2486d..97a3d3f78de 100644 --- a/tensorflow/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -24,9 +24,12 @@ cc_library( build_for_embedded = True, copts = tflite_copts() + micro_copts(), deps = [ - "//tensorflow/lite/c:common", - "//tensorflow/lite/schema:schema_fbs", "@flatbuffers//:runtime_cc", + "//tensorflow/lite/c:common", + # TODO(b/158301698): consider moving internal:compatibility to a more + # central location. + "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/schema:schema_fbs", ], ) diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index c52fc9f690b..5f39732b970 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -89,6 +90,25 @@ TfLiteStatus FlatBufferIntVectorToArray( return kTfLiteOk; } +// Converts the flatbuffer activation to what is used at runtime. +TfLiteFusedActivation ConvertActivation(ActivationFunctionType activation) { + switch (activation) { + case ActivationFunctionType_NONE: + return kTfLiteActNone; + case ActivationFunctionType_RELU: + return kTfLiteActRelu; + case ActivationFunctionType_RELU_N1_TO_1: + return kTfLiteActRelu1; + case ActivationFunctionType_RELU6: + return kTfLiteActRelu6; + case ActivationFunctionType_TANH: + return kTfLiteActTanh; + case ActivationFunctionType_SIGN_BIT: + return kTfLiteActSignBit; + } + return kTfLiteActNone; +} + } // namespace TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, @@ -135,12 +155,131 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, } } -// Parse the appropriate data out of the op. -// -// This handles builtin data explicitly as there are flatbuffer schemas. -// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which -// need to be released by calling `free`.` -// If it returns kTfLiteError, `builtin_data` will be `nullptr`. +// We have this parse function instead of directly returning kTfLiteOk from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +TfLiteStatus ParseDequantize(const Operator*, BuiltinOperator, ErrorReporter*, + BuiltinDataAllocator*, void**) { + return kTfLiteOk; +} + +TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data) { + TFLITE_DCHECK(op != nullptr); + TFLITE_DCHECK(error_reporter != nullptr); + TFLITE_DCHECK(allocator != nullptr); + TFLITE_DCHECK(builtin_data != nullptr); + + SafeBuiltinDataAllocator safe_allocator(allocator); + + std::unique_ptr + params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + + const FullyConnectedOptions* schema_params = + op->builtin_options_as_FullyConnectedOptions(); + + if (schema_params != nullptr) { + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + params->keep_num_dims = schema_params->keep_num_dims(); + params->asymmetric_quantize_inputs = + schema_params->asymmetric_quantize_inputs(); + + switch (schema_params->weights_format()) { + case FullyConnectedOptionsWeightsFormat_DEFAULT: + params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; + break; + case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + params->weights_format = + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; + break; + default: + TF_LITE_REPORT_ERROR(error_reporter, + "Unhandled fully-connected weights format."); + return kTfLiteError; + } + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better undertand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return kTfLiteOk; +} + +// We have this parse function instead of directly returning kTfLiteOk from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +TfLiteStatus ParseQuantize(const Operator*, BuiltinOperator, ErrorReporter*, + BuiltinDataAllocator*, void**) { + return kTfLiteOk; +} + +TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data) { + TFLITE_DCHECK(op != nullptr); + TFLITE_DCHECK(error_reporter != nullptr); + TFLITE_DCHECK(allocator != nullptr); + TFLITE_DCHECK(builtin_data != nullptr); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + + const SoftmaxOptions* schema_params = op->builtin_options_as_SoftmaxOptions(); + + if (schema_params != nullptr) { + params->beta = schema_params->beta(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better undertand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return kTfLiteOk; +} + +TfLiteStatus ParseSvdf(const Operator* op, BuiltinOperator, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data) { + TFLITE_DCHECK(op != nullptr); + TFLITE_DCHECK(error_reporter != nullptr); + TFLITE_DCHECK(allocator != nullptr); + TFLITE_DCHECK(builtin_data != nullptr); + + SafeBuiltinDataAllocator safe_allocator(allocator); + std::unique_ptr + params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + + const SVDFOptions* schema_params = op->builtin_options_as_SVDFOptions(); + if (schema_params != nullptr) { + params->rank = schema_params->rank(); + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + params->asymmetric_quantize_inputs = + schema_params->asymmetric_quantize_inputs(); + } else { + // TODO(b/157480169): We should either return kTfLiteError or fill in some + // reasonable defaults in the params struct. We are not doing so until we + // better undertand the ramifications of changing the legacy behavior. + } + + *builtin_data = params.release(); + return kTfLiteOk; +} + TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data) { @@ -153,23 +292,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } return kTfLitePaddingUnknown; }; - auto parse_activation = [](ActivationFunctionType activation) { - switch (activation) { - case ActivationFunctionType_NONE: - return kTfLiteActNone; - case ActivationFunctionType_RELU: - return kTfLiteActRelu; - case ActivationFunctionType_RELU_N1_TO_1: - return kTfLiteActRelu1; - case ActivationFunctionType_RELU6: - return kTfLiteActRelu6; - case ActivationFunctionType_TANH: - return kTfLiteActTanh; - case ActivationFunctionType_SIGN_BIT: - return kTfLiteActSignBit; - } - return kTfLiteActNone; - }; auto parseLSHProjectionType = [](LSHProjectionType type) { switch (type) { case LSHProjectionType_SPARSE: @@ -195,6 +317,29 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, SafeBuiltinDataAllocator safe_allocator(allocator); *builtin_data = nullptr; switch (op_type) { + case BuiltinOperator_DEQUANTIZE: { + return ParseDequantize(op, op_type, error_reporter, allocator, + builtin_data); + } + + case BuiltinOperator_FULLY_CONNECTED: { + return ParseFullyConnected(op, op_type, error_reporter, allocator, + builtin_data); + } + + case BuiltinOperator_QUANTIZE: { + return ParseQuantize(op, op_type, error_reporter, allocator, + builtin_data); + } + + case BuiltinOperator_SOFTMAX: { + return ParseSoftmax(op, op_type, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_SVDF: { + return ParseSvdf(op, op_type, error_reporter, allocator, builtin_data); + } + case BuiltinOperator_CONV_2D: { auto params = safe_allocator.Allocate(); TF_LITE_ENSURE(error_reporter, params != nullptr); @@ -203,7 +348,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->stride_width = conv_params->stride_w(); params->stride_height = conv_params->stride_h(); params->activation = - parse_activation(conv_params->fused_activation_function()); + ConvertActivation(conv_params->fused_activation_function()); params->dilation_width_factor = conv_params->dilation_w_factor(); params->dilation_height_factor = conv_params->dilation_h_factor(); @@ -247,7 +392,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->filter_width = pool_params->filter_width(); params->filter_height = pool_params->filter_height(); params->activation = - parse_activation(pool_params->fused_activation_function()); + ConvertActivation(pool_params->fused_activation_function()); } *builtin_data = params.release(); return kTfLiteOk; @@ -262,7 +407,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->stride_height = conv_params->stride_h(); params->depth_multiplier = conv_params->depth_multiplier(); params->activation = - parse_activation(conv_params->fused_activation_function()); + ConvertActivation(conv_params->fused_activation_function()); params->dilation_width_factor = conv_params->dilation_w_factor(); params->dilation_height_factor = conv_params->dilation_h_factor(); @@ -270,26 +415,13 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } - case BuiltinOperator_SVDF: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* svdf_params = op->builtin_options_as_SVDFOptions()) { - params->rank = svdf_params->rank(); - params->activation = - parse_activation(svdf_params->fused_activation_function()); - params->asymmetric_quantize_inputs = - svdf_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { auto params = safe_allocator.Allocate(); TF_LITE_ENSURE(error_reporter, params != nullptr); if (const auto* sequence_rnn_params = op->builtin_options_as_SequenceRNNOptions()) { params->activation = - parse_activation(sequence_rnn_params->fused_activation_function()); + ConvertActivation(sequence_rnn_params->fused_activation_function()); params->time_major = sequence_rnn_params->time_major(); params->asymmetric_quantize_inputs = sequence_rnn_params->asymmetric_quantize_inputs(); @@ -303,7 +435,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, TF_LITE_ENSURE(error_reporter, params != nullptr); if (const auto* bidi_sequence_rnn_params = op->builtin_options_as_BidirectionalSequenceRNNOptions()) { - params->activation = parse_activation( + params->activation = ConvertActivation( bidi_sequence_rnn_params->fused_activation_function()); params->time_major = bidi_sequence_rnn_params->time_major(); params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); @@ -318,7 +450,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, TF_LITE_ENSURE(error_reporter, params != nullptr); if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) { params->activation = - parse_activation(rnn_params->fused_activation_function()); + ConvertActivation(rnn_params->fused_activation_function()); params->asymmetric_quantize_inputs = rnn_params->asymmetric_quantize_inputs(); } @@ -336,53 +468,17 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } - case BuiltinOperator_FULLY_CONNECTED: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* fully_connected_params = - op->builtin_options_as_FullyConnectedOptions()) { - params->activation = parse_activation( - fully_connected_params->fused_activation_function()); - params->keep_num_dims = fully_connected_params->keep_num_dims(); - params->asymmetric_quantize_inputs = - fully_connected_params->asymmetric_quantize_inputs(); - switch (fully_connected_params->weights_format()) { - case FullyConnectedOptionsWeightsFormat_DEFAULT: - params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; - break; - case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: - params->weights_format = - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; - break; - default: - TF_LITE_REPORT_ERROR(error_reporter, - "Unhandled fully-connected weights format."); - return kTfLiteError; - } - } - *builtin_data = params.release(); - return kTfLiteOk; - } + case BuiltinOperator_HASHTABLE_LOOKUP: // no-op. return kTfLiteOk; - case BuiltinOperator_SOFTMAX: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* softmax_params = - op->builtin_options_as_SoftmaxOptions()) { - params->beta = softmax_params->beta(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } case BuiltinOperator_CONCATENATION: { auto params = safe_allocator.Allocate(); TF_LITE_ENSURE(error_reporter, params != nullptr); if (const auto* concatenation_params = op->builtin_options_as_ConcatenationOptions()) { - params->activation = - parse_activation(concatenation_params->fused_activation_function()); + params->activation = ConvertActivation( + concatenation_params->fused_activation_function()); params->axis = concatenation_params->axis(); } *builtin_data = params.release(); @@ -393,7 +489,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, TF_LITE_ENSURE(error_reporter, params != nullptr); if (const auto* schema_params = op->builtin_options_as_MulOptions()) { params->activation = - parse_activation(schema_params->fused_activation_function()); + ConvertActivation(schema_params->fused_activation_function()); } *builtin_data = params.release(); return kTfLiteOk; @@ -403,7 +499,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, TF_LITE_ENSURE(error_reporter, params != nullptr); if (const auto* schema_params = op->builtin_options_as_AddOptions()) { params->activation = - parse_activation(schema_params->fused_activation_function()); + ConvertActivation(schema_params->fused_activation_function()); } *builtin_data = params.release(); return kTfLiteOk; @@ -413,7 +509,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, TF_LITE_ENSURE(error_reporter, params != nullptr); if (const auto* schema_params = op->builtin_options_as_DivOptions()) { params->activation = - parse_activation(schema_params->fused_activation_function()); + ConvertActivation(schema_params->fused_activation_function()); } *builtin_data = params.release(); return kTfLiteOk; @@ -423,7 +519,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, TF_LITE_ENSURE(error_reporter, params != nullptr); if (const auto* schema_params = op->builtin_options_as_SubOptions()) { params->activation = - parse_activation(schema_params->fused_activation_function()); + ConvertActivation(schema_params->fused_activation_function()); } *builtin_data = params.release(); return kTfLiteOk; @@ -433,7 +529,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, TF_LITE_ENSURE(error_reporter, params != nullptr); if (const auto* schema_params = op->builtin_options_as_L2NormOptions()) { params->activation = - parse_activation(schema_params->fused_activation_function()); + ConvertActivation(schema_params->fused_activation_function()); } *builtin_data = params.release(); return kTfLiteOk; @@ -456,7 +552,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, TF_LITE_ENSURE(error_reporter, params != nullptr); if (const auto* lstm_params = op->builtin_options_as_LSTMOptions()) { params->activation = - parse_activation(lstm_params->fused_activation_function()); + ConvertActivation(lstm_params->fused_activation_function()); params->cell_clip = lstm_params->cell_clip(); params->proj_clip = lstm_params->proj_clip(); switch (lstm_params->kernel_type()) { @@ -489,7 +585,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, if (const auto* seq_lstm_params = op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { params->activation = - parse_activation(seq_lstm_params->fused_activation_function()); + ConvertActivation(seq_lstm_params->fused_activation_function()); params->cell_clip = seq_lstm_params->cell_clip(); params->proj_clip = seq_lstm_params->proj_clip(); params->time_major = seq_lstm_params->time_major(); @@ -506,7 +602,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, if (const auto* bidi_lstm_params = op->builtin_options_as_BidirectionalSequenceLSTMOptions()) { params->activation = - parse_activation(bidi_lstm_params->fused_activation_function()); + ConvertActivation(bidi_lstm_params->fused_activation_function()); params->cell_clip = bidi_lstm_params->cell_clip(); params->proj_clip = bidi_lstm_params->proj_clip(); params->merge_outputs = bidi_lstm_params->merge_outputs(); @@ -857,7 +953,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_CONCAT_EMBEDDINGS: case BuiltinOperator_COS: case BuiltinOperator_CUSTOM: - case BuiltinOperator_DEQUANTIZE: case BuiltinOperator_ELU: case BuiltinOperator_EMBEDDING_LOOKUP: case BuiltinOperator_EQUAL: @@ -913,7 +1008,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_GATHER_ND: case BuiltinOperator_WHERE: case BuiltinOperator_RANK: - case BuiltinOperator_QUANTIZE: case BuiltinOperator_NON_MAX_SUPPRESSION_V4: case BuiltinOperator_NON_MAX_SUPPRESSION_V5: case BuiltinOperator_SCATTER_ND: diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index 2feddfaa8e6..45f2c9df3b7 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -69,6 +69,35 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, ErrorReporter* error_reporter); +// TODO(b/149408647): The (unnecessary) op_type parameter in the functions below +// is to keep the same signature as ParseOpData. This allows for a gradual +// transfer to selective registration of the parse function, but should be +// removed once we are no longer using ParseOpData for the OpResolver +// implementation in micro. + +TfLiteStatus ParseDequantize(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data); + +TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data); + +TfLiteStatus ParseQuantize(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data); + +TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data); + +TfLiteStatus ParseSvdf(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data); + } // namespace tflite #endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_