Allow Sort to share the buffer with the operand if it is the only user.
The BitonicSort algorithm works in-place, so we can make use of that. On GPU, so far we copied the operand to the output and then performed the algorithm in-place. Now, we may not need to do this anymore if we see that the buffer is shared. Also, we now only need device-to-device copies in case the buffer is not shared, because constants are now also assigned a buffer. PiperOrigin-RevId: 206745686
This commit is contained in:
parent
3bec2640dc
commit
2826d123a0
@ -2068,6 +2068,7 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
|
||||
|
||||
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
auto keys = sort->operand(0);
|
||||
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
|
||||
ShapeIndex keys_shape_index({});
|
||||
ShapeIndex values_shape_index({});
|
||||
@ -2076,41 +2077,25 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
values_shape_index = ShapeIndex({1});
|
||||
}
|
||||
auto keys_destination = GetAllocationSlice(*sort, keys_shape_index);
|
||||
auto values_destination = GetAllocationSlice(*sort, values_shape_index);
|
||||
|
||||
// First copy the operand(s) to the output, so that we can sort in-place.
|
||||
// TODO(b/26783907): Share buffer of output and operand when it is possible.
|
||||
if (sort->operand(0)->IsConstant()) {
|
||||
thunks.push_back(MakeUnique<HostToDeviceCopyThunk>(
|
||||
/*source_address=*/sort->operand(0)->literal().untyped_data(),
|
||||
/*destination_buffer=*/keys_destination,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(0)->shape()),
|
||||
nullptr));
|
||||
} else {
|
||||
if (keys_destination != GetAllocationSlice(*keys)) {
|
||||
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
|
||||
/*source_address=*/GetAllocationSlice(*sort->operand(0)),
|
||||
/*source_address=*/GetAllocationSlice(*keys),
|
||||
/*destination_buffer=*/keys_destination,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(0)->shape()),
|
||||
nullptr));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr));
|
||||
}
|
||||
if (values != nullptr) {
|
||||
if (values->IsConstant()) {
|
||||
thunks.push_back(MakeUnique<HostToDeviceCopyThunk>(
|
||||
/*source_address=*/sort->operand(1)->literal().untyped_data(),
|
||||
/*destination_buffer=*/GetAllocationSlice(*sort, values_shape_index),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(1)->shape()),
|
||||
nullptr));
|
||||
} else {
|
||||
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
|
||||
/*source_address=*/GetAllocationSlice(*sort->operand(1)),
|
||||
/*destination_buffer=*/GetAllocationSlice(*sort, values_shape_index),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(1)->shape()),
|
||||
nullptr));
|
||||
}
|
||||
if (values != nullptr && values_destination != GetAllocationSlice(*values)) {
|
||||
// TODO(b/26783907): Figure out why we never seem to share buffers for
|
||||
// key/value sort.
|
||||
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
|
||||
/*source_address=*/GetAllocationSlice(*values),
|
||||
/*destination_buffer=*/values_destination,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr));
|
||||
}
|
||||
|
||||
int64 dimension_to_sort = sort->dimensions(0);
|
||||
int64 dimension_to_sort_bound =
|
||||
sort->operand(0)->shape().dimensions(dimension_to_sort);
|
||||
int64 dimension_to_sort_bound = keys->shape().dimensions(dimension_to_sort);
|
||||
int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
|
||||
auto index_type = b_.getInt64Ty();
|
||||
|
||||
@ -2134,7 +2119,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
thunks.push_back(
|
||||
BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
sort->operand(0)->shape(), ir_emitter_context_->device_description());
|
||||
keys->shape(), ir_emitter_context_->device_description());
|
||||
UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
|
||||
ir_emitter_context_->llvm_module());
|
||||
|
||||
|
@ -1084,6 +1084,21 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
return operand_indices.size() == 1 && operand_indices[0] == 0;
|
||||
}
|
||||
if (user->opcode() == HloOpcode::kSort) {
|
||||
// Only valid if there are no other users.
|
||||
if (operand->users().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
// If we only sort keys, the output of sort is not a tuple, so we can always
|
||||
// share the buffer.
|
||||
if (user->operand_count() == 1) {
|
||||
return true;
|
||||
}
|
||||
CHECK(!user_index.empty());
|
||||
// Only share with the right tuple element buffer.
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
|
||||
}
|
||||
if (user->opcode() == HloOpcode::kCall) {
|
||||
// Get all uses of value defined by 'operand' at 'operand_index'.
|
||||
const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
|
||||
|
@ -2232,6 +2232,48 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
auto keys = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, keys_shape, "keys"));
|
||||
auto sort =
|
||||
builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
EXPECT_TRUE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
Shape values_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
auto keys = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, keys_shape, "keys"));
|
||||
auto values = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, values_shape, "values"));
|
||||
auto sort = builder.AddInstruction(HloInstruction::CreateSort(
|
||||
ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
// The buffer for the keys can be shared with the first tuple entry.
|
||||
EXPECT_TRUE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
|
||||
// The buffer for the values can be shared with the second tuple entry.
|
||||
EXPECT_TRUE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1}));
|
||||
// Verify that the buffers are not shared with the "wrong" tuple entry.
|
||||
EXPECT_FALSE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
|
||||
EXPECT_FALSE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
|
@ -718,6 +718,7 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
|
||||
// root at operand 0 or 1. Or...
|
||||
// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index
|
||||
// 0.
|
||||
// (5) The 'user' of 'operand' is Sort, and it is the only user.
|
||||
//
|
||||
// (2) and (3) can only be determined if points-to analysis is available.
|
||||
bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
|
||||
@ -783,6 +784,21 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
return operand_indices.size() == 1 && operand_indices[0] == 0;
|
||||
}
|
||||
if (user->opcode() == HloOpcode::kSort) {
|
||||
// Only valid if there are no other users.
|
||||
if (operand->users().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
// If we only sort keys, the output of sort is not a tuple, so we can always
|
||||
// share the buffer.
|
||||
if (user->operand_count() == 1) {
|
||||
return true;
|
||||
}
|
||||
CHECK(!user_index.empty());
|
||||
// Only share with the right tuple element buffer.
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
|
||||
}
|
||||
if (user->opcode() == HloOpcode::kCall) {
|
||||
// TODO(b/62548313): Remove when buffer assignment is module scoped and
|
||||
// does not assign buffers to calls.
|
||||
|
@ -1012,6 +1012,48 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
|
||||
points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
auto keys = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, keys_shape, "keys"));
|
||||
auto sort =
|
||||
builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
EXPECT_TRUE(
|
||||
points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
Shape values_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
auto keys = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, keys_shape, "keys"));
|
||||
auto values = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, values_shape, "values"));
|
||||
auto sort = builder.AddInstruction(HloInstruction::CreateSort(
|
||||
ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
// The buffer for the keys can be shared with the first tuple entry.
|
||||
EXPECT_TRUE(
|
||||
points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
|
||||
// The buffer for the values can be shared with the second tuple entry.
|
||||
EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(values, {},
|
||||
sort, {1}));
|
||||
// Verify that the buffers are not shared with the "wrong" tuple entry.
|
||||
EXPECT_FALSE(
|
||||
points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
|
||||
EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(values, {},
|
||||
sort, {0}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
|
Loading…
Reference in New Issue
Block a user