tile: add support for booleans
PiperOrigin-RevId: 220356963
This commit is contained in:
parent
bdd6af98f0
commit
bfb4bda0ff
@ -182,6 +182,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt64:
|
||||
Tile<int64_t>(*(input->dims), input, multipliers, output);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
Tile<bool>(*(input->dims), input, multipliers, output);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context, "Type '%s' is not supported by tile.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
|
@ -34,34 +34,18 @@ class TileOpModel : public SingleOpModel {
|
||||
BuildInterpreter({input_shape, {static_cast<int>(input_shape.size())}});
|
||||
}
|
||||
|
||||
void SetInputFloat(std::initializer_list<float> data) {
|
||||
PopulateTensor<float>(input_, data);
|
||||
}
|
||||
|
||||
void SetInputUInt8(std::initializer_list<uint8_t> data) {
|
||||
PopulateTensor<uint8_t>(input_, data);
|
||||
}
|
||||
|
||||
void SetInputInt32(std::initializer_list<int32_t> data) {
|
||||
PopulateTensor<int32_t>(input_, data);
|
||||
}
|
||||
|
||||
void SetInputInt64(std::initializer_list<int64_t> data) {
|
||||
PopulateTensor<int64_t>(input_, data);
|
||||
template <typename T>
|
||||
void SetInput(std::initializer_list<T> data) {
|
||||
PopulateTensor<T>(input_, data);
|
||||
}
|
||||
|
||||
void SetMultipliers(std::initializer_list<int32_t> data) {
|
||||
PopulateTensor<int32_t>(multipliers_, data);
|
||||
}
|
||||
|
||||
std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); }
|
||||
|
||||
std::vector<uint8_t> GetOutputUInt8() { return ExtractVector<uint8_t>(output_); }
|
||||
|
||||
std::vector<int32_t> GetOutputInt32() { return ExtractVector<int32_t>(output_); }
|
||||
|
||||
std::vector<int64_t> GetOutputInt64() {
|
||||
return ExtractVector<int64_t>(output_);
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
@ -74,16 +58,16 @@ class TileOpModel : public SingleOpModel {
|
||||
|
||||
TEST(TileTest, Float32Vector) {
|
||||
TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32);
|
||||
m.SetInputFloat({1.f, 2.f, 3.f});
|
||||
m.SetInput<float>({1.f, 2.f, 3.f});
|
||||
m.SetMultipliers({2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputFloat(),
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f}));
|
||||
}
|
||||
|
||||
TEST(TileTest, Float32Matrix) {
|
||||
TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32);
|
||||
m.SetInputFloat({
|
||||
m.SetInput<float>({
|
||||
11.f,
|
||||
12.f,
|
||||
13.f,
|
||||
@ -93,7 +77,7 @@ TEST(TileTest, Float32Matrix) {
|
||||
});
|
||||
m.SetMultipliers({2, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray({
|
||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({
|
||||
11.f,
|
||||
12.f,
|
||||
13.f,
|
||||
@ -112,7 +96,7 @@ TEST(TileTest, Float32Matrix) {
|
||||
|
||||
TEST(TileTest, Float32HighDimension) {
|
||||
TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32);
|
||||
m.SetInputFloat({
|
||||
m.SetInput<float>({
|
||||
11.f,
|
||||
12.f,
|
||||
13.f,
|
||||
@ -123,7 +107,7 @@ TEST(TileTest, Float32HighDimension) {
|
||||
m.SetMultipliers({2, 3, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(
|
||||
m.GetOutputFloat(),
|
||||
m.GetOutput<float>(),
|
||||
ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
|
||||
21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f,
|
||||
11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
|
||||
@ -133,7 +117,7 @@ TEST(TileTest, Float32HighDimension) {
|
||||
|
||||
TEST(TileTest, Uint8Matrix) {
|
||||
TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32);
|
||||
m.SetInputUInt8({
|
||||
m.SetInput<uint8_t>({
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
@ -143,7 +127,7 @@ TEST(TileTest, Uint8Matrix) {
|
||||
});
|
||||
m.SetMultipliers({2, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputUInt8(), ElementsAreArray({
|
||||
EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
@ -162,7 +146,7 @@ TEST(TileTest, Uint8Matrix) {
|
||||
|
||||
TEST(TileTest, Int32Matrix) {
|
||||
TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32);
|
||||
m.SetInputInt32({
|
||||
m.SetInput<int32_t>({
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
@ -172,7 +156,7 @@ TEST(TileTest, Int32Matrix) {
|
||||
});
|
||||
m.SetMultipliers({2, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputInt32(), ElementsAreArray({
|
||||
EXPECT_THAT(m.GetOutput<int32_t>(), ElementsAreArray({
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
@ -189,9 +173,22 @@ TEST(TileTest, Int32Matrix) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
|
||||
}
|
||||
|
||||
TEST(TileTest, BooleanMatrix) {
|
||||
TileOpModel m({2, 3}, TensorType_BOOL, TensorType_INT32);
|
||||
m.SetInput<bool>({true, false, false, true, true, false});
|
||||
m.SetMultipliers({2, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<bool>(),
|
||||
ElementsAreArray({
|
||||
true, false, false, true, true, false, // first tiletrue,
|
||||
true, false, false, true, true, false // second tile
|
||||
}));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
|
||||
}
|
||||
|
||||
TEST(TileTest, Int64Matrix) {
|
||||
TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32);
|
||||
m.SetInputInt64({
|
||||
m.SetInput<int64_t>({
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
@ -201,7 +198,7 @@ TEST(TileTest, Int64Matrix) {
|
||||
});
|
||||
m.SetMultipliers({2, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({
|
||||
EXPECT_THAT(m.GetOutput<int64_t>(), ElementsAreArray({
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
@ -220,7 +217,7 @@ TEST(TileTest, Int64Matrix) {
|
||||
|
||||
TEST(TileTest, Int64Matrix64Multipliers) {
|
||||
TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64);
|
||||
m.SetInputInt64({
|
||||
m.SetInput<int64_t>({
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
@ -230,7 +227,7 @@ TEST(TileTest, Int64Matrix64Multipliers) {
|
||||
});
|
||||
m.SetMultipliers({2, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({
|
||||
EXPECT_THAT(m.GetOutput<int64_t>(), ElementsAreArray({
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
|
Loading…
Reference in New Issue
Block a user