[XLA] Fix dot/conv in ShapeVerifier for !allow_mixed_precision
It is not a verifier error for conv/dot to have differing operands due to features like preferred element type. PiperOrigin-RevId: 356609105 Change-Id: I0e66cb86c96cb5bc8be56af21467d3638fd39047
This commit is contained in:
parent
26ed22700d
commit
b060f14d2f
tensorflow/compiler/xla/service
@ -601,8 +601,8 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
|
||||
/*only_compare_minor_to_major_in_layout=*/true);
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
|
||||
auto* iota = Cast<HloIotaInstruction>(instruction);
|
||||
Status ShapeVerifier::HandleIota(HloInstruction* hlo) {
|
||||
auto* iota = Cast<HloIotaInstruction>(hlo);
|
||||
if (!iota->shape().IsArray()) {
|
||||
return InternalError("Iota does not support non-array result.");
|
||||
}
|
||||
@ -1077,6 +1077,8 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kConvolution:
|
||||
case HloOpcode::kDot:
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kCopyDone:
|
||||
case HloOpcode::kCopyStart:
|
||||
|
@ -47,7 +47,7 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
Status HandleSelect(HloInstruction* select) override;
|
||||
Status HandleTupleSelect(HloInstruction* tuple_select) override;
|
||||
Status HandleConcatenate(HloInstruction* concatenate) override;
|
||||
Status HandleIota(HloInstruction* iota) override;
|
||||
Status HandleIota(HloInstruction* hlo) override;
|
||||
Status HandleConvert(HloInstruction* convert) override;
|
||||
Status HandleBitcastConvert(HloInstruction* convert) override;
|
||||
Status HandleCopy(HloInstruction* copy) override;
|
||||
|
@ -394,6 +394,21 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
|
||||
HasSubstr("Interior padding cannot be negative"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, DotMixedPrecisionAllowed) {
|
||||
static const char* const kDotHloString = R"(
|
||||
HloModule module
|
||||
ENTRY entry_computation {
|
||||
a = f32[2,10] parameter(0)
|
||||
b = bf16[10,2] parameter(1)
|
||||
ROOT dot = f32[2,2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kDotHloString));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
EXPECT_TRUE(status.ok()) << status;
|
||||
}
|
||||
|
||||
// Simple module containing a convolution as the root.
|
||||
static const char* const kConvHloString = R"(
|
||||
HloModule module
|
||||
|
Loading…
Reference in New Issue
Block a user