Move MakeFakeLiteral from client/lib/testing.h to tests/test_utils.h. Also remove superfluous literal creation methods in that file, and replace them with the existing ones in the Literal class.

Also, optionally print layout in Literal::ToString.

PiperOrigin-RevId: 175076277
This commit is contained in:
Mark Heffernan 2017-11-08 15:35:27 -08:00 committed by TensorFlower Gardener
parent 35febc0cc9
commit a6babd6a4f
20 changed files with 209 additions and 189 deletions

View File

@ -44,6 +44,7 @@ cc_library(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
],
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -48,62 +49,6 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
} // namespace
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
if (ShapeUtil::IsTuple(shape)) {
std::vector<std::unique_ptr<Literal>> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
MakeFakeLiteral(element_shape));
elements.push_back(std::move(element));
}
return Literal::MakeTupleOwned(std::move(elements));
}
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
std::minstd_rand0 engine;
switch (shape.element_type()) {
case F32: {
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
TF_CHECK_OK(literal->Populate<float>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case S32: {
std::uniform_int_distribution<int32> generator(
std::numeric_limits<int32>::lowest(),
std::numeric_limits<int32>::max());
TF_CHECK_OK(literal->Populate<int32>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case S64: {
std::uniform_int_distribution<int64> generator(
std::numeric_limits<int64>::lowest(),
std::numeric_limits<int64>::max());
TF_CHECK_OK(literal->Populate<int64>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
TF_CHECK_OK(literal->Populate<bool>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
default:
return Unimplemented("Unsupported type for fake literal generation: %s",
ShapeUtil::HumanString(shape).c_str());
}
return std::move(literal);
}
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {

View File

@ -26,10 +26,6 @@ limitations under the License.
namespace xla {
// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data generation.
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
// Generates fake data of the given shape on the device or dies. The fake data
// is created by performing a computation on the device rather than transferring
// data from the host to the device.

View File

@ -569,9 +569,17 @@ int64 Literal::LinearIndex(
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
}
string Literal::ToString() const {
string Literal::ToString(bool print_layout) const {
std::vector<string> pieces;
auto shape_to_string = [print_layout](const Shape& shape) {
if (print_layout) {
return ShapeUtil::HumanStringWithLayout(shape);
} else {
return ShapeUtil::HumanString(shape);
}
};
auto element_to_string =
[this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
PrimitiveType element_type = shape().element_type();
@ -585,7 +593,7 @@ string Literal::ToString() const {
// TODO(b/32894291): refactor this code to reduce code duplication.
if (ShapeUtil::IsTuple(shape())) {
pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(shape_to_string(shape()));
pieces.push_back(" (\n");
pieces.push_back(tensorflow::str_util::Join(
tuple_literals(), ",\n", [](string* out, const Literal& element) {
@ -601,7 +609,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 2) {
pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(" { ");
@ -613,7 +621,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 3) {
pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(i0 > 0 ? ",\n{" : "{");
@ -628,7 +636,7 @@ string Literal::ToString() const {
}
pieces.push_back("\n}");
} else if (ShapeUtil::Rank(shape()) == 4) {
pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@ -649,7 +657,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 5) {
pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@ -676,7 +684,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else {
pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {...}");
}

View File

@ -450,7 +450,7 @@ class Literal {
tensorflow::Status ValidateLiteral() const;
// Returns a string representation of the literal value.
string ToString() const;
string ToString(bool print_layout = false) const;
// Invokes the "per cell" callback for each element in the provided
// literal with the element's indices and a string representation of

View File

@ -1780,7 +1780,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
)
@ -1851,7 +1850,6 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
)

View File

@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
// Test that two identical constants with different layouts are commoned if
// the pass is not layout sensitive.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
/*minor_to_major=*/{0, 1})));
auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
/*minor_to_major=*/{1, 0})));
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant2 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
// Test that two identical constants with different layouts are *not* commoned
// if the pass is layout sensitive.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
/*minor_to_major=*/{0, 1})));
auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
/*minor_to_major=*/{1, 0})));
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant2 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));

View File

@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
for (auto& minor_to_major : minor_to_majors) {
auto builder = HloComputation::Builder(TestName());
auto constant_literal1 = test_utils::CreateR2LiteralWithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, minor_to_major);
auto constant_literal2 = test_utils::CreateR2LiteralWithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, minor_to_major);
auto constant_literal1 = Literal::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
auto constant_literal2 = Literal::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
Shape ashape = constant_literal1->shape();
auto constant1 = builder.AddInstruction(
@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
// Verify the layouts of a tuple are assigned properly (the element layouts
// match their source).
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{0, 1})));
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{1, 0})));
auto constant0 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
TEST_F(LayoutAssignmentTest, TupleSelect) {
// Verify layouts of a select with tuple operands is assigned properly.
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{0, 1})));
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{1, 0})));
auto constant0 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
auto tuple1 = builder.AddInstruction(

View File

@ -61,13 +61,14 @@ generate_backend_test_macros()
cc_library(
name = "test_utils",
testonly = True,
srcs = ["test_utils.cc"],
hdrs = ["test_utils.h"],
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
],
)

View File

@ -469,8 +469,7 @@ template <typename NativeT>
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
const int width, NativeT min_value, NativeT max_value, uint32 seed) {
std::vector<NativeT> result(width);
test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
seed);
PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
for (int i = 0; i < width; ++i) {
result[i] = generator.get();
}
@ -482,8 +481,7 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
const int rows, const int cols, NativeT min_value, NativeT max_value,
uint32 seed) {
auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
seed);
PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
for (int y = 0; y < rows; ++y) {
for (int x = 0; x < cols; ++x) {
(*result)(y, x) = generator.get();

View File

@ -54,8 +54,8 @@ TEST_F(ClientTest, ExecuteWithLayout) {
.ConsumeValueOrDie();
std::unique_ptr<Literal> expected_literal =
test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
transfer_layout);
Literal::CreateR2WithLayout<int32>(
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
auto computed = client_->Transfer(*data, &expected_literal->shape());

View File

@ -138,13 +138,13 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
// layouts. Use these arrays as parameters to a simple computation. If the
// layout of the array changes then computation should be recompiled (cache
// miss).
auto rowmaj_array = test_utils::CreateR2LiteralWithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0});
auto rowmaj_array = Literal::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle =
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
auto colmaj_array = test_utils::CreateR2LiteralWithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1});
auto colmaj_array = Literal::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle =
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();

View File

@ -264,8 +264,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
ASSERT_TRUE(computed.ok()) << computed.status();
std::unique_ptr<Literal> expected_literal =
test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
layout);
Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
LayoutUtil::MakeLayout(layout));
LiteralTestUtil::AssertEqualShapesAndLayouts(
expected_literal->shape(), computed.ValueOrDie()->shape());
LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());

View File

@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
bool rhs_row_major) {
auto lhs_handle =
client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 2.0}, {3.0, -4.0}},
MinorToMajorForIsRowMajor(lhs_row_major)))
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 6.0}, {7.0, -4.0}},
MinorToMajorForIsRowMajor(rhs_row_major)))
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
.ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());
@ -362,15 +362,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
bool rhs_row_major) {
auto lhs_handle =
client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
MinorToMajorForIsRowMajor(lhs_row_major)))
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
MinorToMajorForIsRowMajor(rhs_row_major)))
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
.ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());
@ -420,13 +420,14 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle =
client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, {1, 0}))
->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
auto rhs_handle =
client_
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, {1, 0}))
->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());

View File

@ -136,16 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
// Create x as a col-major array.
auto x_array = LiteralToShapedBuffer(
*test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}},
/*minor_to_major=*/{0, 1}));
auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
// Create y as a row-major array.
auto y_array = LiteralToShapedBuffer(
*test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}},
/*minor_to_major=*/{1, 0}));
auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
{{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(),
LayoutUtil::MakeLayout({1, 0})));

View File

@ -405,13 +405,13 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
ComputationBuilder builder(client_, TestName());
std::unique_ptr<Literal> param0_literal =
test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0});
std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1});
std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout(
{{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();

View File

@ -0,0 +1,120 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/primitive_util.h"
namespace xla {
namespace {
template <typename FloatT>
void PopulateWithRandomFloatingPointData(Literal* literal) {
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
std::minstd_rand0 engine;
std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f);
TF_CHECK_OK(literal->Populate<FloatT>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
}
template <typename IntT>
void PopulateWithRandomIntegralData(Literal* literal) {
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<IntT>());
std::minstd_rand0 engine;
std::uniform_int_distribution<IntT> generator(
std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
TF_CHECK_OK(literal->Populate<IntT>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
}
} // namespace
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
if (ShapeUtil::IsTuple(shape)) {
std::vector<std::unique_ptr<Literal>> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
MakeFakeLiteral(element_shape));
elements.push_back(std::move(element));
}
return Literal::MakeTupleOwned(std::move(elements));
}
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
switch (shape.element_type()) {
case F32:
PopulateWithRandomFloatingPointData<float>(literal.get());
break;
case F64:
PopulateWithRandomFloatingPointData<double>(literal.get());
break;
case S8:
PopulateWithRandomIntegralData<int8>(literal.get());
break;
case U8:
PopulateWithRandomIntegralData<uint8>(literal.get());
break;
case S16:
PopulateWithRandomIntegralData<int16>(literal.get());
break;
case U16:
PopulateWithRandomIntegralData<uint16>(literal.get());
break;
case S32:
PopulateWithRandomIntegralData<int32>(literal.get());
break;
case U32:
PopulateWithRandomIntegralData<uint32>(literal.get());
break;
case S64:
PopulateWithRandomIntegralData<int64>(literal.get());
break;
case U64:
PopulateWithRandomIntegralData<uint64>(literal.get());
break;
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
std::minstd_rand0 engine;
TF_CHECK_OK(literal->Populate<bool>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
default:
return Unimplemented("Unsupported type for fake literal generation: %s",
ShapeUtil::HumanString(shape).c_str());
}
return std::move(literal);
}
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
const HloModule& module) {
std::vector<std::unique_ptr<Literal>> arguments;
for (const ShapeLayout& shape_layout :
module.config().entry_computation_layout().parameter_layouts()) {
TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape()));
arguments.push_back(std::move(literal));
}
return std::move(arguments);
}
} // namespace xla

View File

@ -23,12 +23,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace test_utils {
// A class which generates pseudorandom numbers of a given type within a given
// range. Not cryptographically secure and likely not perfectly evenly
@ -53,63 +53,15 @@ class PseudorandomGenerator {
std::mt19937 generator_;
};
// Convenience function for creating a rank-2 array with arbitrary layout.
template <typename NativeT>
std::unique_ptr<Literal> CreateR2LiteralWithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
auto literal = MakeUnique<Literal>();
const int64 d0 = values.size();
const int64 d1 = values.begin()->size();
literal.get()->PopulateWithValue<NativeT>(0, {d0, d1});
*literal->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout(minor_to_major);
TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data generation.
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
int64 dim0 = 0;
for (auto inner_list : values) {
int64 dim1 = 0;
for (auto value : inner_list) {
literal.get()->Set({dim0, dim1}, value);
++dim1;
}
++dim0;
}
return literal;
}
// Generates a vector of arguments containing fake data. The number, shape and
// layout of the arguments is appropriate for given HLO module.
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
const HloModule& module);
// Convenience function for creating a rank-3 array with arbitrary layout.
template <typename NativeT>
std::unique_ptr<Literal> CreateR3LiteralWithLayout(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values,
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
auto literal = MakeUnique<Literal>();
const int64 d0 = values.size();
const int64 d1 = values.begin()->size();
const int64 d2 = values.begin()->begin()->size();
literal.get()->PopulateWithValue<NativeT>(0, {d0, d1, d2});
*literal->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout(minor_to_major);
TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
int64 dim0 = 0;
for (auto inner_list : values) {
int64 dim1 = 0;
for (auto inner_inner_list : inner_list) {
int64 dim2 = 0;
for (auto value : inner_inner_list) {
literal.get()->Set({dim0, dim1, dim2}, value);
++dim2;
}
++dim1;
}
++dim0;
}
return literal;
}
} // namespace test_utils
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_

View File

@ -88,6 +88,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:testing",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"