Verify parameter count of called computation in HloVerifier.
PiperOrigin-RevId: 219662775
This commit is contained in:
parent
486dca315c
commit
a169b7aa78
@ -65,7 +65,9 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
|
|||||||
return VerifyNotSparse(hlo->shape());
|
return VerifyNotSparse(hlo->shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
static Status CheckOperandCount(const HloInstruction* hlo, int expected) {
|
namespace {
|
||||||
|
|
||||||
|
Status CheckOperandCount(const HloInstruction* hlo, int expected) {
|
||||||
if (hlo->operand_count() != expected) {
|
if (hlo->operand_count() != expected) {
|
||||||
return InternalError("Expected %d operands for %s instruction: %s",
|
return InternalError("Expected %d operands for %s instruction: %s",
|
||||||
expected, HloOpcodeString(hlo->opcode()),
|
expected, HloOpcodeString(hlo->opcode()),
|
||||||
@ -74,6 +76,19 @@ static Status CheckOperandCount(const HloInstruction* hlo, int expected) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CheckParameterCount(const HloInstruction* calling_instruction,
|
||||||
|
const HloComputation* computation, int expected) {
|
||||||
|
if (computation->num_parameters() != expected) {
|
||||||
|
return InternalError(
|
||||||
|
"Expected computation %s called from %s to have %d parameters, has %d",
|
||||||
|
computation->name(), calling_instruction->name(), expected,
|
||||||
|
computation->num_parameters());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
|
Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
|
||||||
return CheckUnaryShape(hlo);
|
return CheckUnaryShape(hlo);
|
||||||
}
|
}
|
||||||
@ -441,6 +456,8 @@ Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeVerifier::HandleCall(HloInstruction* call) {
|
Status ShapeVerifier::HandleCall(HloInstruction* call) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckParameterCount(call, call->to_apply(), call->operand_count()));
|
||||||
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
|
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
|
||||||
TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
|
TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
|
||||||
}
|
}
|
||||||
@ -540,6 +557,10 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
|
|||||||
|
|
||||||
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
|
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
|
||||||
TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1));
|
TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckParameterCount(xla_while, xla_while->while_body(), 1));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckParameterCount(xla_while, xla_while->while_condition(), 1));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
|
CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -560,6 +581,10 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
|
|||||||
|
|
||||||
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
|
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
|
||||||
TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3));
|
TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckParameterCount(conditional, conditional->true_computation(), 1));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckParameterCount(conditional, conditional->false_computation(), 1));
|
||||||
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
|
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
|
||||||
conditional, 1, conditional->true_computation(), 0));
|
conditional, 1, conditional->true_computation(), 0));
|
||||||
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
|
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user