Add shape compatibility check for kCall and kConditional.
PiperOrigin-RevId: 204235233
This commit is contained in:
parent
26cd1d1d06
commit
05ba5ceb8e
@ -2121,6 +2121,7 @@ tf_cc_test(
|
||||
srcs = ["hlo_verifier_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_parser",
|
||||
":hlo_verifier",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
|
@ -127,6 +127,22 @@ Status CheckIsTokenOperand(const HloInstruction* instruction,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckOperandAndParameter(const HloInstruction* instruction,
|
||||
int64 operand_number,
|
||||
const HloComputation* computation,
|
||||
int64 parameter_number) {
|
||||
const HloInstruction* operand = instruction->operand(operand_number);
|
||||
const HloInstruction* parameter =
|
||||
computation->parameter_instruction(parameter_number);
|
||||
if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) {
|
||||
return InternalError("Operand %s shape does not match parameter's %s in %s",
|
||||
operand->ToString().c_str(),
|
||||
parameter->ToString().c_str(),
|
||||
instruction->ToString().c_str());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
|
||||
@ -253,8 +269,11 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
|
||||
Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
|
||||
|
||||
Status ShapeVerifier::HandleCall(HloInstruction* call) {
|
||||
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
|
||||
}
|
||||
// The shape of kCall should match the shape of the computation it calls.
|
||||
return CheckShape(call, call->to_apply()->ComputeProgramShape().result());
|
||||
return CheckShape(call, call->to_apply()->root_instruction()->shape());
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); }
|
||||
@ -323,19 +342,37 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
|
||||
const Shape& conditional_shape =
|
||||
xla_while->while_condition()->root_instruction()->shape();
|
||||
if (!ShapeUtil::Compatible(conditional_shape,
|
||||
ShapeUtil::MakeShape(PRED, {}))) {
|
||||
return InternalError(
|
||||
"Conditional computation shape does not lead to a scalar predicate "
|
||||
"shape: %s",
|
||||
ShapeUtil::HumanString(conditional_shape).c_str());
|
||||
}
|
||||
// The shape of kWhile should match the shape of the body computation it
|
||||
// calls.
|
||||
return CheckShape(xla_while,
|
||||
xla_while->while_body()->ComputeProgramShape().result());
|
||||
xla_while->while_body()->root_instruction()->shape());
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
|
||||
conditional, 1, conditional->true_computation(), 0));
|
||||
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
|
||||
conditional, 2, conditional->false_computation(), 0));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckShape(conditional,
|
||||
conditional->true_computation()->root_instruction()->shape()));
|
||||
TF_RETURN_IF_ERROR(CheckShape(
|
||||
conditional,
|
||||
conditional->true_computation()->ComputeProgramShape().result()));
|
||||
return CheckShape(
|
||||
conditional,
|
||||
conditional->false_computation()->ComputeProgramShape().result());
|
||||
conditional->false_computation()->root_instruction()->shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandlePad(HloInstruction* pad) {
|
||||
@ -802,33 +839,23 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
|
||||
"While loop must have exactly one operand; had %lld : %s",
|
||||
instruction->operand_count(), instruction->ToString().c_str());
|
||||
}
|
||||
auto* init = instruction->operand(0);
|
||||
auto* cond_param = while_cond->parameter_instruction(0);
|
||||
if (!ShapeUtil::Compatible(init->shape(), cond_param->shape())) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) {
|
||||
if (instruction->true_computation()->num_parameters() != 1) {
|
||||
return FailedPrecondition(
|
||||
"While condition's parameter must have the same shape as the "
|
||||
"loop's 'init'. init: %s, param: %s",
|
||||
init->ToString().c_str(), cond_param->ToString().c_str());
|
||||
"True computation %s of %s must have 1 parameter insted of %lld",
|
||||
instruction->true_computation()->name().c_str(),
|
||||
instruction->ToString().c_str(),
|
||||
instruction->true_computation()->num_parameters());
|
||||
}
|
||||
auto* cond_root = while_cond->root_instruction();
|
||||
if (!ShapeUtil::Compatible(cond_root->shape(),
|
||||
ShapeUtil::MakeShape(PRED, {}))) {
|
||||
return FailedPrecondition("While condition should have shape PRED: %s",
|
||||
cond_root->ToString().c_str());
|
||||
}
|
||||
auto* body_param = while_body->parameter_instruction(0);
|
||||
if (!ShapeUtil::Compatible(init->shape(), body_param->shape())) {
|
||||
if (instruction->false_computation()->num_parameters() != 1) {
|
||||
return FailedPrecondition(
|
||||
"While body's parameter must have the same shape as the loop's"
|
||||
" 'init'. init: %s, param: %s",
|
||||
init->ToString().c_str(), body_param->ToString().c_str());
|
||||
}
|
||||
auto* body_root = while_body->root_instruction();
|
||||
if (!ShapeUtil::Compatible(init->shape(), body_root->shape())) {
|
||||
return FailedPrecondition(
|
||||
"While body should have same shape as the loop's 'init'."
|
||||
"init: %s, body: %s",
|
||||
init->ToString().c_str(), body_root->ToString().c_str());
|
||||
"False computation %s of %s must have 1 parameter insted of %lld",
|
||||
instruction->false_computation()->name().c_str(),
|
||||
instruction->ToString().c_str(),
|
||||
instruction->false_computation()->num_parameters());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -924,6 +951,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
<< " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
|
||||
} else if (instruction->opcode() == HloOpcode::kWhile) {
|
||||
TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
|
||||
} else if (instruction->opcode() == HloOpcode::kConditional) {
|
||||
TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction));
|
||||
} else if (instruction->opcode() !=
|
||||
HloOpcode::kRng /* Rng operands are always scalar. */
|
||||
&& instruction->IsElementwise()) {
|
||||
|
@ -146,6 +146,8 @@ class HloVerifier : public HloPassInterface {
|
||||
|
||||
Status CheckWhileInstruction(HloInstruction* instruction);
|
||||
|
||||
Status CheckConditionalInstruction(HloInstruction* instruction);
|
||||
|
||||
// Checks that the non-scalar operand shapes are compatible to the output
|
||||
// shape, i.e., that there are no implicit broadcasts of size-one dimensions.
|
||||
Status CheckElementwiseInstruction(HloInstruction* instruction);
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
@ -123,5 +124,55 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
|
||||
EXPECT_FALSE(verifier().Run(module.get()).status().ok());
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
callme {
|
||||
ROOT param = (s32[], f32[4]) parameter(0)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = (f32[4], s32[]) parameter(0)
|
||||
ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("shape does not match parameter"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
true_branch {
|
||||
tparam = (s32[], f32[4]) parameter(0)
|
||||
ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
|
||||
}
|
||||
|
||||
false_branch {
|
||||
fparam = (s32[], f32[4]) parameter(0)
|
||||
ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = (f32[4], s32[]) parameter(0)
|
||||
constant = pred[] constant(true)
|
||||
ROOT conditional = f32[4] conditional(constant, p0, p0),
|
||||
true_computation=true_branch, false_computation=false_branch
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("shape does not match parameter"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user