[XLA] Break out literal comparisons from testonly target.
Moves methods from LiteralTestUtil::* to Literal::* where they have nothing to do with test infrastructure. Pares down the "void" variants of the LiteralTestUtil methods and consolidates to the version that return success/failure such that the values can be EXPECT_TRUE / ASSERT_TRUE asserted in the caller test cases. This way the literal comparison functionality can be used from cc_libraries that are not test only / cc_binary. PiperOrigin-RevId: 196209410
This commit is contained in:
parent
5a492ef9bb
commit
400dd49b4c
@ -225,7 +225,7 @@ TEST_F(XlaCompilerTest, Simple) {
|
|||||||
xla::Literal::CreateR1<int32>({4, 143});
|
xla::Literal::CreateR1<int32>({4, 143});
|
||||||
std::unique_ptr<xla::Literal> expected_literal =
|
std::unique_ptr<xla::Literal> expected_literal =
|
||||||
xla::Literal::MakeTuple({expected0.get()});
|
xla::Literal::MakeTuple({expected0.get()});
|
||||||
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
|
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
|
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
|
||||||
@ -320,7 +320,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
|||||||
xla::Literal::CreateR1<int32>({-7, -42});
|
xla::Literal::CreateR1<int32>({-7, -42});
|
||||||
std::unique_ptr<xla::Literal> expected_literal =
|
std::unique_ptr<xla::Literal> expected_literal =
|
||||||
xla::Literal::MakeTuple({expected0.get()});
|
xla::Literal::MakeTuple({expected0.get()});
|
||||||
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
|
EXPECT_TRUE(
|
||||||
|
xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -355,7 +356,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
|||||||
xla::Literal::CreateR1<int32>({-7, -42});
|
xla::Literal::CreateR1<int32>({-7, -42});
|
||||||
std::unique_ptr<xla::Literal> expected =
|
std::unique_ptr<xla::Literal> expected =
|
||||||
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
||||||
xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal);
|
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -523,7 +524,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
|
|||||||
{output_base.get(), output_grad1.get(), output_grad2.get()});
|
{output_base.get(), output_grad1.get(), output_grad2.get()});
|
||||||
std::unique_ptr<xla::Literal> expected_literal =
|
std::unique_ptr<xla::Literal> expected_literal =
|
||||||
xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
|
xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
|
||||||
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
|
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests compilation and execution of a graph that adds two tensors.
|
// Tests compilation and execution of a graph that adds two tensors.
|
||||||
@ -746,7 +747,7 @@ TEST_F(XlaCompilerTest, Variables) {
|
|||||||
xla::Literal::CreateR1<int32>({4, 143});
|
xla::Literal::CreateR1<int32>({4, 143});
|
||||||
std::unique_ptr<xla::Literal> expected_literal =
|
std::unique_ptr<xla::Literal> expected_literal =
|
||||||
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
||||||
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
|
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests a simple graph that reads and writes a variable, with a
|
// Tests a simple graph that reads and writes a variable, with a
|
||||||
@ -811,7 +812,7 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
|||||||
xla::Literal::CreateR1<int32>({26, 66, 34, 401});
|
xla::Literal::CreateR1<int32>({26, 66, 34, 401});
|
||||||
std::unique_ptr<xla::Literal> expected_literal =
|
std::unique_ptr<xla::Literal> expected_literal =
|
||||||
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
||||||
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
|
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -330,6 +330,17 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "literal_comparison",
|
||||||
|
srcs = ["literal_comparison.cc"],
|
||||||
|
hdrs = ["literal_comparison.h"],
|
||||||
|
deps = [
|
||||||
|
":literal_util",
|
||||||
|
":util",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "metric_table_report",
|
name = "metric_table_report",
|
||||||
srcs = ["metric_table_report.cc"],
|
srcs = ["metric_table_report.cc"],
|
||||||
|
226
tensorflow/compiler/xla/literal_comparison.cc
Normal file
226
tensorflow/compiler/xla/literal_comparison.cc
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_comparison.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/core/lib/core/casts.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
|
||||||
|
using tensorflow::strings::StrCat;
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace literal_comparison {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Helper function for comparing a floating point type, FloatT, bitwise equal
|
||||||
|
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
|
||||||
|
// -- on miscompare, a nice error message is given in the AssertionFailure.
|
||||||
|
template <typename FloatT, typename UnsignedT>
|
||||||
|
Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
|
||||||
|
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
|
||||||
|
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
|
||||||
|
auto lhs_double = static_cast<double>(lhs);
|
||||||
|
auto rhs_double = static_cast<double>(rhs);
|
||||||
|
if (ulhs != urhs) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"floating values are not bitwise-equal; and equality testing "
|
||||||
|
"was requested: %s=%g=%a vs %s=%g=%a",
|
||||||
|
StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double,
|
||||||
|
StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Templated comparator that specializes for float equality comparison with the
|
||||||
|
// bitwise helper above (this is the un-specialized fallback, to just use the
|
||||||
|
// default gunit implementation).
|
||||||
|
template <typename NativeT>
|
||||||
|
Status CompareEqual(NativeT lhs, NativeT rhs) {
|
||||||
|
if (lhs == rhs) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return InvalidArgument("Expected equality of these values:\n %s\n %s",
|
||||||
|
StrCat(lhs).c_str(), StrCat(rhs).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specializations for floating types that do bitwise comparisons when equality
|
||||||
|
// comparison is requested.
|
||||||
|
template <>
|
||||||
|
Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
|
||||||
|
return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs) {
|
||||||
|
return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
Status CompareEqual<float>(float lhs, float rhs) {
|
||||||
|
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
Status CompareEqual<double>(double lhs, double rhs) {
|
||||||
|
return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
Status CompareEqual<complex64>(complex64 lhs, complex64 rhs) {
|
||||||
|
auto res = CompareEqual<float>(lhs.real(), rhs.real());
|
||||||
|
if (!res.ok()) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
return CompareEqual<float>(lhs.imag(), rhs.imag());
|
||||||
|
}
|
||||||
|
|
||||||
|
// A recursive function which iterates through every index of expected and
|
||||||
|
// actual literal and compares their values elementwise. Returns true if all
|
||||||
|
// elements are equal.
|
||||||
|
template <typename NativeT>
|
||||||
|
Status Equal(LiteralSlice expected, LiteralSlice actual,
|
||||||
|
tensorflow::gtl::MutableArraySlice<int64> multi_index,
|
||||||
|
int64 dimension) {
|
||||||
|
if (dimension == expected.shape().dimensions_size()) {
|
||||||
|
NativeT expected_value = expected.Get<NativeT>(multi_index);
|
||||||
|
NativeT actual_value = actual.Get<NativeT>(multi_index);
|
||||||
|
return CompareEqual<NativeT>(expected_value, actual_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status result;
|
||||||
|
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
|
||||||
|
multi_index[dimension] = i;
|
||||||
|
result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status EqualShapes(const Shape& expected, const Shape& actual) {
|
||||||
|
if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
|
||||||
|
return InvalidArgument("tupleness-mismatch! want: %s got %s",
|
||||||
|
ShapeUtil::HumanString(expected).c_str(),
|
||||||
|
ShapeUtil::HumanString(actual).c_str());
|
||||||
|
}
|
||||||
|
if (ShapeUtil::IsTuple(expected)) {
|
||||||
|
if (ShapeUtil::TupleElementCount(expected) !=
|
||||||
|
ShapeUtil::TupleElementCount(actual)) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"want tuple element count: %lld got tuple element count: %lld",
|
||||||
|
ShapeUtil::TupleElementCount(expected),
|
||||||
|
ShapeUtil::TupleElementCount(actual));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
|
||||||
|
Status result =
|
||||||
|
EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
|
||||||
|
if (!result.ok()) {
|
||||||
|
return AppendStatus(result, StrCat("mismatch in tuple index", i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
|
||||||
|
return InvalidArgument("want rank of %s got rank of %s",
|
||||||
|
ShapeUtil::HumanString(expected).c_str(),
|
||||||
|
ShapeUtil::HumanString(actual).c_str());
|
||||||
|
}
|
||||||
|
if (expected.element_type() != actual.element_type()) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"mismatch in primitive type %s vs %s",
|
||||||
|
PrimitiveType_Name(expected.element_type()).c_str(),
|
||||||
|
PrimitiveType_Name(actual.element_type()).c_str());
|
||||||
|
}
|
||||||
|
if (expected.dimensions_size() != actual.dimensions_size()) {
|
||||||
|
return InvalidArgument("want dimensions_size %d got dimensions_size %d",
|
||||||
|
expected.dimensions_size(),
|
||||||
|
actual.dimensions_size());
|
||||||
|
}
|
||||||
|
for (int i = 0; i < expected.dimensions_size(); ++i) {
|
||||||
|
if (expected.dimensions(i) != actual.dimensions(i)) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"mismatch in dimension #%d expected: %s actual: %s", i,
|
||||||
|
ShapeUtil::HumanString(expected).c_str(),
|
||||||
|
ShapeUtil::HumanString(actual).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
|
||||||
|
VLOG(1) << "expected:";
|
||||||
|
XLA_VLOG_LINES(1, expected.ToString());
|
||||||
|
VLOG(1) << "actual:";
|
||||||
|
XLA_VLOG_LINES(1, actual.ToString());
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
|
||||||
|
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
|
||||||
|
Status result;
|
||||||
|
switch (expected.shape().element_type()) {
|
||||||
|
case PRED:
|
||||||
|
result = Equal<bool>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case U8:
|
||||||
|
result = Equal<uint8>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case S32:
|
||||||
|
result = Equal<int32>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case S64:
|
||||||
|
result = Equal<int64>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case U32:
|
||||||
|
result = Equal<uint32>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case U64:
|
||||||
|
result = Equal<uint64>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case BF16:
|
||||||
|
result = Equal<bfloat16>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case F16:
|
||||||
|
result = Equal<half>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case F32:
|
||||||
|
result = Equal<float>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case F64:
|
||||||
|
result = Equal<double>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case C64:
|
||||||
|
result = Equal<complex64>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
|
case TUPLE: {
|
||||||
|
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
|
||||||
|
result.Update(
|
||||||
|
Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i})));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
LOG(FATAL)
|
||||||
|
<< "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
|
||||||
|
<< PrimitiveType_Name(expected.shape().element_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result.ok()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
return AppendStatus(result,
|
||||||
|
tensorflow::strings::Printf("expected: %s\nactual: %s",
|
||||||
|
expected.ToString().c_str(),
|
||||||
|
actual.ToString().c_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace literal_comparison
|
||||||
|
} // namespace xla
|
40
tensorflow/compiler/xla/literal_comparison.h
Normal file
40
tensorflow/compiler/xla/literal_comparison.h
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Library for comparing literals without taking a dependency on testing
|
||||||
|
// libraries.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace literal_comparison {
|
||||||
|
|
||||||
|
// Returns ok if the given shapes have the same rank, dimension sizes, and
|
||||||
|
// primitive types.
|
||||||
|
Status EqualShapes(const Shape& expected, const Shape& actual);
|
||||||
|
|
||||||
|
// Returns ok if the expected and actual literals are (bitwise) equal for all
|
||||||
|
// elements in the literal. Also, asserts that the rank, dimensions sizes, and
|
||||||
|
// primitive type are equal.
|
||||||
|
Status Equal(const LiteralSlice& expected, const LiteralSlice& actual);
|
||||||
|
|
||||||
|
} // namespace literal_comparison
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
|
@ -62,6 +62,45 @@ void ConvertEndianShort(char* bytes, int64 size) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return a literal with all arrays of type FromNativeT converted to type
|
||||||
|
// ToNativeT in the given literal.
|
||||||
|
template <typename FromNativeT, typename ToNativeT>
|
||||||
|
std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||||
|
// First construct shape of the result.
|
||||||
|
Shape result_shape(literal.shape());
|
||||||
|
ShapeUtil::ForEachMutableSubshape(
|
||||||
|
&result_shape, [](Shape* subshape, const ShapeIndex&) {
|
||||||
|
if (subshape->element_type() ==
|
||||||
|
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
|
||||||
|
subshape->set_element_type(
|
||||||
|
primitive_util::NativeToPrimitiveType<ToNativeT>());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
auto result = MakeUnique<Literal>(result_shape);
|
||||||
|
|
||||||
|
// Then copy over the data from 'literal' converting FromNativeT values to
|
||||||
|
// ToNativeT values as necessary.
|
||||||
|
ShapeUtil::ForEachSubshape(
|
||||||
|
literal.shape(),
|
||||||
|
[&](const Shape& subshape, const ShapeIndex& shape_index) {
|
||||||
|
if (ShapeUtil::IsArray(subshape)) {
|
||||||
|
if (subshape.element_type() ==
|
||||||
|
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
|
||||||
|
auto src = literal.data<FromNativeT>(shape_index);
|
||||||
|
auto dest = result->data<ToNativeT>(shape_index);
|
||||||
|
for (int64 i = 0; i < src.size(); ++i) {
|
||||||
|
dest[i] = static_cast<ToNativeT>(src[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TF_CHECK_OK(result->CopyFrom(literal,
|
||||||
|
/*dest_shape_index=*/shape_index,
|
||||||
|
/*src_shape_index=*/shape_index));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LiteralBase::~LiteralBase() {}
|
LiteralBase::~LiteralBase() {}
|
||||||
@ -195,6 +234,16 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
|
|||||||
return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
|
return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ std::unique_ptr<Literal> Literal::ConvertBF16ToF32(
|
||||||
|
const LiteralSlice& bf16_literal) {
|
||||||
|
return ConvertType<bfloat16, float>(bf16_literal);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* static */ std::unique_ptr<Literal> Literal::ConvertF32ToBF16(
|
||||||
|
const LiteralSlice& f32_literal) {
|
||||||
|
return ConvertType<float, bfloat16>(f32_literal);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
Status Literal::CopySliceFromInternal(
|
Status Literal::CopySliceFromInternal(
|
||||||
const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
|
const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
|
||||||
@ -788,6 +837,78 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
|
|||||||
return std::move(output);
|
return std::move(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ std::unique_ptr<Literal> Literal::ReshapeSlice(
|
||||||
|
tensorflow::gtl::ArraySlice<int64> new_dimensions,
|
||||||
|
tensorflow::gtl::ArraySlice<int64> minor_to_major,
|
||||||
|
const LiteralSlice& literal) {
|
||||||
|
int64 new_num_elements = 1;
|
||||||
|
for (int64 i = 0; i < new_dimensions.size(); ++i) {
|
||||||
|
new_num_elements *= new_dimensions[i];
|
||||||
|
}
|
||||||
|
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
|
||||||
|
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
|
||||||
|
|
||||||
|
auto new_literal = MakeUnique<Literal>(
|
||||||
|
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
|
||||||
|
|
||||||
|
// Create a new shape with the given minor-to-major layout. This shape is used
|
||||||
|
// solely for converting linear address to multi-dimensional addresses when
|
||||||
|
// writing elements to the new literal.
|
||||||
|
Shape shape_with_layout = new_literal->shape();
|
||||||
|
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
|
||||||
|
|
||||||
|
// Copy data into new literal, element-by-element.
|
||||||
|
for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
|
||||||
|
std::vector<int64> from_multi_index =
|
||||||
|
IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
|
||||||
|
std::vector<int64> to_multi_index =
|
||||||
|
IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
|
||||||
|
switch (literal.shape().element_type()) {
|
||||||
|
case PRED:
|
||||||
|
new_literal->Set<bool>(to_multi_index,
|
||||||
|
literal.Get<bool>(from_multi_index));
|
||||||
|
break;
|
||||||
|
case U8:
|
||||||
|
new_literal->Set<uint8>(to_multi_index,
|
||||||
|
literal.Get<uint8>(from_multi_index));
|
||||||
|
break;
|
||||||
|
case U32:
|
||||||
|
new_literal->Set<uint32>(to_multi_index,
|
||||||
|
literal.Get<uint32>(from_multi_index));
|
||||||
|
break;
|
||||||
|
case S32:
|
||||||
|
new_literal->Set<int32>(to_multi_index,
|
||||||
|
literal.Get<int32>(from_multi_index));
|
||||||
|
break;
|
||||||
|
case U64:
|
||||||
|
new_literal->Set<uint64>(to_multi_index,
|
||||||
|
literal.Get<uint64>(from_multi_index));
|
||||||
|
break;
|
||||||
|
case S64:
|
||||||
|
new_literal->Set<int64>(to_multi_index,
|
||||||
|
literal.Get<int64>(from_multi_index));
|
||||||
|
break;
|
||||||
|
case F32:
|
||||||
|
new_literal->Set<float>(to_multi_index,
|
||||||
|
literal.Get<float>(from_multi_index));
|
||||||
|
break;
|
||||||
|
case F64:
|
||||||
|
new_literal->Set<double>(to_multi_index,
|
||||||
|
literal.Get<double>(from_multi_index));
|
||||||
|
break;
|
||||||
|
case C64:
|
||||||
|
new_literal->Set<complex64>(to_multi_index,
|
||||||
|
literal.Get<complex64>(from_multi_index));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unhandled primitive element type: "
|
||||||
|
<< PrimitiveType_Name(literal.shape().element_type());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new_literal;
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<Literal> LiteralBase::Transpose(
|
std::unique_ptr<Literal> LiteralBase::Transpose(
|
||||||
tensorflow::gtl::ArraySlice<int64> permutation) const {
|
tensorflow::gtl::ArraySlice<int64> permutation) const {
|
||||||
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
|
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
|
||||||
@ -2123,6 +2244,11 @@ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
|
|||||||
return std::move(literal);
|
return std::move(literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ string Literal::MultiIndexAsString(
|
||||||
|
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||||
|
return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
|
||||||
|
}
|
||||||
|
|
||||||
const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
|
const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
|
||||||
return piece(shape_index).untyped_data();
|
return piece(shape_index).untyped_data();
|
||||||
}
|
}
|
||||||
|
@ -920,9 +920,66 @@ class Literal : public LiteralBase {
|
|||||||
PrimitiveType primitive_type,
|
PrimitiveType primitive_type,
|
||||||
tensorflow::gtl::ArraySlice<int64> dimensions);
|
tensorflow::gtl::ArraySlice<int64> dimensions);
|
||||||
|
|
||||||
|
// If the given literal's data type is bfloat16, converts it to a float
|
||||||
|
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
||||||
|
// recursively converts its elements.
|
||||||
|
static std::unique_ptr<Literal> ConvertBF16ToF32(
|
||||||
|
const LiteralSlice& bf16_literal);
|
||||||
|
|
||||||
|
// If the given literal's data type is float, converts it to a bfloat16
|
||||||
|
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
||||||
|
// recursively converts its elements.
|
||||||
|
static std::unique_ptr<Literal> ConvertF32ToBF16(
|
||||||
|
const LiteralSlice& f32_literal);
|
||||||
|
|
||||||
|
// Creates a literal with a new shape with the given new dimensions using the
|
||||||
|
// data in the given input literal. For reshaping purposes the (flat) data
|
||||||
|
// buffer of the input literal is assumed to have the given minor_to_major
|
||||||
|
// layout order.
|
||||||
|
static std::unique_ptr<Literal> ReshapeSlice(
|
||||||
|
tensorflow::gtl::ArraySlice<int64> new_dimensions,
|
||||||
|
tensorflow::gtl::ArraySlice<int64> minor_to_major,
|
||||||
|
const LiteralSlice& literal);
|
||||||
|
|
||||||
|
// Creates a literal with the supplied shape, and uses the provided value
|
||||||
|
// generator to populate the literal's values.
|
||||||
|
// Returns the new literal object, or an error Status if failed.
|
||||||
|
template <
|
||||||
|
PrimitiveType type,
|
||||||
|
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
||||||
|
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
||||||
|
const Shape& shape,
|
||||||
|
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
|
||||||
|
|
||||||
|
// Creates a literal with the supplied shape, and initializes the literal
|
||||||
|
// values using a normal distribution with given mean and stddev standard
|
||||||
|
// deviation, and using the engine as entropy generator.
|
||||||
|
// Returns the new literal object, or an error Status if failed.
|
||||||
|
template <
|
||||||
|
PrimitiveType type, typename E,
|
||||||
|
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
||||||
|
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
||||||
|
const Shape& shape, E* engine, T mean, T stddev);
|
||||||
|
|
||||||
|
// Creates a literal with the supplied shape, and initializes the literal
|
||||||
|
// values using a normal distribution with given mean and stddev standard
|
||||||
|
// deviation.
|
||||||
|
// Returns the new literal object, or an error Status if failed.
|
||||||
|
template <
|
||||||
|
PrimitiveType type,
|
||||||
|
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
||||||
|
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
||||||
|
const Shape& shape, T mean, T stddev);
|
||||||
|
|
||||||
//
|
//
|
||||||
// End of factory methods.
|
// End of factory methods.
|
||||||
|
|
||||||
|
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
|
||||||
|
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
|
||||||
|
// dimension 1 equal to 8.
|
||||||
|
static string MultiIndexAsString(
|
||||||
|
tensorflow::gtl::ArraySlice<int64> multi_index);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Recursively sets the subshapes and buffers of all subpieces rooted at
|
// Recursively sets the subshapes and buffers of all subpieces rooted at
|
||||||
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
|
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
|
||||||
@ -1558,6 +1615,38 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
|
|||||||
return literal;
|
return literal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <PrimitiveType type, typename T>
|
||||||
|
/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
|
||||||
|
const Shape& shape,
|
||||||
|
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
|
||||||
|
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
|
||||||
|
TF_RET_CHECK(shape.element_type() == type);
|
||||||
|
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
|
||||||
|
TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
|
||||||
|
[&](tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||||
|
return generator(indexes);
|
||||||
|
}));
|
||||||
|
return std::move(literal);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <PrimitiveType type, typename E, typename T>
|
||||||
|
/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
|
||||||
|
const Shape& shape, E* engine, T mean, T stddev) {
|
||||||
|
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
|
||||||
|
std::normal_distribution<NativeT> generator(mean, stddev);
|
||||||
|
return CreateRandomLiteral<type, NativeT>(
|
||||||
|
shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
|
||||||
|
return generator(*engine);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <PrimitiveType type, typename T>
|
||||||
|
/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
|
||||||
|
const Shape& shape, T mean, T stddev) {
|
||||||
|
std::minstd_rand0 engine;
|
||||||
|
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
|
#endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
|
||||||
|
@ -101,8 +101,8 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
|
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
|
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
|
||||||
computation, {}, nullptr));
|
computation, {}, nullptr));
|
||||||
LiteralTestUtil::ExpectNear(*expected_literal, *result_literal,
|
EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
|
||||||
ErrorSpec(0.0001));
|
ErrorSpec(0.0001)));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -149,12 +149,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
|
|||||||
EXPECT_TRUE(OutputsBF16(dot->operand(1)));
|
EXPECT_TRUE(OutputsBF16(dot->operand(1)));
|
||||||
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
|
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
|
||||||
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
|
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
dot->operand(0)->literal(),
|
dot->operand(0)->literal(),
|
||||||
*LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)));
|
*Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))));
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
dot->operand(1)->literal(),
|
dot->operand(1)->literal(),
|
||||||
*LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)));
|
*Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that BF16 can be propagated through nested tuples.
|
// Tests that BF16 can be propagated through nested tuples.
|
||||||
|
@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
|
|||||||
const int64 slice_limits[] = {10, 8, 6, 5, 9};
|
const int64 slice_limits[] = {10, 8, 6, 5, 9};
|
||||||
const int64 slice_strides[] = {1, 1, 1, 1, 1};
|
const int64 slice_strides[] = {1, 1, 1, 1, 1};
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto literal,
|
TF_ASSERT_OK_AND_ASSIGN(auto literal,
|
||||||
LiteralTestUtil::CreateRandomLiteral<F32>(
|
Literal::CreateRandomLiteral<F32>(
|
||||||
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
||||||
HloInstruction* literal_instruction = builder.AddInstruction(
|
HloInstruction* literal_instruction = builder.AddInstruction(
|
||||||
HloInstruction::CreateConstant(std::move(literal)));
|
HloInstruction::CreateConstant(std::move(literal)));
|
||||||
@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
|
|||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
const int64 dimensions[] = {11, 8, 7, 5, 9};
|
const int64 dimensions[] = {11, 8, 7, 5, 9};
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto literal,
|
TF_ASSERT_OK_AND_ASSIGN(auto literal,
|
||||||
LiteralTestUtil::CreateRandomLiteral<F32>(
|
Literal::CreateRandomLiteral<F32>(
|
||||||
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
||||||
auto literal_clone = literal->Literal::CloneToUnique();
|
auto literal_clone = literal->Literal::CloneToUnique();
|
||||||
HloInstruction* literal_instruction = builder.AddInstruction(
|
HloInstruction* literal_instruction = builder.AddInstruction(
|
||||||
|
@ -72,7 +72,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
|
|||||||
|
|
||||||
auto result = ExecuteAndTransfer(std::move(module), {});
|
auto result = ExecuteAndTransfer(std::move(module), {});
|
||||||
auto expected = Literal::CreateR0<float>(84.0);
|
auto expected = Literal::CreateR0<float>(84.0);
|
||||||
LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
|
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
|
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
|
||||||
@ -104,7 +104,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
|
|||||||
|
|
||||||
auto result = ExecuteAndTransfer(std::move(module), {});
|
auto result = ExecuteAndTransfer(std::move(module), {});
|
||||||
auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
|
auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
|
||||||
LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
|
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
|
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
|
||||||
@ -134,7 +134,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
|
|||||||
|
|
||||||
auto result = ExecuteAndTransfer(std::move(module), {});
|
auto result = ExecuteAndTransfer(std::move(module), {});
|
||||||
auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
|
auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
|
||||||
LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
|
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
|
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
|
||||||
|
@ -82,9 +82,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
|
|||||||
auto element_type = expected->shape().element_type();
|
auto element_type = expected->shape().element_type();
|
||||||
if (element_type == F32 || element_type == F64) {
|
if (element_type == F32 || element_type == F64) {
|
||||||
ErrorSpec error(aabs);
|
ErrorSpec error(aabs);
|
||||||
LiteralTestUtil::ExpectNear(*expected, *result, error);
|
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
|
||||||
} else {
|
} else {
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
|
|||||||
|
|
||||||
std::unique_ptr<Literal> result = Evaluate();
|
std::unique_ptr<Literal> result = Evaluate();
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool use_bfloat16_;
|
bool use_bfloat16_;
|
||||||
@ -129,7 +129,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
|
|||||||
|
|
||||||
auto expected = Literal::CreateR2<float>({{0, 4}, {2, 4}});
|
auto expected = Literal::CreateR2<float>({{0, 4}, {2, 4}});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
|
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
|
||||||
@ -150,7 +150,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
|
|||||||
|
|
||||||
auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}});
|
auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
|
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
|
||||||
@ -175,7 +175,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
|
|||||||
|
|
||||||
auto expected = Literal::CreateR2<float>({{2, 5}, {0, 4}});
|
auto expected = Literal::CreateR2<float>({{2, 5}, {0, 4}});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verifies that HloEvaluator evaluates a HLO instruction that performs
|
// Verifies that HloEvaluator evaluates a HLO instruction that performs
|
||||||
@ -307,7 +307,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
|
|||||||
|
|
||||||
auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}});
|
auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verifies Reshape operation is correctly evaluated.
|
// Verifies Reshape operation is correctly evaluated.
|
||||||
@ -315,7 +315,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
|
|||||||
HloComputation::Builder b(TestName());
|
HloComputation::Builder b(TestName());
|
||||||
const int64 dimensions[] = {11, 8, 7, 5, 9};
|
const int64 dimensions[] = {11, 8, 7, 5, 9};
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto literal,
|
TF_ASSERT_OK_AND_ASSIGN(auto literal,
|
||||||
LiteralTestUtil::CreateRandomLiteral<F32>(
|
Literal::CreateRandomLiteral<F32>(
|
||||||
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
||||||
auto literal_clone = literal->CloneToUnique();
|
auto literal_clone = literal->CloneToUnique();
|
||||||
HloInstruction* literal_instruction =
|
HloInstruction* literal_instruction =
|
||||||
@ -351,7 +351,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
|
|||||||
|
|
||||||
std::unique_ptr<Literal> result = Evaluate({});
|
std::unique_ptr<Literal> result = Evaluate({});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*result, *output_literal);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
|
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
|
||||||
@ -370,7 +370,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
|
|||||||
|
|
||||||
std::unique_ptr<Literal> result = Evaluate({});
|
std::unique_ptr<Literal> result = Evaluate({});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*result, *output_literal);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
|
TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
|
||||||
@ -392,7 +392,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
|
|||||||
|
|
||||||
auto expected =
|
auto expected =
|
||||||
Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
|
Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
|
TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
|
||||||
@ -413,7 +413,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
|
|||||||
std::unique_ptr<Literal> result = Evaluate();
|
std::unique_ptr<Literal> result = Evaluate();
|
||||||
|
|
||||||
auto expected = Literal::CreateR1<int64>({100, 200});
|
auto expected = Literal::CreateR1<int64>({100, 200});
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
|
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
|
||||||
@ -432,7 +432,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
|
|||||||
|
|
||||||
std::unique_ptr<Literal> result = Evaluate();
|
std::unique_ptr<Literal> result = Evaluate();
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*result, *expected);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
|
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
|
||||||
@ -452,7 +452,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
|
|||||||
|
|
||||||
std::unique_ptr<Literal> result = Evaluate();
|
std::unique_ptr<Literal> result = Evaluate();
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*result, *expected);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
PaddingConfig CreatePaddingConfig(
|
PaddingConfig CreatePaddingConfig(
|
||||||
@ -490,7 +490,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
|
|||||||
auto expected = Literal::CreateR2<int32>(
|
auto expected = Literal::CreateR2<int32>(
|
||||||
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
|
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
|
TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
|
||||||
@ -525,7 +525,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
|
|||||||
|
|
||||||
auto expected = Literal::CreateR4FromArray4D<float>(*expected_array);
|
auto expected = Literal::CreateR4FromArray4D<float>(*expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, NegativePadding2D) {
|
TEST_P(HloEvaluatorTest, NegativePadding2D) {
|
||||||
@ -567,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
|
|||||||
(*expected_array)(0, 4) = 2.718f;
|
(*expected_array)(0, 4) = 2.718f;
|
||||||
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
|
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5));
|
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0x1.0P-5)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
|
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
|
||||||
@ -606,7 +606,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
|
|||||||
auto expected_array = MakeUnique<Array2D<float>>(0, 9);
|
auto expected_array = MakeUnique<Array2D<float>>(0, 9);
|
||||||
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
|
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
|
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
|
||||||
@ -651,7 +651,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
|
auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
|
TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
|
||||||
@ -688,7 +688,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
|
|||||||
|
|
||||||
auto expected = Literal::CreateR1<float>({22.f, 28.f});
|
auto expected = Literal::CreateR1<float>({22.f, 28.f});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
|
TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
|
||||||
@ -737,7 +737,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
|
|||||||
});
|
});
|
||||||
auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
|
auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, SimpleConv1D) {
|
TEST_P(HloEvaluatorTest, SimpleConv1D) {
|
||||||
@ -785,7 +785,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
|
|||||||
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
|
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
|
||||||
auto expected = Literal::CreateR3FromArray3D<float>(expected_array);
|
auto expected = Literal::CreateR3FromArray3D<float>(expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
|
TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
|
||||||
@ -847,7 +847,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
|
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
|
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
|
||||||
@ -927,7 +927,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
|
|||||||
auto expected = Literal::CreateR4FromArray4D<float>(
|
auto expected = Literal::CreateR4FromArray4D<float>(
|
||||||
use_bfloat16_ ? expected_array_bf16 : expected_array);
|
use_bfloat16_ ? expected_array_bf16 : expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
|
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
|
||||||
@ -1004,7 +1004,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
|
|||||||
auto expected = Literal::CreateR4FromArray4D<float>(
|
auto expected = Literal::CreateR4FromArray4D<float>(
|
||||||
use_bfloat16_ ? expected_array_bf16 : expected_array);
|
use_bfloat16_ ? expected_array_bf16 : expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
|
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
|
||||||
@ -1067,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
|
|||||||
}));
|
}));
|
||||||
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
|
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
|
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
|
||||||
@ -1131,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
|
|||||||
}));
|
}));
|
||||||
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
|
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest,
|
TEST_P(HloEvaluatorTest,
|
||||||
@ -1203,7 +1203,7 @@ TEST_P(HloEvaluatorTest,
|
|||||||
}));
|
}));
|
||||||
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
|
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
|
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
|
||||||
@ -1319,7 +1319,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
|
|||||||
|
|
||||||
auto expected = Literal::CreateR1<float>({6, 18});
|
auto expected = Literal::CreateR1<float>({6, 18});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, ReduceWindowMax) {
|
TEST_P(HloEvaluatorTest, ReduceWindowMax) {
|
||||||
@ -1370,7 +1370,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
|
|||||||
std::unique_ptr<Literal> result = Evaluate();
|
std::unique_ptr<Literal> result = Evaluate();
|
||||||
|
|
||||||
auto expected = Literal::CreateR2<float>({{6, 7}});
|
auto expected = Literal::CreateR2<float>({{6, 7}});
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
|
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
|
||||||
@ -1427,7 +1427,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
|
|||||||
std::unique_ptr<Literal> result = Evaluate();
|
std::unique_ptr<Literal> result = Evaluate();
|
||||||
|
|
||||||
auto expected = Literal::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
|
auto expected = Literal::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
|
TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
|
||||||
@ -1490,7 +1490,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
|
|||||||
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
|
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
|
||||||
std::unique_ptr<Literal> result_literal =
|
std::unique_ptr<Literal> result_literal =
|
||||||
Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
|
Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
|
||||||
LiteralTestUtil::ExpectEqual(*result_literal, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, StridedSlice) {
|
TEST_P(HloEvaluatorTest, StridedSlice) {
|
||||||
@ -1523,7 +1523,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
|
|||||||
{19},
|
{19},
|
||||||
});
|
});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, DynamicSlice) {
|
TEST_P(HloEvaluatorTest, DynamicSlice) {
|
||||||
@ -1556,7 +1556,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
|
|||||||
{6, 7, 8},
|
{6, 7, 8},
|
||||||
});
|
});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verifies that the HloEvaluator's implementation goes along with existing
|
// Verifies that the HloEvaluator's implementation goes along with existing
|
||||||
@ -1591,7 +1591,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
|
|||||||
{6, 7, 8},
|
{6, 7, 8},
|
||||||
});
|
});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
|
TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
|
||||||
@ -1627,7 +1627,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
|
|||||||
{5, -6, -7},
|
{5, -6, -7},
|
||||||
});
|
});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, SetAndGetTuples) {
|
TEST_P(HloEvaluatorTest, SetAndGetTuples) {
|
||||||
@ -1662,7 +1662,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
|
|||||||
{5, 6, 7},
|
{5, 6, 7},
|
||||||
});
|
});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
|
TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
|
||||||
@ -1703,7 +1703,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
|
|||||||
result_inner_literal.get(),
|
result_inner_literal.get(),
|
||||||
});
|
});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, Reverse) {
|
TEST_P(HloEvaluatorTest, Reverse) {
|
||||||
@ -1756,7 +1756,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
|
|||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
|
TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
|
||||||
@ -1776,8 +1776,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
|
|||||||
add, {{param0, Literal::CreateR1<float>({1, 2, 3, 4}).get()},
|
add, {{param0, Literal::CreateR1<float>({1, 2, 3, 4}).get()},
|
||||||
{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
|
{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
|
||||||
TF_ASSERT_OK(result.status());
|
TF_ASSERT_OK(result.status());
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<float>({11, 22, 33, 44}),
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*result.ValueOrDie());
|
*Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that EvaluateWithSubstitutions works if one of the operands to the op
|
// Check that EvaluateWithSubstitutions works if one of the operands to the op
|
||||||
@ -1800,8 +1800,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
|
|||||||
auto result = evaluator.EvaluateWithSubstitutions(
|
auto result = evaluator.EvaluateWithSubstitutions(
|
||||||
add, {{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
|
add, {{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
|
||||||
TF_ASSERT_OK(result.status());
|
TF_ASSERT_OK(result.status());
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<float>({11, 22, 33, 44}),
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*result.ValueOrDie());
|
*Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
|
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
|
||||||
@ -1823,9 +1823,9 @@ ENTRY main {
|
|||||||
std::unique_ptr<Literal> operand =
|
std::unique_ptr<Literal> operand =
|
||||||
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
||||||
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
|
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(
|
||||||
*Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
|
LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
|
||||||
*Evaluate({operand.get(), gather_indices.get()}));
|
*Evaluate({operand.get(), gather_indices.get()})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
|
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
|
||||||
@ -1847,9 +1847,9 @@ ENTRY main {
|
|||||||
std::unique_ptr<Literal> operand =
|
std::unique_ptr<Literal> operand =
|
||||||
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
||||||
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
|
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
|
*Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
|
||||||
*Evaluate({operand.get(), gather_indices.get()}));
|
*Evaluate({operand.get(), gather_indices.get()})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
|
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
|
||||||
@ -1872,10 +1872,10 @@ ENTRY main {
|
|||||||
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
||||||
std::unique_ptr<Literal> gather_indices =
|
std::unique_ptr<Literal> gather_indices =
|
||||||
Literal::CreateR2<int32>({{0, 2}, {2, 1}});
|
Literal::CreateR2<int32>({{0, 2}, {2, 1}});
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*Literal::CreateR3<int32>(
|
*Literal::CreateR3<int32>(
|
||||||
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
|
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
|
||||||
*Evaluate({operand.get(), gather_indices.get()}));
|
*Evaluate({operand.get(), gather_indices.get()})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
|
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
|
||||||
@ -1900,9 +1900,9 @@ ENTRY main {
|
|||||||
{{-7, 7}, {-8, 8}, {-9, 9}}});
|
{{-7, 7}, {-8, 8}, {-9, 9}}});
|
||||||
std::unique_ptr<Literal> gather_indices =
|
std::unique_ptr<Literal> gather_indices =
|
||||||
Literal::CreateR2<int32>({{0, 0}, {1, 0}});
|
Literal::CreateR2<int32>({{0, 0}, {1, 0}});
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(
|
||||||
*Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
|
LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
|
||||||
*Evaluate({operand.get(), gather_indices.get()}));
|
*Evaluate({operand.get(), gather_indices.get()})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest,
|
TEST_P(HloEvaluatorTest,
|
||||||
@ -1928,9 +1928,9 @@ ENTRY main {
|
|||||||
{{-7, 7}, {-8, 8}, {-9, 9}}});
|
{{-7, 7}, {-8, 8}, {-9, 9}}});
|
||||||
std::unique_ptr<Literal> gather_indices =
|
std::unique_ptr<Literal> gather_indices =
|
||||||
Literal::CreateR2<int32>({{0, 0}, {1, 0}});
|
Literal::CreateR2<int32>({{0, 0}, {1, 0}});
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(
|
||||||
*Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
|
LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
|
||||||
*Evaluate({operand.get(), gather_indices.get()}));
|
*Evaluate({operand.get(), gather_indices.get()})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
|
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
|
||||||
@ -1952,9 +1952,9 @@ ENTRY main {
|
|||||||
std::unique_ptr<Literal> operand =
|
std::unique_ptr<Literal> operand =
|
||||||
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
||||||
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
|
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(
|
||||||
*Literal::CreateR2<int32>({{5}}),
|
LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{5}}),
|
||||||
*Evaluate({operand.get(), gather_indices.get()}));
|
*Evaluate({operand.get(), gather_indices.get()})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
|
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
|
||||||
@ -1977,9 +1977,9 @@ ENTRY main {
|
|||||||
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
||||||
std::unique_ptr<Literal> gather_indices =
|
std::unique_ptr<Literal> gather_indices =
|
||||||
Literal::CreateR2<int32>({{2, 1}, {1, 1}});
|
Literal::CreateR2<int32>({{2, 1}, {1, 1}});
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(
|
||||||
*Literal::CreateR3<int32>({{{8}}, {{5}}}),
|
LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{8}}, {{5}}}),
|
||||||
*Evaluate({operand.get(), gather_indices.get()}));
|
*Evaluate({operand.get(), gather_indices.get()})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
|
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
|
||||||
@ -2000,9 +2000,9 @@ ENTRY main {
|
|||||||
ParseAndVerifyModule(hlo_text);
|
ParseAndVerifyModule(hlo_text);
|
||||||
std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
|
std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
|
||||||
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
|
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(
|
||||||
*Literal::CreateR2<int32>({{}, {}}),
|
LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{}, {}}),
|
||||||
*Evaluate({operand.get(), gather_indices.get()}));
|
*Evaluate({operand.get(), gather_indices.get()})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
|
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
|
||||||
@ -2025,9 +2025,9 @@ ENTRY main {
|
|||||||
std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({0, 1, 2});
|
std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({0, 1, 2});
|
||||||
std::unique_ptr<Literal> gather_indices =
|
std::unique_ptr<Literal> gather_indices =
|
||||||
Literal::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
|
Literal::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(
|
||||||
*Literal::CreateR2<int32>({{0, 1}, {2, 1}}),
|
LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{0, 1}, {2, 1}}),
|
||||||
*Evaluate({operand.get(), gather_indices.get()}));
|
*Evaluate({operand.get(), gather_indices.get()})));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verifies that HloEvaluator evaluates a HLO instruction that performs
|
// Verifies that HloEvaluator evaluates a HLO instruction that performs
|
||||||
|
@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) {
|
|||||||
// Verify execution on CPU.
|
// Verify execution on CPU.
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
auto expected = Literal::CreateR1<float>({4, 3, 3, 4});
|
auto expected = Literal::CreateR1<float>({4, 3, 3, 4});
|
||||||
LiteralTestUtil::ExpectEqual(*result, *expected);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that `constant` function is changed to `broadcast`.
|
// Test that `constant` function is changed to `broadcast`.
|
||||||
@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) {
|
|||||||
// Verify execution on CPU.
|
// Verify execution on CPU.
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
auto expected = Literal::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
|
auto expected = Literal::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||||
LiteralTestUtil::ExpectEqual(*result, *expected);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
||||||
@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
|||||||
// Verify execution on CPU.
|
// Verify execution on CPU.
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
auto expected = Literal::CreateR1<float>({3, 1, -1, -3});
|
auto expected = Literal::CreateR1<float>({3, 1, -1, -3});
|
||||||
LiteralTestUtil::ExpectEqual(*result, *expected);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,6 +87,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:array2d",
|
"//tensorflow/compiler/xla:array2d",
|
||||||
"//tensorflow/compiler/xla:array3d",
|
"//tensorflow/compiler/xla:array3d",
|
||||||
"//tensorflow/compiler/xla:array4d",
|
"//tensorflow/compiler/xla:array4d",
|
||||||
|
"//tensorflow/compiler/xla:literal_comparison",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
|
@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(42.0), *result,
|
EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0<float>(42.0), *result,
|
||||||
error_spec_);
|
error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
|
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
|
||||||
@ -62,9 +62,9 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
|
*Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
|
||||||
error_spec_);
|
error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
|
XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
|
||||||
@ -85,13 +85,13 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
|
*Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
|
||||||
LiteralSlice(*result, {0}), error_spec_);
|
LiteralSlice(*result, {0}), error_spec_));
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
|
*Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
|
||||||
LiteralSlice(*result, {1}), error_spec_);
|
LiteralSlice(*result, {1}), error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
|
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
|
||||||
@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(
|
EXPECT_TRUE(
|
||||||
*Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
|
LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
|
||||||
error_spec_);
|
*result, error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
|
XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
|
||||||
@ -125,9 +125,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(
|
EXPECT_TRUE(
|
||||||
*Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
|
LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}),
|
||||||
error_spec_);
|
*result, error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
|
XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
|
||||||
@ -142,10 +142,10 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
|
*Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
|
||||||
{{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
|
{{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
|
||||||
*result, error_spec_);
|
*result, error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
|
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
|
||||||
@ -166,8 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
|
|||||||
Array2D<float> pz({{1, 2}, {1, 2}});
|
Array2D<float> pz({{1, 2}, {1, 2}});
|
||||||
expected.FillWithPZ(pz);
|
expected.FillWithPZ(pz);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*result, error_spec_);
|
*Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
|
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
|
||||||
@ -196,8 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
|
|||||||
}
|
}
|
||||||
expected.FillWithYX(yx);
|
expected.FillWithYX(yx);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*result, error_spec_);
|
*Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
|
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
|
||||||
@ -218,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result,
|
EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array),
|
||||||
error_spec_);
|
*result, error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
|
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
|
||||||
@ -238,8 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
|
|||||||
Array4D<float> expected(64, 64, 3, 3);
|
Array4D<float> expected(64, 64, 3, 3);
|
||||||
expected.Fill(1.0f);
|
expected.Fill(1.0f);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*result, error_spec_);
|
*Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
|
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
|
||||||
@ -260,8 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
|
|||||||
Array4D<float> expected(3, 3, 2, 2);
|
Array4D<float> expected(3, 3, 2, 2);
|
||||||
expected.FillWithYX(to_broadcast);
|
expected.FillWithYX(to_broadcast);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*result, error_spec_);
|
*Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
|
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
|
||||||
@ -291,8 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*result, error_spec_);
|
*Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -297,7 +297,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
|||||||
std::unique_ptr<Literal> converted_expected;
|
std::unique_ptr<Literal> converted_expected;
|
||||||
Shape layout_shape;
|
Shape layout_shape;
|
||||||
if (use_bfloat16_) {
|
if (use_bfloat16_) {
|
||||||
converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
|
converted_expected = Literal::ConvertF32ToBF16(expected);
|
||||||
expected_ptr = converted_expected.get();
|
expected_ptr = converted_expected.get();
|
||||||
if (shape_with_layout != nullptr) {
|
if (shape_with_layout != nullptr) {
|
||||||
layout_shape = *shape_with_layout;
|
layout_shape = *shape_with_layout;
|
||||||
@ -311,7 +311,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto expect_equal = [&](const Literal& actual, const string& error_message) {
|
auto expect_equal = [&](const Literal& actual, const string& error_message) {
|
||||||
LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message;
|
||||||
};
|
};
|
||||||
if (execution_options_.debug_options().xla_test_all_output_layouts()) {
|
if (execution_options_.debug_options().xla_test_all_output_layouts()) {
|
||||||
return ComputeAndCompareLiteralWithAllOutputLayouts(
|
return ComputeAndCompareLiteralWithAllOutputLayouts(
|
||||||
@ -323,7 +323,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
|||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
|
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
|
||||||
shape_with_layout));
|
shape_with_layout));
|
||||||
LiteralTestUtil::ExpectEqual(*expected_ptr, *actual);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -349,7 +349,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
|||||||
std::unique_ptr<Literal> converted_expected;
|
std::unique_ptr<Literal> converted_expected;
|
||||||
Shape layout_shape;
|
Shape layout_shape;
|
||||||
if (use_bfloat16_) {
|
if (use_bfloat16_) {
|
||||||
converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
|
converted_expected = Literal::ConvertF32ToBF16(expected);
|
||||||
expected_ptr = converted_expected.get();
|
expected_ptr = converted_expected.get();
|
||||||
if (shape_with_layout != nullptr) {
|
if (shape_with_layout != nullptr) {
|
||||||
layout_shape = *shape_with_layout;
|
layout_shape = *shape_with_layout;
|
||||||
@ -363,7 +363,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto expect_near = [&](const Literal& actual, const string& error_message) {
|
auto expect_near = [&](const Literal& actual, const string& error_message) {
|
||||||
LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message);
|
EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error))
|
||||||
|
<< error_message;
|
||||||
};
|
};
|
||||||
if (execution_options_.debug_options().xla_test_all_output_layouts()) {
|
if (execution_options_.debug_options().xla_test_all_output_layouts()) {
|
||||||
return ComputeAndCompareLiteralWithAllOutputLayouts(
|
return ComputeAndCompareLiteralWithAllOutputLayouts(
|
||||||
@ -375,7 +376,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
|||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
|
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
|
||||||
shape_with_layout));
|
shape_with_layout));
|
||||||
LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error);
|
EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -407,7 +408,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto actual = actual_status.ConsumeValueOrDie();
|
auto actual = actual_status.ConsumeValueOrDie();
|
||||||
LiteralTestUtil::ExpectEqual(expected, *actual);
|
EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClientLibraryTestBase::ComputeAndCompareTuple(
|
void ClientLibraryTestBase::ComputeAndCompareTuple(
|
||||||
@ -419,7 +420,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto actual = actual_status.ConsumeValueOrDie();
|
auto actual = actual_status.ConsumeValueOrDie();
|
||||||
LiteralTestUtil::ExpectNear(expected, *actual, error);
|
EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClientLibraryTestBase::ComputeAndCompare(
|
void ClientLibraryTestBase::ComputeAndCompare(
|
||||||
@ -431,7 +432,7 @@ void ClientLibraryTestBase::ComputeAndCompare(
|
|||||||
}
|
}
|
||||||
std::unique_ptr<Literal> reference, result;
|
std::unique_ptr<Literal> reference, result;
|
||||||
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
|
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
|
||||||
LiteralTestUtil::ExpectEqual(*reference, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClientLibraryTestBase::ComputeAndCompare(
|
void ClientLibraryTestBase::ComputeAndCompare(
|
||||||
@ -444,7 +445,7 @@ void ClientLibraryTestBase::ComputeAndCompare(
|
|||||||
}
|
}
|
||||||
std::unique_ptr<Literal> reference, result;
|
std::unique_ptr<Literal> reference, result;
|
||||||
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
|
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
|
||||||
LiteralTestUtil::ExpectNear(*reference, *result, error);
|
EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
|
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
|
||||||
@ -562,7 +563,7 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
|
|||||||
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
|
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
|
||||||
XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
return builder->ConstantLiteral(
|
return builder->ConstantLiteral(
|
||||||
use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
|
use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<GlobalData>
|
std::unique_ptr<GlobalData>
|
||||||
@ -583,7 +584,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
|
|||||||
const Literal* param_literal = &literal;
|
const Literal* param_literal = &literal;
|
||||||
std::unique_ptr<Literal> converted_literal;
|
std::unique_ptr<Literal> converted_literal;
|
||||||
if (use_bfloat16_) {
|
if (use_bfloat16_) {
|
||||||
converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
|
converted_literal = Literal::ConvertF32ToBF16(literal);
|
||||||
param_literal = converted_literal.get();
|
param_literal = converted_literal.get();
|
||||||
}
|
}
|
||||||
std::unique_ptr<GlobalData> data =
|
std::unique_ptr<GlobalData> data =
|
||||||
|
@ -541,7 +541,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
|
|||||||
XlaBuilder* builder, XlaOp* data_handle) {
|
XlaBuilder* builder, XlaOp* data_handle) {
|
||||||
std::unique_ptr<Literal> literal = Literal::CreateR0(value);
|
std::unique_ptr<Literal> literal = Literal::CreateR0(value);
|
||||||
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
||||||
literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
|
literal = Literal::ConvertF32ToBF16(*literal);
|
||||||
}
|
}
|
||||||
std::unique_ptr<GlobalData> data =
|
std::unique_ptr<GlobalData> data =
|
||||||
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
||||||
@ -555,7 +555,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
|
|||||||
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
|
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
|
||||||
std::unique_ptr<Literal> literal = Literal::CreateR1(values);
|
std::unique_ptr<Literal> literal = Literal::CreateR1(values);
|
||||||
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
||||||
literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
|
literal = Literal::ConvertF32ToBF16(*literal);
|
||||||
}
|
}
|
||||||
std::unique_ptr<GlobalData> data =
|
std::unique_ptr<GlobalData> data =
|
||||||
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
||||||
@ -569,7 +569,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
|
|||||||
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
|
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
|
||||||
std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
|
std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
|
||||||
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
||||||
literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
|
literal = Literal::ConvertF32ToBF16(*literal);
|
||||||
}
|
}
|
||||||
std::unique_ptr<GlobalData> data =
|
std::unique_ptr<GlobalData> data =
|
||||||
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
||||||
@ -583,7 +583,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
|
|||||||
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
|
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
|
||||||
std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
|
std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
|
||||||
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
||||||
literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
|
literal = Literal::ConvertF32ToBF16(*literal);
|
||||||
}
|
}
|
||||||
std::unique_ptr<GlobalData> data =
|
std::unique_ptr<GlobalData> data =
|
||||||
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
||||||
|
@ -62,9 +62,9 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto computed, client_->Transfer(*data, &expected_literal->shape()));
|
auto computed, client_->Transfer(*data, &expected_literal->shape()));
|
||||||
|
|
||||||
LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
|
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
|
||||||
computed->shape());
|
expected_literal->shape(), computed->shape()));
|
||||||
LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -142,7 +142,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
|
|||||||
auto result_literal,
|
auto result_literal,
|
||||||
client_->Transfer(*results[0], &expected_result->shape()));
|
client_->Transfer(*results[0], &expected_result->shape()));
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*expected_result, *result_literal);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -50,8 +50,8 @@ class CompilationCacheTest : public ClientLibraryTestBase {
|
|||||||
/*execution_options=*/&execution_options_,
|
/*execution_options=*/&execution_options_,
|
||||||
&execution_profile)
|
&execution_profile)
|
||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(expected_result),
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*result, error_spec_);
|
*Literal::CreateR0<float>(expected_result), *result, error_spec_));
|
||||||
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
|
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,8 +67,8 @@ class CompilationCacheTest : public ClientLibraryTestBase {
|
|||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
std::unique_ptr<Literal> result =
|
std::unique_ptr<Literal> result =
|
||||||
client_->Transfer(*data_handle).ConsumeValueOrDie();
|
client_->Transfer(*data_handle).ConsumeValueOrDie();
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>(expected_result),
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*result, error_spec_);
|
*Literal::CreateR2<float>(expected_result), *result, error_spec_));
|
||||||
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
|
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,7 +208,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
|
|||||||
ComputeConstantLiteral(client, computation, &b));
|
ComputeConstantLiteral(client, computation, &b));
|
||||||
std::unique_ptr<Literal> expected_literal =
|
std::unique_ptr<Literal> expected_literal =
|
||||||
Literal::CreateR1<int32>({4, 6});
|
Literal::CreateR1<int32>({4, 6});
|
||||||
LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,7 +222,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(auto computed,
|
TF_ASSERT_OK_AND_ASSIGN(auto computed,
|
||||||
ComputeConstantLiteral(client, computation, &b));
|
ComputeConstantLiteral(client, computation, &b));
|
||||||
std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
|
std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
|
||||||
LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,9 +244,9 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
|
|||||||
std::unique_ptr<Literal> expected_literal =
|
std::unique_ptr<Literal> expected_literal =
|
||||||
Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
|
Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
|
||||||
LayoutUtil::MakeLayout(layout));
|
LayoutUtil::MakeLayout(layout));
|
||||||
LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
|
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
|
||||||
computed->shape());
|
expected_literal->shape(), computed->shape()));
|
||||||
LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,7 +49,7 @@ class CopyOpTest : public HloTestBase {
|
|||||||
module->AddEntryComputation(std::move(computation));
|
module->AddEntryComputation(std::move(computation));
|
||||||
|
|
||||||
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
|
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
|
||||||
LiteralTestUtil::ExpectEqual(literal, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
|
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
|
||||||
@ -253,7 +253,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
|
|||||||
|
|
||||||
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
|
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
|
||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
LiteralTestUtil::ExpectEqual(*empty, *actual);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -118,9 +118,9 @@ class FusionTest : public HloTestBase {
|
|||||||
auto expected = Literal::CreateR2FromArray2D(answer_data);
|
auto expected = Literal::CreateR2FromArray2D(answer_data);
|
||||||
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
if (primitive_util::IsFloatingPointType(prim_type)) {
|
if (primitive_util::IsFloatingPointType(prim_type)) {
|
||||||
LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4));
|
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
|
||||||
} else {
|
} else {
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *actual);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,9 +221,9 @@ XLA_TEST_F(FusionTest, Test) {
|
|||||||
const4, reshape3, add2, const1, const0},
|
const4, reshape3, add2, const1, const0},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{0.5}, {2.72}}),
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}),
|
*Literal::CreateR2<float>({{0.5}, {2.72}}),
|
||||||
ErrorSpec(1e-4));
|
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test whether we emit appropriate code for parameters of fusion instructions.
|
// Test whether we emit appropriate code for parameters of fusion instructions.
|
||||||
@ -247,9 +247,9 @@ XLA_TEST_F(FusionTest, Parameter) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}),
|
*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
|
||||||
ErrorSpec(1e-4));
|
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
|
XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
|
||||||
@ -307,9 +307,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectNear(
|
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||||
*Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
|
*Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4));
|
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, ReshapeToScalar) {
|
XLA_TEST_F(FusionTest, ReshapeToScalar) {
|
||||||
@ -322,8 +322,9 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build())
|
hlo_module->AddEntryComputation(builder.Build())
|
||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(5),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR0<int32>(5),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
|
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
|
||||||
@ -336,9 +337,9 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build())
|
hlo_module->AddEntryComputation(builder.Build())
|
||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
|
*Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
|
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
|
||||||
@ -351,9 +352,9 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build())
|
hlo_module->AddEntryComputation(builder.Build())
|
||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
|
*Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
|
XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
|
||||||
@ -366,8 +367,9 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build())
|
hlo_module->AddEntryComputation(builder.Build())
|
||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Reshape__1by1by1) {
|
XLA_TEST_F(FusionTest, Reshape__1by1by1) {
|
||||||
@ -380,8 +382,9 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build())
|
hlo_module->AddEntryComputation(builder.Build())
|
||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR3<int32>({{{7}}}),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{7}}}),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Reshape__) {
|
XLA_TEST_F(FusionTest, Reshape__) {
|
||||||
@ -394,8 +397,9 @@ XLA_TEST_F(FusionTest, Reshape__) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build())
|
hlo_module->AddEntryComputation(builder.Build())
|
||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
|
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
|
||||||
@ -408,9 +412,9 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build())
|
hlo_module->AddEntryComputation(builder.Build())
|
||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
|
*Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Transpose_2by3) {
|
XLA_TEST_F(FusionTest, Transpose_2by3) {
|
||||||
@ -423,9 +427,9 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build())
|
hlo_module->AddEntryComputation(builder.Build())
|
||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
|
*Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Transpose_3by3) {
|
XLA_TEST_F(FusionTest, Transpose_3by3) {
|
||||||
@ -438,9 +442,9 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
|
|||||||
hlo_module->AddEntryComputation(builder.Build())
|
hlo_module->AddEntryComputation(builder.Build())
|
||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
|
*Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Reverse) {
|
XLA_TEST_F(FusionTest, Reverse) {
|
||||||
@ -454,8 +458,9 @@ XLA_TEST_F(FusionTest, Reverse) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({3, 2, 1}),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR1<int32>({3, 2, 1}),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, ReverseNegate) {
|
XLA_TEST_F(FusionTest, ReverseNegate) {
|
||||||
@ -471,8 +476,9 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-3, -2, -1}),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-3, -2, -1}),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, BroadcastNegate) {
|
XLA_TEST_F(FusionTest, BroadcastNegate) {
|
||||||
@ -488,8 +494,9 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -1}),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -1}),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, SliceNegate) {
|
XLA_TEST_F(FusionTest, SliceNegate) {
|
||||||
@ -505,8 +512,9 @@ XLA_TEST_F(FusionTest, SliceNegate) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -3}),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -3}),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
|
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
|
||||||
@ -526,8 +534,9 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
|
|||||||
/*instructions_to_fuse=*/{negate3, dynamic_slice2},
|
/*instructions_to_fuse=*/{negate3, dynamic_slice2},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-2, -3}),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-2, -3}),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, ReshapeNegate) {
|
XLA_TEST_F(FusionTest, ReshapeNegate) {
|
||||||
@ -543,8 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/64070202): Investigate failure.
|
// TODO(b/64070202): Investigate failure.
|
||||||
@ -561,8 +571,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
|
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
|
||||||
@ -591,8 +602,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(15),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR0<int32>(15),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
|
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
|
||||||
@ -612,8 +624,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(-15),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR0<int32>(-15),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
|
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
|
||||||
@ -661,9 +674,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
|
|||||||
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
|
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
|
||||||
HloInstruction::FusionKind::kLoop);
|
HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
|
*Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
// When a constant (or other op) which has multiple users is imported
|
// When a constant (or other op) which has multiple users is imported
|
||||||
@ -697,8 +710,9 @@ XLA_TEST_F(FusionTest, SharedConstant) {
|
|||||||
// fused instruction contains the constant(2), the parameter, and 4 adds
|
// fused instruction contains the constant(2), the parameter, and 4 adds
|
||||||
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
|
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}),
|
EXPECT_TRUE(
|
||||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
LiteralTestUtil::Equal(*Literal::CreateR1<int32>({8}),
|
||||||
|
*ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
|
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
|
||||||
|
@ -629,8 +629,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
|
|||||||
client_->ExecuteParallel(computation_instances));
|
client_->ExecuteParallel(computation_instances));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
|
||||||
client_->Transfer(*(result_data[0])));
|
client_->Transfer(*(result_data[0])));
|
||||||
LiteralTestUtil::ExpectEqual(
|
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||||
*result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}));
|
*result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}})));
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/index_util.h"
|
#include "tensorflow/compiler/xla/index_util.h"
|
||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal_comparison.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
@ -46,119 +47,23 @@ using ::tensorflow::strings::StrCat;
|
|||||||
|
|
||||||
/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes(
|
/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes(
|
||||||
const Shape& expected, const Shape& actual) {
|
const Shape& expected, const Shape& actual) {
|
||||||
if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
|
Status result = literal_comparison::EqualShapes(expected, actual);
|
||||||
return ::testing::AssertionFailure()
|
if (result.ok()) {
|
||||||
<< "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected)
|
return ::testing::AssertionSuccess();
|
||||||
<< " got: " << ShapeUtil::HumanString(actual);
|
|
||||||
}
|
}
|
||||||
if (ShapeUtil::IsTuple(expected)) {
|
return ::testing::AssertionFailure() << result;
|
||||||
if (ShapeUtil::TupleElementCount(expected) !=
|
}
|
||||||
ShapeUtil::TupleElementCount(actual)) {
|
|
||||||
return ::testing::AssertionFailure()
|
/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapesAndLayouts(
|
||||||
<< "want tuple element count: "
|
const Shape& expected, const Shape& actual) {
|
||||||
<< ShapeUtil::TupleElementCount(expected)
|
if (expected.ShortDebugString() != actual.ShortDebugString()) {
|
||||||
<< " got tuple element count: "
|
return ::testing::AssertionFailure()
|
||||||
<< ShapeUtil::TupleElementCount(actual);
|
<< "want: " << expected.ShortDebugString()
|
||||||
}
|
<< " got: " << actual.ShortDebugString();
|
||||||
for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
|
|
||||||
::testing::AssertionResult result =
|
|
||||||
EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i))
|
|
||||||
<< "mismatch in tuple index " << i;
|
|
||||||
if (!result) {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
|
|
||||||
return ::testing::AssertionFailure()
|
|
||||||
<< "want rank of: " << ShapeUtil::HumanString(expected)
|
|
||||||
<< " got rank of: " << ShapeUtil::HumanString(actual);
|
|
||||||
}
|
|
||||||
if (expected.element_type() != actual.element_type()) {
|
|
||||||
return ::testing::AssertionFailure()
|
|
||||||
<< PrimitiveType_Name(expected.element_type()) << " vs "
|
|
||||||
<< PrimitiveType_Name(actual.element_type());
|
|
||||||
}
|
|
||||||
if (expected.dimensions_size() != actual.dimensions_size()) {
|
|
||||||
return ::testing::AssertionFailure()
|
|
||||||
<< "want dimensions_size " << expected.dimensions_size()
|
|
||||||
<< " got dimensions_size " << actual.dimensions_size();
|
|
||||||
}
|
|
||||||
for (int i = 0; i < expected.dimensions_size(); ++i) {
|
|
||||||
if (expected.dimensions(i) != actual.dimensions(i)) {
|
|
||||||
return ::testing::AssertionFailure()
|
|
||||||
<< "mismatch in dimension #" << i
|
|
||||||
<< " expected: " << ShapeUtil::HumanString(expected)
|
|
||||||
<< " actual: " << ShapeUtil::HumanString(actual);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return ::testing::AssertionSuccess();
|
return ::testing::AssertionSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected,
|
|
||||||
const Shape& actual) {
|
|
||||||
ASSERT_TRUE(EqualShapes(expected, actual));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts(
|
|
||||||
const Shape& expected, const Shape& actual) {
|
|
||||||
ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Return a literal with all arrays of type FromNativeT converted to type
|
|
||||||
// ToNativeT in the given literal.
|
|
||||||
template <typename FromNativeT, typename ToNativeT>
|
|
||||||
std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
|
||||||
// First construct shape of the result.
|
|
||||||
Shape result_shape(literal.shape());
|
|
||||||
ShapeUtil::ForEachMutableSubshape(
|
|
||||||
&result_shape, [](Shape* subshape, const ShapeIndex&) {
|
|
||||||
if (subshape->element_type() ==
|
|
||||||
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
|
|
||||||
subshape->set_element_type(
|
|
||||||
primitive_util::NativeToPrimitiveType<ToNativeT>());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
auto result = MakeUnique<Literal>(result_shape);
|
|
||||||
|
|
||||||
// Then copy over the data from 'literal' converting FromNativeT values to
|
|
||||||
// ToNativeT values as necessary.
|
|
||||||
ShapeUtil::ForEachSubshape(
|
|
||||||
literal.shape(),
|
|
||||||
[&](const Shape& subshape, const ShapeIndex& shape_index) {
|
|
||||||
if (ShapeUtil::IsArray(subshape)) {
|
|
||||||
if (subshape.element_type() ==
|
|
||||||
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
|
|
||||||
auto src = literal.data<FromNativeT>(shape_index);
|
|
||||||
auto dest = result->data<ToNativeT>(shape_index);
|
|
||||||
for (int64 i = 0; i < src.size(); ++i) {
|
|
||||||
dest[i] = static_cast<ToNativeT>(src[i]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TF_CHECK_OK(result->CopyFrom(literal,
|
|
||||||
/*dest_shape_index=*/shape_index,
|
|
||||||
/*src_shape_index=*/shape_index));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
|
|
||||||
LiteralSlice literal) {
|
|
||||||
return ConvertType<bfloat16, float>(literal);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
|
|
||||||
LiteralSlice literal) {
|
|
||||||
return ConvertType<float, bfloat16>(literal);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
string Hostname() {
|
string Hostname() {
|
||||||
@ -168,183 +73,15 @@ string Hostname() {
|
|||||||
return string(hostname);
|
return string(hostname);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function for comparing a floating point type, FloatT, bitwise equal
|
|
||||||
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
|
|
||||||
// -- on miscompare, a nice error message is given in the AssertionFailure.
|
|
||||||
template <typename FloatT, typename UnsignedT>
|
|
||||||
::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
|
|
||||||
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
|
|
||||||
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
|
|
||||||
auto lhs_double = static_cast<double>(lhs);
|
|
||||||
auto rhs_double = static_cast<double>(rhs);
|
|
||||||
if (ulhs != urhs) {
|
|
||||||
return ::testing::AssertionFailure() << Printf(
|
|
||||||
"floating values are not bitwise-equal; and equality testing "
|
|
||||||
"was requested: %s=%g=%a vs %s=%g=%a",
|
|
||||||
StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double,
|
|
||||||
lhs_double, StrCat(tensorflow::strings::Hex(urhs)).c_str(),
|
|
||||||
rhs_double, rhs_double);
|
|
||||||
}
|
|
||||||
return ::testing::AssertionSuccess();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Templated comparator that specializes for float equality comparison with the
|
|
||||||
// bitwise helper above (this is the un-specialized fallback, to just use the
|
|
||||||
// default gunit implementation).
|
|
||||||
template <typename NativeT>
|
|
||||||
::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) {
|
|
||||||
if (lhs == rhs) {
|
|
||||||
return ::testing::AssertionSuccess();
|
|
||||||
}
|
|
||||||
::testing::Message msg;
|
|
||||||
msg << "Expected equality of these values:";
|
|
||||||
msg << "\n " << lhs;
|
|
||||||
msg << "\n " << rhs;
|
|
||||||
|
|
||||||
return ::testing::AssertionFailure() << msg;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Specializations for floating types that do bitwise comparisons when equality
|
|
||||||
// comparison is requested.
|
|
||||||
template <>
|
|
||||||
::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
|
|
||||||
return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
::testing::AssertionResult CompareEqual<Eigen::half>(Eigen::half lhs,
|
|
||||||
Eigen::half rhs) {
|
|
||||||
return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
|
|
||||||
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
::testing::AssertionResult CompareEqual<double>(double lhs, double rhs) {
|
|
||||||
return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
::testing::AssertionResult CompareEqual<complex64>(complex64 lhs,
|
|
||||||
complex64 rhs) {
|
|
||||||
auto res = CompareEqual<float>(lhs.real(), rhs.real());
|
|
||||||
if (!res) {
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
return CompareEqual<float>(lhs.imag(), rhs.imag());
|
|
||||||
}
|
|
||||||
|
|
||||||
// A recursive function which iterates through every index of expected and
|
|
||||||
// actual literal and compares their values elementwise. Returns true if all
|
|
||||||
// elements are equal.
|
|
||||||
template <typename NativeT>
|
|
||||||
bool ExpectLiteralsEqual(LiteralSlice expected, LiteralSlice actual,
|
|
||||||
tensorflow::gtl::MutableArraySlice<int64> multi_index,
|
|
||||||
int64 dimension) {
|
|
||||||
if (dimension == expected.shape().dimensions_size()) {
|
|
||||||
NativeT expected_value = expected.Get<NativeT>(multi_index);
|
|
||||||
NativeT actual_value = actual.Get<NativeT>(multi_index);
|
|
||||||
::testing::AssertionResult result =
|
|
||||||
CompareEqual<NativeT>(expected_value, actual_value);
|
|
||||||
return result; // Defines implicit coersion to bool.
|
|
||||||
}
|
|
||||||
|
|
||||||
bool all_match = true;
|
|
||||||
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
|
|
||||||
multi_index[dimension] = i;
|
|
||||||
all_match = all_match && ExpectLiteralsEqual<NativeT>(
|
|
||||||
expected, actual, multi_index, dimension + 1);
|
|
||||||
}
|
|
||||||
return all_match;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/* static */ void LiteralTestUtil::ExpectEqual(LiteralSlice expected,
|
|
||||||
LiteralSlice actual,
|
|
||||||
const string& message) {
|
|
||||||
EXPECT_TRUE(Equal(expected, actual))
|
|
||||||
<< "expected:\n"
|
|
||||||
<< expected.ToString() << "\n\tvs actual:\n"
|
|
||||||
<< actual.ToString()
|
|
||||||
<< (message.empty() ? "" : StrCat("\nmessage: ", message));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */ void LiteralTestUtil::ExpectNotEqual(LiteralSlice expected,
|
|
||||||
LiteralSlice actual) {
|
|
||||||
EXPECT_FALSE(Equal(expected, actual));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
|
/* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
|
||||||
LiteralSlice expected, LiteralSlice actual) {
|
const LiteralSlice& expected, const LiteralSlice& actual) {
|
||||||
VLOG(1) << "expected:";
|
Status result = literal_comparison::Equal(expected, actual);
|
||||||
XLA_VLOG_LINES(1, expected.ToString());
|
if (result.ok()) {
|
||||||
VLOG(1) << "actual:";
|
return ::testing::AssertionSuccess();
|
||||||
XLA_VLOG_LINES(1, actual.ToString());
|
|
||||||
|
|
||||||
AssertEqualShapes(expected.shape(), actual.shape());
|
|
||||||
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
|
|
||||||
bool match = false;
|
|
||||||
switch (expected.shape().element_type()) {
|
|
||||||
case PRED:
|
|
||||||
match = ExpectLiteralsEqual<bool>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case U8:
|
|
||||||
match = ExpectLiteralsEqual<uint8>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case S32:
|
|
||||||
match = ExpectLiteralsEqual<int32>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case S64:
|
|
||||||
match = ExpectLiteralsEqual<int64>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case U32:
|
|
||||||
match = ExpectLiteralsEqual<uint32>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case U64:
|
|
||||||
match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case BF16:
|
|
||||||
match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case F16:
|
|
||||||
match = ExpectLiteralsEqual<half>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case F32:
|
|
||||||
match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case F64:
|
|
||||||
match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case C64:
|
|
||||||
match = ExpectLiteralsEqual<complex64>(expected, actual, &multi_index, 0);
|
|
||||||
break;
|
|
||||||
case TUPLE: {
|
|
||||||
bool tuple_match = true;
|
|
||||||
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
|
|
||||||
SCOPED_TRACE(StrCat("Tuple index ", i, " in ",
|
|
||||||
ShapeUtil::HumanString(expected.shape())));
|
|
||||||
|
|
||||||
// Create LiteralSlices of the expected and actual elements.
|
|
||||||
auto result =
|
|
||||||
Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}));
|
|
||||||
tuple_match = tuple_match ? !!result : false;
|
|
||||||
}
|
|
||||||
match = tuple_match;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
LOG(FATAL)
|
|
||||||
<< "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
|
|
||||||
<< PrimitiveType_Name(expected.shape().element_type());
|
|
||||||
}
|
}
|
||||||
::testing::AssertionResult result = ::testing::AssertionSuccess();
|
return ::testing::AssertionFailure() << result;
|
||||||
if (!match) {
|
|
||||||
result = ::testing::AssertionFailure()
|
|
||||||
<< "expected: " << expected.ToString()
|
|
||||||
<< "\nactual: " << actual.ToString();
|
|
||||||
VLOG(1) << result.message();
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -368,7 +105,7 @@ int64 RecursiveElementCount(const Shape& shape) {
|
|||||||
// 3 minutes. The utility of printing a literal with >1000 elements is
|
// 3 minutes. The utility of printing a literal with >1000 elements is
|
||||||
// questionable, especially when writing the Literal proto to disk is orders
|
// questionable, especially when writing the Literal proto to disk is orders
|
||||||
// of magnitude faster.
|
// of magnitude faster.
|
||||||
string TruncateHugeLiteral(LiteralSlice literal) {
|
string TruncateHugeLiteral(const LiteralSlice& literal) {
|
||||||
return RecursiveElementCount(literal.shape()) < 1000
|
return RecursiveElementCount(literal.shape()) < 1000
|
||||||
? literal.ToString()
|
? literal.ToString()
|
||||||
: "[TRUNCATED, Literal with more than 1000 values]";
|
: "[TRUNCATED, Literal with more than 1000 values]";
|
||||||
@ -435,8 +172,8 @@ class NearComparator {
|
|||||||
// result. The assertion result is successful if all actual and expected
|
// result. The assertion result is successful if all actual and expected
|
||||||
// elements are within the given error bound. In case of error, the assertion
|
// elements are within the given error bound. In case of error, the assertion
|
||||||
// result contains a detailed error message in case of failure.
|
// result contains a detailed error message in case of failure.
|
||||||
static ::testing::AssertionResult Compare(LiteralSlice expected,
|
static ::testing::AssertionResult Compare(const LiteralSlice& expected,
|
||||||
LiteralSlice actual,
|
const LiteralSlice& actual,
|
||||||
ErrorSpec error,
|
ErrorSpec error,
|
||||||
bool detailed_message) {
|
bool detailed_message) {
|
||||||
NearComparator<NativeT> comparator(expected, actual, error,
|
NearComparator<NativeT> comparator(expected, actual, error,
|
||||||
@ -464,7 +201,7 @@ class NearComparator {
|
|||||||
return Printf(
|
return Printf(
|
||||||
"actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
|
"actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
|
||||||
FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
|
FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
|
||||||
LiteralTestUtil::MultiIndexAsString(
|
Literal::MultiIndexAsString(
|
||||||
IndexUtil::LinearIndexToMultidimensionalIndex(shape,
|
IndexUtil::LinearIndexToMultidimensionalIndex(shape,
|
||||||
linear_index))
|
linear_index))
|
||||||
.c_str(),
|
.c_str(),
|
||||||
@ -472,8 +209,9 @@ class NearComparator {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit NearComparator(LiteralSlice expected, LiteralSlice actual,
|
explicit NearComparator(const LiteralSlice& expected,
|
||||||
ErrorSpec error, bool detailed_message)
|
const LiteralSlice& actual, ErrorSpec error,
|
||||||
|
bool detailed_message)
|
||||||
: expected_(expected),
|
: expected_(expected),
|
||||||
actual_(actual),
|
actual_(actual),
|
||||||
error_(error),
|
error_(error),
|
||||||
@ -649,7 +387,7 @@ class NearComparator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Writes the given literal to a file in the test temporary directory.
|
// Writes the given literal to a file in the test temporary directory.
|
||||||
void WriteLiteralToTempFile(LiteralSlice literal, const string& name) {
|
void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
|
||||||
int64 now_usec = tensorflow::Env::Default()->NowMicros();
|
int64 now_usec = tensorflow::Env::Default()->NowMicros();
|
||||||
string filename = tensorflow::io::JoinPath(
|
string filename = tensorflow::io::JoinPath(
|
||||||
tensorflow::testing::TmpDir(),
|
tensorflow::testing::TmpDir(),
|
||||||
@ -794,8 +532,8 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
|
|||||||
// Helper function for comparing two literals for nearness. Handles tuple-shapes
|
// Helper function for comparing two literals for nearness. Handles tuple-shapes
|
||||||
// via recursion. shape_index is the ShapeIndex of expected (or actual)
|
// via recursion. shape_index is the ShapeIndex of expected (or actual)
|
||||||
// currently being compared.
|
// currently being compared.
|
||||||
::testing::AssertionResult NearHelper(LiteralSlice expected,
|
::testing::AssertionResult NearHelper(const LiteralSlice& expected,
|
||||||
LiteralSlice actual,
|
const LiteralSlice& actual,
|
||||||
const ErrorSpec& error,
|
const ErrorSpec& error,
|
||||||
bool detailed_message,
|
bool detailed_message,
|
||||||
const ShapeIndex& shape_index) {
|
const ShapeIndex& shape_index) {
|
||||||
@ -874,30 +612,14 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(
|
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(
|
||||||
LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
|
const LiteralSlice& expected, const LiteralSlice& actual,
|
||||||
bool detailed_message) {
|
const ErrorSpec& error, bool detailed_message) {
|
||||||
return NearHelper(expected, actual, error, detailed_message,
|
return NearHelper(expected, actual, error, detailed_message,
|
||||||
/*shape_index=*/{});
|
/*shape_index=*/{});
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ void LiteralTestUtil::ExpectNear(LiteralSlice expected,
|
/* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
|
||||||
LiteralSlice actual,
|
const LiteralSlice& expected, const LiteralSlice& actual,
|
||||||
const ErrorSpec& error,
|
|
||||||
const string& message) {
|
|
||||||
::testing::AssertionResult res =
|
|
||||||
Near(expected, actual, error, /*detailed_message=*/false);
|
|
||||||
if (!res) {
|
|
||||||
res << "Expected: " << TruncateHugeLiteral(expected) << "\n";
|
|
||||||
res << "Actual: " << TruncateHugeLiteral(actual) << "\n";
|
|
||||||
if (!message.empty()) {
|
|
||||||
res << StrCat("\nmessage: ", message);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
EXPECT_TRUE(res);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
|
|
||||||
LiteralSlice expected, LiteralSlice actual,
|
|
||||||
const tensorflow::gtl::optional<ErrorSpec>& error) {
|
const tensorflow::gtl::optional<ErrorSpec>& error) {
|
||||||
if (error.has_value()) {
|
if (error.has_value()) {
|
||||||
VLOG(1) << "Expects near";
|
VLOG(1) << "Expects near";
|
||||||
@ -907,86 +629,4 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
|
|||||||
return Equal(expected, actual);
|
return Equal(expected, actual);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ void LiteralTestUtil::ExpectNearOrEqual(
|
|
||||||
LiteralSlice expected, LiteralSlice actual,
|
|
||||||
const tensorflow::gtl::optional<ErrorSpec>& error) {
|
|
||||||
EXPECT_TRUE(NearOrEqual(expected, actual, error));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */ string LiteralTestUtil::MultiIndexAsString(
|
|
||||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
|
||||||
return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
|
|
||||||
tensorflow::gtl::ArraySlice<int64> new_dimensions,
|
|
||||||
tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal) {
|
|
||||||
int64 new_num_elements = 1;
|
|
||||||
for (int64 i = 0; i < new_dimensions.size(); ++i) {
|
|
||||||
new_num_elements *= new_dimensions[i];
|
|
||||||
}
|
|
||||||
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
|
|
||||||
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
|
|
||||||
|
|
||||||
auto new_literal = MakeUnique<Literal>(
|
|
||||||
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
|
|
||||||
|
|
||||||
// Create a new shape with the given minor-to-major layout. This shape is used
|
|
||||||
// solely for converting linear address to multi-dimensional addresses when
|
|
||||||
// writing elements to the new literal.
|
|
||||||
Shape shape_with_layout = new_literal->shape();
|
|
||||||
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
|
|
||||||
|
|
||||||
// Copy data into new literal, element-by-element.
|
|
||||||
for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
|
|
||||||
std::vector<int64> from_multi_index =
|
|
||||||
IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
|
|
||||||
std::vector<int64> to_multi_index =
|
|
||||||
IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
|
|
||||||
switch (literal.shape().element_type()) {
|
|
||||||
case PRED:
|
|
||||||
new_literal->Set<bool>(to_multi_index,
|
|
||||||
literal.Get<bool>(from_multi_index));
|
|
||||||
break;
|
|
||||||
case U8:
|
|
||||||
new_literal->Set<uint8>(to_multi_index,
|
|
||||||
literal.Get<uint8>(from_multi_index));
|
|
||||||
break;
|
|
||||||
case U32:
|
|
||||||
new_literal->Set<uint32>(to_multi_index,
|
|
||||||
literal.Get<uint32>(from_multi_index));
|
|
||||||
break;
|
|
||||||
case S32:
|
|
||||||
new_literal->Set<int32>(to_multi_index,
|
|
||||||
literal.Get<int32>(from_multi_index));
|
|
||||||
break;
|
|
||||||
case U64:
|
|
||||||
new_literal->Set<uint64>(to_multi_index,
|
|
||||||
literal.Get<uint64>(from_multi_index));
|
|
||||||
break;
|
|
||||||
case S64:
|
|
||||||
new_literal->Set<int64>(to_multi_index,
|
|
||||||
literal.Get<int64>(from_multi_index));
|
|
||||||
break;
|
|
||||||
case F32:
|
|
||||||
new_literal->Set<float>(to_multi_index,
|
|
||||||
literal.Get<float>(from_multi_index));
|
|
||||||
break;
|
|
||||||
case F64:
|
|
||||||
new_literal->Set<double>(to_multi_index,
|
|
||||||
literal.Get<double>(from_multi_index));
|
|
||||||
break;
|
|
||||||
case C64:
|
|
||||||
new_literal->Set<complex64>(to_multi_index,
|
|
||||||
literal.Get<complex64>(from_multi_index));
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
LOG(FATAL) << "Unhandled primitive element type: "
|
|
||||||
<< PrimitiveType_Name(literal.shape().element_type());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return new_literal;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -57,65 +57,47 @@ class LiteralTestUtil {
|
|||||||
public:
|
public:
|
||||||
// Asserts that the given shapes have the same rank, dimension sizes, and
|
// Asserts that the given shapes have the same rank, dimension sizes, and
|
||||||
// primitive types.
|
// primitive types.
|
||||||
static ::testing::AssertionResult EqualShapes(const Shape& expected,
|
static ::testing::AssertionResult EqualShapes(
|
||||||
const Shape& actual);
|
const Shape& expected, const Shape& actual) MUST_USE_RESULT;
|
||||||
static void AssertEqualShapes(const Shape& expected, const Shape& actual);
|
|
||||||
|
|
||||||
// Asserts that the provided shapes are equal as defined in AssertEqualShapes
|
// Asserts that the provided shapes are equal as defined in AssertEqualShapes
|
||||||
// and that they have the same layout.
|
// and that they have the same layout.
|
||||||
static void AssertEqualShapesAndLayouts(const Shape& expected,
|
static ::testing::AssertionResult EqualShapesAndLayouts(
|
||||||
const Shape& actual);
|
const Shape& expected, const Shape& actual) MUST_USE_RESULT;
|
||||||
|
|
||||||
// If the given literal's data type is bfloat16, converts it to a float
|
static ::testing::AssertionResult Equal(const LiteralSlice& expected,
|
||||||
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
const LiteralSlice& actual)
|
||||||
// recursively converts its elements.
|
TF_MUST_USE_RESULT;
|
||||||
static std::unique_ptr<Literal> ConvertBF16ToF32(LiteralSlice bf16_literal);
|
|
||||||
|
|
||||||
// If the given literal's data type is float, converts it to a bfloat16
|
|
||||||
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
|
||||||
// recursively converts its elements.
|
|
||||||
static std::unique_ptr<Literal> ConvertF32ToBF16(LiteralSlice f32_literal);
|
|
||||||
|
|
||||||
// Asserts that the expected and actual literals are (bitwise) equal for all
|
|
||||||
// elements in the literal. Also, asserts that the rank, dimensions sizes, and
|
|
||||||
// primitive type are equal.
|
|
||||||
static ::testing::AssertionResult Equal(
|
|
||||||
LiteralSlice expected, LiteralSlice actual) TF_MUST_USE_RESULT;
|
|
||||||
|
|
||||||
// Expects that expected and actual are Equal.
|
|
||||||
static void ExpectEqual(LiteralSlice expected, LiteralSlice actual,
|
|
||||||
const string& message = "");
|
|
||||||
|
|
||||||
// Expects that expected and actual are Not Equal.
|
|
||||||
static void ExpectNotEqual(LiteralSlice expected, LiteralSlice actual);
|
|
||||||
|
|
||||||
// Asserts the given literal are (bitwise) equal to given expected values.
|
// Asserts the given literal are (bitwise) equal to given expected values.
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR0Equal(NativeT expected, LiteralSlice actual);
|
static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
|
static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
|
||||||
LiteralSlice actual);
|
const LiteralSlice& actual);
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR2Equal(
|
static void ExpectR2Equal(
|
||||||
std::initializer_list<std::initializer_list<NativeT>> expected,
|
std::initializer_list<std::initializer_list<NativeT>> expected,
|
||||||
LiteralSlice actual);
|
const LiteralSlice& actual);
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR3Equal(
|
static void ExpectR3Equal(
|
||||||
std::initializer_list<
|
std::initializer_list<
|
||||||
std::initializer_list<std::initializer_list<NativeT>>>
|
std::initializer_list<std::initializer_list<NativeT>>>
|
||||||
expected,
|
expected,
|
||||||
LiteralSlice actual);
|
const LiteralSlice& actual);
|
||||||
|
|
||||||
// Asserts the given literal are (bitwise) equal to given array.
|
// Asserts the given literal are (bitwise) equal to given array.
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
|
static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
|
||||||
LiteralSlice actual);
|
const LiteralSlice& actual);
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
|
static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
|
||||||
LiteralSlice actual);
|
const LiteralSlice& actual);
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
|
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
|
||||||
LiteralSlice actual);
|
const LiteralSlice& actual);
|
||||||
|
|
||||||
// Asserts that the expected and actual literals are within the given error
|
// Asserts that the expected and actual literals are within the given error
|
||||||
// bound for all elements. Also, asserts that the rank, dimensions sizes, and
|
// bound for all elements. Also, asserts that the rank, dimensions sizes, and
|
||||||
@ -133,183 +115,138 @@ class LiteralTestUtil {
|
|||||||
// If detailed_message is true, then the error message in the assertion result
|
// If detailed_message is true, then the error message in the assertion result
|
||||||
// will contain a more detailed breakdown of mismatches.
|
// will contain a more detailed breakdown of mismatches.
|
||||||
static ::testing::AssertionResult Near(
|
static ::testing::AssertionResult Near(
|
||||||
LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
|
const LiteralSlice& expected, const LiteralSlice& actual,
|
||||||
bool detailed_message = false) TF_MUST_USE_RESULT;
|
const ErrorSpec& error, bool detailed_message = false) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Expects expected and actual to be Near with the given error.
|
|
||||||
static void ExpectNear(LiteralSlice expected, LiteralSlice actual,
|
|
||||||
const ErrorSpec& error, const string& message = "");
|
|
||||||
|
|
||||||
// Asserts the given literal are within the given error bound of the given
|
// Asserts the given literal are within the given error bound of the given
|
||||||
// expected values. Only supported for floating point values.
|
// expected values. Only supported for floating point values.
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR0Near(NativeT expected, LiteralSlice actual,
|
static void ExpectR0Near(NativeT expected, const LiteralSlice& actual,
|
||||||
const ErrorSpec& error);
|
const ErrorSpec& error);
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
|
static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
|
||||||
LiteralSlice actual, const ErrorSpec& error);
|
const LiteralSlice& actual, const ErrorSpec& error);
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR2Near(
|
static void ExpectR2Near(
|
||||||
std::initializer_list<std::initializer_list<NativeT>> expected,
|
std::initializer_list<std::initializer_list<NativeT>> expected,
|
||||||
LiteralSlice actual, const ErrorSpec& error);
|
const LiteralSlice& actual, const ErrorSpec& error);
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR3Near(
|
static void ExpectR3Near(
|
||||||
std::initializer_list<
|
std::initializer_list<
|
||||||
std::initializer_list<std::initializer_list<NativeT>>>
|
std::initializer_list<std::initializer_list<NativeT>>>
|
||||||
expected,
|
expected,
|
||||||
LiteralSlice actual, const ErrorSpec& error);
|
const LiteralSlice& actual, const ErrorSpec& error);
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR4Near(
|
static void ExpectR4Near(
|
||||||
std::initializer_list<std::initializer_list<
|
std::initializer_list<std::initializer_list<
|
||||||
std::initializer_list<std::initializer_list<NativeT>>>>
|
std::initializer_list<std::initializer_list<NativeT>>>>
|
||||||
expected,
|
expected,
|
||||||
LiteralSlice actual, const ErrorSpec& error);
|
const LiteralSlice& actual, const ErrorSpec& error);
|
||||||
|
|
||||||
// Asserts the given literal are within the given error bound to the given
|
// Asserts the given literal are within the given error bound to the given
|
||||||
// array. Only supported for floating point values.
|
// array. Only supported for floating point values.
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
|
static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
|
||||||
LiteralSlice actual, const ErrorSpec& error);
|
const LiteralSlice& actual,
|
||||||
|
const ErrorSpec& error);
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
|
static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
|
||||||
LiteralSlice actual, const ErrorSpec& error);
|
const LiteralSlice& actual,
|
||||||
|
const ErrorSpec& error);
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
|
static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
|
||||||
LiteralSlice actual, const ErrorSpec& error);
|
const LiteralSlice& actual,
|
||||||
|
const ErrorSpec& error);
|
||||||
|
|
||||||
// If the error spec is given, returns whether the expected and the actual are
|
// If the error spec is given, returns whether the expected and the actual are
|
||||||
// within the error bound; otherwise, returns whether they are equal. Tuples
|
// within the error bound; otherwise, returns whether they are equal. Tuples
|
||||||
// will be compared recursively.
|
// will be compared recursively.
|
||||||
static ::testing::AssertionResult NearOrEqual(
|
static ::testing::AssertionResult NearOrEqual(
|
||||||
LiteralSlice expected, LiteralSlice actual,
|
const LiteralSlice& expected, const LiteralSlice& actual,
|
||||||
const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
|
const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// If the error spec is given, expects the expected and the actual to be near;
|
|
||||||
// otherwise, expects them to be equal. Tuples will be compared recursively.
|
|
||||||
static void ExpectNearOrEqual(
|
|
||||||
LiteralSlice expected, LiteralSlice actual,
|
|
||||||
const tensorflow::gtl::optional<ErrorSpec>& error);
|
|
||||||
|
|
||||||
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
|
|
||||||
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
|
|
||||||
// dimension 1 equal to 8.
|
|
||||||
static string MultiIndexAsString(
|
|
||||||
tensorflow::gtl::ArraySlice<int64> multi_index);
|
|
||||||
|
|
||||||
// Creates a literal with a new shape with the given new dimensions using the
|
|
||||||
// data in the given input literal. For reshaping purposes the (flat) data
|
|
||||||
// buffer of the input literal is assumed to have the given minor_to_major
|
|
||||||
// layout order.
|
|
||||||
static std::unique_ptr<Literal> Reshape(
|
|
||||||
tensorflow::gtl::ArraySlice<int64> new_dimensions,
|
|
||||||
tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal);
|
|
||||||
|
|
||||||
// Creates a literal with the supplied shape, and uses the provided value
|
|
||||||
// generator to populate the literal's values.
|
|
||||||
// Returns the new literal object, or an error Status if failed.
|
|
||||||
template <
|
|
||||||
PrimitiveType type,
|
|
||||||
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
|
||||||
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
|
||||||
const Shape& shape,
|
|
||||||
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
|
|
||||||
|
|
||||||
// Creates a literal with the supplied shape, and initializes the literal
|
|
||||||
// values using a normal distribution with given mean and stddev standard
|
|
||||||
// deviation, and using the engine as entropy generator.
|
|
||||||
// Returns the new literal object, or an error Status if failed.
|
|
||||||
template <
|
|
||||||
PrimitiveType type, typename E,
|
|
||||||
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
|
||||||
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
|
||||||
const Shape& shape, E* engine, T mean, T stddev);
|
|
||||||
|
|
||||||
// Creates a literal with the supplied shape, and initializes the literal
|
|
||||||
// values using a normal distribution with given mean and stddev standard
|
|
||||||
// deviation.
|
|
||||||
// Returns the new literal object, or an error Status if failed.
|
|
||||||
template <
|
|
||||||
PrimitiveType type,
|
|
||||||
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
|
||||||
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
|
||||||
const Shape& shape, T mean, T stddev);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
|
TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
|
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
|
||||||
LiteralSlice actual) {
|
const LiteralSlice& actual) {
|
||||||
ExpectEqual(*Literal::CreateR0<NativeT>(expected), actual);
|
EXPECT_TRUE(Equal(*Literal::CreateR0<NativeT>(expected), actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR1Equal(
|
/* static */ void LiteralTestUtil::ExpectR1Equal(
|
||||||
tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual) {
|
tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual) {
|
||||||
ExpectEqual(*Literal::CreateR1<NativeT>(expected), actual);
|
EXPECT_TRUE(Equal(*Literal::CreateR1<NativeT>(expected), actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR2Equal(
|
/* static */ void LiteralTestUtil::ExpectR2Equal(
|
||||||
std::initializer_list<std::initializer_list<NativeT>> expected,
|
std::initializer_list<std::initializer_list<NativeT>> expected,
|
||||||
LiteralSlice actual) {
|
const LiteralSlice& actual) {
|
||||||
ExpectEqual(*Literal::CreateR2<NativeT>(expected), actual);
|
EXPECT_TRUE(Equal(*Literal::CreateR2<NativeT>(expected), actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR3Equal(
|
/* static */ void LiteralTestUtil::ExpectR3Equal(
|
||||||
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
|
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
|
||||||
expected,
|
expected,
|
||||||
LiteralSlice actual) {
|
const LiteralSlice& actual) {
|
||||||
ExpectEqual(*Literal::CreateR3<NativeT>(expected), actual);
|
EXPECT_TRUE(Equal(*Literal::CreateR3<NativeT>(expected), actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
|
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
|
||||||
const Array2D<NativeT>& expected, LiteralSlice actual) {
|
const Array2D<NativeT>& expected, const LiteralSlice& actual) {
|
||||||
ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual);
|
EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
|
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
|
||||||
const Array3D<NativeT>& expected, LiteralSlice actual) {
|
const Array3D<NativeT>& expected, const LiteralSlice& actual) {
|
||||||
ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual);
|
EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
|
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
|
||||||
const Array4D<NativeT>& expected, LiteralSlice actual) {
|
const Array4D<NativeT>& expected, const LiteralSlice& actual) {
|
||||||
ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual);
|
EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
|
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
|
||||||
LiteralSlice actual,
|
const LiteralSlice& actual,
|
||||||
const ErrorSpec& error) {
|
const ErrorSpec& error) {
|
||||||
ExpectNear(*Literal::CreateR0<NativeT>(expected), actual, error);
|
EXPECT_TRUE(Near(*Literal::CreateR0<NativeT>(expected), actual, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR1Near(
|
/* static */ void LiteralTestUtil::ExpectR1Near(
|
||||||
tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual,
|
tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual,
|
||||||
const ErrorSpec& error) {
|
const ErrorSpec& error) {
|
||||||
ExpectNear(*Literal::CreateR1<NativeT>(expected), actual, error);
|
EXPECT_TRUE(Near(*Literal::CreateR1<NativeT>(expected), actual, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR2Near(
|
/* static */ void LiteralTestUtil::ExpectR2Near(
|
||||||
std::initializer_list<std::initializer_list<NativeT>> expected,
|
std::initializer_list<std::initializer_list<NativeT>> expected,
|
||||||
LiteralSlice actual, const ErrorSpec& error) {
|
const LiteralSlice& actual, const ErrorSpec& error) {
|
||||||
ExpectNear(*Literal::CreateR2<NativeT>(expected), actual, error);
|
EXPECT_TRUE(Near(*Literal::CreateR2<NativeT>(expected), actual, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR3Near(
|
/* static */ void LiteralTestUtil::ExpectR3Near(
|
||||||
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
|
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
|
||||||
expected,
|
expected,
|
||||||
LiteralSlice actual, const ErrorSpec& error) {
|
const LiteralSlice& actual, const ErrorSpec& error) {
|
||||||
ExpectNear(*Literal::CreateR3<NativeT>(expected), actual, error);
|
EXPECT_TRUE(Near(*Literal::CreateR3<NativeT>(expected), actual, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
@ -317,63 +254,29 @@ template <typename NativeT>
|
|||||||
std::initializer_list<std::initializer_list<
|
std::initializer_list<std::initializer_list<
|
||||||
std::initializer_list<std::initializer_list<NativeT>>>>
|
std::initializer_list<std::initializer_list<NativeT>>>>
|
||||||
expected,
|
expected,
|
||||||
LiteralSlice actual, const ErrorSpec& error) {
|
const LiteralSlice& actual, const ErrorSpec& error) {
|
||||||
ExpectNear(*Literal::CreateR4<NativeT>(expected), actual, error);
|
EXPECT_TRUE(Near(*Literal::CreateR4<NativeT>(expected), actual, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
|
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
|
||||||
const Array2D<NativeT>& expected, LiteralSlice actual,
|
const Array2D<NativeT>& expected, const LiteralSlice& actual,
|
||||||
const ErrorSpec& error) {
|
const ErrorSpec& error) {
|
||||||
ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error);
|
EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
|
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
|
||||||
const Array3D<NativeT>& expected, LiteralSlice actual,
|
const Array3D<NativeT>& expected, const LiteralSlice& actual,
|
||||||
const ErrorSpec& error) {
|
const ErrorSpec& error) {
|
||||||
ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error);
|
EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
|
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
|
||||||
const Array4D<NativeT>& expected, LiteralSlice actual,
|
const Array4D<NativeT>& expected, const LiteralSlice& actual,
|
||||||
const ErrorSpec& error) {
|
const ErrorSpec& error) {
|
||||||
ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error);
|
EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error));
|
||||||
}
|
|
||||||
|
|
||||||
template <PrimitiveType type, typename T>
|
|
||||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
|
||||||
LiteralTestUtil::CreateRandomLiteral(
|
|
||||||
const Shape& shape,
|
|
||||||
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
|
|
||||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
|
|
||||||
TF_RET_CHECK(shape.element_type() == type);
|
|
||||||
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
|
|
||||||
TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
|
|
||||||
[&](tensorflow::gtl::ArraySlice<int64> indexes) {
|
|
||||||
return generator(indexes);
|
|
||||||
}));
|
|
||||||
return std::move(literal);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <PrimitiveType type, typename E, typename T>
|
|
||||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
|
||||||
LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
|
|
||||||
T stddev) {
|
|
||||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
|
|
||||||
std::normal_distribution<NativeT> generator(mean, stddev);
|
|
||||||
return CreateRandomLiteral<type, NativeT>(
|
|
||||||
shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
|
|
||||||
return generator(*engine);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <PrimitiveType type, typename T>
|
|
||||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
|
||||||
LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
|
|
||||||
std::minstd_rand0 engine;
|
|
||||||
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -34,7 +34,7 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
|
|||||||
std::unique_ptr<Literal> literal = Literal::MakeTuple({
|
std::unique_ptr<Literal> literal = Literal::MakeTuple({
|
||||||
Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(),
|
Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(),
|
||||||
});
|
});
|
||||||
LiteralTestUtil::ExpectEqual(*literal, *literal);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
|
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
|
||||||
@ -97,6 +97,15 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
|
||||||
|
auto expected = Literal::CreateR1<int32>({1, 2, 3});
|
||||||
|
auto actual = Literal::CreateR1<int32>({4, 5, 6});
|
||||||
|
::testing::AssertionResult result =
|
||||||
|
LiteralTestUtil::Equal(*expected, *actual);
|
||||||
|
EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}"));
|
||||||
|
EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(LiteralTestUtilTest, NearComparatorR1) {
|
TEST(LiteralTestUtilTest, NearComparatorR1) {
|
||||||
auto a =
|
auto a =
|
||||||
Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
|
Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
|
||||||
|
@ -108,7 +108,7 @@ class MultiOutputFusionTest : public HloTestBase {
|
|||||||
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
|
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
|
||||||
auto actual = ExecuteAndTransfer(
|
auto actual = ExecuteAndTransfer(
|
||||||
std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
|
std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
|
||||||
LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
|
EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void RunTest1D(bool manual_fusion, int size) {
|
void RunTest1D(bool manual_fusion, int size) {
|
||||||
@ -168,7 +168,7 @@ class MultiOutputFusionTest : public HloTestBase {
|
|||||||
|
|
||||||
Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
|
Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
|
||||||
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
|
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
|
||||||
LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
|
EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -273,11 +273,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
|
|||||||
&execution_options_));
|
&execution_options_));
|
||||||
}
|
}
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*result1, *result2);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2));
|
||||||
LiteralTestUtil::ExpectEqual(*result1, *result3);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3));
|
||||||
LiteralTestUtil::ExpectNotEqual(*result1, *result4);
|
EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4));
|
||||||
LiteralTestUtil::ExpectNotEqual(*result4, *result5);
|
EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5));
|
||||||
LiteralTestUtil::ExpectNotEqual(*result5, *result6);
|
EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(PrngTest, TenValuesN01) {
|
XLA_TEST_F(PrngTest, TenValuesN01) {
|
||||||
|
@ -656,9 +656,9 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
|
|||||||
std::unique_ptr<Literal> expected =
|
std::unique_ptr<Literal> expected =
|
||||||
Literal::CreateR2FromArray2D<float>(expected_array);
|
Literal::CreateR2FromArray2D<float>(expected_array);
|
||||||
if (use_bfloat16()) {
|
if (use_bfloat16()) {
|
||||||
expected = LiteralTestUtil::ConvertF32ToBF16(*expected);
|
expected = Literal::ConvertF32ToBF16(*expected);
|
||||||
}
|
}
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *actual);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
|
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
|
||||||
@ -731,7 +731,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
|
|||||||
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
|
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
|
||||||
|
|
||||||
std::unique_ptr<Literal> expected =
|
std::unique_ptr<Literal> expected =
|
||||||
LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal);
|
Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
|
||||||
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
|
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
|
||||||
zero_error_spec_);
|
zero_error_spec_);
|
||||||
}
|
}
|
||||||
@ -753,7 +753,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
|
|||||||
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
|
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
|
||||||
|
|
||||||
std::unique_ptr<Literal> expected =
|
std::unique_ptr<Literal> expected =
|
||||||
LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal);
|
Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
|
||||||
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
|
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
|
||||||
zero_error_spec_);
|
zero_error_spec_);
|
||||||
}
|
}
|
||||||
@ -817,7 +817,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
|
|||||||
// Since the reshape is a no-op, verify that it does not change the underlying
|
// Since the reshape is a no-op, verify that it does not change the underlying
|
||||||
// data.
|
// data.
|
||||||
if (use_bfloat16()) {
|
if (use_bfloat16()) {
|
||||||
auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal);
|
auto expected = Literal::ConvertF32ToBF16(*input_literal);
|
||||||
EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
|
EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
|
||||||
} else {
|
} else {
|
||||||
EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
|
EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
|
||||||
@ -886,7 +886,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
|
|||||||
/*new_sizes=*/new_bounds);
|
/*new_sizes=*/new_bounds);
|
||||||
|
|
||||||
std::unique_ptr<Literal> expected =
|
std::unique_ptr<Literal> expected =
|
||||||
LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
|
Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
|
||||||
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
|
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
|
||||||
|
|
||||||
// Specify the requested output shape explicitly to ensure that this reshape
|
// Specify the requested output shape explicitly to ensure that this reshape
|
||||||
@ -915,7 +915,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
|
|||||||
/*new_sizes=*/new_bounds);
|
/*new_sizes=*/new_bounds);
|
||||||
|
|
||||||
std::unique_ptr<Literal> expected =
|
std::unique_ptr<Literal> expected =
|
||||||
LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
|
Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
|
||||||
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
|
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
|
||||||
|
|
||||||
// Specify the requested output shape explicitly to ensure that this reshape
|
// Specify the requested output shape explicitly to ensure that this reshape
|
||||||
@ -944,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
|
|||||||
/*new_sizes=*/new_bounds);
|
/*new_sizes=*/new_bounds);
|
||||||
|
|
||||||
std::unique_ptr<Literal> expected =
|
std::unique_ptr<Literal> expected =
|
||||||
LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
|
Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
|
||||||
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
|
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
|
||||||
|
|
||||||
// Specify the requested output shape explicitly to ensure that this reshape
|
// Specify the requested output shape explicitly to ensure that this reshape
|
||||||
@ -974,7 +974,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
|
|||||||
/*new_sizes=*/new_bounds);
|
/*new_sizes=*/new_bounds);
|
||||||
|
|
||||||
std::unique_ptr<Literal> expected =
|
std::unique_ptr<Literal> expected =
|
||||||
LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
|
Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
|
||||||
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
|
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
|
||||||
|
|
||||||
// Specify the requested output shape explicitly to ensure that this reshape
|
// Specify the requested output shape explicitly to ensure that this reshape
|
||||||
@ -1003,7 +1003,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
|
|||||||
/*new_sizes=*/new_bounds);
|
/*new_sizes=*/new_bounds);
|
||||||
|
|
||||||
std::unique_ptr<Literal> expected =
|
std::unique_ptr<Literal> expected =
|
||||||
LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal)
|
Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
|
||||||
->Relayout(input_literal->shape().layout());
|
->Relayout(input_literal->shape().layout());
|
||||||
|
|
||||||
// Specify the requested output shape explicitly to ensure that this reshape
|
// Specify the requested output shape explicitly to ensure that this reshape
|
||||||
|
@ -100,7 +100,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
|
|||||||
EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
|
EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
|
||||||
|
|
||||||
std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
|
std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
|
||||||
LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
|
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
|
||||||
@ -135,7 +135,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
|
|||||||
EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
|
EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
|
||||||
|
|
||||||
std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
|
std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
|
||||||
LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -41,7 +41,7 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
|
|||||||
client_->TransferToServer(original).ConsumeValueOrDie();
|
client_->TransferToServer(original).ConsumeValueOrDie();
|
||||||
std::unique_ptr<Literal> result =
|
std::unique_ptr<Literal> result =
|
||||||
client_->Transfer(*data).ConsumeValueOrDie();
|
client_->Transfer(*data).ConsumeValueOrDie();
|
||||||
LiteralTestUtil::ExpectEqual(original, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(original, *result));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -390,7 +390,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
|
|||||||
&execution_options_)
|
&execution_options_)
|
||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor);
|
auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor);
|
||||||
LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -431,7 +431,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
|
|||||||
&execution_options_)
|
&execution_options_)
|
||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor);
|
auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor);
|
||||||
LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -175,7 +175,7 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) {
|
|||||||
transfer_manager_->TransferLiteralFromDevice(
|
transfer_manager_->TransferLiteralFromDevice(
|
||||||
stream_executor_, device_buffer));
|
stream_executor_, device_buffer));
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*literal, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
|
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
|
||||||
@ -189,7 +189,7 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
|
|||||||
transfer_manager_->TransferLiteralFromDevice(
|
transfer_manager_->TransferLiteralFromDevice(
|
||||||
stream_executor_, device_buffer));
|
stream_executor_, device_buffer));
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*literal, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
|
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
|
||||||
@ -209,7 +209,7 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
|
|||||||
transfer_manager_->TransferLiteralFromDevice(
|
transfer_manager_->TransferLiteralFromDevice(
|
||||||
stream_executor_, device_buffer));
|
stream_executor_, device_buffer));
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*literal, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
|
XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
|
||||||
@ -224,7 +224,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
|
|||||||
transfer_manager_->TransferLiteralFromDevice(
|
transfer_manager_->TransferLiteralFromDevice(
|
||||||
stream_executor_, device_buffer));
|
stream_executor_, device_buffer));
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*literal, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
|
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
|
||||||
@ -243,7 +243,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
|
|||||||
transfer_manager_->TransferLiteralFromDevice(
|
transfer_manager_->TransferLiteralFromDevice(
|
||||||
stream_executor_, device_buffer));
|
stream_executor_, device_buffer));
|
||||||
|
|
||||||
LiteralTestUtil::ExpectEqual(*literal, *result);
|
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user