From f77ec5d5ee3eba39f79fb768d9d5b3e3788b7edb Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 5 Feb 2019 14:05:59 -0800 Subject: [PATCH] [XLA] Fix HloElementTypeConverter's handling of bitcast convert. Previously, if you asked HloElementTypeConverter to e.g. convert BF16 to F32, it would convert a BF16 input to bitcast-convert to F32. This don't work, because the output of the bitcast-convert still has width 16 bits, and the input/output widths must match. The right thing is to ignore bitcast-convert operations for the purposes of this pass. PiperOrigin-RevId: 232554328 --- .../xla/service/hlo_element_type_converter.cc | 1 + .../xla/service/hlo_element_type_converter_test.cc | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 9b0f2b2a0f4..812db465c6d 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -127,6 +127,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops where it does not make sense to convert them. if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert || + opcode == HloOpcode::kBitcastConvert || opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index 5b633784e2f..4171f738620 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -176,5 +176,19 @@ ENTRY main { EXPECT_THAT(rng1->control_predecessors(), ElementsAre(rng0)); } +TEST_F(HloElementTypeConverterTest, BitcastConvertIsUnmodified) { + const string& hlo_string = R"( + HloModule test + + ENTRY test { + p = bf16[] parameter(0) + ROOT c = u16[] bitcast-convert(p) + })"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + HloElementTypeConverter converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, RunHloPass(&converter, module.get())); + EXPECT_FALSE(converted); +} + } // namespace } // namespace xla