From aa53e46dc6053dfe38b1e4ce415d931601f96c4b Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Mon, 21 Sep 2020 13:04:15 -0700 Subject: [PATCH] [XLA] Make mixed-precision DUS a verifier failure even in mixed precision mode. PiperOrigin-RevId: 332915423 Change-Id: I1d611890022b80b24f73b5fcb09af9cd4f5feb9d --- .../compiler/xla/service/hlo_verifier.cc | 1 + .../compiler/xla/service/hlo_verifier_test.cc | 22 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 0346e9077a0..b3603e4fe5c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1160,6 +1160,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, case HloOpcode::kCopyDone: case HloOpcode::kCopyStart: case HloOpcode::kCustomCall: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 1f71c9586d5..0df30166a1c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -494,6 +494,28 @@ TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) { ASSERT_TRUE(status.ok()); } +TEST_F(HloVerifierTestAllowMixedPrecision, DynamicUpdateSliceMixedPrecision) { + const char* const kDynamicUpdateSliceMixedPrecision = R"( + HloModule kDynamicUpdateSliceMixedPrecision + + ENTRY %entry (parameter.0: f32[32,511,2048], parameter.1: bf16[32,511,512], parameter.2: s32[], parameter.3: s32[], parameter.4: s32[]) -> bf16[32,511,2048] { + %parameter.0 = f32[32,511,2048] parameter(0) + %parameter.1 = bf16[32,511,512] parameter(1) + %parameter.2 = s32[] parameter(2) + %parameter.3 = s32[] parameter(3) + %parameter.4 = s32[] parameter(4) + ROOT %dus = bf16[32,511,2048] dynamic-update-slice(f32[32,511,2048] %parameter.0, bf16[32,511,512] %parameter.1, s32[] %parameter.2, s32[] %parameter.3, s32[] %parameter.4) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule( + kDynamicUpdateSliceMixedPrecision)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected instruction to have shape equal to " + "f32[32,511,2048], actual shape is bf16[32,511,2048]")); +} + TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));