[XLA] Break out literal comparisons from testonly target.

Moves methods from LiteralTestUtil::* to Literal::* where they have nothing
to do with test infrastructure.

Pares down the "void" variants of the LiteralTestUtil methods and consolidates
to the version that return success/failure such that the values can be
EXPECT_TRUE / ASSERT_TRUE asserted in the caller test cases.

This way the literal comparison functionality can be used from cc_libraries
that are not test only / cc_binary.

PiperOrigin-RevId: 196209410
This commit is contained in:
Chris Leary 2018-05-10 20:10:34 -07:00 committed by TensorFlower Gardener
parent 5a492ef9bb
commit 400dd49b4c
32 changed files with 843 additions and 782 deletions

View File

@ -225,7 +225,7 @@ TEST_F(XlaCompilerTest, Simple) {
xla::Literal::CreateR1<int32>({4, 143}); xla::Literal::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal = std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get()}); xla::Literal::MakeTuple({expected0.get()});
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
} }
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
@ -320,7 +320,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
xla::Literal::CreateR1<int32>({-7, -42}); xla::Literal::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected_literal = std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get()}); xla::Literal::MakeTuple({expected0.get()});
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); EXPECT_TRUE(
xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
} }
{ {
@ -355,7 +356,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
xla::Literal::CreateR1<int32>({-7, -42}); xla::Literal::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected = std::unique_ptr<xla::Literal> expected =
xla::Literal::MakeTuple({expected0.get(), expected1.get()}); xla::Literal::MakeTuple({expected0.get(), expected1.get()});
xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
} }
} }
@ -523,7 +524,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
{output_base.get(), output_grad1.get(), output_grad2.get()}); {output_base.get(), output_grad1.get(), output_grad2.get()});
std::unique_ptr<xla::Literal> expected_literal = std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
} }
// Tests compilation and execution of a graph that adds two tensors. // Tests compilation and execution of a graph that adds two tensors.
@ -746,7 +747,7 @@ TEST_F(XlaCompilerTest, Variables) {
xla::Literal::CreateR1<int32>({4, 143}); xla::Literal::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal = std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get(), expected1.get()}); xla::Literal::MakeTuple({expected0.get(), expected1.get()});
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
} }
// Tests a simple graph that reads and writes a variable, with a // Tests a simple graph that reads and writes a variable, with a
@ -811,7 +812,7 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
xla::Literal::CreateR1<int32>({26, 66, 34, 401}); xla::Literal::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal = std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get(), expected1.get()}); xla::Literal::MakeTuple({expected0.get(), expected1.get()});
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
} }
} // namespace } // namespace

View File

@ -330,6 +330,17 @@ tf_cc_test(
], ],
) )
cc_library(
name = "literal_comparison",
srcs = ["literal_comparison.cc"],
hdrs = ["literal_comparison.h"],
deps = [
":literal_util",
":util",
"//tensorflow/core:lib",
],
)
cc_library( cc_library(
name = "metric_table_report", name = "metric_table_report",
srcs = ["metric_table_report.cc"], srcs = ["metric_table_report.cc"],

View File

@ -0,0 +1,226 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/literal_comparison.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/strings/strcat.h"
using tensorflow::strings::StrCat;
namespace xla {
namespace literal_comparison {
namespace {
// Helper function for comparing a floating point type, FloatT, bitwise equal
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
// -- on miscompare, a nice error message is given in the AssertionFailure.
template <typename FloatT, typename UnsignedT>
Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
auto lhs_double = static_cast<double>(lhs);
auto rhs_double = static_cast<double>(rhs);
if (ulhs != urhs) {
return InvalidArgument(
"floating values are not bitwise-equal; and equality testing "
"was requested: %s=%g=%a vs %s=%g=%a",
StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double,
StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double);
}
return Status::OK();
}
// Templated comparator that specializes for float equality comparison with the
// bitwise helper above (this is the un-specialized fallback, to just use the
// default gunit implementation).
template <typename NativeT>
Status CompareEqual(NativeT lhs, NativeT rhs) {
if (lhs == rhs) {
return Status::OK();
}
return InvalidArgument("Expected equality of these values:\n %s\n %s",
StrCat(lhs).c_str(), StrCat(rhs).c_str());
}
// Specializations for floating types that do bitwise comparisons when equality
// comparison is requested.
template <>
Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
}
template <>
Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs) {
return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
}
template <>
Status CompareEqual<float>(float lhs, float rhs) {
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
}
template <>
Status CompareEqual<double>(double lhs, double rhs) {
return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
}
template <>
Status CompareEqual<complex64>(complex64 lhs, complex64 rhs) {
auto res = CompareEqual<float>(lhs.real(), rhs.real());
if (!res.ok()) {
return res;
}
return CompareEqual<float>(lhs.imag(), rhs.imag());
}
// A recursive function which iterates through every index of expected and
// actual literal and compares their values elementwise. Returns true if all
// elements are equal.
template <typename NativeT>
Status Equal(LiteralSlice expected, LiteralSlice actual,
tensorflow::gtl::MutableArraySlice<int64> multi_index,
int64 dimension) {
if (dimension == expected.shape().dimensions_size()) {
NativeT expected_value = expected.Get<NativeT>(multi_index);
NativeT actual_value = actual.Get<NativeT>(multi_index);
return CompareEqual<NativeT>(expected_value, actual_value);
}
Status result;
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
multi_index[dimension] = i;
result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1));
}
return result;
}
} // namespace
Status EqualShapes(const Shape& expected, const Shape& actual) {
if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
return InvalidArgument("tupleness-mismatch! want: %s got %s",
ShapeUtil::HumanString(expected).c_str(),
ShapeUtil::HumanString(actual).c_str());
}
if (ShapeUtil::IsTuple(expected)) {
if (ShapeUtil::TupleElementCount(expected) !=
ShapeUtil::TupleElementCount(actual)) {
return InvalidArgument(
"want tuple element count: %lld got tuple element count: %lld",
ShapeUtil::TupleElementCount(expected),
ShapeUtil::TupleElementCount(actual));
}
for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
Status result =
EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
if (!result.ok()) {
return AppendStatus(result, StrCat("mismatch in tuple index", i));
}
}
} else {
if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
return InvalidArgument("want rank of %s got rank of %s",
ShapeUtil::HumanString(expected).c_str(),
ShapeUtil::HumanString(actual).c_str());
}
if (expected.element_type() != actual.element_type()) {
return InvalidArgument(
"mismatch in primitive type %s vs %s",
PrimitiveType_Name(expected.element_type()).c_str(),
PrimitiveType_Name(actual.element_type()).c_str());
}
if (expected.dimensions_size() != actual.dimensions_size()) {
return InvalidArgument("want dimensions_size %d got dimensions_size %d",
expected.dimensions_size(),
actual.dimensions_size());
}
for (int i = 0; i < expected.dimensions_size(); ++i) {
if (expected.dimensions(i) != actual.dimensions(i)) {
return InvalidArgument(
"mismatch in dimension #%d expected: %s actual: %s", i,
ShapeUtil::HumanString(expected).c_str(),
ShapeUtil::HumanString(actual).c_str());
}
}
}
return Status::OK();
}
Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
VLOG(1) << "expected:";
XLA_VLOG_LINES(1, expected.ToString());
VLOG(1) << "actual:";
XLA_VLOG_LINES(1, actual.ToString());
TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
Status result;
switch (expected.shape().element_type()) {
case PRED:
result = Equal<bool>(expected, actual, &multi_index, 0);
break;
case U8:
result = Equal<uint8>(expected, actual, &multi_index, 0);
break;
case S32:
result = Equal<int32>(expected, actual, &multi_index, 0);
break;
case S64:
result = Equal<int64>(expected, actual, &multi_index, 0);
break;
case U32:
result = Equal<uint32>(expected, actual, &multi_index, 0);
break;
case U64:
result = Equal<uint64>(expected, actual, &multi_index, 0);
break;
case BF16:
result = Equal<bfloat16>(expected, actual, &multi_index, 0);
break;
case F16:
result = Equal<half>(expected, actual, &multi_index, 0);
break;
case F32:
result = Equal<float>(expected, actual, &multi_index, 0);
break;
case F64:
result = Equal<double>(expected, actual, &multi_index, 0);
break;
case C64:
result = Equal<complex64>(expected, actual, &multi_index, 0);
break;
case TUPLE: {
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
result.Update(
Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i})));
}
break;
}
default:
LOG(FATAL)
<< "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
<< PrimitiveType_Name(expected.shape().element_type());
}
if (result.ok()) {
return Status::OK();
}
return AppendStatus(result,
tensorflow::strings::Printf("expected: %s\nactual: %s",
expected.ToString().c_str(),
actual.ToString().c_str()));
}
} // namespace literal_comparison
} // namespace xla

View File

@ -0,0 +1,40 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Library for comparing literals without taking a dependency on testing
// libraries.
#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/lib/core/status.h"
namespace xla {
namespace literal_comparison {
// Returns ok if the given shapes have the same rank, dimension sizes, and
// primitive types.
Status EqualShapes(const Shape& expected, const Shape& actual);
// Returns ok if the expected and actual literals are (bitwise) equal for all
// elements in the literal. Also, asserts that the rank, dimensions sizes, and
// primitive type are equal.
Status Equal(const LiteralSlice& expected, const LiteralSlice& actual);
} // namespace literal_comparison
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_

View File

@ -62,6 +62,45 @@ void ConvertEndianShort(char* bytes, int64 size) {
} }
} }
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
// First construct shape of the result.
Shape result_shape(literal.shape());
ShapeUtil::ForEachMutableSubshape(
&result_shape, [](Shape* subshape, const ShapeIndex&) {
if (subshape->element_type() ==
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
subshape->set_element_type(
primitive_util::NativeToPrimitiveType<ToNativeT>());
}
});
auto result = MakeUnique<Literal>(result_shape);
// Then copy over the data from 'literal' converting FromNativeT values to
// ToNativeT values as necessary.
ShapeUtil::ForEachSubshape(
literal.shape(),
[&](const Shape& subshape, const ShapeIndex& shape_index) {
if (ShapeUtil::IsArray(subshape)) {
if (subshape.element_type() ==
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
auto src = literal.data<FromNativeT>(shape_index);
auto dest = result->data<ToNativeT>(shape_index);
for (int64 i = 0; i < src.size(); ++i) {
dest[i] = static_cast<ToNativeT>(src[i]);
}
} else {
TF_CHECK_OK(result->CopyFrom(literal,
/*dest_shape_index=*/shape_index,
/*src_shape_index=*/shape_index));
}
}
});
return result;
}
} // namespace } // namespace
LiteralBase::~LiteralBase() {} LiteralBase::~LiteralBase() {}
@ -195,6 +234,16 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
} }
/* static */ std::unique_ptr<Literal> Literal::ConvertBF16ToF32(
const LiteralSlice& bf16_literal) {
return ConvertType<bfloat16, float>(bf16_literal);
}
/* static */ std::unique_ptr<Literal> Literal::ConvertF32ToBF16(
const LiteralSlice& f32_literal) {
return ConvertType<float, bfloat16>(f32_literal);
}
template <typename NativeT> template <typename NativeT>
Status Literal::CopySliceFromInternal( Status Literal::CopySliceFromInternal(
const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base, const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
@ -788,6 +837,78 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
return std::move(output); return std::move(output);
} }
/* static */ std::unique_ptr<Literal> Literal::ReshapeSlice(
tensorflow::gtl::ArraySlice<int64> new_dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major,
const LiteralSlice& literal) {
int64 new_num_elements = 1;
for (int64 i = 0; i < new_dimensions.size(); ++i) {
new_num_elements *= new_dimensions[i];
}
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
auto new_literal = MakeUnique<Literal>(
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
// Create a new shape with the given minor-to-major layout. This shape is used
// solely for converting linear address to multi-dimensional addresses when
// writing elements to the new literal.
Shape shape_with_layout = new_literal->shape();
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
// Copy data into new literal, element-by-element.
for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
std::vector<int64> from_multi_index =
IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
std::vector<int64> to_multi_index =
IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
switch (literal.shape().element_type()) {
case PRED:
new_literal->Set<bool>(to_multi_index,
literal.Get<bool>(from_multi_index));
break;
case U8:
new_literal->Set<uint8>(to_multi_index,
literal.Get<uint8>(from_multi_index));
break;
case U32:
new_literal->Set<uint32>(to_multi_index,
literal.Get<uint32>(from_multi_index));
break;
case S32:
new_literal->Set<int32>(to_multi_index,
literal.Get<int32>(from_multi_index));
break;
case U64:
new_literal->Set<uint64>(to_multi_index,
literal.Get<uint64>(from_multi_index));
break;
case S64:
new_literal->Set<int64>(to_multi_index,
literal.Get<int64>(from_multi_index));
break;
case F32:
new_literal->Set<float>(to_multi_index,
literal.Get<float>(from_multi_index));
break;
case F64:
new_literal->Set<double>(to_multi_index,
literal.Get<double>(from_multi_index));
break;
case C64:
new_literal->Set<complex64>(to_multi_index,
literal.Get<complex64>(from_multi_index));
break;
default:
LOG(FATAL) << "Unhandled primitive element type: "
<< PrimitiveType_Name(literal.shape().element_type());
}
}
return new_literal;
}
std::unique_ptr<Literal> LiteralBase::Transpose( std::unique_ptr<Literal> LiteralBase::Transpose(
tensorflow::gtl::ArraySlice<int64> permutation) const { tensorflow::gtl::ArraySlice<int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
@ -2123,6 +2244,11 @@ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
return std::move(literal); return std::move(literal);
} }
/* static */ string Literal::MultiIndexAsString(
tensorflow::gtl::ArraySlice<int64> multi_index) {
return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
}
const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
return piece(shape_index).untyped_data(); return piece(shape_index).untyped_data();
} }

View File

@ -920,9 +920,66 @@ class Literal : public LiteralBase {
PrimitiveType primitive_type, PrimitiveType primitive_type,
tensorflow::gtl::ArraySlice<int64> dimensions); tensorflow::gtl::ArraySlice<int64> dimensions);
// If the given literal's data type is bfloat16, converts it to a float
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
static std::unique_ptr<Literal> ConvertBF16ToF32(
const LiteralSlice& bf16_literal);
// If the given literal's data type is float, converts it to a bfloat16
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
static std::unique_ptr<Literal> ConvertF32ToBF16(
const LiteralSlice& f32_literal);
// Creates a literal with a new shape with the given new dimensions using the
// data in the given input literal. For reshaping purposes the (flat) data
// buffer of the input literal is assumed to have the given minor_to_major
// layout order.
static std::unique_ptr<Literal> ReshapeSlice(
tensorflow::gtl::ArraySlice<int64> new_dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major,
const LiteralSlice& literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
// Returns the new literal object, or an error Status if failed.
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape,
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
// deviation, and using the engine as entropy generator.
// Returns the new literal object, or an error Status if failed.
template <
PrimitiveType type, typename E,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape, E* engine, T mean, T stddev);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
// deviation.
// Returns the new literal object, or an error Status if failed.
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape, T mean, T stddev);
// //
// End of factory methods. // End of factory methods.
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
// dimension 1 equal to 8.
static string MultiIndexAsString(
tensorflow::gtl::ArraySlice<int64> multi_index);
protected: protected:
// Recursively sets the subshapes and buffers of all subpieces rooted at // Recursively sets the subshapes and buffers of all subpieces rooted at
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
@ -1558,6 +1615,38 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
return literal; return literal;
} }
template <PrimitiveType type, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
const Shape& shape,
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
[&](tensorflow::gtl::ArraySlice<int64> indexes) {
return generator(indexes);
}));
return std::move(literal);
}
template <PrimitiveType type, typename E, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
const Shape& shape, E* engine, T mean, T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
return generator(*engine);
});
}
template <PrimitiveType type, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}
} // namespace xla } // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_

View File

@ -101,8 +101,8 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
computation, {}, nullptr)); computation, {}, nullptr));
LiteralTestUtil::ExpectNear(*expected_literal, *result_literal, EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
ErrorSpec(0.0001)); ErrorSpec(0.0001)));
} }
} // namespace } // namespace

View File

@ -149,12 +149,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
EXPECT_TRUE(OutputsBF16(dot->operand(1))); EXPECT_TRUE(OutputsBF16(dot->operand(1)));
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
dot->operand(0)->literal(), dot->operand(0)->literal(),
*LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))); *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))));
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
dot->operand(1)->literal(), dot->operand(1)->literal(),
*LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))); *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))));
} }
// Tests that BF16 can be propagated through nested tuples. // Tests that BF16 can be propagated through nested tuples.

View File

@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_limits[] = {10, 8, 6, 5, 9};
const int64 slice_strides[] = {1, 1, 1, 1, 1}; const int64 slice_strides[] = {1, 1, 1, 1, 1};
TF_ASSERT_OK_AND_ASSIGN(auto literal, TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralTestUtil::CreateRandomLiteral<F32>( Literal::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal))); HloInstruction::CreateConstant(std::move(literal)));
@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
HloComputation::Builder builder(TestName()); HloComputation::Builder builder(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9}; const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal, TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralTestUtil::CreateRandomLiteral<F32>( Literal::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->Literal::CloneToUnique(); auto literal_clone = literal->Literal::CloneToUnique();
HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction* literal_instruction = builder.AddInstruction(

View File

@ -72,7 +72,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
auto result = ExecuteAndTransfer(std::move(module), {}); auto result = ExecuteAndTransfer(std::move(module), {});
auto expected = Literal::CreateR0<float>(84.0); auto expected = Literal::CreateR0<float>(84.0);
LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
} }
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
@ -104,7 +104,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
auto result = ExecuteAndTransfer(std::move(module), {}); auto result = ExecuteAndTransfer(std::move(module), {});
auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
} }
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
@ -134,7 +134,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
auto result = ExecuteAndTransfer(std::move(module), {}); auto result = ExecuteAndTransfer(std::move(module), {});
auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
} }
TEST_F(HloCseTest, ConstantsSameValueDifferentType) { TEST_F(HloCseTest, ConstantsSameValueDifferentType) {

View File

@ -82,9 +82,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
auto element_type = expected->shape().element_type(); auto element_type = expected->shape().element_type();
if (element_type == F32 || element_type == F64) { if (element_type == F32 || element_type == F64) {
ErrorSpec error(aabs); ErrorSpec error(aabs);
LiteralTestUtil::ExpectNear(*expected, *result, error); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
} else { } else {
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
} }
@ -100,7 +100,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
std::unique_ptr<Literal> result = Evaluate(); std::unique_ptr<Literal> result = Evaluate();
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
bool use_bfloat16_; bool use_bfloat16_;
@ -129,7 +129,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
auto expected = Literal::CreateR2<float>({{0, 4}, {2, 4}}); auto expected = Literal::CreateR2<float>({{0, 4}, {2, 4}});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
@ -150,7 +150,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}}); auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
// Verifies that HloEvaluator evaluates a HLO instruction that performs select // Verifies that HloEvaluator evaluates a HLO instruction that performs select
@ -175,7 +175,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
auto expected = Literal::CreateR2<float>({{2, 5}, {0, 4}}); auto expected = Literal::CreateR2<float>({{2, 5}, {0, 4}});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
// Verifies that HloEvaluator evaluates a HLO instruction that performs // Verifies that HloEvaluator evaluates a HLO instruction that performs
@ -307,7 +307,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}}); auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
// Verifies Reshape operation is correctly evaluated. // Verifies Reshape operation is correctly evaluated.
@ -315,7 +315,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
HloComputation::Builder b(TestName()); HloComputation::Builder b(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9}; const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal, TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralTestUtil::CreateRandomLiteral<F32>( Literal::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->CloneToUnique(); auto literal_clone = literal->CloneToUnique();
HloInstruction* literal_instruction = HloInstruction* literal_instruction =
@ -351,7 +351,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
std::unique_ptr<Literal> result = Evaluate({}); std::unique_ptr<Literal> result = Evaluate({});
LiteralTestUtil::ExpectEqual(*result, *output_literal); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
} }
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
@ -370,7 +370,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
std::unique_ptr<Literal> result = Evaluate({}); std::unique_ptr<Literal> result = Evaluate({});
LiteralTestUtil::ExpectEqual(*result, *output_literal); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
} }
TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
@ -392,7 +392,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
auto expected = auto expected =
Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
@ -413,7 +413,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
std::unique_ptr<Literal> result = Evaluate(); std::unique_ptr<Literal> result = Evaluate();
auto expected = Literal::CreateR1<int64>({100, 200}); auto expected = Literal::CreateR1<int64>({100, 200});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
@ -432,7 +432,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
std::unique_ptr<Literal> result = Evaluate(); std::unique_ptr<Literal> result = Evaluate();
LiteralTestUtil::ExpectEqual(*result, *expected); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
} }
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
@ -452,7 +452,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
std::unique_ptr<Literal> result = Evaluate(); std::unique_ptr<Literal> result = Evaluate();
LiteralTestUtil::ExpectEqual(*result, *expected); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
} }
PaddingConfig CreatePaddingConfig( PaddingConfig CreatePaddingConfig(
@ -490,7 +490,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
auto expected = Literal::CreateR2<int32>( auto expected = Literal::CreateR2<int32>(
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
@ -525,7 +525,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
auto expected = Literal::CreateR4FromArray4D<float>(*expected_array); auto expected = Literal::CreateR4FromArray4D<float>(*expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, NegativePadding2D) { TEST_P(HloEvaluatorTest, NegativePadding2D) {
@ -567,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 4) = 2.718f; (*expected_array)(0, 4) = 2.718f;
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array); auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5)); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0x1.0P-5)));
} }
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
@ -606,7 +606,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
auto expected_array = MakeUnique<Array2D<float>>(0, 9); auto expected_array = MakeUnique<Array2D<float>>(0, 9);
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array); auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, DotRank2AndRank1) { TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
@ -651,7 +651,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// clang-format on // clang-format on
auto expected = Literal::CreateR2FromArray2D<float>(expected_array); auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, DotRank1AndRank2) { TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
@ -688,7 +688,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
auto expected = Literal::CreateR1<float>({22.f, 28.f}); auto expected = Literal::CreateR1<float>({22.f, 28.f});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, DotRank2AndRank2) { TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
@ -737,7 +737,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
}); });
auto expected = Literal::CreateR2FromArray2D<float>(expected_array); auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, SimpleConv1D) { TEST_P(HloEvaluatorTest, SimpleConv1D) {
@ -785,7 +785,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}}; Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
auto expected = Literal::CreateR3FromArray3D<float>(expected_array); auto expected = Literal::CreateR3FromArray3D<float>(expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
@ -847,7 +847,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
// clang-format on // clang-format on
auto expected = Literal::CreateR4FromArray4D<float>(expected_array); auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
@ -927,7 +927,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
auto expected = Literal::CreateR4FromArray4D<float>( auto expected = Literal::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array); use_bfloat16_ ? expected_array_bf16 : expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
@ -1004,7 +1004,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
auto expected = Literal::CreateR4FromArray4D<float>( auto expected = Literal::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array); use_bfloat16_ ? expected_array_bf16 : expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
@ -1067,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
})); }));
auto expected = Literal::CreateR4FromArray4D<float>(expected_array); auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
@ -1131,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
})); }));
auto expected = Literal::CreateR4FromArray4D<float>(expected_array); auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, TEST_P(HloEvaluatorTest,
@ -1203,7 +1203,7 @@ TEST_P(HloEvaluatorTest,
})); }));
auto expected = Literal::CreateR4FromArray4D<float>(expected_array); auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
@ -1319,7 +1319,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
auto expected = Literal::CreateR1<float>({6, 18}); auto expected = Literal::CreateR1<float>({6, 18});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, ReduceWindowMax) { TEST_P(HloEvaluatorTest, ReduceWindowMax) {
@ -1370,7 +1370,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
std::unique_ptr<Literal> result = Evaluate(); std::unique_ptr<Literal> result = Evaluate();
auto expected = Literal::CreateR2<float>({{6, 7}}); auto expected = Literal::CreateR2<float>({{6, 7}});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, ReduceWindowAdd) { TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
@ -1427,7 +1427,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
std::unique_ptr<Literal> result = Evaluate(); std::unique_ptr<Literal> result = Evaluate();
auto expected = Literal::CreateR2<float>({{1, 3, 5}, {5, 11, 13}}); auto expected = Literal::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
@ -1490,7 +1490,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4}; std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
std::unique_ptr<Literal> result_literal = std::unique_ptr<Literal> result_literal =
Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f); Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
LiteralTestUtil::ExpectEqual(*result_literal, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
} }
TEST_P(HloEvaluatorTest, StridedSlice) { TEST_P(HloEvaluatorTest, StridedSlice) {
@ -1523,7 +1523,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
{19}, {19},
}); });
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, DynamicSlice) { TEST_P(HloEvaluatorTest, DynamicSlice) {
@ -1556,7 +1556,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
{6, 7, 8}, {6, 7, 8},
}); });
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
// Verifies that the HloEvaluator's implementation goes along with existing // Verifies that the HloEvaluator's implementation goes along with existing
@ -1591,7 +1591,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
{6, 7, 8}, {6, 7, 8},
}); });
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
@ -1627,7 +1627,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
{5, -6, -7}, {5, -6, -7},
}); });
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, SetAndGetTuples) { TEST_P(HloEvaluatorTest, SetAndGetTuples) {
@ -1662,7 +1662,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
{5, 6, 7}, {5, 6, 7},
}); });
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
@ -1703,7 +1703,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
result_inner_literal.get(), result_inner_literal.get(),
}); });
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, Reverse) { TEST_P(HloEvaluatorTest, Reverse) {
@ -1756,7 +1756,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
}); });
// clang-format on // clang-format on
LiteralTestUtil::ExpectEqual(*expected, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
} }
TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
@ -1776,8 +1776,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
add, {{param0, Literal::CreateR1<float>({1, 2, 3, 4}).get()}, add, {{param0, Literal::CreateR1<float>({1, 2, 3, 4}).get()},
{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}}); {square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
TF_ASSERT_OK(result.status()); TF_ASSERT_OK(result.status());
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<float>({11, 22, 33, 44}), EXPECT_TRUE(LiteralTestUtil::Equal(
*result.ValueOrDie()); *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
} }
// Check that EvaluateWithSubstitutions works if one of the operands to the op // Check that EvaluateWithSubstitutions works if one of the operands to the op
@ -1800,8 +1800,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
auto result = evaluator.EvaluateWithSubstitutions( auto result = evaluator.EvaluateWithSubstitutions(
add, {{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}}); add, {{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
TF_ASSERT_OK(result.status()); TF_ASSERT_OK(result.status());
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<float>({11, 22, 33, 44}), EXPECT_TRUE(LiteralTestUtil::Equal(
*result.ValueOrDie()); *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
} }
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@ -1823,9 +1823,9 @@ ENTRY main {
std::unique_ptr<Literal> operand = std::unique_ptr<Literal> operand =
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2}); std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(
*Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}), LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
*Evaluate({operand.get(), gather_indices.get()})); *Evaluate({operand.get(), gather_indices.get()})));
} }
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@ -1847,9 +1847,9 @@ ENTRY main {
std::unique_ptr<Literal> operand = std::unique_ptr<Literal> operand =
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2}); std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}), *Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
*Evaluate({operand.get(), gather_indices.get()})); *Evaluate({operand.get(), gather_indices.get()})));
} }
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@ -1872,10 +1872,10 @@ ENTRY main {
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{0, 2}, {2, 1}}); Literal::CreateR2<int32>({{0, 2}, {2, 1}});
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR3<int32>( *Literal::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
*Evaluate({operand.get(), gather_indices.get()})); *Evaluate({operand.get(), gather_indices.get()})));
} }
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@ -1900,9 +1900,9 @@ ENTRY main {
{{-7, 7}, {-8, 8}, {-9, 9}}}); {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices = std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{0, 0}, {1, 0}}); Literal::CreateR2<int32>({{0, 0}, {1, 0}});
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(
*Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}), LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
*Evaluate({operand.get(), gather_indices.get()})); *Evaluate({operand.get(), gather_indices.get()})));
} }
TEST_P(HloEvaluatorTest, TEST_P(HloEvaluatorTest,
@ -1928,9 +1928,9 @@ ENTRY main {
{{-7, 7}, {-8, 8}, {-9, 9}}}); {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices = std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{0, 0}, {1, 0}}); Literal::CreateR2<int32>({{0, 0}, {1, 0}});
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(
*Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}), LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
*Evaluate({operand.get(), gather_indices.get()})); *Evaluate({operand.get(), gather_indices.get()})));
} }
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@ -1952,9 +1952,9 @@ ENTRY main {
std::unique_ptr<Literal> operand = std::unique_ptr<Literal> operand =
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1}); std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(
*Literal::CreateR2<int32>({{5}}), LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{5}}),
*Evaluate({operand.get(), gather_indices.get()})); *Evaluate({operand.get(), gather_indices.get()})));
} }
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@ -1977,9 +1977,9 @@ ENTRY main {
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{2, 1}, {1, 1}}); Literal::CreateR2<int32>({{2, 1}, {1, 1}});
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(
*Literal::CreateR3<int32>({{{8}}, {{5}}}), LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{8}}, {{5}}}),
*Evaluate({operand.get(), gather_indices.get()})); *Evaluate({operand.get(), gather_indices.get()})));
} }
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@ -2000,9 +2000,9 @@ ENTRY main {
ParseAndVerifyModule(hlo_text); ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}}); std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2}); std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(
*Literal::CreateR2<int32>({{}, {}}), LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{}, {}}),
*Evaluate({operand.get(), gather_indices.get()})); *Evaluate({operand.get(), gather_indices.get()})));
} }
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@ -2025,9 +2025,9 @@ ENTRY main {
std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({0, 1, 2}); std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({0, 1, 2});
std::unique_ptr<Literal> gather_indices = std::unique_ptr<Literal> gather_indices =
Literal::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}}); Literal::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(
*Literal::CreateR2<int32>({{0, 1}, {2, 1}}), LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{0, 1}, {2, 1}}),
*Evaluate({operand.get(), gather_indices.get()})); *Evaluate({operand.get(), gather_indices.get()})));
} }
// Verifies that HloEvaluator evaluates a HLO instruction that performs // Verifies that HloEvaluator evaluates a HLO instruction that performs

View File

@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) {
// Verify execution on CPU. // Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = Literal::CreateR1<float>({4, 3, 3, 4}); auto expected = Literal::CreateR1<float>({4, 3, 3, 4});
LiteralTestUtil::ExpectEqual(*result, *expected); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
} }
// Test that `constant` function is changed to `broadcast`. // Test that `constant` function is changed to `broadcast`.
@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) {
// Verify execution on CPU. // Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = Literal::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}}); auto expected = Literal::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
LiteralTestUtil::ExpectEqual(*result, *expected); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
} }
TEST_F(InlinerTest, MapSubtractOppositeOrder) { TEST_F(InlinerTest, MapSubtractOppositeOrder) {
@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
// Verify execution on CPU. // Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = Literal::CreateR1<float>({3, 1, -1, -3}); auto expected = Literal::CreateR1<float>({3, 1, -1, -3});
LiteralTestUtil::ExpectEqual(*result, *expected); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
} }

View File

@ -87,6 +87,7 @@ cc_library(
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal_comparison",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",

View File

@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(42.0), *result, EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0<float>(42.0), *result,
error_spec_); error_spec_));
} }
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
@ -62,9 +62,9 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
LiteralTestUtil::ExpectNear( EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result, *Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
error_spec_); error_spec_));
} }
XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
@ -85,13 +85,13 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
LiteralTestUtil::ExpectNear( EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), *Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
LiteralSlice(*result, {0}), error_spec_); LiteralSlice(*result, {0}), error_spec_));
LiteralTestUtil::ExpectNear( EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), *Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
LiteralSlice(*result, {1}), error_spec_); LiteralSlice(*result, {1}), error_spec_));
} }
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
LiteralTestUtil::ExpectNear( EXPECT_TRUE(
*Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result, LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
error_spec_); *result, error_spec_));
} }
XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
@ -125,9 +125,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
LiteralTestUtil::ExpectNear( EXPECT_TRUE(
*Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result, LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}),
error_spec_); *result, error_spec_));
} }
XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
@ -142,10 +142,10 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
LiteralTestUtil::ExpectNear( EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, *Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
{{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
*result, error_spec_); *result, error_spec_));
} }
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
@ -166,8 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
Array2D<float> pz({{1, 2}, {1, 2}}); Array2D<float> pz({{1, 2}, {1, 2}});
expected.FillWithPZ(pz); expected.FillWithPZ(pz);
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected), EXPECT_TRUE(LiteralTestUtil::Near(
*result, error_spec_); *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
} }
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@ -196,8 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
} }
expected.FillWithYX(yx); expected.FillWithYX(yx);
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected), EXPECT_TRUE(LiteralTestUtil::Near(
*result, error_spec_); *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
} }
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@ -218,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result, EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array),
error_spec_); *result, error_spec_));
} }
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
@ -238,8 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
Array4D<float> expected(64, 64, 3, 3); Array4D<float> expected(64, 64, 3, 3);
expected.Fill(1.0f); expected.Fill(1.0f);
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected), EXPECT_TRUE(LiteralTestUtil::Near(
*result, error_spec_); *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
} }
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
@ -260,8 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
Array4D<float> expected(3, 3, 2, 2); Array4D<float> expected(3, 3, 2, 2);
expected.FillWithYX(to_broadcast); expected.FillWithYX(to_broadcast);
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected), EXPECT_TRUE(LiteralTestUtil::Near(
*result, error_spec_); *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
} }
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@ -291,8 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected), EXPECT_TRUE(LiteralTestUtil::Near(
*result, error_spec_); *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
} }
} // namespace } // namespace

View File

@ -297,7 +297,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
std::unique_ptr<Literal> converted_expected; std::unique_ptr<Literal> converted_expected;
Shape layout_shape; Shape layout_shape;
if (use_bfloat16_) { if (use_bfloat16_) {
converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); converted_expected = Literal::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get(); expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) { if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout; layout_shape = *shape_with_layout;
@ -311,7 +311,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
} }
} }
auto expect_equal = [&](const Literal& actual, const string& error_message) { auto expect_equal = [&](const Literal& actual, const string& error_message) {
LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message;
}; };
if (execution_options_.debug_options().xla_test_all_output_layouts()) { if (execution_options_.debug_options().xla_test_all_output_layouts()) {
return ComputeAndCompareLiteralWithAllOutputLayouts( return ComputeAndCompareLiteralWithAllOutputLayouts(
@ -323,7 +323,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
} }
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout)); shape_with_layout));
LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
@ -349,7 +349,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
std::unique_ptr<Literal> converted_expected; std::unique_ptr<Literal> converted_expected;
Shape layout_shape; Shape layout_shape;
if (use_bfloat16_) { if (use_bfloat16_) {
converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); converted_expected = Literal::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get(); expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) { if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout; layout_shape = *shape_with_layout;
@ -363,7 +363,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
} }
} }
auto expect_near = [&](const Literal& actual, const string& error_message) { auto expect_near = [&](const Literal& actual, const string& error_message) {
LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error))
<< error_message;
}; };
if (execution_options_.debug_options().xla_test_all_output_layouts()) { if (execution_options_.debug_options().xla_test_all_output_layouts()) {
return ComputeAndCompareLiteralWithAllOutputLayouts( return ComputeAndCompareLiteralWithAllOutputLayouts(
@ -375,7 +376,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
} }
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout)); shape_with_layout));
LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
@ -407,7 +408,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return; return;
} }
auto actual = actual_status.ConsumeValueOrDie(); auto actual = actual_status.ConsumeValueOrDie();
LiteralTestUtil::ExpectEqual(expected, *actual); EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
} }
void ClientLibraryTestBase::ComputeAndCompareTuple( void ClientLibraryTestBase::ComputeAndCompareTuple(
@ -419,7 +420,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return; return;
} }
auto actual = actual_status.ConsumeValueOrDie(); auto actual = actual_status.ConsumeValueOrDie();
LiteralTestUtil::ExpectNear(expected, *actual, error); EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
} }
void ClientLibraryTestBase::ComputeAndCompare( void ClientLibraryTestBase::ComputeAndCompare(
@ -431,7 +432,7 @@ void ClientLibraryTestBase::ComputeAndCompare(
} }
std::unique_ptr<Literal> reference, result; std::unique_ptr<Literal> reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
LiteralTestUtil::ExpectEqual(*reference, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
} }
void ClientLibraryTestBase::ComputeAndCompare( void ClientLibraryTestBase::ComputeAndCompare(
@ -444,7 +445,7 @@ void ClientLibraryTestBase::ComputeAndCompare(
} }
std::unique_ptr<Literal> reference, result; std::unique_ptr<Literal> reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
LiteralTestUtil::ExpectNear(*reference, *result, error); EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
} }
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>> StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
@ -562,7 +563,7 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) { XlaBuilder* builder) {
return builder->ConstantLiteral( return builder->ConstantLiteral(
use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal);
} }
std::unique_ptr<GlobalData> std::unique_ptr<GlobalData>
@ -583,7 +584,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
const Literal* param_literal = &literal; const Literal* param_literal = &literal;
std::unique_ptr<Literal> converted_literal; std::unique_ptr<Literal> converted_literal;
if (use_bfloat16_) { if (use_bfloat16_) {
converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); converted_literal = Literal::ConvertF32ToBF16(literal);
param_literal = converted_literal.get(); param_literal = converted_literal.get();
} }
std::unique_ptr<GlobalData> data = std::unique_ptr<GlobalData> data =

View File

@ -541,7 +541,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
XlaBuilder* builder, XlaOp* data_handle) { XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR0(value); std::unique_ptr<Literal> literal = Literal::CreateR0(value);
if (use_bfloat16_ && literal->shape().element_type() == F32) { if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralTestUtil::ConvertF32ToBF16(*literal); literal = Literal::ConvertF32ToBF16(*literal);
} }
std::unique_ptr<GlobalData> data = std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie(); client_->TransferToServer(*literal).ConsumeValueOrDie();
@ -555,7 +555,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
const string& name, XlaBuilder* builder, XlaOp* data_handle) { const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR1(values); std::unique_ptr<Literal> literal = Literal::CreateR1(values);
if (use_bfloat16_ && literal->shape().element_type() == F32) { if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralTestUtil::ConvertF32ToBF16(*literal); literal = Literal::ConvertF32ToBF16(*literal);
} }
std::unique_ptr<GlobalData> data = std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie(); client_->TransferToServer(*literal).ConsumeValueOrDie();
@ -569,7 +569,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const string& name, XlaBuilder* builder, XlaOp* data_handle) { const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d); std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
if (use_bfloat16_ && literal->shape().element_type() == F32) { if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralTestUtil::ConvertF32ToBF16(*literal); literal = Literal::ConvertF32ToBF16(*literal);
} }
std::unique_ptr<GlobalData> data = std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie(); client_->TransferToServer(*literal).ConsumeValueOrDie();
@ -583,7 +583,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const string& name, XlaBuilder* builder, XlaOp* data_handle) { const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d); std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
if (use_bfloat16_ && literal->shape().element_type() == F32) { if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralTestUtil::ConvertF32ToBF16(*literal); literal = Literal::ConvertF32ToBF16(*literal);
} }
std::unique_ptr<GlobalData> data = std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie(); client_->TransferToServer(*literal).ConsumeValueOrDie();

View File

@ -62,9 +62,9 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto computed, client_->Transfer(*data, &expected_literal->shape())); auto computed, client_->Transfer(*data, &expected_literal->shape()));
LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
computed->shape()); expected_literal->shape(), computed->shape()));
LiteralTestUtil::ExpectEqual(*expected_literal, *computed); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
} }
} }
} }
@ -142,7 +142,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
auto result_literal, auto result_literal,
client_->Transfer(*results[0], &expected_result->shape())); client_->Transfer(*results[0], &expected_result->shape()));
LiteralTestUtil::ExpectEqual(*expected_result, *result_literal); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
} }
} // namespace } // namespace

View File

@ -50,8 +50,8 @@ class CompilationCacheTest : public ClientLibraryTestBase {
/*execution_options=*/&execution_options_, /*execution_options=*/&execution_options_,
&execution_profile) &execution_profile)
.ConsumeValueOrDie(); .ConsumeValueOrDie();
LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(expected_result), EXPECT_TRUE(LiteralTestUtil::Near(
*result, error_spec_); *Literal::CreateR0<float>(expected_result), *result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
} }
@ -67,8 +67,8 @@ class CompilationCacheTest : public ClientLibraryTestBase {
.ConsumeValueOrDie(); .ConsumeValueOrDie();
std::unique_ptr<Literal> result = std::unique_ptr<Literal> result =
client_->Transfer(*data_handle).ConsumeValueOrDie(); client_->Transfer(*data_handle).ConsumeValueOrDie();
LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>(expected_result), EXPECT_TRUE(LiteralTestUtil::Near(
*result, error_spec_); *Literal::CreateR2<float>(expected_result), *result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
} }

View File

@ -208,7 +208,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
ComputeConstantLiteral(client, computation, &b)); ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal = std::unique_ptr<Literal> expected_literal =
Literal::CreateR1<int32>({4, 6}); Literal::CreateR1<int32>({4, 6});
LiteralTestUtil::ExpectEqual(*expected_literal, *computed); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
} }
} }
@ -222,7 +222,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
TF_ASSERT_OK_AND_ASSIGN(auto computed, TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b)); ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5); std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
LiteralTestUtil::ExpectEqual(*expected_literal, *computed); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
} }
} }
@ -244,9 +244,9 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
std::unique_ptr<Literal> expected_literal = std::unique_ptr<Literal> expected_literal =
Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}}, Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
LayoutUtil::MakeLayout(layout)); LayoutUtil::MakeLayout(layout));
LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
computed->shape()); expected_literal->shape(), computed->shape()));
LiteralTestUtil::ExpectEqual(*expected_literal, *computed); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
} }
} }
} }

View File

@ -49,7 +49,7 @@ class CopyOpTest : public HloTestBase {
module->AddEntryComputation(std::move(computation)); module->AddEntryComputation(std::move(computation));
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {}); std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectEqual(literal, *result); EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
} }
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
@ -253,7 +253,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
.ConsumeValueOrDie(); .ConsumeValueOrDie();
LiteralTestUtil::ExpectEqual(*empty, *actual); EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
} }
} // namespace } // namespace

View File

@ -118,9 +118,9 @@ class FusionTest : public HloTestBase {
auto expected = Literal::CreateR2FromArray2D(answer_data); auto expected = Literal::CreateR2FromArray2D(answer_data);
auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
if (primitive_util::IsFloatingPointType(prim_type)) { if (primitive_util::IsFloatingPointType(prim_type)) {
LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
} else { } else {
LiteralTestUtil::ExpectEqual(*expected, *actual); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
} }
} }
@ -221,9 +221,9 @@ XLA_TEST_F(FusionTest, Test) {
const4, reshape3, add2, const1, const0}, const4, reshape3, add2, const1, const0},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{0.5}, {2.72}}), EXPECT_TRUE(LiteralTestUtil::Near(
*ExecuteAndTransfer(std::move(hlo_module), {}), *Literal::CreateR2<float>({{0.5}, {2.72}}),
ErrorSpec(1e-4)); *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
} }
// Test whether we emit appropriate code for parameters of fusion instructions. // Test whether we emit appropriate code for parameters of fusion instructions.
@ -247,9 +247,9 @@ XLA_TEST_F(FusionTest, Parameter) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}), EXPECT_TRUE(LiteralTestUtil::Near(
*ExecuteAndTransfer(std::move(hlo_module), {}), *Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
ErrorSpec(1e-4)); *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
} }
XLA_TEST_F(FusionTest, RandomizedParallelPartition) { XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
@ -307,9 +307,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectNear( EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), *Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
} }
XLA_TEST_F(FusionTest, ReshapeToScalar) { XLA_TEST_F(FusionTest, ReshapeToScalar) {
@ -322,8 +322,9 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
hlo_module->AddEntryComputation(builder.Build()) hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(5), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR0<int32>(5),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
@ -336,9 +337,9 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
hlo_module->AddEntryComputation(builder.Build()) hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}), *Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
*ExecuteAndTransfer(std::move(hlo_module), {})); *ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
@ -351,9 +352,9 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
hlo_module->AddEntryComputation(builder.Build()) hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}), *Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
*ExecuteAndTransfer(std::move(hlo_module), {})); *ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, Reshape_1by1by1_) { XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
@ -366,8 +367,9 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
hlo_module->AddEntryComputation(builder.Build()) hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, Reshape__1by1by1) { XLA_TEST_F(FusionTest, Reshape__1by1by1) {
@ -380,8 +382,9 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
hlo_module->AddEntryComputation(builder.Build()) hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR3<int32>({{{7}}}), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{7}}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, Reshape__) { XLA_TEST_F(FusionTest, Reshape__) {
@ -394,8 +397,9 @@ XLA_TEST_F(FusionTest, Reshape__) {
hlo_module->AddEntryComputation(builder.Build()) hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
@ -408,9 +412,9 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
hlo_module->AddEntryComputation(builder.Build()) hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), *Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
*ExecuteAndTransfer(std::move(hlo_module), {})); *ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, Transpose_2by3) { XLA_TEST_F(FusionTest, Transpose_2by3) {
@ -423,9 +427,9 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
hlo_module->AddEntryComputation(builder.Build()) hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}), *Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
*ExecuteAndTransfer(std::move(hlo_module), {})); *ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, Transpose_3by3) { XLA_TEST_F(FusionTest, Transpose_3by3) {
@ -438,9 +442,9 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
hlo_module->AddEntryComputation(builder.Build()) hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), *Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
*ExecuteAndTransfer(std::move(hlo_module), {})); *ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, Reverse) { XLA_TEST_F(FusionTest, Reverse) {
@ -454,8 +458,9 @@ XLA_TEST_F(FusionTest, Reverse) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({3, 2, 1}), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR1<int32>({3, 2, 1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, ReverseNegate) { XLA_TEST_F(FusionTest, ReverseNegate) {
@ -471,8 +476,9 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-3, -2, -1}), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-3, -2, -1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, BroadcastNegate) { XLA_TEST_F(FusionTest, BroadcastNegate) {
@ -488,8 +494,9 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -1}), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, SliceNegate) { XLA_TEST_F(FusionTest, SliceNegate) {
@ -505,8 +512,9 @@ XLA_TEST_F(FusionTest, SliceNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -3}), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -3}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, DynamicSliceNegate) { XLA_TEST_F(FusionTest, DynamicSliceNegate) {
@ -526,8 +534,9 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
/*instructions_to_fuse=*/{negate3, dynamic_slice2}, /*instructions_to_fuse=*/{negate3, dynamic_slice2},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-2, -3}), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-2, -3}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, ReshapeNegate) { XLA_TEST_F(FusionTest, ReshapeNegate) {
@ -543,8 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
// TODO(b/64070202): Investigate failure. // TODO(b/64070202): Investigate failure.
@ -561,8 +571,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
std::unique_ptr<HloComputation> MakeReduceTestComputation() { std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@ -591,8 +602,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(15), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR0<int32>(15),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
@ -612,8 +624,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(-15), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR0<int32>(-15),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
@ -661,9 +674,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2}, ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
HloInstruction::FusionKind::kLoop); HloInstruction::FusionKind::kLoop);
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}), *Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
*ExecuteAndTransfer(std::move(hlo_module), {})); *ExecuteAndTransfer(std::move(hlo_module), {})));
} }
// When a constant (or other op) which has multiple users is imported // When a constant (or other op) which has multiple users is imported
@ -697,8 +710,9 @@ XLA_TEST_F(FusionTest, SharedConstant) {
// fused instruction contains the constant(2), the parameter, and 4 adds // fused instruction contains the constant(2), the parameter, and 4 adds
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}), EXPECT_TRUE(
*ExecuteAndTransfer(std::move(hlo_module), {})); LiteralTestUtil::Equal(*Literal::CreateR1<int32>({8}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
} }
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }

View File

@ -629,8 +629,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
client_->ExecuteParallel(computation_instances)); client_->ExecuteParallel(computation_instances));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
client_->Transfer(*(result_data[0]))); client_->Transfer(*(result_data[0])));
LiteralTestUtil::ExpectEqual( EXPECT_TRUE(LiteralTestUtil::Equal(
*result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}})); *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}})));
} }
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_comparison.h"
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
@ -46,119 +47,23 @@ using ::tensorflow::strings::StrCat;
/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( /* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes(
const Shape& expected, const Shape& actual) { const Shape& expected, const Shape& actual) {
if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { Status result = literal_comparison::EqualShapes(expected, actual);
return ::testing::AssertionFailure() if (result.ok()) {
<< "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) return ::testing::AssertionSuccess();
<< " got: " << ShapeUtil::HumanString(actual);
} }
if (ShapeUtil::IsTuple(expected)) { return ::testing::AssertionFailure() << result;
if (ShapeUtil::TupleElementCount(expected) != }
ShapeUtil::TupleElementCount(actual)) {
return ::testing::AssertionFailure() /* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapesAndLayouts(
<< "want tuple element count: " const Shape& expected, const Shape& actual) {
<< ShapeUtil::TupleElementCount(expected) if (expected.ShortDebugString() != actual.ShortDebugString()) {
<< " got tuple element count: " return ::testing::AssertionFailure()
<< ShapeUtil::TupleElementCount(actual); << "want: " << expected.ShortDebugString()
} << " got: " << actual.ShortDebugString();
for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
::testing::AssertionResult result =
EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i))
<< "mismatch in tuple index " << i;
if (!result) {
return result;
}
}
} else {
if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
return ::testing::AssertionFailure()
<< "want rank of: " << ShapeUtil::HumanString(expected)
<< " got rank of: " << ShapeUtil::HumanString(actual);
}
if (expected.element_type() != actual.element_type()) {
return ::testing::AssertionFailure()
<< PrimitiveType_Name(expected.element_type()) << " vs "
<< PrimitiveType_Name(actual.element_type());
}
if (expected.dimensions_size() != actual.dimensions_size()) {
return ::testing::AssertionFailure()
<< "want dimensions_size " << expected.dimensions_size()
<< " got dimensions_size " << actual.dimensions_size();
}
for (int i = 0; i < expected.dimensions_size(); ++i) {
if (expected.dimensions(i) != actual.dimensions(i)) {
return ::testing::AssertionFailure()
<< "mismatch in dimension #" << i
<< " expected: " << ShapeUtil::HumanString(expected)
<< " actual: " << ShapeUtil::HumanString(actual);
}
}
} }
return ::testing::AssertionSuccess(); return ::testing::AssertionSuccess();
} }
/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected,
const Shape& actual) {
ASSERT_TRUE(EqualShapes(expected, actual));
}
/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts(
const Shape& expected, const Shape& actual) {
ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
}
namespace {
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
// First construct shape of the result.
Shape result_shape(literal.shape());
ShapeUtil::ForEachMutableSubshape(
&result_shape, [](Shape* subshape, const ShapeIndex&) {
if (subshape->element_type() ==
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
subshape->set_element_type(
primitive_util::NativeToPrimitiveType<ToNativeT>());
}
});
auto result = MakeUnique<Literal>(result_shape);
// Then copy over the data from 'literal' converting FromNativeT values to
// ToNativeT values as necessary.
ShapeUtil::ForEachSubshape(
literal.shape(),
[&](const Shape& subshape, const ShapeIndex& shape_index) {
if (ShapeUtil::IsArray(subshape)) {
if (subshape.element_type() ==
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
auto src = literal.data<FromNativeT>(shape_index);
auto dest = result->data<ToNativeT>(shape_index);
for (int64 i = 0; i < src.size(); ++i) {
dest[i] = static_cast<ToNativeT>(src[i]);
}
} else {
TF_CHECK_OK(result->CopyFrom(literal,
/*dest_shape_index=*/shape_index,
/*src_shape_index=*/shape_index));
}
}
});
return result;
}
} // namespace
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
LiteralSlice literal) {
return ConvertType<bfloat16, float>(literal);
}
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
LiteralSlice literal) {
return ConvertType<float, bfloat16>(literal);
}
namespace { namespace {
string Hostname() { string Hostname() {
@ -168,183 +73,15 @@ string Hostname() {
return string(hostname); return string(hostname);
} }
// Helper function for comparing a floating point type, FloatT, bitwise equal
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
// -- on miscompare, a nice error message is given in the AssertionFailure.
template <typename FloatT, typename UnsignedT>
::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
auto lhs_double = static_cast<double>(lhs);
auto rhs_double = static_cast<double>(rhs);
if (ulhs != urhs) {
return ::testing::AssertionFailure() << Printf(
"floating values are not bitwise-equal; and equality testing "
"was requested: %s=%g=%a vs %s=%g=%a",
StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double,
lhs_double, StrCat(tensorflow::strings::Hex(urhs)).c_str(),
rhs_double, rhs_double);
}
return ::testing::AssertionSuccess();
}
// Templated comparator that specializes for float equality comparison with the
// bitwise helper above (this is the un-specialized fallback, to just use the
// default gunit implementation).
template <typename NativeT>
::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) {
if (lhs == rhs) {
return ::testing::AssertionSuccess();
}
::testing::Message msg;
msg << "Expected equality of these values:";
msg << "\n " << lhs;
msg << "\n " << rhs;
return ::testing::AssertionFailure() << msg;
}
// Specializations for floating types that do bitwise comparisons when equality
// comparison is requested.
template <>
::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
}
template <>
::testing::AssertionResult CompareEqual<Eigen::half>(Eigen::half lhs,
Eigen::half rhs) {
return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
}
template <>
::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
}
template <>
::testing::AssertionResult CompareEqual<double>(double lhs, double rhs) {
return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
}
template <>
::testing::AssertionResult CompareEqual<complex64>(complex64 lhs,
complex64 rhs) {
auto res = CompareEqual<float>(lhs.real(), rhs.real());
if (!res) {
return res;
}
return CompareEqual<float>(lhs.imag(), rhs.imag());
}
// A recursive function which iterates through every index of expected and
// actual literal and compares their values elementwise. Returns true if all
// elements are equal.
template <typename NativeT>
bool ExpectLiteralsEqual(LiteralSlice expected, LiteralSlice actual,
tensorflow::gtl::MutableArraySlice<int64> multi_index,
int64 dimension) {
if (dimension == expected.shape().dimensions_size()) {
NativeT expected_value = expected.Get<NativeT>(multi_index);
NativeT actual_value = actual.Get<NativeT>(multi_index);
::testing::AssertionResult result =
CompareEqual<NativeT>(expected_value, actual_value);
return result; // Defines implicit coersion to bool.
}
bool all_match = true;
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
multi_index[dimension] = i;
all_match = all_match && ExpectLiteralsEqual<NativeT>(
expected, actual, multi_index, dimension + 1);
}
return all_match;
}
} // namespace } // namespace
/* static */ void LiteralTestUtil::ExpectEqual(LiteralSlice expected,
LiteralSlice actual,
const string& message) {
EXPECT_TRUE(Equal(expected, actual))
<< "expected:\n"
<< expected.ToString() << "\n\tvs actual:\n"
<< actual.ToString()
<< (message.empty() ? "" : StrCat("\nmessage: ", message));
}
/* static */ void LiteralTestUtil::ExpectNotEqual(LiteralSlice expected,
LiteralSlice actual) {
EXPECT_FALSE(Equal(expected, actual));
}
/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( /* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
LiteralSlice expected, LiteralSlice actual) { const LiteralSlice& expected, const LiteralSlice& actual) {
VLOG(1) << "expected:"; Status result = literal_comparison::Equal(expected, actual);
XLA_VLOG_LINES(1, expected.ToString()); if (result.ok()) {
VLOG(1) << "actual:"; return ::testing::AssertionSuccess();
XLA_VLOG_LINES(1, actual.ToString());
AssertEqualShapes(expected.shape(), actual.shape());
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
bool match = false;
switch (expected.shape().element_type()) {
case PRED:
match = ExpectLiteralsEqual<bool>(expected, actual, &multi_index, 0);
break;
case U8:
match = ExpectLiteralsEqual<uint8>(expected, actual, &multi_index, 0);
break;
case S32:
match = ExpectLiteralsEqual<int32>(expected, actual, &multi_index, 0);
break;
case S64:
match = ExpectLiteralsEqual<int64>(expected, actual, &multi_index, 0);
break;
case U32:
match = ExpectLiteralsEqual<uint32>(expected, actual, &multi_index, 0);
break;
case U64:
match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
break;
case BF16:
match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
break;
case F16:
match = ExpectLiteralsEqual<half>(expected, actual, &multi_index, 0);
break;
case F32:
match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
break;
case F64:
match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0);
break;
case C64:
match = ExpectLiteralsEqual<complex64>(expected, actual, &multi_index, 0);
break;
case TUPLE: {
bool tuple_match = true;
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
SCOPED_TRACE(StrCat("Tuple index ", i, " in ",
ShapeUtil::HumanString(expected.shape())));
// Create LiteralSlices of the expected and actual elements.
auto result =
Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}));
tuple_match = tuple_match ? !!result : false;
}
match = tuple_match;
break;
}
default:
LOG(FATAL)
<< "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
<< PrimitiveType_Name(expected.shape().element_type());
} }
::testing::AssertionResult result = ::testing::AssertionSuccess(); return ::testing::AssertionFailure() << result;
if (!match) {
result = ::testing::AssertionFailure()
<< "expected: " << expected.ToString()
<< "\nactual: " << actual.ToString();
VLOG(1) << result.message();
}
return result;
} }
namespace { namespace {
@ -368,7 +105,7 @@ int64 RecursiveElementCount(const Shape& shape) {
// 3 minutes. The utility of printing a literal with >1000 elements is // 3 minutes. The utility of printing a literal with >1000 elements is
// questionable, especially when writing the Literal proto to disk is orders // questionable, especially when writing the Literal proto to disk is orders
// of magnitude faster. // of magnitude faster.
string TruncateHugeLiteral(LiteralSlice literal) { string TruncateHugeLiteral(const LiteralSlice& literal) {
return RecursiveElementCount(literal.shape()) < 1000 return RecursiveElementCount(literal.shape()) < 1000
? literal.ToString() ? literal.ToString()
: "[TRUNCATED, Literal with more than 1000 values]"; : "[TRUNCATED, Literal with more than 1000 values]";
@ -435,8 +172,8 @@ class NearComparator {
// result. The assertion result is successful if all actual and expected // result. The assertion result is successful if all actual and expected
// elements are within the given error bound. In case of error, the assertion // elements are within the given error bound. In case of error, the assertion
// result contains a detailed error message in case of failure. // result contains a detailed error message in case of failure.
static ::testing::AssertionResult Compare(LiteralSlice expected, static ::testing::AssertionResult Compare(const LiteralSlice& expected,
LiteralSlice actual, const LiteralSlice& actual,
ErrorSpec error, ErrorSpec error,
bool detailed_message) { bool detailed_message) {
NearComparator<NativeT> comparator(expected, actual, error, NearComparator<NativeT> comparator(expected, actual, error,
@ -464,7 +201,7 @@ class NearComparator {
return Printf( return Printf(
"actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
LiteralTestUtil::MultiIndexAsString( Literal::MultiIndexAsString(
IndexUtil::LinearIndexToMultidimensionalIndex(shape, IndexUtil::LinearIndexToMultidimensionalIndex(shape,
linear_index)) linear_index))
.c_str(), .c_str(),
@ -472,8 +209,9 @@ class NearComparator {
} }
}; };
explicit NearComparator(LiteralSlice expected, LiteralSlice actual, explicit NearComparator(const LiteralSlice& expected,
ErrorSpec error, bool detailed_message) const LiteralSlice& actual, ErrorSpec error,
bool detailed_message)
: expected_(expected), : expected_(expected),
actual_(actual), actual_(actual),
error_(error), error_(error),
@ -649,7 +387,7 @@ class NearComparator {
} }
// Writes the given literal to a file in the test temporary directory. // Writes the given literal to a file in the test temporary directory.
void WriteLiteralToTempFile(LiteralSlice literal, const string& name) { void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
int64 now_usec = tensorflow::Env::Default()->NowMicros(); int64 now_usec = tensorflow::Env::Default()->NowMicros();
string filename = tensorflow::io::JoinPath( string filename = tensorflow::io::JoinPath(
tensorflow::testing::TmpDir(), tensorflow::testing::TmpDir(),
@ -794,8 +532,8 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
// Helper function for comparing two literals for nearness. Handles tuple-shapes // Helper function for comparing two literals for nearness. Handles tuple-shapes
// via recursion. shape_index is the ShapeIndex of expected (or actual) // via recursion. shape_index is the ShapeIndex of expected (or actual)
// currently being compared. // currently being compared.
::testing::AssertionResult NearHelper(LiteralSlice expected, ::testing::AssertionResult NearHelper(const LiteralSlice& expected,
LiteralSlice actual, const LiteralSlice& actual,
const ErrorSpec& error, const ErrorSpec& error,
bool detailed_message, bool detailed_message,
const ShapeIndex& shape_index) { const ShapeIndex& shape_index) {
@ -874,30 +612,14 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
} // namespace } // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near( /* static */ ::testing::AssertionResult LiteralTestUtil::Near(
LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error, const LiteralSlice& expected, const LiteralSlice& actual,
bool detailed_message) { const ErrorSpec& error, bool detailed_message) {
return NearHelper(expected, actual, error, detailed_message, return NearHelper(expected, actual, error, detailed_message,
/*shape_index=*/{}); /*shape_index=*/{});
} }
/* static */ void LiteralTestUtil::ExpectNear(LiteralSlice expected, /* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
LiteralSlice actual, const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error,
const string& message) {
::testing::AssertionResult res =
Near(expected, actual, error, /*detailed_message=*/false);
if (!res) {
res << "Expected: " << TruncateHugeLiteral(expected) << "\n";
res << "Actual: " << TruncateHugeLiteral(actual) << "\n";
if (!message.empty()) {
res << StrCat("\nmessage: ", message);
}
}
EXPECT_TRUE(res);
}
/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
LiteralSlice expected, LiteralSlice actual,
const tensorflow::gtl::optional<ErrorSpec>& error) { const tensorflow::gtl::optional<ErrorSpec>& error) {
if (error.has_value()) { if (error.has_value()) {
VLOG(1) << "Expects near"; VLOG(1) << "Expects near";
@ -907,86 +629,4 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
return Equal(expected, actual); return Equal(expected, actual);
} }
/*static*/ void LiteralTestUtil::ExpectNearOrEqual(
LiteralSlice expected, LiteralSlice actual,
const tensorflow::gtl::optional<ErrorSpec>& error) {
EXPECT_TRUE(NearOrEqual(expected, actual, error));
}
/* static */ string LiteralTestUtil::MultiIndexAsString(
tensorflow::gtl::ArraySlice<int64> multi_index) {
return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
}
/* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
tensorflow::gtl::ArraySlice<int64> new_dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal) {
int64 new_num_elements = 1;
for (int64 i = 0; i < new_dimensions.size(); ++i) {
new_num_elements *= new_dimensions[i];
}
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
auto new_literal = MakeUnique<Literal>(
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
// Create a new shape with the given minor-to-major layout. This shape is used
// solely for converting linear address to multi-dimensional addresses when
// writing elements to the new literal.
Shape shape_with_layout = new_literal->shape();
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
// Copy data into new literal, element-by-element.
for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
std::vector<int64> from_multi_index =
IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
std::vector<int64> to_multi_index =
IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
switch (literal.shape().element_type()) {
case PRED:
new_literal->Set<bool>(to_multi_index,
literal.Get<bool>(from_multi_index));
break;
case U8:
new_literal->Set<uint8>(to_multi_index,
literal.Get<uint8>(from_multi_index));
break;
case U32:
new_literal->Set<uint32>(to_multi_index,
literal.Get<uint32>(from_multi_index));
break;
case S32:
new_literal->Set<int32>(to_multi_index,
literal.Get<int32>(from_multi_index));
break;
case U64:
new_literal->Set<uint64>(to_multi_index,
literal.Get<uint64>(from_multi_index));
break;
case S64:
new_literal->Set<int64>(to_multi_index,
literal.Get<int64>(from_multi_index));
break;
case F32:
new_literal->Set<float>(to_multi_index,
literal.Get<float>(from_multi_index));
break;
case F64:
new_literal->Set<double>(to_multi_index,
literal.Get<double>(from_multi_index));
break;
case C64:
new_literal->Set<complex64>(to_multi_index,
literal.Get<complex64>(from_multi_index));
break;
default:
LOG(FATAL) << "Unhandled primitive element type: "
<< PrimitiveType_Name(literal.shape().element_type());
}
}
return new_literal;
}
} // namespace xla } // namespace xla

View File

@ -57,65 +57,47 @@ class LiteralTestUtil {
public: public:
// Asserts that the given shapes have the same rank, dimension sizes, and // Asserts that the given shapes have the same rank, dimension sizes, and
// primitive types. // primitive types.
static ::testing::AssertionResult EqualShapes(const Shape& expected, static ::testing::AssertionResult EqualShapes(
const Shape& actual); const Shape& expected, const Shape& actual) MUST_USE_RESULT;
static void AssertEqualShapes(const Shape& expected, const Shape& actual);
// Asserts that the provided shapes are equal as defined in AssertEqualShapes // Asserts that the provided shapes are equal as defined in AssertEqualShapes
// and that they have the same layout. // and that they have the same layout.
static void AssertEqualShapesAndLayouts(const Shape& expected, static ::testing::AssertionResult EqualShapesAndLayouts(
const Shape& actual); const Shape& expected, const Shape& actual) MUST_USE_RESULT;
// If the given literal's data type is bfloat16, converts it to a float static ::testing::AssertionResult Equal(const LiteralSlice& expected,
// literal; otherwise, returns a copy of it. If the literal is a tuple, const LiteralSlice& actual)
// recursively converts its elements. TF_MUST_USE_RESULT;
static std::unique_ptr<Literal> ConvertBF16ToF32(LiteralSlice bf16_literal);
// If the given literal's data type is float, converts it to a bfloat16
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
static std::unique_ptr<Literal> ConvertF32ToBF16(LiteralSlice f32_literal);
// Asserts that the expected and actual literals are (bitwise) equal for all
// elements in the literal. Also, asserts that the rank, dimensions sizes, and
// primitive type are equal.
static ::testing::AssertionResult Equal(
LiteralSlice expected, LiteralSlice actual) TF_MUST_USE_RESULT;
// Expects that expected and actual are Equal.
static void ExpectEqual(LiteralSlice expected, LiteralSlice actual,
const string& message = "");
// Expects that expected and actual are Not Equal.
static void ExpectNotEqual(LiteralSlice expected, LiteralSlice actual);
// Asserts the given literal are (bitwise) equal to given expected values. // Asserts the given literal are (bitwise) equal to given expected values.
template <typename NativeT> template <typename NativeT>
static void ExpectR0Equal(NativeT expected, LiteralSlice actual); static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
template <typename NativeT> template <typename NativeT>
static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected, static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
LiteralSlice actual); const LiteralSlice& actual);
template <typename NativeT> template <typename NativeT>
static void ExpectR2Equal( static void ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected, std::initializer_list<std::initializer_list<NativeT>> expected,
LiteralSlice actual); const LiteralSlice& actual);
template <typename NativeT> template <typename NativeT>
static void ExpectR3Equal( static void ExpectR3Equal(
std::initializer_list< std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>> std::initializer_list<std::initializer_list<NativeT>>>
expected, expected,
LiteralSlice actual); const LiteralSlice& actual);
// Asserts the given literal are (bitwise) equal to given array. // Asserts the given literal are (bitwise) equal to given array.
template <typename NativeT> template <typename NativeT>
static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected, static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
LiteralSlice actual); const LiteralSlice& actual);
template <typename NativeT> template <typename NativeT>
static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected, static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
LiteralSlice actual); const LiteralSlice& actual);
template <typename NativeT> template <typename NativeT>
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected, static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
LiteralSlice actual); const LiteralSlice& actual);
// Asserts that the expected and actual literals are within the given error // Asserts that the expected and actual literals are within the given error
// bound for all elements. Also, asserts that the rank, dimensions sizes, and // bound for all elements. Also, asserts that the rank, dimensions sizes, and
@ -133,183 +115,138 @@ class LiteralTestUtil {
// If detailed_message is true, then the error message in the assertion result // If detailed_message is true, then the error message in the assertion result
// will contain a more detailed breakdown of mismatches. // will contain a more detailed breakdown of mismatches.
static ::testing::AssertionResult Near( static ::testing::AssertionResult Near(
LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error, const LiteralSlice& expected, const LiteralSlice& actual,
bool detailed_message = false) TF_MUST_USE_RESULT; const ErrorSpec& error, bool detailed_message = false) TF_MUST_USE_RESULT;
// Expects expected and actual to be Near with the given error.
static void ExpectNear(LiteralSlice expected, LiteralSlice actual,
const ErrorSpec& error, const string& message = "");
// Asserts the given literal are within the given error bound of the given // Asserts the given literal are within the given error bound of the given
// expected values. Only supported for floating point values. // expected values. Only supported for floating point values.
template <typename NativeT> template <typename NativeT>
static void ExpectR0Near(NativeT expected, LiteralSlice actual, static void ExpectR0Near(NativeT expected, const LiteralSlice& actual,
const ErrorSpec& error); const ErrorSpec& error);
template <typename NativeT> template <typename NativeT>
static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected, static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
LiteralSlice actual, const ErrorSpec& error); const LiteralSlice& actual, const ErrorSpec& error);
template <typename NativeT> template <typename NativeT>
static void ExpectR2Near( static void ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected, std::initializer_list<std::initializer_list<NativeT>> expected,
LiteralSlice actual, const ErrorSpec& error); const LiteralSlice& actual, const ErrorSpec& error);
template <typename NativeT> template <typename NativeT>
static void ExpectR3Near( static void ExpectR3Near(
std::initializer_list< std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>> std::initializer_list<std::initializer_list<NativeT>>>
expected, expected,
LiteralSlice actual, const ErrorSpec& error); const LiteralSlice& actual, const ErrorSpec& error);
template <typename NativeT> template <typename NativeT>
static void ExpectR4Near( static void ExpectR4Near(
std::initializer_list<std::initializer_list< std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>> std::initializer_list<std::initializer_list<NativeT>>>>
expected, expected,
LiteralSlice actual, const ErrorSpec& error); const LiteralSlice& actual, const ErrorSpec& error);
// Asserts the given literal are within the given error bound to the given // Asserts the given literal are within the given error bound to the given
// array. Only supported for floating point values. // array. Only supported for floating point values.
template <typename NativeT> template <typename NativeT>
static void ExpectR2NearArray2D(const Array2D<NativeT>& expected, static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
LiteralSlice actual, const ErrorSpec& error); const LiteralSlice& actual,
const ErrorSpec& error);
template <typename NativeT> template <typename NativeT>
static void ExpectR3NearArray3D(const Array3D<NativeT>& expected, static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
LiteralSlice actual, const ErrorSpec& error); const LiteralSlice& actual,
const ErrorSpec& error);
template <typename NativeT> template <typename NativeT>
static void ExpectR4NearArray4D(const Array4D<NativeT>& expected, static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
LiteralSlice actual, const ErrorSpec& error); const LiteralSlice& actual,
const ErrorSpec& error);
// If the error spec is given, returns whether the expected and the actual are // If the error spec is given, returns whether the expected and the actual are
// within the error bound; otherwise, returns whether they are equal. Tuples // within the error bound; otherwise, returns whether they are equal. Tuples
// will be compared recursively. // will be compared recursively.
static ::testing::AssertionResult NearOrEqual( static ::testing::AssertionResult NearOrEqual(
LiteralSlice expected, LiteralSlice actual, const LiteralSlice& expected, const LiteralSlice& actual,
const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT; const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
// If the error spec is given, expects the expected and the actual to be near;
// otherwise, expects them to be equal. Tuples will be compared recursively.
static void ExpectNearOrEqual(
LiteralSlice expected, LiteralSlice actual,
const tensorflow::gtl::optional<ErrorSpec>& error);
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
// dimension 1 equal to 8.
static string MultiIndexAsString(
tensorflow::gtl::ArraySlice<int64> multi_index);
// Creates a literal with a new shape with the given new dimensions using the
// data in the given input literal. For reshaping purposes the (flat) data
// buffer of the input literal is assumed to have the given minor_to_major
// layout order.
static std::unique_ptr<Literal> Reshape(
tensorflow::gtl::ArraySlice<int64> new_dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
// Returns the new literal object, or an error Status if failed.
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape,
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
// deviation, and using the engine as entropy generator.
// Returns the new literal object, or an error Status if failed.
template <
PrimitiveType type, typename E,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape, E* engine, T mean, T stddev);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
// deviation.
// Returns the new literal object, or an error Status if failed.
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape, T mean, T stddev);
private: private:
TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
}; };
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
LiteralSlice actual) { const LiteralSlice& actual) {
ExpectEqual(*Literal::CreateR0<NativeT>(expected), actual); EXPECT_TRUE(Equal(*Literal::CreateR0<NativeT>(expected), actual));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal( /* static */ void LiteralTestUtil::ExpectR1Equal(
tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual) { tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual) {
ExpectEqual(*Literal::CreateR1<NativeT>(expected), actual); EXPECT_TRUE(Equal(*Literal::CreateR1<NativeT>(expected), actual));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Equal( /* static */ void LiteralTestUtil::ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected, std::initializer_list<std::initializer_list<NativeT>> expected,
LiteralSlice actual) { const LiteralSlice& actual) {
ExpectEqual(*Literal::CreateR2<NativeT>(expected), actual); EXPECT_TRUE(Equal(*Literal::CreateR2<NativeT>(expected), actual));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3Equal( /* static */ void LiteralTestUtil::ExpectR3Equal(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected, expected,
LiteralSlice actual) { const LiteralSlice& actual) {
ExpectEqual(*Literal::CreateR3<NativeT>(expected), actual); EXPECT_TRUE(Equal(*Literal::CreateR3<NativeT>(expected), actual));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D( /* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
const Array2D<NativeT>& expected, LiteralSlice actual) { const Array2D<NativeT>& expected, const LiteralSlice& actual) {
ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D( /* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
const Array3D<NativeT>& expected, LiteralSlice actual) { const Array3D<NativeT>& expected, const LiteralSlice& actual) {
ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D( /* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
const Array4D<NativeT>& expected, LiteralSlice actual) { const Array4D<NativeT>& expected, const LiteralSlice& actual) {
ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
LiteralSlice actual, const LiteralSlice& actual,
const ErrorSpec& error) { const ErrorSpec& error) {
ExpectNear(*Literal::CreateR0<NativeT>(expected), actual, error); EXPECT_TRUE(Near(*Literal::CreateR0<NativeT>(expected), actual, error));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near( /* static */ void LiteralTestUtil::ExpectR1Near(
tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual, tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual,
const ErrorSpec& error) { const ErrorSpec& error) {
ExpectNear(*Literal::CreateR1<NativeT>(expected), actual, error); EXPECT_TRUE(Near(*Literal::CreateR1<NativeT>(expected), actual, error));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Near( /* static */ void LiteralTestUtil::ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected, std::initializer_list<std::initializer_list<NativeT>> expected,
LiteralSlice actual, const ErrorSpec& error) { const LiteralSlice& actual, const ErrorSpec& error) {
ExpectNear(*Literal::CreateR2<NativeT>(expected), actual, error); EXPECT_TRUE(Near(*Literal::CreateR2<NativeT>(expected), actual, error));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3Near( /* static */ void LiteralTestUtil::ExpectR3Near(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected, expected,
LiteralSlice actual, const ErrorSpec& error) { const LiteralSlice& actual, const ErrorSpec& error) {
ExpectNear(*Literal::CreateR3<NativeT>(expected), actual, error); EXPECT_TRUE(Near(*Literal::CreateR3<NativeT>(expected), actual, error));
} }
template <typename NativeT> template <typename NativeT>
@ -317,63 +254,29 @@ template <typename NativeT>
std::initializer_list<std::initializer_list< std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>> std::initializer_list<std::initializer_list<NativeT>>>>
expected, expected,
LiteralSlice actual, const ErrorSpec& error) { const LiteralSlice& actual, const ErrorSpec& error) {
ExpectNear(*Literal::CreateR4<NativeT>(expected), actual, error); EXPECT_TRUE(Near(*Literal::CreateR4<NativeT>(expected), actual, error));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D( /* static */ void LiteralTestUtil::ExpectR2NearArray2D(
const Array2D<NativeT>& expected, LiteralSlice actual, const Array2D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) { const ErrorSpec& error) {
ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3NearArray3D( /* static */ void LiteralTestUtil::ExpectR3NearArray3D(
const Array3D<NativeT>& expected, LiteralSlice actual, const Array3D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) { const ErrorSpec& error) {
ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error));
} }
template <typename NativeT> template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4NearArray4D( /* static */ void LiteralTestUtil::ExpectR4NearArray4D(
const Array4D<NativeT>& expected, LiteralSlice actual, const Array4D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) { const ErrorSpec& error) {
ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error));
}
template <PrimitiveType type, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>>
LiteralTestUtil::CreateRandomLiteral(
const Shape& shape,
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
[&](tensorflow::gtl::ArraySlice<int64> indexes) {
return generator(indexes);
}));
return std::move(literal);
}
template <PrimitiveType type, typename E, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>>
LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
return generator(*engine);
});
}
template <PrimitiveType type, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>>
LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
} }
} // namespace xla } // namespace xla

View File

@ -34,7 +34,7 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
std::unique_ptr<Literal> literal = Literal::MakeTuple({ std::unique_ptr<Literal> literal = Literal::MakeTuple({
Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(), Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(),
}); });
LiteralTestUtil::ExpectEqual(*literal, *literal); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
} }
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
@ -97,6 +97,15 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
} }
} }
TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
auto expected = Literal::CreateR1<int32>({1, 2, 3});
auto actual = Literal::CreateR1<int32>({4, 5, 6});
::testing::AssertionResult result =
LiteralTestUtil::Equal(*expected, *actual);
EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}"));
EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}"));
}
TEST(LiteralTestUtilTest, NearComparatorR1) { TEST(LiteralTestUtilTest, NearComparatorR1) {
auto a = auto a =
Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});

View File

@ -108,7 +108,7 @@ class MultiOutputFusionTest : public HloTestBase {
expect.PopulateWithValue<float>(size * 1.5f * 3.5f); expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
auto actual = ExecuteAndTransfer( auto actual = ExecuteAndTransfer(
std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1}); std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
} }
void RunTest1D(bool manual_fusion, int size) { void RunTest1D(bool manual_fusion, int size) {
@ -168,7 +168,7 @@ class MultiOutputFusionTest : public HloTestBase {
Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f})); Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
} }
}; };

View File

@ -273,11 +273,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
&execution_options_)); &execution_options_));
} }
LiteralTestUtil::ExpectEqual(*result1, *result2); EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2));
LiteralTestUtil::ExpectEqual(*result1, *result3); EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3));
LiteralTestUtil::ExpectNotEqual(*result1, *result4); EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4));
LiteralTestUtil::ExpectNotEqual(*result4, *result5); EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5));
LiteralTestUtil::ExpectNotEqual(*result5, *result6); EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6));
} }
XLA_TEST_F(PrngTest, TenValuesN01) { XLA_TEST_F(PrngTest, TenValuesN01) {

View File

@ -656,9 +656,9 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
std::unique_ptr<Literal> expected = std::unique_ptr<Literal> expected =
Literal::CreateR2FromArray2D<float>(expected_array); Literal::CreateR2FromArray2D<float>(expected_array);
if (use_bfloat16()) { if (use_bfloat16()) {
expected = LiteralTestUtil::ConvertF32ToBF16(*expected); expected = Literal::ConvertF32ToBF16(*expected);
} }
LiteralTestUtil::ExpectEqual(*expected, *actual); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
} }
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
@ -731,7 +731,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
std::unique_ptr<Literal> expected = std::unique_ptr<Literal> expected =
LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_); zero_error_spec_);
} }
@ -753,7 +753,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
std::unique_ptr<Literal> expected = std::unique_ptr<Literal> expected =
LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_); zero_error_spec_);
} }
@ -817,7 +817,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
// Since the reshape is a no-op, verify that it does not change the underlying // Since the reshape is a no-op, verify that it does not change the underlying
// data. // data.
if (use_bfloat16()) { if (use_bfloat16()) {
auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); auto expected = Literal::ConvertF32ToBF16(*input_literal);
EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>()); EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
} else { } else {
EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>()); EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
@ -886,7 +886,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
/*new_sizes=*/new_bounds); /*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected = std::unique_ptr<Literal> expected =
LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape // Specify the requested output shape explicitly to ensure that this reshape
@ -915,7 +915,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
/*new_sizes=*/new_bounds); /*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected = std::unique_ptr<Literal> expected =
LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape // Specify the requested output shape explicitly to ensure that this reshape
@ -944,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
/*new_sizes=*/new_bounds); /*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected = std::unique_ptr<Literal> expected =
LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape // Specify the requested output shape explicitly to ensure that this reshape
@ -974,7 +974,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
/*new_sizes=*/new_bounds); /*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected = std::unique_ptr<Literal> expected =
LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape // Specify the requested output shape explicitly to ensure that this reshape
@ -1003,7 +1003,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
/*new_sizes=*/new_bounds); /*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected = std::unique_ptr<Literal> expected =
LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
->Relayout(input_literal->shape().layout()); ->Relayout(input_literal->shape().layout());
// Specify the requested output shape explicitly to ensure that this reshape // Specify the requested output shape explicitly to ensure that this reshape

View File

@ -100,7 +100,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
EXPECT_EQ(46.0f, actual->Get<float>({1, 1})); EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual); std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
LiteralTestUtil::ExpectEqual(*round_tripped, *actual); EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
} }
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
@ -135,7 +135,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
EXPECT_EQ(46.0f, actual->Get<float>({1, 1})); EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual); std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
LiteralTestUtil::ExpectEqual(*round_tripped, *actual); EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
} }
} // namespace } // namespace

View File

@ -41,7 +41,7 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
client_->TransferToServer(original).ConsumeValueOrDie(); client_->TransferToServer(original).ConsumeValueOrDie();
std::unique_ptr<Literal> result = std::unique_ptr<Literal> result =
client_->Transfer(*data).ConsumeValueOrDie(); client_->Transfer(*data).ConsumeValueOrDie();
LiteralTestUtil::ExpectEqual(original, *result); EXPECT_TRUE(LiteralTestUtil::Equal(original, *result));
} }
}; };

View File

@ -390,7 +390,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
&execution_options_) &execution_options_)
.ConsumeValueOrDie(); .ConsumeValueOrDie();
auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor); auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor);
LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
} }
} }
} }
@ -431,7 +431,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
&execution_options_) &execution_options_)
.ConsumeValueOrDie(); .ConsumeValueOrDie();
auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor); auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor);
LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
} }
} }
} }

View File

@ -175,7 +175,7 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) {
transfer_manager_->TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer)); stream_executor_, device_buffer));
LiteralTestUtil::ExpectEqual(*literal, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
} }
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
@ -189,7 +189,7 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
transfer_manager_->TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer)); stream_executor_, device_buffer));
LiteralTestUtil::ExpectEqual(*literal, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
} }
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
@ -209,7 +209,7 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
transfer_manager_->TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer)); stream_executor_, device_buffer));
LiteralTestUtil::ExpectEqual(*literal, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
} }
XLA_TEST_F(TransferManagerTest, TransferComplexValue) { XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
@ -224,7 +224,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
transfer_manager_->TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer)); stream_executor_, device_buffer));
LiteralTestUtil::ExpectEqual(*literal, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
} }
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
@ -243,7 +243,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
transfer_manager_->TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer)); stream_executor_, device_buffer));
LiteralTestUtil::ExpectEqual(*literal, *result); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
} }
} // namespace } // namespace