2139 lines
75 KiB
C++
2139 lines
75 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/literal.h"
|
|
|
|
#include <limits>
|
|
#include <vector>
|
|
|
|
#include "absl/base/casts.h"
|
|
#include "absl/memory/memory.h"
|
|
#include "absl/strings/match.h"
|
|
#include "absl/strings/str_cat.h"
|
|
#include "tensorflow/compiler/xla/array3d.h"
|
|
#include "tensorflow/compiler/xla/array4d.h"
|
|
#include "tensorflow/compiler/xla/layout_util.h"
|
|
#include "tensorflow/compiler/xla/literal_util.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/test.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
#include "tensorflow/core/platform/macros.h"
|
|
#include "tensorflow/core/platform/types.h"
|
|
|
|
namespace xla {
|
|
namespace {
|
|
|
|
using ::testing::ElementsAre;
|
|
using ::testing::HasSubstr;
|
|
|
|
class LiteralUtilTest : public ::testing::Test {
|
|
protected:
|
|
LiteralUtilTest() {
|
|
Array4D<float> arr4d({
|
|
// clang-format off
|
|
{ // i0=0
|
|
{ // i1=0
|
|
{1, 2, 3}, // i2=0
|
|
{4, 5, 6}, // i2=1
|
|
{7, 8, 9}, // i2=2
|
|
},
|
|
{ // i1=1
|
|
{11, 12, 13},
|
|
{14, 15, 16},
|
|
{17, 18, 19},
|
|
},
|
|
},
|
|
{ // i0=1
|
|
{ // i1=0
|
|
{101, 102, 103},
|
|
{104, 105, 106},
|
|
{107, 108, 109},
|
|
},
|
|
{ // i1=1
|
|
{201, 202, 203}, // i2=0
|
|
{204, 205, 206}, // i2=1
|
|
{207, 208, 209}, // i2=2
|
|
},
|
|
},
|
|
// clang-format on
|
|
});
|
|
|
|
layout_r2_dim0major_ = LayoutUtil::MakeLayout({1, 0});
|
|
layout_r2_dim0minor_ = LayoutUtil::MakeLayout({0, 1});
|
|
layout_r3_dim0major_ = LayoutUtil::MakeLayout({2, 1, 0});
|
|
layout_r3_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2});
|
|
layout_r4_dim0major_ = LayoutUtil::MakeLayout({3, 2, 1, 0});
|
|
layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3});
|
|
|
|
literal_r4_2x2x3x3_dim0major_ =
|
|
LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
|
|
layout_r4_dim0major_);
|
|
literal_r4_2x2x3x3_dim0minor_ =
|
|
LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
|
|
layout_r4_dim0minor_);
|
|
}
|
|
|
|
Layout layout_r2_dim0major_;
|
|
Layout layout_r2_dim0minor_;
|
|
Layout layout_r3_dim0major_;
|
|
Layout layout_r3_dim0minor_;
|
|
Layout layout_r4_dim0major_;
|
|
Layout layout_r4_dim0minor_;
|
|
Literal literal_r4_2x2x3x3_dim0major_;
|
|
Literal literal_r4_2x2x3x3_dim0minor_;
|
|
};
|
|
|
|
TEST_F(LiteralUtilTest, LiteralScalarToString) {
|
|
auto true_lit = LiteralUtil::CreateR0<bool>(true);
|
|
EXPECT_EQ("pred[] true", true_lit.ToString());
|
|
|
|
auto false_lit = LiteralUtil::CreateR0<bool>(false);
|
|
EXPECT_EQ("pred[] false", false_lit.ToString());
|
|
|
|
auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
|
|
EXPECT_EQ("u32[] 42", u32_lit.ToString());
|
|
|
|
auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
|
|
EXPECT_EQ("s32[] -999", s32_lit.ToString());
|
|
|
|
auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
|
|
EXPECT_EQ("f32[] 3.14", f32_lit.ToString());
|
|
|
|
auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
|
|
EXPECT_EQ("f16[] 0.5", f16_lit.ToString());
|
|
|
|
auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
|
|
EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString());
|
|
|
|
auto c128_lit = LiteralUtil::CreateR0<complex128>({3.14, 2.78});
|
|
EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString());
|
|
|
|
auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
|
|
EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString());
|
|
|
|
// 3.14 will be rounded to 3.140625 in bfloat16 format.
|
|
auto bf16_lit_truncated =
|
|
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
|
|
ASSERT_EQ("bf16[] 3.141", bf16_lit_truncated.ToString());
|
|
|
|
auto bf16_lit_truncated2 =
|
|
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
|
|
EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, LiteralVectorToString) {
|
|
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
|
|
EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, R2ToString) {
|
|
const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
|
|
const string expected = R"(s32[3,2] {
|
|
{ 1, 2 },
|
|
{ 3, 4 },
|
|
{ 5, 6 }
|
|
})";
|
|
EXPECT_EQ(expected, literal.ToString());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, R2DynamicToString) {
|
|
auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
|
|
literal.SetDynamicSize(0, {}, 2);
|
|
const string expected = R"(s32[<=3,2](2,2) {
|
|
{ 1, 2 },
|
|
{ 3, 4 }
|
|
})";
|
|
EXPECT_EQ(expected, literal.ToString());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, R3ToString) {
|
|
const auto literal =
|
|
LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
|
|
const string expected = R"(s32[3,2,1] {
|
|
{
|
|
{1},
|
|
{2}
|
|
},
|
|
{
|
|
{3},
|
|
{4}
|
|
},
|
|
{
|
|
{5},
|
|
{6}
|
|
}
|
|
})";
|
|
EXPECT_EQ(expected, literal.ToString());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, R6ToString) {
|
|
const auto literal =
|
|
LiteralUtil::CreateFromDimensions(S32, {2, 2, 1, 1, 1, 2});
|
|
const string expected = R"(s32[2,2,1,1,1,2] {
|
|
{ /*i0=0*/
|
|
{ /*i1=0*/
|
|
{ /*i2=0*/
|
|
{ /*i3=0*/
|
|
{ 0, 0 }
|
|
}
|
|
}
|
|
},
|
|
{ /*i1=1*/
|
|
{ /*i2=0*/
|
|
{ /*i3=0*/
|
|
{ 0, 0 }
|
|
}
|
|
}
|
|
}
|
|
},
|
|
{ /*i0=1*/
|
|
{ /*i1=0*/
|
|
{ /*i2=0*/
|
|
{ /*i3=0*/
|
|
{ 0, 0 }
|
|
}
|
|
}
|
|
},
|
|
{ /*i1=1*/
|
|
{ /*i2=0*/
|
|
{ /*i3=0*/
|
|
{ 0, 0 }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})";
|
|
EXPECT_EQ(expected, literal.ToString());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, TupleToString) {
|
|
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
|
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
|
|
const string expected = R"((
|
|
f32[] 1,
|
|
f32[2,2] {
|
|
{ 1, 2 },
|
|
{ 3, 4 }
|
|
}
|
|
))";
|
|
EXPECT_EQ(expected, tuple.ToString());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
|
|
// clang-format off
|
|
Array3D<float> array_3d({
|
|
{{1.0f, 2.0f},
|
|
{3.0f, 4.0f},
|
|
{5.0f, 6.0f}},
|
|
{{7.0f, 8.0f},
|
|
{9.0f, 10.0f},
|
|
{11.0f, 12.0f}},
|
|
});
|
|
// clang-format on
|
|
|
|
auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
|
|
EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
|
|
string result = literal.ToString();
|
|
const string expected = R"(f32[2,3,2] {
|
|
{
|
|
{ 1, 2 },
|
|
{ 3, 4 },
|
|
{ 5, 6 }
|
|
},
|
|
{
|
|
{ 7, 8 },
|
|
{ 9, 10 },
|
|
{ 11, 12 }
|
|
}
|
|
})";
|
|
EXPECT_EQ(expected, result);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
|
|
// clang-format off
|
|
auto literal = LiteralUtil::CreateR4Projected<float>({
|
|
{1, 2},
|
|
{1001, 1002},
|
|
{2001, 2002},
|
|
}, /*projection_p=*/1, /*projection_z=*/2);
|
|
// clang-format on
|
|
EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
|
|
string result = literal.ToString();
|
|
const string expected = R"(f32[1,2,3,2] {
|
|
{ /*i0=0*/
|
|
{ /*i1=0*/
|
|
{ 1, 2 },
|
|
{ 1001, 1002 },
|
|
{ 2001, 2002 }
|
|
},
|
|
{ /*i1=1*/
|
|
{ 1, 2 },
|
|
{ 1001, 1002 },
|
|
{ 2001, 2002 }
|
|
}
|
|
}
|
|
})";
|
|
EXPECT_EQ(expected, result);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
|
|
EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
|
|
ElementsAre(2, 2, 3, 3));
|
|
string result = literal_r4_2x2x3x3_dim0major_.ToString();
|
|
const string expected = R"(f32[2,2,3,3] {
|
|
{ /*i0=0*/
|
|
{ /*i1=0*/
|
|
{ 1, 2, 3 },
|
|
{ 4, 5, 6 },
|
|
{ 7, 8, 9 }
|
|
},
|
|
{ /*i1=1*/
|
|
{ 11, 12, 13 },
|
|
{ 14, 15, 16 },
|
|
{ 17, 18, 19 }
|
|
}
|
|
},
|
|
{ /*i0=1*/
|
|
{ /*i1=0*/
|
|
{ 101, 102, 103 },
|
|
{ 104, 105, 106 },
|
|
{ 107, 108, 109 }
|
|
},
|
|
{ /*i1=1*/
|
|
{ 201, 202, 203 },
|
|
{ 204, 205, 206 },
|
|
{ 207, 208, 209 }
|
|
}
|
|
}
|
|
})";
|
|
EXPECT_EQ(expected, result);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, EachCellR2F32) {
|
|
// clang-format off
|
|
auto literal = LiteralUtil::CreateR2<float>({
|
|
{3.1f, 4.2f},
|
|
{9.3f, 12.4f},
|
|
});
|
|
// clang-format on
|
|
std::vector<std::tuple<int64, int64, string>> seen;
|
|
literal.EachCellAsString(
|
|
[&seen](absl::Span<const int64> indices, const string& value) {
|
|
seen.emplace_back(indices[0], indices[1], value);
|
|
});
|
|
|
|
using Elem = std::tuple<int64, int64, string>;
|
|
std::vector<Elem> expected = {Elem(0, 0, "3.1"), Elem(0, 1, "4.2"),
|
|
Elem(1, 0, "9.3"), Elem(1, 1, "12.4")};
|
|
EXPECT_EQ(expected, seen);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ScalarEquality) {
|
|
// Test equality with scalars.
|
|
auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
|
|
auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
|
|
|
|
EXPECT_EQ(f32_42, f32_42);
|
|
EXPECT_EQ(f32_42, f32_42_clone);
|
|
|
|
auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
|
|
EXPECT_NE(f32_42, f32_123);
|
|
|
|
auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
|
|
EXPECT_NE(f32_42, f64_42);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, NonScalarEquality) {
|
|
// Test equality with nonscalars.
|
|
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto matrix_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto matrix_different =
|
|
LiteralUtil::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
|
|
auto vector_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
|
|
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
|
Literal nil(ShapeUtil::MakeNil());
|
|
|
|
EXPECT_EQ(matrix, matrix);
|
|
EXPECT_EQ(matrix, matrix_clone);
|
|
EXPECT_NE(matrix, matrix_different);
|
|
EXPECT_NE(matrix, vector_literal);
|
|
EXPECT_NE(matrix, scalar);
|
|
EXPECT_NE(matrix, nil);
|
|
EXPECT_EQ(nil, nil);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, TokenEquality) {
|
|
auto token0 = LiteralUtil::CreateToken();
|
|
auto token1 = LiteralUtil::CreateToken();
|
|
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
|
|
|
EXPECT_EQ(token0, token1);
|
|
EXPECT_NE(token0, scalar);
|
|
|
|
EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
|
|
LiteralUtil::MakeTuple({&token0}));
|
|
EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
|
|
LiteralUtil::MakeTuple({&token1, &scalar}));
|
|
EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
|
|
LiteralUtil::MakeTuple({&scalar, &token1}));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
|
|
// Test equality with literals which have different layouts.
|
|
Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
|
|
colmajor.Set<float>({0, 0}, 1.0);
|
|
colmajor.Set<float>({0, 1}, 2.0);
|
|
colmajor.Set<float>({1, 0}, 3.0);
|
|
colmajor.Set<float>({1, 1}, 4.0);
|
|
|
|
Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
|
|
rowmajor.Set<float>({0, 0}, 1.0);
|
|
rowmajor.Set<float>({0, 1}, 2.0);
|
|
rowmajor.Set<float>({1, 0}, 3.0);
|
|
rowmajor.Set<float>({1, 1}, 4.0);
|
|
|
|
EXPECT_EQ(rowmajor, colmajor);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, TupleEquality) {
|
|
// Test equality with tuples.
|
|
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
|
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
|
|
|
|
// Tuple with the same elements. One element is shared with the original
|
|
// tuple, the other is a clone of the element in the original tuple.
|
|
auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
|
|
auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
|
|
EXPECT_EQ(tuple1, tuple2);
|
|
|
|
// Tuple with elements reversed.
|
|
auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
|
|
EXPECT_NE(tuple1, reversed_tuple);
|
|
|
|
// Tuple with different value.
|
|
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
|
|
auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
|
|
EXPECT_NE(tuple1, different_tuple);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, DynamicShapeEquality) {
|
|
// Test equality with tuples.
|
|
auto r1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
|
|
r1.SetDynamicSize(0, {}, 1);
|
|
auto r2 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
r2.SetDynamicSize(0, {}, 1);
|
|
auto tuple1 = LiteralUtil::MakeTuple({&r1, &r2});
|
|
|
|
// Tuple with the same elements. One element is shared with the original
|
|
// tuple, the other is a clone of the element in the original tuple.
|
|
auto r1_clone = LiteralUtil::CreateR1<float>({1.0, 3.0});
|
|
r1_clone.SetDynamicSize(0, {}, 1);
|
|
auto tuple2 = LiteralUtil::MakeTuple({&r1_clone, &r2});
|
|
EXPECT_EQ(tuple1, tuple2);
|
|
|
|
// Tuple with different dynamic sizes.
|
|
auto r2_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
r2_clone.SetDynamicSize(0, {}, 2);
|
|
auto tuple_3 = LiteralUtil::MakeTuple({&r1_clone, &r2_clone});
|
|
EXPECT_NE(tuple1, tuple_3);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, C64Equality) {
|
|
// Test equality with tuples.
|
|
auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
|
|
|
|
// Tuple with the same elements. One element is shared with the original
|
|
// tuple, the other is a clone of the element in the original tuple.
|
|
auto vector_clone =
|
|
LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
|
|
EXPECT_EQ(vector, vector_clone);
|
|
|
|
auto vector_reversed =
|
|
LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
|
|
EXPECT_NE(vector, vector_reversed);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, C128Equality) {
|
|
// Test equality with tuples.
|
|
auto vector = LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
|
|
|
|
// Tuple with the same elements. One element is shared with the original
|
|
// tuple, the other is a clone of the element in the original tuple.
|
|
auto vector_clone =
|
|
LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
|
|
EXPECT_EQ(vector, vector_clone);
|
|
|
|
auto vector_reversed =
|
|
LiteralUtil::CreateR1<complex128>({{3.0, 4.0}, {1.0, 2.0}});
|
|
EXPECT_NE(vector, vector_reversed);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, IsAllTuple) {
|
|
auto element1 = LiteralUtil::CreateR0<float>(0.0);
|
|
auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
|
|
auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
|
|
|
|
// Tuples should always return false for IsAll.
|
|
EXPECT_FALSE(tuple.IsAll(0));
|
|
EXPECT_FALSE(tuple.IsAll(1));
|
|
}
|
|
|
|
// Verifies that CreateFromShape works for tuples.
|
|
TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
|
|
auto scalar = LiteralUtil::CreateR0<float>(0.0);
|
|
auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
|
|
auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
|
|
|
|
auto x = Literal::CreateFromShape(tuple.shape());
|
|
EXPECT_EQ(tuple, x);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, IsAll) {
|
|
EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
|
|
EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
|
|
|
|
// We shouldn't reinterpret int8_min as an unsigned type and then decide that
|
|
// it is equal to 255.
|
|
auto int8_min = std::numeric_limits<int8>::min();
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255).IsAll(int8_min));
|
|
|
|
EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
|
|
|
|
EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
|
|
EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
|
|
|
|
EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}).IsAll(8));
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}).IsAll(8));
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}).IsAll(8));
|
|
|
|
half h8(8.0f);
|
|
half h9(9.0f);
|
|
EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
|
|
|
|
bfloat16 b8(8.0f);
|
|
bfloat16 b9(9.0f);
|
|
|
|
EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
|
|
|
|
// 9.001 will be truncated to 9.0
|
|
bfloat16 b91(9.001f);
|
|
bfloat16 b90(9.00f);
|
|
EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
|
|
|
|
complex64 c8_9 = {8, 9};
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
|
|
|
|
auto uint64_max = std::numeric_limits<uint64>::max();
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
|
|
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
|
|
.IsAll(-1));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, IsAllFloat) {
|
|
// IsAllFloat always returns false when the literal is not floating-point.
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllFloat(0));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllFloat(0));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
|
|
|
|
EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
|
|
EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
|
|
EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
|
|
EXPECT_FALSE(
|
|
LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
|
|
EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
|
|
.IsAllFloat(.5));
|
|
|
|
EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
|
|
EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
|
|
EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
|
|
EXPECT_FALSE(
|
|
LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, IsAllComplex) {
|
|
// IsAllComplex always returns false when the literal is not complex.
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllComplex(0));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllComplex(0));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
|
|
EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
|
|
|
|
complex64 c8_9 = {8, 9};
|
|
complex64 c7_9 = {7, 9};
|
|
EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
|
|
.IsAllComplex({8.0f, 9.0f}));
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
|
|
.IsAllComplex({8.0f, 9.0f}));
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
|
|
.IsAllComplex({8.0f, 9.0f}));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, IsAllFirst) {
|
|
// IsAllComplex always returns false when the literal is not complex.
|
|
EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
|
|
EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
|
|
EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2}).IsAllFirst());
|
|
EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5}).IsAllFirst());
|
|
EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2}).IsAllFirst());
|
|
EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5}).IsAllFirst());
|
|
EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2}).IsAllFirst());
|
|
EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5}).IsAllFirst());
|
|
EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2}).IsAllFirst());
|
|
|
|
complex64 c8_9 = {8, 9};
|
|
complex64 c7_9 = {7, 9};
|
|
EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
|
|
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, IsZero) {
|
|
auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
|
|
auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
|
|
EXPECT_TRUE(scalar_zero.IsZero({}));
|
|
EXPECT_FALSE(scalar_one.IsZero({}));
|
|
|
|
auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
|
|
EXPECT_FALSE(array.IsZero({0, 1}));
|
|
EXPECT_TRUE(array.IsZero({0, 2}));
|
|
EXPECT_TRUE(array.IsZero({1, 1}));
|
|
EXPECT_FALSE(array.IsZero({1, 2}));
|
|
|
|
auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
|
|
auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
|
|
EXPECT_TRUE(complex_zero.IsZero({}));
|
|
EXPECT_FALSE(complex_nonzero.IsZero({}));
|
|
}
|
|
|
|
template <typename T>
|
|
class LiteralUtilTestTemplated : public ::testing::Test {};
|
|
|
|
using TestedTypes = ::testing::Types<float, int32, uint32, complex64>;
|
|
TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes);
|
|
|
|
TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
|
|
// Make a non-integer for floating point types.
|
|
TypeParam half = TypeParam(1) / TypeParam(2);
|
|
auto data = LiteralUtil::CreateR2<TypeParam>({{half, 2}, {3, 4}});
|
|
const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
|
|
const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
|
|
|
|
auto data01 = data.Relayout(layout01);
|
|
EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
|
|
EXPECT_EQ(data, data01);
|
|
|
|
auto data10 = data.Relayout(layout10);
|
|
EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
|
|
EXPECT_EQ(data, data10);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ReshapeR0) {
|
|
auto original = LiteralUtil::CreateR0<float>(1.7f);
|
|
auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
|
|
EXPECT_EQ(original, reshape);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ReshapeR4) {
|
|
// clang-format off
|
|
// F32[1x3x2x4]
|
|
auto original = LiteralUtil::CreateR4WithLayout<float>({{
|
|
{{10, 11, 12, 13}, {14, 15, 16, 17}},
|
|
{{18, 19, 20, 21}, {22, 23, 24, 25}},
|
|
{{26, 27, 28, 29}, {30, 31, 32, 33}},
|
|
}}, layout_r4_dim0major_);
|
|
// F32[1x3x4x2]
|
|
auto expected = LiteralUtil::CreateR3WithLayout<float>({
|
|
{{10, 11}, {12, 13}, {14, 15}, {16, 17}},
|
|
{{18, 19}, {20, 21}, {22, 23}, {24, 25}},
|
|
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
|
|
}, layout_r3_dim0major_);
|
|
// clang-format on
|
|
auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
|
|
|
|
EXPECT_EQ(expected, reshape);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
|
|
// clang-format off
|
|
// F32[1x3x2x4]
|
|
auto original = LiteralUtil::CreateR4WithLayout<float>({{
|
|
{{10, 11, 12, 13}, {14, 15, 16, 17}},
|
|
{{18, 19, 20, 21}, {22, 23, 24, 25}},
|
|
{{26, 27, 28, 29}, {30, 31, 32, 33}},
|
|
}}, layout_r4_dim0minor_);
|
|
// F32[1x3x4x2]
|
|
auto expected = LiteralUtil::CreateR3WithLayout<float>({
|
|
{{10, 11}, {12, 13}, {14, 15}, {16, 17}},
|
|
{{18, 19}, {20, 21}, {22, 23}, {24, 25}},
|
|
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
|
|
}, layout_r3_dim0major_);
|
|
// clang-format on
|
|
auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
|
|
|
|
EXPECT_EQ(expected, reshape);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, TransposeR0) {
|
|
auto original = LiteralUtil::CreateR0<float>(1.7f);
|
|
auto reshape = original.Transpose(/*permutation=*/{});
|
|
EXPECT_EQ(original, reshape);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, TransposeR4) {
|
|
// clang-format off
|
|
// F32[1x3x2x4]
|
|
auto original = LiteralUtil::CreateR4<float>({{
|
|
{{10, 11, 12, 13}, {14, 15, 16, 17}},
|
|
{{18, 19, 20, 21}, {22, 23, 24, 25}},
|
|
{{26, 27, 28, 29}, {30, 31, 32, 33}},
|
|
}});
|
|
// clang-format on
|
|
auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
|
|
|
|
reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
|
|
EXPECT_EQ(value, original.Get<float>(
|
|
{indices[2], indices[3], indices[0], indices[1]}));
|
|
});
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, TransposeDynamicR2) {
|
|
// F32[2, <=3] (2, 1)
|
|
auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
|
|
original.SetDynamicSize(1, 1);
|
|
// F32[<=3, 2] (1, 2)
|
|
auto reshape = original.Transpose(/*permutation=*/{1, 0});
|
|
|
|
reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
|
|
EXPECT_EQ(value, original.Get<float>({indices[1], indices[0]}));
|
|
});
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ToStaticR2) {
|
|
// F32[2, <=3] (2, 1)
|
|
auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
|
|
original.SetDynamicSize(1, 1);
|
|
// F32[2, 1]
|
|
auto static_literal = original.ToStatic();
|
|
EXPECT_EQ(static_literal.shape(), ShapeUtil::MakeShape(F32, {2, 1}));
|
|
EXPECT_TRUE(static_literal.shape().is_static());
|
|
|
|
static_literal.EachCell<float>(
|
|
[&](absl::Span<const int64> indices, float value) {
|
|
EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
|
|
});
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ToBoundedDynamicR2) {
|
|
// F32[2, 1]
|
|
auto original = LiteralUtil::CreateR2<float>({{1}, {4}});
|
|
// F32[2, <=3] (2, 1)
|
|
auto dynamic_shape = ShapeUtil::MakeShape(F32, {2, 3}, {false, true});
|
|
auto dynamic_literal = original.ToBoundedDynamic(dynamic_shape);
|
|
EXPECT_EQ(dynamic_literal.shape(), dynamic_shape);
|
|
|
|
dynamic_literal.EachCell<float>(
|
|
[&](absl::Span<const int64> indices, float value) {
|
|
EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
|
|
});
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
|
|
// Tests that using Relayout on an array is equivalent to creating it in the
|
|
// target layout in the first place.
|
|
auto dim0minor_relaid_to_dim0major =
|
|
literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
|
|
EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
|
|
|
|
auto dim0major_relaid_to_dim0minor =
|
|
literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
|
|
EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
|
|
// Test expected memory layout of R2 dim0-minor (column-major) literal.
|
|
auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
|
|
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
|
|
EXPECT_EQ(mat_dim0minor.element_count(), 6);
|
|
EXPECT_THAT(mat_dim0minor.data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
|
|
|
|
// Test expected memory layout when using Relayout to row major.
|
|
auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
|
|
EXPECT_THAT(relaid_mat_to_dim0major.data<int32>(),
|
|
ElementsAre(1, 2, 3, 4, 5, 6));
|
|
|
|
// Test expected memory layout of R2 created with dim0-major (row-major).
|
|
auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
|
|
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
|
|
EXPECT_EQ(mat_dim0major.element_count(), 6);
|
|
EXPECT_THAT(mat_dim0major.data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
|
|
|
|
// Test expected memory layout when using Relayout to column major.
|
|
auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
|
|
EXPECT_THAT(relaid_mat_to_dim0minor.data<int32>(),
|
|
ElementsAre(1, 4, 2, 5, 3, 6));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, TestR3LinearLayout) {
|
|
// Test expected memory layout of R3 dim0-minor (column-major) literal.
|
|
Array3D<int> arr3d(
|
|
// clang-format off
|
|
{
|
|
{
|
|
{1, 2, 3},
|
|
{4, 5, 6},
|
|
},
|
|
{
|
|
{7, 8, 9},
|
|
{10, 11, 12},
|
|
},
|
|
}); // clang-format on
|
|
auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
|
|
arr3d, layout_r3_dim0minor_);
|
|
|
|
EXPECT_EQ(lit_dim0minor.element_count(), 12);
|
|
std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
|
|
EXPECT_THAT(lit_dim0minor.data<int32>(),
|
|
testing::ElementsAreArray(expected_dim0minor));
|
|
|
|
// Test expected memory layout when using Relayout to row major.
|
|
auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
|
|
std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
|
|
EXPECT_THAT(relaid_lit_to_dim0major.data<int32>(),
|
|
testing::ElementsAreArray(expected_dim0major));
|
|
|
|
// Test expected memory layout of R3 created with dim0-major (row-major).
|
|
auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
|
|
arr3d, layout_r3_dim0major_);
|
|
EXPECT_EQ(lit_dim0major.element_count(), 12);
|
|
EXPECT_THAT(lit_dim0major.data<int32>(),
|
|
testing::ElementsAreArray(expected_dim0major));
|
|
|
|
// Test expected memory layout when using Relayout to column major.
|
|
auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
|
|
EXPECT_THAT(relaid_lit_to_dim0minor.data<int32>(),
|
|
testing::ElementsAreArray(expected_dim0minor));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, SliceR0S32) {
|
|
auto input = LiteralUtil::CreateR0<int32>(1);
|
|
auto result = input.Slice({}, {});
|
|
EXPECT_EQ(input, result);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, SliceR1F32) {
|
|
auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
|
|
auto result = input.Slice({3}, {4});
|
|
auto expected = LiteralUtil::CreateR1<float>({4.0});
|
|
EXPECT_EQ(expected, result);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, SliceR2U32) {
|
|
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
|
|
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
|
auto result = input_3x4.Slice({0, 2}, {2, 4});
|
|
auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
|
|
EXPECT_EQ(expected, result);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, SliceR3U32Full) {
|
|
auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
|
|
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
|
|
auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
|
|
EXPECT_EQ(input_2x3x2, result);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, SliceR2Dynamic) {
|
|
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
|
|
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
|
input_3x4.SetDynamicSize(1, 3);
|
|
// slice second dim from dynamic size 3 to dynamic size 1.
|
|
auto result = input_3x4.Slice({0, 1}, {2, 2});
|
|
auto expected = LiteralUtil::CreateR2<uint32>({{2}, {6}});
|
|
EXPECT_EQ(expected, result);
|
|
EXPECT_EQ(result.GetDynamicSize(1), 1);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, SliceR2DynamicInBound) {
|
|
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
|
|
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
|
input_3x4.SetDynamicSize(1, 1);
|
|
auto result = input_3x4.Slice({0, 0}, {2, 2});
|
|
auto expected = LiteralUtil::CreateR2<uint32>({{1}, {5}});
|
|
EXPECT_EQ(expected, result);
|
|
EXPECT_EQ(result.GetDynamicSize(1), 1);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, SliceR2DynamicOutOfBound) {
|
|
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
|
|
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
|
input_3x4.SetDynamicSize(1, 1);
|
|
auto result = input_3x4.Slice({0, 1}, {2, 3});
|
|
auto expected = LiteralUtil::CreateR2<uint32>({{}, {}});
|
|
EXPECT_EQ(expected, result);
|
|
// Out of bound access clamps into 0 sized dimension.
|
|
EXPECT_EQ(result.GetDynamicSize(1), 0);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateR1S64) {
|
|
Literal output(ShapeUtil::MakeShape(S64, {1}));
|
|
output.PopulateR1<int64>({77});
|
|
auto expected = LiteralUtil::CreateR1<int64>({77});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateR1U64) {
|
|
Literal output(ShapeUtil::MakeShape(U64, {2}));
|
|
output.PopulateR1<uint64>({{77, 88}});
|
|
auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateR1C64) {
|
|
Literal output(ShapeUtil::MakeShape(C64, {1}));
|
|
output.PopulateR1<complex64>({{77, 88}});
|
|
auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateR1C128) {
|
|
Literal output(ShapeUtil::MakeShape(C128, {1}));
|
|
output.PopulateR1<complex128>({{77, 88}});
|
|
auto expected = LiteralUtil::CreateR1<complex128>({{77, 88}});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateR2C64) {
|
|
Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
|
|
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
|
|
auto expected =
|
|
LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
|
|
Literal output(ShapeUtil::MakeShape(BF16, {}));
|
|
bfloat16 h(0.25f);
|
|
output.PopulateWithValue<bfloat16>(h);
|
|
auto expected = LiteralUtil::CreateR0<bfloat16>(h);
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
|
|
Literal output(ShapeUtil::MakeShape(BF16, {3}));
|
|
bfloat16 h(0.5f);
|
|
output.PopulateWithValue<bfloat16>(h);
|
|
auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
|
|
Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
|
|
bfloat16 h(2.0f);
|
|
output.PopulateWithValue<bfloat16>(h);
|
|
auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
|
|
Literal output(ShapeUtil::MakeShape(F32, {}));
|
|
output.PopulateWithValue<float>(2.5f);
|
|
auto expected = LiteralUtil::CreateR0<float>(2.5f);
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
|
|
Literal output(ShapeUtil::MakeShape(S64, {3}));
|
|
output.PopulateWithValue<int64>(-7);
|
|
auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
|
|
Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
|
|
output.PopulateWithValue<uint64>(42);
|
|
auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
|
|
Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
|
|
output.PopulateWithValue<complex64>({4, 2});
|
|
auto expected =
|
|
LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR2C128) {
|
|
Literal output(ShapeUtil::MakeShape(C128, {2, 2}));
|
|
output.PopulateWithValue<complex128>({4, 2});
|
|
auto expected =
|
|
LiteralUtil::CreateR2<complex128>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
|
|
Literal output(ShapeUtil::MakeShape(F16, {}));
|
|
half h(0.25f);
|
|
output.PopulateWithValue<half>(h);
|
|
auto expected = LiteralUtil::CreateR0<half>(h);
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
|
|
Literal output(ShapeUtil::MakeShape(F16, {3}));
|
|
half h(0.5f);
|
|
output.PopulateWithValue<half>(h);
|
|
auto expected = LiteralUtil::CreateR1<half>({h, h, h});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
|
|
Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
|
|
half h(2.0f);
|
|
output.PopulateWithValue<half>(h);
|
|
auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ReplicateR2U32) {
|
|
auto input = LiteralUtil::CreateR2<uint32>(
|
|
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
|
auto output = input.Replicate<uint32>(3);
|
|
auto expected = LiteralUtil::CreateR3<uint32>(
|
|
{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
|
|
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
|
|
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
|
|
EXPECT_EQ(output, expected);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CopySliceFrom) {
|
|
const int64 dimensions[] = {17, 15, 34, 21};
|
|
const int64 layouts[][4] = {
|
|
{3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}};
|
|
for (const auto& layout : layouts) {
|
|
Shape shape = ShapeUtil::MakeShapeWithLayout(
|
|
primitive_util::NativeToPrimitiveType<uint32>(), dimensions, layout);
|
|
|
|
auto source = Literal::CreateFromShape(shape);
|
|
const int64 zero_base[] = {0, 0, 0, 0};
|
|
const int64 step[] = {1, 1, 1, 1};
|
|
uint32 seqnr = 0;
|
|
auto init_proc = [&](absl::Span<const int64> indexes) {
|
|
source.Set(indexes, ++seqnr);
|
|
return true;
|
|
};
|
|
ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
|
|
init_proc);
|
|
|
|
auto blank = Literal::CreateFromShape(shape);
|
|
const int64 src_base[] = {3, 1, 5, 7};
|
|
const int64 dest_base[] = {6, 4, 12, 2};
|
|
const int64 copy_size[] = {7, 8, 11, 9};
|
|
TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
|
|
|
|
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
|
|
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
|
|
bool matched = true;
|
|
auto check_proc = [&](absl::Span<const int64> indexes) {
|
|
std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
|
|
std::transform(source_indexes.begin(), source_indexes.end(), src_base,
|
|
source_indexes.begin(), std::plus<int64>());
|
|
std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
|
|
std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
|
|
blank_indexes.begin(), std::plus<int64>());
|
|
auto bval = blank.Get<uint32>(blank_indexes);
|
|
matched = (bval != 0 && bval == source.Get<uint32>(source_indexes));
|
|
return matched;
|
|
};
|
|
|
|
ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
|
|
check_proc);
|
|
EXPECT_TRUE(matched);
|
|
}
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CopyFromScalars) {
|
|
auto zero = LiteralUtil::CreateR0<uint32>(0);
|
|
auto nine = LiteralUtil::CreateR0<uint32>(9);
|
|
TF_EXPECT_OK(zero.CopyFrom(nine));
|
|
EXPECT_EQ(zero, nine);
|
|
|
|
auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
|
|
TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
|
|
EXPECT_EQ(zero.Get<uint32>({}), 17);
|
|
TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
|
|
EXPECT_EQ(vect.Get<uint32>({4}), 17);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
|
|
const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0});
|
|
const auto const_nine = LiteralUtil::CreateR1<float>({9});
|
|
const auto const_empty = Literal::CreateFromShape(empty_r1_shape);
|
|
|
|
{
|
|
// Source contains dimension with zero elements.
|
|
const auto empty = Literal::CreateFromShape(empty_r1_shape);
|
|
auto nine = LiteralUtil::CreateR1<float>({9});
|
|
|
|
TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
|
|
EXPECT_EQ(nine, const_nine);
|
|
}
|
|
|
|
{
|
|
// Copy 0 element to destination with zero elements.
|
|
auto empty = Literal::CreateFromShape(empty_r1_shape);
|
|
auto nine = LiteralUtil::CreateR1<float>({9});
|
|
|
|
TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
|
|
EXPECT_EQ(empty, const_empty);
|
|
}
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CopyFromNilShape) {
|
|
Literal nil_literal0(ShapeUtil::MakeNil());
|
|
Literal nil_literal1(ShapeUtil::MakeNil());
|
|
// This doesn't actually do any copying, but it should succeed.
|
|
TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CopyFromArrays) {
|
|
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
|
|
auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
|
|
EXPECT_NE(scalar_42, scalar_123);
|
|
TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
|
|
/*src_shape_index=*/{}));
|
|
EXPECT_EQ(scalar_42, scalar_123);
|
|
EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
|
|
|
|
auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
|
|
EXPECT_NE(matrix_1234, matrix_5678);
|
|
EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
|
|
TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
|
|
/*src_shape_index=*/{}));
|
|
EXPECT_EQ(matrix_1234, matrix_5678);
|
|
EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CopyFromTuples) {
|
|
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
Literal nil_literal(ShapeUtil::MakeNil());
|
|
Literal inner_elements[] = {LiteralUtil::CreateR0<int32>(42),
|
|
LiteralUtil::CreateR1<double>({23.0, 44.0})};
|
|
Literal inner_tuple = LiteralUtil::MakeTuple(
|
|
{&inner_elements[0], &inner_elements[1], &nil_literal});
|
|
Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
|
|
// Create a tuple the same shape as the inner tuple of nested_tuple but with
|
|
// different values..
|
|
Literal int32_minus5 = LiteralUtil::CreateR0<int32>(-5);
|
|
Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
|
|
Literal tuple =
|
|
LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
|
|
|
|
EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
|
|
EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), 42);
|
|
EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
|
|
EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
|
|
|
|
// Overwrite the inner tuple element of nested_tuple with the contents of
|
|
// 'tuple'.
|
|
TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
|
|
/*src_shape_index=*/{}));
|
|
|
|
// The matrix element should be unchanged.
|
|
EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
|
|
|
|
// The tuple element should have been copied from 'tuple'.
|
|
EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), -5);
|
|
EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
|
|
EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
|
|
}
|
|
TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
|
|
Literal elements[] = {LiteralUtil::CreateR0<int32>(-2),
|
|
LiteralUtil::CreateR0<int32>(4)};
|
|
Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
|
|
|
|
EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
|
|
EXPECT_EQ(tuple.Get<int32>({}, {1}), 4);
|
|
|
|
// Copy from one element to the other.
|
|
TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
|
|
/*src_shape_index=*/{0}));
|
|
|
|
EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
|
|
EXPECT_EQ(tuple.Get<int32>({}, {1}), -2);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
|
|
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
|
|
Status status = matrix.CopyFrom(vector);
|
|
ASSERT_FALSE(status.ok());
|
|
EXPECT_THAT(status.error_message(),
|
|
HasSubstr("Destination subshape incompatible"));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, F16) {
|
|
// Verify that the internal data views are consistent and that they
|
|
// are in little endian format
|
|
// TODO - modify if we make the data format machine endianness dependent
|
|
Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
|
|
const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
|
|
EXPECT_EQ(d1[0], 0);
|
|
EXPECT_EQ(d1[1], 0);
|
|
EXPECT_EQ(d1[2], 0);
|
|
EXPECT_EQ(d1[3], 0);
|
|
EXPECT_EQ(d1[4], 0);
|
|
EXPECT_EQ(d1[5], 0);
|
|
EXPECT_EQ(d1[6], 0);
|
|
EXPECT_EQ(d1[7], 0);
|
|
|
|
half h1(1.0f);
|
|
half h2(2.0f);
|
|
auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
|
|
const char* d2 = reinterpret_cast<const char*>(m2.data<half>().data());
|
|
EXPECT_EQ(d2[0], 0);
|
|
EXPECT_EQ(d2[1], 0x3C);
|
|
EXPECT_EQ(d2[2], 0);
|
|
EXPECT_EQ(d2[3], 0x40);
|
|
EXPECT_EQ(d2[4], 0);
|
|
EXPECT_EQ(d2[5], 0x40);
|
|
EXPECT_EQ(d2[6], 0);
|
|
EXPECT_EQ(d2[7], 0x3C);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, Populate) {
|
|
struct PopulateData {
|
|
std::vector<int64> dimensions;
|
|
std::vector<int64> layout;
|
|
} populate_data[] = {
|
|
{{}, {}},
|
|
{{0}, {0}},
|
|
{{16}, {0}},
|
|
{{2, 0}, {1, 0}},
|
|
{{4, 16}, {1, 0}},
|
|
{{21, 12}, {0, 1}},
|
|
{{6, 11, 17}, {2, 0, 1}},
|
|
{{6, 11, 5, 17}, {3, 2, 0, 1}},
|
|
};
|
|
for (const auto& data : populate_data) {
|
|
Shape shape = ShapeUtil::MakeShapeWithLayout(
|
|
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
|
|
data.layout);
|
|
Literal literal(shape);
|
|
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
|
|
// Offsets from linear index just to avoid R0 literals to be initialized
|
|
// with zero.
|
|
return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
|
|
indexes) +
|
|
17;
|
|
};
|
|
TF_EXPECT_OK(literal.Populate<uint32>(generator));
|
|
|
|
std::vector<int64> zero_base(data.dimensions.size(), 0);
|
|
std::vector<int64> step(data.dimensions.size(), 1);
|
|
bool matched = true;
|
|
auto check_function = [&](absl::Span<const int64> indexes) {
|
|
auto value = literal.Get<uint32>(indexes);
|
|
matched = matched && (value == generator(indexes));
|
|
return matched;
|
|
};
|
|
ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
|
|
check_function);
|
|
EXPECT_TRUE(matched);
|
|
}
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, PopulateParallel) {
|
|
struct PopulateData {
|
|
std::vector<int64> dimensions;
|
|
std::vector<int64> layout;
|
|
} populate_data[] = {
|
|
{{}, {}},
|
|
{{0}, {0}},
|
|
{{16}, {0}},
|
|
{{2, 0}, {1, 0}},
|
|
{{4, 16}, {1, 0}},
|
|
{{21, 12}, {0, 1}},
|
|
{{6, 11, 17}, {2, 0, 1}},
|
|
{{6, 11, 5, 17}, {3, 2, 0, 1}},
|
|
};
|
|
for (const auto& data : populate_data) {
|
|
Shape shape = ShapeUtil::MakeShapeWithLayout(
|
|
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
|
|
data.layout);
|
|
Literal literal(shape);
|
|
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
|
|
// Offsets from linear index just to avoid R0 literals to be initialized
|
|
// with zero.
|
|
return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
|
|
indexes) +
|
|
17;
|
|
};
|
|
TF_EXPECT_OK(literal.PopulateParallel<uint32>(generator));
|
|
|
|
std::vector<int64> zero_base(data.dimensions.size(), 0);
|
|
std::vector<int64> step(data.dimensions.size(), 1);
|
|
bool matched = true;
|
|
auto check_function = [&](absl::Span<const int64> indexes) {
|
|
auto value = literal.Get<uint32>(indexes);
|
|
matched = matched && (value == generator(indexes));
|
|
return matched;
|
|
};
|
|
ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
|
|
check_function);
|
|
EXPECT_TRUE(matched);
|
|
}
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ConvertR4) {
|
|
// clang-format off
|
|
auto original = LiteralUtil::CreateR4WithLayout<int8>({{
|
|
{{10, 11, 12, 13}, {14, 15, 16, 17}},
|
|
{{18, 19, 20, 21}, {22, 23, 24, 25}},
|
|
{{26, 27, 28, 29}, {30, 31, 32, 33}},
|
|
}}, layout_r4_dim0major_);
|
|
auto expected = LiteralUtil::CreateR4WithLayout<uint32>({{
|
|
{{10, 11, 12, 13}, {14, 15, 16, 17}},
|
|
{{18, 19, 20, 21}, {22, 23, 24, 25}},
|
|
{{26, 27, 28, 29}, {30, 31, 32, 33}},
|
|
}}, layout_r4_dim0major_);
|
|
// clang-format on
|
|
TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
|
|
|
|
EXPECT_EQ(expected, converted);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
|
|
// clang-format off
|
|
auto s8 = LiteralUtil::CreateR4WithLayout<int8>({{
|
|
{{10, 0, 12, 0}, {0, 15, 0, 17}},
|
|
{{0, 19, 0, 21}, {22, 0, 24, 0}},
|
|
{{26, 0, 28, 0}, {0, 31, 0, 33}},
|
|
}}, layout_r4_dim0major_);
|
|
auto s16 = LiteralUtil::CreateR4WithLayout<int16>({{
|
|
{{10, 0, 12, 0}, {0, 15, 0, 17}},
|
|
{{0, 19, 0, 21}, {22, 0, 24, 0}},
|
|
{{26, 0, 28, 0}, {0, 31, 0, 33}},
|
|
}}, layout_r4_dim0major_);
|
|
auto s32 = LiteralUtil::CreateR4WithLayout<int32>({{
|
|
{{10, 0, 12, 0}, {0, 15, 0, 17}},
|
|
{{0, 19, 0, 21}, {22, 0, 24, 0}},
|
|
{{26, 0, 28, 0}, {0, 31, 0, 33}},
|
|
}}, layout_r4_dim0major_);
|
|
auto u16 = LiteralUtil::CreateR4WithLayout<uint16>({{
|
|
{{10, 0, 12, 0}, {0, 15, 0, 17}},
|
|
{{0, 19, 0, 21}, {22, 0, 24, 0}},
|
|
{{26, 0, 28, 0}, {0, 31, 0, 33}},
|
|
}}, layout_r4_dim0major_);
|
|
auto u32 = LiteralUtil::CreateR4WithLayout<uint32>({{
|
|
{{10, 0, 12, 0}, {0, 15, 0, 17}},
|
|
{{0, 19, 0, 21}, {22, 0, 24, 0}},
|
|
{{26, 0, 28, 0}, {0, 31, 0, 33}},
|
|
}}, layout_r4_dim0major_);
|
|
auto s64 = LiteralUtil::CreateR4WithLayout<int64>({{
|
|
{{10, 0, 12, 0}, {0, 15, 0, 17}},
|
|
{{0, 19, 0, 21}, {22, 0, 24, 0}},
|
|
{{26, 0, 28, 0}, {0, 31, 0, 33}},
|
|
}}, layout_r4_dim0major_);
|
|
auto u64 = LiteralUtil::CreateR4WithLayout<uint64>({{
|
|
{{10, 0, 12, 0}, {0, 15, 0, 17}},
|
|
{{0, 19, 0, 21}, {22, 0, 24, 0}},
|
|
{{26, 0, 28, 0}, {0, 31, 0, 33}},
|
|
}}, layout_r4_dim0major_);
|
|
auto pred = LiteralUtil::CreateR4WithLayout<bool>({{
|
|
{{true, false, true, false}, {false, true, false, true}},
|
|
{{false, true, false, true}, {true, false, true, false}},
|
|
{{true, false, true, false}, {false, true, false, true}},
|
|
}}, layout_r4_dim0major_);
|
|
auto int32_pred = LiteralUtil::CreateR4WithLayout<int32>({{
|
|
{{1, 0, 1, 0}, {0, 1, 0, 1}},
|
|
{{0, 1, 0, 1}, {1, 0, 1, 0}},
|
|
{{1, 0, 1, 0}, {0, 1, 0, 1}},
|
|
}}, layout_r4_dim0major_);
|
|
auto f16 = LiteralUtil::CreateR4WithLayout<half>({{
|
|
{{half(10.0), half(0.0), half(12.0), half(0.0)},
|
|
{half(0.0), half(15.0), half(0.0), half(17.0)}},
|
|
{{half(0.0), half(19.0), half(0.0), half(21.0)},
|
|
{half(22.0), half(0.0), half(24.0), half(0.0)}},
|
|
{{half(26.0), half(0.0), half(28.0), half(0.0)},
|
|
{half(0.0), half(31.0), half(0.0), half(33.0)}},
|
|
}}, layout_r4_dim0major_);
|
|
auto bf16 = LiteralUtil::CreateR4WithLayout<bfloat16>({{
|
|
{{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
|
|
{bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
|
|
{{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
|
|
{bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
|
|
{{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
|
|
{bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
|
|
}}, layout_r4_dim0major_);
|
|
auto f32 = LiteralUtil::CreateR4WithLayout<float>({{
|
|
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
|
|
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
|
|
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
|
|
}}, layout_r4_dim0major_);
|
|
auto f64 = LiteralUtil::CreateR4WithLayout<double>({{
|
|
{{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
|
|
{{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
|
|
{{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
|
|
}}, layout_r4_dim0major_);
|
|
auto c64 = LiteralUtil::CreateR4WithLayout<complex64>({{
|
|
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
|
|
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
|
|
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
|
|
}}, layout_r4_dim0major_);
|
|
auto c128 = LiteralUtil::CreateR4WithLayout<complex128>({{
|
|
{{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
|
|
{{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
|
|
{{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
|
|
}}, layout_r4_dim0major_); // clang-format on
|
|
Literal conv;
|
|
|
|
conv = s8.Convert(U16).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, u16);
|
|
|
|
conv = s8.Convert(S16).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, s16);
|
|
|
|
conv = s8.Convert(U32).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, u32);
|
|
|
|
conv = s8.Convert(S32).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, s32);
|
|
|
|
conv = s8.Convert(U64).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, u64);
|
|
|
|
conv = s8.Convert(S64).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, s64);
|
|
|
|
conv = s8.Convert(PRED).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, pred);
|
|
|
|
conv = bf16.Convert(S32).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, s32);
|
|
|
|
conv = bf16.Convert(F32).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, f32);
|
|
|
|
conv = pred.Convert(S32).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, int32_pred);
|
|
|
|
conv = f32.Convert(S32).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, s32);
|
|
|
|
conv = f64.Convert(S32).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, s32);
|
|
|
|
conv = s32.Convert(F32).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, f32);
|
|
|
|
conv = f32.Convert(F16).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, f16);
|
|
|
|
conv = f64.Convert(F16).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, f16);
|
|
|
|
conv = s32.Convert(F16).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, f16);
|
|
|
|
conv = u32.Convert(F16).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, f16);
|
|
|
|
conv = s32.Convert(C64).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, c64);
|
|
|
|
conv = f16.Convert(C64).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, c64);
|
|
|
|
conv = s32.Convert(S16).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, s16);
|
|
|
|
conv = s32.Convert(U16).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, u16);
|
|
|
|
conv = s32.Convert(C128).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, c128);
|
|
|
|
conv = f16.Convert(C128).ConsumeValueOrDie();
|
|
EXPECT_EQ(conv, c128);
|
|
|
|
EXPECT_EQ(s32.Convert(TUPLE).status().code(),
|
|
tensorflow::error::UNIMPLEMENTED);
|
|
EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
|
|
EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
|
|
EXPECT_EQ(c128.Convert(F32).status().code(),
|
|
tensorflow::error::UNIMPLEMENTED);
|
|
EXPECT_EQ(c128.Convert(S32).status().code(),
|
|
tensorflow::error::UNIMPLEMENTED);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, BitcastConvert) {
|
|
auto original = LiteralUtil::CreateR1<uint32>(
|
|
{absl::bit_cast<uint32>(2.5f), absl::bit_cast<uint32>(-42.25f),
|
|
absl::bit_cast<uint32>(100.f), 0xbeef});
|
|
auto expected = LiteralUtil::CreateR1<float>(
|
|
{2.5f, -42.25f, 100.0f, absl::bit_cast<float>(0xbeef)});
|
|
TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
|
|
auto literal = LiteralUtil::CreateR0<uint32>(1234);
|
|
Status status = literal.BitcastConvert(F64).status();
|
|
EXPECT_NE(Status::OK(), status);
|
|
EXPECT_TRUE(
|
|
absl::StrContains(status.error_message(), "bit widths are different"));
|
|
}
|
|
|
|
// Sets the layout of the given ShapeProto to the default.
|
|
void SetDefaultLayoutOnProto(ShapeProto* shape_proto) {
|
|
CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type()));
|
|
shape_proto->mutable_layout()->set_format(DENSE);
|
|
auto* minor_to_major =
|
|
shape_proto->mutable_layout()->mutable_minor_to_major();
|
|
minor_to_major->Resize(shape_proto->dimensions_size(), 0);
|
|
const int64 size = minor_to_major->size();
|
|
for (int64 i = 0; i < size; ++i) {
|
|
minor_to_major->Set(i, size - 1 - i);
|
|
}
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
|
|
LiteralProto p;
|
|
p.mutable_shape()->set_element_type(PRED);
|
|
for (int len = 0; len < 25; ++len) {
|
|
p.mutable_shape()->clear_dimensions();
|
|
p.mutable_shape()->add_dimensions(len);
|
|
SetDefaultLayoutOnProto(p.mutable_shape());
|
|
p.clear_preds();
|
|
for (int i = 0; i < len; ++i) {
|
|
p.add_preds((i % 2) == (len % 2));
|
|
}
|
|
|
|
TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
|
|
ASSERT_EQ(len, literal.data<bool>().size());
|
|
int i = 0;
|
|
for (bool value : literal.data<bool>()) {
|
|
EXPECT_EQ((i % 2) == (len % 2), value);
|
|
++i;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Note that f16 is currently stored in a byte array in little endian byte order
|
|
TEST_F(LiteralUtilTest, ToProto_f16) {
|
|
half h1(1.0f);
|
|
half h2(2.0f);
|
|
|
|
auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
|
|
EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
|
|
EXPECT_EQ(4, m.data<half>().size());
|
|
|
|
LiteralProto p = m.ToProto();
|
|
EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape())));
|
|
EXPECT_EQ(8, p.f16s().size());
|
|
const char* d = p.f16s().data();
|
|
EXPECT_EQ(d[0], 0);
|
|
EXPECT_EQ(d[1], 0x3C);
|
|
EXPECT_EQ(d[2], 0);
|
|
EXPECT_EQ(d[3], 0x40);
|
|
EXPECT_EQ(d[4], 0);
|
|
EXPECT_EQ(d[5], 0x40);
|
|
EXPECT_EQ(d[6], 0);
|
|
EXPECT_EQ(d[7], 0x3C);
|
|
}
|
|
|
|
// Note that f16 is currently stored in a byte array in little endian byte order
|
|
TEST_F(LiteralUtilTest, CopyFromProto_f16) {
|
|
half h1(1.0f);
|
|
half h2(2.0f);
|
|
|
|
const char half_vals[8] = {0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C};
|
|
LiteralProto p;
|
|
p.mutable_shape()->set_element_type(F16);
|
|
p.mutable_shape()->clear_dimensions();
|
|
p.mutable_shape()->add_dimensions(4);
|
|
SetDefaultLayoutOnProto(p.mutable_shape());
|
|
p.clear_f16s();
|
|
p.set_f16s(half_vals, 8);
|
|
TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
|
|
auto r = literal.data<half>();
|
|
ASSERT_EQ(4, r.size());
|
|
EXPECT_EQ(h1, r[0]);
|
|
EXPECT_EQ(h2, r[1]);
|
|
EXPECT_EQ(h2, r[2]);
|
|
EXPECT_EQ(h1, r[3]);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CopyFromProto_u16) {
|
|
uint16 u1(0xabcd);
|
|
uint16 u2(0x1234);
|
|
|
|
const unsigned char uint16_vals[8] = {0xcd, 0xab, 0x34, 0x12,
|
|
0x34, 0x12, 0xcd, 0xab};
|
|
LiteralProto p;
|
|
p.mutable_shape()->set_element_type(U16);
|
|
p.mutable_shape()->clear_dimensions();
|
|
p.mutable_shape()->add_dimensions(4);
|
|
SetDefaultLayoutOnProto(p.mutable_shape());
|
|
p.clear_u16s();
|
|
p.set_u16s(uint16_vals, 8);
|
|
TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
|
|
auto r = literal.data<uint16>();
|
|
ASSERT_EQ(4, r.size());
|
|
EXPECT_EQ(u1, r[0]);
|
|
EXPECT_EQ(u2, r[1]);
|
|
EXPECT_EQ(u2, r[2]);
|
|
EXPECT_EQ(u1, r[3]);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, LiteralDynamicSliceTest) {
|
|
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
|
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
|
|
auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
|
|
Literal nil(ShapeUtil::MakeNil());
|
|
|
|
EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
|
|
EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
|
|
EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
|
|
EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
|
|
EXPECT_EQ(LiteralSlice(nil, {}), nil);
|
|
|
|
EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
|
|
EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
|
|
|
|
EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
|
|
EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
|
|
EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
|
|
EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
|
|
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
|
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
|
|
auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
|
|
// Verify that changing the underlying data beneath the view changes the
|
|
// data of the view itself.
|
|
const auto nested_tuple_view = LiteralSlice(nested_tuple);
|
|
EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
|
|
1.0f);
|
|
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
|
|
/*shape_index=*/{0, 0}),
|
|
1.0f);
|
|
nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
|
|
EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
|
|
555.0f);
|
|
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
|
|
/*shape_index=*/{0, 0}),
|
|
555.0f);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
|
|
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
|
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
|
|
auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
|
|
|
|
const auto nested_tuple_view = LiteralSlice(nested_tuple);
|
|
const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
|
|
const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
|
|
EXPECT_EQ(matrix_view,
|
|
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
|
|
std::vector<int64> int64_values = {1, 2, 3};
|
|
const Shape literal_shape = ShapeUtil::MakeShape(S64, {3});
|
|
|
|
BorrowingLiteral literal(reinterpret_cast<const char*>(int64_values.data()),
|
|
literal_shape);
|
|
|
|
EXPECT_EQ(literal.Get<int64>({0}), 1);
|
|
EXPECT_EQ(literal.Get<int64>({1}), 2);
|
|
EXPECT_EQ(literal.Get<int64>({2}), 3);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
|
|
std::vector<int64> one_two_three = {1, 2, 3};
|
|
const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3});
|
|
|
|
std::vector<int64> hundred = {100};
|
|
const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1});
|
|
|
|
std::vector<const char*> src_buf_ptrs;
|
|
src_buf_ptrs.emplace_back(
|
|
reinterpret_cast<const char*>(one_two_three.data()));
|
|
src_buf_ptrs.emplace_back(reinterpret_cast<const char*>(hundred.data()));
|
|
auto literal_tuple = BorrowingLiteral(
|
|
src_buf_ptrs,
|
|
ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape}));
|
|
|
|
EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{0}, /*shape_index=*/{0}),
|
|
1);
|
|
EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{0}, /*shape_index=*/{1}),
|
|
100);
|
|
|
|
EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{1}, /*shape_index=*/{0}),
|
|
2);
|
|
|
|
EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{2}, /*shape_index=*/{0}),
|
|
3);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, LiteralMove) {
|
|
Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
Literal literal(std::move(matrix));
|
|
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
|
|
EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
|
|
EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
|
|
EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
|
|
EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, DecomposeTuple) {
|
|
Literal nil_literal(ShapeUtil::MakeNil());
|
|
Literal inner_elements[] = {
|
|
LiteralUtil::CreateR0<int32>(42),
|
|
LiteralUtil::CreateR1<double>({23.0, 44.0}),
|
|
};
|
|
Literal tuple_elements[] = {
|
|
LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}),
|
|
LiteralUtil::MakeTuple(
|
|
{&inner_elements[0], &inner_elements[1], &nil_literal}),
|
|
};
|
|
Literal nested_tuple = LiteralUtil::MakeTuple(
|
|
{&tuple_elements[0], &tuple_elements[1], &nil_literal});
|
|
|
|
EXPECT_FALSE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
|
|
std::vector<Literal> elements = nested_tuple.DecomposeTuple();
|
|
EXPECT_TRUE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
|
|
|
|
ASSERT_EQ(elements.size(), 3);
|
|
|
|
EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(),
|
|
ShapeUtil::MakeShape(S32, {2, 2})));
|
|
EXPECT_EQ(elements[0].Get<int32>({0, 0}), 1);
|
|
EXPECT_EQ(elements[0].Get<int32>({0, 1}), 2);
|
|
EXPECT_EQ(elements[0].Get<int32>({1, 0}), 3);
|
|
EXPECT_EQ(elements[0].Get<int32>({1, 1}), 4);
|
|
|
|
EXPECT_TRUE(ShapeUtil::Compatible(
|
|
elements[1].shape(),
|
|
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}),
|
|
ShapeUtil::MakeShape(F64, {2}),
|
|
ShapeUtil::MakeNil()})));
|
|
EXPECT_EQ(elements[1].Get<int32>({}, /*shape_index=*/{0}), 42);
|
|
EXPECT_EQ(elements[1].Get<double>({0}, /*shape_index=*/{1}), 23.0);
|
|
EXPECT_EQ(elements[1].Get<double>({1}, /*shape_index=*/{1}), 44.0);
|
|
|
|
EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil()));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
|
|
Literal nil_literal(ShapeUtil::MakeNil());
|
|
std::vector<Literal> elements = nil_literal.DecomposeTuple();
|
|
EXPECT_EQ(elements.size(), 0);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, MoveIntoTuple) {
|
|
std::vector<Literal> elements;
|
|
elements.push_back(LiteralUtil::CreateR0<float>(1.0));
|
|
elements.push_back(LiteralUtil::CreateR1<int32>({4, 8}));
|
|
std::vector<Literal> inner_elements;
|
|
inner_elements.push_back(LiteralUtil::CreateR0<int32>(42));
|
|
inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
|
|
elements.push_back(
|
|
LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
|
|
|
|
Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
|
|
ASSERT_TRUE(literal.shape().IsTuple());
|
|
ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
|
|
|
|
EXPECT_EQ(literal.Get<float>({}, /*shape_index=*/{0}), 1.0);
|
|
EXPECT_EQ(literal.Get<int32>({0}, /*shape_index=*/{1}), 4);
|
|
EXPECT_EQ(literal.Get<int32>({1}, /*shape_index=*/{1}), 8);
|
|
EXPECT_EQ(literal.Get<int32>({}, /*shape_index=*/{2, 0}), 42);
|
|
EXPECT_EQ(literal.Get<double>({0}, /*shape_index=*/{2, 1}), 23.0);
|
|
EXPECT_EQ(literal.Get<double>({1}, /*shape_index=*/{2, 1}), 44.0);
|
|
|
|
for (const Literal& element : elements) {
|
|
EXPECT_TRUE(ShapeUtil::IsEmptyTuple(element.shape()));
|
|
}
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) {
|
|
Literal literal = Literal::MoveIntoTuple({});
|
|
ASSERT_TRUE(literal.shape().IsTuple());
|
|
EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
|
|
Literal literal;
|
|
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
|
|
|
|
Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
literal = std::move(matrix);
|
|
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
|
|
EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
|
|
EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
|
|
EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
|
|
EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, LiteralSliceCopy) {
|
|
Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
const auto matrix_view = LiteralSlice(matrix);
|
|
LiteralSlice matrix_view_copy(matrix_view);
|
|
|
|
EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
|
|
EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
|
|
EXPECT_EQ(matrix_view_copy.Get<float>({1, 0}), 3.0);
|
|
EXPECT_EQ(matrix_view_copy.Get<float>({1, 1}), 4.0);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, GetSetTuple) {
|
|
Literal elements[] = {
|
|
LiteralUtil::CreateR0<float>(42.0),
|
|
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
|
|
};
|
|
auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
|
|
EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
|
|
tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
|
|
EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
|
|
|
|
EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
|
|
tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
|
|
EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
|
|
-4.0);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
|
|
// Literals constructed using CreateFromShape should be zero initialized.
|
|
Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
|
|
EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
|
|
EXPECT_TRUE(scalar_f32.IsAll(0));
|
|
|
|
Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
|
|
EXPECT_EQ(vector_s32.Get<int32>({0}), 0);
|
|
EXPECT_EQ(vector_s32.Get<int32>({1}), 0);
|
|
EXPECT_EQ(vector_s32.Get<int32>({2}), 0);
|
|
EXPECT_TRUE(vector_s32.IsAll(0));
|
|
|
|
Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
|
|
{ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
|
|
ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}),
|
|
ShapeUtil::MakeShape(C128, {})}));
|
|
|
|
EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
|
|
EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
|
|
EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
|
|
EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
|
|
EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
|
|
EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
|
|
EXPECT_EQ(tuple.Get<complex128>({}, {4}), complex128(0.0, 0.0));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ProtoRoundTrip) {
|
|
// Test serializing then deserializing a Literal through a proto.
|
|
auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
|
|
auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
|
|
auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
|
|
auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
|
|
auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto vector_c128 =
|
|
LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
|
|
auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
|
|
{bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
|
|
auto vector_half =
|
|
LiteralUtil::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
|
|
auto matrix_pred =
|
|
LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
|
|
auto tuple = LiteralUtil::MakeTuple(
|
|
{&one_f32, &vector_half, &matrix_pred, &matrix_pred});
|
|
Literal nil_literal(ShapeUtil::MakeNil());
|
|
auto nested_tuple =
|
|
LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
|
|
|
|
auto to_from_proto = [](const Literal& literal) -> Literal {
|
|
return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
|
|
};
|
|
|
|
EXPECT_EQ(one_f32, to_from_proto(one_f32));
|
|
EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
|
|
EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
|
|
EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
|
|
EXPECT_EQ(vector_c128, to_from_proto(vector_c128));
|
|
EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
|
|
EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
|
|
EXPECT_EQ(tuple, to_from_proto(tuple));
|
|
EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
|
|
EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
|
|
|
|
EXPECT_NE(one_f32, two_f32);
|
|
EXPECT_NE(one_f32, to_from_proto(two_f32));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
|
|
// Proto contains a shape, but no values.
|
|
LiteralProto proto;
|
|
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
|
|
Status status = Literal::CreateFromProto(proto).status();
|
|
ASSERT_FALSE(status.ok());
|
|
EXPECT_THAT(status.error_message(),
|
|
HasSubstr("Expected 3 elements in LiteralProto"));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ValidProtoNoValues) {
|
|
// Proto contains a shape, but no values.
|
|
LiteralProto proto;
|
|
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
|
|
Status status =
|
|
Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
|
|
.status();
|
|
EXPECT_TRUE(status.ok());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, ValidProtoWithClearedValues) {
|
|
auto literal = LiteralUtil::CreateR1<bool>({true, false, true});
|
|
LiteralProto proto = literal.ToProto();
|
|
EXPECT_EQ(proto.preds_size(), 3);
|
|
|
|
// Clear values.
|
|
proto.clear_preds();
|
|
EXPECT_EQ(proto.preds_size(), 0);
|
|
Status status =
|
|
Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
|
|
.status();
|
|
EXPECT_TRUE(status.ok());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
|
|
// Proto contains values, but no shape.
|
|
LiteralProto proto;
|
|
proto.add_preds(false);
|
|
proto.add_preds(true);
|
|
proto.add_preds(false);
|
|
Status status = Literal::CreateFromProto(proto).status();
|
|
ASSERT_FALSE(status.ok());
|
|
EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
|
|
// Proto contains values in wrong container.
|
|
LiteralProto proto;
|
|
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
|
|
proto.add_preds(false);
|
|
proto.add_preds(true);
|
|
proto.add_preds(false);
|
|
Status status = Literal::CreateFromProto(proto).status();
|
|
ASSERT_FALSE(status.ok());
|
|
EXPECT_THAT(status.error_message(),
|
|
HasSubstr("Expected 3 elements in LiteralProto"));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
|
|
// Proto contains too few values.
|
|
LiteralProto proto;
|
|
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto();
|
|
proto.add_f32s(1.0);
|
|
proto.add_f32s(2.0);
|
|
proto.add_f32s(3.0);
|
|
Status status = Literal::CreateFromProto(proto).status();
|
|
ASSERT_FALSE(status.ok());
|
|
EXPECT_THAT(status.error_message(),
|
|
HasSubstr("Expected 84 elements in LiteralProto"));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
|
|
// Proto contains too many values.
|
|
LiteralProto proto;
|
|
*proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto();
|
|
proto.add_s32s(42);
|
|
proto.add_s32s(-10);
|
|
proto.add_s32s(100);
|
|
Status status = Literal::CreateFromProto(proto).status();
|
|
ASSERT_FALSE(status.ok());
|
|
EXPECT_THAT(status.error_message(),
|
|
HasSubstr("Expected 2 elements in LiteralProto"));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
|
|
// Proto shape missing layout.
|
|
LiteralProto proto;
|
|
*proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto();
|
|
proto.mutable_shape()->clear_layout();
|
|
proto.add_preds(true);
|
|
proto.add_preds(false);
|
|
proto.add_preds(true);
|
|
proto.add_preds(false);
|
|
Status status = Literal::CreateFromProto(proto).status();
|
|
ASSERT_FALSE(status.ok());
|
|
EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
|
|
// Proto has the too few tuple elements.
|
|
LiteralProto proto;
|
|
*proto.mutable_shape() =
|
|
ShapeUtil::MakeTupleShape(
|
|
{ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
|
|
.ToProto();
|
|
LiteralProto* element0 = proto.add_tuple_literals();
|
|
*element0->mutable_shape() =
|
|
ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
|
|
element0->add_preds(false);
|
|
element0->add_preds(true);
|
|
|
|
Status status = Literal::CreateFromProto(proto).status();
|
|
ASSERT_FALSE(status.ok());
|
|
EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
|
|
// Proto has the too many tuple elements.
|
|
LiteralProto proto;
|
|
*proto.mutable_shape() =
|
|
ShapeUtil::MakeTupleShape(
|
|
{ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
|
|
.ToProto();
|
|
LiteralProto* element0 = proto.add_tuple_literals();
|
|
*element0->mutable_shape() =
|
|
ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
|
|
element0->add_preds(false);
|
|
element0->add_preds(true);
|
|
LiteralProto* element1 = proto.add_tuple_literals();
|
|
*element1->mutable_shape() =
|
|
ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto();
|
|
element1->add_f32s(42.0);
|
|
LiteralProto* element2 = proto.add_tuple_literals();
|
|
*element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto();
|
|
element2->add_f32s(123.0);
|
|
|
|
Status status = Literal::CreateFromProto(proto).status();
|
|
ASSERT_FALSE(status.ok());
|
|
EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
|
|
Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Literal broadcasted_literal,
|
|
literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
|
|
/*dimensions=*/{0}));
|
|
EXPECT_EQ(broadcasted_literal,
|
|
LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
|
|
Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Literal broadcasted_literal,
|
|
literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
|
|
/*dimensions=*/{1}));
|
|
EXPECT_EQ(broadcasted_literal,
|
|
LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
|
|
Literal literal = LiteralUtil::CreateR0<int32>(9);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Literal broadcasted_literal,
|
|
literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
|
|
/*dimensions=*/{}));
|
|
EXPECT_EQ(broadcasted_literal,
|
|
LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, DynamicBroadcast) {
|
|
Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
|
|
literal.SetDynamicSize(0, 1);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Literal broadcasted_literal,
|
|
literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
|
|
/*dimensions=*/{1}));
|
|
EXPECT_EQ(broadcasted_literal, LiteralUtil::CreateR2<int64>({{1}, {1}}));
|
|
EXPECT_EQ(broadcasted_literal.GetDynamicSize(1), 1);
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, GetAsComplex128) {
|
|
complex128 value = {1, 0};
|
|
Literal c1 = LiteralUtil::CreateR0<complex128>(value);
|
|
EXPECT_EQ(*c1.GetAsComplex128({}), value);
|
|
Literal c2 = LiteralUtil::CreateR0<double>(1);
|
|
EXPECT_EQ(*c2.GetAsComplex128({}), value);
|
|
complex64 float_value = {1, 0};
|
|
Literal c4 = LiteralUtil::CreateR0<complex64>(float_value);
|
|
EXPECT_EQ(*c4.GetAsComplex128({}), value);
|
|
complex128 other_value = {1, 2};
|
|
Literal c5 = LiteralUtil::CreateR0<complex128>(other_value);
|
|
EXPECT_EQ(*c5.GetAsComplex128({}), other_value);
|
|
Literal c6 = LiteralUtil::CreateR0<int64>(1);
|
|
EXPECT_FALSE(c6.GetAsComplex128({}).has_value());
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, SliceOnBool) {
|
|
Literal c1 = LiteralUtil::CreateR1<bool>({true, true, false});
|
|
EXPECT_EQ(c1, c1.Slice({0}, {3}));
|
|
}
|
|
|
|
TEST_F(LiteralUtilTest, IsEqualAt) {
|
|
double val_double = 10.0;
|
|
int val_integral = 10;
|
|
Literal c1 = LiteralUtil::CreateR0<int>(10);
|
|
EXPECT_TRUE(c1.IsEqualAt({}, val_double));
|
|
EXPECT_TRUE(c1.IsEqualAt({}, val_integral));
|
|
Literal c2 = LiteralUtil::CreateR0<double>(10);
|
|
EXPECT_TRUE(c2.IsEqualAt({}, val_double));
|
|
EXPECT_TRUE(c2.IsEqualAt({}, val_integral));
|
|
complex128 val_complex = {10, 0};
|
|
EXPECT_TRUE(c2.IsEqualAt({}, val_complex));
|
|
EXPECT_TRUE(c1.IsEqualAt({}, val_complex));
|
|
Literal c3 = LiteralUtil::CreateR0<complex128>(val_complex);
|
|
EXPECT_TRUE(c3.IsEqualAt({}, val_double));
|
|
EXPECT_TRUE(c3.IsEqualAt({}, val_integral));
|
|
EXPECT_TRUE(c3.IsEqualAt({}, val_complex));
|
|
EXPECT_FALSE(c3.IsEqualAt({}, std::numeric_limits<double>::infinity()));
|
|
complex128 val_true_complex = {10, 3};
|
|
complex64 val_smaller_complex = {10, 3};
|
|
Literal c4 = LiteralUtil::CreateR0<complex128>(val_true_complex);
|
|
EXPECT_TRUE(c4.IsEqualAt({}, val_true_complex));
|
|
EXPECT_TRUE(c4.IsEqualAt({}, val_smaller_complex));
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace xla
|