[XLA] Stop the bf16 conversion folding from converting unused tuple outputs.
The BF16 conversion folder thinks that and unused tuple output is a candidate for conversion folding (even if its not used by any convert). Stop it from doing that. Also constrain_layout() AllReduce shouldn't be optimized by bf16 conversion folding. Also add some extra AllReduce test cases. PiperOrigin-RevId: 341466256 Change-Id: I3f3cf2fb2fb7bb6c301af4e50171f36ea9ddb56e
This commit is contained in:
parent
5d3f9b173a
commit
6ef526e758
@ -188,8 +188,8 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
|
Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
|
||||||
if (crs->IsCrossModuleAllReduce()) {
|
if (crs->HasSideEffectNoRecurse()) {
|
||||||
// Cross-module all-reduce has side effect.
|
// Do not perform optimization on side-effected AllReduce.
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
// First use DefaultAction() to handle the operands. It can't handle
|
// First use DefaultAction() to handle the operands. It can't handle
|
||||||
@ -226,6 +226,10 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
|
|||||||
// Fold conversions only when all the get-tuple-elements' users are
|
// Fold conversions only when all the get-tuple-elements' users are
|
||||||
// conversions from F32 to BF16.
|
// conversions from F32 to BF16.
|
||||||
auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() {
|
auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() {
|
||||||
|
// If no uses then return false. (As no uses are bf16 converts).
|
||||||
|
if (per_tuple_element_gtes[i].empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
for (auto gte : per_tuple_element_gtes[i]) {
|
for (auto gte : per_tuple_element_gtes[i]) {
|
||||||
if (!AllUsersAreF32ToBF16Converts(gte)) {
|
if (!AllUsersAreF32ToBF16Converts(gte)) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -100,5 +100,88 @@ XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) {
|
|||||||
ExecuteAndTransfer(std::move(module), {&literal0}));
|
ExecuteAndTransfer(std::move(module), {&literal0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TrivialAllReduceTest, AllReduceU8) {
|
||||||
|
const char* module_str = R"(
|
||||||
|
HloModule test
|
||||||
|
|
||||||
|
%AddComputation.15 {
|
||||||
|
%x.16 = u8[] parameter(0)
|
||||||
|
%y.17 = u8[] parameter(1)
|
||||||
|
ROOT %add.18 = u8[] add(u8[] %x.16, u8[] %y.17)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %test_computation {
|
||||||
|
%constant.4 = u8[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||||
|
%reshape.5 = u8[1]{0} reshape(u8[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||||
|
%broadcast.6 = u8[1]{0} broadcast(u8[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||||
|
%reshape.7 = u8[] reshape(u8[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||||
|
%broadcast.8 = u8[8]{0} broadcast(u8[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||||
|
%constant.2 = u8[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18}
|
||||||
|
%reshape.3 = u8[1]{0} reshape(u8[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563}
|
||||||
|
%constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
|
||||||
|
%dynamic-update-slice.10 = u8[8]{0} dynamic-update-slice(u8[8]{0} %broadcast.8, u8[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
|
||||||
|
%p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463}
|
||||||
|
%convert.11 = u8[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%tuple.12 = (u8[8]{0}, u8[]) tuple(u8[8]{0} %dynamic-update-slice.10, u8[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%get-tuple-element.13 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%get-tuple-element.14 = u8[] get-tuple-element((u8[8]{0}, u8[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%all-reduce.19 = (u8[8]{0}, u8[]) all-reduce(u8[8]{0} %get-tuple-element.13, u8[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%get-tuple-element.21 = u8[] get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%convert.22 = f32[] convert(u8[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%get-tuple-element.20 = u8[8]{0} get-tuple-element((u8[8]{0}, u8[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
ROOT %tuple.23 = (u8[8]{0}) tuple(u8[8]{0} %get-tuple-element.20)
|
||||||
|
})";
|
||||||
|
|
||||||
|
auto module =
|
||||||
|
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
|
||||||
|
.ValueOrDie();
|
||||||
|
auto literal_in = LiteralUtil::CreateR0<float>(0);
|
||||||
|
auto literal0 = LiteralUtil::CreateR1<uint8_t>({1, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}),
|
||||||
|
ExecuteAndTransfer(std::move(module), {&literal_in}));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TrivialAllReduceTest, AllReduceS32) {
|
||||||
|
const char* module_str = R"(
|
||||||
|
|
||||||
|
HloModule test
|
||||||
|
|
||||||
|
%AddComputation.15 {
|
||||||
|
%x.16 = s32[] parameter(0)
|
||||||
|
%y.17 = s32[] parameter(1)
|
||||||
|
ROOT %add.18 = s32[] add(s32[] %x.16, s32[] %y.17)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %test_computation {
|
||||||
|
%constant.4 = s32[] constant(0), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||||
|
%reshape.5 = s32[1]{0} reshape(s32[] %constant.4), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||||
|
%broadcast.6 = s32[1]{0} broadcast(s32[1]{0} %reshape.5), dimensions={0}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||||
|
%reshape.7 = s32[] reshape(s32[1]{0} %broadcast.6), metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||||
|
%broadcast.8 = s32[8]{0} broadcast(s32[] %reshape.7), dimensions={}, metadata={op_type="aten::expand" source_file="main@test_all_reduce_int.py" source_line=17}
|
||||||
|
%constant.2 = s32[] constant(1), metadata={op_type="prim::Constant" source_file="main@test_all_reduce_int.py" source_line=18}
|
||||||
|
%reshape.3 = s32[1]{0} reshape(s32[] %constant.2), metadata={op_type="aten::view" source_file="__format__@tensor.py" source_line=563}
|
||||||
|
%constant.9 = s64[] constant(0), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
|
||||||
|
%dynamic-update-slice.10 = s32[8]{0} dynamic-update-slice(s32[8]{0} %broadcast.8, s32[1]{0} %reshape.3, s64[] %constant.9), metadata={op_type="xla::update_slice" source_file="__format__@tensor.py" source_line=563}
|
||||||
|
%p0.1 = f32[] parameter(0), metadata={op_type="xla::device_data" source_file="_get_all_reduce_token@xla_model.py" source_line=463}
|
||||||
|
%convert.11 = s32[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%tuple.12 = (s32[8]{0}, s32[]) tuple(s32[8]{0} %dynamic-update-slice.10, s32[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%get-tuple-element.13 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%get-tuple-element.14 = s32[] get-tuple-element((s32[8]{0}, s32[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%all-reduce.19 = (s32[8]{0}, s32[]) all-reduce(s32[8]{0} %get-tuple-element.13, s32[] %get-tuple-element.14), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%get-tuple-element.21 = s32[] get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%convert.22 = f32[] convert(s32[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
%get-tuple-element.20 = s32[8]{0} get-tuple-element((s32[8]{0}, s32[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file="all_reduce@xla_model.py" source_line=560}
|
||||||
|
ROOT %tuple.23 = (s32[8]{0}) tuple(s32[8]{0} %get-tuple-element.20)
|
||||||
|
})";
|
||||||
|
|
||||||
|
auto module =
|
||||||
|
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
|
||||||
|
.ValueOrDie();
|
||||||
|
auto literal_in = LiteralUtil::CreateR0<float>(0);
|
||||||
|
auto literal0 = LiteralUtil::CreateR1<int32>({1, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
EXPECT_EQ(LiteralUtil::MakeTuple({&literal0}),
|
||||||
|
ExecuteAndTransfer(std::move(module), {&literal_in}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user