From b5ec87165f54923ccf7ac4e474bfba6d53121d9a Mon Sep 17 00:00:00 2001 From: rsun Date: Thu, 7 Jan 2021 17:55:27 -0800 Subject: [PATCH 1/2] Refactor for BuiltinOperator_GATHER_ND in flatbuffer_conversions --- tensorflow/lite/core/api/flatbuffer_conversions.cc | 13 ++++++++++++- tensorflow/lite/core/api/flatbuffer_conversions.h | 3 +++ tensorflow/lite/kernels/expand_dims.cc | 3 +-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index fcd043734ff..519877a76f5 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -229,6 +229,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, return ParseFullyConnected(op, error_reporter, allocator, builtin_data); } + case BuiltinOperator_GATHER_ND: { + return ParseGatherNd(op, error_reporter, allocator, builtin_data); + } + case BuiltinOperator_GREATER: { return ParseGreater(op, error_reporter, allocator, builtin_data); } @@ -805,7 +809,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_SQUARED_DIFFERENCE: case BuiltinOperator_REVERSE_V2: case BuiltinOperator_ADD_N: - case BuiltinOperator_GATHER_ND: case BuiltinOperator_WHERE: case BuiltinOperator_RANK: case BuiltinOperator_NON_MAX_SUPPRESSION_V4: @@ -1207,6 +1210,14 @@ TfLiteStatus ParseFullyConnected(const Operator* op, return kTfLiteOk; } +// We have this parse function instead of directly returning kTfLiteOk from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +TfLiteStatus ParseGatherNd(const Operator*, ErrorReporter*, + BuiltinDataAllocator*, void**) { + return kTfLiteOk; +} + // We have this parse function instead of directly returning kTfLiteOk from the // switch-case in ParseOpData because this function is used as part of the // selective registration for the OpResolver implementation in micro. diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index 9f6131fa9c4..b00dfd0c1b3 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -135,6 +135,9 @@ TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseGatherNd(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data); + TfLiteStatus ParseGreater(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data); diff --git a/tensorflow/lite/kernels/expand_dims.cc b/tensorflow/lite/kernels/expand_dims.cc index 950131c8d69..231ba6df8ba 100644 --- a/tensorflow/lite/kernels/expand_dims.cc +++ b/tensorflow/lite/kernels/expand_dims.cc @@ -1,5 +1,3 @@ - -#include /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/lite/c/common.h" From 77fc5b5049c384b1c94e0e082df0f15bd77cf4d3 Mon Sep 17 00:00:00 2001 From: rsun Date: Tue, 19 Jan 2021 14:36:27 -0800 Subject: [PATCH 2/2] Resolve conflicts in flatbuffer_conversions between GATHER and GATHER_ND --- .../lite/core/api/flatbuffer_conversions.cc | 19 +++++++++++++++++++ .../lite/core/api/flatbuffer_conversions.h | 3 +++ 2 files changed, 22 insertions(+) diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 519877a76f5..5072c1aa0dd 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -1210,6 +1210,25 @@ TfLiteStatus ParseFullyConnected(const Operator* op, return kTfLiteOk; } +// We have this parse function instead of directly returning kTfLiteOk from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +TfLiteStatus ParseGather(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data) { + CheckParsePointerParams(op, error_reporter, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + params->axis = 0; + if (const auto* gather_params = op->builtin_options_as_GatherOptions()) { + params->axis = gather_params->axis(); + } + + *builtin_data = params.release(); + return kTfLiteOk; +} + // We have this parse function instead of directly returning kTfLiteOk from the // switch-case in ParseOpData because this function is used as part of the // selective registration for the OpResolver implementation in micro. diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index b00dfd0c1b3..2f75959c680 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -135,6 +135,9 @@ TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseGather(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data); + TfLiteStatus ParseGatherNd(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data);