From d1085a6e0062ecbd211a939af0bc7a7c769dbe70 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Mon, 2 Mar 2020 18:35:45 -0800 Subject: [PATCH] Make XLA's kUserAlias work together with XRT's swap/compaction. The XRT tuple allocation owns the device memory, which, in order for the lower level aliasing to work, needs to be handed out as "owning" within the parameter's shape tree. But if the parameter's shape tree get destroyed (for an intermediate error before execute) the memory will get released and the tuple allocation will be pointing to free memory. This CL introduces an ExecutionInput data structure which wraps a maybe-owning shape tree together with the indices which should be released before the shape tree gets destroyed. This allows the data structure to travel down until the point where the buffers lands inside the ExecutionOutput, which uses a similar logic (until finally the result gets consumed). Unfortunately the situation of the device memory data structures got a bit messy, with Owning, MaybeOwning, ShapedBuffer, ScopedShapedBuffer, ... none of which can work nicely with buffer sharing. Ideally we should have something like std::shared_ptr and ShapeTree> and be done with it. Unfortunately the change (I started that route first) towards that goal is pretty major. PiperOrigin-RevId: 298498866 Change-Id: I2e27c11b7187fa2992ae3b606ea95c18f312cb5a --- .../compiler/xla/client/local_client.cc | 7 +- tensorflow/compiler/xla/client/local_client.h | 7 +- .../xla/service/cpu/cpu_executable.cc | 22 +++-- .../compiler/xla/service/cpu/cpu_executable.h | 5 +- tensorflow/compiler/xla/service/executable.cc | 20 ++--- tensorflow/compiler/xla/service/executable.h | 89 +++++++++++++++++-- .../xla/service/gpu/gpu_executable.cc | 13 ++- .../compiler/xla/service/gpu/gpu_executable.h | 2 +- .../xla/service/interpreter/executable.cc | 33 ++++--- .../xla/service/interpreter/executable.h | 2 +- .../xla/tests/buffer_donation_test.cc | 4 +- tensorflow/compiler/xrt/BUILD | 1 + tensorflow/compiler/xrt/xrt_state.cc | 27 +++--- tensorflow/compiler/xrt/xrt_state.h | 8 +- 14 files changed, 152 insertions(+), 88 deletions(-) diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 7b29e9c4e90..df070d97ff7 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -271,8 +271,7 @@ static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer( StatusOr LocalExecutable::RunAsync( absl::Span argument_host_shapes, - std::vector> arguments, - ExecutableRunOptions run_options) { + std::vector arguments, ExecutableRunOptions run_options) { if (argument_host_shapes.size() != arguments.size()) { return InvalidArgument( "Number of argument host shapes not equal to number of arguments (%d " @@ -291,8 +290,8 @@ StatusOr LocalExecutable::RunAsync( shaped_buffer_ptrs.reserve(arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer( - *argument_host_shapes[i], arguments[i], backend_->platform(), - stream->parent()->device_ordinal())); + *argument_host_shapes[i], arguments[i].Buffers(), + backend_->platform(), stream->parent()->device_ordinal())); shaped_buffer_ptrs.push_back(&shaped_buffers.back()); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 3f9ed37b05f..7cdeb9dcbf6 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -61,8 +61,7 @@ class LocalExecutable { // executable. StatusOr RunAsync( absl::Span argument_host_shapes, - std::vector> arguments, - ExecutableRunOptions run_options); + std::vector arguments, ExecutableRunOptions run_options); // Return the options used to build the executable. const ExecutableBuildOptions& build_options() const { return build_options_; } @@ -76,8 +75,8 @@ class LocalExecutable { // // The given ExecutableRunOptions override any values from TF_XLA_FLAGS // environment variable. - Status ValidateExecutionOptions( - const ExecutableRunOptions& run_options, const Backend& backend); + Status ValidateExecutionOptions(const ExecutableRunOptions& run_options, + const Backend& backend); // Returns a literal containing the contents of the given ShapedBuffer. StatusOr LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 4deae02ad2c..8c1ae0179c0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -78,9 +78,9 @@ CpuExecutable::CpuExecutable( StatusOr, std::vector, std::vector>> -CpuExecutable::CreateBufferTable( - se::DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector> arguments) { +CpuExecutable::CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator, + int device_ordinal, + std::vector arguments) { std::vector unowning_buffers( assignment_->Allocations().size()); std::vector owning_buffers( @@ -95,7 +95,7 @@ CpuExecutable::CreateBufferTable( if (allocation.is_entry_computation_parameter()) { unowning_buffers[i] = arguments[allocation.parameter_number()] - .element(allocation.param_shape_index()) + .Buffer(allocation.param_shape_index()) .AsDeviceMemoryBase(); CHECK_EQ(allocation.size(), unowning_buffers[i].size()) << "Size mismatch on param " << allocation.parameter_number() @@ -139,9 +139,9 @@ CpuExecutable::CreateBufferTable( VLOG(3) << "result index: " << result_slice.index(); std::vector buffers_to_free; - for (ShapeTree& argument : arguments) { - for (std::pair& buffer : argument) { - auto maybe_owning_buffer = buffer.second.Release(); + for (auto& argument : arguments) { + for (auto& index_buffer : *argument.MutableBuffers()) { + auto maybe_owning_buffer = index_buffer.second.Release(); if (maybe_owning_buffer) { buffers_to_free.push_back(std::move(*maybe_owning_buffer)); } @@ -284,7 +284,7 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector arguments, HloExecutionProfile* hlo_execution_profile) { if (GetRootValueSet().IsAmbiguous()) { return Unimplemented("Points-to set of root instruction is ambiguous"); @@ -297,7 +297,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( for (int64 i = 0; i < entry_comp->num_parameters(); ++i) { const Shape& expected_shape = entry_comp->parameter_instruction(i)->shape(); - const Shape& actual_shape = arguments[i].shape(); + const Shape& actual_shape = arguments[i].Buffers().shape(); CHECK( Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape)) << absl::StreamFormat( @@ -355,9 +355,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( std::make_shared>( std::move(owning_buffers)), hlo_execution_profile}); - - return ExecutionOutput(std::move(result), std::move(buffers_to_release), {}, - se::OwningDeviceMemory()); + return ExecutionOutput(std::move(result), std::move(buffers_to_release)); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 6f8a7c3315a..4ec688c1016 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -57,7 +57,7 @@ class CpuExecutable : public Executable { StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector arguments, HloExecutionProfile* hlo_execution_profile) override; // This should be called after set_ir_module_string. @@ -103,8 +103,7 @@ class CpuExecutable : public Executable { std::vector, std::vector>> CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator, - int device_ordinal, - std::vector> arguments); + int device_ordinal, std::vector arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 60fc7d50a36..f41c4b77cd1 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -44,15 +44,13 @@ StatusOr Executable::ExecuteOnStream( return result; } -static ShapeTree MakeMaybeOwningDeviceMemoryTree( +static ExecutionInput MakeMaybeOwningDeviceMemoryTree( const ShapedBuffer& shaped_buffer) { - ShapeTree result(shaped_buffer.on_device_shape()); - auto in_it = shaped_buffer.buffers().begin(); - auto out_it = result.begin(); - for (; in_it != shaped_buffer.buffers().end(); ++in_it, ++out_it) { - DCHECK(out_it != result.end()); - out_it->second = MaybeOwningDeviceMemory(in_it->second); - } + ExecutionInput result(shaped_buffer.on_device_shape()); + shaped_buffer.buffers().ForEachElement( + [&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) { + result.SetBuffer(index, MaybeOwningDeviceMemory(mem)); + }); return result; } @@ -60,7 +58,7 @@ StatusOr Executable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { - std::vector> args(arguments.size()); + std::vector args(arguments.size()); auto out_it = args.begin(); for (const ShapedBuffer* arg : arguments) { *out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg); @@ -73,7 +71,7 @@ StatusOr Executable::ExecuteAsyncOnStream( StatusOr Executable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector arguments, HloExecutionProfile* hlo_execution_profile) { StatusOr result = ExecuteAsyncOnStream( run_options, std::move(arguments), hlo_execution_profile); @@ -238,7 +236,7 @@ StatusOr Executable::ExecuteAsyncOnStreamWrapper( StatusOr Executable::ExecuteAsyncOnStreamWrapper( const ServiceExecutableRunOptions* run_options, - std::vector> arguments) { + std::vector arguments) { auto state = ExecuteWrapperBeforeExecution(*this, run_options); StatusOr return_value = ExecuteAsyncOnStream( run_options, std::move(arguments), state.profile_ptr.get()); diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 1156a9f4ae9..4859759eba5 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -42,18 +42,75 @@ limitations under the License. namespace xla { +// TODO(b/150633678): Both the ExecutionInput and ExecutionOutput need to be +// revisited, with the execute APIs taking data structure which can better model +// shareable buffers. +class ExecutionInput { + public: + ExecutionInput() = default; + explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {} + explicit ExecutionInput(ShapeTree buffers) + : buffers_(std::move(buffers)) {} + ExecutionInput(ShapeTree buffers, + std::vector owner_held_indices) + : buffers_(std::move(buffers)), + unowned_indices_(std::move(owner_held_indices)) {} + ExecutionInput(ExecutionInput&&) = default; + + ~ExecutionInput() { + for (auto& index : unowned_indices_) { + auto buffer = buffers_.mutable_element(index)->Release(); + if (buffer) { + buffer->Release(); + } + } + } + + ExecutionInput& operator=(ExecutionInput&&) = default; + + const Shape& shape() const { return buffers_.shape(); } + + void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) { + *buffers_.mutable_element(index) = std::move(buffer); + } + + void SetUnownedBuffer(const ShapeIndex& index, + MaybeOwningDeviceMemory buffer) { + *buffers_.mutable_element(index) = std::move(buffer); + unowned_indices_.push_back(index); + } + + const ShapeTree& Buffers() const { return buffers_; } + + ShapeTree* MutableBuffers() { return &buffers_; } + + MaybeOwningDeviceMemory* MutableBuffer(const ShapeIndex& index) { + return buffers_.mutable_element(index); + } + + const MaybeOwningDeviceMemory& Buffer(const ShapeIndex& index) const { + return buffers_.element(index); + } + + private: + ShapeTree buffers_; + std::vector unowned_indices_; +}; + // ExecutionOutput encapsulates the output buffers of a execution and the // leftover buffers to be released by the caller. class ExecutionOutput { public: + explicit ExecutionOutput(ScopedShapedBuffer result) + : result_(std::move(result)) {} ExecutionOutput(ScopedShapedBuffer result, - std::vector to_be_released, - std::vector aliased_indices, - se::OwningDeviceMemory output_shape_table) + std::vector to_be_released) : result_(std::move(result)), - to_be_released_(std::move(to_be_released)), - aliased_indices_(std::move(aliased_indices)), - output_shape_table_(std::move(output_shape_table)) {} + to_be_released_(std::move(to_be_released)) {} + ExecutionOutput(Shape on_host_shape, Shape on_device_shape, + se::DeviceMemoryAllocator* allocator, int device_ordinal) + : result_(std::move(on_host_shape), std::move(on_device_shape), allocator, + device_ordinal) {} ExecutionOutput(ExecutionOutput&&) = default; ExecutionOutput& operator=(ExecutionOutput&&) = default; @@ -66,6 +123,18 @@ class ExecutionOutput { } } + void AddAliasedIndex(ShapeIndex index) { + aliased_indices_.push_back(std::move(index)); + } + + void AddToBeReleased(se::OwningDeviceMemory mem) { + to_be_released_.push_back(std::move(mem)); + } + + void SetOutputShapeTable(se::OwningDeviceMemory output_shape_table) { + output_shape_table_ = std::move(output_shape_table); + } + // Should be called once it is known that the execute operation succeeded, // before returning the ExecutionOutput to the caller. ExecutionOutput& Commit() { @@ -75,6 +144,8 @@ class ExecutionOutput { const ScopedShapedBuffer& Result() const { return result_; } + ScopedShapedBuffer* MutableResult() { return &result_; } + const se::OwningDeviceMemory& ShapeTable() const { return output_shape_table_; } @@ -169,12 +240,12 @@ class Executable { // complete. StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector arguments, HloExecutionProfile* hlo_execution_profile); virtual StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector arguments, HloExecutionProfile* hlo_execution_profile) = 0; // Same as ExecuteOnStream(), but runs this executable on multiple @@ -208,7 +279,7 @@ class Executable { StatusOr ExecuteAsyncOnStreamWrapper( const ServiceExecutableRunOptions* run_options, - std::vector> arguments); + std::vector arguments); const HloProfilePrinterData& hlo_profile_printer_data() const { CHECK(hlo_profiling_enabled()); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index d4797e094fd..1f601712038 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -328,7 +328,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector arguments, HloExecutionProfile* hlo_execution_profile) { XLA_SCOPED_LOGGING_TIMER(absl::StrCat("GpuExecutable::ExecuteAsyncOnStream(", module().name(), ")")); @@ -367,7 +367,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( auto param_no = allocation.parameter_number(); se::DeviceMemoryBase buffer = arguments[param_no] - .element(allocation.param_shape_index()) + .Buffer(allocation.param_shape_index()) .AsDeviceMemoryBase(); // All top-level buffers and sub-buffers must have an explicit, non-null @@ -458,16 +458,15 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result)); std::vector buffers_to_free; - for (ShapeTree& argument : arguments) { - for (std::pair& buffer : argument) { - auto maybe_owning_buffer = buffer.second.Release(); + for (auto& argument : arguments) { + for (auto& index_buffer : *argument.MutableBuffers()) { + auto maybe_owning_buffer = index_buffer.second.Release(); if (maybe_owning_buffer) { buffers_to_free.push_back(std::move(*maybe_owning_buffer)); } } } - return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free), - {}, {}); + return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free)); } const InstructionValueSet& GpuExecutable::GetRootValueSet() const { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index ca1d11b7b7d..3d3afe6168b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -84,7 +84,7 @@ class GpuExecutable : public Executable { // doesn't match the compute capability passed to this object's constructor. StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector arguments, HloExecutionProfile* hlo_execution_profile) override; std::shared_ptr GetBufferAssignment() const { diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index f42623a1764..725cb437f8c 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -55,7 +55,7 @@ InterpreterExecutable::~InterpreterExecutable() {} StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); se::StreamExecutor* executor = stream->parent(); @@ -65,14 +65,14 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( // TransferManager methods below. std::vector argument_buffers; argument_buffers.reserve(arguments.size()); - for (const ShapeTree& arg : arguments) { - argument_buffers.push_back( - ShapedBuffer(arg.shape(), arg.shape(), - /*platform=*/platform, - /*device_ordinal=*/executor->device_ordinal())); - auto in_it = arg.begin(); + for (auto& argument : arguments) { + const ShapeTree& buffers = argument.Buffers(); + argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(), + /*platform=*/nullptr, + /*device_ordinal=*/0)); + auto in_it = buffers.begin(); auto out_it = argument_buffers.back().buffers().begin(); - for (; in_it != arg.end(); ++in_it, ++out_it) { + for (; in_it != buffers.end(); ++in_it, ++out_it) { out_it->second = in_it->second.AsDeviceMemoryBase(); } } @@ -128,12 +128,13 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( } // Transform the result literal back into a ShapedBuffer. - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers, transfer_manager->AllocateScopedShapedBuffer( result_literal.shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - run_options->stream(), result_literal, result)); + run_options->stream(), result_literal, result_buffers)); + ExecutionOutput result(std::move(result_buffers)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -142,17 +143,15 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const double nanoseconds = (end_micros - start_micros) * 1000.0; profile->set_compute_time_ns(std::max(nanoseconds, 1.0)); } - - std::vector buffers_to_free; - for (ShapeTree& argument : arguments) { - for (std::pair& buffer : argument) { - auto maybe_owning_buffer = buffer.second.Release(); + for (auto& argument : arguments) { + for (auto& index_buffer : *argument.MutableBuffers()) { + auto maybe_owning_buffer = index_buffer.second.Release(); if (maybe_owning_buffer) { - buffers_to_free.push_back(std::move(*maybe_owning_buffer)); + result.AddToBeReleased(std::move(*maybe_owning_buffer)); } } } - return ExecutionOutput(std::move(result), std::move(buffers_to_free), {}, {}); + return std::move(result); } /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 5df13dfb368..5b2f41a884c 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -50,7 +50,7 @@ class InterpreterExecutable : public Executable { StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector arguments, HloExecutionProfile* hlo_execution_profile) override LOCKS_EXCLUDED(evaluator_lock_); diff --git a/tensorflow/compiler/xla/tests/buffer_donation_test.cc b/tensorflow/compiler/xla/tests/buffer_donation_test.cc index 44e958215a6..be76fa74ae2 100644 --- a/tensorflow/compiler/xla/tests/buffer_donation_test.cc +++ b/tensorflow/compiler/xla/tests/buffer_donation_test.cc @@ -96,8 +96,8 @@ class BufferDonationTest : public HloTestBase { memory_allocator.get()); }); - std::vector> args; - args.emplace_back(std::move(owned_buffers)); + std::vector args; + args.emplace_back(ExecutionInput(std::move(owned_buffers))); TF_ASSERT_OK_AND_ASSIGN( ExecutionOutput output, diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 93ad08fbfdf..d1445144b76 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -72,6 +72,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index cb8d9a1d4da..a0daa5c6c23 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -632,7 +632,7 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source, *buffers_.mutable_element(dest_index) = source_buffer; source_buffer->Ref(); if (dest_buffer != nullptr) { - // If we handed over the ownership of a buffer in ToDeviceMemoryTree(), we + // If we handed over the ownership of a buffer in ToExecutionInput(), we // will be called here on the way back from execution, to alias back the // buffer at that index. In that case the buffers will be the same. So we // need to discard the memory at the destination buffer, before releasing @@ -646,11 +646,10 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source, return Status::OK(); } -xla::StatusOr> -XRTTupleAllocation::ToDeviceMemoryTree( +xla::StatusOr XRTTupleAllocation::ToExecutionInput( const std::function(const xla::ShapeIndex&)>& - release_checker) { - xla::ShapeTree shaped_tree(on_device_shape()); + alias_checker) { + xla::ExecutionInput result(on_device_shape()); for (const auto& index_buffer : buffers_) { if (index_buffer.second == nullptr || index_buffer.second->allocation().is_null()) { @@ -658,18 +657,20 @@ XRTTupleAllocation::ToDeviceMemoryTree( index_buffer.first.ToString(), " has been released"); } - TF_ASSIGN_OR_RETURN(bool should_release, - release_checker(index_buffer.first)); - if (!should_release) { - *shaped_tree.mutable_element(index_buffer.first) = - index_buffer.second->allocation(); + TF_ASSIGN_OR_RETURN(bool should_alias, alias_checker(index_buffer.first)); + if (!should_alias) { + result.SetBuffer( + index_buffer.first, + xla::MaybeOwningDeviceMemory(index_buffer.second->allocation())); } else { // We keep the ownership of the device memory here. - *shaped_tree.mutable_element(index_buffer.first) = se::OwningDeviceMemory( - index_buffer.second->allocation(), device_ordinal_, allocator_); + result.SetUnownedBuffer( + index_buffer.first, + xla::MaybeOwningDeviceMemory(se::OwningDeviceMemory( + index_buffer.second->allocation(), device_ordinal_, allocator_))); } } - return std::move(shaped_tree); + return std::move(result); } } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 8b87a12cfd6..8d706ee6e30 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -214,7 +215,7 @@ class XRTTupleAllocation : public core::RefCounted { const xla::ShapeIndex& source_index, const xla::ShapeIndex& dest_index); - // Returns the device memory tree of this allocation. If the release_checker + // Returns the device memory tree of this allocation. If the alias_checker // function returns true for a given index, an owned device memory is returned // to the caller. But the tuple allocation cannot release the ownership in // full, as the execute operation might fail. So we rely on a call to @@ -227,10 +228,9 @@ class XRTTupleAllocation : public core::RefCounted { // introduce a sharing concept (IOW shared_ptr model vs. unique_ptr). // We'd need something similar to XRTTupleAllocation instead of // ScopedShapedBuffer, which wants ownership and does not allow sharing. - xla::StatusOr> - ToDeviceMemoryTree( + xla::StatusOr ToExecutionInput( const std::function(const xla::ShapeIndex&)>& - release_checker); + alias_checker); private: // Creates a new handle with (tuple) shape.