[XLA] BF16 propagation: do not change if propagation is confined inside a fusion.
We now use a set to track all the potential changes, and do the actual changes on the HLOs at the end. This also makes the boolean return value (whether anything is changed) correct. PiperOrigin-RevId: 195160025
This commit is contained in:
parent
1d92d5037e
commit
9180cc254d
@ -33,7 +33,7 @@ BFloat16Propagation::BFloat16Propagation(
|
|||||||
const BFloat16Support* bfloat16_support)
|
const BFloat16Support* bfloat16_support)
|
||||||
: bfloat16_support_(bfloat16_support) {}
|
: bfloat16_support_(bfloat16_support) {}
|
||||||
|
|
||||||
void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision(
|
void BFloat16Propagation::DetermineFusionComputationPrecision(
|
||||||
HloInstruction* fusion) {
|
HloInstruction* fusion) {
|
||||||
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
|
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
|
||||||
if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
|
if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
|
||||||
@ -48,15 +48,13 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision(
|
|||||||
auto root = fusion->fused_instructions_computation()->root_instruction();
|
auto root = fusion->fused_instructions_computation()->root_instruction();
|
||||||
|
|
||||||
// Adjust root's element types according to the fusion's output shape.
|
// Adjust root's element types according to the fusion's output shape.
|
||||||
ShapeUtil::ForEachMutableSubshape(
|
ShapeUtil::ForEachSubshape(
|
||||||
root->mutable_shape(), [&](Shape* subshape, const ShapeIndex& index) {
|
root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
|
||||||
if (subshape->element_type() != F32) {
|
if (subshape.element_type() != F32) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (ShapeUtil::GetSubshape(fusion->shape(), index).element_type() ==
|
if (OutputTypeAfterChange(fusion, index) == BF16) {
|
||||||
BF16) {
|
AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
|
||||||
subshape->set_element_type(BF16);
|
|
||||||
changed_ = true;
|
|
||||||
VLOG(2) << "Fused root " << root->ToString() << " at shape index "
|
VLOG(2) << "Fused root " << root->ToString() << " at shape index "
|
||||||
<< index << " changed to BF16 precision for fusion "
|
<< index << " changed to BF16 precision for fusion "
|
||||||
<< fusion->ToString();
|
<< fusion->ToString();
|
||||||
@ -67,13 +65,101 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision(
|
|||||||
auto insts =
|
auto insts =
|
||||||
fusion->fused_instructions_computation()->MakeInstructionPostOrder();
|
fusion->fused_instructions_computation()->MakeInstructionPostOrder();
|
||||||
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
|
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
|
||||||
DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
|
DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
|
||||||
}
|
}
|
||||||
computations_visited_in_mutation_pass_.insert(
|
computations_visited_in_backward_pass_.insert(
|
||||||
fusion->fused_instructions_computation());
|
fusion->fused_instructions_computation());
|
||||||
|
|
||||||
|
RevertIfFusionInternalBF16Changes(fusion);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision(
|
void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
|
||||||
|
HloInstruction* fusion) {
|
||||||
|
auto has_changes = [this](HloInstruction* inst) {
|
||||||
|
auto it = changes_to_bf16_.find(inst);
|
||||||
|
return it != changes_to_bf16_.end() && !it->second.empty();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto root = fusion->fused_instructions_computation()->root_instruction();
|
||||||
|
tensorflow::gtl::FlatSet<const HloValue*> changed_root_buffers;
|
||||||
|
|
||||||
|
auto root_changes_it = changes_to_bf16_.find(root);
|
||||||
|
if (root_changes_it != changes_to_bf16_.end()) {
|
||||||
|
for (const auto& index : root_changes_it->second) {
|
||||||
|
for (const HloValue* value :
|
||||||
|
dataflow_->GetValueSet(root, index).values()) {
|
||||||
|
changed_root_buffers.insert(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto aliases_changed_root_buffer =
|
||||||
|
[this, &changed_root_buffers](const HloInstruction* inst) {
|
||||||
|
bool aliasing = false;
|
||||||
|
ShapeUtil::ForEachSubshape(
|
||||||
|
inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
|
||||||
|
if (aliasing) {
|
||||||
|
// Skip if aliasing is already found.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Only F32 buffers are considered for changing to BF16 in this
|
||||||
|
// pass.
|
||||||
|
if (subshape.element_type() != F32) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const HloValue* value :
|
||||||
|
dataflow_->GetValueSet(inst, index).values()) {
|
||||||
|
if (ContainsKey(changed_root_buffers, value)) {
|
||||||
|
aliasing = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return aliasing;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto inst :
|
||||||
|
fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
|
||||||
|
if (inst->opcode() == HloOpcode::kParameter) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (aliases_changed_root_buffer(inst)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (inst->opcode() == HloOpcode::kFusion) {
|
||||||
|
bool parameter_reverted = false;
|
||||||
|
for (int64 i = 0; i < inst->operand_count(); ++i) {
|
||||||
|
if (has_changes(inst->mutable_operand(i))) {
|
||||||
|
// Changes on the operand have not been reverted.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto* fused_parameter = inst->fused_parameter(i);
|
||||||
|
if (has_changes(fused_parameter)) {
|
||||||
|
changes_to_bf16_.erase(fused_parameter);
|
||||||
|
parameter_reverted = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (parameter_reverted) {
|
||||||
|
RevertIfFusionInternalBF16Changes(inst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!has_changes(inst)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bool revert_changes = true;
|
||||||
|
for (auto operand : inst->operands()) {
|
||||||
|
if (has_changes(operand)) {
|
||||||
|
revert_changes = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (revert_changes) {
|
||||||
|
changes_to_bf16_.erase(inst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BFloat16Propagation::DetermineWhileComputationsPrecision(
|
||||||
HloInstruction* while_hlo) {
|
HloInstruction* while_hlo) {
|
||||||
CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
|
CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
|
||||||
|
|
||||||
@ -86,16 +172,14 @@ void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision(
|
|||||||
auto body_root = body->root_instruction();
|
auto body_root = body->root_instruction();
|
||||||
HloComputation* condition = while_hlo->while_condition();
|
HloComputation* condition = while_hlo->while_condition();
|
||||||
|
|
||||||
ShapeUtil::ForEachMutableSubshape(
|
ShapeUtil::ForEachSubshape(
|
||||||
body_root->mutable_shape(),
|
body_root->shape(), [this, while_hlo, body_root](
|
||||||
[this, while_hlo, body_root](Shape* subshape, const ShapeIndex& index) {
|
const Shape& subshape, const ShapeIndex& index) {
|
||||||
if (subshape->element_type() != F32) {
|
if (subshape.element_type() != F32) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (ShapeUtil::GetSubshape(while_hlo->shape(), index).element_type() ==
|
if (OutputTypeAfterChange(while_hlo, index) == BF16) {
|
||||||
BF16) {
|
AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16);
|
||||||
subshape->set_element_type(BF16);
|
|
||||||
changed_ = true;
|
|
||||||
VLOG(2) << "While body root " << body_root->ToString()
|
VLOG(2) << "While body root " << body_root->ToString()
|
||||||
<< " at shape index " << index
|
<< " at shape index " << index
|
||||||
<< " changed to BF16 precision for while "
|
<< " changed to BF16 precision for while "
|
||||||
@ -106,30 +190,30 @@ void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision(
|
|||||||
auto body_insts = body->MakeInstructionPostOrder();
|
auto body_insts = body->MakeInstructionPostOrder();
|
||||||
for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
|
for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
|
||||||
++inst_it) {
|
++inst_it) {
|
||||||
DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
|
DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
|
||||||
}
|
}
|
||||||
computations_visited_in_mutation_pass_.insert(body);
|
computations_visited_in_backward_pass_.insert(body);
|
||||||
|
|
||||||
auto condition_insts = condition->MakeInstructionPostOrder();
|
auto condition_insts = condition->MakeInstructionPostOrder();
|
||||||
for (auto inst_it = condition_insts.rbegin();
|
for (auto inst_it = condition_insts.rbegin();
|
||||||
inst_it != condition_insts.rend(); ++inst_it) {
|
inst_it != condition_insts.rend(); ++inst_it) {
|
||||||
DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
|
DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
|
||||||
}
|
}
|
||||||
computations_visited_in_mutation_pass_.insert(condition);
|
computations_visited_in_backward_pass_.insert(condition);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
|
bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
|
||||||
const ShapeIndex& index) const {
|
const ShapeIndex& index) const {
|
||||||
auto value_set = dataflow_->GetValueSet(&hlo, index);
|
auto& value_set = dataflow_->GetValueSet(&hlo, index);
|
||||||
for (const HloValue* value : value_set.values()) {
|
for (const HloValue* value : value_set.values()) {
|
||||||
if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
|
if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (value->shape().element_type() == BF16) {
|
if (ValueTypeAfterChange(value) == BF16) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (const HloUse& use : value->uses()) {
|
for (const HloUse& use : value->uses()) {
|
||||||
if (!ContainsKey(instructions_visited_in_mutation_pass_,
|
if (!ContainsKey(instructions_visited_in_backward_pass_,
|
||||||
use.instruction)) {
|
use.instruction)) {
|
||||||
// We don't know yet whether use.instruction will consume BF16 since it
|
// We don't know yet whether use.instruction will consume BF16 since it
|
||||||
// hasn't been visited. Although we visit instructions in reverse
|
// hasn't been visited. Although we visit instructions in reverse
|
||||||
@ -145,26 +229,23 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
|
|||||||
// precision, or a called computation's parameters have been changed to
|
// precision, or a called computation's parameters have been changed to
|
||||||
// BF16 for fusions or whiles.
|
// BF16 for fusions or whiles.
|
||||||
if (use.instruction->opcode() == HloOpcode::kFusion) {
|
if (use.instruction->opcode() == HloOpcode::kFusion) {
|
||||||
const auto* fused_parameter =
|
auto* fused_parameter =
|
||||||
use.instruction->fused_parameter(use.operand_number);
|
use.instruction->fused_parameter(use.operand_number);
|
||||||
if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index)
|
if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) {
|
||||||
.element_type() != BF16) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
} else if (use.instruction->opcode() == HloOpcode::kWhile) {
|
} else if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||||
const auto* cond_parameter =
|
auto* cond_parameter =
|
||||||
use.instruction->while_condition()->parameter_instruction(
|
use.instruction->while_condition()->parameter_instruction(
|
||||||
use.operand_number);
|
use.operand_number);
|
||||||
if (ShapeUtil::GetSubshape(cond_parameter->shape(), use.operand_index)
|
if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
|
||||||
.element_type() != BF16) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const auto* body_parameter =
|
auto* body_parameter =
|
||||||
use.instruction->while_body()->parameter_instruction(
|
use.instruction->while_body()->parameter_instruction(
|
||||||
use.operand_number);
|
use.operand_number);
|
||||||
if (ShapeUtil::GetSubshape(body_parameter->shape(), use.operand_index)
|
if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) {
|
||||||
.element_type() != BF16) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
@ -174,19 +255,20 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// If the op propagates precision and it outputs a BF16, then it's OK to
|
// If the op propagates precision and it outputs a BF16, then it's OK to
|
||||||
// supply BF16 also as the input. In the backward mutation pass, the users
|
// supply BF16 also as the input. In the backward pass, the users shapes
|
||||||
// shapes should have already been processed.
|
// should have already been processed.
|
||||||
PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID;
|
PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID;
|
||||||
if (use.instruction->opcode() == HloOpcode::kTuple ||
|
if (use.instruction->opcode() == HloOpcode::kTuple ||
|
||||||
(use.instruction->opcode() == HloOpcode::kCrossReplicaSum &&
|
(use.instruction->opcode() == HloOpcode::kCrossReplicaSum &&
|
||||||
ShapeUtil::IsTuple(use.instruction->shape()))) {
|
ShapeUtil::IsTuple(use.instruction->shape()))) {
|
||||||
user_output_type = ShapeUtil::GetSubshape(
|
ShapeIndex use_output_index{use.operand_number};
|
||||||
ShapeUtil::GetSubshape(use.instruction->shape(),
|
for (int64 i : use.operand_index) {
|
||||||
{use.operand_number}),
|
use_output_index.push_back(i);
|
||||||
use.operand_index)
|
}
|
||||||
.element_type();
|
user_output_type =
|
||||||
|
OutputTypeAfterChange(use.instruction, use_output_index);
|
||||||
} else {
|
} else {
|
||||||
user_output_type = use.instruction->shape().element_type();
|
user_output_type = OutputTypeAfterChange(use.instruction, {});
|
||||||
}
|
}
|
||||||
if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
|
if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
|
||||||
*use.instruction, use.operand_number) &&
|
*use.instruction, use.operand_number) &&
|
||||||
@ -199,8 +281,8 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
|
void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
|
||||||
HloInstruction* hlo, bool skip_parameters) {
|
bool skip_parameters) {
|
||||||
// We handle any fusion computation or while body/condition after the
|
// We handle any fusion computation or while body/condition after the
|
||||||
// instruction is handled, because we need to know the output shape of a
|
// instruction is handled, because we need to know the output shape of a
|
||||||
// fusion or while before propagating inside its computations.
|
// fusion or while before propagating inside its computations.
|
||||||
@ -209,12 +291,12 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
|
|||||||
[this, hlo, &postpone_processing_called_computations] {
|
[this, hlo, &postpone_processing_called_computations] {
|
||||||
if (!postpone_processing_called_computations) {
|
if (!postpone_processing_called_computations) {
|
||||||
if (hlo->opcode() == HloOpcode::kFusion) {
|
if (hlo->opcode() == HloOpcode::kFusion) {
|
||||||
DetermineAndMutateFusionComputationPrecision(hlo);
|
DetermineFusionComputationPrecision(hlo);
|
||||||
} else if (hlo->opcode() == HloOpcode::kWhile) {
|
} else if (hlo->opcode() == HloOpcode::kWhile) {
|
||||||
DetermineAndMutateWhileComputationsPrecision(hlo);
|
DetermineWhileComputationsPrecision(hlo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
instructions_visited_in_mutation_pass_.insert(hlo);
|
instructions_visited_in_backward_pass_.insert(hlo);
|
||||||
});
|
});
|
||||||
|
|
||||||
if (hlo->opcode() == HloOpcode::kWhile &&
|
if (hlo->opcode() == HloOpcode::kWhile &&
|
||||||
@ -245,9 +327,9 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
|
|||||||
CHECK(hlo->parent() != nullptr);
|
CHECK(hlo->parent() != nullptr);
|
||||||
if (hlo == hlo->parent()->root_instruction()) {
|
if (hlo == hlo->parent()->root_instruction()) {
|
||||||
if (!hlo->parent()->IsFusionComputation()) {
|
if (!hlo->parent()->IsFusionComputation()) {
|
||||||
ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& subshape,
|
ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */,
|
||||||
const ShapeIndex& index) {
|
const ShapeIndex& index) {
|
||||||
if (subshape.element_type() != F32) {
|
if (OutputTypeAfterChange(hlo, index) != F32) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
|
for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
|
||||||
@ -269,13 +351,12 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapeUtil::ForEachMutableSubshape(
|
ShapeUtil::ForEachSubshape(
|
||||||
hlo->mutable_shape(),
|
hlo->shape(),
|
||||||
[hlo, this](Shape* subshape, const ShapeIndex& index) {
|
[hlo, this](const Shape& /* subshape */, const ShapeIndex& index) {
|
||||||
if (subshape->element_type() == F32 &&
|
if (OutputTypeAfterChange(hlo, index) == F32 &&
|
||||||
AllUsersConsumeBF16(*hlo, index)) {
|
AllUsersConsumeBF16(*hlo, index)) {
|
||||||
subshape->set_element_type(BF16);
|
AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16);
|
||||||
changed_ = true;
|
|
||||||
VLOG(2) << "HloInstruction output at shape index " << index
|
VLOG(2) << "HloInstruction output at shape index " << index
|
||||||
<< " changed to BF16 precision: " << hlo->ToString();
|
<< " changed to BF16 precision: " << hlo->ToString();
|
||||||
}
|
}
|
||||||
@ -308,26 +389,24 @@ void BFloat16Propagation::AdjustCalledComputationParameters(
|
|||||||
CHECK_EQ(operands.size(), computation->num_parameters());
|
CHECK_EQ(operands.size(), computation->num_parameters());
|
||||||
for (int64 i = 0; i < operands.size(); ++i) {
|
for (int64 i = 0; i < operands.size(); ++i) {
|
||||||
auto parameter = computation->parameter_instruction(i);
|
auto parameter = computation->parameter_instruction(i);
|
||||||
ShapeUtil::ForEachMutableSubshape(
|
ShapeUtil::ForEachSubshape(
|
||||||
parameter->mutable_shape(),
|
parameter->shape(),
|
||||||
[this, i, hlo, &operands, parameter](Shape* subshape,
|
[this, i, hlo, &operands, parameter](const Shape& /* subshape */,
|
||||||
const ShapeIndex& index) {
|
const ShapeIndex& index) {
|
||||||
if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
|
if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
PrimitiveType operand_type =
|
PrimitiveType operand_type =
|
||||||
ShapeUtil::GetSubshape(operands[i]->shape(), index)
|
OutputTypeAfterChange(operands[i], index);
|
||||||
.element_type();
|
if (OutputTypeAfterChange(parameter, index) == operand_type) {
|
||||||
if (subshape->element_type() == operand_type) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
CHECK(operand_type == F32 || operand_type == BF16);
|
AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type);
|
||||||
subshape->set_element_type(operand_type);
|
|
||||||
changed_ = true;
|
|
||||||
VLOG(2) << "Called computation parameter "
|
VLOG(2) << "Called computation parameter "
|
||||||
<< parameter->ToString() << " at shape index " << index
|
<< parameter->ToString() << " at shape index " << index
|
||||||
<< " adjusted to match operand in HLO "
|
<< " adjusted to "
|
||||||
<< hlo->ToString();
|
<< (operand_type == BF16 ? "BF16" : "F32")
|
||||||
|
<< " to match operand in HLO " << hlo->ToString();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -348,51 +427,48 @@ void BFloat16Propagation::AdjustCalledComputationParameters(
|
|||||||
|
|
||||||
void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
|
void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
|
||||||
auto adjust_computation = [this, hlo](HloComputation* computation,
|
auto adjust_computation = [this, hlo](HloComputation* computation,
|
||||||
const Shape& output_shape) {
|
HloInstruction* output) {
|
||||||
// Adjust root.
|
// Adjust root.
|
||||||
HloInstruction* root = computation->root_instruction();
|
HloInstruction* root = computation->root_instruction();
|
||||||
ShapeUtil::ForEachMutableSubshape(
|
ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output](
|
||||||
root->mutable_shape(), [this, hlo, root, &output_shape](
|
const Shape& /* subshape */,
|
||||||
Shape* subshape, const ShapeIndex& index) {
|
const ShapeIndex& index) {
|
||||||
if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
|
if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const PrimitiveType output_type =
|
const PrimitiveType output_type = OutputTypeAfterChange(output, index);
|
||||||
ShapeUtil::GetSubshape(output_shape, index).element_type();
|
if (OutputTypeAfterChange(root, index) == output_type) {
|
||||||
if (subshape->element_type() == output_type) {
|
return;
|
||||||
return;
|
}
|
||||||
}
|
AddToOrRemoveFromBF16ChangeSet(root, index, output_type);
|
||||||
CHECK(output_type == F32 || output_type == BF16);
|
// It's possible that output_type is F32, but the root instruction's
|
||||||
subshape->set_element_type(output_type);
|
// type is BF16; e.g., a fusion node's output was changed to BF16
|
||||||
// It's possible that output_type is F32, but the root instruction's
|
// initially but then adjusted back to F32, and the fusion computation
|
||||||
// type is BF16; e.g., a fusion node's output was changed to BF16
|
// is now being adjusted after the fusion node.
|
||||||
// initially but then adjusted back to F32, and the fusion computation
|
if (output_type == F32) {
|
||||||
// is now being adjusted after the fusion node.
|
for (const auto* value : dataflow_->GetValueSet(root, index).values()) {
|
||||||
if (output_type == F32) {
|
// We rely on the fact that this adjustment works in reverse
|
||||||
for (const auto* value :
|
// topological order so that called computation will be
|
||||||
dataflow_->GetValueSet(root, index).values()) {
|
// processed later. Adding the value to
|
||||||
// We rely on the fact that this adjustment works in reverse
|
// values_that_must_be_kept_as_f32_ will ensure the
|
||||||
// topological order so that called computation will be
|
// correctness of the adjustment for HLOs that will be
|
||||||
// processed later. Adding the value to
|
// processed later.
|
||||||
// values_that_must_be_kept_as_f32_ will ensure the
|
values_that_must_be_kept_as_f32_.insert(value);
|
||||||
// correctness of the adjustment for HLOs that will be
|
}
|
||||||
// processed later.
|
}
|
||||||
values_that_must_be_kept_as_f32_.insert(value);
|
VLOG(2) << "Called computation root " << root->ToString()
|
||||||
}
|
<< " at shape index " << index << " adjusted to "
|
||||||
}
|
<< (output_type == BF16 ? "BF16" : "F32")
|
||||||
changed_ = true;
|
<< " to match output shape of " << hlo->ToString();
|
||||||
VLOG(2) << "Called computation root " << root->ToString()
|
});
|
||||||
<< " at shape index " << index
|
|
||||||
<< " adjusted to match output shape of " << hlo->ToString();
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
switch (hlo->opcode()) {
|
switch (hlo->opcode()) {
|
||||||
case HloOpcode::kFusion:
|
case HloOpcode::kFusion:
|
||||||
adjust_computation(hlo->fused_instructions_computation(), hlo->shape());
|
adjust_computation(hlo->fused_instructions_computation(), hlo);
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kWhile:
|
case HloOpcode::kWhile:
|
||||||
adjust_computation(hlo->while_body(), hlo->shape());
|
adjust_computation(hlo->while_body(), hlo);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
@ -409,16 +485,19 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
|||||||
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
|
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
|
||||||
auto hlo = *inst_it;
|
auto hlo = *inst_it;
|
||||||
auto adjust_hlo_output = [this, hlo, ¶meter_changed](
|
auto adjust_hlo_output = [this, hlo, ¶meter_changed](
|
||||||
Shape* subshape, const ShapeIndex& index) {
|
const Shape& /* subshape */,
|
||||||
if (subshape->element_type() != F32 && subshape->element_type() != BF16) {
|
const ShapeIndex& index) {
|
||||||
|
auto output_type = OutputTypeAfterChange(hlo, index);
|
||||||
|
if (output_type != F32 && output_type != BF16) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
PrimitiveType type = BF16;
|
PrimitiveType type = BF16;
|
||||||
for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
|
for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
|
||||||
if (value->shape().element_type() == BF16) {
|
auto value_type = ValueTypeAfterChange(value);
|
||||||
|
if (value_type == BF16) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
CHECK_EQ(value->shape().element_type(), F32);
|
CHECK_EQ(value_type, F32);
|
||||||
type = F32;
|
type = F32;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -437,16 +516,17 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
|||||||
values_that_must_be_kept_as_f32_.insert(value);
|
values_that_must_be_kept_as_f32_.insert(value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (type != subshape->element_type()) {
|
if (type != output_type) {
|
||||||
subshape->set_element_type(type);
|
AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
|
||||||
VLOG(2) << "HloInstruction output at shape index " << index
|
VLOG(2) << "HloInstruction output at shape index " << index
|
||||||
<< " adjusted to " << *subshape << ": " << hlo->ToString();
|
<< " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
|
||||||
|
<< hlo->ToString();
|
||||||
if (hlo->opcode() == HloOpcode::kParameter) {
|
if (hlo->opcode() == HloOpcode::kParameter) {
|
||||||
parameter_changed = true;
|
parameter_changed = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_hlo_output);
|
ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
|
||||||
AdjustCalledComputationRoot(hlo);
|
AdjustCalledComputationRoot(hlo);
|
||||||
if (hlo->opcode() == HloOpcode::kWhile) {
|
if (hlo->opcode() == HloOpcode::kWhile) {
|
||||||
// We need to run on the while body and condition repeatedly until a fixed
|
// We need to run on the while body and condition repeatedly until a fixed
|
||||||
@ -463,8 +543,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
|||||||
ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
|
ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
|
||||||
&visited_in_while)) {
|
&visited_in_while)) {
|
||||||
visited_in_while.clear();
|
visited_in_while.clear();
|
||||||
ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(),
|
ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
|
||||||
adjust_hlo_output);
|
|
||||||
AdjustCalledComputationRoot(hlo);
|
AdjustCalledComputationRoot(hlo);
|
||||||
}
|
}
|
||||||
visited_computations->insert(visited_in_while.begin(),
|
visited_computations->insert(visited_in_while.begin(),
|
||||||
@ -478,7 +557,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
|||||||
return parameter_changed;
|
return parameter_changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
|
void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
|
||||||
HloModule* module) {
|
HloModule* module) {
|
||||||
std::list<HloComputation*> computations_topological_order =
|
std::list<HloComputation*> computations_topological_order =
|
||||||
module->MakeComputationPostOrder();
|
module->MakeComputationPostOrder();
|
||||||
@ -490,7 +569,9 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
|
|||||||
}
|
}
|
||||||
ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
|
ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
|
||||||
// We could have changed a fusion computation's root shape to have a different
|
// We could have changed a fusion computation's root shape to have a different
|
||||||
// precision than the fusion node's output, if the fusion root does not
|
// precision than the fusion node's output, if the fusion root does not
|
||||||
// define a buffer (e.g., a tuple). Now we add conversions after such fusion
|
// define a buffer (e.g., a tuple). Now we add conversions after such fusion
|
||||||
@ -517,7 +598,7 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
|
|||||||
// (2) after adding conversion
|
// (2) after adding conversion
|
||||||
// (3) after tuple simplifier and DCE.
|
// (3) after tuple simplifier and DCE.
|
||||||
bool needs_tuple_simplifier = false;
|
bool needs_tuple_simplifier = false;
|
||||||
for (auto computation : computations_topological_order) {
|
for (auto computation : module->MakeComputationPostOrder()) {
|
||||||
auto insts = computation->MakeInstructionPostOrder();
|
auto insts = computation->MakeInstructionPostOrder();
|
||||||
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
|
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
|
||||||
auto hlo = *inst_it;
|
auto hlo = *inst_it;
|
||||||
@ -587,7 +668,14 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
|
|||||||
needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape());
|
needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (needs_tuple_simplifier) {
|
||||||
|
TupleSimplifier tuple_simplifier;
|
||||||
|
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
|
||||||
// We may have converted some constants from F32 to BF16, so adjust the
|
// We may have converted some constants from F32 to BF16, so adjust the
|
||||||
// constant literals in such cases. We do this here instead of when the
|
// constant literals in such cases. We do this here instead of when the
|
||||||
// constant node's is changed because 1) the HloInstruction interface does not
|
// constant node's is changed because 1) the HloInstruction interface does not
|
||||||
@ -598,8 +686,7 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
|
|||||||
// can avoid repeated conversions.
|
// can avoid repeated conversions.
|
||||||
//
|
//
|
||||||
// TODO(b/73833576): Consider resetting literal in HloInstruction.
|
// TODO(b/73833576): Consider resetting literal in HloInstruction.
|
||||||
bool needs_dce = needs_tuple_simplifier;
|
for (auto computation : module->MakeComputationPostOrder()) {
|
||||||
for (auto computation : computations_topological_order) {
|
|
||||||
for (auto hlo : computation->MakeInstructionPostOrder()) {
|
for (auto hlo : computation->MakeInstructionPostOrder()) {
|
||||||
if (hlo->opcode() != HloOpcode::kConstant) {
|
if (hlo->opcode() != HloOpcode::kConstant) {
|
||||||
continue;
|
continue;
|
||||||
@ -612,23 +699,13 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
|
|||||||
auto new_constant = computation->AddInstruction(
|
auto new_constant = computation->AddInstruction(
|
||||||
HloInstruction::CreateConstant(std::move(converted_literal)));
|
HloInstruction::CreateConstant(std::move(converted_literal)));
|
||||||
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
|
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
|
||||||
needs_dce = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (needs_tuple_simplifier) {
|
|
||||||
TupleSimplifier tuple_simplifier;
|
|
||||||
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
|
|
||||||
}
|
|
||||||
if (needs_dce) {
|
|
||||||
HloDCE dce;
|
|
||||||
TF_RETURN_IF_ERROR(dce.Run(module).status());
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) {
|
Status BFloat16Propagation::SkipNoopConversions(HloModule* module) {
|
||||||
for (auto computation : module->computations()) {
|
for (auto computation : module->computations()) {
|
||||||
for (auto hlo : computation->MakeInstructionPostOrder()) {
|
for (auto hlo : computation->MakeInstructionPostOrder()) {
|
||||||
if (hlo->opcode() != HloOpcode::kConvert) {
|
if (hlo->opcode() != HloOpcode::kConvert) {
|
||||||
@ -643,7 +720,6 @@ Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) {
|
|||||||
if (is_root) {
|
if (is_root) {
|
||||||
computation->set_root_instruction(source);
|
computation->set_root_instruction(source);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(hlo));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -652,8 +728,18 @@ Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) {
|
|||||||
// The algorithm first does a forward pass (parameters to root) to determine a
|
// The algorithm first does a forward pass (parameters to root) to determine a
|
||||||
// set of instructions to consider using bfloat16, then does a backward pass to
|
// set of instructions to consider using bfloat16, then does a backward pass to
|
||||||
// determine the precisions of those instructions according to the need of
|
// determine the precisions of those instructions according to the need of
|
||||||
// their users.
|
// their users. During the backward pass, the potential changes are stored in
|
||||||
|
// changes_to_bf16_ which are subject to further adjustments then applied to the
|
||||||
|
// HLOs.
|
||||||
StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
|
StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
|
||||||
|
consider_using_bfloat16_.clear();
|
||||||
|
instructions_visited_in_backward_pass_.clear();
|
||||||
|
computations_visited_in_backward_pass_.clear();
|
||||||
|
values_that_must_be_kept_as_f32_.clear();
|
||||||
|
caller_counts_.clear();
|
||||||
|
changes_to_bf16_.clear();
|
||||||
|
changed_ = false;
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
|
TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
|
||||||
|
|
||||||
std::list<HloComputation*> computations_topological_order =
|
std::list<HloComputation*> computations_topological_order =
|
||||||
@ -686,8 +772,24 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
|
|||||||
}
|
}
|
||||||
auto insts = (*comp_it)->MakeInstructionPostOrder();
|
auto insts = (*comp_it)->MakeInstructionPostOrder();
|
||||||
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
|
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
|
||||||
DetermineAndMutateInstructionPrecision(*inst_it,
|
DetermineInstructionPrecision(*inst_it,
|
||||||
/*skip_parameters=*/true);
|
/*skip_parameters=*/true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's possible that an instruction does not define a buffer, but the
|
||||||
|
// defining instruction's shape has changed. So we need to adjust the output
|
||||||
|
// shapes of instructions according to the HLO values they refer to.
|
||||||
|
ResolveInconsistencyOfAliasingBuffers(module);
|
||||||
|
|
||||||
|
// Apply the changes in changes_to_bf16_.
|
||||||
|
for (auto& change : changes_to_bf16_) {
|
||||||
|
auto shape = change.first->mutable_shape();
|
||||||
|
for (const auto& index : change.second) {
|
||||||
|
auto subshape = ShapeUtil::GetMutableSubshape(shape, index);
|
||||||
|
CHECK_EQ(subshape->element_type(), F32);
|
||||||
|
subshape->set_element_type(BF16);
|
||||||
|
changed_ = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -695,15 +797,56 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// It's possible that an instruction does not define a buffer, but the
|
TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module));
|
||||||
// defining instruction's shape has changed. So we need to adjust the output
|
TF_RETURN_IF_ERROR(ResolveConvertedConstants(module));
|
||||||
// shapes of instructions according to the HLO values they refer to.
|
|
||||||
TF_RETURN_IF_ERROR(ResolveInconsistencyOfAliasingBuffers(module));
|
|
||||||
|
|
||||||
// This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 ->
|
// This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 ->
|
||||||
// BF16), so we remove them now.
|
// BF16), so we skip them now.
|
||||||
TF_RETURN_IF_ERROR(RemoveNoopConversions(module));
|
TF_RETURN_IF_ERROR(SkipNoopConversions(module));
|
||||||
|
|
||||||
|
{
|
||||||
|
// We may have dead HLOs after ResolveInconsistentFusions,
|
||||||
|
// ResolveConvertedConstants and SkipNoopConversions.
|
||||||
|
HloDCE dce;
|
||||||
|
TF_RETURN_IF_ERROR(dce.Run(module).status());
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
|
||||||
|
HloInstruction* hlo, const ShapeIndex& index) const {
|
||||||
|
PrimitiveType type_on_hlo =
|
||||||
|
ShapeUtil::GetSubshape(hlo->shape(), index).element_type();
|
||||||
|
if (type_on_hlo != F32) {
|
||||||
|
return type_on_hlo;
|
||||||
|
}
|
||||||
|
auto it = changes_to_bf16_.find(hlo);
|
||||||
|
if (it == changes_to_bf16_.end()) {
|
||||||
|
return type_on_hlo;
|
||||||
|
}
|
||||||
|
return ContainsKey(it->second, index) ? BF16 : F32;
|
||||||
|
}
|
||||||
|
|
||||||
|
PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
|
||||||
|
const HloValue* value) const {
|
||||||
|
auto hlo = value->defining_instruction();
|
||||||
|
const auto& position = value->defining_position();
|
||||||
|
return OutputTypeAfterChange(hlo, position.index);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
|
||||||
|
HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
|
||||||
|
if (target_type == BF16) {
|
||||||
|
auto& entry = changes_to_bf16_[hlo];
|
||||||
|
entry.insert(index);
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(target_type, F32);
|
||||||
|
auto it = changes_to_bf16_.find(hlo);
|
||||||
|
if (it == changes_to_bf16_.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
it->second.erase(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||||
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -85,30 +86,39 @@ class BFloat16Propagation : public HloPassInterface {
|
|||||||
tensorflow::gtl::FlatSet<const HloInstruction*> consider_using_bfloat16_;
|
tensorflow::gtl::FlatSet<const HloInstruction*> consider_using_bfloat16_;
|
||||||
|
|
||||||
// ***************************
|
// ***************************
|
||||||
// Functions called and state produced by the backward mutation pass (from
|
// Functions called and state produced by the backward pass (from root to
|
||||||
// root to parameters).
|
// parameters) that finds opportunities to use BF16.
|
||||||
|
|
||||||
// Determines the precision for the given instruction in the mutation pass.
|
// Determines the precision for the given instruction in the
|
||||||
void DetermineAndMutateInstructionPrecision(HloInstruction* hlo,
|
// opportunity-finding pass.
|
||||||
bool skip_parameters);
|
void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters);
|
||||||
|
|
||||||
// Special handling in the mutation pass for fusion computations.
|
// Special handling in the opportunity-finding pass for fusion computations.
|
||||||
//
|
//
|
||||||
// Precondition: hlo->opcode() == kFusion
|
// Precondition: hlo->opcode() == kFusion
|
||||||
void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion);
|
void DetermineFusionComputationPrecision(HloInstruction* fusion);
|
||||||
|
|
||||||
// Special handling in the mutation pass for while computations.
|
// Reverts changes to BF16 that will not propagate outside a fusion
|
||||||
|
// computation. This avoids BF16 casts overhead inside a fusion which won't
|
||||||
|
// save memory bandwidth.
|
||||||
|
//
|
||||||
|
// Precondition: hlo->opcode() == kFusion
|
||||||
|
void RevertIfFusionInternalBF16Changes(HloInstruction* fusion);
|
||||||
|
|
||||||
|
// Special handling in the opportunity-finding pass for while computations.
|
||||||
//
|
//
|
||||||
// Precondition: hlo->opcode() == kWhile
|
// Precondition: hlo->opcode() == kWhile
|
||||||
void DetermineAndMutateWhileComputationsPrecision(HloInstruction* while_hlo);
|
void DetermineWhileComputationsPrecision(HloInstruction* while_hlo);
|
||||||
|
|
||||||
// The set of HloInstructions that have been visited in the mutation pass.
|
// The set of HloInstructions that have been visited in the
|
||||||
|
// opportunity-finding pass.
|
||||||
tensorflow::gtl::FlatSet<const HloInstruction*>
|
tensorflow::gtl::FlatSet<const HloInstruction*>
|
||||||
instructions_visited_in_mutation_pass_;
|
instructions_visited_in_backward_pass_;
|
||||||
|
|
||||||
// The set of HloComputations that have been visited in the mutation pass.
|
// The set of HloComputations that have been visited in the
|
||||||
|
// opportunity-finding pass.
|
||||||
tensorflow::gtl::FlatSet<const HloComputation*>
|
tensorflow::gtl::FlatSet<const HloComputation*>
|
||||||
computations_visited_in_mutation_pass_;
|
computations_visited_in_backward_pass_;
|
||||||
|
|
||||||
// ***************************
|
// ***************************
|
||||||
// Functions called by the final inconsistency resolving pass.
|
// Functions called by the final inconsistency resolving pass.
|
||||||
@ -116,7 +126,7 @@ class BFloat16Propagation : public HloPassInterface {
|
|||||||
// Adjusts the output shapes of HloInstructions such that if two
|
// Adjusts the output shapes of HloInstructions such that if two
|
||||||
// HloInstructions have aliasing buffers in their outputs, they must have the
|
// HloInstructions have aliasing buffers in their outputs, they must have the
|
||||||
// same precision.
|
// same precision.
|
||||||
Status ResolveInconsistencyOfAliasingBuffers(HloModule* module);
|
void ResolveInconsistencyOfAliasingBuffers(HloModule* module);
|
||||||
|
|
||||||
// Resolves inconsistency of aliasing buffers for the given computation, and
|
// Resolves inconsistency of aliasing buffers for the given computation, and
|
||||||
// recursively runs on a while instruction's condition and body until a fixed
|
// recursively runs on a while instruction's condition and body until a fixed
|
||||||
@ -134,9 +144,19 @@ class BFloat16Propagation : public HloPassInterface {
|
|||||||
void AdjustCalledComputationRoot(HloInstruction* hlo);
|
void AdjustCalledComputationRoot(HloInstruction* hlo);
|
||||||
|
|
||||||
// ***************************
|
// ***************************
|
||||||
// Removes no-op conversions (same source and target shapes) that can be
|
// Functions called after changes in changes_to_bf16_ are applied.
|
||||||
// produced this pass.
|
|
||||||
Status RemoveNoopConversions(HloModule* module);
|
// Resolves inconsistencies introduced by this pass for fusions with
|
||||||
|
// tuple-type output.
|
||||||
|
Status ResolveInconsistentFusions(HloModule* module);
|
||||||
|
|
||||||
|
// Converts the literals in kConstant HLOs which have their types changed to
|
||||||
|
// BF16 by this pass.
|
||||||
|
Status ResolveConvertedConstants(HloModule* module);
|
||||||
|
|
||||||
|
// Skips no-op conversions (same source and target shapes) that can be
|
||||||
|
// produced this pass, i.e., replaces them in their uses with their operands.
|
||||||
|
Status SkipNoopConversions(HloModule* module);
|
||||||
|
|
||||||
// ***************************
|
// ***************************
|
||||||
// Functions called and state used by two or more passes.
|
// Functions called and state used by two or more passes.
|
||||||
@ -146,6 +166,23 @@ class BFloat16Propagation : public HloPassInterface {
|
|||||||
bool AllUsersConsumeBF16(const HloInstruction& hlo,
|
bool AllUsersConsumeBF16(const HloInstruction& hlo,
|
||||||
const ShapeIndex& index) const;
|
const ShapeIndex& index) const;
|
||||||
|
|
||||||
|
// The output element type of the HLO at the given shape index after changes
|
||||||
|
// in changes_to_bf16_ are applied.
|
||||||
|
PrimitiveType OutputTypeAfterChange(HloInstruction* hlo,
|
||||||
|
const ShapeIndex& index) const;
|
||||||
|
|
||||||
|
// The element type of the HLO value after changes in changes_to_bf16_ are
|
||||||
|
// applied.
|
||||||
|
PrimitiveType ValueTypeAfterChange(const HloValue* value) const;
|
||||||
|
|
||||||
|
// If target_type == BF16, adds the HLO at the given index to
|
||||||
|
// changes_to_bf16_; otherwise, target_type must be F32 and this function
|
||||||
|
// removes the HLO at the given index from changes_to_bf16_ if it was earlier
|
||||||
|
// added.
|
||||||
|
void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo,
|
||||||
|
const ShapeIndex& index,
|
||||||
|
PrimitiveType target_type);
|
||||||
|
|
||||||
// The set of F32 HLO values that must be kept in F32.
|
// The set of F32 HLO values that must be kept in F32.
|
||||||
tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_;
|
tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_;
|
||||||
|
|
||||||
@ -153,10 +190,28 @@ class BFloat16Propagation : public HloPassInterface {
|
|||||||
// module. Populated at the beginning of this pass.
|
// module. Populated at the beginning of this pass.
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*, int64> caller_counts_;
|
tensorflow::gtl::FlatMap<const HloComputation*, int64> caller_counts_;
|
||||||
|
|
||||||
|
// We first store the potential F32-to-BF16 changes to changes_to_bf16_, which
|
||||||
|
// are subject to further adjustment, then finally applied to the HLOs. This
|
||||||
|
// avoids setting changed_ to true but all changes are reverted during
|
||||||
|
// adjustment.
|
||||||
|
struct IndexHasher {
|
||||||
|
int64 operator()(const ShapeIndex& index) const {
|
||||||
|
int64 hash = 0;
|
||||||
|
for (int64 i : index) {
|
||||||
|
hash = tensorflow::Hash64Combine(hash, std::hash<int64>()(i));
|
||||||
|
}
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tensorflow::gtl::FlatMap<HloInstruction*,
|
||||||
|
tensorflow::gtl::FlatSet<ShapeIndex, IndexHasher>>
|
||||||
|
changes_to_bf16_;
|
||||||
|
|
||||||
|
// Whether the last processed HLO module has been changed by this pass.
|
||||||
|
bool changed_ = false;
|
||||||
|
|
||||||
const BFloat16Support* bfloat16_support_;
|
const BFloat16Support* bfloat16_support_;
|
||||||
std::unique_ptr<HloDataflowAnalysis> dataflow_;
|
std::unique_ptr<HloDataflowAnalysis> dataflow_;
|
||||||
|
|
||||||
bool changed_ = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -323,6 +323,37 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
|
|||||||
EXPECT_TRUE(OutputsBF16(b_f1));
|
EXPECT_TRUE(OutputsBF16(b_f1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests that changes to BF16 that cannot be propagated outside a fusion are
|
||||||
|
// discarded.
|
||||||
|
TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||||
|
|
||||||
|
HloInstruction* param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, shape, "param"));
|
||||||
|
HloInstruction* add = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
|
||||||
|
|
||||||
|
auto builder_f = HloComputation::Builder("fusion");
|
||||||
|
HloInstruction* a_f =
|
||||||
|
builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||||
|
HloInstruction* b_f =
|
||||||
|
builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
|
||||||
|
HloInstruction* add_f = builder_f.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
|
||||||
|
HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f));
|
||||||
|
auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
|
||||||
|
auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
|
||||||
|
dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f));
|
||||||
|
|
||||||
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
EXPECT_FALSE(PropagatePrecision(module.get()));
|
||||||
|
EXPECT_EQ(computation->root_instruction(), fusion);
|
||||||
|
}
|
||||||
|
|
||||||
// Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion
|
// Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion
|
||||||
// outputs are only used by a dot, and 3) one element of the tuple is used by
|
// outputs are only used by a dot, and 3) one element of the tuple is used by
|
||||||
// an add in the fusion computation, then the propagation pass should create a
|
// an add in the fusion computation, then the propagation pass should create a
|
||||||
|
Loading…
Reference in New Issue
Block a user