[XLA] Improve fusion cost analysis, it is not necessarily correct but it is
better than before. PiperOrigin-RevId: 245492608
This commit is contained in:
parent
85c4d2348a
commit
d9a8b15b57
@ -2239,6 +2239,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -17,6 +17,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
@ -129,6 +133,42 @@ int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const {
|
|||||||
return shape_size_(shape);
|
return shape_size_(shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 HloCostAnalysis::FusionParameterReadBytes(
|
||||||
|
const HloInstruction* hlo) const {
|
||||||
|
int64 size = 0;
|
||||||
|
bool seen_trivial_user = false;
|
||||||
|
CHECK(hlo->IsFused() && hlo->opcode() == HloOpcode::kParameter);
|
||||||
|
for (const HloInstruction* user : hlo->users()) {
|
||||||
|
switch (user->opcode()) {
|
||||||
|
case HloOpcode::kFusion: {
|
||||||
|
for (int64 idx : user->OperandIndices(hlo)) {
|
||||||
|
size += FusionParameterReadBytes(user->fused_parameter(idx));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case HloOpcode::kSlice:
|
||||||
|
size += GetShapeSize(user->shape());
|
||||||
|
break;
|
||||||
|
case HloOpcode::kDynamicSlice:
|
||||||
|
size += hlo == user->operand(0) ? GetShapeSize(user->shape())
|
||||||
|
: GetShapeSize(hlo->shape());
|
||||||
|
break;
|
||||||
|
case HloOpcode::kBroadcast:
|
||||||
|
case HloOpcode::kReshape:
|
||||||
|
size += GetShapeSize(hlo->shape());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// Other instructions reading this parameter are assumed to be able to
|
||||||
|
// share the read from memory.
|
||||||
|
if (!seen_trivial_user) {
|
||||||
|
seen_trivial_user = true;
|
||||||
|
size += GetShapeSize(hlo->shape());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
|
Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
|
||||||
return HandleElementwiseOp(hlo);
|
return HandleElementwiseOp(hlo);
|
||||||
}
|
}
|
||||||
@ -612,6 +652,17 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
|
Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
|
||||||
|
if (fusion->IsCustomFusion()) {
|
||||||
|
for (const HloInstruction* hlo :
|
||||||
|
fusion->fused_instructions_computation()->instructions()) {
|
||||||
|
if (hlo->opcode() == HloOpcode::kGather) {
|
||||||
|
return HandleGather(hlo);
|
||||||
|
}
|
||||||
|
if (hlo->opcode() == HloOpcode::kScatter) {
|
||||||
|
return HandleScatter(hlo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
current_properties_,
|
current_properties_,
|
||||||
ProcessNestedSubcomputation(fusion->fused_instructions_computation()));
|
ProcessNestedSubcomputation(fusion->fused_instructions_computation()));
|
||||||
@ -622,12 +673,34 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
|
|||||||
current_properties_[kBytesAccessedKey] = 0;
|
current_properties_[kBytesAccessedKey] = 0;
|
||||||
ShapeUtil::ForEachSubshape(
|
ShapeUtil::ForEachSubshape(
|
||||||
fusion->shape(),
|
fusion->shape(),
|
||||||
[this](const Shape& subshape, const ShapeIndex& /*shape_index*/) {
|
[this, fusion](const Shape& subshape, const ShapeIndex& shape_index) {
|
||||||
|
if (!subshape.IsArray()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (shape_index.empty()) {
|
||||||
|
if (fusion->fused_expression_root()->opcode() ==
|
||||||
|
HloOpcode::kDynamicUpdateSlice) {
|
||||||
|
current_properties_[kBytesAccessedKey] += GetShapeSize(
|
||||||
|
fusion->fused_expression_root()->operand(0)->shape());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else if (shape_index.size() == 1) {
|
||||||
|
if (fusion->fused_expression_root()
|
||||||
|
->operand(shape_index[0])
|
||||||
|
->opcode() == HloOpcode::kDynamicUpdateSlice) {
|
||||||
|
current_properties_[kBytesAccessedKey] +=
|
||||||
|
GetShapeSize(fusion->fused_expression_root()
|
||||||
|
->operand(shape_index[0])
|
||||||
|
->operand(0)
|
||||||
|
->shape());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
|
current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
|
||||||
});
|
});
|
||||||
|
|
||||||
for (const HloInstruction* operand : fusion->operands()) {
|
for (const HloInstruction* operand : fusion->fused_parameters()) {
|
||||||
current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape());
|
current_properties_[kBytesAccessedKey] += FusionParameterReadBytes(operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -196,6 +196,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
|||||||
// a layout.
|
// a layout.
|
||||||
int64 GetShapeSize(const Shape& shape) const;
|
int64 GetShapeSize(const Shape& shape) const;
|
||||||
|
|
||||||
|
// Traverses a fusion operand to find the actual bytes accessed by the fusion
|
||||||
|
// node.
|
||||||
|
int64 FusionParameterReadBytes(const HloInstruction* hlo) const;
|
||||||
|
|
||||||
// Function which computes the size of the top-level of a given shape (not
|
// Function which computes the size of the top-level of a given shape (not
|
||||||
// including nested elements, if any). If null then bytes_accessed methods
|
// including nested elements, if any). If null then bytes_accessed methods
|
||||||
// return an error.
|
// return an error.
|
||||||
|
Loading…
Reference in New Issue
Block a user