[XLA] Make HloCostAnalysis account for nested shapes to calculate bytes accessed in fusion, infeed and outfeed

PiperOrigin-RevId: 356785877
Change-Id: Ic482891d04b7dfad63be11d01c3ecc7dae69916a
This commit is contained in:
Berkin Ilbeyi 2021-02-10 11:36:58 -08:00 committed by TensorFlower Gardener
parent b681785237
commit 2753c4ce37
2 changed files with 126 additions and 8 deletions

View File

@ -146,7 +146,8 @@ int64 HloCostAnalysis::FusionParameterReadBytes(
const HloInstruction* hlo) const {
int64 size = 0;
bool seen_trivial_user = false;
CHECK(hlo->IsFused() && hlo->opcode() == HloOpcode::kParameter);
CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter ||
hlo->opcode() == HloOpcode::kGetTupleElement));
for (const HloInstruction* user : hlo->users()) {
switch (user->opcode()) {
case HloOpcode::kFusion: {
@ -335,11 +336,34 @@ Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
return Status::OK();
}
Status HloCostAnalysis::HandleInfeed(const HloInstruction*) {
Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) {
// Count nested infeed output tuples.
int64 size = 0;
for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) {
size += GetShapeSize(indexed_shape.shape);
SetOutputBytesAccessed(indexed_shape.index,
GetShapeSize(indexed_shape.shape));
}
SetOutputBytesAccessed(size);
current_properties_[kBytesAccessedKey] = size;
return Status::OK();
}
Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) {
// Count nested outfeed operand tuples.
current_properties_[kBytesAccessedKey] = 0;
for (int64 i = 0; i < outfeed->operand_count(); ++i) {
const HloInstruction* operand = outfeed->operand(i);
int64 size = 0;
for (const auto& indexed_shape :
ShapeUtil::GetLeafShapes(operand->shape())) {
size += GetShapeSize(indexed_shape.shape);
SetOperandBytesAccessed(i, indexed_shape.index,
GetShapeSize(indexed_shape.shape));
}
SetOperandBytesAccessed(i, size);
current_properties_[kBytesAccessedKey] += size;
}
return Status::OK();
}
@ -872,9 +896,31 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
for (int64 i = 0; i < fusion->fused_parameters().size(); ++i) {
const HloInstruction* operand = fusion->fused_parameter(i);
int64 size = FusionParameterReadBytes(operand);
current_properties_[kBytesAccessedKey] += size;
SetOperandBytesAccessed(i, size);
int64 operand_size = 0;
if (!fusion->shape().IsTuple()) {
operand_size = FusionParameterReadBytes(operand);
} else {
// If the fusion parameter is a tuple type, find the gte for the leaf
// shape and calculate the bytes accessed for those array types.
for (const auto& indexed_shape :
ShapeUtil::GetLeafShapes(operand->shape())) {
const HloInstruction* gte = operand;
for (int64 index : indexed_shape.index) {
for (const HloInstruction* user : gte->users()) {
if (user->opcode() == HloOpcode::kGetTupleElement &&
user->tuple_index() == index) {
gte = user;
break;
}
}
}
int64 size = FusionParameterReadBytes(gte);
operand_size += size;
SetOperandBytesAccessed(i, indexed_shape.index, size);
}
}
current_properties_[kBytesAccessedKey] += operand_size;
SetOperandBytesAccessed(i, operand_size);
}
return Status::OK();

View File

@ -693,10 +693,10 @@ TEST_F(FusionCostAnalysis, LoopFusionTupleOutput) {
EXPECT_EQ(fusion_analysis.flop_count(), 16);
EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion),
sizeof(float) * (3 + 5) * 2 * 2 + kPointerSize * 2);
sizeof(float) * (5 + 5) * 2 * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
kPointerSize * 2);
sizeof(float) * 2 * 2 * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
@ -758,6 +758,78 @@ TEST_F(FusionCostAnalysis, NoLayout) {
sizeof(float) * 2 * 3 * 4 * 5);
}
TEST_F(FusionCostAnalysis, TupleBytesAccessed) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
fused_computation {
param = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
gte0 = f32[2,2]{1,0} get-tuple-element(param), index=0
gte1 = f32[2,2]{1,0} get-tuple-element(param), index=1
add = f32[2,2]{1,0} add(gte0, gte1)
mul = f32[2,2]{1,0} multiply(gte0, gte1)
ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(add, mul)
}
ENTRY entry {
param0 = f32[2,2]{1,0} parameter(0)
param1 = f32[2,2]{1,0} parameter(1)
tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(param0, param1)
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(tuple), kind=kLoop, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
HloInstruction* fusion = module->entry_computation()->root_instruction();
HloCostAnalysis fusion_analysis(ShapeSize);
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion), sizeof(float) * 2 * 2 * 4);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
sizeof(float) * 2 * 2 * 2);
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
sizeof(float) * 2 * 2 * 2);
}
TEST_F(FusionCostAnalysis, InfeedOutfeed) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
ENTRY entry {
after-all = token[] after-all()
infeed = ((f32[2,3]{1,0}), token[]) infeed(after-all)
gte0 = (f32[2,3]{1,0}) get-tuple-element(infeed), index=0
gte1 = f32[2,3]{1,0} get-tuple-element(gte0), index=0
add = f32[2,3]{1,0} add(gte1, gte1)
tuple = (f32[2,3]{1,0}) tuple(add)
tok = token[] get-tuple-element(infeed), index=1
ROOT outfeed = token[] outfeed(tuple, tok)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
HloInstruction* infeed =
module->entry_computation()->GetInstructionWithName("infeed");
HloInstruction* outfeed =
module->entry_computation()->GetInstructionWithName("outfeed");
HloCostAnalysis analysis(ShapeSize);
ASSERT_IS_OK(infeed->Accept(&analysis));
ASSERT_IS_OK(outfeed->Accept(&analysis));
EXPECT_EQ(analysis.bytes_accessed(*infeed), sizeof(float) * 2 * 3);
EXPECT_EQ(analysis.operand_bytes_accessed(*infeed, 0), 0);
EXPECT_EQ(analysis.output_bytes_accessed(*infeed), sizeof(float) * 2 * 3);
EXPECT_EQ(analysis.bytes_accessed(*outfeed), sizeof(float) * 2 * 3);
EXPECT_EQ(analysis.operand_bytes_accessed(*outfeed, 0),
sizeof(float) * 2 * 3);
EXPECT_EQ(analysis.output_bytes_accessed(*outfeed), 0);
}
TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);