Use XLA_VLOG_LINES() in literal_test_util to avoid truncation of large tensors.

PiperOrigin-RevId: 163745522
This commit is contained in:
A. Unique TensorFlower 2017-07-31 14:08:59 -07:00 committed by TensorFlower Gardener
parent 043505a094
commit efb7fb8e58
5 changed files with 39 additions and 13 deletions

View File

@ -627,6 +627,11 @@ class HloInstruction {
return fusion_kind_; return fusion_kind_;
} }
void set_fusion_kind(FusionKind kind) {
CHECK_EQ(HloOpcode::kFusion, opcode_);
fusion_kind_ = kind;
}
// Merges the fused instructions from 'instruction_to_merge' into the // Merges the fused instructions from 'instruction_to_merge' into the
// fused instruction set of 'this', updating operands as necessary. // fused instruction set of 'this', updating operands as necessary.
// //

View File

@ -394,12 +394,15 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
VLOG(2) << "Fusing " << producer << " into " << consumer; VLOG(2) << "Fusing " << producer << " into " << consumer;
auto kind = ChooseKind(producer, consumer);
if (consumer->opcode() == HloOpcode::kFusion) { if (consumer->opcode() == HloOpcode::kFusion) {
fusion_instruction = consumer; fusion_instruction = consumer;
if (kind != fusion_instruction->fusion_kind()) {
fusion_instruction->set_fusion_kind(kind);
}
} else { } else {
fusion_instruction = fusion_instruction = computation_->AddInstruction(
computation_->AddInstruction(HloInstruction::CreateFusion( HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
consumer->shape(), ChooseKind(producer, consumer), consumer));
TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction)); TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction));
} }
fusion_instruction->FuseInstruction(producer); fusion_instruction->FuseInstruction(producer);

View File

@ -170,8 +170,10 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( /* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
const Literal& expected, const Literal& actual) { const Literal& expected, const Literal& actual) {
VLOG(1) << "expected: " << expected.ToString(); VLOG(1) << "expected:";
VLOG(1) << "actual: " << actual.ToString(); XLA_VLOG_LINES(1, expected.ToString());
VLOG(1) << "actual:";
XLA_VLOG_LINES(1, actual.ToString());
AssertEqualShapes(expected.shape(), actual.shape()); AssertEqualShapes(expected.shape(), actual.shape());
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0); std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
@ -256,8 +258,10 @@ class NearComparator {
// within the error bound. Emits useful log messages and dumps literals to // within the error bound. Emits useful log messages and dumps literals to
// temporary files on failure. Returns true if literals match. // temporary files on failure. Returns true if literals match.
bool ExpectNear(const Literal& expected, const Literal& actual) { bool ExpectNear(const Literal& expected, const Literal& actual) {
VLOG(1) << "expected: " << expected.ToString(); VLOG(1) << "expected:";
VLOG(1) << "actual: " << actual.ToString(); XLA_VLOG_LINES(1, expected.ToString());
VLOG(1) << "actual:";
XLA_VLOG_LINES(1, actual.ToString());
LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape()); LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape());

View File

@ -159,11 +159,12 @@ TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); }
class MatOpsDotAddTest class MatOpsDotAddTest
: public ClientLibraryTestBase, : public ClientLibraryTestBase,
public ::testing::WithParamInterface<std::tuple<bool, bool>> {}; public ::testing::WithParamInterface<std::tuple<bool, bool, bool>> {};
TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) { TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) {
bool row_major = std::get<0>(GetParam()); bool row_major = std::get<0>(GetParam());
bool add_lhs = std::get<1>(GetParam()); bool add_lhs = std::get<1>(GetParam());
bool transpose = std::get<2>(GetParam());
Array2D<float> lhs({{1.0, 2.0}, {3.0, 4.0}}); Array2D<float> lhs({{1.0, 2.0}, {3.0, 4.0}});
Array2D<float> rhs({{10.0, 11.0}, {12.0, 13.0}}); Array2D<float> rhs({{10.0, 11.0}, {12.0, 13.0}});
@ -188,15 +189,27 @@ TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs");
auto lhs_mat_arg = lhs_arg;
if (transpose) {
lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0});
}
auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs");
auto result = builder.Dot(lhs_arg, rhs_arg); auto result = builder.Dot(lhs_mat_arg, rhs_arg);
Array2D<float> expected; Array2D<float> expected;
if (add_lhs) { if (add_lhs) {
result = builder.Add(result, lhs_arg); result = builder.Add(result, lhs_arg);
expected = Array2D<float>({{35, 39}, {81, 89}}); if (transpose) {
expected = Array2D<float>({{47, 52}, {71, 78}});
} else {
expected = Array2D<float>({{35, 39}, {81, 89}});
}
} else { } else {
result = builder.Add(result, rhs_arg); result = builder.Add(result, rhs_arg);
expected = Array2D<float>({{44, 48}, {90, 98}}); if (transpose) {
expected = Array2D<float>({{56, 61}, {80, 87}});
} else {
expected = Array2D<float>({{44, 48}, {90, 98}});
}
} }
ComputeAndCompareR2<float>(&builder, expected, ComputeAndCompareR2<float>(&builder, expected,
@ -205,7 +218,7 @@ TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) {
} }
INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest,
::testing::Combine(::testing::Bool(), ::testing::Combine(::testing::Bool(), ::testing::Bool(),
::testing::Bool())); ::testing::Bool()));
} // namespace } // namespace

View File

@ -748,7 +748,8 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
const auto& bounds = GetParam().bounds; const auto& bounds = GetParam().bounds;
Array3D<float> input_array(bounds[0], bounds[1], bounds[2]); Array3D<float> input_array(bounds[0], bounds[1], bounds[2]);
input_array.FillRandom(3.14f, 0.05); // input_array.FillRandom(3.14f, 0.05);
input_array.Fill(1.0f);
auto input_literal = Literal::CreateR3FromArray3D(input_array); auto input_literal = Literal::CreateR3FromArray3D(input_array);
input_literal = input_literal =