[XLA] Bug fix for bfloat16 propagation where fixed-point repetition of resolving aliasing operands are needed.

PiperOrigin-RevId: 340723186
Change-Id: I9e8769bd4e35b7e41d2feea9a7036cc7e2c2d303
This commit is contained in:
A. Unique TensorFlower 2020-11-04 14:02:18 -08:00 committed by TensorFlower Gardener
parent 9c510da34b
commit 4e9997d049
2 changed files with 189 additions and 92 deletions

View File

@ -579,105 +579,119 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
auto insts = computation->MakeInstructionPostOrder();
// Do the adjustment on each instruction in the computation in reverse
// topological order.
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
auto hlo = *inst_it;
auto adjust_hlo_output = [this, hlo, &parameter_changed](
const Shape& /* subshape */,
const ShapeIndex& index) {
auto output_type = OutputTypeAfterChange(hlo, index);
if (output_type != F32 && output_type != BF16) {
return;
}
PrimitiveType type = BF16;
for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
auto value_type = ValueTypeAfterChange(value);
if (value_type == BF16) {
continue;
while (true) {
bool any_change = false;
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
auto hlo = *inst_it;
auto adjust_hlo_output = [&](const Shape& /* subshape */,
const ShapeIndex& index) {
auto output_type = OutputTypeAfterChange(hlo, index);
VLOG(2) << "output_type is " << ((output_type == BF16) ? "BF16" : "F32")
<< " for :" << hlo->ToString() << "\n";
if (output_type != F32 && output_type != BF16) {
return;
}
CHECK_EQ(value_type, F32);
type = F32;
break;
}
// In order to find aliases due to in-place operations, use
// GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
// but this code works with HloModules that aren't ready yet to use
// HloAliasAnalysis (e.g., their computation graphs may not have been
// flattened yet).
for (const auto& operand_and_output_index :
HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
if (operand_and_output_index.second == index) {
const HloUse& operand = operand_and_output_index.first;
for (const auto* value :
dataflow_
->GetValueSet(hlo->operand(operand.operand_number),
operand.operand_index)
.values()) {
auto value_type = ValueTypeAfterChange(value);
if (value_type == BF16) {
continue;
PrimitiveType type = BF16;
for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
auto value_type = ValueTypeAfterChange(value);
if (value_type == BF16) {
continue;
}
VLOG(2) << "Adjust to F32 due to aliased dataflow value: "
<< value->ToString() << "\n";
CHECK_EQ(value_type, F32);
type = F32;
break;
}
// In order to find aliases due to in-place operations, use
// GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
// but this code works with HloModules that aren't ready yet to use
// HloAliasAnalysis (e.g., their computation graphs may not have been
// flattened yet).
for (const auto& operand_and_output_index :
HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
if (operand_and_output_index.second == index) {
const HloUse& operand = operand_and_output_index.first;
for (const auto* value :
dataflow_
->GetValueSet(hlo->operand(operand.operand_number),
operand.operand_index)
.values()) {
auto value_type = ValueTypeAfterChange(value);
if (value_type == BF16) {
continue;
}
VLOG(2) << "Adjust to F32 due to InputOutPair: "
<< value->ToString() << "\n";
CHECK_EQ(value_type, F32);
type = F32;
break;
}
CHECK_EQ(value_type, F32);
type = F32;
break;
}
}
}
// It's possible that a user has been changed from BF16 to F32
// during this final adjustment pass, so we need to check
// AllUsersConsumeBF16() again.
if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
type = F32;
}
if (type == F32) {
for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
// We rely on the fact that this adjustment works in reverse
// topological order. Adding the value to
// values_that_must_be_kept_as_f32_ will ensure the correctness
// of the adjustment for HLOs that will be processed later.
values_that_must_be_kept_as_f32_.insert(value);
// It's possible that a user has been changed from BF16 to F32
// during this final adjustment pass, so we need to check
// AllUsersConsumeBF16() again.
if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
VLOG(2) << "Adjust to F32 due to All user consumeBF16 fail\n";
type = F32;
}
if (type == F32) {
for (const auto* value :
dataflow_->GetValueSet(hlo, index).values()) {
// We rely on the fact that this adjustment works in reverse
// topological order. Adding the value to
// values_that_must_be_kept_as_f32_ will ensure the correctness
// of the adjustment for HLOs that will be processed later.
values_that_must_be_kept_as_f32_.insert(value);
}
}
if (type != output_type) {
any_change = true;
AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
VLOG(2) << "HloInstruction output at shape index " << index
<< " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
<< hlo->ToString();
if (hlo->opcode() == HloOpcode::kParameter) {
parameter_changed = true;
}
}
};
ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
AdjustCalledComputationRoot(hlo);
if (hlo->opcode() == HloOpcode::kWhile) {
// We need to run on the while body and condition repeatedly until a
// fixed point is reached, i.e., the parameters do not change any more.
// We may need more than one iteration because the while input and
// output alias each other, so changing one input parameter requires
// changing the corresponding output element and thus may transitively
// require changing another input parameter. A fixed point will be
// reached because the parameters can only be changed from BF16 to F32,
// not the other way around.
absl::flat_hash_set<const HloComputation*> visited_in_while;
while (ResolveInconsistencyOfAliasingBuffersHelper(
hlo->while_condition(), &visited_in_while) ||
ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
&visited_in_while)) {
visited_in_while.clear();
ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
AdjustCalledComputationRoot(hlo);
}
visited_computations->insert(visited_in_while.begin(),
visited_in_while.end());
} else if (hlo->opcode() == HloOpcode::kFusion) {
ResolveInconsistencyOfAliasingBuffersHelper(
hlo->fused_instructions_computation(), visited_computations);
} else if (hlo->opcode() == HloOpcode::kConditional) {
for (auto* branch : hlo->branch_computations()) {
ResolveInconsistencyOfAliasingBuffersHelper(branch,
visited_computations);
}
}
if (type != output_type) {
AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
VLOG(2) << "HloInstruction output at shape index " << index
<< " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
<< hlo->ToString();
if (hlo->opcode() == HloOpcode::kParameter) {
parameter_changed = true;
}
}
};
ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
AdjustCalledComputationRoot(hlo);
if (hlo->opcode() == HloOpcode::kWhile) {
// We need to run on the while body and condition repeatedly until a fixed
// point is reached, i.e., the parameters do not change any more. We may
// need more than one iteration because the while input and output alias
// each other, so changing one input parameter requires changing the
// corresponding output element and thus may transitively require changing
// another input parameter. A fixed point will be reached because the
// parameters can only be changed from BF16 to F32, not the other way
// around.
absl::flat_hash_set<const HloComputation*> visited_in_while;
while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
&visited_in_while) ||
ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
&visited_in_while)) {
visited_in_while.clear();
ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
AdjustCalledComputationRoot(hlo);
}
visited_computations->insert(visited_in_while.begin(),
visited_in_while.end());
} else if (hlo->opcode() == HloOpcode::kFusion) {
ResolveInconsistencyOfAliasingBuffersHelper(
hlo->fused_instructions_computation(), visited_computations);
} else if (hlo->opcode() == HloOpcode::kConditional) {
for (auto* branch : hlo->branch_computations()) {
ResolveInconsistencyOfAliasingBuffersHelper(branch,
visited_computations);
}
}
if (!any_change) {
break;
}
}
// Now adjust parameters of called computations.

View File

@ -1182,4 +1182,87 @@ ENTRY main {
EXPECT_FALSE(OutputsBF16(dus));
}
// This test demonstrates the need for invoking the ResolveAliasingBuffer
// multiple times via a fixed-point algorithm. The key was the aliasing of the
// two output buffers of the conditional, at subshape 0 (first element). This
// aliasing is not resolved until after the gte0 variale is already processed,
// triggering incorrect type for gte0 if not repeating the aliasing analysis.
TEST_F(BFloat16PropagationTest, ConditionalGTEWithFusion) {
const string module_str = R"(
HloModule module
%add.0 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
x.1 = f32[4096,4096] parameter(0)
y.1 = f32[4096,4096] parameter(1)
ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
%add.1 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
x.1 = f32[4096,4096] parameter(0)
y.1 = f32[4096,4096] parameter(1)
ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
%add.2 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
x.1 = f32[4096,4096] parameter(0)
y.1 = f32[4096,4096] parameter(1)
ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
%add.3 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
x.1 = f32[4096,4096] parameter(0)
y.1 = f32[4096,4096] parameter(1)
ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
true_branch {
true_param = f32[4096,4096] parameter(0)
constant.1 = f32[4096,4096] constant(0)
add0 = f32[4096,4096] fusion(true_param,true_param), kind=kLoop, calls=add.0
constant.2 = f32[4096,4096] constant(0)
ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[]) tuple(true_param,add0,constant.2)
}
false_branch {
false_param = f32[4096,4096] parameter(0)
add3 = f32[4096,4096] fusion(false_param,false_param), kind=kLoop, calls=add.1
constant.1 = f32[4096,4096] constant(0)
ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[]) tuple(add3, add3,constant.1)
}
ENTRY entry {
param0 = f32[4096,4096] parameter(0)
copy0 = f32[4096,4096] copy(param0)
param1 = pred[] parameter(1)
conditional = (f32[4096,4096], f32[4096,4096], f32[4096,4096]) conditional(param1, param0, copy0),
true_computation=true_branch, false_computation=false_branch
gte = f32[4096,4096] get-tuple-element(conditional), index=0
gte1 = f32[4096,4096] get-tuple-element(conditional), index=1
gte2 = f32[4096,4096] get-tuple-element(conditional), index=2
add2 = f32[4096,4096] fusion(gte, gte1), kind=kLoop, calls=add.2
ROOT add3 = f32[4096,4096] fusion(add2, gte2), kind=kLoop, calls=add.3
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
EXPECT_TRUE(PropagatePrecision(module.get()));
VLOG(2) << module->ToString() << "\n";
EXPECT_TRUE(HloVerifier(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/true)
.Run(module.get())
.status()
.ok());
auto gte = FindInstruction(module.get(), "gte");
auto gte1 = FindInstruction(module.get(), "gte1");
auto gte2 = FindInstruction(module.get(), "gte2");
EXPECT_FALSE(OutputsBF16(gte));
EXPECT_FALSE(OutputsBF16(gte1));
EXPECT_TRUE(OutputsBF16(gte2));
}
} // namespace xla