[XLA] Make BF16 normalization better handle tuple-shaped inputs/outputs.
Use per-subshape accounting and conversion for mismatched types. PiperOrigin-RevId: 275410631 Change-Id: I098fcb7cbdfebc015ba68156563f8350077fd4ea
This commit is contained in:
parent
94794f435d
commit
976e7d7f58
@ -116,7 +116,9 @@ cc_library(
|
||||
deps = [
|
||||
":bfloat16_support",
|
||||
":hlo",
|
||||
":hlo_dce",
|
||||
":hlo_pass",
|
||||
":tuple_simplifier",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
|
||||
@ -18,8 +18,10 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -29,6 +31,8 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
explicit BFloat16NormalizationVisitor(
|
||||
@ -51,19 +55,30 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
|
||||
// independently.
|
||||
Status HandleMultipleOutputs(HloInstruction* hlo);
|
||||
|
||||
// Inserts a conversion HLO that changes the given HLO's output type.
|
||||
Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
|
||||
// Creates a copy of `hlo` with subshapes matching `from` type converted to
|
||||
// `to` type. If no matching subshapes are found, returns the original `hlo`.
|
||||
StatusOr<HloInstruction*> ConvertType(HloInstruction* hlo, PrimitiveType from,
|
||||
PrimitiveType to,
|
||||
HloComputation* computation);
|
||||
|
||||
// Inserts a conversion HLO that changes the given HLO's output type. If the
|
||||
// output is a tuple, change all elements that match the from type.
|
||||
Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType from,
|
||||
PrimitiveType to,
|
||||
HloComputation* computation);
|
||||
|
||||
// Changes the output type to the specified type, then inserts a conversion
|
||||
// to the original type.
|
||||
// to the original type. If the output is a tuple, change all elements that
|
||||
// match the from type.
|
||||
Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo,
|
||||
PrimitiveType from,
|
||||
PrimitiveType to,
|
||||
HloComputation* computation);
|
||||
|
||||
// Inserts a conversion HLO that changes the given HLO's operand type.
|
||||
// Inserts a conversion HLO that changes the given HLO's operand type. If the
|
||||
// operand is a tuple, change all elements that match the from type.
|
||||
Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx,
|
||||
PrimitiveType to,
|
||||
PrimitiveType from, PrimitiveType to,
|
||||
HloComputation* computation);
|
||||
|
||||
// Inserts conversion HLOs to replace the called computations' BF16
|
||||
@ -77,47 +92,140 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
int64 CountSubshapesWithMatchingType(const Shape& shape, PrimitiveType type) {
|
||||
int64 count = 0;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
shape, [&](const Shape& subshape, const ShapeIndex& index) {
|
||||
if (subshape.element_type() == type) {
|
||||
++count;
|
||||
}
|
||||
});
|
||||
return count;
|
||||
}
|
||||
|
||||
int64 ShapeLeafCount(const Shape& shape) {
|
||||
int64 count = 0;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
shape, [&](const Shape& subshape, const ShapeIndex& index) {
|
||||
if (ShapeUtil::IsLeafIndex(shape, index)) {
|
||||
++count;
|
||||
}
|
||||
});
|
||||
return count;
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> BFloat16NormalizationVisitor::ConvertType(
|
||||
HloInstruction* hlo, PrimitiveType from, PrimitiveType to,
|
||||
HloComputation* computation) {
|
||||
if (CountSubshapesWithMatchingType(hlo->shape(), from) == 0) {
|
||||
return hlo;
|
||||
}
|
||||
// If `hlo` is a convert from `to` to `from`, then we can return its operand,
|
||||
// if it is a BF16->F32 convert which doesn't do rounding.
|
||||
if (hlo->opcode() == HloOpcode::kConvert &&
|
||||
hlo->operand(0)->shape().element_type() == to && to == BF16 &&
|
||||
from == F32) {
|
||||
return hlo->mutable_operand(0);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto new_hlo,
|
||||
computation->DeepCopyInstructionWithCustomCopier(
|
||||
hlo, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
|
||||
HloComputation* comp) {
|
||||
const auto& original_subshape =
|
||||
ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
|
||||
if (original_subshape.element_type() != from) {
|
||||
return leaf;
|
||||
}
|
||||
auto new_subshape =
|
||||
ShapeUtil::ChangeElementType(original_subshape, to);
|
||||
bfloat16_normalization_->UpdateLayout(&new_subshape);
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateConvert(new_subshape, leaf));
|
||||
}));
|
||||
return new_hlo;
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
|
||||
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
|
||||
HloInstruction* hlo, PrimitiveType from, PrimitiveType to,
|
||||
HloComputation* computation) {
|
||||
bool is_root = computation->root_instruction() == hlo;
|
||||
std::vector<HloInstruction*> materialized_users = hlo->users();
|
||||
// Use inst's shape temporarily, in order to pass checks in ReplaceUseWith.
|
||||
auto convert = computation->AddInstruction(
|
||||
HloInstruction::CreateConvert(hlo->shape(), hlo));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto new_hlo, ConvertType(hlo, from, to, computation));
|
||||
if (new_hlo == hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
for (auto* user : materialized_users) {
|
||||
if (user->opcode() == HloOpcode::kConvert &&
|
||||
user->shape().element_type() == F32) {
|
||||
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo));
|
||||
}
|
||||
if (is_root) {
|
||||
computation->set_root_instruction(convert);
|
||||
computation->set_root_instruction(new_hlo, /*accept_different_shape=*/true);
|
||||
}
|
||||
convert->mutable_shape()->set_element_type(to);
|
||||
bfloat16_normalization_->UpdateLayout(convert->mutable_shape());
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
|
||||
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
|
||||
auto original_type = hlo->shape().element_type();
|
||||
hlo->mutable_shape()->set_element_type(to);
|
||||
HloInstruction* hlo, PrimitiveType from, PrimitiveType to,
|
||||
HloComputation* computation) {
|
||||
auto original_shape = hlo->shape();
|
||||
if (CountSubshapesWithMatchingType(original_shape, from) == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
ShapeUtil::ForEachMutableSubshape(
|
||||
hlo->mutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) {
|
||||
if (subshape->element_type() == from) {
|
||||
subshape->set_element_type(to);
|
||||
}
|
||||
});
|
||||
bfloat16_normalization_->UpdateLayout(hlo->mutable_shape());
|
||||
return InsertConvertAfterOutput(hlo, original_type, computation);
|
||||
bool is_root = computation->root_instruction() == hlo;
|
||||
std::vector<HloInstruction*> materialized_users = hlo->users();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto new_hlo,
|
||||
computation->DeepCopyInstructionWithCustomCopier(
|
||||
hlo, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
|
||||
HloComputation* comp) {
|
||||
const auto& original_subshape =
|
||||
ShapeUtil::GetSubshape(original_shape, leaf_index);
|
||||
if (original_subshape.element_type() ==
|
||||
leaf->shape().element_type()) {
|
||||
return leaf;
|
||||
}
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateConvert(original_subshape, leaf));
|
||||
}));
|
||||
|
||||
for (auto* user : materialized_users) {
|
||||
// If the user is a BF16 -> F32 convert, we can replace it with `hlo`, which
|
||||
// has its input changed to F32.
|
||||
if (user->opcode() == HloOpcode::kConvert &&
|
||||
user->shape().element_type() == to && to == F32 && from == BF16) {
|
||||
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo));
|
||||
}
|
||||
}
|
||||
if (is_root) {
|
||||
computation->set_root_instruction(new_hlo, /*accept_different_shape=*/true);
|
||||
}
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
|
||||
HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
|
||||
HloComputation* computation) {
|
||||
HloInstruction* hlo, int64 operand_idx, PrimitiveType from,
|
||||
PrimitiveType to, HloComputation* computation) {
|
||||
auto operand = hlo->mutable_operand(operand_idx);
|
||||
auto shape = ShapeUtil::ChangeElementType(operand->shape(), to);
|
||||
bfloat16_normalization_->UpdateLayout(&shape);
|
||||
auto convert = computation->AddInstruction(
|
||||
HloInstruction::CreateConvert(shape, operand));
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
|
||||
TF_ASSIGN_OR_RETURN(auto new_operand,
|
||||
ConvertType(operand, from, to, computation));
|
||||
if (new_operand == operand) {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
hlo->ReplaceOperandWithDifferentShape(operand_idx, new_operand));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -139,16 +247,12 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations(
|
||||
});
|
||||
for (auto& comp_pair : cloned_computations) {
|
||||
auto comp = comp_pair.second;
|
||||
if (comp->root_instruction()->shape().element_type() == BF16) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
InsertConvertAfterOutput(comp->root_instruction(), F32, comp));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
InsertConvertAfterOutput(comp->root_instruction(), BF16, F32, comp));
|
||||
for (auto* param : comp->parameter_instructions()) {
|
||||
if (param->shape().element_type() == BF16) {
|
||||
// This changes the parameter to F32 then inserts a convert after it.
|
||||
TF_RETURN_IF_ERROR(
|
||||
ChangeOutputTypeThenInsertConvertBack(param, F32, comp));
|
||||
}
|
||||
// This changes the parameter to F32 then inserts a convert after it.
|
||||
TF_RETURN_IF_ERROR(
|
||||
ChangeOutputTypeThenInsertConvertBack(param, BF16, F32, comp));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -163,6 +267,8 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
|
||||
bool has_unsupported_bf16_operand = false;
|
||||
bool has_unsupported_bf16_output = false;
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
CHECK(hlo->operand(i)->shape().IsArray());
|
||||
CHECK(ShapeUtil::GetSubshape(hlo->shape(), {i}).IsArray());
|
||||
operand_types[i] = hlo->operand(i)->shape().element_type();
|
||||
output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type();
|
||||
if (operand_types[i] == F32) {
|
||||
@ -203,7 +309,8 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
|
||||
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
if (should_convert_operand(i)) {
|
||||
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
InsertConvertBeforeOperand(hlo, i, BF16, F32, computation_));
|
||||
f32_count += 1;
|
||||
bf16_count -= 1;
|
||||
}
|
||||
@ -275,36 +382,34 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
|
||||
|
||||
Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
|
||||
int f32_count = 0;
|
||||
int bf16_count = 1;
|
||||
int bf16_count = 0;
|
||||
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
if (hlo->operand(i)->shape().element_type() == F32) {
|
||||
f32_count += 1;
|
||||
} else if (hlo->operand(i)->shape().element_type() == BF16) {
|
||||
bf16_count += 1;
|
||||
}
|
||||
f32_count += CountSubshapesWithMatchingType(hlo->operand(i)->shape(), F32);
|
||||
bf16_count +=
|
||||
CountSubshapesWithMatchingType(hlo->operand(i)->shape(), BF16);
|
||||
}
|
||||
|
||||
if (hlo->shape().element_type() == F32) {
|
||||
f32_count += 1;
|
||||
} else if (hlo->shape().element_type() == BF16) {
|
||||
bf16_count += 1;
|
||||
}
|
||||
f32_count += CountSubshapesWithMatchingType(hlo->shape(), F32);
|
||||
bf16_count += CountSubshapesWithMatchingType(hlo->shape(), BF16);
|
||||
|
||||
std::vector<HloComputation*> bf16_called_comps;
|
||||
for (auto* comp : hlo->called_computations()) {
|
||||
bool comp_has_bf16 = false;
|
||||
if (comp->root_instruction()->shape().element_type() == F32) {
|
||||
f32_count += 1;
|
||||
} else if (comp->root_instruction()->shape().element_type() == BF16) {
|
||||
bf16_count += 1;
|
||||
f32_count +=
|
||||
CountSubshapesWithMatchingType(comp->root_instruction()->shape(), F32);
|
||||
int64 bf16_count_comp_root =
|
||||
CountSubshapesWithMatchingType(comp->root_instruction()->shape(), BF16);
|
||||
if (bf16_count_comp_root > 0) {
|
||||
bf16_count += bf16_count_comp_root;
|
||||
comp_has_bf16 = true;
|
||||
}
|
||||
for (auto* param : comp->parameter_instructions()) {
|
||||
if (param->shape().element_type() == F32) {
|
||||
f32_count += 1;
|
||||
} else if (param->shape().element_type() == BF16) {
|
||||
bf16_count += 1;
|
||||
f32_count += CountSubshapesWithMatchingType(param->shape(), F32);
|
||||
int64 bf16_count_comp_param =
|
||||
CountSubshapesWithMatchingType(param->shape(), BF16);
|
||||
if (bf16_count_comp_param > 0) {
|
||||
bf16_count += bf16_count_comp_param;
|
||||
comp_has_bf16 = true;
|
||||
}
|
||||
}
|
||||
@ -315,21 +420,27 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
|
||||
|
||||
// Resolve unsupported BF16 operands.
|
||||
for (int i = 0; i < hlo->operand_count(); ++i) {
|
||||
if (hlo->operand(i)->shape().element_type() == BF16 &&
|
||||
int64 bf16_count_in_operand =
|
||||
CountSubshapesWithMatchingType(hlo->operand(i)->shape(), BF16);
|
||||
if (bf16_count_in_operand > 0 &&
|
||||
!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
|
||||
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
|
||||
bf16_count -= 1;
|
||||
f32_count += 1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
InsertConvertBeforeOperand(hlo, i, BF16, F32, computation_));
|
||||
bf16_count -= bf16_count_in_operand;
|
||||
f32_count += bf16_count_in_operand;
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve unsupported BF16 output.
|
||||
if (hlo->shape().element_type() == BF16 &&
|
||||
!bfloat16_support_->SupportsBF16Output(*hlo)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
|
||||
bf16_count -= 1;
|
||||
f32_count += 1;
|
||||
if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
|
||||
int64 bf16_count_in_hlo =
|
||||
CountSubshapesWithMatchingType(hlo->shape(), BF16);
|
||||
if (bf16_count_in_hlo > 0) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ChangeOutputTypeThenInsertConvertBack(hlo, BF16, F32, computation_));
|
||||
bf16_count -= bf16_count_in_hlo;
|
||||
f32_count += bf16_count_in_hlo;
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve unsupported mixed precision after resolving unsupported BF16
|
||||
@ -341,10 +452,12 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
|
||||
}
|
||||
// See if we can change everything to BF16.
|
||||
if (hlo->called_computations().empty() &&
|
||||
hlo->shape().element_type() == BF16) {
|
||||
CountSubshapesWithMatchingType(hlo->shape(), BF16) ==
|
||||
ShapeLeafCount(hlo->shape())) {
|
||||
bool can_use_bf16 = true;
|
||||
for (int i = 0; i < hlo->operand_count(); ++i) {
|
||||
if (hlo->operand(i)->shape().element_type() == BF16) {
|
||||
if (CountSubshapesWithMatchingType(hlo->operand(i)->shape(), BF16) ==
|
||||
ShapeLeafCount(hlo->operand(i)->shape())) {
|
||||
continue;
|
||||
}
|
||||
if ((bfloat16_support_->EffectiveOperandPrecisionIsBF16(*hlo, i) ||
|
||||
@ -358,22 +471,17 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
|
||||
}
|
||||
if (can_use_bf16) {
|
||||
for (int i = 0; i < hlo->operand_count(); ++i) {
|
||||
if (hlo->operand(i)->shape().element_type() == F32) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
InsertConvertBeforeOperand(hlo, i, BF16, computation_));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
InsertConvertBeforeOperand(hlo, i, F32, BF16, computation_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
if (hlo->shape().element_type() == BF16) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ChangeOutputTypeThenInsertConvertBack(hlo, BF16, F32, computation_));
|
||||
for (int i = 0; i < hlo->operand_count(); ++i) {
|
||||
if (hlo->operand(i)->shape().element_type() == BF16) {
|
||||
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
InsertConvertBeforeOperand(hlo, i, BF16, F32, computation_));
|
||||
}
|
||||
return ConvertCalledComputations(hlo, bf16_called_comps);
|
||||
}
|
||||
@ -385,6 +493,7 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
if (hlo->opcode() == HloOpcode::kTuple || //
|
||||
hlo->opcode() == HloOpcode::kGetTupleElement || //
|
||||
hlo->opcode() == HloOpcode::kConstant || //
|
||||
hlo->opcode() == HloOpcode::kDomain || //
|
||||
hlo->opcode() == HloOpcode::kParameter || //
|
||||
hlo->opcode() == HloOpcode::kFusion || //
|
||||
hlo->opcode() == HloOpcode::kConvert || //
|
||||
@ -410,6 +519,8 @@ Status BFloat16NormalizationVisitor::Preprocess(HloInstruction* hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
|
||||
XLA_VLOG_LINES(
|
||||
2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
|
||||
@ -419,6 +530,12 @@ StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
|
||||
}
|
||||
XLA_VLOG_LINES(2,
|
||||
"BFloat16Normalization::Run(), after:\n" + module->ToString());
|
||||
if (visitor.changed()) {
|
||||
TupleSimplifier tuple_simplifier;
|
||||
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
|
||||
HloDCE dce;
|
||||
TF_RETURN_IF_ERROR(dce.Run(module).status());
|
||||
}
|
||||
return visitor.changed();
|
||||
}
|
||||
|
||||
|
||||
@ -40,7 +40,8 @@ class TestBFloat16Support : public BFloat16Support {
|
||||
hlo.opcode() == HloOpcode::kSubtract ||
|
||||
hlo.opcode() == HloOpcode::kReduce ||
|
||||
hlo.opcode() == HloOpcode::kTuple ||
|
||||
hlo.opcode() == HloOpcode::kGetTupleElement) {
|
||||
hlo.opcode() == HloOpcode::kGetTupleElement ||
|
||||
hlo.opcode() == HloOpcode::kAllToAll) {
|
||||
return true;
|
||||
}
|
||||
if (hlo.opcode() == HloOpcode::kDot) {
|
||||
@ -54,7 +55,8 @@ class TestBFloat16Support : public BFloat16Support {
|
||||
if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce ||
|
||||
hlo.opcode() == HloOpcode::kSubtract ||
|
||||
hlo.opcode() == HloOpcode::kDot || hlo.opcode() == HloOpcode::kTuple ||
|
||||
hlo.opcode() == HloOpcode::kGetTupleElement) {
|
||||
hlo.opcode() == HloOpcode::kGetTupleElement ||
|
||||
hlo.opcode() == HloOpcode::kAllToAll) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -258,19 +260,76 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
|
||||
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
|
||||
/*replica_groups=*/{},
|
||||
/*channel_id=*/absl::nullopt));
|
||||
HloInstruction* gte = builder.AddInstruction(
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
|
||||
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), gte);
|
||||
EXPECT_EQ(gte->shape().element_type(), BF16);
|
||||
EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16);
|
||||
EXPECT_EQ(crs->operand(1)->shape().element_type(), F32);
|
||||
EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
|
||||
}
|
||||
|
||||
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_shape, "a"));
|
||||
|
||||
std::vector<ReplicaGroup> replica_groups(1);
|
||||
replica_groups[0].add_replica_ids(0);
|
||||
replica_groups[0].add_replica_ids(1);
|
||||
HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll(
|
||||
ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), {a, a},
|
||||
replica_groups));
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), a2a);
|
||||
EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {0}).element_type(), BF16);
|
||||
EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {1}).element_type(), BF16);
|
||||
EXPECT_EQ(a2a->operand(0)->opcode(), HloOpcode::kConvert);
|
||||
EXPECT_EQ(a2a->operand(0)->shape().element_type(), BF16);
|
||||
EXPECT_EQ(a2a->operand(1)->opcode(), HloOpcode::kConvert);
|
||||
EXPECT_EQ(a2a->operand(1)->shape().element_type(), BF16);
|
||||
}
|
||||
|
||||
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_shape, "a"));
|
||||
|
||||
std::vector<ReplicaGroup> replica_groups(1);
|
||||
replica_groups[0].add_replica_ids(0);
|
||||
replica_groups[0].add_replica_ids(1);
|
||||
HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll(
|
||||
ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), {a, a},
|
||||
replica_groups));
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kTuple);
|
||||
EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {0}).element_type(), F32);
|
||||
EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {1}).element_type(), F32);
|
||||
EXPECT_EQ(a2a->operand(0)->opcode(), HloOpcode::kParameter);
|
||||
EXPECT_EQ(a2a->operand(0)->shape().element_type(), F32);
|
||||
EXPECT_EQ(a2a->operand(1)->opcode(), HloOpcode::kParameter);
|
||||
EXPECT_EQ(a2a->operand(1)->shape().element_type(), F32);
|
||||
}
|
||||
|
||||
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
@ -288,15 +347,14 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
|
||||
MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}),
|
||||
{key, value}, 0, /*is_stable=*/false, &builder,
|
||||
module.get()));
|
||||
HloInstruction* gte = builder.AddInstruction(
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));
|
||||
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), gte);
|
||||
EXPECT_EQ(gte->shape().element_type(), BF16);
|
||||
EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16);
|
||||
EXPECT_EQ(sort->operand(0)->shape().element_type(), F32);
|
||||
EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user