[XLA] BF16 propagation: do not change if propagation is confined inside a fusion.

We now use a set to track all the potential changes, and do the actual changes
on the HLOs at the end. This also makes the boolean return value (whether
anything is changed) correct.

PiperOrigin-RevId: 195160025
This commit is contained in:
Yuanzhong Xu 2018-05-02 15:14:08 -07:00 committed by TensorFlower Gardener
parent 1d92d5037e
commit 9180cc254d
3 changed files with 387 additions and 158 deletions

View File

@ -33,7 +33,7 @@ BFloat16Propagation::BFloat16Propagation(
const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision(
void BFloat16Propagation::DetermineFusionComputationPrecision(
HloInstruction* fusion) {
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
@ -48,15 +48,13 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision(
auto root = fusion->fused_instructions_computation()->root_instruction();
// Adjust root's element types according to the fusion's output shape.
ShapeUtil::ForEachMutableSubshape(
root->mutable_shape(), [&](Shape* subshape, const ShapeIndex& index) {
if (subshape->element_type() != F32) {
ShapeUtil::ForEachSubshape(
root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (subshape.element_type() != F32) {
return;
}
if (ShapeUtil::GetSubshape(fusion->shape(), index).element_type() ==
BF16) {
subshape->set_element_type(BF16);
changed_ = true;
if (OutputTypeAfterChange(fusion, index) == BF16) {
AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
VLOG(2) << "Fused root " << root->ToString() << " at shape index "
<< index << " changed to BF16 precision for fusion "
<< fusion->ToString();
@ -67,13 +65,101 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision(
auto insts =
fusion->fused_instructions_computation()->MakeInstructionPostOrder();
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
}
computations_visited_in_mutation_pass_.insert(
computations_visited_in_backward_pass_.insert(
fusion->fused_instructions_computation());
RevertIfFusionInternalBF16Changes(fusion);
}
void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision(
void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
HloInstruction* fusion) {
auto has_changes = [this](HloInstruction* inst) {
auto it = changes_to_bf16_.find(inst);
return it != changes_to_bf16_.end() && !it->second.empty();
};
auto root = fusion->fused_instructions_computation()->root_instruction();
tensorflow::gtl::FlatSet<const HloValue*> changed_root_buffers;
auto root_changes_it = changes_to_bf16_.find(root);
if (root_changes_it != changes_to_bf16_.end()) {
for (const auto& index : root_changes_it->second) {
for (const HloValue* value :
dataflow_->GetValueSet(root, index).values()) {
changed_root_buffers.insert(value);
}
}
}
auto aliases_changed_root_buffer =
[this, &changed_root_buffers](const HloInstruction* inst) {
bool aliasing = false;
ShapeUtil::ForEachSubshape(
inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (aliasing) {
// Skip if aliasing is already found.
return;
}
// Only F32 buffers are considered for changing to BF16 in this
// pass.
if (subshape.element_type() != F32) {
return;
}
for (const HloValue* value :
dataflow_->GetValueSet(inst, index).values()) {
if (ContainsKey(changed_root_buffers, value)) {
aliasing = true;
break;
}
}
});
return aliasing;
};
for (auto inst :
fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
if (inst->opcode() == HloOpcode::kParameter) {
continue;
}
if (aliases_changed_root_buffer(inst)) {
continue;
}
if (inst->opcode() == HloOpcode::kFusion) {
bool parameter_reverted = false;
for (int64 i = 0; i < inst->operand_count(); ++i) {
if (has_changes(inst->mutable_operand(i))) {
// Changes on the operand have not been reverted.
continue;
}
auto* fused_parameter = inst->fused_parameter(i);
if (has_changes(fused_parameter)) {
changes_to_bf16_.erase(fused_parameter);
parameter_reverted = true;
}
}
if (parameter_reverted) {
RevertIfFusionInternalBF16Changes(inst);
}
}
if (!has_changes(inst)) {
continue;
}
bool revert_changes = true;
for (auto operand : inst->operands()) {
if (has_changes(operand)) {
revert_changes = false;
break;
}
}
if (revert_changes) {
changes_to_bf16_.erase(inst);
}
}
}
void BFloat16Propagation::DetermineWhileComputationsPrecision(
HloInstruction* while_hlo) {
CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
@ -86,16 +172,14 @@ void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision(
auto body_root = body->root_instruction();
HloComputation* condition = while_hlo->while_condition();
ShapeUtil::ForEachMutableSubshape(
body_root->mutable_shape(),
[this, while_hlo, body_root](Shape* subshape, const ShapeIndex& index) {
if (subshape->element_type() != F32) {
ShapeUtil::ForEachSubshape(
body_root->shape(), [this, while_hlo, body_root](
const Shape& subshape, const ShapeIndex& index) {
if (subshape.element_type() != F32) {
return;
}
if (ShapeUtil::GetSubshape(while_hlo->shape(), index).element_type() ==
BF16) {
subshape->set_element_type(BF16);
changed_ = true;
if (OutputTypeAfterChange(while_hlo, index) == BF16) {
AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16);
VLOG(2) << "While body root " << body_root->ToString()
<< " at shape index " << index
<< " changed to BF16 precision for while "
@ -106,30 +190,30 @@ void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision(
auto body_insts = body->MakeInstructionPostOrder();
for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
++inst_it) {
DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
}
computations_visited_in_mutation_pass_.insert(body);
computations_visited_in_backward_pass_.insert(body);
auto condition_insts = condition->MakeInstructionPostOrder();
for (auto inst_it = condition_insts.rbegin();
inst_it != condition_insts.rend(); ++inst_it) {
DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
}
computations_visited_in_mutation_pass_.insert(condition);
computations_visited_in_backward_pass_.insert(condition);
}
bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
const ShapeIndex& index) const {
auto value_set = dataflow_->GetValueSet(&hlo, index);
auto& value_set = dataflow_->GetValueSet(&hlo, index);
for (const HloValue* value : value_set.values()) {
if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
return false;
}
if (value->shape().element_type() == BF16) {
if (ValueTypeAfterChange(value) == BF16) {
continue;
}
for (const HloUse& use : value->uses()) {
if (!ContainsKey(instructions_visited_in_mutation_pass_,
if (!ContainsKey(instructions_visited_in_backward_pass_,
use.instruction)) {
// We don't know yet whether use.instruction will consume BF16 since it
// hasn't been visited. Although we visit instructions in reverse
@ -145,26 +229,23 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
// precision, or a called computation's parameters have been changed to
// BF16 for fusions or whiles.
if (use.instruction->opcode() == HloOpcode::kFusion) {
const auto* fused_parameter =
auto* fused_parameter =
use.instruction->fused_parameter(use.operand_number);
if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index)
.element_type() != BF16) {
if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) {
return false;
}
continue;
} else if (use.instruction->opcode() == HloOpcode::kWhile) {
const auto* cond_parameter =
auto* cond_parameter =
use.instruction->while_condition()->parameter_instruction(
use.operand_number);
if (ShapeUtil::GetSubshape(cond_parameter->shape(), use.operand_index)
.element_type() != BF16) {
if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
return false;
}
const auto* body_parameter =
auto* body_parameter =
use.instruction->while_body()->parameter_instruction(
use.operand_number);
if (ShapeUtil::GetSubshape(body_parameter->shape(), use.operand_index)
.element_type() != BF16) {
if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) {
return false;
}
continue;
@ -174,19 +255,20 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
continue;
}
// If the op propagates precision and it outputs a BF16, then it's OK to
// supply BF16 also as the input. In the backward mutation pass, the users
// shapes should have already been processed.
// supply BF16 also as the input. In the backward pass, the users shapes
// should have already been processed.
PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID;
if (use.instruction->opcode() == HloOpcode::kTuple ||
(use.instruction->opcode() == HloOpcode::kCrossReplicaSum &&
ShapeUtil::IsTuple(use.instruction->shape()))) {
user_output_type = ShapeUtil::GetSubshape(
ShapeUtil::GetSubshape(use.instruction->shape(),
{use.operand_number}),
use.operand_index)
.element_type();
ShapeIndex use_output_index{use.operand_number};
for (int64 i : use.operand_index) {
use_output_index.push_back(i);
}
user_output_type =
OutputTypeAfterChange(use.instruction, use_output_index);
} else {
user_output_type = use.instruction->shape().element_type();
user_output_type = OutputTypeAfterChange(use.instruction, {});
}
if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
*use.instruction, use.operand_number) &&
@ -199,8 +281,8 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
return true;
}
void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
HloInstruction* hlo, bool skip_parameters) {
void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
bool skip_parameters) {
// We handle any fusion computation or while body/condition after the
// instruction is handled, because we need to know the output shape of a
// fusion or while before propagating inside its computations.
@ -209,12 +291,12 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
[this, hlo, &postpone_processing_called_computations] {
if (!postpone_processing_called_computations) {
if (hlo->opcode() == HloOpcode::kFusion) {
DetermineAndMutateFusionComputationPrecision(hlo);
DetermineFusionComputationPrecision(hlo);
} else if (hlo->opcode() == HloOpcode::kWhile) {
DetermineAndMutateWhileComputationsPrecision(hlo);
DetermineWhileComputationsPrecision(hlo);
}
}
instructions_visited_in_mutation_pass_.insert(hlo);
instructions_visited_in_backward_pass_.insert(hlo);
});
if (hlo->opcode() == HloOpcode::kWhile &&
@ -245,9 +327,9 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
CHECK(hlo->parent() != nullptr);
if (hlo == hlo->parent()->root_instruction()) {
if (!hlo->parent()->IsFusionComputation()) {
ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& subshape,
ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */,
const ShapeIndex& index) {
if (subshape.element_type() != F32) {
if (OutputTypeAfterChange(hlo, index) != F32) {
return;
}
for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
@ -269,13 +351,12 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
return;
}
ShapeUtil::ForEachMutableSubshape(
hlo->mutable_shape(),
[hlo, this](Shape* subshape, const ShapeIndex& index) {
if (subshape->element_type() == F32 &&
ShapeUtil::ForEachSubshape(
hlo->shape(),
[hlo, this](const Shape& /* subshape */, const ShapeIndex& index) {
if (OutputTypeAfterChange(hlo, index) == F32 &&
AllUsersConsumeBF16(*hlo, index)) {
subshape->set_element_type(BF16);
changed_ = true;
AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16);
VLOG(2) << "HloInstruction output at shape index " << index
<< " changed to BF16 precision: " << hlo->ToString();
}
@ -308,26 +389,24 @@ void BFloat16Propagation::AdjustCalledComputationParameters(
CHECK_EQ(operands.size(), computation->num_parameters());
for (int64 i = 0; i < operands.size(); ++i) {
auto parameter = computation->parameter_instruction(i);
ShapeUtil::ForEachMutableSubshape(
parameter->mutable_shape(),
[this, i, hlo, &operands, parameter](Shape* subshape,
ShapeUtil::ForEachSubshape(
parameter->shape(),
[this, i, hlo, &operands, parameter](const Shape& /* subshape */,
const ShapeIndex& index) {
if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
return;
}
PrimitiveType operand_type =
ShapeUtil::GetSubshape(operands[i]->shape(), index)
.element_type();
if (subshape->element_type() == operand_type) {
OutputTypeAfterChange(operands[i], index);
if (OutputTypeAfterChange(parameter, index) == operand_type) {
return;
}
CHECK(operand_type == F32 || operand_type == BF16);
subshape->set_element_type(operand_type);
changed_ = true;
AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type);
VLOG(2) << "Called computation parameter "
<< parameter->ToString() << " at shape index " << index
<< " adjusted to match operand in HLO "
<< hlo->ToString();
<< " adjusted to "
<< (operand_type == BF16 ? "BF16" : "F32")
<< " to match operand in HLO " << hlo->ToString();
});
}
};
@ -348,51 +427,48 @@ void BFloat16Propagation::AdjustCalledComputationParameters(
void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
auto adjust_computation = [this, hlo](HloComputation* computation,
const Shape& output_shape) {
HloInstruction* output) {
// Adjust root.
HloInstruction* root = computation->root_instruction();
ShapeUtil::ForEachMutableSubshape(
root->mutable_shape(), [this, hlo, root, &output_shape](
Shape* subshape, const ShapeIndex& index) {
if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
return;
}
const PrimitiveType output_type =
ShapeUtil::GetSubshape(output_shape, index).element_type();
if (subshape->element_type() == output_type) {
return;
}
CHECK(output_type == F32 || output_type == BF16);
subshape->set_element_type(output_type);
// It's possible that output_type is F32, but the root instruction's
// type is BF16; e.g., a fusion node's output was changed to BF16
// initially but then adjusted back to F32, and the fusion computation
// is now being adjusted after the fusion node.
if (output_type == F32) {
for (const auto* value :
dataflow_->GetValueSet(root, index).values()) {
// We rely on the fact that this adjustment works in reverse
// topological order so that called computation will be
// processed later. Adding the value to
// values_that_must_be_kept_as_f32_ will ensure the
// correctness of the adjustment for HLOs that will be
// processed later.
values_that_must_be_kept_as_f32_.insert(value);
}
}
changed_ = true;
VLOG(2) << "Called computation root " << root->ToString()
<< " at shape index " << index
<< " adjusted to match output shape of " << hlo->ToString();
});
ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output](
const Shape& /* subshape */,
const ShapeIndex& index) {
if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
return;
}
const PrimitiveType output_type = OutputTypeAfterChange(output, index);
if (OutputTypeAfterChange(root, index) == output_type) {
return;
}
AddToOrRemoveFromBF16ChangeSet(root, index, output_type);
// It's possible that output_type is F32, but the root instruction's
// type is BF16; e.g., a fusion node's output was changed to BF16
// initially but then adjusted back to F32, and the fusion computation
// is now being adjusted after the fusion node.
if (output_type == F32) {
for (const auto* value : dataflow_->GetValueSet(root, index).values()) {
// We rely on the fact that this adjustment works in reverse
// topological order so that called computation will be
// processed later. Adding the value to
// values_that_must_be_kept_as_f32_ will ensure the
// correctness of the adjustment for HLOs that will be
// processed later.
values_that_must_be_kept_as_f32_.insert(value);
}
}
VLOG(2) << "Called computation root " << root->ToString()
<< " at shape index " << index << " adjusted to "
<< (output_type == BF16 ? "BF16" : "F32")
<< " to match output shape of " << hlo->ToString();
});
};
switch (hlo->opcode()) {
case HloOpcode::kFusion:
adjust_computation(hlo->fused_instructions_computation(), hlo->shape());
adjust_computation(hlo->fused_instructions_computation(), hlo);
break;
case HloOpcode::kWhile:
adjust_computation(hlo->while_body(), hlo->shape());
adjust_computation(hlo->while_body(), hlo);
break;
default:
break;
@ -409,16 +485,19 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
auto hlo = *inst_it;
auto adjust_hlo_output = [this, hlo, &parameter_changed](
Shape* subshape, const ShapeIndex& index) {
if (subshape->element_type() != F32 && subshape->element_type() != BF16) {
const Shape& /* subshape */,
const ShapeIndex& index) {
auto output_type = OutputTypeAfterChange(hlo, index);
if (output_type != F32 && output_type != BF16) {
return;
}
PrimitiveType type = BF16;
for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
if (value->shape().element_type() == BF16) {
auto value_type = ValueTypeAfterChange(value);
if (value_type == BF16) {
continue;
}
CHECK_EQ(value->shape().element_type(), F32);
CHECK_EQ(value_type, F32);
type = F32;
break;
}
@ -437,16 +516,17 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
values_that_must_be_kept_as_f32_.insert(value);
}
}
if (type != subshape->element_type()) {
subshape->set_element_type(type);
if (type != output_type) {
AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
VLOG(2) << "HloInstruction output at shape index " << index
<< " adjusted to " << *subshape << ": " << hlo->ToString();
<< " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
<< hlo->ToString();
if (hlo->opcode() == HloOpcode::kParameter) {
parameter_changed = true;
}
}
};
ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_hlo_output);
ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
AdjustCalledComputationRoot(hlo);
if (hlo->opcode() == HloOpcode::kWhile) {
// We need to run on the while body and condition repeatedly until a fixed
@ -463,8 +543,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
&visited_in_while)) {
visited_in_while.clear();
ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(),
adjust_hlo_output);
ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
AdjustCalledComputationRoot(hlo);
}
visited_computations->insert(visited_in_while.begin(),
@ -478,7 +557,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
return parameter_changed;
}
Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
HloModule* module) {
std::list<HloComputation*> computations_topological_order =
module->MakeComputationPostOrder();
@ -490,7 +569,9 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
}
ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
}
}
Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
// We could have changed a fusion computation's root shape to have a different
// precision than the fusion node's output, if the fusion root does not
// define a buffer (e.g., a tuple). Now we add conversions after such fusion
@ -517,7 +598,7 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
// (2) after adding conversion
// (3) after tuple simplifier and DCE.
bool needs_tuple_simplifier = false;
for (auto computation : computations_topological_order) {
for (auto computation : module->MakeComputationPostOrder()) {
auto insts = computation->MakeInstructionPostOrder();
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
auto hlo = *inst_it;
@ -587,7 +668,14 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape());
}
}
if (needs_tuple_simplifier) {
TupleSimplifier tuple_simplifier;
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
}
return Status::OK();
}
Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
// We may have converted some constants from F32 to BF16, so adjust the
// constant literals in such cases. We do this here instead of when the
// constant node's is changed because 1) the HloInstruction interface does not
@ -598,8 +686,7 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
// can avoid repeated conversions.
//
// TODO(b/73833576): Consider resetting literal in HloInstruction.
bool needs_dce = needs_tuple_simplifier;
for (auto computation : computations_topological_order) {
for (auto computation : module->MakeComputationPostOrder()) {
for (auto hlo : computation->MakeInstructionPostOrder()) {
if (hlo->opcode() != HloOpcode::kConstant) {
continue;
@ -612,23 +699,13 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
auto new_constant = computation->AddInstruction(
HloInstruction::CreateConstant(std::move(converted_literal)));
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
needs_dce = true;
}
}
}
if (needs_tuple_simplifier) {
TupleSimplifier tuple_simplifier;
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
}
if (needs_dce) {
HloDCE dce;
TF_RETURN_IF_ERROR(dce.Run(module).status());
}
return Status::OK();
}
Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) {
Status BFloat16Propagation::SkipNoopConversions(HloModule* module) {
for (auto computation : module->computations()) {
for (auto hlo : computation->MakeInstructionPostOrder()) {
if (hlo->opcode() != HloOpcode::kConvert) {
@ -643,7 +720,6 @@ Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) {
if (is_root) {
computation->set_root_instruction(source);
}
TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(hlo));
}
}
return Status::OK();
@ -652,8 +728,18 @@ Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) {
// The algorithm first does a forward pass (parameters to root) to determine a
// set of instructions to consider using bfloat16, then does a backward pass to
// determine the precisions of those instructions according to the need of
// their users.
// their users. During the backward pass, the potential changes are stored in
// changes_to_bf16_ which are subject to further adjustments then applied to the
// HLOs.
StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
consider_using_bfloat16_.clear();
instructions_visited_in_backward_pass_.clear();
computations_visited_in_backward_pass_.clear();
values_that_must_be_kept_as_f32_.clear();
caller_counts_.clear();
changes_to_bf16_.clear();
changed_ = false;
TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
std::list<HloComputation*> computations_topological_order =
@ -686,8 +772,24 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
}
auto insts = (*comp_it)->MakeInstructionPostOrder();
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
DetermineAndMutateInstructionPrecision(*inst_it,
/*skip_parameters=*/true);
DetermineInstructionPrecision(*inst_it,
/*skip_parameters=*/true);
}
}
// It's possible that an instruction does not define a buffer, but the
// defining instruction's shape has changed. So we need to adjust the output
// shapes of instructions according to the HLO values they refer to.
ResolveInconsistencyOfAliasingBuffers(module);
// Apply the changes in changes_to_bf16_.
for (auto& change : changes_to_bf16_) {
auto shape = change.first->mutable_shape();
for (const auto& index : change.second) {
auto subshape = ShapeUtil::GetMutableSubshape(shape, index);
CHECK_EQ(subshape->element_type(), F32);
subshape->set_element_type(BF16);
changed_ = true;
}
}
@ -695,15 +797,56 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
return false;
}
// It's possible that an instruction does not define a buffer, but the
// defining instruction's shape has changed. So we need to adjust the output
// shapes of instructions according to the HLO values they refer to.
TF_RETURN_IF_ERROR(ResolveInconsistencyOfAliasingBuffers(module));
TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module));
TF_RETURN_IF_ERROR(ResolveConvertedConstants(module));
// This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 ->
// BF16), so we remove them now.
TF_RETURN_IF_ERROR(RemoveNoopConversions(module));
// BF16), so we skip them now.
TF_RETURN_IF_ERROR(SkipNoopConversions(module));
{
// We may have dead HLOs after ResolveInconsistentFusions,
// ResolveConvertedConstants and SkipNoopConversions.
HloDCE dce;
TF_RETURN_IF_ERROR(dce.Run(module).status());
}
return true;
}
PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
HloInstruction* hlo, const ShapeIndex& index) const {
PrimitiveType type_on_hlo =
ShapeUtil::GetSubshape(hlo->shape(), index).element_type();
if (type_on_hlo != F32) {
return type_on_hlo;
}
auto it = changes_to_bf16_.find(hlo);
if (it == changes_to_bf16_.end()) {
return type_on_hlo;
}
return ContainsKey(it->second, index) ? BF16 : F32;
}
PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
const HloValue* value) const {
auto hlo = value->defining_instruction();
const auto& position = value->defining_position();
return OutputTypeAfterChange(hlo, position.index);
}
void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
if (target_type == BF16) {
auto& entry = changes_to_bf16_[hlo];
entry.insert(index);
} else {
CHECK_EQ(target_type, F32);
auto it = changes_to_bf16_.find(hlo);
if (it == changes_to_bf16_.end()) {
return;
}
it->second.erase(index);
}
}
} // namespace xla

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace xla {
@ -85,30 +86,39 @@ class BFloat16Propagation : public HloPassInterface {
tensorflow::gtl::FlatSet<const HloInstruction*> consider_using_bfloat16_;
// ***************************
// Functions called and state produced by the backward mutation pass (from
// root to parameters).
// Functions called and state produced by the backward pass (from root to
// parameters) that finds opportunities to use BF16.
// Determines the precision for the given instruction in the mutation pass.
void DetermineAndMutateInstructionPrecision(HloInstruction* hlo,
bool skip_parameters);
// Determines the precision for the given instruction in the
// opportunity-finding pass.
void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters);
// Special handling in the mutation pass for fusion computations.
// Special handling in the opportunity-finding pass for fusion computations.
//
// Precondition: hlo->opcode() == kFusion
void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion);
void DetermineFusionComputationPrecision(HloInstruction* fusion);
// Special handling in the mutation pass for while computations.
// Reverts changes to BF16 that will not propagate outside a fusion
// computation. This avoids BF16 casts overhead inside a fusion which won't
// save memory bandwidth.
//
// Precondition: hlo->opcode() == kFusion
void RevertIfFusionInternalBF16Changes(HloInstruction* fusion);
// Special handling in the opportunity-finding pass for while computations.
//
// Precondition: hlo->opcode() == kWhile
void DetermineAndMutateWhileComputationsPrecision(HloInstruction* while_hlo);
void DetermineWhileComputationsPrecision(HloInstruction* while_hlo);
// The set of HloInstructions that have been visited in the mutation pass.
// The set of HloInstructions that have been visited in the
// opportunity-finding pass.
tensorflow::gtl::FlatSet<const HloInstruction*>
instructions_visited_in_mutation_pass_;
instructions_visited_in_backward_pass_;
// The set of HloComputations that have been visited in the mutation pass.
// The set of HloComputations that have been visited in the
// opportunity-finding pass.
tensorflow::gtl::FlatSet<const HloComputation*>
computations_visited_in_mutation_pass_;
computations_visited_in_backward_pass_;
// ***************************
// Functions called by the final inconsistency resolving pass.
@ -116,7 +126,7 @@ class BFloat16Propagation : public HloPassInterface {
// Adjusts the output shapes of HloInstructions such that if two
// HloInstructions have aliasing buffers in their outputs, they must have the
// same precision.
Status ResolveInconsistencyOfAliasingBuffers(HloModule* module);
void ResolveInconsistencyOfAliasingBuffers(HloModule* module);
// Resolves inconsistency of aliasing buffers for the given computation, and
// recursively runs on a while instruction's condition and body until a fixed
@ -134,9 +144,19 @@ class BFloat16Propagation : public HloPassInterface {
void AdjustCalledComputationRoot(HloInstruction* hlo);
// ***************************
// Removes no-op conversions (same source and target shapes) that can be
// produced this pass.
Status RemoveNoopConversions(HloModule* module);
// Functions called after changes in changes_to_bf16_ are applied.
// Resolves inconsistencies introduced by this pass for fusions with
// tuple-type output.
Status ResolveInconsistentFusions(HloModule* module);
// Converts the literals in kConstant HLOs which have their types changed to
// BF16 by this pass.
Status ResolveConvertedConstants(HloModule* module);
// Skips no-op conversions (same source and target shapes) that can be
// produced this pass, i.e., replaces them in their uses with their operands.
Status SkipNoopConversions(HloModule* module);
// ***************************
// Functions called and state used by two or more passes.
@ -146,6 +166,23 @@ class BFloat16Propagation : public HloPassInterface {
bool AllUsersConsumeBF16(const HloInstruction& hlo,
const ShapeIndex& index) const;
// The output element type of the HLO at the given shape index after changes
// in changes_to_bf16_ are applied.
PrimitiveType OutputTypeAfterChange(HloInstruction* hlo,
const ShapeIndex& index) const;
// The element type of the HLO value after changes in changes_to_bf16_ are
// applied.
PrimitiveType ValueTypeAfterChange(const HloValue* value) const;
// If target_type == BF16, adds the HLO at the given index to
// changes_to_bf16_; otherwise, target_type must be F32 and this function
// removes the HLO at the given index from changes_to_bf16_ if it was earlier
// added.
void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo,
const ShapeIndex& index,
PrimitiveType target_type);
// The set of F32 HLO values that must be kept in F32.
tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_;
@ -153,10 +190,28 @@ class BFloat16Propagation : public HloPassInterface {
// module. Populated at the beginning of this pass.
tensorflow::gtl::FlatMap<const HloComputation*, int64> caller_counts_;
// We first store the potential F32-to-BF16 changes to changes_to_bf16_, which
// are subject to further adjustment, then finally applied to the HLOs. This
// avoids setting changed_ to true but all changes are reverted during
// adjustment.
struct IndexHasher {
int64 operator()(const ShapeIndex& index) const {
int64 hash = 0;
for (int64 i : index) {
hash = tensorflow::Hash64Combine(hash, std::hash<int64>()(i));
}
return hash;
}
};
tensorflow::gtl::FlatMap<HloInstruction*,
tensorflow::gtl::FlatSet<ShapeIndex, IndexHasher>>
changes_to_bf16_;
// Whether the last processed HLO module has been changed by this pass.
bool changed_ = false;
const BFloat16Support* bfloat16_support_;
std::unique_ptr<HloDataflowAnalysis> dataflow_;
bool changed_ = false;
};
} // namespace xla

View File

@ -323,6 +323,37 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
EXPECT_TRUE(OutputsBF16(b_f1));
}
// Tests that changes to BF16 that cannot be propagated outside a fusion are
// discarded.
TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
auto builder_f = HloComputation::Builder("fusion");
HloInstruction* a_f =
builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b_f =
builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
HloInstruction* add_f = builder_f.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f));
auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f));
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_FALSE(PropagatePrecision(module.get()));
EXPECT_EQ(computation->root_instruction(), fusion);
}
// Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion
// outputs are only used by a dot, and 3) one element of the tuple is used by
// an add in the fusion computation, then the propagation pass should create a