[XLA] Stop the bf16 conversion folding from converting unused tuple outputs.

The BF16 conversion folder thinks that and unused tuple output is a candidate for
conversion folding (even if its not used by any convert). Stop it from doing that.
Also constrain_layout() AllReduce shouldn't be optimized by bf16 conversion folding.

Also add some extra AllReduce test cases.

PiperOrigin-RevId: 341466256
Change-Id: I3f3cf2fb2fb7bb6c301af4e50171f36ea9ddb56e
This commit is contained in:
Marcello Maggioni 2020-11-09 12:52:38 -08:00 committed by TensorFlower Gardener
parent 5d3f9b173a
commit 6ef526e758
2 changed files with 89 additions and 2 deletions

View File

@ -188,8 +188,8 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
} }
Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
if (crs->IsCrossModuleAllReduce()) { if (crs->HasSideEffectNoRecurse()) {
// Cross-module all-reduce has side effect. // Do not perform optimization on side-effected AllReduce.
return Status::OK(); return Status::OK();
} }
// First use DefaultAction() to handle the operands. It can't handle // First use DefaultAction() to handle the operands. It can't handle
@ -226,6 +226,10 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
// Fold conversions only when all the get-tuple-elements' users are // Fold conversions only when all the get-tuple-elements' users are
// conversions from F32 to BF16. // conversions from F32 to BF16.
auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() { auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() {
// If no uses then return false. (As no uses are bf16 converts).
if (per_tuple_element_gtes[i].empty()) {
return false;
}
for (auto gte : per_tuple_element_gtes[i]) { for (auto gte : per_tuple_element_gtes[i]) {
if (!AllUsersAreF32ToBF16Converts(gte)) { if (!AllUsersAreF32ToBF16Converts(gte)) {
return false; return false;

View File

@ -100,5 +100,88 @@ XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) {
ExecuteAndTransfer(std::move(module), {&literal0})); ExecuteAndTransfer(std::move(module), {&literal0}));
} }
XLA_TEST_F(TrivialAllReduceTest, AllReduceU8) {
const char* module_str = R"(
HloModule test
%AddComputation.15 {
%x.16 = u8[] parameter(0)
%y.17 = u8[] parameter(1)
ROOT %add.18 = u8[] add(u8[] %x.16, u8[] %y.17)
}
ENTRY %test_computation {
%constant.4 = u8[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17}
%reshape.5 = u8[1]{0} reshape(u8[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
%broadcast.6 = u8[1]{0} broadcast(u8[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
%reshape.7 = u8[] reshape(u8[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
%broadcast.8 = u8[8]{0} broadcast(u8[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
%constant.2 = u8[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18}
%reshape.3 = u8[1]{0} reshape(u8[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563}
%constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
%dynamic-update-slice.10 = u8[8]{0} dynamic-update-slice(u8[8]{0} %broadcast.8, u8[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
%p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463}
%convert.11 = u8[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%tuple.12 = (u8[8]{0}, u8[]) tuple(u8[8]{0} %dynamic-update-slice.10, u8[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%get-tuple-element.13 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%get-tuple-element.14 = u8[] get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%all-reduce.19 = (u8[8]{0}, u8[]) all-reduce(u8[8]{0} %get-tuple-element.13, u8[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%get-tuple-element.21 = u8[] get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%convert.22 = f32[] convert(u8[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%get-tuple-element.20 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
ROOT %tuple.23 = (u8[8]{0}) tuple(u8[8]{0} %get-tuple-element.20)
})";
auto module =
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
.ValueOrDie();
auto literal_in = LiteralUtil::CreateR0<float>(0);
auto literal0 = LiteralUtil::CreateR1<uint8_t>({1, 0, 0, 0, 0, 0, 0, 0});
EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}),
ExecuteAndTransfer(std::move(module), {&literal_in}));
}
XLA_TEST_F(TrivialAllReduceTest, AllReduceS32) {
const char* module_str = R"(
HloModule test
%AddComputation.15 {
%x.16 = s32[] parameter(0)
%y.17 = s32[] parameter(1)
ROOT %add.18 = s32[] add(s32[] %x.16, s32[] %y.17)
}
ENTRY %test_computation {
%constant.4 = s32[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17}
%reshape.5 = s32[1]{0} reshape(s32[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
%broadcast.6 = s32[1]{0} broadcast(s32[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
%reshape.7 = s32[] reshape(s32[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
%broadcast.8 = s32[8]{0} broadcast(s32[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
%constant.2 = s32[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18}
%reshape.3 = s32[1]{0} reshape(s32[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563}
%constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
%dynamic-update-slice.10 = s32[8]{0} dynamic-update-slice(s32[8]{0} %broadcast.8, s32[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
%p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463}
%convert.11 = s32[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%tuple.12 = (s32[8]{0}, s32[]) tuple(s32[8]{0} %dynamic-update-slice.10, s32[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%get-tuple-element.13 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%get-tuple-element.14 = s32[] get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%all-reduce.19 = (s32[8]{0}, s32[]) all-reduce(s32[8]{0} %get-tuple-element.13, s32[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%get-tuple-element.21 = s32[] get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%convert.22 = f32[] convert(s32[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
%get-tuple-element.20 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
ROOT %tuple.23 = (s32[8]{0}) tuple(s32[8]{0} %get-tuple-element.20)
})";
auto module =
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
.ValueOrDie();
auto literal_in = LiteralUtil::CreateR0<float>(0);
auto literal0 = LiteralUtil::CreateR1<int32>({1, 0, 0, 0, 0, 0, 0, 0});
EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}),
ExecuteAndTransfer(std::move(module), {&literal_in}));
}
} // namespace } // namespace
} // namespace xla } // namespace xla