STT-tensorflow/tensorflow/compiler/xla/service/hlo_matchers.cc
A. Unique TensorFlower 943792dcb0 [XLA] Improve error message in HLO matchers.
Previously, mismatches in opcode or number of operands wasn't very informative
because the error message didn't print out the HloInstruction string. For
example, an opcode mismatch might have looked like previously:

Value of: body_data_add
Expected: subtract
  Actual: 0x7f58fe3e8c00 (of type xla::HloInstruction*)

With this CL, it now looks like:

Value of: body_data_add
Expected: subtract
  Actual: 0x7efefd68ec00 (of type xla::HloInstruction*), (%add.1 = f32[2,3]{1,0:S(1)} add(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %constant.2))
PiperOrigin-RevId: 258814523
2019-07-18 12:38:55 -07:00

303 lines
9.6 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
namespace xla {
namespace testing {
bool HloMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
// These cases are self-explanatory from the printed value.
if (!instruction) {
return false;
}
*listener << "(" << instruction->ToString() << ")";
if (instruction->opcode() != opcode_) {
return false;
}
// Special case: no operand matchers means don't verify.
if (operands_.empty()) {
return true;
}
const auto& operands = instruction->operands();
if (operands.size() != operands_.size()) {
*listener << " has too "
<< (operands.size() > operands_.size() ? "many" : "few")
<< " operands (got " << operands.size() << ", want "
<< operands_.size() << ")";
return false;
}
for (int index = 0; index < operands.size(); index++) {
::testing::StringMatchResultListener inner_listener;
if (!operands_[index].MatchAndExplain(operands[index], &inner_listener)) {
if (listener->IsInterested()) {
*listener << "\noperand " << index << ":\n\t"
<< operands[index]->ToString()
<< "\ndoesn't match expected:\n\t";
operands_[index].DescribeTo(listener->stream());
string explanation = inner_listener.str();
if (!explanation.empty()) {
*listener << ", " << explanation;
}
}
return false;
}
}
return true;
}
void HloMatcher::DescribeTo(::std::ostream* os) const {
*os << opcode_;
if (!operands_.empty()) {
*os << "(";
for (int i = 0; i < operands_.size(); i++) {
if (i > 0) {
*os << ", ";
}
operands_[i].DescribeTo(os);
}
*os << ")";
}
}
bool HloParameterMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
return false;
}
if (instruction->parameter_number() != parameter_number_) {
*listener << " has wrong parameter number (got "
<< instruction->parameter_number() << ", want "
<< parameter_number_ << ")";
return false;
}
return true;
}
bool HloComparisonMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
return false;
}
if (instruction->comparison_direction() != direction_) {
*listener << " has wrong comparison direction (got "
<< ComparisonDirectionToString(
instruction->comparison_direction())
<< ", want " << ComparisonDirectionToString(direction_) << ")";
return false;
}
return true;
}
bool HloGetTupleElementMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
return false;
}
if (instruction->tuple_index() != tuple_index_) {
*listener << " has wrong tuple index (got " << instruction->tuple_index()
<< ", want " << tuple_index_ << ")";
return false;
}
return true;
}
void HloCustomCallMatcher::DescribeTo(std::ostream* os) const {
HloMatcher::DescribeTo(os);
*os << " with call target that ";
call_target_matcher_.DescribeTo(os);
}
bool HloCustomCallMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
return false;
}
::testing::StringMatchResultListener sub_listener;
bool result = ExplainMatchResult(
call_target_matcher_, instruction->custom_call_target(), &sub_listener);
if (sub_listener.str().empty()) {
sub_listener << " that ";
std::stringstream desc_stream;
if (result) {
call_target_matcher_.DescribeTo(&desc_stream);
} else {
call_target_matcher_.DescribeNegationTo(&desc_stream);
}
sub_listener << desc_stream.str();
}
*listener << " custom-call with call target" << sub_listener.str();
return result;
}
bool HloShapeMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (ShapeUtil::Compatible(instruction->shape(), shape_)) {
return true;
}
*listener << instruction->ToString() << " has incorrect shape (expected: "
<< ShapeUtil::HumanString(shape_) << ")";
return false;
}
void HloShapeMatcher::DescribeTo(std::ostream* os) const {
*os << ShapeUtil::HumanString(shape_);
}
bool HloShapeAndLayoutMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (ShapeUtil::Equal(instruction->shape(), shape_)) {
return true;
}
*listener << instruction->ToString() << " has incorrect shape (expected: "
<< ShapeUtil::HumanStringWithLayout(shape_) << ")";
return false;
}
void HloShapeAndLayoutMatcher::DescribeTo(std::ostream* os) const {
*os << ShapeUtil::HumanStringWithLayout(shape_);
}
bool HloShardingMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (!sharding_.has_value()) {
if (!instruction->has_sharding()) {
return true;
}
*listener << instruction->ToString() << " expected to have no sharding.";
return false;
}
if (instruction->has_sharding()) {
if (instruction->sharding() == sharding_.value()) {
return true;
}
*listener << instruction->ToString()
<< " has incorrect sharding (expected: " << sharding_->ToString()
<< ")";
return false;
} else {
*listener << instruction->ToString()
<< " has no sharding (expected: " << sharding_->ToString() << ")";
return false;
}
}
void HloShardingMatcher::DescribeTo(std::ostream* os) const {
if (sharding_.has_value()) {
*os << sharding_->ToString();
} else {
*os << "<no-sharding>";
}
}
bool HloDotWithContractingDimsMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
return false;
}
const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers();
if (dim_nums.lhs_contracting_dimensions_size() != 1 ||
dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) {
*listener << " has wrong lhs_contracting_dimensions (got {"
<< absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",")
<< "} want {" << lhs_contracting_dim_ << "})";
return false;
}
if (dim_nums.rhs_contracting_dimensions_size() != 1 ||
dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) {
*listener << " has wrong rhs_contracting_dimensions (got {"
<< absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",")
<< "} want {" << rhs_contracting_dim_ << "})";
return false;
}
return true;
}
void HloDotWithContractingDimsMatcher::DescribeTo(std::ostream* os) const {
HloMatcher::DescribeTo(os);
*os << " with lhs_contracting_dims={" << lhs_contracting_dim_
<< "} and rhs_contracting_dims={" << rhs_contracting_dim_ << "}";
}
bool HloAsyncCopyMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
return false;
}
const HloInstruction* copy_done = instruction;
if (!copy_done->shape().has_layout()) {
*listener << " does not have layout, expected a layout with memory space "
<< to_space_;
return false;
}
if (copy_done->shape().layout().memory_space() != to_space_) {
*listener << " copies to memory space "
<< copy_done->shape().layout().memory_space() << ", expected "
<< to_space_;
return false;
}
const HloInstruction* copy_start_operand =
copy_done->operands()[0]->operands()[0];
if (!copy_start_operand->shape().has_layout()) {
*listener << copy_start_operand->ToString()
<< " does not have layout, expected a layout with memory space "
<< from_space_;
return false;
}
if (copy_start_operand->shape().layout().memory_space() != from_space_) {
*listener << " is in the memory space "
<< copy_start_operand->shape().layout().memory_space()
<< ", expected " << from_space_;
return false;
}
return true;
}
void HloAsyncCopyMatcher::DescribeTo(std::ostream* os) const {
HloMatcher::DescribeTo(os);
*os << " (copy from memory space " << from_space_ << " to " << to_space_
<< ")";
}
} // namespace testing
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {
*os << (inst ? inst->ToString() : "nullptr");
}
} // namespace xla