[XLA] Bug fix for bfloat16 propagation where fixed-point repetition of resolving aliasing operands are needed.
PiperOrigin-RevId: 340723186 Change-Id: I9e8769bd4e35b7e41d2feea9a7036cc7e2c2d303
This commit is contained in:
parent
9c510da34b
commit
4e9997d049
@ -579,12 +579,15 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
auto insts = computation->MakeInstructionPostOrder();
|
||||
// Do the adjustment on each instruction in the computation in reverse
|
||||
// topological order.
|
||||
while (true) {
|
||||
bool any_change = false;
|
||||
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
|
||||
auto hlo = *inst_it;
|
||||
auto adjust_hlo_output = [this, hlo, ¶meter_changed](
|
||||
const Shape& /* subshape */,
|
||||
auto adjust_hlo_output = [&](const Shape& /* subshape */,
|
||||
const ShapeIndex& index) {
|
||||
auto output_type = OutputTypeAfterChange(hlo, index);
|
||||
VLOG(2) << "output_type is " << ((output_type == BF16) ? "BF16" : "F32")
|
||||
<< " for :" << hlo->ToString() << "\n";
|
||||
if (output_type != F32 && output_type != BF16) {
|
||||
return;
|
||||
}
|
||||
@ -594,6 +597,8 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
if (value_type == BF16) {
|
||||
continue;
|
||||
}
|
||||
VLOG(2) << "Adjust to F32 due to aliased dataflow value: "
|
||||
<< value->ToString() << "\n";
|
||||
CHECK_EQ(value_type, F32);
|
||||
type = F32;
|
||||
break;
|
||||
@ -616,6 +621,8 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
if (value_type == BF16) {
|
||||
continue;
|
||||
}
|
||||
VLOG(2) << "Adjust to F32 due to InputOutPair: "
|
||||
<< value->ToString() << "\n";
|
||||
CHECK_EQ(value_type, F32);
|
||||
type = F32;
|
||||
break;
|
||||
@ -627,10 +634,12 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
// during this final adjustment pass, so we need to check
|
||||
// AllUsersConsumeBF16() again.
|
||||
if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
|
||||
VLOG(2) << "Adjust to F32 due to All user consumeBF16 fail\n";
|
||||
type = F32;
|
||||
}
|
||||
if (type == F32) {
|
||||
for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
|
||||
for (const auto* value :
|
||||
dataflow_->GetValueSet(hlo, index).values()) {
|
||||
// We rely on the fact that this adjustment works in reverse
|
||||
// topological order. Adding the value to
|
||||
// values_that_must_be_kept_as_f32_ will ensure the correctness
|
||||
@ -639,6 +648,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
}
|
||||
}
|
||||
if (type != output_type) {
|
||||
any_change = true;
|
||||
AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
|
||||
VLOG(2) << "HloInstruction output at shape index " << index
|
||||
<< " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
|
||||
@ -651,17 +661,17 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
|
||||
AdjustCalledComputationRoot(hlo);
|
||||
if (hlo->opcode() == HloOpcode::kWhile) {
|
||||
// We need to run on the while body and condition repeatedly until a fixed
|
||||
// point is reached, i.e., the parameters do not change any more. We may
|
||||
// need more than one iteration because the while input and output alias
|
||||
// each other, so changing one input parameter requires changing the
|
||||
// corresponding output element and thus may transitively require changing
|
||||
// another input parameter. A fixed point will be reached because the
|
||||
// parameters can only be changed from BF16 to F32, not the other way
|
||||
// around.
|
||||
// We need to run on the while body and condition repeatedly until a
|
||||
// fixed point is reached, i.e., the parameters do not change any more.
|
||||
// We may need more than one iteration because the while input and
|
||||
// output alias each other, so changing one input parameter requires
|
||||
// changing the corresponding output element and thus may transitively
|
||||
// require changing another input parameter. A fixed point will be
|
||||
// reached because the parameters can only be changed from BF16 to F32,
|
||||
// not the other way around.
|
||||
absl::flat_hash_set<const HloComputation*> visited_in_while;
|
||||
while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
|
||||
&visited_in_while) ||
|
||||
while (ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
hlo->while_condition(), &visited_in_while) ||
|
||||
ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
|
||||
&visited_in_while)) {
|
||||
visited_in_while.clear();
|
||||
@ -680,6 +690,10 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!any_change) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Now adjust parameters of called computations.
|
||||
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
|
||||
AdjustCalledComputationParameters(*inst_it);
|
||||
|
||||
@ -1182,4 +1182,87 @@ ENTRY main {
|
||||
EXPECT_FALSE(OutputsBF16(dus));
|
||||
}
|
||||
|
||||
// This test demonstrates the need for invoking the ResolveAliasingBuffer
|
||||
// multiple times via a fixed-point algorithm. The key was the aliasing of the
|
||||
// two output buffers of the conditional, at subshape 0 (first element). This
|
||||
// aliasing is not resolved until after the gte0 variale is already processed,
|
||||
// triggering incorrect type for gte0 if not repeating the aliasing analysis.
|
||||
TEST_F(BFloat16PropagationTest, ConditionalGTEWithFusion) {
|
||||
const string module_str = R"(
|
||||
HloModule module
|
||||
|
||||
%add.0 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
|
||||
x.1 = f32[4096,4096] parameter(0)
|
||||
y.1 = f32[4096,4096] parameter(1)
|
||||
ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
|
||||
%add.1 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
|
||||
x.1 = f32[4096,4096] parameter(0)
|
||||
y.1 = f32[4096,4096] parameter(1)
|
||||
ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
|
||||
%add.2 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
|
||||
x.1 = f32[4096,4096] parameter(0)
|
||||
y.1 = f32[4096,4096] parameter(1)
|
||||
ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
|
||||
%add.3 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
|
||||
x.1 = f32[4096,4096] parameter(0)
|
||||
y.1 = f32[4096,4096] parameter(1)
|
||||
ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
|
||||
true_branch {
|
||||
true_param = f32[4096,4096] parameter(0)
|
||||
constant.1 = f32[4096,4096] constant(0)
|
||||
add0 = f32[4096,4096] fusion(true_param,true_param), kind=kLoop, calls=add.0
|
||||
constant.2 = f32[4096,4096] constant(0)
|
||||
ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[]) tuple(true_param,add0,constant.2)
|
||||
}
|
||||
|
||||
false_branch {
|
||||
false_param = f32[4096,4096] parameter(0)
|
||||
add3 = f32[4096,4096] fusion(false_param,false_param), kind=kLoop, calls=add.1
|
||||
constant.1 = f32[4096,4096] constant(0)
|
||||
ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[]) tuple(add3, add3,constant.1)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = f32[4096,4096] parameter(0)
|
||||
copy0 = f32[4096,4096] copy(param0)
|
||||
param1 = pred[] parameter(1)
|
||||
conditional = (f32[4096,4096], f32[4096,4096], f32[4096,4096]) conditional(param1, param0, copy0),
|
||||
true_computation=true_branch, false_computation=false_branch
|
||||
gte = f32[4096,4096] get-tuple-element(conditional), index=0
|
||||
gte1 = f32[4096,4096] get-tuple-element(conditional), index=1
|
||||
gte2 = f32[4096,4096] get-tuple-element(conditional), index=2
|
||||
add2 = f32[4096,4096] fusion(gte, gte1), kind=kLoop, calls=add.2
|
||||
ROOT add3 = f32[4096,4096] fusion(add2, gte2), kind=kLoop, calls=add.3
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
EXPECT_TRUE(PropagatePrecision(module.get()));
|
||||
VLOG(2) << module->ToString() << "\n";
|
||||
EXPECT_TRUE(HloVerifier(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/true)
|
||||
.Run(module.get())
|
||||
.status()
|
||||
.ok());
|
||||
auto gte = FindInstruction(module.get(), "gte");
|
||||
auto gte1 = FindInstruction(module.get(), "gte1");
|
||||
auto gte2 = FindInstruction(module.get(), "gte2");
|
||||
EXPECT_FALSE(OutputsBF16(gte));
|
||||
EXPECT_FALSE(OutputsBF16(gte1));
|
||||
EXPECT_TRUE(OutputsBF16(gte2));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user