Add shape compatibility check for kCall and kConditional.
PiperOrigin-RevId: 204235233
This commit is contained in:
parent
26cd1d1d06
commit
05ba5ceb8e
@ -2121,6 +2121,7 @@ tf_cc_test(
|
|||||||
srcs = ["hlo_verifier_test.cc"],
|
srcs = ["hlo_verifier_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":hlo_parser",
|
||||||
":hlo_verifier",
|
":hlo_verifier",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
|
@ -127,6 +127,22 @@ Status CheckIsTokenOperand(const HloInstruction* instruction,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CheckOperandAndParameter(const HloInstruction* instruction,
|
||||||
|
int64 operand_number,
|
||||||
|
const HloComputation* computation,
|
||||||
|
int64 parameter_number) {
|
||||||
|
const HloInstruction* operand = instruction->operand(operand_number);
|
||||||
|
const HloInstruction* parameter =
|
||||||
|
computation->parameter_instruction(parameter_number);
|
||||||
|
if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) {
|
||||||
|
return InternalError("Operand %s shape does not match parameter's %s in %s",
|
||||||
|
operand->ToString().c_str(),
|
||||||
|
parameter->ToString().c_str(),
|
||||||
|
instruction->ToString().c_str());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
|
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
|
||||||
@ -253,8 +269,11 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
|
|||||||
Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
|
Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
|
||||||
|
|
||||||
Status ShapeVerifier::HandleCall(HloInstruction* call) {
|
Status ShapeVerifier::HandleCall(HloInstruction* call) {
|
||||||
|
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
|
||||||
|
}
|
||||||
// The shape of kCall should match the shape of the computation it calls.
|
// The shape of kCall should match the shape of the computation it calls.
|
||||||
return CheckShape(call, call->to_apply()->ComputeProgramShape().result());
|
return CheckShape(call, call->to_apply()->root_instruction()->shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); }
|
Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); }
|
||||||
@ -323,19 +342,37 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
|
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
|
||||||
|
const Shape& conditional_shape =
|
||||||
|
xla_while->while_condition()->root_instruction()->shape();
|
||||||
|
if (!ShapeUtil::Compatible(conditional_shape,
|
||||||
|
ShapeUtil::MakeShape(PRED, {}))) {
|
||||||
|
return InternalError(
|
||||||
|
"Conditional computation shape does not lead to a scalar predicate "
|
||||||
|
"shape: %s",
|
||||||
|
ShapeUtil::HumanString(conditional_shape).c_str());
|
||||||
|
}
|
||||||
// The shape of kWhile should match the shape of the body computation it
|
// The shape of kWhile should match the shape of the body computation it
|
||||||
// calls.
|
// calls.
|
||||||
return CheckShape(xla_while,
|
return CheckShape(xla_while,
|
||||||
xla_while->while_body()->ComputeProgramShape().result());
|
xla_while->while_body()->root_instruction()->shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
|
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
|
||||||
|
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
|
||||||
|
conditional, 1, conditional->true_computation(), 0));
|
||||||
|
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
|
||||||
|
conditional, 2, conditional->false_computation(), 0));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckShape(conditional,
|
||||||
|
conditional->true_computation()->root_instruction()->shape()));
|
||||||
TF_RETURN_IF_ERROR(CheckShape(
|
TF_RETURN_IF_ERROR(CheckShape(
|
||||||
conditional,
|
conditional,
|
||||||
conditional->true_computation()->ComputeProgramShape().result()));
|
conditional->false_computation()->root_instruction()->shape()));
|
||||||
return CheckShape(
|
return Status::OK();
|
||||||
conditional,
|
|
||||||
conditional->false_computation()->ComputeProgramShape().result());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ShapeVerifier::HandlePad(HloInstruction* pad) {
|
Status ShapeVerifier::HandlePad(HloInstruction* pad) {
|
||||||
@ -802,33 +839,23 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
|
|||||||
"While loop must have exactly one operand; had %lld : %s",
|
"While loop must have exactly one operand; had %lld : %s",
|
||||||
instruction->operand_count(), instruction->ToString().c_str());
|
instruction->operand_count(), instruction->ToString().c_str());
|
||||||
}
|
}
|
||||||
auto* init = instruction->operand(0);
|
return Status::OK();
|
||||||
auto* cond_param = while_cond->parameter_instruction(0);
|
}
|
||||||
if (!ShapeUtil::Compatible(init->shape(), cond_param->shape())) {
|
|
||||||
|
Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) {
|
||||||
|
if (instruction->true_computation()->num_parameters() != 1) {
|
||||||
return FailedPrecondition(
|
return FailedPrecondition(
|
||||||
"While condition's parameter must have the same shape as the "
|
"True computation %s of %s must have 1 parameter insted of %lld",
|
||||||
"loop's 'init'. init: %s, param: %s",
|
instruction->true_computation()->name().c_str(),
|
||||||
init->ToString().c_str(), cond_param->ToString().c_str());
|
instruction->ToString().c_str(),
|
||||||
|
instruction->true_computation()->num_parameters());
|
||||||
}
|
}
|
||||||
auto* cond_root = while_cond->root_instruction();
|
if (instruction->false_computation()->num_parameters() != 1) {
|
||||||
if (!ShapeUtil::Compatible(cond_root->shape(),
|
|
||||||
ShapeUtil::MakeShape(PRED, {}))) {
|
|
||||||
return FailedPrecondition("While condition should have shape PRED: %s",
|
|
||||||
cond_root->ToString().c_str());
|
|
||||||
}
|
|
||||||
auto* body_param = while_body->parameter_instruction(0);
|
|
||||||
if (!ShapeUtil::Compatible(init->shape(), body_param->shape())) {
|
|
||||||
return FailedPrecondition(
|
return FailedPrecondition(
|
||||||
"While body's parameter must have the same shape as the loop's"
|
"False computation %s of %s must have 1 parameter insted of %lld",
|
||||||
" 'init'. init: %s, param: %s",
|
instruction->false_computation()->name().c_str(),
|
||||||
init->ToString().c_str(), body_param->ToString().c_str());
|
instruction->ToString().c_str(),
|
||||||
}
|
instruction->false_computation()->num_parameters());
|
||||||
auto* body_root = while_body->root_instruction();
|
|
||||||
if (!ShapeUtil::Compatible(init->shape(), body_root->shape())) {
|
|
||||||
return FailedPrecondition(
|
|
||||||
"While body should have same shape as the loop's 'init'."
|
|
||||||
"init: %s, body: %s",
|
|
||||||
init->ToString().c_str(), body_root->ToString().c_str());
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -924,6 +951,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
|||||||
<< " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
|
<< " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
|
||||||
} else if (instruction->opcode() == HloOpcode::kWhile) {
|
} else if (instruction->opcode() == HloOpcode::kWhile) {
|
||||||
TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
|
TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
|
||||||
|
} else if (instruction->opcode() == HloOpcode::kConditional) {
|
||||||
|
TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction));
|
||||||
} else if (instruction->opcode() !=
|
} else if (instruction->opcode() !=
|
||||||
HloOpcode::kRng /* Rng operands are always scalar. */
|
HloOpcode::kRng /* Rng operands are always scalar. */
|
||||||
&& instruction->IsElementwise()) {
|
&& instruction->IsElementwise()) {
|
||||||
|
@ -146,6 +146,8 @@ class HloVerifier : public HloPassInterface {
|
|||||||
|
|
||||||
Status CheckWhileInstruction(HloInstruction* instruction);
|
Status CheckWhileInstruction(HloInstruction* instruction);
|
||||||
|
|
||||||
|
Status CheckConditionalInstruction(HloInstruction* instruction);
|
||||||
|
|
||||||
// Checks that the non-scalar operand shapes are compatible to the output
|
// Checks that the non-scalar operand shapes are compatible to the output
|
||||||
// shape, i.e., that there are no implicit broadcasts of size-one dimensions.
|
// shape, i.e., that there are no implicit broadcasts of size-one dimensions.
|
||||||
Status CheckElementwiseInstruction(HloInstruction* instruction);
|
Status CheckElementwiseInstruction(HloInstruction* instruction);
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
@ -123,5 +124,55 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
|
|||||||
EXPECT_FALSE(verifier().Run(module.get()).status().ok());
|
EXPECT_FALSE(verifier().Run(module.get()).status().ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule Module
|
||||||
|
|
||||||
|
callme {
|
||||||
|
ROOT param = (s32[], f32[4]) parameter(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
p0 = (f32[4], s32[]) parameter(0)
|
||||||
|
ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
|
||||||
|
|
||||||
|
auto status = verifier().Run(module.get()).status();
|
||||||
|
ASSERT_FALSE(status.ok());
|
||||||
|
EXPECT_THAT(status.error_message(),
|
||||||
|
HasSubstr("shape does not match parameter"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule Module
|
||||||
|
|
||||||
|
true_branch {
|
||||||
|
tparam = (s32[], f32[4]) parameter(0)
|
||||||
|
ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
|
||||||
|
}
|
||||||
|
|
||||||
|
false_branch {
|
||||||
|
fparam = (s32[], f32[4]) parameter(0)
|
||||||
|
ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
p0 = (f32[4], s32[]) parameter(0)
|
||||||
|
constant = pred[] constant(true)
|
||||||
|
ROOT conditional = f32[4] conditional(constant, p0, p0),
|
||||||
|
true_computation=true_branch, false_computation=false_branch
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
|
||||||
|
|
||||||
|
auto status = verifier().Run(module.get()).status();
|
||||||
|
ASSERT_FALSE(status.ok());
|
||||||
|
EXPECT_THAT(status.error_message(),
|
||||||
|
HasSubstr("shape does not match parameter"));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user