Move MakeFakeLiteral from client/lib/testing.h to tests/test_utils.h. Also remove superfluous literal creation methods in that file, and replace them with the existing ones in the Literal class.
Also, optionally print layout in Literal::ToString. PiperOrigin-RevId: 175076277
This commit is contained in:
parent
35febc0cc9
commit
a6babd6a4f
@ -44,6 +44,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client:computation",
|
"//tensorflow/compiler/xla/client:computation",
|
||||||
"//tensorflow/compiler/xla/client:computation_builder",
|
"//tensorflow/compiler/xla/client:computation_builder",
|
||||||
"//tensorflow/compiler/xla/client:global_data",
|
"//tensorflow/compiler/xla/client:global_data",
|
||||||
|
"//tensorflow/compiler/xla/tests:test_utils",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
@ -48,62 +49,6 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
|
|
||||||
if (ShapeUtil::IsTuple(shape)) {
|
|
||||||
std::vector<std::unique_ptr<Literal>> elements;
|
|
||||||
for (const Shape& element_shape : shape.tuple_shapes()) {
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
|
|
||||||
MakeFakeLiteral(element_shape));
|
|
||||||
elements.push_back(std::move(element));
|
|
||||||
}
|
|
||||||
return Literal::MakeTupleOwned(std::move(elements));
|
|
||||||
}
|
|
||||||
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
|
|
||||||
std::minstd_rand0 engine;
|
|
||||||
switch (shape.element_type()) {
|
|
||||||
case F32: {
|
|
||||||
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
|
|
||||||
TF_CHECK_OK(literal->Populate<float>(
|
|
||||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
|
||||||
return generator(engine);
|
|
||||||
}));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case S32: {
|
|
||||||
std::uniform_int_distribution<int32> generator(
|
|
||||||
std::numeric_limits<int32>::lowest(),
|
|
||||||
std::numeric_limits<int32>::max());
|
|
||||||
TF_CHECK_OK(literal->Populate<int32>(
|
|
||||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
|
||||||
return generator(engine);
|
|
||||||
}));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case S64: {
|
|
||||||
std::uniform_int_distribution<int64> generator(
|
|
||||||
std::numeric_limits<int64>::lowest(),
|
|
||||||
std::numeric_limits<int64>::max());
|
|
||||||
TF_CHECK_OK(literal->Populate<int64>(
|
|
||||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
|
||||||
return generator(engine);
|
|
||||||
}));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case PRED: {
|
|
||||||
std::uniform_int_distribution<int> generator(0, 1);
|
|
||||||
TF_CHECK_OK(literal->Populate<bool>(
|
|
||||||
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
|
||||||
return generator(engine);
|
|
||||||
}));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return Unimplemented("Unsupported type for fake literal generation: %s",
|
|
||||||
ShapeUtil::HumanString(shape).c_str());
|
|
||||||
}
|
|
||||||
return std::move(literal);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
|
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
|
||||||
Client* client) {
|
Client* client) {
|
||||||
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {
|
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {
|
||||||
|
@ -26,10 +26,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Generates fake data in a literal of the given shape, or returns an error
|
|
||||||
// status if the element type is currently unhandled for fake data generation.
|
|
||||||
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
|
|
||||||
|
|
||||||
// Generates fake data of the given shape on the device or dies. The fake data
|
// Generates fake data of the given shape on the device or dies. The fake data
|
||||||
// is created by performing a computation on the device rather than transferring
|
// is created by performing a computation on the device rather than transferring
|
||||||
// data from the host to the device.
|
// data from the host to the device.
|
||||||
|
@ -569,9 +569,17 @@ int64 Literal::LinearIndex(
|
|||||||
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
|
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
string Literal::ToString() const {
|
string Literal::ToString(bool print_layout) const {
|
||||||
std::vector<string> pieces;
|
std::vector<string> pieces;
|
||||||
|
|
||||||
|
auto shape_to_string = [print_layout](const Shape& shape) {
|
||||||
|
if (print_layout) {
|
||||||
|
return ShapeUtil::HumanStringWithLayout(shape);
|
||||||
|
} else {
|
||||||
|
return ShapeUtil::HumanString(shape);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
auto element_to_string =
|
auto element_to_string =
|
||||||
[this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
|
[this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
|
||||||
PrimitiveType element_type = shape().element_type();
|
PrimitiveType element_type = shape().element_type();
|
||||||
@ -585,7 +593,7 @@ string Literal::ToString() const {
|
|||||||
|
|
||||||
// TODO(b/32894291): refactor this code to reduce code duplication.
|
// TODO(b/32894291): refactor this code to reduce code duplication.
|
||||||
if (ShapeUtil::IsTuple(shape())) {
|
if (ShapeUtil::IsTuple(shape())) {
|
||||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
pieces.push_back(shape_to_string(shape()));
|
||||||
pieces.push_back(" (\n");
|
pieces.push_back(" (\n");
|
||||||
pieces.push_back(tensorflow::str_util::Join(
|
pieces.push_back(tensorflow::str_util::Join(
|
||||||
tuple_literals(), ",\n", [](string* out, const Literal& element) {
|
tuple_literals(), ",\n", [](string* out, const Literal& element) {
|
||||||
@ -601,7 +609,7 @@ string Literal::ToString() const {
|
|||||||
}
|
}
|
||||||
pieces.push_back("}");
|
pieces.push_back("}");
|
||||||
} else if (ShapeUtil::Rank(shape()) == 2) {
|
} else if (ShapeUtil::Rank(shape()) == 2) {
|
||||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
pieces.push_back(shape_to_string(shape()));
|
||||||
pieces.push_back(" {\n");
|
pieces.push_back(" {\n");
|
||||||
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
||||||
pieces.push_back(" { ");
|
pieces.push_back(" { ");
|
||||||
@ -613,7 +621,7 @@ string Literal::ToString() const {
|
|||||||
}
|
}
|
||||||
pieces.push_back("}");
|
pieces.push_back("}");
|
||||||
} else if (ShapeUtil::Rank(shape()) == 3) {
|
} else if (ShapeUtil::Rank(shape()) == 3) {
|
||||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
pieces.push_back(shape_to_string(shape()));
|
||||||
pieces.push_back(" {\n");
|
pieces.push_back(" {\n");
|
||||||
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
||||||
pieces.push_back(i0 > 0 ? ",\n{" : "{");
|
pieces.push_back(i0 > 0 ? ",\n{" : "{");
|
||||||
@ -628,7 +636,7 @@ string Literal::ToString() const {
|
|||||||
}
|
}
|
||||||
pieces.push_back("\n}");
|
pieces.push_back("\n}");
|
||||||
} else if (ShapeUtil::Rank(shape()) == 4) {
|
} else if (ShapeUtil::Rank(shape()) == 4) {
|
||||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
pieces.push_back(shape_to_string(shape()));
|
||||||
pieces.push_back(" {\n");
|
pieces.push_back(" {\n");
|
||||||
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
||||||
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
|
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
|
||||||
@ -649,7 +657,7 @@ string Literal::ToString() const {
|
|||||||
}
|
}
|
||||||
pieces.push_back("}");
|
pieces.push_back("}");
|
||||||
} else if (ShapeUtil::Rank(shape()) == 5) {
|
} else if (ShapeUtil::Rank(shape()) == 5) {
|
||||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
pieces.push_back(shape_to_string(shape()));
|
||||||
pieces.push_back(" {\n");
|
pieces.push_back(" {\n");
|
||||||
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
|
||||||
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
|
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
|
||||||
@ -676,7 +684,7 @@ string Literal::ToString() const {
|
|||||||
}
|
}
|
||||||
pieces.push_back("}");
|
pieces.push_back("}");
|
||||||
} else {
|
} else {
|
||||||
pieces.push_back(ShapeUtil::HumanString(shape()));
|
pieces.push_back(shape_to_string(shape()));
|
||||||
pieces.push_back(" {...}");
|
pieces.push_back(" {...}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -450,7 +450,7 @@ class Literal {
|
|||||||
tensorflow::Status ValidateLiteral() const;
|
tensorflow::Status ValidateLiteral() const;
|
||||||
|
|
||||||
// Returns a string representation of the literal value.
|
// Returns a string representation of the literal value.
|
||||||
string ToString() const;
|
string ToString(bool print_layout = false) const;
|
||||||
|
|
||||||
// Invokes the "per cell" callback for each element in the provided
|
// Invokes the "per cell" callback for each element in the provided
|
||||||
// literal with the element's indices and a string representation of
|
// literal with the element's indices and a string representation of
|
||||||
|
@ -1780,7 +1780,6 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:test_utils",
|
"//tensorflow/compiler/xla/tests:test_utils",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1851,7 +1850,6 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||||
"//tensorflow/compiler/xla/tests:test_utils",
|
"//tensorflow/compiler/xla/tests:test_utils",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
|
|||||||
// Test that two identical constants with different layouts are commoned if
|
// Test that two identical constants with different layouts are commoned if
|
||||||
// the pass is not layout sensitive.
|
// the pass is not layout sensitive.
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
auto constant1 = builder.AddInstruction(
|
||||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||||
/*minor_to_major=*/{0, 1})));
|
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
|
||||||
auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
|
auto constant2 = builder.AddInstruction(
|
||||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||||
/*minor_to_major=*/{1, 0})));
|
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
|
||||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
|
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
|
||||||
|
|
||||||
@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
|
|||||||
// Test that two identical constants with different layouts are *not* commoned
|
// Test that two identical constants with different layouts are *not* commoned
|
||||||
// if the pass is layout sensitive.
|
// if the pass is layout sensitive.
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
auto constant1 = builder.AddInstruction(
|
||||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||||
/*minor_to_major=*/{0, 1})));
|
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
|
||||||
auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
|
auto constant2 = builder.AddInstruction(
|
||||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||||
/*minor_to_major=*/{1, 0})));
|
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
|
||||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
|
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
|
||||||
|
|
||||||
|
@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
|
|||||||
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
|
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
|
||||||
for (auto& minor_to_major : minor_to_majors) {
|
for (auto& minor_to_major : minor_to_majors) {
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
auto constant_literal1 = test_utils::CreateR2LiteralWithLayout<float>(
|
auto constant_literal1 = Literal::CreateR2WithLayout<float>(
|
||||||
{{1.0, 2.0}, {3.0, 4.0}}, minor_to_major);
|
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
|
||||||
auto constant_literal2 = test_utils::CreateR2LiteralWithLayout<float>(
|
auto constant_literal2 = Literal::CreateR2WithLayout<float>(
|
||||||
{{5.0, 6.0}, {7.0, 8.0}}, minor_to_major);
|
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
|
||||||
Shape ashape = constant_literal1->shape();
|
Shape ashape = constant_literal1->shape();
|
||||||
|
|
||||||
auto constant1 = builder.AddInstruction(
|
auto constant1 = builder.AddInstruction(
|
||||||
@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
|
|||||||
// Verify the layouts of a tuple are assigned properly (the element layouts
|
// Verify the layouts of a tuple are assigned properly (the element layouts
|
||||||
// match their source).
|
// match their source).
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
auto constant0 = builder.AddInstruction(
|
||||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||||
{0, 1})));
|
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
|
||||||
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
auto constant1 = builder.AddInstruction(
|
||||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||||
{1, 0})));
|
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
|
||||||
auto tuple = builder.AddInstruction(
|
auto tuple = builder.AddInstruction(
|
||||||
HloInstruction::CreateTuple({constant0, constant1}));
|
HloInstruction::CreateTuple({constant0, constant1}));
|
||||||
|
|
||||||
@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
|
|||||||
TEST_F(LayoutAssignmentTest, TupleSelect) {
|
TEST_F(LayoutAssignmentTest, TupleSelect) {
|
||||||
// Verify layouts of a select with tuple operands is assigned properly.
|
// Verify layouts of a select with tuple operands is assigned properly.
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
auto constant0 = builder.AddInstruction(
|
||||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||||
{0, 1})));
|
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
|
||||||
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
auto constant1 = builder.AddInstruction(
|
||||||
test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
|
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
|
||||||
{1, 0})));
|
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
|
||||||
auto tuple0 = builder.AddInstruction(
|
auto tuple0 = builder.AddInstruction(
|
||||||
HloInstruction::CreateTuple({constant0, constant1}));
|
HloInstruction::CreateTuple({constant0, constant1}));
|
||||||
auto tuple1 = builder.AddInstruction(
|
auto tuple1 = builder.AddInstruction(
|
||||||
|
@ -61,13 +61,14 @@ generate_backend_test_macros()
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "test_utils",
|
name = "test_utils",
|
||||||
testonly = True,
|
srcs = ["test_utils.cc"],
|
||||||
hdrs = ["test_utils.h"],
|
hdrs = ["test_utils.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -469,8 +469,7 @@ template <typename NativeT>
|
|||||||
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
|
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
|
||||||
const int width, NativeT min_value, NativeT max_value, uint32 seed) {
|
const int width, NativeT min_value, NativeT max_value, uint32 seed) {
|
||||||
std::vector<NativeT> result(width);
|
std::vector<NativeT> result(width);
|
||||||
test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
|
PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
|
||||||
seed);
|
|
||||||
for (int i = 0; i < width; ++i) {
|
for (int i = 0; i < width; ++i) {
|
||||||
result[i] = generator.get();
|
result[i] = generator.get();
|
||||||
}
|
}
|
||||||
@ -482,8 +481,7 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
|
|||||||
const int rows, const int cols, NativeT min_value, NativeT max_value,
|
const int rows, const int cols, NativeT min_value, NativeT max_value,
|
||||||
uint32 seed) {
|
uint32 seed) {
|
||||||
auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
|
auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
|
||||||
test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
|
PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
|
||||||
seed);
|
|
||||||
for (int y = 0; y < rows; ++y) {
|
for (int y = 0; y < rows; ++y) {
|
||||||
for (int x = 0; x < cols; ++x) {
|
for (int x = 0; x < cols; ++x) {
|
||||||
(*result)(y, x) = generator.get();
|
(*result)(y, x) = generator.get();
|
||||||
|
@ -54,8 +54,8 @@ TEST_F(ClientTest, ExecuteWithLayout) {
|
|||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
|
|
||||||
std::unique_ptr<Literal> expected_literal =
|
std::unique_ptr<Literal> expected_literal =
|
||||||
test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
|
Literal::CreateR2WithLayout<int32>(
|
||||||
transfer_layout);
|
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
|
||||||
|
|
||||||
auto computed = client_->Transfer(*data, &expected_literal->shape());
|
auto computed = client_->Transfer(*data, &expected_literal->shape());
|
||||||
|
|
||||||
|
@ -138,13 +138,13 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
|
|||||||
// layouts. Use these arrays as parameters to a simple computation. If the
|
// layouts. Use these arrays as parameters to a simple computation. If the
|
||||||
// layout of the array changes then computation should be recompiled (cache
|
// layout of the array changes then computation should be recompiled (cache
|
||||||
// miss).
|
// miss).
|
||||||
auto rowmaj_array = test_utils::CreateR2LiteralWithLayout(
|
auto rowmaj_array = Literal::CreateR2WithLayout(
|
||||||
{{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0});
|
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
|
||||||
auto rowmaj_handle =
|
auto rowmaj_handle =
|
||||||
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
|
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
|
||||||
|
|
||||||
auto colmaj_array = test_utils::CreateR2LiteralWithLayout(
|
auto colmaj_array = Literal::CreateR2WithLayout(
|
||||||
{{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1});
|
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
|
||||||
auto colmaj_handle =
|
auto colmaj_handle =
|
||||||
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
|
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
|
||||||
|
|
||||||
|
@ -264,8 +264,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
|
|||||||
ASSERT_TRUE(computed.ok()) << computed.status();
|
ASSERT_TRUE(computed.ok()) << computed.status();
|
||||||
|
|
||||||
std::unique_ptr<Literal> expected_literal =
|
std::unique_ptr<Literal> expected_literal =
|
||||||
test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
|
Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
|
||||||
layout);
|
LayoutUtil::MakeLayout(layout));
|
||||||
LiteralTestUtil::AssertEqualShapesAndLayouts(
|
LiteralTestUtil::AssertEqualShapesAndLayouts(
|
||||||
expected_literal->shape(), computed.ValueOrDie()->shape());
|
expected_literal->shape(), computed.ValueOrDie()->shape());
|
||||||
LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
|
LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
|
||||||
|
@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
|
|||||||
bool rhs_row_major) {
|
bool rhs_row_major) {
|
||||||
auto lhs_handle =
|
auto lhs_handle =
|
||||||
client_
|
client_
|
||||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
|
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
|
||||||
{{1.0, 2.0}, {3.0, -4.0}},
|
{{1.0, 2.0}, {3.0, -4.0}},
|
||||||
MinorToMajorForIsRowMajor(lhs_row_major)))
|
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
|
||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
auto rhs_handle =
|
auto rhs_handle =
|
||||||
client_
|
client_
|
||||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
|
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
|
||||||
{{1.0, 6.0}, {7.0, -4.0}},
|
{{1.0, 6.0}, {7.0, -4.0}},
|
||||||
MinorToMajorForIsRowMajor(rhs_row_major)))
|
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
|
||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
|
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
@ -362,15 +362,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
|
|||||||
bool rhs_row_major) {
|
bool rhs_row_major) {
|
||||||
auto lhs_handle =
|
auto lhs_handle =
|
||||||
client_
|
client_
|
||||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
|
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
|
||||||
{{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
|
{{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
|
||||||
MinorToMajorForIsRowMajor(lhs_row_major)))
|
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
|
||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
auto rhs_handle =
|
auto rhs_handle =
|
||||||
client_
|
client_
|
||||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
|
->TransferToServer(*Literal::CreateR2WithLayout<Element>(
|
||||||
{{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
|
{{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
|
||||||
MinorToMajorForIsRowMajor(rhs_row_major)))
|
LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
|
||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
|
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
@ -420,13 +420,14 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
|
|||||||
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
|
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
|
||||||
auto lhs_handle =
|
auto lhs_handle =
|
||||||
client_
|
client_
|
||||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
|
->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
|
||||||
{{1.0, 2.0, 3.0, -4.0}}, {1, 0}))
|
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
|
||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
auto rhs_handle =
|
auto rhs_handle =
|
||||||
client_
|
client_
|
||||||
->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
|
->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
|
||||||
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, {1, 0}))
|
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
|
||||||
|
LayoutUtil::MakeLayout({1, 0})))
|
||||||
.ConsumeValueOrDie();
|
.ConsumeValueOrDie();
|
||||||
|
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
@ -136,16 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
|
|||||||
auto computation = builder.Build().ConsumeValueOrDie();
|
auto computation = builder.Build().ConsumeValueOrDie();
|
||||||
|
|
||||||
// Create x as a col-major array.
|
// Create x as a col-major array.
|
||||||
auto x_array = LiteralToShapedBuffer(
|
auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
|
||||||
*test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}},
|
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
|
||||||
/*minor_to_major=*/{0, 1}));
|
|
||||||
EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(),
|
EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(),
|
||||||
LayoutUtil::MakeLayout({0, 1})));
|
LayoutUtil::MakeLayout({0, 1})));
|
||||||
|
|
||||||
// Create y as a row-major array.
|
// Create y as a row-major array.
|
||||||
auto y_array = LiteralToShapedBuffer(
|
auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
|
||||||
*test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}},
|
{{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
|
||||||
/*minor_to_major=*/{1, 0}));
|
|
||||||
EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(),
|
EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(),
|
||||||
LayoutUtil::MakeLayout({1, 0})));
|
LayoutUtil::MakeLayout({1, 0})));
|
||||||
|
|
||||||
|
@ -405,13 +405,13 @@ TEST_F(MapTest, MapBinaryAdder) {
|
|||||||
// for Map that used to fail in shape inference (b/28989438).
|
// for Map that used to fail in shape inference (b/28989438).
|
||||||
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
|
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
std::unique_ptr<Literal> param0_literal =
|
std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout(
|
||||||
test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0});
|
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
|
||||||
std::unique_ptr<GlobalData> param0_data =
|
std::unique_ptr<GlobalData> param0_data =
|
||||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||||
|
|
||||||
std::unique_ptr<Literal> param1_literal =
|
std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout(
|
||||||
test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1});
|
{{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
|
||||||
std::unique_ptr<GlobalData> param1_data =
|
std::unique_ptr<GlobalData> param1_data =
|
||||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||||
|
|
||||||
|
120
tensorflow/compiler/xla/tests/test_utils.cc
Normal file
120
tensorflow/compiler/xla/tests/test_utils.cc
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename FloatT>
|
||||||
|
void PopulateWithRandomFloatingPointData(Literal* literal) {
|
||||||
|
CHECK_EQ(literal->shape().element_type(),
|
||||||
|
primitive_util::NativeToPrimitiveType<FloatT>());
|
||||||
|
std::minstd_rand0 engine;
|
||||||
|
std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f);
|
||||||
|
TF_CHECK_OK(literal->Populate<FloatT>(
|
||||||
|
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||||
|
return generator(engine);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
void PopulateWithRandomIntegralData(Literal* literal) {
|
||||||
|
CHECK_EQ(literal->shape().element_type(),
|
||||||
|
primitive_util::NativeToPrimitiveType<IntT>());
|
||||||
|
std::minstd_rand0 engine;
|
||||||
|
std::uniform_int_distribution<IntT> generator(
|
||||||
|
std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
|
||||||
|
TF_CHECK_OK(literal->Populate<IntT>(
|
||||||
|
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||||
|
return generator(engine);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
|
||||||
|
if (ShapeUtil::IsTuple(shape)) {
|
||||||
|
std::vector<std::unique_ptr<Literal>> elements;
|
||||||
|
for (const Shape& element_shape : shape.tuple_shapes()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
|
||||||
|
MakeFakeLiteral(element_shape));
|
||||||
|
elements.push_back(std::move(element));
|
||||||
|
}
|
||||||
|
return Literal::MakeTupleOwned(std::move(elements));
|
||||||
|
}
|
||||||
|
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
|
||||||
|
switch (shape.element_type()) {
|
||||||
|
case F32:
|
||||||
|
PopulateWithRandomFloatingPointData<float>(literal.get());
|
||||||
|
break;
|
||||||
|
case F64:
|
||||||
|
PopulateWithRandomFloatingPointData<double>(literal.get());
|
||||||
|
break;
|
||||||
|
case S8:
|
||||||
|
PopulateWithRandomIntegralData<int8>(literal.get());
|
||||||
|
break;
|
||||||
|
case U8:
|
||||||
|
PopulateWithRandomIntegralData<uint8>(literal.get());
|
||||||
|
break;
|
||||||
|
case S16:
|
||||||
|
PopulateWithRandomIntegralData<int16>(literal.get());
|
||||||
|
break;
|
||||||
|
case U16:
|
||||||
|
PopulateWithRandomIntegralData<uint16>(literal.get());
|
||||||
|
break;
|
||||||
|
case S32:
|
||||||
|
PopulateWithRandomIntegralData<int32>(literal.get());
|
||||||
|
break;
|
||||||
|
case U32:
|
||||||
|
PopulateWithRandomIntegralData<uint32>(literal.get());
|
||||||
|
break;
|
||||||
|
case S64:
|
||||||
|
PopulateWithRandomIntegralData<int64>(literal.get());
|
||||||
|
break;
|
||||||
|
case U64:
|
||||||
|
PopulateWithRandomIntegralData<uint64>(literal.get());
|
||||||
|
break;
|
||||||
|
case PRED: {
|
||||||
|
std::uniform_int_distribution<int> generator(0, 1);
|
||||||
|
std::minstd_rand0 engine;
|
||||||
|
TF_CHECK_OK(literal->Populate<bool>(
|
||||||
|
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
|
||||||
|
return generator(engine);
|
||||||
|
}));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return Unimplemented("Unsupported type for fake literal generation: %s",
|
||||||
|
ShapeUtil::HumanString(shape).c_str());
|
||||||
|
}
|
||||||
|
return std::move(literal);
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
|
||||||
|
const HloModule& module) {
|
||||||
|
std::vector<std::unique_ptr<Literal>> arguments;
|
||||||
|
for (const ShapeLayout& shape_layout :
|
||||||
|
module.config().entry_computation_layout().parameter_layouts()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape()));
|
||||||
|
arguments.push_back(std::move(literal));
|
||||||
|
}
|
||||||
|
return std::move(arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
@ -23,12 +23,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace test_utils {
|
|
||||||
|
|
||||||
// A class which generates pseudorandom numbers of a given type within a given
|
// A class which generates pseudorandom numbers of a given type within a given
|
||||||
// range. Not cryptographically secure and likely not perfectly evenly
|
// range. Not cryptographically secure and likely not perfectly evenly
|
||||||
@ -53,63 +53,15 @@ class PseudorandomGenerator {
|
|||||||
std::mt19937 generator_;
|
std::mt19937 generator_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convenience function for creating a rank-2 array with arbitrary layout.
|
// Generates fake data in a literal of the given shape, or returns an error
|
||||||
template <typename NativeT>
|
// status if the element type is currently unhandled for fake data generation.
|
||||||
std::unique_ptr<Literal> CreateR2LiteralWithLayout(
|
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
|
||||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
|
||||||
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
|
|
||||||
auto literal = MakeUnique<Literal>();
|
|
||||||
const int64 d0 = values.size();
|
|
||||||
const int64 d1 = values.begin()->size();
|
|
||||||
literal.get()->PopulateWithValue<NativeT>(0, {d0, d1});
|
|
||||||
*literal->mutable_shape()->mutable_layout() =
|
|
||||||
LayoutUtil::MakeLayout(minor_to_major);
|
|
||||||
TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
|
|
||||||
|
|
||||||
int64 dim0 = 0;
|
// Generates a vector of arguments containing fake data. The number, shape and
|
||||||
for (auto inner_list : values) {
|
// layout of the arguments is appropriate for given HLO module.
|
||||||
int64 dim1 = 0;
|
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
|
||||||
for (auto value : inner_list) {
|
const HloModule& module);
|
||||||
literal.get()->Set({dim0, dim1}, value);
|
|
||||||
++dim1;
|
|
||||||
}
|
|
||||||
++dim0;
|
|
||||||
}
|
|
||||||
return literal;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convenience function for creating a rank-3 array with arbitrary layout.
|
|
||||||
template <typename NativeT>
|
|
||||||
std::unique_ptr<Literal> CreateR3LiteralWithLayout(
|
|
||||||
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
|
|
||||||
values,
|
|
||||||
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
|
|
||||||
auto literal = MakeUnique<Literal>();
|
|
||||||
const int64 d0 = values.size();
|
|
||||||
const int64 d1 = values.begin()->size();
|
|
||||||
const int64 d2 = values.begin()->begin()->size();
|
|
||||||
literal.get()->PopulateWithValue<NativeT>(0, {d0, d1, d2});
|
|
||||||
*literal->mutable_shape()->mutable_layout() =
|
|
||||||
LayoutUtil::MakeLayout(minor_to_major);
|
|
||||||
TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
|
|
||||||
|
|
||||||
int64 dim0 = 0;
|
|
||||||
for (auto inner_list : values) {
|
|
||||||
int64 dim1 = 0;
|
|
||||||
for (auto inner_inner_list : inner_list) {
|
|
||||||
int64 dim2 = 0;
|
|
||||||
for (auto value : inner_inner_list) {
|
|
||||||
literal.get()->Set({dim0, dim1, dim2}, value);
|
|
||||||
++dim2;
|
|
||||||
}
|
|
||||||
++dim1;
|
|
||||||
}
|
|
||||||
++dim0;
|
|
||||||
}
|
|
||||||
return literal;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace test_utils
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
|
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
|
||||||
|
@ -88,6 +88,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/client/lib:testing",
|
"//tensorflow/compiler/xla/client/lib:testing",
|
||||||
"//tensorflow/compiler/xla/service:session_proto",
|
"//tensorflow/compiler/xla/service:session_proto",
|
||||||
|
"//tensorflow/compiler/xla/tests:test_utils",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user