diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 59341c84958..012d9153b2d 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -608,15 +608,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, return kTfLiteOk; } case BuiltinOperator_GATHER: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - params->axis = 0; - if (const auto* gather_params = op->builtin_options_as_GatherOptions()) { - params->axis = gather_params->axis(); - } - - *builtin_data = params.release(); - return kTfLiteOk; + return ParseGather(op, error_reporter, allocator, builtin_data); } case BuiltinOperator_SQUEEZE: { @@ -1196,6 +1188,27 @@ TfLiteStatus ParseFullyConnected(const Operator* op, return kTfLiteOk; } +// We have this parse function instead of directly returning kTfLiteOk from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +TfLiteStatus ParseGather(const Operator* op, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, error_reporter, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + params->axis = 0; + if (const auto* gather_params = op->builtin_options_as_GatherOptions()) { + params->axis = gather_params->axis(); + } + + *builtin_data = params.release(); + return kTfLiteOk; +} + // We have this parse function instead of directly returning kTfLiteOk from the // switch-case in ParseOpData because this function is used as part of the // selective registration for the OpResolver implementation in micro. diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index 8b4a0267e82..0493fa69228 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -131,6 +131,9 @@ TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseGather(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data); + TfLiteStatus ParseGreater(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data);