Split individual builtin op parsing into their own helper functions.

PiperOrigin-RevId: 314955578
Change-Id: I74be4ba92af2e0ed4382f0b4ac8cf22efef0fc0f
This commit is contained in:
Advait Jain 2020-06-05 10:46:54 -07:00 committed by TensorFlower Gardener
parent 2178657934
commit 2438248693
3 changed files with 219 additions and 93 deletions

View File

@ -24,9 +24,12 @@ cc_library(
build_for_embedded = True,
copts = tflite_copts() + micro_copts(),
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
"@flatbuffers//:runtime_cc",
"//tensorflow/lite/c:common",
# TODO(b/158301698): consider moving internal:compatibility to a more
# central location.
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/schema:schema_fbs",
],
)

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
@ -89,6 +90,25 @@ TfLiteStatus FlatBufferIntVectorToArray(
return kTfLiteOk;
}
// Converts the flatbuffer activation to what is used at runtime.
TfLiteFusedActivation ConvertActivation(ActivationFunctionType activation) {
switch (activation) {
case ActivationFunctionType_NONE:
return kTfLiteActNone;
case ActivationFunctionType_RELU:
return kTfLiteActRelu;
case ActivationFunctionType_RELU_N1_TO_1:
return kTfLiteActRelu1;
case ActivationFunctionType_RELU6:
return kTfLiteActRelu6;
case ActivationFunctionType_TANH:
return kTfLiteActTanh;
case ActivationFunctionType_SIGN_BIT:
return kTfLiteActSignBit;
}
return kTfLiteActNone;
}
} // namespace
TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
@ -135,12 +155,131 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
}
}
// Parse the appropriate data out of the op.
//
// This handles builtin data explicitly as there are flatbuffer schemas.
// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
// need to be released by calling `free`.`
// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseDequantize(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
TFLITE_DCHECK(op != nullptr);
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(allocator != nullptr);
TFLITE_DCHECK(builtin_data != nullptr);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteFullyConnectedParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteFullyConnectedParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const FullyConnectedOptions* schema_params =
op->builtin_options_as_FullyConnectedOptions();
if (schema_params != nullptr) {
params->activation =
ConvertActivation(schema_params->fused_activation_function());
params->keep_num_dims = schema_params->keep_num_dims();
params->asymmetric_quantize_inputs =
schema_params->asymmetric_quantize_inputs();
switch (schema_params->weights_format()) {
case FullyConnectedOptionsWeightsFormat_DEFAULT:
params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
break;
case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
params->weights_format =
kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
break;
default:
TF_LITE_REPORT_ERROR(error_reporter,
"Unhandled fully-connected weights format.");
return kTfLiteError;
}
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseQuantize(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
TFLITE_DCHECK(op != nullptr);
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(allocator != nullptr);
TFLITE_DCHECK(builtin_data != nullptr);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteSoftmaxParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteSoftmaxParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const SoftmaxOptions* schema_params = op->builtin_options_as_SoftmaxOptions();
if (schema_params != nullptr) {
params->beta = schema_params->beta();
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
TfLiteStatus ParseSvdf(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
TFLITE_DCHECK(op != nullptr);
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(allocator != nullptr);
TFLITE_DCHECK(builtin_data != nullptr);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteSVDFParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteSVDFParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const SVDFOptions* schema_params = op->builtin_options_as_SVDFOptions();
if (schema_params != nullptr) {
params->rank = schema_params->rank();
params->activation =
ConvertActivation(schema_params->fused_activation_function());
params->asymmetric_quantize_inputs =
schema_params->asymmetric_quantize_inputs();
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
@ -153,23 +292,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
return kTfLitePaddingUnknown;
};
auto parse_activation = [](ActivationFunctionType activation) {
switch (activation) {
case ActivationFunctionType_NONE:
return kTfLiteActNone;
case ActivationFunctionType_RELU:
return kTfLiteActRelu;
case ActivationFunctionType_RELU_N1_TO_1:
return kTfLiteActRelu1;
case ActivationFunctionType_RELU6:
return kTfLiteActRelu6;
case ActivationFunctionType_TANH:
return kTfLiteActTanh;
case ActivationFunctionType_SIGN_BIT:
return kTfLiteActSignBit;
}
return kTfLiteActNone;
};
auto parseLSHProjectionType = [](LSHProjectionType type) {
switch (type) {
case LSHProjectionType_SPARSE:
@ -195,6 +317,29 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
SafeBuiltinDataAllocator safe_allocator(allocator);
*builtin_data = nullptr;
switch (op_type) {
case BuiltinOperator_DEQUANTIZE: {
return ParseDequantize(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_FULLY_CONNECTED: {
return ParseFullyConnected(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_QUANTIZE: {
return ParseQuantize(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_SOFTMAX: {
return ParseSoftmax(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_SVDF: {
return ParseSvdf(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_CONV_2D: {
auto params = safe_allocator.Allocate<TfLiteConvParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
@ -203,7 +348,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->stride_width = conv_params->stride_w();
params->stride_height = conv_params->stride_h();
params->activation =
parse_activation(conv_params->fused_activation_function());
ConvertActivation(conv_params->fused_activation_function());
params->dilation_width_factor = conv_params->dilation_w_factor();
params->dilation_height_factor = conv_params->dilation_h_factor();
@ -247,7 +392,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->filter_width = pool_params->filter_width();
params->filter_height = pool_params->filter_height();
params->activation =
parse_activation(pool_params->fused_activation_function());
ConvertActivation(pool_params->fused_activation_function());
}
*builtin_data = params.release();
return kTfLiteOk;
@ -262,7 +407,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->stride_height = conv_params->stride_h();
params->depth_multiplier = conv_params->depth_multiplier();
params->activation =
parse_activation(conv_params->fused_activation_function());
ConvertActivation(conv_params->fused_activation_function());
params->dilation_width_factor = conv_params->dilation_w_factor();
params->dilation_height_factor = conv_params->dilation_h_factor();
@ -270,26 +415,13 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_SVDF: {
auto params = safe_allocator.Allocate<TfLiteSVDFParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
params->rank = svdf_params->rank();
params->activation =
parse_activation(svdf_params->fused_activation_function());
params->asymmetric_quantize_inputs =
svdf_params->asymmetric_quantize_inputs();
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
auto params = safe_allocator.Allocate<TfLiteSequenceRNNParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* sequence_rnn_params =
op->builtin_options_as_SequenceRNNOptions()) {
params->activation =
parse_activation(sequence_rnn_params->fused_activation_function());
ConvertActivation(sequence_rnn_params->fused_activation_function());
params->time_major = sequence_rnn_params->time_major();
params->asymmetric_quantize_inputs =
sequence_rnn_params->asymmetric_quantize_inputs();
@ -303,7 +435,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* bidi_sequence_rnn_params =
op->builtin_options_as_BidirectionalSequenceRNNOptions()) {
params->activation = parse_activation(
params->activation = ConvertActivation(
bidi_sequence_rnn_params->fused_activation_function());
params->time_major = bidi_sequence_rnn_params->time_major();
params->merge_outputs = bidi_sequence_rnn_params->merge_outputs();
@ -318,7 +450,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) {
params->activation =
parse_activation(rnn_params->fused_activation_function());
ConvertActivation(rnn_params->fused_activation_function());
params->asymmetric_quantize_inputs =
rnn_params->asymmetric_quantize_inputs();
}
@ -336,53 +468,17 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_FULLY_CONNECTED: {
auto params = safe_allocator.Allocate<TfLiteFullyConnectedParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* fully_connected_params =
op->builtin_options_as_FullyConnectedOptions()) {
params->activation = parse_activation(
fully_connected_params->fused_activation_function());
params->keep_num_dims = fully_connected_params->keep_num_dims();
params->asymmetric_quantize_inputs =
fully_connected_params->asymmetric_quantize_inputs();
switch (fully_connected_params->weights_format()) {
case FullyConnectedOptionsWeightsFormat_DEFAULT:
params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
break;
case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
params->weights_format =
kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
break;
default:
TF_LITE_REPORT_ERROR(error_reporter,
"Unhandled fully-connected weights format.");
return kTfLiteError;
}
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_HASHTABLE_LOOKUP:
// no-op.
return kTfLiteOk;
case BuiltinOperator_SOFTMAX: {
auto params = safe_allocator.Allocate<TfLiteSoftmaxParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* softmax_params =
op->builtin_options_as_SoftmaxOptions()) {
params->beta = softmax_params->beta();
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_CONCATENATION: {
auto params = safe_allocator.Allocate<TfLiteConcatenationParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* concatenation_params =
op->builtin_options_as_ConcatenationOptions()) {
params->activation =
parse_activation(concatenation_params->fused_activation_function());
params->activation = ConvertActivation(
concatenation_params->fused_activation_function());
params->axis = concatenation_params->axis();
}
*builtin_data = params.release();
@ -393,7 +489,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* schema_params = op->builtin_options_as_MulOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
ConvertActivation(schema_params->fused_activation_function());
}
*builtin_data = params.release();
return kTfLiteOk;
@ -403,7 +499,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* schema_params = op->builtin_options_as_AddOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
ConvertActivation(schema_params->fused_activation_function());
}
*builtin_data = params.release();
return kTfLiteOk;
@ -413,7 +509,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* schema_params = op->builtin_options_as_DivOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
ConvertActivation(schema_params->fused_activation_function());
}
*builtin_data = params.release();
return kTfLiteOk;
@ -423,7 +519,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* schema_params = op->builtin_options_as_SubOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
ConvertActivation(schema_params->fused_activation_function());
}
*builtin_data = params.release();
return kTfLiteOk;
@ -433,7 +529,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* schema_params = op->builtin_options_as_L2NormOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
ConvertActivation(schema_params->fused_activation_function());
}
*builtin_data = params.release();
return kTfLiteOk;
@ -456,7 +552,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
params->activation =
parse_activation(lstm_params->fused_activation_function());
ConvertActivation(lstm_params->fused_activation_function());
params->cell_clip = lstm_params->cell_clip();
params->proj_clip = lstm_params->proj_clip();
switch (lstm_params->kernel_type()) {
@ -489,7 +585,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
if (const auto* seq_lstm_params =
op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) {
params->activation =
parse_activation(seq_lstm_params->fused_activation_function());
ConvertActivation(seq_lstm_params->fused_activation_function());
params->cell_clip = seq_lstm_params->cell_clip();
params->proj_clip = seq_lstm_params->proj_clip();
params->time_major = seq_lstm_params->time_major();
@ -506,7 +602,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
if (const auto* bidi_lstm_params =
op->builtin_options_as_BidirectionalSequenceLSTMOptions()) {
params->activation =
parse_activation(bidi_lstm_params->fused_activation_function());
ConvertActivation(bidi_lstm_params->fused_activation_function());
params->cell_clip = bidi_lstm_params->cell_clip();
params->proj_clip = bidi_lstm_params->proj_clip();
params->merge_outputs = bidi_lstm_params->merge_outputs();
@ -857,7 +953,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_CONCAT_EMBEDDINGS:
case BuiltinOperator_COS:
case BuiltinOperator_CUSTOM:
case BuiltinOperator_DEQUANTIZE:
case BuiltinOperator_ELU:
case BuiltinOperator_EMBEDDING_LOOKUP:
case BuiltinOperator_EQUAL:
@ -913,7 +1008,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_GATHER_ND:
case BuiltinOperator_WHERE:
case BuiltinOperator_RANK:
case BuiltinOperator_QUANTIZE:
case BuiltinOperator_NON_MAX_SUPPRESSION_V4:
case BuiltinOperator_NON_MAX_SUPPRESSION_V5:
case BuiltinOperator_SCATTER_ND:

View File

@ -69,6 +69,35 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
ErrorReporter* error_reporter);
// TODO(b/149408647): The (unnecessary) op_type parameter in the functions below
// is to keep the same signature as ParseOpData. This allows for a gradual
// transfer to selective registration of the parse function, but should be
// removed once we are no longer using ParseOpData for the OpResolver
// implementation in micro.
TfLiteStatus ParseDequantize(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseQuantize(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSvdf(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
} // namespace tflite
#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_