Move MakeFakeLiteral from client/lib/testing.h to tests/test_utils.h. Also remove superfluous literal creation methods in that file, and replace them with the existing ones in the Literal class.
Also, optionally print layout in Literal::ToString. PiperOrigin-RevId: 175076277
This commit is contained in:
parent
35febc0cc9
commit
a6babd6a4f
@ -44,6 +44,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:computation",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
@ -48,62 +49,6 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
std::vector<std::unique_ptr<Literal>> elements;
|
||||
for (const Shape& element_shape : shape.tuple_shapes()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
|
||||
MakeFakeLiteral(element_shape));
|
||||
elements.push_back(std::move(element));
|
||||
}
|
||||
return Literal::MakeTupleOwned(std::move(elements));
|
||||
}
|
||||
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
|
||||
std::minstd_rand0 engine;
|
||||
switch (shape.element_type()) {
|
||||
case F32: {
|
||||
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
|
||||
TF_CHECK_OK(literal->Populate<float>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
case S32: {
|
||||
std::uniform_int_distribution<int32> generator(
|
||||
std::numeric_limits<int32>::lowest(),
|
||||
std::numeric_limits<int32>::max());
|
||||
TF_CHECK_OK(literal->Populate<int32>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
case S64: {
|
||||
std::uniform_int_distribution<int64> generator(
|
||||
std::numeric_limits<int64>::lowest(),
|
||||
std::numeric_limits<int64>::max());
|
||||
TF_CHECK_OK(literal->Populate<int64>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
case PRED: {
|
||||
std::uniform_int_distribution<int> generator(0, 1);
|
||||
TF_CHECK_OK(literal->Populate<bool>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return Unimplemented("Unsupported type for fake literal generation: %s",
|
||||
ShapeUtil::HumanString(shape).c_str());
|
||||
}
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
|
||||
Client* client) {
|
||||
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {
|
||||
|
@ -26,10 +26,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Generates fake data in a literal of the given shape, or returns an error
|
||||
// status if the element type is currently unhandled for fake data generation.
|
||||
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
|
||||
|
||||
// Generates fake data of the given shape on the device or dies. The fake data
|
||||
// is created by performing a computation on the device rather than transferring
|
||||
// data from the host to the device.
|
||||
|
@ -569,9 +569,17 @@ int64 Literal::LinearIndex(
|
||||
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
|
||||
}
|
||||
|
||||
string Literal::ToString() const {
|
||||
string Literal::ToString(bool print_layout) const {
|
||||
std::vector<string> pieces;
|
||||
|
||||
auto shape_to_string = [print_layout](const Shape& shape) {
|
||||
if (print_layout) {
|
||||
return ShapeUtil::HumanStringWithLayout(shape);
|
||||
} else {
|
||||
return ShapeUtil::HumanString(shape);
|
||||
}
|
||||
};
|
||||
|
||||
auto element_to_string =
|
||||
[this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
|
||||
PrimitiveType element_type = shape().element_type();
|
||||
@ -585,7 +593,7 @@ string Literal::ToString() const {
|
||||
|
||||
// TODO(b/32894291): refactor this code to reduce code duplication.
|
||||
if (ShapeUtil::IsTuple(shape())) {
|
||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
||||
pieces.push_back(shape_to_string(shape()));
|
||||
pieces.push_back(" (\n");
|
||||
pieces.push_back(tensorflow::str_util::Join(
|
||||
tuple_literals(), ",\n", [](string* out, const Literal& element) {
|
||||
@ -601,7 +609,7 @@ string Literal::ToString() const {
|
||||
}
|
||||
pieces.push_back("}");
|
||||
} else if (ShapeUtil::Rank(shape()) == 2) {
|
||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
||||
pieces.push_back(shape_to_string(shape()));
|
||||
pieces.push_back(" {\n");
|
||||
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
||||
pieces.push_back(" { ");
|
||||
@ -613,7 +621,7 @@ string Literal::ToString() const {
|
||||
}
|
||||
pieces.push_back("}");
|
||||
} else if (ShapeUtil::Rank(shape()) == 3) {
|
||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
||||
pieces.push_back(shape_to_string(shape()));
|
||||
pieces.push_back(" {\n");
|
||||
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
||||
pieces.push_back(i0 > 0 ? ",\n{" : "{");
|
||||
@ -628,7 +636,7 @@ string Literal::ToString() const {
|
||||
}
|
||||
pieces.push_back("\n}");
|
||||
} else if (ShapeUtil::Rank(shape()) == 4) {
|
||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
||||
pieces.push_back(shape_to_string(shape()));
|
||||
pieces.push_back(" {\n");
|
||||
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
||||
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
|
||||
@ -649,7 +657,7 @@ string Literal::ToString() const {
|
||||
}
|
||||
pieces.push_back("}");
|
||||
} else if (ShapeUtil::Rank(shape()) == 5) {
|
||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
||||
pieces.push_back(shape_to_string(shape()));
|
||||
pieces.push_back(" {\n");
|
||||
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
||||
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
|
||||
@ -676,7 +684,7 @@ string Literal::ToString() const {
|
||||
}
|
||||
pieces.push_back("}");
|
||||
} else {
|
||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
||||
pieces.push_back(shape_to_string(shape()));
|
||||
pieces.push_back(" {...}");
|
||||
}
|
||||
|
||||
|
@ -450,7 +450,7 @@ class Literal {
|
||||
tensorflow::Status ValidateLiteral() const;
|
||||
|
||||
// Returns a string representation of the literal value.
|
||||
string ToString() const;
|
||||
string ToString(bool print_layout = false) const;
|
||||
|
||||
// Invokes the "per cell" callback for each element in the provided
|
||||
// literal with the element's indices and a string representation of
|
||||
|
@ -1780,7 +1780,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
@ -1851,7 +1850,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
|
||||
// Test that two identical constants with different layouts are commoned if
|
||||
// the pass is not layout sensitive.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
||||
/*minor_to_major=*/{0, 1})));
|
||||
auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
||||
/*minor_to_major=*/{1, 0})));
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
|
||||
auto constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
|
||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
|
||||
|
||||
@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
|
||||
// Test that two identical constants with different layouts are *not* commoned
|
||||
// if the pass is layout sensitive.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
||||
/*minor_to_major=*/{0, 1})));
|
||||
auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
||||
/*minor_to_major=*/{1, 0})));
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
|
||||
auto constant2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
|
||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
|
||||
|
||||
|
@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
|
||||
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
|
||||
for (auto& minor_to_major : minor_to_majors) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant_literal1 = test_utils::CreateR2LiteralWithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, minor_to_major);
|
||||
auto constant_literal2 = test_utils::CreateR2LiteralWithLayout<float>(
|
||||
{{5.0, 6.0}, {7.0, 8.0}}, minor_to_major);
|
||||
auto constant_literal1 = Literal::CreateR2WithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
|
||||
auto constant_literal2 = Literal::CreateR2WithLayout<float>(
|
||||
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
|
||||
Shape ashape = constant_literal1->shape();
|
||||
|
||||
auto constant1 = builder.AddInstruction(
|
||||
@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
|
||||
// Verify the layouts of a tuple are assigned properly (the element layouts
|
||||
// match their source).
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
||||
{0, 1})));
|
||||
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
||||
{1, 0})));
|
||||
auto constant0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
|
||||
auto tuple = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({constant0, constant1}));
|
||||
|
||||
@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
|
||||
TEST_F(LayoutAssignmentTest, TupleSelect) {
|
||||
// Verify layouts of a select with tuple operands is assigned properly.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
||||
{0, 1})));
|
||||
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
||||
{1, 0})));
|
||||
auto constant0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
|
||||
auto tuple0 = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({constant0, constant1}));
|
||||
auto tuple1 = builder.AddInstruction(
|
||||
|
@ -61,13 +61,14 @@ generate_backend_test_macros()
|
||||
|
||||
cc_library(
|
||||
name = "test_utils",
|
||||
testonly = True,
|
||||
srcs = ["test_utils.cc"],
|
||||
hdrs = ["test_utils.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
@ -469,8 +469,7 @@ template <typename NativeT>
|
||||
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
|
||||
const int width, NativeT min_value, NativeT max_value, uint32 seed) {
|
||||
std::vector<NativeT> result(width);
|
||||
test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
|
||||
seed);
|
||||
PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
|
||||
for (int i = 0; i < width; ++i) {
|
||||
result[i] = generator.get();
|
||||
}
|
||||
@ -482,8 +481,7 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
|
||||
const int rows, const int cols, NativeT min_value, NativeT max_value,
|
||||
uint32 seed) {
|
||||
auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
|
||||
test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
|
||||
seed);
|
||||
PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
|
||||
for (int y = 0; y < rows; ++y) {
|
||||
for (int x = 0; x < cols; ++x) {
|
||||
(*result)(y, x) = generator.get();
|
||||
|
@ -54,8 +54,8 @@ TEST_F(ClientTest, ExecuteWithLayout) {
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
|
||||
transfer_layout);
|
||||
Literal::CreateR2WithLayout<int32>(
|
||||
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
|
||||
|
||||
auto computed = client_->Transfer(*data, &expected_literal->shape());
|
||||
|
||||
|
@ -138,13 +138,13 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
|
||||
// layouts. Use these arrays as parameters to a simple computation. If the
|
||||
// layout of the array changes then computation should be recompiled (cache
|
||||
// miss).
|
||||
auto rowmaj_array = test_utils::CreateR2LiteralWithLayout(
|
||||
{{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0});
|
||||
auto rowmaj_array = Literal::CreateR2WithLayout(
|
||||
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
|
||||
auto rowmaj_handle =
|
||||
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
|
||||
|
||||
auto colmaj_array = test_utils::CreateR2LiteralWithLayout(
|
||||
{{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1});
|
||||
auto colmaj_array = Literal::CreateR2WithLayout(
|
||||
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
|
||||
auto colmaj_handle =
|
||||
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
|
||||
|
||||
|
@ -264,8 +264,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
|
||||
ASSERT_TRUE(computed.ok()) << computed.status();
|
||||
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
|
||||
layout);
|
||||
Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
|
||||
LayoutUtil::MakeLayout(layout));
|
||||
LiteralTestUtil::AssertEqualShapesAndLayouts(
|
||||
expected_literal->shape(), computed.ValueOrDie()->shape());
|
||||
LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
|
||||
|
@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
|
||||
bool rhs_row_major) {
|
||||
auto lhs_handle =
|
||||
client_
|
||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
|
||||
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
|
||||
{{1.0, 2.0}, {3.0, -4.0}},
|
||||
MinorToMajorForIsRowMajor(lhs_row_major)))
|
||||
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
|
||||
.ConsumeValueOrDie();
|
||||
auto rhs_handle =
|
||||
client_
|
||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
|
||||
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
|
||||
{{1.0, 6.0}, {7.0, -4.0}},
|
||||
MinorToMajorForIsRowMajor(rhs_row_major)))
|
||||
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
@ -362,15 +362,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
|
||||
bool rhs_row_major) {
|
||||
auto lhs_handle =
|
||||
client_
|
||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
|
||||
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
|
||||
{{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
|
||||
MinorToMajorForIsRowMajor(lhs_row_major)))
|
||||
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
|
||||
.ConsumeValueOrDie();
|
||||
auto rhs_handle =
|
||||
client_
|
||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
|
||||
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
|
||||
{{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
|
||||
MinorToMajorForIsRowMajor(rhs_row_major)))
|
||||
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
@ -420,13 +420,14 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
|
||||
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
|
||||
auto lhs_handle =
|
||||
client_
|
||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
|
||||
{{1.0, 2.0, 3.0, -4.0}}, {1, 0}))
|
||||
->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
|
||||
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
|
||||
.ConsumeValueOrDie();
|
||||
auto rhs_handle =
|
||||
client_
|
||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
|
||||
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, {1, 0}))
|
||||
->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
|
||||
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
|
||||
LayoutUtil::MakeLayout({1, 0})))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
@ -136,16 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
|
||||
auto computation = builder.Build().ConsumeValueOrDie();
|
||||
|
||||
// Create x as a col-major array.
|
||||
auto x_array = LiteralToShapedBuffer(
|
||||
*test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}},
|
||||
/*minor_to_major=*/{0, 1}));
|
||||
auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
|
||||
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(),
|
||||
LayoutUtil::MakeLayout({0, 1})));
|
||||
|
||||
// Create y as a row-major array.
|
||||
auto y_array = LiteralToShapedBuffer(
|
||||
*test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}},
|
||||
/*minor_to_major=*/{1, 0}));
|
||||
auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
|
||||
{{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(),
|
||||
LayoutUtil::MakeLayout({1, 0})));
|
||||
|
||||
|
@ -405,13 +405,13 @@ TEST_F(MapTest, MapBinaryAdder) {
|
||||
// for Map that used to fail in shape inference (b/28989438).
|
||||
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
std::unique_ptr<Literal> param0_literal =
|
||||
test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0});
|
||||
std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout(
|
||||
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
|
||||
std::unique_ptr<GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> param1_literal =
|
||||
test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1});
|
||||
std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout(
|
||||
{{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
|
||||
std::unique_ptr<GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
|
||||
|
120
tensorflow/compiler/xla/tests/test_utils.cc
Normal file
120
tensorflow/compiler/xla/tests/test_utils.cc
Normal file
@ -0,0 +1,120 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename FloatT>
|
||||
void PopulateWithRandomFloatingPointData(Literal* literal) {
|
||||
CHECK_EQ(literal->shape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<FloatT>());
|
||||
std::minstd_rand0 engine;
|
||||
std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f);
|
||||
TF_CHECK_OK(literal->Populate<FloatT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
void PopulateWithRandomIntegralData(Literal* literal) {
|
||||
CHECK_EQ(literal->shape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<IntT>());
|
||||
std::minstd_rand0 engine;
|
||||
std::uniform_int_distribution<IntT> generator(
|
||||
std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
|
||||
TF_CHECK_OK(literal->Populate<IntT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
std::vector<std::unique_ptr<Literal>> elements;
|
||||
for (const Shape& element_shape : shape.tuple_shapes()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
|
||||
MakeFakeLiteral(element_shape));
|
||||
elements.push_back(std::move(element));
|
||||
}
|
||||
return Literal::MakeTupleOwned(std::move(elements));
|
||||
}
|
||||
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
|
||||
switch (shape.element_type()) {
|
||||
case F32:
|
||||
PopulateWithRandomFloatingPointData<float>(literal.get());
|
||||
break;
|
||||
case F64:
|
||||
PopulateWithRandomFloatingPointData<double>(literal.get());
|
||||
break;
|
||||
case S8:
|
||||
PopulateWithRandomIntegralData<int8>(literal.get());
|
||||
break;
|
||||
case U8:
|
||||
PopulateWithRandomIntegralData<uint8>(literal.get());
|
||||
break;
|
||||
case S16:
|
||||
PopulateWithRandomIntegralData<int16>(literal.get());
|
||||
break;
|
||||
case U16:
|
||||
PopulateWithRandomIntegralData<uint16>(literal.get());
|
||||
break;
|
||||
case S32:
|
||||
PopulateWithRandomIntegralData<int32>(literal.get());
|
||||
break;
|
||||
case U32:
|
||||
PopulateWithRandomIntegralData<uint32>(literal.get());
|
||||
break;
|
||||
case S64:
|
||||
PopulateWithRandomIntegralData<int64>(literal.get());
|
||||
break;
|
||||
case U64:
|
||||
PopulateWithRandomIntegralData<uint64>(literal.get());
|
||||
break;
|
||||
case PRED: {
|
||||
std::uniform_int_distribution<int> generator(0, 1);
|
||||
std::minstd_rand0 engine;
|
||||
TF_CHECK_OK(literal->Populate<bool>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||
return generator(engine);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return Unimplemented("Unsupported type for fake literal generation: %s",
|
||||
ShapeUtil::HumanString(shape).c_str());
|
||||
}
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
|
||||
const HloModule& module) {
|
||||
std::vector<std::unique_ptr<Literal>> arguments;
|
||||
for (const ShapeLayout& shape_layout :
|
||||
module.config().entry_computation_layout().parameter_layouts()) {
|
||||
TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape()));
|
||||
arguments.push_back(std::move(literal));
|
||||
}
|
||||
return std::move(arguments);
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -23,12 +23,12 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace test_utils {
|
||||
|
||||
// A class which generates pseudorandom numbers of a given type within a given
|
||||
// range. Not cryptographically secure and likely not perfectly evenly
|
||||
@ -53,63 +53,15 @@ class PseudorandomGenerator {
|
||||
std::mt19937 generator_;
|
||||
};
|
||||
|
||||
// Convenience function for creating a rank-2 array with arbitrary layout.
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<Literal> CreateR2LiteralWithLayout(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
|
||||
auto literal = MakeUnique<Literal>();
|
||||
const int64 d0 = values.size();
|
||||
const int64 d1 = values.begin()->size();
|
||||
literal.get()->PopulateWithValue<NativeT>(0, {d0, d1});
|
||||
*literal->mutable_shape()->mutable_layout() =
|
||||
LayoutUtil::MakeLayout(minor_to_major);
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
|
||||
// Generates fake data in a literal of the given shape, or returns an error
|
||||
// status if the element type is currently unhandled for fake data generation.
|
||||
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
|
||||
|
||||
int64 dim0 = 0;
|
||||
for (auto inner_list : values) {
|
||||
int64 dim1 = 0;
|
||||
for (auto value : inner_list) {
|
||||
literal.get()->Set({dim0, dim1}, value);
|
||||
++dim1;
|
||||
}
|
||||
++dim0;
|
||||
}
|
||||
return literal;
|
||||
}
|
||||
// Generates a vector of arguments containing fake data. The number, shape and
|
||||
// layout of the arguments is appropriate for given HLO module.
|
||||
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
|
||||
const HloModule& module);
|
||||
|
||||
// Convenience function for creating a rank-3 array with arbitrary layout.
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<Literal> CreateR3LiteralWithLayout(
|
||||
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
|
||||
values,
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
|
||||
auto literal = MakeUnique<Literal>();
|
||||
const int64 d0 = values.size();
|
||||
const int64 d1 = values.begin()->size();
|
||||
const int64 d2 = values.begin()->begin()->size();
|
||||
literal.get()->PopulateWithValue<NativeT>(0, {d0, d1, d2});
|
||||
*literal->mutable_shape()->mutable_layout() =
|
||||
LayoutUtil::MakeLayout(minor_to_major);
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
|
||||
|
||||
int64 dim0 = 0;
|
||||
for (auto inner_list : values) {
|
||||
int64 dim1 = 0;
|
||||
for (auto inner_inner_list : inner_list) {
|
||||
int64 dim2 = 0;
|
||||
for (auto value : inner_inner_list) {
|
||||
literal.get()->Set({dim0, dim1, dim2}, value);
|
||||
++dim2;
|
||||
}
|
||||
++dim1;
|
||||
}
|
||||
++dim0;
|
||||
}
|
||||
return literal;
|
||||
}
|
||||
|
||||
} // namespace test_utils
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
|
||||
|
@ -88,6 +88,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client/lib:testing",
|
||||
"//tensorflow/compiler/xla/service:session_proto",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
|
Loading…
Reference in New Issue
Block a user