diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 4303b0f5763..29ad2943f2a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -188,8 +188,8 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { } Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { - if (crs->IsCrossModuleAllReduce()) { - // Cross-module all-reduce has side effect. + if (crs->HasSideEffectNoRecurse()) { + // Do not perform optimization on side-effected AllReduce. return Status::OK(); } // First use DefaultAction() to handle the operands. It can't handle @@ -226,6 +226,10 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { // Fold conversions only when all the get-tuple-elements' users are // conversions from F32 to BF16. auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() { + // If no uses then return false. (As no uses are bf16 converts). + if (per_tuple_element_gtes[i].empty()) { + return false; + } for (auto gte : per_tuple_element_gtes[i]) { if (!AllUsersAreF32ToBF16Converts(gte)) { return false; diff --git a/tensorflow/compiler/xla/tests/all_reduce_test.cc b/tensorflow/compiler/xla/tests/all_reduce_test.cc index 33a8db8de32..9f1f0030860 100644 --- a/tensorflow/compiler/xla/tests/all_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/all_reduce_test.cc @@ -100,5 +100,88 @@ XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) { ExecuteAndTransfer(std::move(module), {&literal0})); } +XLA_TEST_F(TrivialAllReduceTest, AllReduceU8) { + const char* module_str = R"( +HloModule test + +%AddComputation.15 { + %x.16 = u8[] parameter(0) + %y.17 = u8[] parameter(1) + ROOT %add.18 = u8[] add(u8[] %x.16, u8[] %y.17) +} + +ENTRY %test_computation { + %constant.4 = u8[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17} + %reshape.5 = u8[1]{0} reshape(u8[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17} + %broadcast.6 = u8[1]{0} broadcast(u8[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17} + %reshape.7 = u8[] reshape(u8[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17} + %broadcast.8 = u8[8]{0} broadcast(u8[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17} + %constant.2 = u8[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18} + %reshape.3 = u8[1]{0} reshape(u8[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563} + %constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563} + %dynamic-update-slice.10 = u8[8]{0} dynamic-update-slice(u8[8]{0} %broadcast.8, u8[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563} + %p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463} + %convert.11 = u8[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %tuple.12 = (u8[8]{0}, u8[]) tuple(u8[8]{0} %dynamic-update-slice.10, u8[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %get-tuple-element.13 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %get-tuple-element.14 = u8[] get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %all-reduce.19 = (u8[8]{0}, u8[]) all-reduce(u8[8]{0} %get-tuple-element.13, u8[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %get-tuple-element.21 = u8[] get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %convert.22 = f32[] convert(u8[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %get-tuple-element.20 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + ROOT %tuple.23 = (u8[8]{0}) tuple(u8[8]{0} %get-tuple-element.20) +})"; + + auto module = + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()) + .ValueOrDie(); + auto literal_in = LiteralUtil::CreateR0(0); + auto literal0 = LiteralUtil::CreateR1({1, 0, 0, 0, 0, 0, 0, 0}); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}), + ExecuteAndTransfer(std::move(module), {&literal_in})); +} + +XLA_TEST_F(TrivialAllReduceTest, AllReduceS32) { + const char* module_str = R"( + +HloModule test + +%AddComputation.15 { + %x.16 = s32[] parameter(0) + %y.17 = s32[] parameter(1) + ROOT %add.18 = s32[] add(s32[] %x.16, s32[] %y.17) +} + +ENTRY %test_computation { + %constant.4 = s32[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17} + %reshape.5 = s32[1]{0} reshape(s32[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17} + %broadcast.6 = s32[1]{0} broadcast(s32[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17} + %reshape.7 = s32[] reshape(s32[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17} + %broadcast.8 = s32[8]{0} broadcast(s32[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17} + %constant.2 = s32[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18} + %reshape.3 = s32[1]{0} reshape(s32[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563} + %constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563} + %dynamic-update-slice.10 = s32[8]{0} dynamic-update-slice(s32[8]{0} %broadcast.8, s32[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563} + %p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463} + %convert.11 = s32[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %tuple.12 = (s32[8]{0}, s32[]) tuple(s32[8]{0} %dynamic-update-slice.10, s32[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %get-tuple-element.13 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %get-tuple-element.14 = s32[] get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %all-reduce.19 = (s32[8]{0}, s32[]) all-reduce(s32[8]{0} %get-tuple-element.13, s32[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %get-tuple-element.21 = s32[] get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %convert.22 = f32[] convert(s32[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + %get-tuple-element.20 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560} + ROOT %tuple.23 = (s32[8]{0}) tuple(s32[8]{0} %get-tuple-element.20) +})"; + + auto module = + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()) + .ValueOrDie(); + auto literal_in = LiteralUtil::CreateR0(0); + auto literal0 = LiteralUtil::CreateR1({1, 0, 0, 0, 0, 0, 0, 0}); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}), + ExecuteAndTransfer(std::move(module), {&literal_in})); +} + } // namespace } // namespace xla