Don't fuse in CpuInstructionFusion if it causes ineffiencies in FusedIrEmitter.
Also move the function which checks for inefficiencies to FusedIrEmitter. PiperOrigin-RevId: 237419419
This commit is contained in:
parent
760862f24d
commit
bc05464062
@ -68,8 +68,8 @@ class RGBToHSVTest(xla_test.XLATestCase):
|
|||||||
{batch0: inp})
|
{batch0: inp})
|
||||||
|
|
||||||
# Verify that processing batch elements together is the same as separate
|
# Verify that processing batch elements together is the same as separate
|
||||||
self.assertAllClose(batch1, join1)
|
self.assertAllCloseAccordingToType(batch1, join1, half_rtol=0.000002)
|
||||||
self.assertAllClose(batch2, join2)
|
self.assertAllCloseAccordingToType(batch2, join2, half_rtol=0.000002)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
batch2, inp, bfloat16_atol=0.03, half_rtol=0.02)
|
batch2, inp, bfloat16_atol=0.03, half_rtol=0.02)
|
||||||
|
|
||||||
|
@ -748,6 +748,7 @@ cc_library(
|
|||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:instruction_fusion",
|
"//tensorflow/compiler/xla/service:instruction_fusion",
|
||||||
|
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
@ -117,6 +118,14 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Don't fuse if fusing would cause too much code duplication because of
|
||||||
|
// inefficiencies in the fusion emitter.
|
||||||
|
// TODO(b/119692968): Remove this once the fusion emitter can handle
|
||||||
|
// arbitrary fusion nodes.
|
||||||
|
if (FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (consumer->opcode() == HloOpcode::kDot) {
|
if (consumer->opcode() == HloOpcode::kDot) {
|
||||||
// In the general case we call out to optimized "black box" GEMM routines
|
// In the general case we call out to optimized "black box" GEMM routines
|
||||||
// for Dot, which precludes fusion. However, in very specific cases, we try
|
// for Dot, which precludes fusion. However, in very specific cases, we try
|
||||||
|
@ -582,6 +582,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:instruction_fusion",
|
"//tensorflow/compiler/xla/service:instruction_fusion",
|
||||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||||
|
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -703,6 +704,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||||
|
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -1102,8 +1104,6 @@ cc_library(
|
|||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -289,8 +290,8 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
|||||||
// TODO(b/119692968): Remove this once the fusion emitter can handle arbitrary
|
// TODO(b/119692968): Remove this once the fusion emitter can handle arbitrary
|
||||||
// fusion nodes.
|
// fusion nodes.
|
||||||
if (absl::c_any_of(fusion->users(), [fusion](const HloInstruction* user) {
|
if (absl::c_any_of(fusion->users(), [fusion](const HloInstruction* user) {
|
||||||
return IsFusionEmitterInefficient(/*consumer=*/user,
|
return FusedIrEmitter::IsFusedIrEmitterInefficient(/*consumer=*/user,
|
||||||
/*producer=*/fusion);
|
/*producer=*/fusion);
|
||||||
})) {
|
})) {
|
||||||
VLOG(3) << "Not merging " << fusion->name()
|
VLOG(3) << "Not merging " << fusion->name()
|
||||||
<< ": Contains one or more users where fusing would cause "
|
<< ": Contains one or more users where fusing would cause "
|
||||||
|
@ -15,12 +15,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <iterator>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
@ -186,102 +184,5 @@ bool IsFusible(const HloInstruction& instr) {
|
|||||||
return IsInputFusible(instr) || IsLoopFusible(instr);
|
return IsInputFusible(instr) || IsLoopFusible(instr);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsFusionEmitterInefficient(const HloInstruction* consumer,
|
|
||||||
const HloInstruction* producer) {
|
|
||||||
if (consumer->opcode() != HloOpcode::kFusion) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// Collects for each instruction in the fusion node from which (indirect)
|
|
||||||
// users newly created index values are passed. Roughly speaking, we reuse
|
|
||||||
// index values if the shapes are equal when ignoring the element type (we may
|
|
||||||
// reuse also if the shape change is a bitcast, but we don't consider that
|
|
||||||
// here). By ignoring potential reuses our estimate whether the fusion emitter
|
|
||||||
// is inefficient is a bit more conservative than necessary.
|
|
||||||
absl::flat_hash_map<const HloInstruction*,
|
|
||||||
absl::flat_hash_set<const HloInstruction*>>
|
|
||||||
indexing_users;
|
|
||||||
// Stores the number of different index accesses for each instruction in the
|
|
||||||
// fusion node. The fusion emitter caches access with the same index, so this
|
|
||||||
// value indicates how many times a specific instruction will be emitted.
|
|
||||||
absl::flat_hash_map<const HloInstruction*, int64> index_usage_count;
|
|
||||||
index_usage_count[consumer] = 1;
|
|
||||||
|
|
||||||
auto evaluate_fusion_computation = [&indexing_users, &index_usage_count](
|
|
||||||
const HloInstruction* fusion) {
|
|
||||||
auto postorder =
|
|
||||||
fusion->fused_instructions_computation()->MakeInstructionPostOrder();
|
|
||||||
std::reverse(postorder.begin(), postorder.end());
|
|
||||||
for (const auto* instruction : postorder) {
|
|
||||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
int64& total = index_usage_count[instruction];
|
|
||||||
if (indexing_users[instruction].empty()) {
|
|
||||||
total = index_usage_count[fusion];
|
|
||||||
} else {
|
|
||||||
total = 0;
|
|
||||||
for (const auto* user : indexing_users[instruction]) {
|
|
||||||
int64 weight = 1;
|
|
||||||
// Concatenate is special: the index differs for each operand, so
|
|
||||||
// in the worst case we have to deal with as many index values as
|
|
||||||
// the number of operands of Concatenate. By considering the worst
|
|
||||||
// case, we are more conservative than necessary regarding
|
|
||||||
// refusing to fuse.
|
|
||||||
if (user->opcode() == HloOpcode::kConcatenate) {
|
|
||||||
weight = user->operand_count();
|
|
||||||
}
|
|
||||||
total += index_usage_count[user] * weight;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const auto* operand : instruction->operands()) {
|
|
||||||
// For simplicity we assume that all shape and layout changing
|
|
||||||
// operations invalidate index reuse.
|
|
||||||
if (Shape::Equal().IgnoreElementType()(operand->shape(),
|
|
||||||
instruction->shape())) {
|
|
||||||
// If the index is reused, it means the operand gets index values
|
|
||||||
// from the same set of (indirect) users as 'instruction' itself.
|
|
||||||
indexing_users[operand].insert(indexing_users[instruction].begin(),
|
|
||||||
indexing_users[instruction].end());
|
|
||||||
} else {
|
|
||||||
// If the index is not reused, it means 'instruction' computes a
|
|
||||||
// new index derived from the index it gets.
|
|
||||||
indexing_users[operand].insert(instruction);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
evaluate_fusion_computation(consumer);
|
|
||||||
|
|
||||||
// Also account for the 'producer' if it would be fused. Find the operand it
|
|
||||||
// corresponds to.
|
|
||||||
for (int64 operand_num = 0; operand_num < consumer->operand_count();
|
|
||||||
++operand_num) {
|
|
||||||
if (consumer->operand(operand_num) == producer) {
|
|
||||||
auto instruction = consumer->fused_parameter(operand_num);
|
|
||||||
int64& total = index_usage_count[producer];
|
|
||||||
total = 0;
|
|
||||||
for (const auto* user : indexing_users[instruction]) {
|
|
||||||
total += index_usage_count[user];
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If 'producer' is a fusion node as well, also evaluate it.
|
|
||||||
if (producer->opcode() == HloOpcode::kFusion) {
|
|
||||||
evaluate_fusion_computation(producer);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sum up the total number of emitted ops.
|
|
||||||
int64 total = 0;
|
|
||||||
for (const auto& entry : index_usage_count) {
|
|
||||||
total += entry.second;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the code duplication has at most a factor of 8 (where 8 is an
|
|
||||||
// arbitrary constant that seems to work).
|
|
||||||
return total > 8 * index_usage_count.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -69,14 +69,6 @@ bool IsInputFusibleScatter(const HloInstruction& instr);
|
|||||||
bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
|
bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
|
||||||
const HloInstruction& instr2);
|
const HloInstruction& instr2);
|
||||||
|
|
||||||
// Evaluates whether fusing 'producer' into 'consumer' might cause exponential
|
|
||||||
// behavior in the fusion emitter. We currently can have exponential time/memory
|
|
||||||
// requirements for emitting certain fusion kernels, in which case we don't want
|
|
||||||
// to fuse.
|
|
||||||
// TODO(b/119692968): Remove this once we have fixed our fusion emitter.
|
|
||||||
bool IsFusionEmitterInefficient(const HloInstruction* consumer,
|
|
||||||
const HloInstruction* producer);
|
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
@ -267,7 +268,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
|||||||
// have exponential time/memory requirements for emitting certain fusion
|
// have exponential time/memory requirements for emitting certain fusion
|
||||||
// kernels, in which case we don't want to fuse.
|
// kernels, in which case we don't want to fuse.
|
||||||
// TODO(b/119692968): Remove this once we have fixed our fusion emitter.
|
// TODO(b/119692968): Remove this once we have fixed our fusion emitter.
|
||||||
return !IsFusionEmitterInefficient(consumer, producer);
|
return !FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,
|
bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,
|
||||||
|
@ -161,6 +161,7 @@ cc_library(
|
|||||||
":llvm_util",
|
":llvm_util",
|
||||||
":loop_emitter",
|
":loop_emitter",
|
||||||
":tuple_ops",
|
":tuple_ops",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
@ -169,6 +170,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@llvm//:core",
|
"@llvm//:core",
|
||||||
|
@ -15,14 +15,22 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "llvm/IR/BasicBlock.h"
|
#include "llvm/IR/BasicBlock.h"
|
||||||
#include "llvm/IR/Value.h"
|
#include "llvm/IR/Value.h"
|
||||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
@ -195,4 +203,101 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator(
|
|||||||
return indexed_generators_.at(instruction);
|
return indexed_generators_.at(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool FusedIrEmitter::IsFusedIrEmitterInefficient(
|
||||||
|
const HloInstruction* consumer, const HloInstruction* producer) {
|
||||||
|
if (consumer->opcode() != HloOpcode::kFusion) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Collects for each instruction in the fusion node from which (indirect)
|
||||||
|
// users newly created index values are passed. Roughly speaking, we reuse
|
||||||
|
// index values if the shapes are equal when ignoring the element type (we may
|
||||||
|
// reuse also if the shape change is a bitcast, but we don't consider that
|
||||||
|
// here). By ignoring potential reuses our estimate whether the fusion emitter
|
||||||
|
// is inefficient is a bit more conservative than necessary.
|
||||||
|
absl::flat_hash_map<const HloInstruction*,
|
||||||
|
absl::flat_hash_set<const HloInstruction*>>
|
||||||
|
indexing_users;
|
||||||
|
// Stores the number of different index accesses for each instruction in the
|
||||||
|
// fusion node. The fusion emitter caches access with the same index, so this
|
||||||
|
// value indicates how many times a specific instruction will be emitted.
|
||||||
|
absl::flat_hash_map<const HloInstruction*, int64> index_usage_count;
|
||||||
|
index_usage_count[consumer] = 1;
|
||||||
|
|
||||||
|
auto evaluate_fusion_computation = [&indexing_users, &index_usage_count](
|
||||||
|
const HloInstruction* fusion) {
|
||||||
|
auto postorder =
|
||||||
|
fusion->fused_instructions_computation()->MakeInstructionPostOrder();
|
||||||
|
std::reverse(postorder.begin(), postorder.end());
|
||||||
|
for (const auto* instruction : postorder) {
|
||||||
|
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int64& total = index_usage_count[instruction];
|
||||||
|
if (indexing_users[instruction].empty()) {
|
||||||
|
total = index_usage_count[fusion];
|
||||||
|
} else {
|
||||||
|
total = 0;
|
||||||
|
for (const auto* user : indexing_users[instruction]) {
|
||||||
|
int64 weight = 1;
|
||||||
|
// Concatenate is special: the index differs for each operand, so
|
||||||
|
// in the worst case we have to deal with as many index values as
|
||||||
|
// the number of operands of Concatenate. By considering the worst
|
||||||
|
// case, we are more conservative than necessary regarding
|
||||||
|
// refusing to fuse.
|
||||||
|
if (user->opcode() == HloOpcode::kConcatenate) {
|
||||||
|
weight = user->operand_count();
|
||||||
|
}
|
||||||
|
total += index_usage_count[user] * weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto* operand : instruction->operands()) {
|
||||||
|
// For simplicity we assume that all shape and layout changing
|
||||||
|
// operations invalidate index reuse.
|
||||||
|
if (Shape::Equal().IgnoreElementType()(operand->shape(),
|
||||||
|
instruction->shape())) {
|
||||||
|
// If the index is reused, it means the operand gets index values
|
||||||
|
// from the same set of (indirect) users as 'instruction' itself.
|
||||||
|
indexing_users[operand].insert(indexing_users[instruction].begin(),
|
||||||
|
indexing_users[instruction].end());
|
||||||
|
} else {
|
||||||
|
// If the index is not reused, it means 'instruction' computes a
|
||||||
|
// new index derived from the index it gets.
|
||||||
|
indexing_users[operand].insert(instruction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
evaluate_fusion_computation(consumer);
|
||||||
|
|
||||||
|
// Also account for the 'producer' if it would be fused. Find the operand it
|
||||||
|
// corresponds to.
|
||||||
|
for (int64 operand_num = 0; operand_num < consumer->operand_count();
|
||||||
|
++operand_num) {
|
||||||
|
if (consumer->operand(operand_num) == producer) {
|
||||||
|
auto instruction = consumer->fused_parameter(operand_num);
|
||||||
|
int64& total = index_usage_count[producer];
|
||||||
|
total = 0;
|
||||||
|
for (const auto* user : indexing_users[instruction]) {
|
||||||
|
total += index_usage_count[user];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If 'producer' is a fusion node as well, also evaluate it.
|
||||||
|
if (producer->opcode() == HloOpcode::kFusion) {
|
||||||
|
evaluate_fusion_computation(producer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sum up the total number of emitted ops.
|
||||||
|
int64 total = 0;
|
||||||
|
for (const auto& entry : index_usage_count) {
|
||||||
|
total += entry.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the code duplication has at most a factor of 8 (where 8 is an
|
||||||
|
// arbitrary constant that seems to work).
|
||||||
|
return total > 8 * index_usage_count.size();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -91,6 +91,14 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
|
|||||||
tiled_parameter_info_ = info;
|
tiled_parameter_info_ = info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Evaluates whether fusing 'producer' into 'consumer' might cause exponential
|
||||||
|
// behavior in FusedIrEmitter. We currently can have exponential time/memory
|
||||||
|
// requirements for emitting certain fusion kernels, in which case we don't
|
||||||
|
// want to fuse.
|
||||||
|
// TODO(b/119692968): Remove this once we have fixed our fusion emitter.
|
||||||
|
static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer,
|
||||||
|
const HloInstruction* producer);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Returns the IrArrays for the fusion instruction operands.
|
// Returns the IrArrays for the fusion instruction operands.
|
||||||
llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) {
|
llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) {
|
||||||
|
@ -349,8 +349,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
|
|||||||
error_spec_);
|
error_spec_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/119692968): This test runs OOM on the CPU backend.
|
XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) {
|
||||||
XLA_TEST_F(ArrayElementwiseOpTest, DISABLED_ON_CPU(DeeplyNestedAddWithSlices)) {
|
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
std::vector<float> values(30, 0.0);
|
std::vector<float> values(30, 0.0);
|
||||||
auto a_literal = LiteralUtil::CreateR1<float>(values);
|
auto a_literal = LiteralUtil::CreateR1<float>(values);
|
||||||
|
Loading…
Reference in New Issue
Block a user