Verify parameter count of called computation in HloVerifier.

PiperOrigin-RevId: 219662775
This commit is contained in:
Mark Heffernan 2018-11-01 11:00:16 -07:00 committed by TensorFlower Gardener
parent 486dca315c
commit a169b7aa78

View File

@ -65,7 +65,9 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
return VerifyNotSparse(hlo->shape());
}
static Status CheckOperandCount(const HloInstruction* hlo, int expected) {
namespace {
Status CheckOperandCount(const HloInstruction* hlo, int expected) {
if (hlo->operand_count() != expected) {
return InternalError("Expected %d operands for %s instruction: %s",
expected, HloOpcodeString(hlo->opcode()),
@ -74,6 +76,19 @@ static Status CheckOperandCount(const HloInstruction* hlo, int expected) {
return Status::OK();
}
Status CheckParameterCount(const HloInstruction* calling_instruction,
const HloComputation* computation, int expected) {
if (computation->num_parameters() != expected) {
return InternalError(
"Expected computation %s called from %s to have %d parameters, has %d",
computation->name(), calling_instruction->name(), expected,
computation->num_parameters());
}
return Status::OK();
}
} // namespace
Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
return CheckUnaryShape(hlo);
}
@ -441,6 +456,8 @@ Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
}
Status ShapeVerifier::HandleCall(HloInstruction* call) {
TF_RETURN_IF_ERROR(
CheckParameterCount(call, call->to_apply(), call->operand_count()));
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
}
@ -540,6 +557,10 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1));
TF_RETURN_IF_ERROR(
CheckParameterCount(xla_while, xla_while->while_body(), 1));
TF_RETURN_IF_ERROR(
CheckParameterCount(xla_while, xla_while->while_condition(), 1));
TF_RETURN_IF_ERROR(
CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
TF_RETURN_IF_ERROR(
@ -560,6 +581,10 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3));
TF_RETURN_IF_ERROR(
CheckParameterCount(conditional, conditional->true_computation(), 1));
TF_RETURN_IF_ERROR(
CheckParameterCount(conditional, conditional->false_computation(), 1));
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
conditional, 1, conditional->true_computation(), 0));
TF_RETURN_IF_ERROR(CheckOperandAndParameter(