From c9826766bc9ef62dc49962788b5c388d27afb134 Mon Sep 17 00:00:00 2001 From: rsun Date: Wed, 6 Jan 2021 17:01:47 -0800 Subject: [PATCH 1/3] Copy lite/kernels/gather.cc to lite/micro/kernels/gather.cc w/o any change --- tensorflow/lite/micro/kernels/gather.cc | 212 ++++++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 tensorflow/lite/micro/kernels/gather.cc diff --git a/tensorflow/lite/micro/kernels/gather.cc b/tensorflow/lite/micro/kernels/gather.cc new file mode 100644 index 00000000000..57ac9c267e9 --- /dev/null +++ b/tensorflow/lite/micro/kernels/gather.cc @@ -0,0 +1,212 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace gather { +constexpr int kInputTensor = 0; +constexpr int kInputPositions = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const auto* params = + reinterpret_cast(node->builtin_data); + const TfLiteTensor* input; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); + const TfLiteTensor* positions; + TF_LITE_ENSURE_OK(context, + GetInputSafe(context, node, kInputPositions, &positions)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, + GetOutputSafe(context, node, kOutputTensor, &output)); + + switch (positions->type) { + case kTfLiteInt64: + case kTfLiteInt32: + break; + default: + context->ReportError( + context, "Positions of type '%s' are not supported by gather.", + TfLiteTypeGetName(positions->type)); + return kTfLiteError; + } + + // Assign to output the input type. + output->type = input->type; + + // Check conditions for different types. + switch (input->type) { + case kTfLiteFloat32: + case kTfLiteUInt8: + case kTfLiteInt8: + case kTfLiteInt16: + case kTfLiteInt64: + case kTfLiteInt32: + case kTfLiteBool: + break; + case kTfLiteString: { + // Only 1D input is supported. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + } break; + default: + context->ReportError(context, "Type '%s' is not supported by gather.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + + int axis = params->axis; + if (axis < 0) { + axis += NumDimensions(input); + } + TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input)); + + const int num_dimensions = + NumDimensions(input) + NumDimensions(positions) - 1; + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + int output_index = 0; + for (int i = 0; i < axis; ++i) { + output_shape->data[output_index++] = input->dims->data[i]; + } + for (int i = 0; i < positions->dims->size; ++i) { + output_shape->data[output_index++] = positions->dims->data[i]; + } + for (int i = axis + 1; i < input->dims->size; ++i) { + output_shape->data[output_index++] = input->dims->data[i]; + } + return context->ResizeTensor(context, output, output_shape); +} + +template +TfLiteStatus Gather(const TfLiteGatherParams& params, const TfLiteTensor* input, + const TfLiteTensor* positions, TfLiteTensor* output) { + tflite::GatherParams op_params; + op_params.axis = params.axis; + optimized_ops::Gather(op_params, GetTensorShape(input), + GetTensorData(input), GetTensorShape(positions), + GetTensorData(positions), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; +} + +template +TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input, + const TfLiteTensor* positions, + TfLiteTensor* output) { + DynamicBuffer buffer; + const PositionT* indexes = GetTensorData(positions); + const PositionT num_strings = GetStringCount(input); + const int num_indexes = NumElements(positions); + + for (int i = 0; i < num_indexes; ++i) { + const PositionT pos = indexes[i]; + TF_LITE_ENSURE(context, pos < num_strings); + const auto string_ref = GetString(input, pos); + buffer.AddString(string_ref.str, string_ref.len); + } + buffer.WriteToTensor(output, /*new_shape=*/nullptr); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* params = + reinterpret_cast(node->builtin_data); + const TfLiteTensor* input; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); + const TfLiteTensor* positions; + TF_LITE_ENSURE_OK(context, + GetInputSafe(context, node, kInputPositions, &positions)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, + GetOutputSafe(context, node, kOutputTensor, &output)); + + if (positions->type == kTfLiteInt32) { + switch (input->type) { + case kTfLiteFloat32: + return Gather(*params, input, positions, output); + case kTfLiteUInt8: + return Gather(*params, input, positions, output); + case kTfLiteInt8: + return Gather(*params, input, positions, output); + case kTfLiteInt16: + return Gather(*params, input, positions, output); + case kTfLiteInt32: + return Gather(*params, input, positions, output); + case kTfLiteInt64: + return Gather(*params, input, positions, output); + case kTfLiteBool: + return Gather(*params, input, positions, output); + case kTfLiteString: + return GatherStrings(context, input, positions, output); + default: + context->ReportError(context, "Type '%s' is not supported by gather.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + } + if (positions->type == kTfLiteInt64) { + switch (input->type) { + case kTfLiteFloat32: + return Gather(*params, input, positions, output); + case kTfLiteUInt8: + return Gather(*params, input, positions, output); + case kTfLiteInt8: + return Gather(*params, input, positions, output); + case kTfLiteInt16: + return Gather(*params, input, positions, output); + case kTfLiteInt32: + return Gather(*params, input, positions, output); + case kTfLiteInt64: + return Gather(*params, input, positions, output); + case kTfLiteBool: + return Gather(*params, input, positions, output); + case kTfLiteString: + return GatherStrings(context, input, positions, output); + default: + context->ReportError(context, "Type '%s' is not supported by gather.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + } + context->ReportError(context, + "Positions of type '%s' are not supported by gather.", + TfLiteTypeGetName(positions->type)); + return kTfLiteError; +} +} // namespace gather + +TfLiteRegistration* Register_GATHER() { + static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare, + gather::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite From 0b206e600d77b572afcd770837a4c9dd93c5c547 Mon Sep 17 00:00:00 2001 From: rsun Date: Mon, 11 Jan 2021 17:03:11 -0800 Subject: [PATCH 2/3] Replace context->ReportError with TF_LITE_KERNEL_LOG in lite/micro/kernels/gather.cc --- tensorflow/lite/micro/kernels/gather.cc | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/micro/kernels/gather.cc b/tensorflow/lite/micro/kernels/gather.cc index 57ac9c267e9..83487d4a58b 100644 --- a/tensorflow/lite/micro/kernels/gather.cc +++ b/tensorflow/lite/micro/kernels/gather.cc @@ -52,7 +52,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt32: break; default: - context->ReportError( + TF_LITE_KERNEL_LOG( context, "Positions of type '%s' are not supported by gather.", TfLiteTypeGetName(positions->type)); return kTfLiteError; @@ -76,8 +76,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); } break; default: - context->ReportError(context, "Type '%s' is not supported by gather.", - TfLiteTypeGetName(input->type)); + TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.", + TfLiteTypeGetName(input->type)); return kTfLiteError; } @@ -165,8 +165,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteString: return GatherStrings(context, input, positions, output); default: - context->ReportError(context, "Type '%s' is not supported by gather.", - TfLiteTypeGetName(input->type)); + TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.", + TfLiteTypeGetName(input->type)); return kTfLiteError; } } @@ -189,14 +189,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteString: return GatherStrings(context, input, positions, output); default: - context->ReportError(context, "Type '%s' is not supported by gather.", - TfLiteTypeGetName(input->type)); + TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.", + TfLiteTypeGetName(input->type)); return kTfLiteError; } } - context->ReportError(context, - "Positions of type '%s' are not supported by gather.", - TfLiteTypeGetName(positions->type)); + TF_LITE_KERNEL_LOG(context, + "Positions of type '%s' are not supported by gather.", + TfLiteTypeGetName(positions->type)); return kTfLiteError; } } // namespace gather From 64c09bb8db3e86dbed71ef68075c13b16c1442b4 Mon Sep 17 00:00:00 2001 From: rsun Date: Mon, 18 Jan 2021 11:18:13 -0800 Subject: [PATCH 3/3] Fix formatting for lite/micro/kernels/gather.cc --- tensorflow/lite/micro/kernels/gather.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/kernels/gather.cc b/tensorflow/lite/micro/kernels/gather.cc index 83487d4a58b..22020e551c2 100644 --- a/tensorflow/lite/micro/kernels/gather.cc +++ b/tensorflow/lite/micro/kernels/gather.cc @@ -52,9 +52,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt32: break; default: - TF_LITE_KERNEL_LOG( - context, "Positions of type '%s' are not supported by gather.", - TfLiteTypeGetName(positions->type)); + TF_LITE_KERNEL_LOG(context, + "Positions of type '%s' are not supported by gather.", + TfLiteTypeGetName(positions->type)); return kTfLiteError; }