From a5a87b420b87a1f832e241db3a5b724207ea700a Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Fri, 11 Dec 2020 10:48:06 -0600 Subject: [PATCH] micro: remove lite-specific code from copy of FILL Remove the bulk of lite-specific code from the micro implementation of operator FILL. - Flatten the namespace - Don't resize output tensors - Remove input and output types other than int8 and float32 - Don't use gtest --- tensorflow/lite/micro/kernels/fill.cc | 71 +--------------- tensorflow/lite/micro/kernels/fill_test.cc | 99 ---------------------- 2 files changed, 2 insertions(+), 168 deletions(-) diff --git a/tensorflow/lite/micro/kernels/fill.cc b/tensorflow/lite/micro/kernels/fill.cc index 6a486109d46..af2e46afae0 100644 --- a/tensorflow/lite/micro/kernels/fill.cc +++ b/tensorflow/lite/micro/kernels/fill.cc @@ -23,9 +23,6 @@ limitations under the License. #include "tensorflow/lite/string_util.h" namespace tflite { -namespace ops { -namespace builtin { -namespace fill { namespace { @@ -33,41 +30,6 @@ constexpr int kDimsTensor = 0; constexpr int kValueTensor = 1; constexpr int kOutputTensor = 0; -template -TfLiteStatus ResizeOutputImpl(TfLiteContext* context, const TfLiteTensor* dims, - TfLiteTensor* output) { - TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dims->dims->data[0]); - for (int i = 0; i < output_shape->size; ++i) { - T data = GetTensorData(dims)[i]; - if (data < 0) { - TfLiteIntArrayFree(output_shape); - TF_LITE_KERNEL_LOG(context, "Fill dimensions must be >= 0", - TfLiteTypeGetName(dims->type)); - return kTfLiteError; - } - output_shape->data[i] = data; - } - return context->ResizeTensor(context, output, output_shape); -} - -TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* dims, - TfLiteTensor* output) { - switch (dims->type) { - case kTfLiteInt32: - return ResizeOutputImpl(context, dims, output); - case kTfLiteInt64: - return ResizeOutputImpl(context, dims, output); - default: - TF_LITE_KERNEL_LOG( - context, - "Fill only currently supports int32, int64 for input 0, got %s.", - TfLiteTypeGetName(dims->type)); - return kTfLiteError; - } -} - -} // namespace - TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -100,20 +62,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) { - DynamicBuffer buffer; - const auto string_ref = GetString(value, 0); - int n = 1; - for (int i = 0; i < output->dims->size; ++i) { - n *= output->dims->data[i]; - } - for (int i = 0; i < n; ++i) { - buffer.AddString(string_ref.str, string_ref.len); - } - buffer.WriteToTensor(output, /*new_shape=*/nullptr); - return kTfLiteOk; -} - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* value; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value)); @@ -132,26 +80,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorShape(output), \ GetTensorData(output)) switch (output->type) { - case kTfLiteInt32: - TF_LITE_FILL(int32_t); - break; - case kTfLiteInt64: - TF_LITE_FILL(int64_t); - break; case kTfLiteFloat32: TF_LITE_FILL(float); break; - case kTfLiteBool: - TF_LITE_FILL(bool); - break; - case kTfLiteString: - FillString(value, output); - break; default: TF_LITE_KERNEL_LOG( context, - "Fill only currently supports int32, int64, float32, bool, string " - "for input 1, got %d.", + "Fill only currently supports float32 for input 1, got %d.", TfLiteTypeGetName(value->type)); return kTfLiteError; } @@ -159,7 +94,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace fill +} // namespace TfLiteRegistration* Register_FILL() { static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, @@ -167,6 +102,4 @@ TfLiteRegistration* Register_FILL() { return &r; } -} // namespace builtin -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/fill_test.cc b/tensorflow/lite/micro/kernels/fill_test.cc index ff249920f1d..77c9baf4028 100644 --- a/tensorflow/lite/micro/kernels/fill_test.cc +++ b/tensorflow/lite/micro/kernels/fill_test.cc @@ -12,85 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include -#include - -#include "tensorflow/lite/kernels/test_util.h" -#include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/string_type.h" - -namespace tflite { namespace { -using ::testing::ElementsAreArray; -using ::testing::IsEmpty; - -enum class TestType { - kConst = 0, - kDynamic = 1, -}; - -template -class FillOpModel : public SingleOpModel { - public: - explicit FillOpModel(TensorType dims_tensor_type, - std::initializer_list dims_shape, - std::initializer_list dims_data, - value_type value, TestType input_tensor_types) { - if (input_tensor_types == TestType::kDynamic) { - dims_ = AddInput(dims_tensor_type); - value_ = AddInput(GetTensorType()); - } else { - dims_ = AddConstInput(dims_tensor_type, dims_data, dims_shape); - value_ = AddConstInput(GetTensorType(), {value}, {}); - } - output_ = AddOutput(GetTensorType()); - SetBuiltinOp(BuiltinOperator_FILL, BuiltinOptions_FillOptions, - CreateFillOptions(builder_).Union()); - BuildInterpreter({dims_shape, {}}); - - if (input_tensor_types == TestType::kDynamic) { - if (dims_data.size() > 0) { - PopulateTensor(dims_, dims_data); - } - PopulateTensor(value_, {value}); - } - } - - std::vector GetOutput() { - return ExtractVector(output_); - } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - protected: - int dims_; - int value_; - int output_; -}; - -class FillOpTest : public ::testing::TestWithParam {}; - -TEST_P(FillOpTest, FillInt32) { - FillOpModel m(TensorType_INT32, {2}, {2, 3}, -11, - GetParam()); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({-11, -11, -11, -11, -11, -11})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); -} - -TEST_P(FillOpTest, FillInt64) { - FillOpModel m(TensorType_INT64, {2}, {2, 4}, 1LL << 45, - GetParam()); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray({1LL << 45, 1LL << 45, 1LL << 45, 1LL << 45, - 1LL << 45, 1LL << 45, 1LL << 45, 1LL << 45})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 4})); -} - TEST_P(FillOpTest, FillFloat) { FillOpModel m(TensorType_INT64, {3}, {2, 2, 2}, 4.0, GetParam()); @@ -116,27 +40,4 @@ TEST_P(FillOpTest, FillOutputScalar) { EXPECT_THAT(m.GetOutputShape(), IsEmpty()); } -TEST_P(FillOpTest, FillBool) { - FillOpModel m(TensorType_INT64, {3}, {2, 2, 2}, true, - GetParam()); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({true, true, true, true, true, - true, true, true})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); -} - -TEST(FillOpTest, FillString) { - FillOpModel m(TensorType_INT64, {3}, {2, 2, 2}, "AB", - TestType::kDynamic); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({"AB", "AB", "AB", "AB", "AB", - "AB", "AB", "AB"})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); -} - -INSTANTIATE_TEST_SUITE_P(FillOpTest, FillOpTest, - ::testing::Values(TestType::kConst, - TestType::kDynamic)); - } // namespace -} // namespace tflite