diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 9b0f2b2a0f4..812db465c6d 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -127,6 +127,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops where it does not make sense to convert them. if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert || + opcode == HloOpcode::kBitcastConvert || opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index 5b633784e2f..4171f738620 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -176,5 +176,19 @@ ENTRY main { EXPECT_THAT(rng1->control_predecessors(), ElementsAre(rng0)); } +TEST_F(HloElementTypeConverterTest, BitcastConvertIsUnmodified) { + const string& hlo_string = R"( + HloModule test + + ENTRY test { + p = bf16[] parameter(0) + ROOT c = u16[] bitcast-convert(p) + })"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + HloElementTypeConverter converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, RunHloPass(&converter, module.get())); + EXPECT_FALSE(converted); +} + } // namespace } // namespace xla