[XLA] Stop the bf16 conversion folding from converting unused tuple outputs.
The BF16 conversion folder thinks that and unused tuple output is a candidate for conversion folding (even if its not used by any convert). Stop it from doing that. Also constrain_layout() AllReduce shouldn't be optimized by bf16 conversion folding. Also add some extra AllReduce test cases. PiperOrigin-RevId: 341466256 Change-Id: I3f3cf2fb2fb7bb6c301af4e50171f36ea9ddb56e
This commit is contained in:
parent
5d3f9b173a
commit
6ef526e758
@ -188,8 +188,8 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
|
||||
if (crs->IsCrossModuleAllReduce()) {
|
||||
// Cross-module all-reduce has side effect.
|
||||
if (crs->HasSideEffectNoRecurse()) {
|
||||
// Do not perform optimization on side-effected AllReduce.
|
||||
return Status::OK();
|
||||
}
|
||||
// First use DefaultAction() to handle the operands. It can't handle
|
||||
@ -226,6 +226,10 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
|
||||
// Fold conversions only when all the get-tuple-elements' users are
|
||||
// conversions from F32 to BF16.
|
||||
auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() {
|
||||
// If no uses then return false. (As no uses are bf16 converts).
|
||||
if (per_tuple_element_gtes[i].empty()) {
|
||||
return false;
|
||||
}
|
||||
for (auto gte : per_tuple_element_gtes[i]) {
|
||||
if (!AllUsersAreF32ToBF16Converts(gte)) {
|
||||
return false;
|
||||
|
@ -100,5 +100,88 @@ XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) {
|
||||
ExecuteAndTransfer(std::move(module), {&literal0}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TrivialAllReduceTest, AllReduceU8) {
|
||||
const char* module_str = R"(
|
||||
HloModule test
|
||||
|
||||
%AddComputation.15 {
|
||||
%x.16 = u8[] parameter(0)
|
||||
%y.17 = u8[] parameter(1)
|
||||
ROOT %add.18 = u8[] add(u8[] %x.16, u8[] %y.17)
|
||||
}
|
||||
|
||||
ENTRY %test_computation {
|
||||
%constant.4 = u8[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||
%reshape.5 = u8[1]{0} reshape(u8[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||
%broadcast.6 = u8[1]{0} broadcast(u8[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||
%reshape.7 = u8[] reshape(u8[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||
%broadcast.8 = u8[8]{0} broadcast(u8[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||
%constant.2 = u8[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18}
|
||||
%reshape.3 = u8[1]{0} reshape(u8[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563}
|
||||
%constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
|
||||
%dynamic-update-slice.10 = u8[8]{0} dynamic-update-slice(u8[8]{0} %broadcast.8, u8[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
|
||||
%p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463}
|
||||
%convert.11 = u8[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%tuple.12 = (u8[8]{0}, u8[]) tuple(u8[8]{0} %dynamic-update-slice.10, u8[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%get-tuple-element.13 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%get-tuple-element.14 = u8[] get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%all-reduce.19 = (u8[8]{0}, u8[]) all-reduce(u8[8]{0} %get-tuple-element.13, u8[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%get-tuple-element.21 = u8[] get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%convert.22 = f32[] convert(u8[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%get-tuple-element.20 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
ROOT %tuple.23 = (u8[8]{0}) tuple(u8[8]{0} %get-tuple-element.20)
|
||||
})";
|
||||
|
||||
auto module =
|
||||
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
|
||||
.ValueOrDie();
|
||||
auto literal_in = LiteralUtil::CreateR0<float>(0);
|
||||
auto literal0 = LiteralUtil::CreateR1<uint8_t>({1, 0, 0, 0, 0, 0, 0, 0});
|
||||
EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}),
|
||||
ExecuteAndTransfer(std::move(module), {&literal_in}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TrivialAllReduceTest, AllReduceS32) {
|
||||
const char* module_str = R"(
|
||||
|
||||
HloModule test
|
||||
|
||||
%AddComputation.15 {
|
||||
%x.16 = s32[] parameter(0)
|
||||
%y.17 = s32[] parameter(1)
|
||||
ROOT %add.18 = s32[] add(s32[] %x.16, s32[] %y.17)
|
||||
}
|
||||
|
||||
ENTRY %test_computation {
|
||||
%constant.4 = s32[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||
%reshape.5 = s32[1]{0} reshape(s32[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||
%broadcast.6 = s32[1]{0} broadcast(s32[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||
%reshape.7 = s32[] reshape(s32[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||
%broadcast.8 = s32[8]{0} broadcast(s32[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||
%constant.2 = s32[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18}
|
||||
%reshape.3 = s32[1]{0} reshape(s32[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563}
|
||||
%constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
|
||||
%dynamic-update-slice.10 = s32[8]{0} dynamic-update-slice(s32[8]{0} %broadcast.8, s32[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
|
||||
%p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463}
|
||||
%convert.11 = s32[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%tuple.12 = (s32[8]{0}, s32[]) tuple(s32[8]{0} %dynamic-update-slice.10, s32[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%get-tuple-element.13 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%get-tuple-element.14 = s32[] get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%all-reduce.19 = (s32[8]{0}, s32[]) all-reduce(s32[8]{0} %get-tuple-element.13, s32[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%get-tuple-element.21 = s32[] get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%convert.22 = f32[] convert(s32[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
%get-tuple-element.20 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||
ROOT %tuple.23 = (s32[8]{0}) tuple(s32[8]{0} %get-tuple-element.20)
|
||||
})";
|
||||
|
||||
auto module =
|
||||
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
|
||||
.ValueOrDie();
|
||||
auto literal_in = LiteralUtil::CreateR0<float>(0);
|
||||
auto literal0 = LiteralUtil::CreateR1<int32>({1, 0, 0, 0, 0, 0, 0, 0});
|
||||
EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}),
|
||||
ExecuteAndTransfer(std::move(module), {&literal_in}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user