[XLA] Improve fusion cost analysis, it is not necessarily correct but it is
better than before. PiperOrigin-RevId: 245492608
This commit is contained in:
parent
85c4d2348a
commit
d9a8b15b57
@ -2239,6 +2239,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -17,6 +17,10 @@ limitations under the License.
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -129,6 +133,42 @@ int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const {
|
||||
return shape_size_(shape);
|
||||
}
|
||||
|
||||
int64 HloCostAnalysis::FusionParameterReadBytes(
|
||||
const HloInstruction* hlo) const {
|
||||
int64 size = 0;
|
||||
bool seen_trivial_user = false;
|
||||
CHECK(hlo->IsFused() && hlo->opcode() == HloOpcode::kParameter);
|
||||
for (const HloInstruction* user : hlo->users()) {
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kFusion: {
|
||||
for (int64 idx : user->OperandIndices(hlo)) {
|
||||
size += FusionParameterReadBytes(user->fused_parameter(idx));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kSlice:
|
||||
size += GetShapeSize(user->shape());
|
||||
break;
|
||||
case HloOpcode::kDynamicSlice:
|
||||
size += hlo == user->operand(0) ? GetShapeSize(user->shape())
|
||||
: GetShapeSize(hlo->shape());
|
||||
break;
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kReshape:
|
||||
size += GetShapeSize(hlo->shape());
|
||||
break;
|
||||
default:
|
||||
// Other instructions reading this parameter are assumed to be able to
|
||||
// share the read from memory.
|
||||
if (!seen_trivial_user) {
|
||||
seen_trivial_user = true;
|
||||
size += GetShapeSize(hlo->shape());
|
||||
}
|
||||
}
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
|
||||
return HandleElementwiseOp(hlo);
|
||||
}
|
||||
@ -612,6 +652,17 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
|
||||
if (fusion->IsCustomFusion()) {
|
||||
for (const HloInstruction* hlo :
|
||||
fusion->fused_instructions_computation()->instructions()) {
|
||||
if (hlo->opcode() == HloOpcode::kGather) {
|
||||
return HandleGather(hlo);
|
||||
}
|
||||
if (hlo->opcode() == HloOpcode::kScatter) {
|
||||
return HandleScatter(hlo);
|
||||
}
|
||||
}
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
current_properties_,
|
||||
ProcessNestedSubcomputation(fusion->fused_instructions_computation()));
|
||||
@ -622,12 +673,34 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
|
||||
current_properties_[kBytesAccessedKey] = 0;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
fusion->shape(),
|
||||
[this](const Shape& subshape, const ShapeIndex& /*shape_index*/) {
|
||||
[this, fusion](const Shape& subshape, const ShapeIndex& shape_index) {
|
||||
if (!subshape.IsArray()) {
|
||||
return;
|
||||
}
|
||||
if (shape_index.empty()) {
|
||||
if (fusion->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice) {
|
||||
current_properties_[kBytesAccessedKey] += GetShapeSize(
|
||||
fusion->fused_expression_root()->operand(0)->shape());
|
||||
return;
|
||||
}
|
||||
} else if (shape_index.size() == 1) {
|
||||
if (fusion->fused_expression_root()
|
||||
->operand(shape_index[0])
|
||||
->opcode() == HloOpcode::kDynamicUpdateSlice) {
|
||||
current_properties_[kBytesAccessedKey] +=
|
||||
GetShapeSize(fusion->fused_expression_root()
|
||||
->operand(shape_index[0])
|
||||
->operand(0)
|
||||
->shape());
|
||||
return;
|
||||
}
|
||||
}
|
||||
current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
|
||||
});
|
||||
|
||||
for (const HloInstruction* operand : fusion->operands()) {
|
||||
current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape());
|
||||
for (const HloInstruction* operand : fusion->fused_parameters()) {
|
||||
current_properties_[kBytesAccessedKey] += FusionParameterReadBytes(operand);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -196,6 +196,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
// a layout.
|
||||
int64 GetShapeSize(const Shape& shape) const;
|
||||
|
||||
// Traverses a fusion operand to find the actual bytes accessed by the fusion
|
||||
// node.
|
||||
int64 FusionParameterReadBytes(const HloInstruction* hlo) const;
|
||||
|
||||
// Function which computes the size of the top-level of a given shape (not
|
||||
// including nested elements, if any). If null then bytes_accessed methods
|
||||
// return an error.
|
||||
|
Loading…
Reference in New Issue
Block a user