diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 4fa3f87ed15..a2f581fa89f 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -253,6 +253,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, return ParsePool(op, error_reporter, allocator, builtin_data); } + case BuiltinOperator_LEAKY_RELU: { + return ParseLeakyRelu(op, error_reporter, allocator, builtin_data); + } + case BuiltinOperator_LESS: { return ParseLess(op, error_reporter, allocator, builtin_data); } @@ -682,16 +686,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } - case BuiltinOperator_LEAKY_RELU: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* leaky_relu_params = - op->builtin_options_as_LeakyReluOptions()) { - params->alpha = leaky_relu_params->alpha(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } case BuiltinOperator_MIRROR_PAD: { auto params = safe_allocator.Allocate(); TF_LITE_ENSURE(error_reporter, params != nullptr); @@ -1266,6 +1260,22 @@ TfLiteStatus ParseL2Normalization(const Operator* op, return kTfLiteOk; } +TfLiteStatus ParseLeakyRelu(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, error_reporter, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* leaky_relu_params = + op->builtin_options_as_LeakyReluOptions()) { + params->alpha = leaky_relu_params->alpha(); + } + *builtin_data = params.release(); + return kTfLiteOk; +} + // We have this parse function instead of directly returning kTfLiteOk from the // switch-case in ParseOpData because this function is used as part of the // selective registration for the OpResolver implementation in micro. diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index e903f40790f..45041ad8543 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -155,6 +155,10 @@ TfLiteStatus ParseL2Normalization(const Operator* op, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseLeakyRelu(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data); + TfLiteStatus ParseLess(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data);