[XLA/CPU] Support buffer aliasing on XLA:CPU
PiperOrigin-RevId: 315633188 Change-Id: Id403065962b3151ebb6c741fcf9ddf4523490cde
This commit is contained in:
parent
0ef962b1a5
commit
584a042d35
@ -207,7 +207,8 @@ Status CpuExecutable::ExecuteComputeFunction(
|
||||
|
||||
StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<MaybeOwningDeviceMemory> buffers) {
|
||||
absl::Span<MaybeOwningDeviceMemory> buffers,
|
||||
absl::Span<ExecutionInput> arguments) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
ExecutionOutput result(/*on_host_shape=*/result_shape(),
|
||||
/*on_device_shape=*/result_shape(),
|
||||
@ -221,7 +222,7 @@ StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
|
||||
// caller.
|
||||
for (auto& p : result.MutableResult()->buffers()) {
|
||||
const ShapeIndex& index = p.first;
|
||||
se::DeviceMemoryBase& device_memory = p.second;
|
||||
se::DeviceMemoryBase& result_buffer = p.second;
|
||||
const HloValueSet& sources = this->GetRootValueSet().element(index);
|
||||
// The points to set is unambiguous so the set should be a
|
||||
// singleton.
|
||||
@ -229,39 +230,54 @@ StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
|
||||
const HloValue* value_source = sources.values()[0];
|
||||
HloInstruction* src = value_source->instruction();
|
||||
|
||||
// The source for this result buffer can be a nested buffer such as
|
||||
// a tuple element. The source instruction should have a
|
||||
// non-parameter buffer assigned.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const BufferAllocation::Slice slice,
|
||||
this->assignment_->GetUniqueSlice(src, value_source->index()));
|
||||
const BufferAllocation::Index buffer_index = slice.index();
|
||||
MaybeOwningDeviceMemory& buffer = buffers[buffer_index];
|
||||
if (!slice.allocation()->is_entry_computation_parameter()) {
|
||||
// If the buffer coming out of the result is from a parameter, the
|
||||
// owning buffer will be null, and that means the caller aliased some
|
||||
// parameter buffer to an output one (via the
|
||||
// HloInputOutputAliasConfig API). If that is the case, the caller
|
||||
// will receive a partially complete scoped shaped buffer, which they
|
||||
// will have to fill up on return. Unfortunately the interface to the
|
||||
// execute APIs are ShapedBuffer pointer based, which assumes caller
|
||||
// ownership, and hence a buffer coming from there cannot be part of
|
||||
// the new ScopedShapedBuffer we create for the result (which assumes
|
||||
// ownership).
|
||||
absl::optional<se::OwningDeviceMemory> owned_buffer = buffer.Release();
|
||||
CHECK(owned_buffer);
|
||||
device_memory = owned_buffer->Release();
|
||||
buffer = device_memory;
|
||||
} else {
|
||||
auto output_alias = input_output_alias.GetAliasedOutput(
|
||||
slice.allocation()->parameter_number(),
|
||||
slice.allocation()->param_shape_index());
|
||||
CHECK(output_alias) << "Output buffer is coming from parameter "
|
||||
<< slice.allocation()->parameter_number()
|
||||
<< " at index "
|
||||
<< slice.allocation()->param_shape_index()
|
||||
<< ", but no alias exists";
|
||||
CHECK_EQ(*output_alias, index);
|
||||
// TODO(cheshire): duplication with other backends.
|
||||
absl::optional<HloInputOutputAliasConfig::Alias> alias =
|
||||
input_output_alias.GetAliasedParameter(index);
|
||||
if (alias) {
|
||||
CHECK_LT(alias->parameter_number, arguments.size());
|
||||
ExecutionInput& input = arguments[alias->parameter_number];
|
||||
MaybeOwningDeviceMemory* maybe_owning_memory =
|
||||
input.MutableBuffer(alias->parameter_index);
|
||||
if (absl::optional<se::OwningDeviceMemory> owning =
|
||||
maybe_owning_memory->Release()) {
|
||||
// If the caller passes the ownership of the device memory, reuse it
|
||||
// as the output buffer. It is up to the caller whether or not to
|
||||
// donate a buffer; the aliasing information describes which buffers
|
||||
// may alias, not buffers that must alias.
|
||||
se::DeviceMemoryBase argument_buffer = owning->Release();
|
||||
*maybe_owning_memory = argument_buffer;
|
||||
result_buffer = argument_buffer;
|
||||
if (alias->kind == HloInputOutputAliasConfig::kUserAlias) {
|
||||
// This is a user alias, so a must alias. The caller is giving us the
|
||||
// input buffer, but in case of error of the execute call, we should
|
||||
// not be releasing it as it contains valid data (for example, it is a
|
||||
// parameter which the user wants us to alias, in a gradient update
|
||||
// computation). So we store the index into the result in the aliased
|
||||
// vactor, which will be fed to the ExecutionOutput, which will be
|
||||
// using the indices to drop the addresses from its own
|
||||
// ScopedShapedBuffer result, if the ExecutionOutput is not committed.
|
||||
result.AddAliasedIndex(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (result_buffer.is_null()) {
|
||||
// The source for this result buffer can be a nested buffer such as
|
||||
// a tuple element. The source instruction should have a
|
||||
// non-parameter buffer assigned.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const BufferAllocation::Slice slice,
|
||||
this->assignment_->GetUniqueSlice(src, value_source->index()));
|
||||
const BufferAllocation::Index buffer_index = slice.index();
|
||||
MaybeOwningDeviceMemory& buffer = buffers[buffer_index];
|
||||
if (absl::optional<se::OwningDeviceMemory> owned_buffer =
|
||||
buffer.Release()) {
|
||||
result_buffer = owned_buffer->Release();
|
||||
buffer = result_buffer;
|
||||
} else {
|
||||
result_buffer = buffer.AsDeviceMemoryBase();
|
||||
result.AddAliasedIndex(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::move(result);
|
||||
@ -303,7 +319,8 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ExecutionOutput result,
|
||||
CreateResultShapedBuffer(run_options, absl::MakeSpan(buffers)));
|
||||
CreateResultShapedBuffer(run_options, absl::MakeSpan(buffers),
|
||||
absl::MakeSpan(arguments)));
|
||||
|
||||
// Logically we want this lambda to capture `buffers` by move, ultimately our
|
||||
// functor needs to be wrapped in an std::function, and that requires its
|
||||
|
@ -118,7 +118,8 @@ class CpuExecutable : public Executable {
|
||||
// assignment.
|
||||
StatusOr<ExecutionOutput> CreateResultShapedBuffer(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<MaybeOwningDeviceMemory> buffers);
|
||||
absl::Span<MaybeOwningDeviceMemory> buffers,
|
||||
absl::Span<ExecutionInput> arguments);
|
||||
|
||||
// Returns the instruction value set of the root instruction of the entry
|
||||
// computation. Uses dataflow analysis from buffer assignment.
|
||||
|
@ -216,11 +216,8 @@ TEST_F(BufferDonationTest, SimpleWhileTupleTest) {
|
||||
HloInstruction::CreateGetTupleElement(f32v1_, while0, 1));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
// Input output aliasing is supported on CPU and GPU.
|
||||
#if defined(XLA_TEST_BACKEND_TPU) || defined(XLA_TEST_BACKEND_GPU)
|
||||
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
|
||||
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
|
||||
#endif
|
||||
|
||||
auto arg = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR0<int>(0), LiteralUtil::CreateR1<float>({1.1f})});
|
||||
|
Loading…
Reference in New Issue
Block a user