diff --git a/tensorflow/lite/micro/kernels/fill.cc b/tensorflow/lite/micro/kernels/fill.cc index 6a486109d46..a7839cd1e41 100644 --- a/tensorflow/lite/micro/kernels/fill.cc +++ b/tensorflow/lite/micro/kernels/fill.cc @@ -23,9 +23,6 @@ limitations under the License. #include "tensorflow/lite/string_util.h" namespace tflite { -namespace ops { -namespace builtin { -namespace fill { namespace { @@ -33,41 +30,6 @@ constexpr int kDimsTensor = 0; constexpr int kValueTensor = 1; constexpr int kOutputTensor = 0; -template <typename T> -TfLiteStatus ResizeOutputImpl(TfLiteContext* context, const TfLiteTensor* dims, - TfLiteTensor* output) { - TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dims->dims->data[0]); - for (int i = 0; i < output_shape->size; ++i) { - T data = GetTensorData<T>(dims)[i]; - if (data < 0) { - TfLiteIntArrayFree(output_shape); - TF_LITE_KERNEL_LOG(context, "Fill dimensions must be >= 0", - TfLiteTypeGetName(dims->type)); - return kTfLiteError; - } - output_shape->data[i] = data; - } - return context->ResizeTensor(context, output, output_shape); -} - -TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* dims, - TfLiteTensor* output) { - switch (dims->type) { - case kTfLiteInt32: - return ResizeOutputImpl<int32_t>(context, dims, output); - case kTfLiteInt64: - return ResizeOutputImpl<int64_t>(context, dims, output); - default: - TF_LITE_KERNEL_LOG( - context, - "Fill only currently supports int32, int64 for input 0, got %s.", - TfLiteTypeGetName(dims->type)); - return kTfLiteError; - } -} - -} // namespace - TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -100,20 +62,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) { - DynamicBuffer buffer; - const auto string_ref = GetString(value, 0); - int n = 1; - for (int i = 0; i < output->dims->size; ++i) { - n *= output->dims->data[i]; - } - for (int i = 0; i < n; ++i) { - buffer.AddString(string_ref.str, string_ref.len); - } - buffer.WriteToTensor(output, /*new_shape=*/nullptr); - return kTfLiteOk; -} - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* value; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value)); @@ -132,26 +80,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorShape(output), \ GetTensorData<data_type>(output)) switch (output->type) { - case kTfLiteInt32: - TF_LITE_FILL(int32_t); - break; - case kTfLiteInt64: - TF_LITE_FILL(int64_t); - break; case kTfLiteFloat32: TF_LITE_FILL(float); break; - case kTfLiteBool: - TF_LITE_FILL(bool); - break; - case kTfLiteString: - FillString(value, output); - break; default: TF_LITE_KERNEL_LOG( - context, - "Fill only currently supports int32, int64, float32, bool, string " - "for input 1, got %d.", + context, "Fill only currently supports float32 for input 1, got %d.", TfLiteTypeGetName(value->type)); return kTfLiteError; } @@ -159,7 +93,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace fill +} // namespace TfLiteRegistration* Register_FILL() { static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, @@ -167,6 +101,4 @@ TfLiteRegistration* Register_FILL() { return &r; } -} // namespace builtin -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/fill_test.cc b/tensorflow/lite/micro/kernels/fill_test.cc index ff249920f1d..77c9baf4028 100644 --- a/tensorflow/lite/micro/kernels/fill_test.cc +++ b/tensorflow/lite/micro/kernels/fill_test.cc @@ -12,85 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <stdint.h> -#include <initializer_list> -#include <string> -#include <vector> - -#include "tensorflow/lite/kernels/test_util.h" -#include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/string_type.h" - -namespace tflite { namespace { -using ::testing::ElementsAreArray; -using ::testing::IsEmpty; - -enum class TestType { - kConst = 0, - kDynamic = 1, -}; - -template <typename dims_type, typename value_type> -class FillOpModel : public SingleOpModel { - public: - explicit FillOpModel(TensorType dims_tensor_type, - std::initializer_list<int> dims_shape, - std::initializer_list<dims_type> dims_data, - value_type value, TestType input_tensor_types) { - if (input_tensor_types == TestType::kDynamic) { - dims_ = AddInput(dims_tensor_type); - value_ = AddInput(GetTensorType<value_type>()); - } else { - dims_ = AddConstInput(dims_tensor_type, dims_data, dims_shape); - value_ = AddConstInput(GetTensorType<value_type>(), {value}, {}); - } - output_ = AddOutput(GetTensorType<value_type>()); - SetBuiltinOp(BuiltinOperator_FILL, BuiltinOptions_FillOptions, - CreateFillOptions(builder_).Union()); - BuildInterpreter({dims_shape, {}}); - - if (input_tensor_types == TestType::kDynamic) { - if (dims_data.size() > 0) { - PopulateTensor<dims_type>(dims_, dims_data); - } - PopulateTensor<value_type>(value_, {value}); - } - } - - std::vector<value_type> GetOutput() { - return ExtractVector<value_type>(output_); - } - std::vector<int> GetOutputShape() { return GetTensorShape(output_); } - - protected: - int dims_; - int value_; - int output_; -}; - -class FillOpTest : public ::testing::TestWithParam<TestType> {}; - -TEST_P(FillOpTest, FillInt32) { - FillOpModel<int32_t, int32_t> m(TensorType_INT32, {2}, {2, 3}, -11, - GetParam()); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({-11, -11, -11, -11, -11, -11})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); -} - -TEST_P(FillOpTest, FillInt64) { - FillOpModel<int64_t, int64_t> m(TensorType_INT64, {2}, {2, 4}, 1LL << 45, - GetParam()); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray({1LL << 45, 1LL << 45, 1LL << 45, 1LL << 45, - 1LL << 45, 1LL << 45, 1LL << 45, 1LL << 45})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 4})); -} - TEST_P(FillOpTest, FillFloat) { FillOpModel<int64_t, float> m(TensorType_INT64, {3}, {2, 2, 2}, 4.0, GetParam()); @@ -116,27 +40,4 @@ TEST_P(FillOpTest, FillOutputScalar) { EXPECT_THAT(m.GetOutputShape(), IsEmpty()); } -TEST_P(FillOpTest, FillBool) { - FillOpModel<int64_t, bool> m(TensorType_INT64, {3}, {2, 2, 2}, true, - GetParam()); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({true, true, true, true, true, - true, true, true})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); -} - -TEST(FillOpTest, FillString) { - FillOpModel<int64_t, std::string> m(TensorType_INT64, {3}, {2, 2, 2}, "AB", - TestType::kDynamic); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({"AB", "AB", "AB", "AB", "AB", - "AB", "AB", "AB"})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); -} - -INSTANTIATE_TEST_SUITE_P(FillOpTest, FillOpTest, - ::testing::Values(TestType::kConst, - TestType::kDynamic)); - } // namespace -} // namespace tflite