Verify parameter count of called computation in HloVerifier.
PiperOrigin-RevId: 219662775
This commit is contained in:
parent
486dca315c
commit
a169b7aa78
@ -65,7 +65,9 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
|
||||
return VerifyNotSparse(hlo->shape());
|
||||
}
|
||||
|
||||
static Status CheckOperandCount(const HloInstruction* hlo, int expected) {
|
||||
namespace {
|
||||
|
||||
Status CheckOperandCount(const HloInstruction* hlo, int expected) {
|
||||
if (hlo->operand_count() != expected) {
|
||||
return InternalError("Expected %d operands for %s instruction: %s",
|
||||
expected, HloOpcodeString(hlo->opcode()),
|
||||
@ -74,6 +76,19 @@ static Status CheckOperandCount(const HloInstruction* hlo, int expected) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckParameterCount(const HloInstruction* calling_instruction,
|
||||
const HloComputation* computation, int expected) {
|
||||
if (computation->num_parameters() != expected) {
|
||||
return InternalError(
|
||||
"Expected computation %s called from %s to have %d parameters, has %d",
|
||||
computation->name(), calling_instruction->name(), expected,
|
||||
computation->num_parameters());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
|
||||
return CheckUnaryShape(hlo);
|
||||
}
|
||||
@ -441,6 +456,8 @@ Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleCall(HloInstruction* call) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckParameterCount(call, call->to_apply(), call->operand_count()));
|
||||
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
|
||||
}
|
||||
@ -540,6 +557,10 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
|
||||
|
||||
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckParameterCount(xla_while, xla_while->while_body(), 1));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckParameterCount(xla_while, xla_while->while_condition(), 1));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -560,6 +581,10 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
|
||||
|
||||
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckParameterCount(conditional, conditional->true_computation(), 1));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckParameterCount(conditional, conditional->false_computation(), 1));
|
||||
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
|
||||
conditional, 1, conditional->true_computation(), 0));
|
||||
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
|
||||
|
Loading…
Reference in New Issue
Block a user