[XLA] Make mixed-precision DUS a verifier failure even in mixed precision mode.

PiperOrigin-RevId: 332915423
Change-Id: I1d611890022b80b24f73b5fcb09af9cd4f5feb9d
This commit is contained in:
Berkin Ilbeyi 2020-09-21 13:04:15 -07:00 committed by TensorFlower Gardener
parent 33f47ab16d
commit aa53e46dc6
2 changed files with 23 additions and 0 deletions

View File

@ -1160,6 +1160,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
case HloOpcode::kCopyDone:
case HloOpcode::kCopyStart:
case HloOpcode::kCustomCall:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kGetTupleElement:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:

View File

@ -494,6 +494,28 @@ TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTestAllowMixedPrecision, DynamicUpdateSliceMixedPrecision) {
const char* const kDynamicUpdateSliceMixedPrecision = R"(
HloModule kDynamicUpdateSliceMixedPrecision
ENTRY %entry (parameter.0: f32[32,511,2048], parameter.1: bf16[32,511,512], parameter.2: s32[], parameter.3: s32[], parameter.4: s32[]) -> bf16[32,511,2048] {
%parameter.0 = f32[32,511,2048] parameter(0)
%parameter.1 = bf16[32,511,512] parameter(1)
%parameter.2 = s32[] parameter(2)
%parameter.3 = s32[] parameter(3)
%parameter.4 = s32[] parameter(4)
ROOT %dus = bf16[32,511,2048] dynamic-update-slice(f32[32,511,2048] %parameter.0, bf16[32,511,512] %parameter.1, s32[] %parameter.2, s32[] %parameter.3, s32[] %parameter.4)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
kDynamicUpdateSliceMixedPrecision));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected instruction to have shape equal to "
"f32[32,511,2048], actual shape is bf16[32,511,2048]"));
}
TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));