tile: add support for booleans
PiperOrigin-RevId: 220356963
This commit is contained in:
parent
bdd6af98f0
commit
bfb4bda0ff
@ -182,6 +182,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
Tile<int64_t>(*(input->dims), input, multipliers, output);
|
Tile<int64_t>(*(input->dims), input, multipliers, output);
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteBool:
|
||||||
|
Tile<bool>(*(input->dims), input, multipliers, output);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type '%s' is not supported by tile.",
|
context->ReportError(context, "Type '%s' is not supported by tile.",
|
||||||
TfLiteTypeGetName(output->type));
|
TfLiteTypeGetName(output->type));
|
||||||
|
@ -34,34 +34,18 @@ class TileOpModel : public SingleOpModel {
|
|||||||
BuildInterpreter({input_shape, {static_cast<int>(input_shape.size())}});
|
BuildInterpreter({input_shape, {static_cast<int>(input_shape.size())}});
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetInputFloat(std::initializer_list<float> data) {
|
template <typename T>
|
||||||
PopulateTensor<float>(input_, data);
|
void SetInput(std::initializer_list<T> data) {
|
||||||
}
|
PopulateTensor<T>(input_, data);
|
||||||
|
|
||||||
void SetInputUInt8(std::initializer_list<uint8_t> data) {
|
|
||||||
PopulateTensor<uint8_t>(input_, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetInputInt32(std::initializer_list<int32_t> data) {
|
|
||||||
PopulateTensor<int32_t>(input_, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetInputInt64(std::initializer_list<int64_t> data) {
|
|
||||||
PopulateTensor<int64_t>(input_, data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetMultipliers(std::initializer_list<int32_t> data) {
|
void SetMultipliers(std::initializer_list<int32_t> data) {
|
||||||
PopulateTensor<int32_t>(multipliers_, data);
|
PopulateTensor<int32_t>(multipliers_, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); }
|
template <typename T>
|
||||||
|
std::vector<T> GetOutput() {
|
||||||
std::vector<uint8_t> GetOutputUInt8() { return ExtractVector<uint8_t>(output_); }
|
return ExtractVector<T>(output_);
|
||||||
|
|
||||||
std::vector<int32_t> GetOutputInt32() { return ExtractVector<int32_t>(output_); }
|
|
||||||
|
|
||||||
std::vector<int64_t> GetOutputInt64() {
|
|
||||||
return ExtractVector<int64_t>(output_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
@ -74,16 +58,16 @@ class TileOpModel : public SingleOpModel {
|
|||||||
|
|
||||||
TEST(TileTest, Float32Vector) {
|
TEST(TileTest, Float32Vector) {
|
||||||
TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32);
|
TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32);
|
||||||
m.SetInputFloat({1.f, 2.f, 3.f});
|
m.SetInput<float>({1.f, 2.f, 3.f});
|
||||||
m.SetMultipliers({2});
|
m.SetMultipliers({2});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputFloat(),
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f}));
|
ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TileTest, Float32Matrix) {
|
TEST(TileTest, Float32Matrix) {
|
||||||
TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32);
|
TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32);
|
||||||
m.SetInputFloat({
|
m.SetInput<float>({
|
||||||
11.f,
|
11.f,
|
||||||
12.f,
|
12.f,
|
||||||
13.f,
|
13.f,
|
||||||
@ -93,7 +77,7 @@ TEST(TileTest, Float32Matrix) {
|
|||||||
});
|
});
|
||||||
m.SetMultipliers({2, 1});
|
m.SetMultipliers({2, 1});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray({
|
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({
|
||||||
11.f,
|
11.f,
|
||||||
12.f,
|
12.f,
|
||||||
13.f,
|
13.f,
|
||||||
@ -112,7 +96,7 @@ TEST(TileTest, Float32Matrix) {
|
|||||||
|
|
||||||
TEST(TileTest, Float32HighDimension) {
|
TEST(TileTest, Float32HighDimension) {
|
||||||
TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32);
|
TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32);
|
||||||
m.SetInputFloat({
|
m.SetInput<float>({
|
||||||
11.f,
|
11.f,
|
||||||
12.f,
|
12.f,
|
||||||
13.f,
|
13.f,
|
||||||
@ -123,7 +107,7 @@ TEST(TileTest, Float32HighDimension) {
|
|||||||
m.SetMultipliers({2, 3, 1});
|
m.SetMultipliers({2, 3, 1});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
m.GetOutputFloat(),
|
m.GetOutput<float>(),
|
||||||
ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
|
ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
|
||||||
21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f,
|
21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f,
|
||||||
11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
|
11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
|
||||||
@ -133,7 +117,7 @@ TEST(TileTest, Float32HighDimension) {
|
|||||||
|
|
||||||
TEST(TileTest, Uint8Matrix) {
|
TEST(TileTest, Uint8Matrix) {
|
||||||
TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32);
|
TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32);
|
||||||
m.SetInputUInt8({
|
m.SetInput<uint8_t>({
|
||||||
11,
|
11,
|
||||||
12,
|
12,
|
||||||
13,
|
13,
|
||||||
@ -143,7 +127,7 @@ TEST(TileTest, Uint8Matrix) {
|
|||||||
});
|
});
|
||||||
m.SetMultipliers({2, 1});
|
m.SetMultipliers({2, 1});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputUInt8(), ElementsAreArray({
|
EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({
|
||||||
11,
|
11,
|
||||||
12,
|
12,
|
||||||
13,
|
13,
|
||||||
@ -162,7 +146,7 @@ TEST(TileTest, Uint8Matrix) {
|
|||||||
|
|
||||||
TEST(TileTest, Int32Matrix) {
|
TEST(TileTest, Int32Matrix) {
|
||||||
TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32);
|
TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32);
|
||||||
m.SetInputInt32({
|
m.SetInput<int32_t>({
|
||||||
11,
|
11,
|
||||||
12,
|
12,
|
||||||
13,
|
13,
|
||||||
@ -172,7 +156,7 @@ TEST(TileTest, Int32Matrix) {
|
|||||||
});
|
});
|
||||||
m.SetMultipliers({2, 1});
|
m.SetMultipliers({2, 1});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputInt32(), ElementsAreArray({
|
EXPECT_THAT(m.GetOutput<int32_t>(), ElementsAreArray({
|
||||||
11,
|
11,
|
||||||
12,
|
12,
|
||||||
13,
|
13,
|
||||||
@ -189,9 +173,22 @@ TEST(TileTest, Int32Matrix) {
|
|||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TileTest, BooleanMatrix) {
|
||||||
|
TileOpModel m({2, 3}, TensorType_BOOL, TensorType_INT32);
|
||||||
|
m.SetInput<bool>({true, false, false, true, true, false});
|
||||||
|
m.SetMultipliers({2, 1});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<bool>(),
|
||||||
|
ElementsAreArray({
|
||||||
|
true, false, false, true, true, false, // first tiletrue,
|
||||||
|
true, false, false, true, true, false // second tile
|
||||||
|
}));
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TileTest, Int64Matrix) {
|
TEST(TileTest, Int64Matrix) {
|
||||||
TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32);
|
TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32);
|
||||||
m.SetInputInt64({
|
m.SetInput<int64_t>({
|
||||||
11,
|
11,
|
||||||
12,
|
12,
|
||||||
13,
|
13,
|
||||||
@ -201,7 +198,7 @@ TEST(TileTest, Int64Matrix) {
|
|||||||
});
|
});
|
||||||
m.SetMultipliers({2, 1});
|
m.SetMultipliers({2, 1});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({
|
EXPECT_THAT(m.GetOutput<int64_t>(), ElementsAreArray({
|
||||||
11,
|
11,
|
||||||
12,
|
12,
|
||||||
13,
|
13,
|
||||||
@ -220,7 +217,7 @@ TEST(TileTest, Int64Matrix) {
|
|||||||
|
|
||||||
TEST(TileTest, Int64Matrix64Multipliers) {
|
TEST(TileTest, Int64Matrix64Multipliers) {
|
||||||
TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64);
|
TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64);
|
||||||
m.SetInputInt64({
|
m.SetInput<int64_t>({
|
||||||
11,
|
11,
|
||||||
12,
|
12,
|
||||||
13,
|
13,
|
||||||
@ -230,7 +227,7 @@ TEST(TileTest, Int64Matrix64Multipliers) {
|
|||||||
});
|
});
|
||||||
m.SetMultipliers({2, 1});
|
m.SetMultipliers({2, 1});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({
|
EXPECT_THAT(m.GetOutput<int64_t>(), ElementsAreArray({
|
||||||
11,
|
11,
|
||||||
12,
|
12,
|
||||||
13,
|
13,
|
||||||
|
Loading…
Reference in New Issue
Block a user