[XLA] Fix dot/conv in ShapeVerifier for !allow_mixed_precision

It is not a verifier error for conv/dot to have differing operands due to features like preferred element type.

PiperOrigin-RevId: 356609105
Change-Id: I0e66cb86c96cb5bc8be56af21467d3638fd39047
This commit is contained in:
David Majnemer 2021-02-09 15:42:29 -08:00 committed by TensorFlower Gardener
parent 26ed22700d
commit b060f14d2f
3 changed files with 20 additions and 3 deletions

View File

@ -601,8 +601,8 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
/*only_compare_minor_to_major_in_layout=*/true);
}
Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction);
Status ShapeVerifier::HandleIota(HloInstruction* hlo) {
auto* iota = Cast<HloIotaInstruction>(hlo);
if (!iota->shape().IsArray()) {
return InternalError("Iota does not support non-array result.");
}
@ -1077,6 +1077,8 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kConstant:
case HloOpcode::kConvolution:
case HloOpcode::kDot:
case HloOpcode::kAllReduce:
case HloOpcode::kCopyDone:
case HloOpcode::kCopyStart:

View File

@ -47,7 +47,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleSelect(HloInstruction* select) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleIota(HloInstruction* iota) override;
Status HandleIota(HloInstruction* hlo) override;
Status HandleConvert(HloInstruction* convert) override;
Status HandleBitcastConvert(HloInstruction* convert) override;
Status HandleCopy(HloInstruction* copy) override;

View File

@ -394,6 +394,21 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
HasSubstr("Interior padding cannot be negative"));
}
TEST_F(HloVerifierTest, DotMixedPrecisionAllowed) {
static const char* const kDotHloString = R"(
HloModule module
ENTRY entry_computation {
a = f32[2,10] parameter(0)
b = bf16[10,2] parameter(1)
ROOT dot = f32[2,2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kDotHloString));
auto status = verifier().Run(module.get()).status();
EXPECT_TRUE(status.ok()) << status;
}
// Simple module containing a convolution as the root.
static const char* const kConvHloString = R"(
HloModule module