From bfb4bda0ffd2539d64f755f65f088c1c6aca2b6d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 6 Nov 2018 14:49:50 -0800 Subject: [PATCH] tile: add support for booleans PiperOrigin-RevId: 220356963 --- tensorflow/lite/kernels/tile.cc | 3 + tensorflow/lite/kernels/tile_test.cc | 199 +++++++++++++-------------- 2 files changed, 101 insertions(+), 101 deletions(-) diff --git a/tensorflow/lite/kernels/tile.cc b/tensorflow/lite/kernels/tile.cc index 6d13f9e92f9..1b747974743 100644 --- a/tensorflow/lite/kernels/tile.cc +++ b/tensorflow/lite/kernels/tile.cc @@ -182,6 +182,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: Tile(*(input->dims), input, multipliers, output); break; + case kTfLiteBool: + Tile(*(input->dims), input, multipliers, output); + break; default: context->ReportError(context, "Type '%s' is not supported by tile.", TfLiteTypeGetName(output->type)); diff --git a/tensorflow/lite/kernels/tile_test.cc b/tensorflow/lite/kernels/tile_test.cc index d12a7c19a36..a88ff66f075 100644 --- a/tensorflow/lite/kernels/tile_test.cc +++ b/tensorflow/lite/kernels/tile_test.cc @@ -34,34 +34,18 @@ class TileOpModel : public SingleOpModel { BuildInterpreter({input_shape, {static_cast(input_shape.size())}}); } - void SetInputFloat(std::initializer_list data) { - PopulateTensor(input_, data); - } - - void SetInputUInt8(std::initializer_list data) { - PopulateTensor(input_, data); - } - - void SetInputInt32(std::initializer_list data) { - PopulateTensor(input_, data); - } - - void SetInputInt64(std::initializer_list data) { - PopulateTensor(input_, data); + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); } void SetMultipliers(std::initializer_list data) { PopulateTensor(multipliers_, data); } - std::vector GetOutputFloat() { return ExtractVector(output_); } - - std::vector GetOutputUInt8() { return ExtractVector(output_); } - - std::vector GetOutputInt32() { return ExtractVector(output_); } - - std::vector GetOutputInt64() { - return ExtractVector(output_); + template + std::vector GetOutput() { + return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } @@ -74,16 +58,16 @@ class TileOpModel : public SingleOpModel { TEST(TileTest, Float32Vector) { TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32); - m.SetInputFloat({1.f, 2.f, 3.f}); + m.SetInput({1.f, 2.f, 3.f}); m.SetMultipliers({2}); m.Invoke(); - EXPECT_THAT(m.GetOutputFloat(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f})); } TEST(TileTest, Float32Matrix) { TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32); - m.SetInputFloat({ + m.SetInput({ 11.f, 12.f, 13.f, @@ -93,26 +77,26 @@ TEST(TileTest, Float32Matrix) { }); m.SetMultipliers({2, 1}); m.Invoke(); - EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray({ - 11.f, - 12.f, - 13.f, - 21.f, - 22.f, - 23.f, - 11.f, - 12.f, - 13.f, - 21.f, - 22.f, - 23.f, - })); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + })); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); } TEST(TileTest, Float32HighDimension) { TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32); - m.SetInputFloat({ + m.SetInput({ 11.f, 12.f, 13.f, @@ -123,7 +107,7 @@ TEST(TileTest, Float32HighDimension) { m.SetMultipliers({2, 3, 1}); m.Invoke(); EXPECT_THAT( - m.GetOutputFloat(), + m.GetOutput(), ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, @@ -133,7 +117,7 @@ TEST(TileTest, Float32HighDimension) { TEST(TileTest, Uint8Matrix) { TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32); - m.SetInputUInt8({ + m.SetInput({ 11, 12, 13, @@ -143,26 +127,26 @@ TEST(TileTest, Uint8Matrix) { }); m.SetMultipliers({2, 1}); m.Invoke(); - EXPECT_THAT(m.GetOutputUInt8(), ElementsAreArray({ - 11, - 12, - 13, - 21, - 22, - 23, - 11, - 12, - 13, - 21, - 22, - 23, - })); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); } TEST(TileTest, Int32Matrix) { TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32); - m.SetInputInt32({ + m.SetInput({ 11, 12, 13, @@ -172,26 +156,39 @@ TEST(TileTest, Int32Matrix) { }); m.SetMultipliers({2, 1}); m.Invoke(); - EXPECT_THAT(m.GetOutputInt32(), ElementsAreArray({ - 11, - 12, - 13, - 21, - 22, - 23, - 11, - 12, - 13, - 21, - 22, - 23, - })); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, BooleanMatrix) { + TileOpModel m({2, 3}, TensorType_BOOL, TensorType_INT32); + m.SetInput({true, false, false, true, true, false}); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({ + true, false, false, true, true, false, // first tiletrue, + true, false, false, true, true, false // second tile + })); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); } TEST(TileTest, Int64Matrix) { TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32); - m.SetInputInt64({ + m.SetInput({ 11, 12, 13, @@ -201,26 +198,26 @@ TEST(TileTest, Int64Matrix) { }); m.SetMultipliers({2, 1}); m.Invoke(); - EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({ - 11, - 12, - 13, - 21, - 22, - 23, - 11, - 12, - 13, - 21, - 22, - 23, - })); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); } TEST(TileTest, Int64Matrix64Multipliers) { TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64); - m.SetInputInt64({ + m.SetInput({ 11, 12, 13, @@ -230,20 +227,20 @@ TEST(TileTest, Int64Matrix64Multipliers) { }); m.SetMultipliers({2, 1}); m.Invoke(); - EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({ - 11, - 12, - 13, - 21, - 22, - 23, - 11, - 12, - 13, - 21, - 22, - 23, - })); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); } } // namespace