[XLA] Try committing DUS buffer sharing again with fixes.
There were two sources of failures with the previous CL (cl/331779239): 1- Some models triggered dynamic update slice input and output to be different precisions. We don't want to disable mixed precision since the update can be a different precision than the input/output and there is a performance cost to disabling this mode. To fix, in bfloat16 propagation, we now force the DUS input and output to be the same precision. 2- Two models triggered patterns where two dynamic update slices that shared the same input were fused to the same multi-output fusion. This is wrong because the DUS happens in-place, so we can't have DUS multi-output fusions that use the same operand. We now explicitly check this pattern and disallow this in multi-output fusions. PiperOrigin-RevId: 332247949 Change-Id: If76f15a0d26e45e256269b2ec51f9c648d2f203f
This commit is contained in:
parent
105f48a75e
commit
593bf310d5
@ -83,6 +83,7 @@ cc_library(
|
||||
deps = [
|
||||
":bfloat16_support",
|
||||
":hlo",
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
@ -1684,6 +1685,7 @@ cc_library(
|
||||
hdrs = ["multi_output_fusion.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_dce",
|
||||
":hlo_pass",
|
||||
":hlo_reachability",
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -159,19 +160,20 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
|
||||
|
||||
Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
// Do not fold BF16 conversions for instructions related to tuples, entry and
|
||||
// exit of a computation, fusion, convert, side-effecting instructions and
|
||||
// control flow.
|
||||
if (hlo->opcode() == HloOpcode::kTuple || //
|
||||
hlo->opcode() == HloOpcode::kGetTupleElement || //
|
||||
hlo->opcode() == HloOpcode::kConstant || //
|
||||
hlo->opcode() == HloOpcode::kParameter || //
|
||||
hlo->opcode() == HloOpcode::kFusion || //
|
||||
hlo->opcode() == HloOpcode::kBitcastConvert || //
|
||||
hlo->opcode() == HloOpcode::kConvert || //
|
||||
hlo->opcode() == HloOpcode::kCall || //
|
||||
hlo->opcode() == HloOpcode::kCustomCall || //
|
||||
hlo->opcode() == HloOpcode::kWhile || //
|
||||
hlo->opcode() == HloOpcode::kConditional || //
|
||||
// exit of a computation, fusion, convert, side-effecting instructions,
|
||||
// in-place operations and control flow.
|
||||
if (hlo->opcode() == HloOpcode::kTuple || //
|
||||
hlo->opcode() == HloOpcode::kGetTupleElement || //
|
||||
hlo->opcode() == HloOpcode::kConstant || //
|
||||
hlo->opcode() == HloOpcode::kParameter || //
|
||||
hlo->opcode() == HloOpcode::kFusion || //
|
||||
hlo->opcode() == HloOpcode::kBitcastConvert || //
|
||||
hlo->opcode() == HloOpcode::kConvert || //
|
||||
hlo->opcode() == HloOpcode::kCall || //
|
||||
hlo->opcode() == HloOpcode::kCustomCall || //
|
||||
hlo->opcode() == HloOpcode::kWhile || //
|
||||
hlo->opcode() == HloOpcode::kConditional || //
|
||||
HloDataflowAnalysis::IsInPlaceOperation(hlo->opcode()) || //
|
||||
hlo->HasSideEffectNoRecurse()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -598,6 +598,31 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
type = F32;
|
||||
break;
|
||||
}
|
||||
// In order to find aliases due to in-place operations, use
|
||||
// GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
|
||||
// but this code works with HloModules that aren't ready yet to use
|
||||
// HloAliasAnalysis (e.g., their computation graphs may not have been
|
||||
// flattened yet).
|
||||
for (const auto& operand_and_output_index :
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
|
||||
if (operand_and_output_index.second == index) {
|
||||
const HloUse& operand = operand_and_output_index.first;
|
||||
for (const auto* value :
|
||||
dataflow_
|
||||
->GetValueSet(hlo->operand(operand.operand_number),
|
||||
operand.operand_index)
|
||||
.values()) {
|
||||
auto value_type = ValueTypeAfterChange(value);
|
||||
if (value_type == BF16) {
|
||||
continue;
|
||||
}
|
||||
CHECK_EQ(value_type, F32);
|
||||
type = F32;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// It's possible that a user has been changed from BF16 to F32
|
||||
// during this final adjustment pass, so we need to check
|
||||
// AllUsersConsumeBF16() again.
|
||||
|
@ -1156,4 +1156,30 @@ ENTRY entry {
|
||||
EXPECT_FALSE(PropagatePrecision(module.get()));
|
||||
}
|
||||
|
||||
TEST_F(BFloat16PropagationTest, DynamicUpdateSlice) {
|
||||
// This test is crafted so that the DUS has an f32 input (due to parameter)
|
||||
// and bf16 output (due to dot). But we should enforce DUS operand 0 and
|
||||
// output to get the same precision since it's an in-place operation.
|
||||
const string module_str = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
param = f32[128,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
dynamic-update-slice = f32[128,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3)
|
||||
ROOT dot = f32[128,128] dot(dynamic-update-slice, dynamic-update-slice), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
EXPECT_FALSE(PropagatePrecision(module.get()));
|
||||
|
||||
HloInstruction* dus = module->entry_computation()->GetInstructionWithName(
|
||||
"dynamic-update-slice");
|
||||
EXPECT_FALSE(OutputsBF16(dus));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -1007,102 +1007,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
|
||||
return true;
|
||||
} // namespace xla
|
||||
|
||||
Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
||||
// Try allocate same buffer for dynamic update slice's operand and output.
|
||||
|
||||
// If memory_space_assignment is run and there is information about a color in
|
||||
// preset assignments, don't merge those buffers. We expect
|
||||
// memory_space_assignment to have merged these buffers. If
|
||||
// memory_space_assignment didn't merge these buffers and have assigned
|
||||
// different offsets to the operand and the output buffer, merging the buffers
|
||||
// can cause memory corruption if memory_space_assignment assigned a different
|
||||
// buffer at the same offset.
|
||||
absl::flat_hash_set<int64> excluded_colors;
|
||||
if (preset_assignments_) {
|
||||
for (const auto& color_and_info :
|
||||
preset_assignments_->assignment_informations()) {
|
||||
excluded_colors.insert(color_and_info.first);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(yunxing): Moving this logic to alias analysis and add must-alias rule
|
||||
// to operations that can be done in place.
|
||||
for (HloComputation* computation : assignment->module().computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice ||
|
||||
(instruction->opcode() == HloOpcode::kFusion &&
|
||||
(instruction->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice)))) {
|
||||
continue;
|
||||
}
|
||||
if (instruction->parent()->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
if (instruction->operand_count() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The operand can't share the same buffer with the user based on dataflow
|
||||
// analysis.
|
||||
if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser(
|
||||
instruction->mutable_operand(0), {}, instruction, {})) {
|
||||
continue;
|
||||
}
|
||||
HloBuffer& instruction_buffer =
|
||||
assignment->alias_analysis().GetUniqueBufferAt(instruction, {});
|
||||
|
||||
HloBuffer& operand_buffer =
|
||||
assignment->alias_analysis().GetUniqueBufferAt(
|
||||
instruction->operand(0), {});
|
||||
|
||||
// The instruction or operand color is excluded because it was assigned by
|
||||
// memory_space_assignment.
|
||||
if (excluded_colors.contains(instruction_buffer.color()) ||
|
||||
excluded_colors.contains(operand_buffer.color())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Already have the same buffer. No need to merge those.
|
||||
if (instruction_buffer.id() == operand_buffer.id()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do not perform in-place dynamic update slice if the operand buffer is
|
||||
// read-only.
|
||||
if (HloBufferIsReadOnly(operand_buffer)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool interfere = false;
|
||||
|
||||
for (const HloValue* instruction_value : instruction_buffer.values()) {
|
||||
for (const HloValue* operand_value : operand_buffer.values()) {
|
||||
if (assignment->hlo_ordering().MayInterfere(
|
||||
*instruction_value, *operand_value,
|
||||
assignment->dataflow_analysis())) {
|
||||
interfere = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (interfere) {
|
||||
continue;
|
||||
}
|
||||
if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) {
|
||||
continue;
|
||||
}
|
||||
if (instruction_buffer.color() != operand_buffer.color()) {
|
||||
continue;
|
||||
}
|
||||
VLOG(3) << "Merging inplace " << instruction_buffer << " and "
|
||||
<< operand_buffer;
|
||||
assignment->alias_analysis().MergeBuffers(instruction_buffer,
|
||||
operand_buffer);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BufferAssigner::AssignSingleHloBuffer(
|
||||
const HloBuffer* hlo_buffer, bool is_thread_local,
|
||||
absl::flat_hash_map<const HloComputation*,
|
||||
@ -1654,7 +1558,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
||||
VLOG(3) << "After coloring:";
|
||||
XLA_VLOG_LINES(3,
|
||||
assignment->alias_analysis().dataflow_analysis().ToString());
|
||||
TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get()));
|
||||
|
||||
std::vector<const HloComputation*> thread_local_computations;
|
||||
std::vector<const HloComputation*> global_computations;
|
||||
|
@ -635,10 +635,6 @@ class BufferAssigner {
|
||||
absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
|
||||
BufferAssignment* assignment);
|
||||
|
||||
// Promotes operations (DUS, scatter) to be done in place: If an operation can
|
||||
// be done in place, merge its buffer with its operand buffer.
|
||||
Status MergeInplaceOpBuffers(BufferAssignment* assignment);
|
||||
|
||||
// Assigns a single hlo buffer to an HLO allocation.
|
||||
Status AssignSingleHloBuffer(
|
||||
const HloBuffer* hlo_buffer, bool is_thread_local,
|
||||
|
@ -1925,8 +1925,10 @@ ENTRY main {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
|
||||
HloInstruction* parameter =
|
||||
m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
|
||||
HloInstruction* dus =
|
||||
HloInstruction* dus1 =
|
||||
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
|
||||
HloInstruction* dus2 =
|
||||
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
|
||||
|
||||
auto buffers = RunBufferAssignment(m.get());
|
||||
|
||||
@ -1934,8 +1936,10 @@ ENTRY main {
|
||||
const BufferAllocation& parameter_alloc =
|
||||
GetTopLevelAllocation(*buffers, parameter);
|
||||
|
||||
const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus);
|
||||
EXPECT_NE(parameter_alloc, dus_alloc);
|
||||
const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
|
||||
EXPECT_EQ(parameter_alloc, dus1_alloc);
|
||||
const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
|
||||
EXPECT_EQ(parameter_alloc, dus2_alloc);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -362,6 +362,19 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
|
||||
// will remove the unnecessary copies.
|
||||
Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
|
||||
HloInstruction* in_place_op,
|
||||
int64 operand_number) {
|
||||
VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
|
||||
HloInstruction* operand = in_place_op->mutable_operand(operand_number);
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
|
||||
in_place_op->parent()->DeepCopyInstruction(operand));
|
||||
TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Conservatively adds copies before root instruction of entry computation and
|
||||
// each aliased parameter to resolve interference of aliased input and output
|
||||
// buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
|
||||
@ -509,6 +522,12 @@ class CopyRemover {
|
||||
// value. The map is used to construct the copy info map below.
|
||||
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
|
||||
for (const HloBuffer& buffer : alias_analysis.buffers()) {
|
||||
// No copies should have been inserted within fused computations, so no
|
||||
// need to remove them. HloOrdering isn't compatible with HloValues inside
|
||||
// fusions, so skip copy removal for them.
|
||||
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
|
||||
continue;
|
||||
}
|
||||
// Verify values contained in the buffer are strictly ordered. This
|
||||
// should always be the case after adding copies to eliminate
|
||||
// interference. Specifically, the addition of the control flow edges
|
||||
@ -591,7 +610,7 @@ class CopyRemover {
|
||||
void CreateCopyMap(
|
||||
const HloModule& module,
|
||||
const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
|
||||
for (HloComputation* computation : module.computations()) {
|
||||
for (HloComputation* computation : module.MakeNonfusionComputations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
// Add copies with unambiguous source values to the map. Copies with
|
||||
// ambiguous sources are not removable.
|
||||
@ -1005,7 +1024,7 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, can_share_buffer_));
|
||||
|
||||
for (HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
for (HloInstruction* instruction :
|
||||
computation->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() == HloOpcode::kWhile) {
|
||||
@ -1013,6 +1032,15 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
||||
} else if (instruction->opcode() == HloOpcode::kConditional) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddCopiesForConditional(*alias_analysis, instruction));
|
||||
} else {
|
||||
for (const auto& operand_and_output_index :
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
|
||||
const HloUse& operand = operand_and_output_index.first;
|
||||
CHECK_EQ(operand.operand_index, ShapeIndex{})
|
||||
<< "Support for non-{} shape operand not currently implemented.";
|
||||
TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
|
||||
*alias_analysis, instruction, operand.operand_number));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2530,5 +2530,250 @@ ENTRY Entry {
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add = f32[1280,1,128] add(negate, negate)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
add = f32[1280,1,128] add(negate, negate)
|
||||
fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
|
||||
get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
|
||||
get-tuple-element.3 = s32[] get-tuple-element(state), index=0
|
||||
constant.2 = s32[] constant(128)
|
||||
add.5 = s32[] add(get-tuple-element.3, constant.2)
|
||||
constant.3 = s32[] constant(0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation.1 {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
fused_computation.2 {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
add = f32[1280,1,128] add(negate, negate)
|
||||
fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1
|
||||
ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) {
|
||||
// Tests multi-output fusion with two DUS outputs, requiring two copies.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
param2 = f32[1280,1,128] parameter(2)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add.1 = f32[1280,1,128] add(param0, param0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
negate1 = f32[1280,1,128] negate(param)
|
||||
negate2 = f32[1280,1,128] negate(param)
|
||||
fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
||||
gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
|
||||
gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
|
||||
gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
|
||||
add0 = f32[1280,1,128] add(negate0, gte0)
|
||||
add1 = f32[1280,1,128] add(negate1, gte1)
|
||||
add2 = f32[1280,1,128] add(negate2, gte2)
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 2);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) {
|
||||
// Same as above, but negate1 is not used beyond fusion, so it only needs one
|
||||
// copy for negate0.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
param2 = f32[1280,1,128] parameter(2)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add.1 = f32[1280,1,128] add(param0, param0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
negate1 = f32[1280,1,128] negate(param)
|
||||
negate2 = f32[1280,1,128] negate(param)
|
||||
fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
||||
gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
|
||||
gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
|
||||
gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
|
||||
add0 = f32[1280,1,128] add(negate0, gte0)
|
||||
add1 = f32[1280,1,128] add(gte1, gte1)
|
||||
add2 = f32[1280,1,128] add(negate2, gte2)
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1,6 +1,6 @@
|
||||
// RUN: hlo_to_llvm_ir %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||
@ -43,8 +43,8 @@
|
||||
// CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4
|
||||
// CHECK: br label %[[VAL_23]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{i32 0, i32 6}
|
||||
// CHECK: !4 = !{}
|
||||
@ -72,7 +72,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
|
||||
@ -104,8 +104,8 @@ ENTRY main {
|
||||
// CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4
|
||||
// CHECK: br label %[[VAL_57]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{}
|
||||
|
||||
@ -131,7 +131,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
|
||||
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
|
||||
@ -188,8 +188,8 @@ ENTRY main {
|
||||
// CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1
|
||||
// CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{i32 0, i32 6}
|
||||
// CHECK: !4 = !{}
|
||||
@ -216,7 +216,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
|
||||
@ -253,8 +253,8 @@ ENTRY main {
|
||||
// CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4
|
||||
// CHECK: br label %[[VAL_138]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{}
|
||||
|
||||
|
@ -308,6 +308,39 @@ class BufferValueMap {
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeInPlaceOperationAliasedBuffers(
|
||||
const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
|
||||
VLOG(3) << "Compute aliases for in-place operations (e.g. "
|
||||
"kDynamicUpdateSlice and kScatter)";
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
HloInstruction* instruction = position.instruction;
|
||||
for (const auto& operand_and_output_index :
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
|
||||
if (position.index == operand_and_output_index.second) {
|
||||
const HloUse& operand = operand_and_output_index.first;
|
||||
const HloValue& operand_value = dataflow_.GetUniqueValueAt(
|
||||
instruction->operand(operand.operand_number),
|
||||
operand.operand_index);
|
||||
VLOG(3) << " operand value " << operand_value.ToShortString()
|
||||
<< " aliases.";
|
||||
aliased_buffers->push_back(GetBufferForValue(operand_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const HloUse& use : value.uses()) {
|
||||
for (const auto& operand_and_output_index :
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(use.instruction)) {
|
||||
if (use == operand_and_output_index.first) {
|
||||
const HloValue& use_value = dataflow_.GetUniqueValueAt(
|
||||
use.instruction, operand_and_output_index.second);
|
||||
VLOG(3) << " use value " << use_value.ToShortString() << " aliases.";
|
||||
aliased_buffers->push_back(GetBufferForValue(use_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute and return a vector of buffers that the given value must be
|
||||
// contained in due to HLO aliasing rules.
|
||||
std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
|
||||
@ -318,6 +351,7 @@ class BufferValueMap {
|
||||
ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeInPlaceOperationAliasedBuffers(value, &aliased_buffers);
|
||||
// Uniquify aliased buffers.
|
||||
absl::c_sort(aliased_buffers);
|
||||
aliased_buffers.erase(
|
||||
|
@ -1062,6 +1062,118 @@ TEST_F(HloAliasAnalysisTest, MergeBuffersReverse) {
|
||||
analysis.BufferLivesOut(analysis.buffers()[0]);
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, DynamicUpdateSlice) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {8});
|
||||
Shape update_shape = ShapeUtil::MakeShape(F32, {4});
|
||||
Shape index_shape = ShapeUtil::MakeShape(S32, {});
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param0"));
|
||||
auto param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, update_shape, "param1"));
|
||||
auto param2 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, index_shape, "param2"));
|
||||
auto copy0 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0));
|
||||
auto dynamic_update_slice = builder.AddInstruction(
|
||||
HloInstruction::CreateDynamicUpdateSlice(shape, copy0, param1, {param2}));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(copy0),
|
||||
analysis.GetUniqueBufferAt(dynamic_update_slice));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, DynamicUpdateSliceMultiOutputFusion) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
param2 = f32[1280,1,128] parameter(2)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add.1 = f32[1280,1,128] add(param0, param0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
negate1 = f32[1280,1,128] negate(param)
|
||||
negate2 = f32[1280,1,128] negate(param)
|
||||
ROOT fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
LOG(INFO) << analysis.ToString();
|
||||
|
||||
// Expect negate1 and negate2 to alias with fusion{1} and fusion{2}
|
||||
// respectively (due to DUS), but not negate0 and fusion{0}.
|
||||
const HloInstruction* fusion =
|
||||
module_->entry_computation()->GetInstructionWithName("fusion");
|
||||
const HloInstruction* negate0 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate0");
|
||||
const HloInstruction* negate1 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate1");
|
||||
const HloInstruction* negate2 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate2");
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate1),
|
||||
analysis.GetUniqueBufferAt(fusion, {1}));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate2),
|
||||
analysis.GetUniqueBufferAt(fusion, {2}));
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(negate0),
|
||||
analysis.GetUniqueBufferAt(fusion, {0}));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, ChainedDynamicUpdateSliceFusion) {
|
||||
// CPU and GPU backends may generate fusions with dynamic update slices
|
||||
// feeding each other. They expect the fusion to not be in-place if that is
|
||||
// the case.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
ROOT fusion = f32[1280,1,128] fusion(negate0), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
LOG(INFO) << analysis.ToString();
|
||||
|
||||
const HloInstruction* fusion =
|
||||
module_->entry_computation()->GetInstructionWithName("fusion");
|
||||
const HloInstruction* negate0 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate0");
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(negate0),
|
||||
analysis.GetUniqueBufferAt(fusion));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, BitcastInterference) {
|
||||
// A bitcast value simultaneously live with its operand should not cause
|
||||
// interference.
|
||||
|
@ -1178,69 +1178,49 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
|
||||
return true;
|
||||
}
|
||||
|
||||
// Given a fusion whose root is a dynamic-update-slice op, determines whether
|
||||
// the fusion's output buffer can be shared with the buffer of fusion_param,
|
||||
// which must be a fused parameter of the fusion.
|
||||
//
|
||||
// Preconditions:
|
||||
//
|
||||
// - fusion's root is a dynamic-update-slice op.
|
||||
// - fusion_param is a parameter within the fusion.
|
||||
//
|
||||
// fusion_param may point to a subelement of the actual parameter instruction if
|
||||
// the param is a tuple; i.e. fusion_param->index() need not be the empty list.
|
||||
//
|
||||
// Returns true if:
|
||||
//
|
||||
// * fusion_param is used by the root of dynamic-update-slice as the "base" of
|
||||
// the update, i.e. the thing being updated, AND
|
||||
// * all other uses of fusion_param are dynamic-slices that slice the same
|
||||
// indices as are overwritten in the dynamic-update-slice.
|
||||
//
|
||||
// In the case that there are no other uses of fusion_param (last bullet point
|
||||
// is vacuously true) it's easy to see why an in-place DUS is safe; this is just
|
||||
// the "natural" implementation of DUS. If there are other users, in-place DUS
|
||||
// is safe on the assumption that the thread which writes element i of the
|
||||
// output will be the only one to read element i of fusion_param (via the
|
||||
// dynamic-slice ops).
|
||||
static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion,
|
||||
const HloValue& fusion_param_value) {
|
||||
auto* root =
|
||||
Cast<HloDynamicUpdateSliceInstruction>(fusion->fused_expression_root());
|
||||
auto* fusion_param = fusion_param_value.instruction();
|
||||
CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter);
|
||||
CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation());
|
||||
/*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) {
|
||||
return opcode == HloOpcode::kDynamicUpdateSlice ||
|
||||
opcode == HloOpcode::kScatter;
|
||||
}
|
||||
|
||||
// fusion_param must be used by the root as the "base" of the
|
||||
// dynamic-update-slice. The natural way to check this would be
|
||||
//
|
||||
// `if (root->operand(0) != fusion_param)`
|
||||
//
|
||||
// but we also have to handle the case where the fusion parameter is
|
||||
// tuple-shaped and we're considering just one element of that tuple, i.e.
|
||||
// fusion_param.index() != {}.
|
||||
if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) {
|
||||
return use.instruction == root;
|
||||
}) != 1) {
|
||||
return false;
|
||||
/*static*/ std::vector<std::pair<HloUse, ShapeIndex>>
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(HloInstruction* instruction) {
|
||||
if (IsInPlaceOperation(instruction->opcode())) {
|
||||
return {{HloUse{instruction, 0, {}}, {}}};
|
||||
} else if (instruction->opcode() != HloOpcode::kFusion) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// All other uses of fusion_param must be dynamic-slices that slice the same
|
||||
// indices as are overwritten by the dynamic-update-slice.
|
||||
for (const HloUse& use : fusion_param_value.uses()) {
|
||||
auto* user = use.instruction;
|
||||
if (user == root) {
|
||||
continue;
|
||||
std::vector<std::pair<HloUse, ShapeIndex>> input_output_pairs;
|
||||
for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) {
|
||||
const HloInstruction* hlo_generating_output =
|
||||
instruction->fused_expression_root();
|
||||
for (int64 i = 0; i < indexed_shape.index.size(); ++i) {
|
||||
if (hlo_generating_output->opcode() == HloOpcode::kTuple) {
|
||||
hlo_generating_output =
|
||||
hlo_generating_output->operand(indexed_shape.index[i]);
|
||||
} else {
|
||||
CHECK_EQ(i, indexed_shape.index.size() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Check that `user` is a dynamic-slice op and has the same slice indices as
|
||||
// `root`.
|
||||
auto* ds = DynCast<HloDynamicSliceInstruction>(user);
|
||||
if (!ds || ds->index_operands() != root->index_operands()) {
|
||||
return false;
|
||||
if (IsInPlaceOperation(hlo_generating_output->opcode())) {
|
||||
ShapeIndex operand_index;
|
||||
const HloInstruction* fusion_parameter =
|
||||
hlo_generating_output->operand(0);
|
||||
while (fusion_parameter->opcode() == HloOpcode::kGetTupleElement) {
|
||||
operand_index.push_front(fusion_parameter->tuple_index());
|
||||
fusion_parameter = fusion_parameter->operand(0);
|
||||
}
|
||||
|
||||
if (fusion_parameter->opcode() == HloOpcode::kParameter) {
|
||||
input_output_pairs.emplace_back(
|
||||
HloUse{instruction, fusion_parameter->parameter_number(),
|
||||
operand_index},
|
||||
indexed_shape.index);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return input_output_pairs;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
@ -1261,24 +1241,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (user->opcode() == HloOpcode::kFusion) {
|
||||
// Get the parameter associated with 'operand';
|
||||
HloInstruction* fusion_param =
|
||||
user->fused_parameter(user->operand_index(operand));
|
||||
|
||||
const HloValue& fusion_param_value =
|
||||
GetValueDefinedAt(fusion_param, operand_index);
|
||||
|
||||
// TODO(b/80315712): This code is in a bit of a weird intermediate state
|
||||
// at the moment. The in-place DUS check really needs to be common to all
|
||||
// backends, so it runs first. Then we run the backend-specific check if
|
||||
// provided, or go through the target-independent check if not.
|
||||
// Unfortunately, the notionally "target-independent" path actually contains
|
||||
// some target-specific code, so we can't run all of it *in addition* to the
|
||||
// target-specific function, like the interface documentation says.
|
||||
if (user->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice) {
|
||||
return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value);
|
||||
// Must-alias relationship returns true for in-place operations (DUS and DUS
|
||||
// fusions), regardless of the backend.
|
||||
for (const auto& operand_and_output_index :
|
||||
GetInPlaceInputOutputPairs(user)) {
|
||||
if (operand_and_output_index.second != user_index) {
|
||||
continue;
|
||||
}
|
||||
for (const HloUse& use : GetUniqueValueAt(operand, operand_index).uses()) {
|
||||
if (use == operand_and_output_index.first) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,6 +49,9 @@ class HloDataflowAnalysis {
|
||||
// Infrastructure for passing may-alias hints: HLO passes can populate the
|
||||
// may-alias table. If an empty optional is returned, default rules are used.
|
||||
//
|
||||
// Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be
|
||||
// overriden using backend-specific overrides.
|
||||
//
|
||||
// The first parameter of the function should be the instruction, the
|
||||
// second parameter should be an operand of the instruction. The third
|
||||
// parameter should be the output index of the instruction.
|
||||
@ -160,6 +163,15 @@ class HloDataflowAnalysis {
|
||||
|
||||
const HloModule& module() const { return module_; }
|
||||
|
||||
// Returns true if the operation is an in-place operation and its operand 0
|
||||
// must alias with the output.
|
||||
static bool IsInPlaceOperation(HloOpcode opcode);
|
||||
|
||||
// Returns a vector consisting of the HloUse (operand number and shape index)
|
||||
// and output shape index of the in-place operations within this HLO.
|
||||
static std::vector<std::pair<HloUse, ShapeIndex>> GetInPlaceInputOutputPairs(
|
||||
HloInstruction* instruction);
|
||||
|
||||
protected:
|
||||
HloDataflowAnalysis(const HloModule& module, bool ssa_form,
|
||||
bool bitcast_defines_value = false,
|
||||
|
@ -2324,36 +2324,6 @@ TEST_F(CanShareOperandBufferWithUserTest,
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithDifferentIndices) {
|
||||
const char* kModule = R"(
|
||||
HloModule test
|
||||
|
||||
fused_computation {
|
||||
p0 = f32[10,20,30] parameter(0)
|
||||
p1 = s32[] parameter(1)
|
||||
p2 = s32[] parameter(2)
|
||||
p3 = s32[] parameter(3)
|
||||
slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
|
||||
ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p3, p2)
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10,20,30] parameter(0)
|
||||
p1 = s32[] parameter(1)
|
||||
p2 = s32[] parameter(2)
|
||||
p3 = s32[] parameter(3)
|
||||
ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
|
||||
auto* fusion = module_->entry_computation()->root_instruction();
|
||||
auto* param = module_->entry_computation()->parameter_instruction(0);
|
||||
|
||||
RunAnalysis();
|
||||
EXPECT_FALSE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) {
|
||||
const char* kModule = R"(
|
||||
HloModule test
|
||||
|
@ -232,6 +232,24 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
||||
return copies;
|
||||
}
|
||||
|
||||
int64 GetAlternateMemoryOffset(const PresetAssignments& preset_assignments,
|
||||
const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const {
|
||||
// Returns the offset of the assignment, -1 if it's not in the alternate
|
||||
// memory.
|
||||
const HloModule* module = instruction->parent()->parent();
|
||||
auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
|
||||
HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index);
|
||||
for (auto& pos_and_chunk : preset_assignments.chunks()) {
|
||||
for (auto& value : buffer.values()) {
|
||||
if (pos_and_chunk.first == value->defining_position()) {
|
||||
return pos_and_chunk.second.offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
|
||||
HloComputation::Builder builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||
@ -4415,6 +4433,47 @@ TEST_P(MemorySpaceAssignmentTest, Determinism) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, InPlaceOp) {
|
||||
// Tests that in-place ops like DynamicUpdateSlice get the same allocation as
|
||||
// its input.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module, is_scheduled=true
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[2,3] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast = f32[2,1] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[2,3] parameter(0)
|
||||
negate = f32[2,3] negate(param)
|
||||
fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation
|
||||
ROOT add = f32[2,3] add(fusion, fusion)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
auto preset_assignments = AssignMemorySpace(module.get());
|
||||
HloInstruction* negate_instruction =
|
||||
module->entry_computation()->GetInstructionWithName("negate");
|
||||
int64 negate_offset =
|
||||
GetAlternateMemoryOffset(*preset_assignments, negate_instruction);
|
||||
HloInstruction* fusion_instruction =
|
||||
module->entry_computation()->GetInstructionWithName("fusion");
|
||||
int64 fusion_offset =
|
||||
GetAlternateMemoryOffset(*preset_assignments, fusion_instruction);
|
||||
// We expect negate and fusion to get the same offsets.
|
||||
EXPECT_EQ(negate_offset, fusion_offset);
|
||||
const bool allocate_across_sequential_calls = GetParam();
|
||||
if (allocate_across_sequential_calls) {
|
||||
EXPECT_NE(negate_offset, -1);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
|
||||
MemorySpaceAssignmentTest,
|
||||
::testing::Values(false, true));
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -338,6 +339,21 @@ bool MultiOutputFusion::LegalToFuseMainConstraints(HloInstruction* instr1,
|
||||
if (!ShapesCompatibleForFusion(instr1, instr2)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If both nodes are in-place operations and they use a common in-place
|
||||
// operand, we can't fuse these two.
|
||||
for (const auto& operand_and_output_index1 :
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr1)) {
|
||||
const HloInstruction* operand =
|
||||
instr1->operand(operand_and_output_index1.first.operand_number);
|
||||
for (const auto& operand_and_output_index2 :
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr2)) {
|
||||
if (operand ==
|
||||
instr2->operand(operand_and_output_index2.first.operand_number)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user