diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 920838fc981..0cdfbf05ea1 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -26,7 +26,11 @@ bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { // These cases are self-explanatory from the printed value. - if (!instruction || instruction->opcode() != opcode_) { + if (!instruction) { + return false; + } + *listener << "(" << instruction->ToString() << ")"; + if (instruction->opcode() != opcode_) { return false; } // Special case: no operand matchers means don't verify. @@ -35,7 +39,7 @@ bool HloMatcher::MatchAndExplain( } const auto& operands = instruction->operands(); if (operands.size() != operands_.size()) { - *listener << "has too " + *listener << " has too " << (operands.size() > operands_.size() ? "many" : "few") << " operands (got " << operands.size() << ", want " << operands_.size() << ")"; @@ -81,7 +85,7 @@ bool HloParameterMatcher::MatchAndExplain( return false; } if (instruction->parameter_number() != parameter_number_) { - *listener << "has wrong parameter number (got " + *listener << " has wrong parameter number (got " << instruction->parameter_number() << ", want " << parameter_number_ << ")"; return false; @@ -96,7 +100,7 @@ bool HloComparisonMatcher::MatchAndExplain( return false; } if (instruction->comparison_direction() != direction_) { - *listener << "has wrong comparison direction (got " + *listener << " has wrong comparison direction (got " << ComparisonDirectionToString( instruction->comparison_direction()) << ", want " << ComparisonDirectionToString(direction_) << ")"; @@ -112,7 +116,7 @@ bool HloGetTupleElementMatcher::MatchAndExplain( return false; } if (instruction->tuple_index() != tuple_index_) { - *listener << "has wrong tuple index (got " << instruction->tuple_index() + *listener << " has wrong tuple index (got " << instruction->tuple_index() << ", want " << tuple_index_ << ")"; return false; } @@ -145,7 +149,7 @@ bool HloCustomCallMatcher::MatchAndExplain( } sub_listener << desc_stream.str(); } - *listener << "custom-call with call target" << sub_listener.str(); + *listener << " custom-call with call target" << sub_listener.str(); return result; } @@ -222,8 +226,7 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers(); if (dim_nums.lhs_contracting_dimensions_size() != 1 || dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { - *listener << instruction->ToString() - << " has wrong lhs_contracting_dimensions (got {" + *listener << " has wrong lhs_contracting_dimensions (got {" << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" << lhs_contracting_dim_ << "})"; return false; @@ -231,8 +234,7 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( if (dim_nums.rhs_contracting_dimensions_size() != 1 || dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { - *listener << instruction->ToString() - << " has wrong rhs_contracting_dimensions (got {" + *listener << " has wrong rhs_contracting_dimensions (got {" << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" << rhs_contracting_dim_ << "})"; return false; @@ -256,13 +258,12 @@ bool HloAsyncCopyMatcher::MatchAndExplain( const HloInstruction* copy_done = instruction; if (!copy_done->shape().has_layout()) { - *listener << copy_done->ToString() - << " does not have layout, expected a layout with memory space " + *listener << " does not have layout, expected a layout with memory space " << to_space_; return false; } if (copy_done->shape().layout().memory_space() != to_space_) { - *listener << copy_done->ToString() << " copies to memory space " + *listener << " copies to memory space " << copy_done->shape().layout().memory_space() << ", expected " << to_space_; return false; @@ -277,7 +278,7 @@ bool HloAsyncCopyMatcher::MatchAndExplain( return false; } if (copy_start_operand->shape().layout().memory_space() != from_space_) { - *listener << copy_done->ToString() << " is in the memory space " + *listener << " is in the memory space " << copy_start_operand->shape().layout().memory_space() << ", expected " << from_space_; return false; diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index e6dd45d2990..0c6c632f5c8 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -53,26 +53,36 @@ TEST(HloMatchersTest, Test) { op::Add(op::Parameter(), op::Multiply(_, op::Parameter()))); // Negative matches: check the explanation string. - EXPECT_THAT(Explain(add.get(), op::Parameter()), Eq("")); - EXPECT_THAT(Explain(add.get(), op::Add(op::Parameter())), - Eq("has too many operands (got 2, want 1)")); + EXPECT_THAT( + Explain(add.get(), op::Parameter()), + Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))")); + EXPECT_THAT( + Explain(add.get(), op::Add(op::Parameter())), + Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply)) " + "has too many operands (got 2, want 1)")); EXPECT_THAT( Explain(add.get(), op::Add(op::Parameter(), op::Parameter())), - Eq("\noperand 1:\n\t" + Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))" + "\noperand 1:\n\t" "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n" "doesn't match expected:\n\t" - "parameter")); + "parameter" + ", (%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} " + "%param))")); EXPECT_THAT( Explain(add.get(), op::Add(op::Parameter(), op::Multiply(op::Add(), op::Add()))), - Eq("\noperand 1:\n\t" + Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))" + "\noperand 1:\n\t" "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n" "doesn't match expected:\n\t" - "multiply(add, add), \n" + "multiply(add, add)" + ", (%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} " + "%param))\n" "operand 0:\n\t" "%param = f32[1]{0} parameter(0)\n" "doesn't match expected:\n\t" - "add")); + "add, (%param = f32[1]{0} parameter(0))")); } TEST(HloMatchersTest, CustomCallMatcher) { @@ -99,7 +109,9 @@ TEST(HloMatchersTest, CustomCallMatcher) { ::testing::Not(op::CustomCall(::testing::StartsWith("bar")))); EXPECT_THAT(Explain(call.get(), op::CustomCall("bar")), - R"(custom-call with call target that isn't equal to "bar")"); + "(%custom-call = f32[1]{0} custom-call(f32[3]{0} %constant, " + "s32[3]{0} %constant), custom_call_target=\"foo_target\") " + "custom-call with call target that isn't equal to \"bar\""); EXPECT_THAT(DescribeHloMatcher(op::CustomCall("foo_target")), R"(custom-call with call target that is equal to "foo_target")"); } @@ -207,16 +219,16 @@ ENTRY DotOperationFusion_TransposeFusion { Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)), - "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " - "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "(%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}) has wrong " "lhs_contracting_dimensions (got {1} want {0})"); EXPECT_THAT( Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)), - "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " - "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "(%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}) has wrong " "rhs_contracting_dimensions (got {0} want {1})"); } @@ -243,9 +255,13 @@ TEST(HloMatchersTest, ComparisonMatcher) { EXPECT_THAT(le.get(), op::Le(op::Parameter(0), op::Add(op::Parameter(0), op::Parameter(1)))); - EXPECT_THAT(Explain(eq.get(), op::Add()), Eq("")); + EXPECT_THAT(Explain(eq.get(), op::Add()), + Eq("(%compare = f32[1]{0} compare(f32[1]{0} %param.0, " + "f32[1]{0} %param.1), direction=EQ)")); EXPECT_THAT(Explain(eq.get(), op::Ne()), - Eq("has wrong comparison direction (got EQ, want NE)")); + Eq("(%compare = f32[1]{0} compare(f32[1]{0} %param.0, " + "f32[1]{0} %param.1), direction=EQ) " + "has wrong comparison direction (got EQ, want NE)")); } TEST(HloMatchersTest, AsyncCopyMatcher) { @@ -267,15 +283,18 @@ TEST(HloMatchersTest, AsyncCopyMatcher) { EXPECT_THAT(copy_done.get(), op::AsyncCopy(2, 1, op::Parameter(0))); EXPECT_THAT(Explain(copy_start.get(), op::AsyncCopy(2, 1, op::Parameter(0))), - Eq("")); - EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))), - "%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) " - "%copy-start) " - "copies to memory space 2, expected 3"); - EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))), - "%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) " - "%copy-start) " - "is in the memory space 1, expected 3"); + Eq("(%copy-start = (f32[16]{0:S(2)}, u32[]) " + "copy-start(f32[16]{0:S(1)} %p0))")); + EXPECT_THAT( + Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))), + "(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) " + "%copy-start)) " + "copies to memory space 2, expected 3"); + EXPECT_THAT( + Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))), + "(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) " + "%copy-start)) " + "is in the memory space 1, expected 3"); } } // namespace