From d9a8b15b5752f16ad425de3f3f5911e40263c311 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 26 Apr 2019 14:55:31 -0700 Subject: [PATCH] [XLA] Improve fusion cost analysis, it is not necessarily correct but it is better than before. PiperOrigin-RevId: 245492608 --- tensorflow/compiler/xla/service/BUILD | 1 + .../compiler/xla/service/hlo_cost_analysis.cc | 79 ++++++++++++++++++- .../compiler/xla/service/hlo_cost_analysis.h | 4 + 3 files changed, 81 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 344cd1bbc2b..bdea54a8ae4 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2239,6 +2239,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 372f015bb60..a0efc4fe8c9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -17,6 +17,10 @@ limitations under the License. #include +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -129,6 +133,42 @@ int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const { return shape_size_(shape); } +int64 HloCostAnalysis::FusionParameterReadBytes( + const HloInstruction* hlo) const { + int64 size = 0; + bool seen_trivial_user = false; + CHECK(hlo->IsFused() && hlo->opcode() == HloOpcode::kParameter); + for (const HloInstruction* user : hlo->users()) { + switch (user->opcode()) { + case HloOpcode::kFusion: { + for (int64 idx : user->OperandIndices(hlo)) { + size += FusionParameterReadBytes(user->fused_parameter(idx)); + } + break; + } + case HloOpcode::kSlice: + size += GetShapeSize(user->shape()); + break; + case HloOpcode::kDynamicSlice: + size += hlo == user->operand(0) ? GetShapeSize(user->shape()) + : GetShapeSize(hlo->shape()); + break; + case HloOpcode::kBroadcast: + case HloOpcode::kReshape: + size += GetShapeSize(hlo->shape()); + break; + default: + // Other instructions reading this parameter are assumed to be able to + // share the read from memory. + if (!seen_trivial_user) { + seen_trivial_user = true; + size += GetShapeSize(hlo->shape()); + } + } + } + return size; +} + Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } @@ -612,6 +652,17 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { } Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { + if (fusion->IsCustomFusion()) { + for (const HloInstruction* hlo : + fusion->fused_instructions_computation()->instructions()) { + if (hlo->opcode() == HloOpcode::kGather) { + return HandleGather(hlo); + } + if (hlo->opcode() == HloOpcode::kScatter) { + return HandleScatter(hlo); + } + } + } TF_ASSIGN_OR_RETURN( current_properties_, ProcessNestedSubcomputation(fusion->fused_instructions_computation())); @@ -622,12 +673,34 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { current_properties_[kBytesAccessedKey] = 0; ShapeUtil::ForEachSubshape( fusion->shape(), - [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { + [this, fusion](const Shape& subshape, const ShapeIndex& shape_index) { + if (!subshape.IsArray()) { + return; + } + if (shape_index.empty()) { + if (fusion->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + current_properties_[kBytesAccessedKey] += GetShapeSize( + fusion->fused_expression_root()->operand(0)->shape()); + return; + } + } else if (shape_index.size() == 1) { + if (fusion->fused_expression_root() + ->operand(shape_index[0]) + ->opcode() == HloOpcode::kDynamicUpdateSlice) { + current_properties_[kBytesAccessedKey] += + GetShapeSize(fusion->fused_expression_root() + ->operand(shape_index[0]) + ->operand(0) + ->shape()); + return; + } + } current_properties_[kBytesAccessedKey] += GetShapeSize(subshape); }); - for (const HloInstruction* operand : fusion->operands()) { - current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape()); + for (const HloInstruction* operand : fusion->fused_parameters()) { + current_properties_[kBytesAccessedKey] += FusionParameterReadBytes(operand); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 4480554de50..ab96fa4796f 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -196,6 +196,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // a layout. int64 GetShapeSize(const Shape& shape) const; + // Traverses a fusion operand to find the actual bytes accessed by the fusion + // node. + int64 FusionParameterReadBytes(const HloInstruction* hlo) const; + // Function which computes the size of the top-level of a given shape (not // including nested elements, if any). If null then bytes_accessed methods // return an error.