[XLA] Improve fusion cost analysis, it is not necessarily correct but it is

better than before.

PiperOrigin-RevId: 245492608
This commit is contained in:
Blake Hechtman 2019-04-26 14:55:31 -07:00 committed by TensorFlower Gardener
parent 85c4d2348a
commit d9a8b15b57
3 changed files with 81 additions and 3 deletions

View File

@ -2239,6 +2239,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/types:span",
],
)

View File

@ -17,6 +17,10 @@ limitations under the License.
#include <cmath>
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@ -129,6 +133,42 @@ int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const {
return shape_size_(shape);
}
int64 HloCostAnalysis::FusionParameterReadBytes(
const HloInstruction* hlo) const {
int64 size = 0;
bool seen_trivial_user = false;
CHECK(hlo->IsFused() && hlo->opcode() == HloOpcode::kParameter);
for (const HloInstruction* user : hlo->users()) {
switch (user->opcode()) {
case HloOpcode::kFusion: {
for (int64 idx : user->OperandIndices(hlo)) {
size += FusionParameterReadBytes(user->fused_parameter(idx));
}
break;
}
case HloOpcode::kSlice:
size += GetShapeSize(user->shape());
break;
case HloOpcode::kDynamicSlice:
size += hlo == user->operand(0) ? GetShapeSize(user->shape())
: GetShapeSize(hlo->shape());
break;
case HloOpcode::kBroadcast:
case HloOpcode::kReshape:
size += GetShapeSize(hlo->shape());
break;
default:
// Other instructions reading this parameter are assumed to be able to
// share the read from memory.
if (!seen_trivial_user) {
seen_trivial_user = true;
size += GetShapeSize(hlo->shape());
}
}
}
return size;
}
Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
return HandleElementwiseOp(hlo);
}
@ -612,6 +652,17 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
}
Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
if (fusion->IsCustomFusion()) {
for (const HloInstruction* hlo :
fusion->fused_instructions_computation()->instructions()) {
if (hlo->opcode() == HloOpcode::kGather) {
return HandleGather(hlo);
}
if (hlo->opcode() == HloOpcode::kScatter) {
return HandleScatter(hlo);
}
}
}
TF_ASSIGN_OR_RETURN(
current_properties_,
ProcessNestedSubcomputation(fusion->fused_instructions_computation()));
@ -622,12 +673,34 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
current_properties_[kBytesAccessedKey] = 0;
ShapeUtil::ForEachSubshape(
fusion->shape(),
[this](const Shape& subshape, const ShapeIndex& /*shape_index*/) {
[this, fusion](const Shape& subshape, const ShapeIndex& shape_index) {
if (!subshape.IsArray()) {
return;
}
if (shape_index.empty()) {
if (fusion->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
current_properties_[kBytesAccessedKey] += GetShapeSize(
fusion->fused_expression_root()->operand(0)->shape());
return;
}
} else if (shape_index.size() == 1) {
if (fusion->fused_expression_root()
->operand(shape_index[0])
->opcode() == HloOpcode::kDynamicUpdateSlice) {
current_properties_[kBytesAccessedKey] +=
GetShapeSize(fusion->fused_expression_root()
->operand(shape_index[0])
->operand(0)
->shape());
return;
}
}
current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
});
for (const HloInstruction* operand : fusion->operands()) {
current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape());
for (const HloInstruction* operand : fusion->fused_parameters()) {
current_properties_[kBytesAccessedKey] += FusionParameterReadBytes(operand);
}
return Status::OK();

View File

@ -196,6 +196,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
// a layout.
int64 GetShapeSize(const Shape& shape) const;
// Traverses a fusion operand to find the actual bytes accessed by the fusion
// node.
int64 FusionParameterReadBytes(const HloInstruction* hlo) const;
// Function which computes the size of the top-level of a given shape (not
// including nested elements, if any). If null then bytes_accessed methods
// return an error.