Move MakeFakeLiteral from client/lib/testing.h to tests/test_utils.h. Also remove superfluous literal creation methods in that file, and replace them with the existing ones in the Literal class.

Also, optionally print layout in Literal::ToString.

PiperOrigin-RevId: 175076277
This commit is contained in:
Mark Heffernan 2017-11-08 15:35:27 -08:00 committed by TensorFlower Gardener
parent 35febc0cc9
commit a6babd6a4f
20 changed files with 209 additions and 189 deletions

View File

@ -44,6 +44,7 @@ cc_library(
"//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
) )

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
@ -48,62 +49,6 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
} // namespace } // namespace
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
if (ShapeUtil::IsTuple(shape)) {
std::vector<std::unique_ptr<Literal>> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
MakeFakeLiteral(element_shape));
elements.push_back(std::move(element));
}
return Literal::MakeTupleOwned(std::move(elements));
}
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
std::minstd_rand0 engine;
switch (shape.element_type()) {
case F32: {
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
TF_CHECK_OK(literal->Populate<float>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case S32: {
std::uniform_int_distribution<int32> generator(
std::numeric_limits<int32>::lowest(),
std::numeric_limits<int32>::max());
TF_CHECK_OK(literal->Populate<int32>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case S64: {
std::uniform_int_distribution<int64> generator(
std::numeric_limits<int64>::lowest(),
std::numeric_limits<int64>::max());
TF_CHECK_OK(literal->Populate<int64>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
TF_CHECK_OK(literal->Populate<bool>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
default:
return Unimplemented("Unsupported type for fake literal generation: %s",
ShapeUtil::HumanString(shape).c_str());
}
return std::move(literal);
}
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape, std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) { Client* client) {
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) { if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {

View File

@ -26,10 +26,6 @@ limitations under the License.
namespace xla { namespace xla {
// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data generation.
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
// Generates fake data of the given shape on the device or dies. The fake data // Generates fake data of the given shape on the device or dies. The fake data
// is created by performing a computation on the device rather than transferring // is created by performing a computation on the device rather than transferring
// data from the host to the device. // data from the host to the device.

View File

@ -569,9 +569,17 @@ int64 Literal::LinearIndex(
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
} }
string Literal::ToString() const { string Literal::ToString(bool print_layout) const {
std::vector<string> pieces; std::vector<string> pieces;
auto shape_to_string = [print_layout](const Shape& shape) {
if (print_layout) {
return ShapeUtil::HumanStringWithLayout(shape);
} else {
return ShapeUtil::HumanString(shape);
}
};
auto element_to_string = auto element_to_string =
[this](tensorflow::gtl::ArraySlice<int64> indices) -> string { [this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
PrimitiveType element_type = shape().element_type(); PrimitiveType element_type = shape().element_type();
@ -585,7 +593,7 @@ string Literal::ToString() const {
// TODO(b/32894291): refactor this code to reduce code duplication. // TODO(b/32894291): refactor this code to reduce code duplication.
if (ShapeUtil::IsTuple(shape())) { if (ShapeUtil::IsTuple(shape())) {
pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(shape_to_string(shape()));
pieces.push_back(" (\n"); pieces.push_back(" (\n");
pieces.push_back(tensorflow::str_util::Join( pieces.push_back(tensorflow::str_util::Join(
tuple_literals(), ",\n", [](string* out, const Literal& element) { tuple_literals(), ",\n", [](string* out, const Literal& element) {
@ -601,7 +609,7 @@ string Literal::ToString() const {
} }
pieces.push_back("}"); pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 2) { } else if (ShapeUtil::Rank(shape()) == 2) {
pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n"); pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(" { "); pieces.push_back(" { ");
@ -613,7 +621,7 @@ string Literal::ToString() const {
} }
pieces.push_back("}"); pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 3) { } else if (ShapeUtil::Rank(shape()) == 3) {
pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n"); pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(i0 > 0 ? ",\n{" : "{"); pieces.push_back(i0 > 0 ? ",\n{" : "{");
@ -628,7 +636,7 @@ string Literal::ToString() const {
} }
pieces.push_back("\n}"); pieces.push_back("\n}");
} else if (ShapeUtil::Rank(shape()) == 4) { } else if (ShapeUtil::Rank(shape()) == 4) {
pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n"); pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@ -649,7 +657,7 @@ string Literal::ToString() const {
} }
pieces.push_back("}"); pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 5) { } else if (ShapeUtil::Rank(shape()) == 5) {
pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n"); pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@ -676,7 +684,7 @@ string Literal::ToString() const {
} }
pieces.push_back("}"); pieces.push_back("}");
} else { } else {
pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {...}"); pieces.push_back(" {...}");
} }

View File

@ -450,7 +450,7 @@ class Literal {
tensorflow::Status ValidateLiteral() const; tensorflow::Status ValidateLiteral() const;
// Returns a string representation of the literal value. // Returns a string representation of the literal value.
string ToString() const; string ToString(bool print_layout = false) const;
// Invokes the "per cell" callback for each element in the provided // Invokes the "per cell" callback for each element in the provided
// literal with the element's indices and a string representation of // literal with the element's indices and a string representation of

View File

@ -1780,7 +1780,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
) )
@ -1851,7 +1850,6 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
) )

View File

@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
// Test that two identical constants with different layouts are commoned if // Test that two identical constants with different layouts are commoned if
// the pass is not layout sensitive. // the pass is not layout sensitive.
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( auto constant1 = builder.AddInstruction(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}}, HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
/*minor_to_major=*/{0, 1}))); {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant( auto constant2 = builder.AddInstruction(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}}, HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
/*minor_to_major=*/{1, 0}))); {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary( auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2)); constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
// Test that two identical constants with different layouts are *not* commoned // Test that two identical constants with different layouts are *not* commoned
// if the pass is layout sensitive. // if the pass is layout sensitive.
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( auto constant1 = builder.AddInstruction(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}}, HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
/*minor_to_major=*/{0, 1}))); {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant( auto constant2 = builder.AddInstruction(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}}, HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
/*minor_to_major=*/{1, 0}))); {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary( auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2)); constant1->shape(), HloOpcode::kAdd, constant1, constant2));

View File

@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}}; std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
for (auto& minor_to_major : minor_to_majors) { for (auto& minor_to_major : minor_to_majors) {
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
auto constant_literal1 = test_utils::CreateR2LiteralWithLayout<float>( auto constant_literal1 = Literal::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, minor_to_major); {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
auto constant_literal2 = test_utils::CreateR2LiteralWithLayout<float>( auto constant_literal2 = Literal::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, minor_to_major); {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
Shape ashape = constant_literal1->shape(); Shape ashape = constant_literal1->shape();
auto constant1 = builder.AddInstruction( auto constant1 = builder.AddInstruction(
@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
// Verify the layouts of a tuple are assigned properly (the element layouts // Verify the layouts of a tuple are assigned properly (the element layouts
// match their source). // match their source).
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant( auto constant0 = builder.AddInstruction(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}}, HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{0, 1}))); {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( auto constant1 = builder.AddInstruction(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}}, HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{1, 0}))); {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple = builder.AddInstruction( auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1})); HloInstruction::CreateTuple({constant0, constant1}));
@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
TEST_F(LayoutAssignmentTest, TupleSelect) { TEST_F(LayoutAssignmentTest, TupleSelect) {
// Verify layouts of a select with tuple operands is assigned properly. // Verify layouts of a select with tuple operands is assigned properly.
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant( auto constant0 = builder.AddInstruction(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}}, HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{0, 1}))); {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( auto constant1 = builder.AddInstruction(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}}, HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{1, 0}))); {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple0 = builder.AddInstruction( auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1})); HloInstruction::CreateTuple({constant0, constant1}));
auto tuple1 = builder.AddInstruction( auto tuple1 = builder.AddInstruction(

View File

@ -61,13 +61,14 @@ generate_backend_test_macros()
cc_library( cc_library(
name = "test_utils", name = "test_utils",
testonly = True, srcs = ["test_utils.cc"],
hdrs = ["test_utils.h"], hdrs = ["test_utils.h"],
deps = [ deps = [
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
) )

View File

@ -469,8 +469,7 @@ template <typename NativeT>
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1( std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
const int width, NativeT min_value, NativeT max_value, uint32 seed) { const int width, NativeT min_value, NativeT max_value, uint32 seed) {
std::vector<NativeT> result(width); std::vector<NativeT> result(width);
test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value, PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
seed);
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
result[i] = generator.get(); result[i] = generator.get();
} }
@ -482,8 +481,7 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
const int rows, const int cols, NativeT min_value, NativeT max_value, const int rows, const int cols, NativeT min_value, NativeT max_value,
uint32 seed) { uint32 seed) {
auto result = MakeUnique<Array2D<NativeT>>(rows, cols); auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value, PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
seed);
for (int y = 0; y < rows; ++y) { for (int y = 0; y < rows; ++y) {
for (int x = 0; x < cols; ++x) { for (int x = 0; x < cols; ++x) {
(*result)(y, x) = generator.get(); (*result)(y, x) = generator.get();

View File

@ -54,8 +54,8 @@ TEST_F(ClientTest, ExecuteWithLayout) {
.ConsumeValueOrDie(); .ConsumeValueOrDie();
std::unique_ptr<Literal> expected_literal = std::unique_ptr<Literal> expected_literal =
test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}}, Literal::CreateR2WithLayout<int32>(
transfer_layout); {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
auto computed = client_->Transfer(*data, &expected_literal->shape()); auto computed = client_->Transfer(*data, &expected_literal->shape());

View File

@ -138,13 +138,13 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
// layouts. Use these arrays as parameters to a simple computation. If the // layouts. Use these arrays as parameters to a simple computation. If the
// layout of the array changes then computation should be recompiled (cache // layout of the array changes then computation should be recompiled (cache
// miss). // miss).
auto rowmaj_array = test_utils::CreateR2LiteralWithLayout( auto rowmaj_array = Literal::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0}); {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle = auto rowmaj_handle =
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
auto colmaj_array = test_utils::CreateR2LiteralWithLayout( auto colmaj_array = Literal::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1}); {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle = auto colmaj_handle =
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();

View File

@ -264,8 +264,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
ASSERT_TRUE(computed.ok()) << computed.status(); ASSERT_TRUE(computed.ok()) << computed.status();
std::unique_ptr<Literal> expected_literal = std::unique_ptr<Literal> expected_literal =
test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}}, Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
layout); LayoutUtil::MakeLayout(layout));
LiteralTestUtil::AssertEqualShapesAndLayouts( LiteralTestUtil::AssertEqualShapesAndLayouts(
expected_literal->shape(), computed.ValueOrDie()->shape()); expected_literal->shape(), computed.ValueOrDie()->shape());
LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());

View File

@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
bool rhs_row_major) { bool rhs_row_major) {
auto lhs_handle = auto lhs_handle =
client_ client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>( ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 2.0}, {3.0, -4.0}}, {{1.0, 2.0}, {3.0, -4.0}},
MinorToMajorForIsRowMajor(lhs_row_major))) LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie(); .ConsumeValueOrDie();
auto rhs_handle = auto rhs_handle =
client_ client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>( ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 6.0}, {7.0, -4.0}}, {{1.0, 6.0}, {7.0, -4.0}},
MinorToMajorForIsRowMajor(rhs_row_major))) LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
.ConsumeValueOrDie(); .ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -362,15 +362,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
bool rhs_row_major) { bool rhs_row_major) {
auto lhs_handle = auto lhs_handle =
client_ client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>( ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
MinorToMajorForIsRowMajor(lhs_row_major))) LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie(); .ConsumeValueOrDie();
auto rhs_handle = auto rhs_handle =
client_ client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>( ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
MinorToMajorForIsRowMajor(rhs_row_major))) LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
.ConsumeValueOrDie(); .ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -420,13 +420,14 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
XLA_TEST_F(DotOperationTest, MatrixVectorC64) { XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle = auto lhs_handle =
client_ client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>( ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, {1, 0})) {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie(); .ConsumeValueOrDie();
auto rhs_handle = auto rhs_handle =
client_ client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>( ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, {1, 0})) {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie(); .ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());

View File

@ -136,16 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie(); auto computation = builder.Build().ConsumeValueOrDie();
// Create x as a col-major array. // Create x as a col-major array.
auto x_array = LiteralToShapedBuffer( auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
*test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}}, {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
/*minor_to_major=*/{0, 1}));
EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(), EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(),
LayoutUtil::MakeLayout({0, 1}))); LayoutUtil::MakeLayout({0, 1})));
// Create y as a row-major array. // Create y as a row-major array.
auto y_array = LiteralToShapedBuffer( auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
*test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}}, {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
/*minor_to_major=*/{1, 0}));
EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(), EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(),
LayoutUtil::MakeLayout({1, 0}))); LayoutUtil::MakeLayout({1, 0})));

View File

@ -405,13 +405,13 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438). // for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddWithMixedLayouts) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
std::unique_ptr<Literal> param0_literal = std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout(
test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0}); {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data = std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal = std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout(
test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1}); {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data = std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); client_->TransferToServer(*param1_literal).ConsumeValueOrDie();

View File

@ -0,0 +1,120 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/primitive_util.h"
namespace xla {
namespace {
template <typename FloatT>
void PopulateWithRandomFloatingPointData(Literal* literal) {
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
std::minstd_rand0 engine;
std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f);
TF_CHECK_OK(literal->Populate<FloatT>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
}
template <typename IntT>
void PopulateWithRandomIntegralData(Literal* literal) {
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<IntT>());
std::minstd_rand0 engine;
std::uniform_int_distribution<IntT> generator(
std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
TF_CHECK_OK(literal->Populate<IntT>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
}
} // namespace
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
if (ShapeUtil::IsTuple(shape)) {
std::vector<std::unique_ptr<Literal>> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
MakeFakeLiteral(element_shape));
elements.push_back(std::move(element));
}
return Literal::MakeTupleOwned(std::move(elements));
}
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
switch (shape.element_type()) {
case F32:
PopulateWithRandomFloatingPointData<float>(literal.get());
break;
case F64:
PopulateWithRandomFloatingPointData<double>(literal.get());
break;
case S8:
PopulateWithRandomIntegralData<int8>(literal.get());
break;
case U8:
PopulateWithRandomIntegralData<uint8>(literal.get());
break;
case S16:
PopulateWithRandomIntegralData<int16>(literal.get());
break;
case U16:
PopulateWithRandomIntegralData<uint16>(literal.get());
break;
case S32:
PopulateWithRandomIntegralData<int32>(literal.get());
break;
case U32:
PopulateWithRandomIntegralData<uint32>(literal.get());
break;
case S64:
PopulateWithRandomIntegralData<int64>(literal.get());
break;
case U64:
PopulateWithRandomIntegralData<uint64>(literal.get());
break;
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
std::minstd_rand0 engine;
TF_CHECK_OK(literal->Populate<bool>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
default:
return Unimplemented("Unsupported type for fake literal generation: %s",
ShapeUtil::HumanString(shape).c_str());
}
return std::move(literal);
}
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
const HloModule& module) {
std::vector<std::unique_ptr<Literal>> arguments;
for (const ShapeLayout& shape_layout :
module.config().entry_computation_layout().parameter_layouts()) {
TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape()));
arguments.push_back(std::move(literal));
}
return std::move(arguments);
}
} // namespace xla

View File

@ -23,12 +23,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace xla { namespace xla {
namespace test_utils {
// A class which generates pseudorandom numbers of a given type within a given // A class which generates pseudorandom numbers of a given type within a given
// range. Not cryptographically secure and likely not perfectly evenly // range. Not cryptographically secure and likely not perfectly evenly
@ -53,63 +53,15 @@ class PseudorandomGenerator {
std::mt19937 generator_; std::mt19937 generator_;
}; };
// Convenience function for creating a rank-2 array with arbitrary layout. // Generates fake data in a literal of the given shape, or returns an error
template <typename NativeT> // status if the element type is currently unhandled for fake data generation.
std::unique_ptr<Literal> CreateR2LiteralWithLayout( StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
std::initializer_list<std::initializer_list<NativeT>> values,
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
auto literal = MakeUnique<Literal>();
const int64 d0 = values.size();
const int64 d1 = values.begin()->size();
literal.get()->PopulateWithValue<NativeT>(0, {d0, d1});
*literal->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout(minor_to_major);
TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
int64 dim0 = 0; // Generates a vector of arguments containing fake data. The number, shape and
for (auto inner_list : values) { // layout of the arguments is appropriate for given HLO module.
int64 dim1 = 0; StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
for (auto value : inner_list) { const HloModule& module);
literal.get()->Set({dim0, dim1}, value);
++dim1;
}
++dim0;
}
return literal;
}
// Convenience function for creating a rank-3 array with arbitrary layout.
template <typename NativeT>
std::unique_ptr<Literal> CreateR3LiteralWithLayout(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values,
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
auto literal = MakeUnique<Literal>();
const int64 d0 = values.size();
const int64 d1 = values.begin()->size();
const int64 d2 = values.begin()->begin()->size();
literal.get()->PopulateWithValue<NativeT>(0, {d0, d1, d2});
*literal->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout(minor_to_major);
TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
int64 dim0 = 0;
for (auto inner_list : values) {
int64 dim1 = 0;
for (auto inner_inner_list : inner_list) {
int64 dim2 = 0;
for (auto value : inner_inner_list) {
literal.get()->Set({dim0, dim1, dim2}, value);
++dim2;
}
++dim1;
}
++dim0;
}
return literal;
}
} // namespace test_utils
} // namespace xla } // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_

View File

@ -88,6 +88,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:testing", "//tensorflow/compiler/xla/client/lib:testing",
"//tensorflow/compiler/xla/service:session_proto", "//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"