Add shape compatibility check for kCall and kConditional.

PiperOrigin-RevId: 204235233
This commit is contained in:
A. Unique TensorFlower 2018-07-11 20:14:28 -07:00 committed by TensorFlower Gardener
parent 26cd1d1d06
commit 05ba5ceb8e
4 changed files with 113 additions and 30 deletions

View File

@ -2121,6 +2121,7 @@ tf_cc_test(
srcs = ["hlo_verifier_test.cc"],
deps = [
":hlo",
":hlo_parser",
":hlo_verifier",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",

View File

@ -127,6 +127,22 @@ Status CheckIsTokenOperand(const HloInstruction* instruction,
return Status::OK();
}
Status CheckOperandAndParameter(const HloInstruction* instruction,
int64 operand_number,
const HloComputation* computation,
int64 parameter_number) {
const HloInstruction* operand = instruction->operand(operand_number);
const HloInstruction* parameter =
computation->parameter_instruction(parameter_number);
if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) {
return InternalError("Operand %s shape does not match parameter's %s in %s",
operand->ToString().c_str(),
parameter->ToString().c_str(),
instruction->ToString().c_str());
}
return Status::OK();
}
} // namespace
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
@ -253,8 +269,11 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
Status ShapeVerifier::HandleCall(HloInstruction* call) {
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
}
// The shape of kCall should match the shape of the computation it calls.
return CheckShape(call, call->to_apply()->ComputeProgramShape().result());
return CheckShape(call, call->to_apply()->root_instruction()->shape());
}
Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); }
@ -323,19 +342,37 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
}
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
TF_RETURN_IF_ERROR(
CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
TF_RETURN_IF_ERROR(
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
const Shape& conditional_shape =
xla_while->while_condition()->root_instruction()->shape();
if (!ShapeUtil::Compatible(conditional_shape,
ShapeUtil::MakeShape(PRED, {}))) {
return InternalError(
"Conditional computation shape does not lead to a scalar predicate "
"shape: %s",
ShapeUtil::HumanString(conditional_shape).c_str());
}
// The shape of kWhile should match the shape of the body computation it
// calls.
return CheckShape(xla_while,
xla_while->while_body()->ComputeProgramShape().result());
xla_while->while_body()->root_instruction()->shape());
}
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
conditional, 1, conditional->true_computation(), 0));
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
conditional, 2, conditional->false_computation(), 0));
TF_RETURN_IF_ERROR(
CheckShape(conditional,
conditional->true_computation()->root_instruction()->shape()));
TF_RETURN_IF_ERROR(CheckShape(
conditional,
conditional->true_computation()->ComputeProgramShape().result()));
return CheckShape(
conditional,
conditional->false_computation()->ComputeProgramShape().result());
conditional->false_computation()->root_instruction()->shape()));
return Status::OK();
}
Status ShapeVerifier::HandlePad(HloInstruction* pad) {
@ -802,33 +839,23 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
"While loop must have exactly one operand; had %lld : %s",
instruction->operand_count(), instruction->ToString().c_str());
}
auto* init = instruction->operand(0);
auto* cond_param = while_cond->parameter_instruction(0);
if (!ShapeUtil::Compatible(init->shape(), cond_param->shape())) {
return Status::OK();
}
Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) {
if (instruction->true_computation()->num_parameters() != 1) {
return FailedPrecondition(
"While condition's parameter must have the same shape as the "
"loop's 'init'. init: %s, param: %s",
init->ToString().c_str(), cond_param->ToString().c_str());
"True computation %s of %s must have 1 parameter insted of %lld",
instruction->true_computation()->name().c_str(),
instruction->ToString().c_str(),
instruction->true_computation()->num_parameters());
}
auto* cond_root = while_cond->root_instruction();
if (!ShapeUtil::Compatible(cond_root->shape(),
ShapeUtil::MakeShape(PRED, {}))) {
return FailedPrecondition("While condition should have shape PRED: %s",
cond_root->ToString().c_str());
}
auto* body_param = while_body->parameter_instruction(0);
if (!ShapeUtil::Compatible(init->shape(), body_param->shape())) {
if (instruction->false_computation()->num_parameters() != 1) {
return FailedPrecondition(
"While body's parameter must have the same shape as the loop's"
" 'init'. init: %s, param: %s",
init->ToString().c_str(), body_param->ToString().c_str());
}
auto* body_root = while_body->root_instruction();
if (!ShapeUtil::Compatible(init->shape(), body_root->shape())) {
return FailedPrecondition(
"While body should have same shape as the loop's 'init'."
"init: %s, body: %s",
init->ToString().c_str(), body_root->ToString().c_str());
"False computation %s of %s must have 1 parameter insted of %lld",
instruction->false_computation()->name().c_str(),
instruction->ToString().c_str(),
instruction->false_computation()->num_parameters());
}
return Status::OK();
}
@ -924,6 +951,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
<< " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
} else if (instruction->opcode() == HloOpcode::kWhile) {
TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
} else if (instruction->opcode() == HloOpcode::kConditional) {
TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction));
} else if (instruction->opcode() !=
HloOpcode::kRng /* Rng operands are always scalar. */
&& instruction->IsElementwise()) {

View File

@ -146,6 +146,8 @@ class HloVerifier : public HloPassInterface {
Status CheckWhileInstruction(HloInstruction* instruction);
Status CheckConditionalInstruction(HloInstruction* instruction);
// Checks that the non-scalar operand shapes are compatible to the output
// shape, i.e., that there are no implicit broadcasts of size-one dimensions.
Status CheckElementwiseInstruction(HloInstruction* instruction);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@ -123,5 +124,55 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
EXPECT_FALSE(verifier().Run(module.get()).status().ok());
}
TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
const char* const hlo_string = R"(
HloModule Module
callme {
ROOT param = (s32[], f32[4]) parameter(0)
}
ENTRY entry {
p0 = (f32[4], s32[]) parameter(0)
ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("shape does not match parameter"));
}
TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
const char* const hlo_string = R"(
HloModule Module
true_branch {
tparam = (s32[], f32[4]) parameter(0)
ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
}
false_branch {
fparam = (s32[], f32[4]) parameter(0)
ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
}
ENTRY entry {
p0 = (f32[4], s32[]) parameter(0)
constant = pred[] constant(true)
ROOT conditional = f32[4] conditional(constant, p0, p0),
true_computation=true_branch, false_computation=false_branch
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("shape does not match parameter"));
}
} // namespace
} // namespace xla