[XLA] Improve error message in HLO matchers.

Previously, mismatches in opcode or number of operands wasn't very informative
because the error message didn't print out the HloInstruction string. For
example, an opcode mismatch might have looked like previously:

Value of: body_data_add
Expected: subtract
  Actual: 0x7f58fe3e8c00 (of type xla::HloInstruction*)

With this CL, it now looks like:

Value of: body_data_add
Expected: subtract
  Actual: 0x7efefd68ec00 (of type xla::HloInstruction*), (%add.1 = f32[2,3]{1,0:S(1)} add(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %constant.2))
PiperOrigin-RevId: 258814523
This commit is contained in:
A. Unique TensorFlower 2019-07-18 11:47:16 -07:00 committed by TensorFlower Gardener
parent f1b70ad839
commit 943792dcb0
2 changed files with 58 additions and 38 deletions

View File

@ -26,7 +26,11 @@ bool HloMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
// These cases are self-explanatory from the printed value.
if (!instruction || instruction->opcode() != opcode_) {
if (!instruction) {
return false;
}
*listener << "(" << instruction->ToString() << ")";
if (instruction->opcode() != opcode_) {
return false;
}
// Special case: no operand matchers means don't verify.
@ -35,7 +39,7 @@ bool HloMatcher::MatchAndExplain(
}
const auto& operands = instruction->operands();
if (operands.size() != operands_.size()) {
*listener << "has too "
*listener << " has too "
<< (operands.size() > operands_.size() ? "many" : "few")
<< " operands (got " << operands.size() << ", want "
<< operands_.size() << ")";
@ -81,7 +85,7 @@ bool HloParameterMatcher::MatchAndExplain(
return false;
}
if (instruction->parameter_number() != parameter_number_) {
*listener << "has wrong parameter number (got "
*listener << " has wrong parameter number (got "
<< instruction->parameter_number() << ", want "
<< parameter_number_ << ")";
return false;
@ -96,7 +100,7 @@ bool HloComparisonMatcher::MatchAndExplain(
return false;
}
if (instruction->comparison_direction() != direction_) {
*listener << "has wrong comparison direction (got "
*listener << " has wrong comparison direction (got "
<< ComparisonDirectionToString(
instruction->comparison_direction())
<< ", want " << ComparisonDirectionToString(direction_) << ")";
@ -112,7 +116,7 @@ bool HloGetTupleElementMatcher::MatchAndExplain(
return false;
}
if (instruction->tuple_index() != tuple_index_) {
*listener << "has wrong tuple index (got " << instruction->tuple_index()
*listener << " has wrong tuple index (got " << instruction->tuple_index()
<< ", want " << tuple_index_ << ")";
return false;
}
@ -145,7 +149,7 @@ bool HloCustomCallMatcher::MatchAndExplain(
}
sub_listener << desc_stream.str();
}
*listener << "custom-call with call target" << sub_listener.str();
*listener << " custom-call with call target" << sub_listener.str();
return result;
}
@ -222,8 +226,7 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain(
const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers();
if (dim_nums.lhs_contracting_dimensions_size() != 1 ||
dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) {
*listener << instruction->ToString()
<< " has wrong lhs_contracting_dimensions (got {"
*listener << " has wrong lhs_contracting_dimensions (got {"
<< absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",")
<< "} want {" << lhs_contracting_dim_ << "})";
return false;
@ -231,8 +234,7 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain(
if (dim_nums.rhs_contracting_dimensions_size() != 1 ||
dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) {
*listener << instruction->ToString()
<< " has wrong rhs_contracting_dimensions (got {"
*listener << " has wrong rhs_contracting_dimensions (got {"
<< absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",")
<< "} want {" << rhs_contracting_dim_ << "})";
return false;
@ -256,13 +258,12 @@ bool HloAsyncCopyMatcher::MatchAndExplain(
const HloInstruction* copy_done = instruction;
if (!copy_done->shape().has_layout()) {
*listener << copy_done->ToString()
<< " does not have layout, expected a layout with memory space "
*listener << " does not have layout, expected a layout with memory space "
<< to_space_;
return false;
}
if (copy_done->shape().layout().memory_space() != to_space_) {
*listener << copy_done->ToString() << " copies to memory space "
*listener << " copies to memory space "
<< copy_done->shape().layout().memory_space() << ", expected "
<< to_space_;
return false;
@ -277,7 +278,7 @@ bool HloAsyncCopyMatcher::MatchAndExplain(
return false;
}
if (copy_start_operand->shape().layout().memory_space() != from_space_) {
*listener << copy_done->ToString() << " is in the memory space "
*listener << " is in the memory space "
<< copy_start_operand->shape().layout().memory_space()
<< ", expected " << from_space_;
return false;

View File

@ -53,26 +53,36 @@ TEST(HloMatchersTest, Test) {
op::Add(op::Parameter(), op::Multiply(_, op::Parameter())));
// Negative matches: check the explanation string.
EXPECT_THAT(Explain(add.get(), op::Parameter()), Eq(""));
EXPECT_THAT(Explain(add.get(), op::Add(op::Parameter())),
Eq("has too many operands (got 2, want 1)"));
EXPECT_THAT(
Explain(add.get(), op::Parameter()),
Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))"));
EXPECT_THAT(
Explain(add.get(), op::Add(op::Parameter())),
Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply)) "
"has too many operands (got 2, want 1)"));
EXPECT_THAT(
Explain(add.get(), op::Add(op::Parameter(), op::Parameter())),
Eq("\noperand 1:\n\t"
Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))"
"\noperand 1:\n\t"
"%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n"
"doesn't match expected:\n\t"
"parameter"));
"parameter"
", (%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} "
"%param))"));
EXPECT_THAT(
Explain(add.get(),
op::Add(op::Parameter(), op::Multiply(op::Add(), op::Add()))),
Eq("\noperand 1:\n\t"
Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))"
"\noperand 1:\n\t"
"%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n"
"doesn't match expected:\n\t"
"multiply(add, add), \n"
"multiply(add, add)"
", (%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} "
"%param))\n"
"operand 0:\n\t"
"%param = f32[1]{0} parameter(0)\n"
"doesn't match expected:\n\t"
"add"));
"add, (%param = f32[1]{0} parameter(0))"));
}
TEST(HloMatchersTest, CustomCallMatcher) {
@ -99,7 +109,9 @@ TEST(HloMatchersTest, CustomCallMatcher) {
::testing::Not(op::CustomCall(::testing::StartsWith("bar"))));
EXPECT_THAT(Explain(call.get(), op::CustomCall("bar")),
R"(custom-call with call target that isn't equal to "bar")");
"(%custom-call = f32[1]{0} custom-call(f32[3]{0} %constant, "
"s32[3]{0} %constant), custom_call_target=\"foo_target\") "
"custom-call with call target that isn't equal to \"bar\"");
EXPECT_THAT(DescribeHloMatcher(op::CustomCall("foo_target")),
R"(custom-call with call target that is equal to "foo_target")");
}
@ -207,16 +219,16 @@ ENTRY DotOperationFusion_TransposeFusion {
Explain(root, op::Dot(op::Parameter(0), op::Parameter(1),
/*lhs_contracting_dim=*/0,
/*rhs_contracting_dim=*/0)),
"%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} "
"%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong "
"(%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} "
"%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}) has wrong "
"lhs_contracting_dimensions (got {1} want {0})");
EXPECT_THAT(
Explain(root, op::Dot(op::Parameter(0), op::Parameter(1),
/*lhs_contracting_dim=*/1,
/*rhs_contracting_dim=*/1)),
"%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} "
"%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong "
"(%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} "
"%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}) has wrong "
"rhs_contracting_dimensions (got {0} want {1})");
}
@ -243,9 +255,13 @@ TEST(HloMatchersTest, ComparisonMatcher) {
EXPECT_THAT(le.get(), op::Le(op::Parameter(0),
op::Add(op::Parameter(0), op::Parameter(1))));
EXPECT_THAT(Explain(eq.get(), op::Add()), Eq(""));
EXPECT_THAT(Explain(eq.get(), op::Add()),
Eq("(%compare = f32[1]{0} compare(f32[1]{0} %param.0, "
"f32[1]{0} %param.1), direction=EQ)"));
EXPECT_THAT(Explain(eq.get(), op::Ne()),
Eq("has wrong comparison direction (got EQ, want NE)"));
Eq("(%compare = f32[1]{0} compare(f32[1]{0} %param.0, "
"f32[1]{0} %param.1), direction=EQ) "
"has wrong comparison direction (got EQ, want NE)"));
}
TEST(HloMatchersTest, AsyncCopyMatcher) {
@ -267,15 +283,18 @@ TEST(HloMatchersTest, AsyncCopyMatcher) {
EXPECT_THAT(copy_done.get(), op::AsyncCopy(2, 1, op::Parameter(0)));
EXPECT_THAT(Explain(copy_start.get(), op::AsyncCopy(2, 1, op::Parameter(0))),
Eq(""));
EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))),
"%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) "
"%copy-start) "
"copies to memory space 2, expected 3");
EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))),
"%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) "
"%copy-start) "
"is in the memory space 1, expected 3");
Eq("(%copy-start = (f32[16]{0:S(2)}, u32[]) "
"copy-start(f32[16]{0:S(1)} %p0))"));
EXPECT_THAT(
Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))),
"(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) "
"%copy-start)) "
"copies to memory space 2, expected 3");
EXPECT_THAT(
Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))),
"(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) "
"%copy-start)) "
"is in the memory space 1, expected 3");
}
} // namespace