Make XLA's kUserAlias work together with XRT's swap/compaction.

The XRT tuple allocation owns the device memory, which, in order for the lower level
aliasing to work, needs to be handed out as "owning" within the parameter's shape tree.
But if the parameter's shape tree get destroyed (for an intermediate error before execute)
the memory will get released and the tuple allocation will be pointing to free memory.
This CL introduces an ExecutionInput data structure which wraps a maybe-owning shape tree
together with the indices which should be released before the shape tree gets destroyed.
This allows the data structure to travel down until the point where the buffers lands
inside the ExecutionOutput, which uses a similar logic (until finally the result gets consumed).
Unfortunately the situation of the device memory data structures got a bit messy, with
Owning, MaybeOwning, ShapedBuffer, ScopedShapedBuffer, ... none of which can work nicely
with buffer sharing.
Ideally we should have something like std::shared_ptr<OwningDeviceMemory> and
ShapeTree<std::shared_ptr<OwningDeviceMemory>> and be done with it.
Unfortunately the change (I started that route first) towards that goal is pretty major.

PiperOrigin-RevId: 298498866
Change-Id: I2e27c11b7187fa2992ae3b606ea95c18f312cb5a
This commit is contained in:
Davide Libenzi 2020-03-02 18:35:45 -08:00 committed by TensorFlower Gardener
parent 9cef7b78b0
commit d1085a6e00
14 changed files with 152 additions and 88 deletions

View File

@ -271,8 +271,7 @@ static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
absl::Span<Shape const* const> argument_host_shapes,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
ExecutableRunOptions run_options) {
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
if (argument_host_shapes.size() != arguments.size()) {
return InvalidArgument(
"Number of argument host shapes not equal to number of arguments (%d "
@ -291,8 +290,8 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
shaped_buffer_ptrs.reserve(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
*argument_host_shapes[i], arguments[i], backend_->platform(),
stream->parent()->device_ordinal()));
*argument_host_shapes[i], arguments[i].Buffers(),
backend_->platform(), stream->parent()->device_ordinal()));
shaped_buffer_ptrs.push_back(&shaped_buffers.back());
}

View File

@ -61,8 +61,7 @@ class LocalExecutable {
// executable.
StatusOr<ExecutionOutput> RunAsync(
absl::Span<Shape const* const> argument_host_shapes,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
ExecutableRunOptions run_options);
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options);
// Return the options used to build the executable.
const ExecutableBuildOptions& build_options() const { return build_options_; }
@ -76,8 +75,8 @@ class LocalExecutable {
//
// The given ExecutableRunOptions override any values from TF_XLA_FLAGS
// environment variable.
Status ValidateExecutionOptions(
const ExecutableRunOptions& run_options, const Backend& backend);
Status ValidateExecutionOptions(const ExecutableRunOptions& run_options,
const Backend& backend);
// Returns a literal containing the contents of the given ShapedBuffer.
StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);

View File

@ -78,9 +78,9 @@ CpuExecutable::CpuExecutable(
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>,
std::vector<se::OwningDeviceMemory>,
std::vector<se::OwningDeviceMemory>>>
CpuExecutable::CreateBufferTable(
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments) {
CpuExecutable::CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
int device_ordinal,
std::vector<ExecutionInput> arguments) {
std::vector<se::DeviceMemoryBase> unowning_buffers(
assignment_->Allocations().size());
std::vector<se::OwningDeviceMemory> owning_buffers(
@ -95,7 +95,7 @@ CpuExecutable::CreateBufferTable(
if (allocation.is_entry_computation_parameter()) {
unowning_buffers[i] = arguments[allocation.parameter_number()]
.element(allocation.param_shape_index())
.Buffer(allocation.param_shape_index())
.AsDeviceMemoryBase();
CHECK_EQ(allocation.size(), unowning_buffers[i].size())
<< "Size mismatch on param " << allocation.parameter_number()
@ -139,9 +139,9 @@ CpuExecutable::CreateBufferTable(
VLOG(3) << "result index: " << result_slice.index();
std::vector<se::OwningDeviceMemory> buffers_to_free;
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
auto maybe_owning_buffer = buffer.second.Release();
for (auto& argument : arguments) {
for (auto& index_buffer : *argument.MutableBuffers()) {
auto maybe_owning_buffer = index_buffer.second.Release();
if (maybe_owning_buffer) {
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
}
@ -284,7 +284,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) {
if (GetRootValueSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous");
@ -297,7 +297,7 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
const Shape& expected_shape =
entry_comp->parameter_instruction(i)->shape();
const Shape& actual_shape = arguments[i].shape();
const Shape& actual_shape = arguments[i].Buffers().shape();
CHECK(
Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape))
<< absl::StreamFormat(
@ -355,9 +355,7 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
std::make_shared<std::vector<se::OwningDeviceMemory>>(
std::move(owning_buffers)),
hlo_execution_profile});
return ExecutionOutput(std::move(result), std::move(buffers_to_release), {},
se::OwningDeviceMemory());
return ExecutionOutput(std::move(result), std::move(buffers_to_release));
}
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {

View File

@ -57,7 +57,7 @@ class CpuExecutable : public Executable {
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) override;
// This should be called after set_ir_module_string.
@ -103,8 +103,7 @@ class CpuExecutable : public Executable {
std::vector<se::OwningDeviceMemory>,
std::vector<se::OwningDeviceMemory>>>
CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
int device_ordinal,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments);
int device_ordinal, std::vector<ExecutionInput> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.

View File

@ -44,15 +44,13 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
return result;
}
static ShapeTree<MaybeOwningDeviceMemory> MakeMaybeOwningDeviceMemoryTree(
static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
const ShapedBuffer& shaped_buffer) {
ShapeTree<MaybeOwningDeviceMemory> result(shaped_buffer.on_device_shape());
auto in_it = shaped_buffer.buffers().begin();
auto out_it = result.begin();
for (; in_it != shaped_buffer.buffers().end(); ++in_it, ++out_it) {
DCHECK(out_it != result.end());
out_it->second = MaybeOwningDeviceMemory(in_it->second);
}
ExecutionInput result(shaped_buffer.on_device_shape());
shaped_buffer.buffers().ForEachElement(
[&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
result.SetBuffer(index, MaybeOwningDeviceMemory(mem));
});
return result;
}
@ -60,7 +58,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
std::vector<ShapeTree<MaybeOwningDeviceMemory>> args(arguments.size());
std::vector<ExecutionInput> args(arguments.size());
auto out_it = args.begin();
for (const ShapedBuffer* arg : arguments) {
*out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg);
@ -73,7 +71,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) {
StatusOr<ExecutionOutput> result = ExecuteAsyncOnStream(
run_options, std::move(arguments), hlo_execution_profile);
@ -238,7 +236,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStreamWrapper(
StatusOr<ExecutionOutput> Executable::ExecuteAsyncOnStreamWrapper(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments) {
std::vector<ExecutionInput> arguments) {
auto state = ExecuteWrapperBeforeExecution(*this, run_options);
StatusOr<ExecutionOutput> return_value = ExecuteAsyncOnStream(
run_options, std::move(arguments), state.profile_ptr.get());

View File

@ -42,18 +42,75 @@ limitations under the License.
namespace xla {
// TODO(b/150633678): Both the ExecutionInput and ExecutionOutput need to be
// revisited, with the execute APIs taking data structure which can better model
// shareable buffers.
class ExecutionInput {
public:
ExecutionInput() = default;
explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {}
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers)
: buffers_(std::move(buffers)) {}
ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
std::vector<ShapeIndex> owner_held_indices)
: buffers_(std::move(buffers)),
unowned_indices_(std::move(owner_held_indices)) {}
ExecutionInput(ExecutionInput&&) = default;
~ExecutionInput() {
for (auto& index : unowned_indices_) {
auto buffer = buffers_.mutable_element(index)->Release();
if (buffer) {
buffer->Release();
}
}
}
ExecutionInput& operator=(ExecutionInput&&) = default;
const Shape& shape() const { return buffers_.shape(); }
void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) {
*buffers_.mutable_element(index) = std::move(buffer);
}
void SetUnownedBuffer(const ShapeIndex& index,
MaybeOwningDeviceMemory buffer) {
*buffers_.mutable_element(index) = std::move(buffer);
unowned_indices_.push_back(index);
}
const ShapeTree<MaybeOwningDeviceMemory>& Buffers() const { return buffers_; }
ShapeTree<MaybeOwningDeviceMemory>* MutableBuffers() { return &buffers_; }
MaybeOwningDeviceMemory* MutableBuffer(const ShapeIndex& index) {
return buffers_.mutable_element(index);
}
const MaybeOwningDeviceMemory& Buffer(const ShapeIndex& index) const {
return buffers_.element(index);
}
private:
ShapeTree<MaybeOwningDeviceMemory> buffers_;
std::vector<ShapeIndex> unowned_indices_;
};
// ExecutionOutput encapsulates the output buffers of a execution and the
// leftover buffers to be released by the caller.
class ExecutionOutput {
public:
explicit ExecutionOutput(ScopedShapedBuffer result)
: result_(std::move(result)) {}
ExecutionOutput(ScopedShapedBuffer result,
std::vector<se::OwningDeviceMemory> to_be_released,
std::vector<ShapeIndex> aliased_indices,
se::OwningDeviceMemory output_shape_table)
std::vector<se::OwningDeviceMemory> to_be_released)
: result_(std::move(result)),
to_be_released_(std::move(to_be_released)),
aliased_indices_(std::move(aliased_indices)),
output_shape_table_(std::move(output_shape_table)) {}
to_be_released_(std::move(to_be_released)) {}
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
se::DeviceMemoryAllocator* allocator, int device_ordinal)
: result_(std::move(on_host_shape), std::move(on_device_shape), allocator,
device_ordinal) {}
ExecutionOutput(ExecutionOutput&&) = default;
ExecutionOutput& operator=(ExecutionOutput&&) = default;
@ -66,6 +123,18 @@ class ExecutionOutput {
}
}
void AddAliasedIndex(ShapeIndex index) {
aliased_indices_.push_back(std::move(index));
}
void AddToBeReleased(se::OwningDeviceMemory mem) {
to_be_released_.push_back(std::move(mem));
}
void SetOutputShapeTable(se::OwningDeviceMemory output_shape_table) {
output_shape_table_ = std::move(output_shape_table);
}
// Should be called once it is known that the execute operation succeeded,
// before returning the ExecutionOutput to the caller.
ExecutionOutput& Commit() {
@ -75,6 +144,8 @@ class ExecutionOutput {
const ScopedShapedBuffer& Result() const { return result_; }
ScopedShapedBuffer* MutableResult() { return &result_; }
const se::OwningDeviceMemory& ShapeTable() const {
return output_shape_table_;
}
@ -169,12 +240,12 @@ class Executable {
// complete.
StatusOr<ExecutionOutput> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile);
virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) = 0;
// Same as ExecuteOnStream(), but runs this executable on multiple
@ -208,7 +279,7 @@ class Executable {
StatusOr<ExecutionOutput> ExecuteAsyncOnStreamWrapper(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments);
std::vector<ExecutionInput> arguments);
const HloProfilePrinterData& hlo_profile_printer_data() const {
CHECK(hlo_profiling_enabled());

View File

@ -328,7 +328,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) {
XLA_SCOPED_LOGGING_TIMER(absl::StrCat("GpuExecutable::ExecuteAsyncOnStream(",
module().name(), ")"));
@ -367,7 +367,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
auto param_no = allocation.parameter_number();
se::DeviceMemoryBase buffer =
arguments[param_no]
.element(allocation.param_shape_index())
.Buffer(allocation.param_shape_index())
.AsDeviceMemoryBase();
// All top-level buffers and sub-buffers must have an explicit, non-null
@ -458,16 +458,15 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
std::vector<se::OwningDeviceMemory> buffers_to_free;
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
auto maybe_owning_buffer = buffer.second.Release();
for (auto& argument : arguments) {
for (auto& index_buffer : *argument.MutableBuffers()) {
auto maybe_owning_buffer = index_buffer.second.Release();
if (maybe_owning_buffer) {
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
}
}
}
return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free),
{}, {});
return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free));
}
const InstructionValueSet& GpuExecutable::GetRootValueSet() const {

View File

@ -84,7 +84,7 @@ class GpuExecutable : public Executable {
// doesn't match the compute capability passed to this object's constructor.
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) override;
std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {

View File

@ -55,7 +55,7 @@ InterpreterExecutable::~InterpreterExecutable() {}
StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) {
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
@ -65,14 +65,14 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
// TransferManager methods below.
std::vector<ShapedBuffer> argument_buffers;
argument_buffers.reserve(arguments.size());
for (const ShapeTree<MaybeOwningDeviceMemory>& arg : arguments) {
argument_buffers.push_back(
ShapedBuffer(arg.shape(), arg.shape(),
/*platform=*/platform,
/*device_ordinal=*/executor->device_ordinal()));
auto in_it = arg.begin();
for (auto& argument : arguments) {
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(),
/*platform=*/nullptr,
/*device_ordinal=*/0));
auto in_it = buffers.begin();
auto out_it = argument_buffers.back().buffers().begin();
for (; in_it != arg.end(); ++in_it, ++out_it) {
for (; in_it != buffers.end(); ++in_it, ++out_it) {
out_it->second = in_it->second.AsDeviceMemoryBase();
}
}
@ -128,12 +128,13 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
}
// Transform the result literal back into a ShapedBuffer.
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers,
transfer_manager->AllocateScopedShapedBuffer(
result_literal.shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
run_options->stream(), result_literal, result));
run_options->stream(), result_literal, result_buffers));
ExecutionOutput result(std::move(result_buffers));
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@ -142,17 +143,15 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
const double nanoseconds = (end_micros - start_micros) * 1000.0;
profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
}
std::vector<se::OwningDeviceMemory> buffers_to_free;
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
auto maybe_owning_buffer = buffer.second.Release();
for (auto& argument : arguments) {
for (auto& index_buffer : *argument.MutableBuffers()) {
auto maybe_owning_buffer = index_buffer.second.Release();
if (maybe_owning_buffer) {
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
result.AddToBeReleased(std::move(*maybe_owning_buffer));
}
}
}
return ExecutionOutput(std::move(result), std::move(buffers_to_free), {}, {});
return std::move(result);
}
/*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {

View File

@ -50,7 +50,7 @@ class InterpreterExecutable : public Executable {
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) override
LOCKS_EXCLUDED(evaluator_lock_);

View File

@ -96,8 +96,8 @@ class BufferDonationTest : public HloTestBase {
memory_allocator.get());
});
std::vector<ShapeTree<MaybeOwningDeviceMemory>> args;
args.emplace_back(std::move(owned_buffers));
std::vector<ExecutionInput> args;
args.emplace_back(ExecutionInput(std::move(owned_buffers)));
TF_ASSERT_OK_AND_ASSIGN(
ExecutionOutput output,

View File

@ -72,6 +72,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",

View File

@ -632,7 +632,7 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
*buffers_.mutable_element(dest_index) = source_buffer;
source_buffer->Ref();
if (dest_buffer != nullptr) {
// If we handed over the ownership of a buffer in ToDeviceMemoryTree(), we
// If we handed over the ownership of a buffer in ToExecutionInput(), we
// will be called here on the way back from execution, to alias back the
// buffer at that index. In that case the buffers will be the same. So we
// need to discard the memory at the destination buffer, before releasing
@ -646,11 +646,10 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
return Status::OK();
}
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>>
XRTTupleAllocation::ToDeviceMemoryTree(
xla::StatusOr<xla::ExecutionInput> XRTTupleAllocation::ToExecutionInput(
const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
release_checker) {
xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
alias_checker) {
xla::ExecutionInput result(on_device_shape());
for (const auto& index_buffer : buffers_) {
if (index_buffer.second == nullptr ||
index_buffer.second->allocation().is_null()) {
@ -658,18 +657,20 @@ XRTTupleAllocation::ToDeviceMemoryTree(
index_buffer.first.ToString(),
" has been released");
}
TF_ASSIGN_OR_RETURN(bool should_release,
release_checker(index_buffer.first));
if (!should_release) {
*shaped_tree.mutable_element(index_buffer.first) =
index_buffer.second->allocation();
TF_ASSIGN_OR_RETURN(bool should_alias, alias_checker(index_buffer.first));
if (!should_alias) {
result.SetBuffer(
index_buffer.first,
xla::MaybeOwningDeviceMemory(index_buffer.second->allocation()));
} else {
// We keep the ownership of the device memory here.
*shaped_tree.mutable_element(index_buffer.first) = se::OwningDeviceMemory(
index_buffer.second->allocation(), device_ordinal_, allocator_);
result.SetUnownedBuffer(
index_buffer.first,
xla::MaybeOwningDeviceMemory(se::OwningDeviceMemory(
index_buffer.second->allocation(), device_ordinal_, allocator_)));
}
}
return std::move(shaped_tree);
return std::move(result);
}
} // namespace tensorflow

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -214,7 +215,7 @@ class XRTTupleAllocation : public core::RefCounted {
const xla::ShapeIndex& source_index,
const xla::ShapeIndex& dest_index);
// Returns the device memory tree of this allocation. If the release_checker
// Returns the device memory tree of this allocation. If the alias_checker
// function returns true for a given index, an owned device memory is returned
// to the caller. But the tuple allocation cannot release the ownership in
// full, as the execute operation might fail. So we rely on a call to
@ -227,10 +228,9 @@ class XRTTupleAllocation : public core::RefCounted {
// introduce a sharing concept (IOW shared_ptr model vs. unique_ptr).
// We'd need something similar to XRTTupleAllocation instead of
// ScopedShapedBuffer, which wants ownership and does not allow sharing.
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>>
ToDeviceMemoryTree(
xla::StatusOr<xla::ExecutionInput> ToExecutionInput(
const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
release_checker);
alias_checker);
private:
// Creates a new handle with (tuple) shape.