tile: add support for booleans

PiperOrigin-RevId: 220356963
This commit is contained in:
A. Unique TensorFlower 2018-11-06 14:49:50 -08:00 committed by TensorFlower Gardener
parent bdd6af98f0
commit bfb4bda0ff
2 changed files with 101 additions and 101 deletions

View File

@ -182,6 +182,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64:
Tile<int64_t>(*(input->dims), input, multipliers, output);
break;
case kTfLiteBool:
Tile<bool>(*(input->dims), input, multipliers, output);
break;
default:
context->ReportError(context, "Type '%s' is not supported by tile.",
TfLiteTypeGetName(output->type));

View File

@ -34,34 +34,18 @@ class TileOpModel : public SingleOpModel {
BuildInterpreter({input_shape, {static_cast<int>(input_shape.size())}});
}
void SetInputFloat(std::initializer_list<float> data) {
PopulateTensor<float>(input_, data);
}
void SetInputUInt8(std::initializer_list<uint8_t> data) {
PopulateTensor<uint8_t>(input_, data);
}
void SetInputInt32(std::initializer_list<int32_t> data) {
PopulateTensor<int32_t>(input_, data);
}
void SetInputInt64(std::initializer_list<int64_t> data) {
PopulateTensor<int64_t>(input_, data);
template <typename T>
void SetInput(std::initializer_list<T> data) {
PopulateTensor<T>(input_, data);
}
void SetMultipliers(std::initializer_list<int32_t> data) {
PopulateTensor<int32_t>(multipliers_, data);
}
std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); }
std::vector<uint8_t> GetOutputUInt8() { return ExtractVector<uint8_t>(output_); }
std::vector<int32_t> GetOutputInt32() { return ExtractVector<int32_t>(output_); }
std::vector<int64_t> GetOutputInt64() {
return ExtractVector<int64_t>(output_);
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
@ -74,16 +58,16 @@ class TileOpModel : public SingleOpModel {
TEST(TileTest, Float32Vector) {
TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32);
m.SetInputFloat({1.f, 2.f, 3.f});
m.SetInput<float>({1.f, 2.f, 3.f});
m.SetMultipliers({2});
m.Invoke();
EXPECT_THAT(m.GetOutputFloat(),
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f}));
}
TEST(TileTest, Float32Matrix) {
TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32);
m.SetInputFloat({
m.SetInput<float>({
11.f,
12.f,
13.f,
@ -93,26 +77,26 @@ TEST(TileTest, Float32Matrix) {
});
m.SetMultipliers({2, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray({
11.f,
12.f,
13.f,
21.f,
22.f,
23.f,
11.f,
12.f,
13.f,
21.f,
22.f,
23.f,
}));
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({
11.f,
12.f,
13.f,
21.f,
22.f,
23.f,
11.f,
12.f,
13.f,
21.f,
22.f,
23.f,
}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
}
TEST(TileTest, Float32HighDimension) {
TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32);
m.SetInputFloat({
m.SetInput<float>({
11.f,
12.f,
13.f,
@ -123,7 +107,7 @@ TEST(TileTest, Float32HighDimension) {
m.SetMultipliers({2, 3, 1});
m.Invoke();
EXPECT_THAT(
m.GetOutputFloat(),
m.GetOutput<float>(),
ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f,
11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
@ -133,7 +117,7 @@ TEST(TileTest, Float32HighDimension) {
TEST(TileTest, Uint8Matrix) {
TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32);
m.SetInputUInt8({
m.SetInput<uint8_t>({
11,
12,
13,
@ -143,26 +127,26 @@ TEST(TileTest, Uint8Matrix) {
});
m.SetMultipliers({2, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputUInt8(), ElementsAreArray({
11,
12,
13,
21,
22,
23,
11,
12,
13,
21,
22,
23,
}));
EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({
11,
12,
13,
21,
22,
23,
11,
12,
13,
21,
22,
23,
}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
}
TEST(TileTest, Int32Matrix) {
TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32);
m.SetInputInt32({
m.SetInput<int32_t>({
11,
12,
13,
@ -172,26 +156,39 @@ TEST(TileTest, Int32Matrix) {
});
m.SetMultipliers({2, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputInt32(), ElementsAreArray({
11,
12,
13,
21,
22,
23,
11,
12,
13,
21,
22,
23,
}));
EXPECT_THAT(m.GetOutput<int32_t>(), ElementsAreArray({
11,
12,
13,
21,
22,
23,
11,
12,
13,
21,
22,
23,
}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
}
TEST(TileTest, BooleanMatrix) {
TileOpModel m({2, 3}, TensorType_BOOL, TensorType_INT32);
m.SetInput<bool>({true, false, false, true, true, false});
m.SetMultipliers({2, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput<bool>(),
ElementsAreArray({
true, false, false, true, true, false, // first tiletrue,
true, false, false, true, true, false // second tile
}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
}
TEST(TileTest, Int64Matrix) {
TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32);
m.SetInputInt64({
m.SetInput<int64_t>({
11,
12,
13,
@ -201,26 +198,26 @@ TEST(TileTest, Int64Matrix) {
});
m.SetMultipliers({2, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({
11,
12,
13,
21,
22,
23,
11,
12,
13,
21,
22,
23,
}));
EXPECT_THAT(m.GetOutput<int64_t>(), ElementsAreArray({
11,
12,
13,
21,
22,
23,
11,
12,
13,
21,
22,
23,
}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
}
TEST(TileTest, Int64Matrix64Multipliers) {
TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64);
m.SetInputInt64({
m.SetInput<int64_t>({
11,
12,
13,
@ -230,20 +227,20 @@ TEST(TileTest, Int64Matrix64Multipliers) {
});
m.SetMultipliers({2, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({
11,
12,
13,
21,
22,
23,
11,
12,
13,
21,
22,
23,
}));
EXPECT_THAT(m.GetOutput<int64_t>(), ElementsAreArray({
11,
12,
13,
21,
22,
23,
11,
12,
13,
21,
22,
23,
}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
}
} // namespace