From 943792dcb01e55e21546f77936f4ff12e77eba19 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 18 Jul 2019 11:47:16 -0700 Subject: [PATCH] [XLA] Improve error message in HLO matchers. Previously, mismatches in opcode or number of operands wasn't very informative because the error message didn't print out the HloInstruction string. For example, an opcode mismatch might have looked like previously: Value of: body_data_add Expected: subtract Actual: 0x7f58fe3e8c00 (of type xla::HloInstruction*) With this CL, it now looks like: Value of: body_data_add Expected: subtract Actual: 0x7efefd68ec00 (of type xla::HloInstruction*), (%add.1 = f32[2,3]{1,0:S(1)} add(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %constant.2)) PiperOrigin-RevId: 258814523 --- .../compiler/xla/service/hlo_matchers.cc | 29 ++++---- .../compiler/xla/service/hlo_matchers_test.cc | 67 ++++++++++++------- 2 files changed, 58 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 920838fc981..0cdfbf05ea1 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -26,7 +26,11 @@ bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { // These cases are self-explanatory from the printed value. - if (!instruction || instruction->opcode() != opcode_) { + if (!instruction) { + return false; + } + *listener << "(" << instruction->ToString() << ")"; + if (instruction->opcode() != opcode_) { return false; } // Special case: no operand matchers means don't verify. @@ -35,7 +39,7 @@ bool HloMatcher::MatchAndExplain( } const auto& operands = instruction->operands(); if (operands.size() != operands_.size()) { - *listener << "has too " + *listener << " has too " << (operands.size() > operands_.size() ? "many" : "few") << " operands (got " << operands.size() << ", want " << operands_.size() << ")"; @@ -81,7 +85,7 @@ bool HloParameterMatcher::MatchAndExplain( return false; } if (instruction->parameter_number() != parameter_number_) { - *listener << "has wrong parameter number (got " + *listener << " has wrong parameter number (got " << instruction->parameter_number() << ", want " << parameter_number_ << ")"; return false; @@ -96,7 +100,7 @@ bool HloComparisonMatcher::MatchAndExplain( return false; } if (instruction->comparison_direction() != direction_) { - *listener << "has wrong comparison direction (got " + *listener << " has wrong comparison direction (got " << ComparisonDirectionToString( instruction->comparison_direction()) << ", want " << ComparisonDirectionToString(direction_) << ")"; @@ -112,7 +116,7 @@ bool HloGetTupleElementMatcher::MatchAndExplain( return false; } if (instruction->tuple_index() != tuple_index_) { - *listener << "has wrong tuple index (got " << instruction->tuple_index() + *listener << " has wrong tuple index (got " << instruction->tuple_index() << ", want " << tuple_index_ << ")"; return false; } @@ -145,7 +149,7 @@ bool HloCustomCallMatcher::MatchAndExplain( } sub_listener << desc_stream.str(); } - *listener << "custom-call with call target" << sub_listener.str(); + *listener << " custom-call with call target" << sub_listener.str(); return result; } @@ -222,8 +226,7 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers(); if (dim_nums.lhs_contracting_dimensions_size() != 1 || dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { - *listener << instruction->ToString() - << " has wrong lhs_contracting_dimensions (got {" + *listener << " has wrong lhs_contracting_dimensions (got {" << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" << lhs_contracting_dim_ << "})"; return false; @@ -231,8 +234,7 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( if (dim_nums.rhs_contracting_dimensions_size() != 1 || dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { - *listener << instruction->ToString() - << " has wrong rhs_contracting_dimensions (got {" + *listener << " has wrong rhs_contracting_dimensions (got {" << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" << rhs_contracting_dim_ << "})"; return false; @@ -256,13 +258,12 @@ bool HloAsyncCopyMatcher::MatchAndExplain( const HloInstruction* copy_done = instruction; if (!copy_done->shape().has_layout()) { - *listener << copy_done->ToString() - << " does not have layout, expected a layout with memory space " + *listener << " does not have layout, expected a layout with memory space " << to_space_; return false; } if (copy_done->shape().layout().memory_space() != to_space_) { - *listener << copy_done->ToString() << " copies to memory space " + *listener << " copies to memory space " << copy_done->shape().layout().memory_space() << ", expected " << to_space_; return false; @@ -277,7 +278,7 @@ bool HloAsyncCopyMatcher::MatchAndExplain( return false; } if (copy_start_operand->shape().layout().memory_space() != from_space_) { - *listener << copy_done->ToString() << " is in the memory space " + *listener << " is in the memory space " << copy_start_operand->shape().layout().memory_space() << ", expected " << from_space_; return false; diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index e6dd45d2990..0c6c632f5c8 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -53,26 +53,36 @@ TEST(HloMatchersTest, Test) { op::Add(op::Parameter(), op::Multiply(_, op::Parameter()))); // Negative matches: check the explanation string. - EXPECT_THAT(Explain(add.get(), op::Parameter()), Eq("")); - EXPECT_THAT(Explain(add.get(), op::Add(op::Parameter())), - Eq("has too many operands (got 2, want 1)")); + EXPECT_THAT( + Explain(add.get(), op::Parameter()), + Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))")); + EXPECT_THAT( + Explain(add.get(), op::Add(op::Parameter())), + Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply)) " + "has too many operands (got 2, want 1)")); EXPECT_THAT( Explain(add.get(), op::Add(op::Parameter(), op::Parameter())), - Eq("\noperand 1:\n\t" + Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))" + "\noperand 1:\n\t" "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n" "doesn't match expected:\n\t" - "parameter")); + "parameter" + ", (%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} " + "%param))")); EXPECT_THAT( Explain(add.get(), op::Add(op::Parameter(), op::Multiply(op::Add(), op::Add()))), - Eq("\noperand 1:\n\t" + Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))" + "\noperand 1:\n\t" "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n" "doesn't match expected:\n\t" - "multiply(add, add), \n" + "multiply(add, add)" + ", (%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} " + "%param))\n" "operand 0:\n\t" "%param = f32[1]{0} parameter(0)\n" "doesn't match expected:\n\t" - "add")); + "add, (%param = f32[1]{0} parameter(0))")); } TEST(HloMatchersTest, CustomCallMatcher) { @@ -99,7 +109,9 @@ TEST(HloMatchersTest, CustomCallMatcher) { ::testing::Not(op::CustomCall(::testing::StartsWith("bar")))); EXPECT_THAT(Explain(call.get(), op::CustomCall("bar")), - R"(custom-call with call target that isn't equal to "bar")"); + "(%custom-call = f32[1]{0} custom-call(f32[3]{0} %constant, " + "s32[3]{0} %constant), custom_call_target=\"foo_target\") " + "custom-call with call target that isn't equal to \"bar\""); EXPECT_THAT(DescribeHloMatcher(op::CustomCall("foo_target")), R"(custom-call with call target that is equal to "foo_target")"); } @@ -207,16 +219,16 @@ ENTRY DotOperationFusion_TransposeFusion { Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)), - "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " - "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "(%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}) has wrong " "lhs_contracting_dimensions (got {1} want {0})"); EXPECT_THAT( Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)), - "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " - "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "(%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}) has wrong " "rhs_contracting_dimensions (got {0} want {1})"); } @@ -243,9 +255,13 @@ TEST(HloMatchersTest, ComparisonMatcher) { EXPECT_THAT(le.get(), op::Le(op::Parameter(0), op::Add(op::Parameter(0), op::Parameter(1)))); - EXPECT_THAT(Explain(eq.get(), op::Add()), Eq("")); + EXPECT_THAT(Explain(eq.get(), op::Add()), + Eq("(%compare = f32[1]{0} compare(f32[1]{0} %param.0, " + "f32[1]{0} %param.1), direction=EQ)")); EXPECT_THAT(Explain(eq.get(), op::Ne()), - Eq("has wrong comparison direction (got EQ, want NE)")); + Eq("(%compare = f32[1]{0} compare(f32[1]{0} %param.0, " + "f32[1]{0} %param.1), direction=EQ) " + "has wrong comparison direction (got EQ, want NE)")); } TEST(HloMatchersTest, AsyncCopyMatcher) { @@ -267,15 +283,18 @@ TEST(HloMatchersTest, AsyncCopyMatcher) { EXPECT_THAT(copy_done.get(), op::AsyncCopy(2, 1, op::Parameter(0))); EXPECT_THAT(Explain(copy_start.get(), op::AsyncCopy(2, 1, op::Parameter(0))), - Eq("")); - EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))), - "%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) " - "%copy-start) " - "copies to memory space 2, expected 3"); - EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))), - "%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) " - "%copy-start) " - "is in the memory space 1, expected 3"); + Eq("(%copy-start = (f32[16]{0:S(2)}, u32[]) " + "copy-start(f32[16]{0:S(1)} %p0))")); + EXPECT_THAT( + Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))), + "(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) " + "%copy-start)) " + "copies to memory space 2, expected 3"); + EXPECT_THAT( + Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))), + "(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) " + "%copy-start)) " + "is in the memory space 1, expected 3"); } } // namespace