Don't fuse in CpuInstructionFusion if it causes ineffiencies in FusedIrEmitter.
Also move the function which checks for inefficiencies to FusedIrEmitter. PiperOrigin-RevId: 237419419
This commit is contained in:
parent
760862f24d
commit
bc05464062
@ -68,8 +68,8 @@ class RGBToHSVTest(xla_test.XLATestCase):
|
||||
{batch0: inp})
|
||||
|
||||
# Verify that processing batch elements together is the same as separate
|
||||
self.assertAllClose(batch1, join1)
|
||||
self.assertAllClose(batch2, join2)
|
||||
self.assertAllCloseAccordingToType(batch1, join1, half_rtol=0.000002)
|
||||
self.assertAllCloseAccordingToType(batch2, join2, half_rtol=0.000002)
|
||||
self.assertAllCloseAccordingToType(
|
||||
batch2, inp, bfloat16_atol=0.03, half_rtol=0.02)
|
||||
|
||||
|
@ -748,6 +748,7 @@ cc_library(
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:instruction_fusion",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
@ -117,6 +118,14 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't fuse if fusing would cause too much code duplication because of
|
||||
// inefficiencies in the fusion emitter.
|
||||
// TODO(b/119692968): Remove this once the fusion emitter can handle
|
||||
// arbitrary fusion nodes.
|
||||
if (FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (consumer->opcode() == HloOpcode::kDot) {
|
||||
// In the general case we call out to optimized "black box" GEMM routines
|
||||
// for Dot, which precludes fusion. However, in very specific cases, we try
|
||||
|
@ -582,6 +582,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:instruction_fusion",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
@ -703,6 +704,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -1102,8 +1104,6 @@ cc_library(
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -289,7 +290,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
||||
// TODO(b/119692968): Remove this once the fusion emitter can handle arbitrary
|
||||
// fusion nodes.
|
||||
if (absl::c_any_of(fusion->users(), [fusion](const HloInstruction* user) {
|
||||
return IsFusionEmitterInefficient(/*consumer=*/user,
|
||||
return FusedIrEmitter::IsFusedIrEmitterInefficient(/*consumer=*/user,
|
||||
/*producer=*/fusion);
|
||||
})) {
|
||||
VLOG(3) << "Not merging " << fusion->name()
|
||||
|
@ -15,12 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
@ -186,102 +184,5 @@ bool IsFusible(const HloInstruction& instr) {
|
||||
return IsInputFusible(instr) || IsLoopFusible(instr);
|
||||
}
|
||||
|
||||
bool IsFusionEmitterInefficient(const HloInstruction* consumer,
|
||||
const HloInstruction* producer) {
|
||||
if (consumer->opcode() != HloOpcode::kFusion) {
|
||||
return false;
|
||||
}
|
||||
// Collects for each instruction in the fusion node from which (indirect)
|
||||
// users newly created index values are passed. Roughly speaking, we reuse
|
||||
// index values if the shapes are equal when ignoring the element type (we may
|
||||
// reuse also if the shape change is a bitcast, but we don't consider that
|
||||
// here). By ignoring potential reuses our estimate whether the fusion emitter
|
||||
// is inefficient is a bit more conservative than necessary.
|
||||
absl::flat_hash_map<const HloInstruction*,
|
||||
absl::flat_hash_set<const HloInstruction*>>
|
||||
indexing_users;
|
||||
// Stores the number of different index accesses for each instruction in the
|
||||
// fusion node. The fusion emitter caches access with the same index, so this
|
||||
// value indicates how many times a specific instruction will be emitted.
|
||||
absl::flat_hash_map<const HloInstruction*, int64> index_usage_count;
|
||||
index_usage_count[consumer] = 1;
|
||||
|
||||
auto evaluate_fusion_computation = [&indexing_users, &index_usage_count](
|
||||
const HloInstruction* fusion) {
|
||||
auto postorder =
|
||||
fusion->fused_instructions_computation()->MakeInstructionPostOrder();
|
||||
std::reverse(postorder.begin(), postorder.end());
|
||||
for (const auto* instruction : postorder) {
|
||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||
continue;
|
||||
}
|
||||
int64& total = index_usage_count[instruction];
|
||||
if (indexing_users[instruction].empty()) {
|
||||
total = index_usage_count[fusion];
|
||||
} else {
|
||||
total = 0;
|
||||
for (const auto* user : indexing_users[instruction]) {
|
||||
int64 weight = 1;
|
||||
// Concatenate is special: the index differs for each operand, so
|
||||
// in the worst case we have to deal with as many index values as
|
||||
// the number of operands of Concatenate. By considering the worst
|
||||
// case, we are more conservative than necessary regarding
|
||||
// refusing to fuse.
|
||||
if (user->opcode() == HloOpcode::kConcatenate) {
|
||||
weight = user->operand_count();
|
||||
}
|
||||
total += index_usage_count[user] * weight;
|
||||
}
|
||||
}
|
||||
for (const auto* operand : instruction->operands()) {
|
||||
// For simplicity we assume that all shape and layout changing
|
||||
// operations invalidate index reuse.
|
||||
if (Shape::Equal().IgnoreElementType()(operand->shape(),
|
||||
instruction->shape())) {
|
||||
// If the index is reused, it means the operand gets index values
|
||||
// from the same set of (indirect) users as 'instruction' itself.
|
||||
indexing_users[operand].insert(indexing_users[instruction].begin(),
|
||||
indexing_users[instruction].end());
|
||||
} else {
|
||||
// If the index is not reused, it means 'instruction' computes a
|
||||
// new index derived from the index it gets.
|
||||
indexing_users[operand].insert(instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
evaluate_fusion_computation(consumer);
|
||||
|
||||
// Also account for the 'producer' if it would be fused. Find the operand it
|
||||
// corresponds to.
|
||||
for (int64 operand_num = 0; operand_num < consumer->operand_count();
|
||||
++operand_num) {
|
||||
if (consumer->operand(operand_num) == producer) {
|
||||
auto instruction = consumer->fused_parameter(operand_num);
|
||||
int64& total = index_usage_count[producer];
|
||||
total = 0;
|
||||
for (const auto* user : indexing_users[instruction]) {
|
||||
total += index_usage_count[user];
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If 'producer' is a fusion node as well, also evaluate it.
|
||||
if (producer->opcode() == HloOpcode::kFusion) {
|
||||
evaluate_fusion_computation(producer);
|
||||
}
|
||||
|
||||
// Sum up the total number of emitted ops.
|
||||
int64 total = 0;
|
||||
for (const auto& entry : index_usage_count) {
|
||||
total += entry.second;
|
||||
}
|
||||
|
||||
// Check that the code duplication has at most a factor of 8 (where 8 is an
|
||||
// arbitrary constant that seems to work).
|
||||
return total > 8 * index_usage_count.size();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -69,14 +69,6 @@ bool IsInputFusibleScatter(const HloInstruction& instr);
|
||||
bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
|
||||
const HloInstruction& instr2);
|
||||
|
||||
// Evaluates whether fusing 'producer' into 'consumer' might cause exponential
|
||||
// behavior in the fusion emitter. We currently can have exponential time/memory
|
||||
// requirements for emitting certain fusion kernels, in which case we don't want
|
||||
// to fuse.
|
||||
// TODO(b/119692968): Remove this once we have fixed our fusion emitter.
|
||||
bool IsFusionEmitterInefficient(const HloInstruction* consumer,
|
||||
const HloInstruction* producer);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -267,7 +268,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
// have exponential time/memory requirements for emitting certain fusion
|
||||
// kernels, in which case we don't want to fuse.
|
||||
// TODO(b/119692968): Remove this once we have fixed our fusion emitter.
|
||||
return !IsFusionEmitterInefficient(consumer, producer);
|
||||
return !FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer);
|
||||
}
|
||||
|
||||
bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,
|
||||
|
@ -161,6 +161,7 @@ cc_library(
|
||||
":llvm_util",
|
||||
":loop_emitter",
|
||||
":tuple_ops",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -169,6 +170,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm//:core",
|
||||
|
@ -15,14 +15,22 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -195,4 +203,101 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator(
|
||||
return indexed_generators_.at(instruction);
|
||||
}
|
||||
|
||||
bool FusedIrEmitter::IsFusedIrEmitterInefficient(
|
||||
const HloInstruction* consumer, const HloInstruction* producer) {
|
||||
if (consumer->opcode() != HloOpcode::kFusion) {
|
||||
return false;
|
||||
}
|
||||
// Collects for each instruction in the fusion node from which (indirect)
|
||||
// users newly created index values are passed. Roughly speaking, we reuse
|
||||
// index values if the shapes are equal when ignoring the element type (we may
|
||||
// reuse also if the shape change is a bitcast, but we don't consider that
|
||||
// here). By ignoring potential reuses our estimate whether the fusion emitter
|
||||
// is inefficient is a bit more conservative than necessary.
|
||||
absl::flat_hash_map<const HloInstruction*,
|
||||
absl::flat_hash_set<const HloInstruction*>>
|
||||
indexing_users;
|
||||
// Stores the number of different index accesses for each instruction in the
|
||||
// fusion node. The fusion emitter caches access with the same index, so this
|
||||
// value indicates how many times a specific instruction will be emitted.
|
||||
absl::flat_hash_map<const HloInstruction*, int64> index_usage_count;
|
||||
index_usage_count[consumer] = 1;
|
||||
|
||||
auto evaluate_fusion_computation = [&indexing_users, &index_usage_count](
|
||||
const HloInstruction* fusion) {
|
||||
auto postorder =
|
||||
fusion->fused_instructions_computation()->MakeInstructionPostOrder();
|
||||
std::reverse(postorder.begin(), postorder.end());
|
||||
for (const auto* instruction : postorder) {
|
||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||
continue;
|
||||
}
|
||||
int64& total = index_usage_count[instruction];
|
||||
if (indexing_users[instruction].empty()) {
|
||||
total = index_usage_count[fusion];
|
||||
} else {
|
||||
total = 0;
|
||||
for (const auto* user : indexing_users[instruction]) {
|
||||
int64 weight = 1;
|
||||
// Concatenate is special: the index differs for each operand, so
|
||||
// in the worst case we have to deal with as many index values as
|
||||
// the number of operands of Concatenate. By considering the worst
|
||||
// case, we are more conservative than necessary regarding
|
||||
// refusing to fuse.
|
||||
if (user->opcode() == HloOpcode::kConcatenate) {
|
||||
weight = user->operand_count();
|
||||
}
|
||||
total += index_usage_count[user] * weight;
|
||||
}
|
||||
}
|
||||
for (const auto* operand : instruction->operands()) {
|
||||
// For simplicity we assume that all shape and layout changing
|
||||
// operations invalidate index reuse.
|
||||
if (Shape::Equal().IgnoreElementType()(operand->shape(),
|
||||
instruction->shape())) {
|
||||
// If the index is reused, it means the operand gets index values
|
||||
// from the same set of (indirect) users as 'instruction' itself.
|
||||
indexing_users[operand].insert(indexing_users[instruction].begin(),
|
||||
indexing_users[instruction].end());
|
||||
} else {
|
||||
// If the index is not reused, it means 'instruction' computes a
|
||||
// new index derived from the index it gets.
|
||||
indexing_users[operand].insert(instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
evaluate_fusion_computation(consumer);
|
||||
|
||||
// Also account for the 'producer' if it would be fused. Find the operand it
|
||||
// corresponds to.
|
||||
for (int64 operand_num = 0; operand_num < consumer->operand_count();
|
||||
++operand_num) {
|
||||
if (consumer->operand(operand_num) == producer) {
|
||||
auto instruction = consumer->fused_parameter(operand_num);
|
||||
int64& total = index_usage_count[producer];
|
||||
total = 0;
|
||||
for (const auto* user : indexing_users[instruction]) {
|
||||
total += index_usage_count[user];
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If 'producer' is a fusion node as well, also evaluate it.
|
||||
if (producer->opcode() == HloOpcode::kFusion) {
|
||||
evaluate_fusion_computation(producer);
|
||||
}
|
||||
|
||||
// Sum up the total number of emitted ops.
|
||||
int64 total = 0;
|
||||
for (const auto& entry : index_usage_count) {
|
||||
total += entry.second;
|
||||
}
|
||||
|
||||
// Check that the code duplication has at most a factor of 8 (where 8 is an
|
||||
// arbitrary constant that seems to work).
|
||||
return total > 8 * index_usage_count.size();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -91,6 +91,14 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
|
||||
tiled_parameter_info_ = info;
|
||||
}
|
||||
|
||||
// Evaluates whether fusing 'producer' into 'consumer' might cause exponential
|
||||
// behavior in FusedIrEmitter. We currently can have exponential time/memory
|
||||
// requirements for emitting certain fusion kernels, in which case we don't
|
||||
// want to fuse.
|
||||
// TODO(b/119692968): Remove this once we have fixed our fusion emitter.
|
||||
static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer,
|
||||
const HloInstruction* producer);
|
||||
|
||||
protected:
|
||||
// Returns the IrArrays for the fusion instruction operands.
|
||||
llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) {
|
||||
|
@ -349,8 +349,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
// TODO(b/119692968): This test runs OOM on the CPU backend.
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, DISABLED_ON_CPU(DeeplyNestedAddWithSlices)) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) {
|
||||
XlaBuilder builder(TestName());
|
||||
std::vector<float> values(30, 0.0);
|
||||
auto a_literal = LiteralUtil::CreateR1<float>(values);
|
||||
|
Loading…
Reference in New Issue
Block a user