[XLA] Refactor Executable::ExecuteAsyncOnStream.

Change implementations of Executable to always implement the overload that takes a std::vector<ShapeTree<MaybeOwningDeviceMemory>>. Make the non-owning version a wrapper around the maybe-owning version.

Simplification in preparation for plumbing buffer donation into JAX. This change is also a necessary preparatory step for implementing buffer donation on CPU and GPU.

PiperOrigin-RevId: 283615681
Change-Id: I0d3c65bee506822d23e5827493213e0921b4ef9e
This commit is contained in:
Amit Patankar 2019-12-03 13:55:22 -08:00 committed by TensorFlower Gardener
parent 17f3e8ad39
commit af79ee35f5
13 changed files with 69 additions and 146 deletions

View File

@ -242,16 +242,9 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
"//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_execution_profile",
"//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/compiler/xla/service:logical_buffer",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:macros",
"//tensorflow/core/platform:mutex",
"//tensorflow/core/platform:platform_port",
"//tensorflow/core/platform:types",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor/host:host_stream", "//tensorflow/stream_executor/host:host_stream",

View File

@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
@ -45,7 +44,6 @@ limitations under the License.
#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/host/host_stream.h" #include "tensorflow/stream_executor/host/host_stream.h"
namespace xla { namespace xla {
@ -75,12 +73,11 @@ CpuExecutable::CpuExecutable(
<< reinterpret_cast<void*>(compute_function_); << reinterpret_cast<void*>(compute_function_);
} }
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>, StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<se::OwningDeviceMemory>,
std::vector<se::OwningDeviceMemory>>> std::vector<se::OwningDeviceMemory>>>
CpuExecutable::CreateBufferTable( CpuExecutable::CreateBufferTable(
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal, se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments) { absl::Span<const ShapedBuffer* const> arguments) {
std::vector<se::DeviceMemoryBase> unowning_buffers( std::vector<se::DeviceMemoryBase> unowning_buffers(
assignment_->Allocations().size()); assignment_->Allocations().size());
std::vector<se::OwningDeviceMemory> owning_buffers( std::vector<se::OwningDeviceMemory> owning_buffers(
@ -94,9 +91,8 @@ CpuExecutable::CreateBufferTable(
VLOG(3) << allocation.ToString(); VLOG(3) << allocation.ToString();
if (allocation.is_entry_computation_parameter()) { if (allocation.is_entry_computation_parameter()) {
unowning_buffers[i] = arguments[allocation.parameter_number()] unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
.element(allocation.param_shape_index()) allocation.param_shape_index());
.AsDeviceMemoryBase();
CHECK_EQ(allocation.size(), unowning_buffers[i].size()) CHECK_EQ(allocation.size(), unowning_buffers[i].size())
<< "Size mismatch on param " << allocation.parameter_number() << "Size mismatch on param " << allocation.parameter_number()
<< " at shape index " << allocation.param_shape_index().ToString(); << " at shape index " << allocation.param_shape_index().ToString();
@ -138,17 +134,7 @@ CpuExecutable::CreateBufferTable(
assignment_->GetUniqueTopLevelOutputSlice()); assignment_->GetUniqueTopLevelOutputSlice());
VLOG(3) << "result index: " << result_slice.index(); VLOG(3) << "result index: " << result_slice.index();
std::vector<se::OwningDeviceMemory> buffers_to_free; return {{std::move(unowning_buffers), std::move(owning_buffers)}};
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
auto maybe_owning_buffer = buffer.second.Release();
if (maybe_owning_buffer) {
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
}
}
}
return {{std::move(unowning_buffers), std::move(owning_buffers),
std::move(buffers_to_free)}};
} }
Status CpuExecutable::ExecuteComputeFunction( Status CpuExecutable::ExecuteComputeFunction(
@ -282,9 +268,9 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
return std::move(result_buffer); return std::move(result_buffer);
} }
StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream( StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments, absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) { HloExecutionProfile* hlo_execution_profile) {
if (GetRootValueSet().IsAmbiguous()) { if (GetRootValueSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous"); return Unimplemented("Points-to set of root instruction is ambiguous");
@ -297,7 +283,7 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
for (int64 i = 0; i < entry_comp->num_parameters(); ++i) { for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
const Shape& expected_shape = const Shape& expected_shape =
entry_comp->parameter_instruction(i)->shape(); entry_comp->parameter_instruction(i)->shape();
const Shape& actual_shape = arguments[i].shape(); const Shape& actual_shape = arguments[i]->on_device_shape();
CHECK(expected_shape == actual_shape) << absl::StreamFormat( CHECK(expected_shape == actual_shape) << absl::StreamFormat(
"Shape mismatch on argument %d. Expected %s, but was %s.", i, "Shape mismatch on argument %d. Expected %s, but was %s.", i,
expected_shape.ToString(/*print_layout=*/true), expected_shape.ToString(/*print_layout=*/true),
@ -311,11 +297,10 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
se::DeviceMemoryAllocator* memory_allocator = run_options->allocator(); se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
std::vector<se::OwningDeviceMemory> owning_buffers; std::vector<se::OwningDeviceMemory> owning_buffers;
std::vector<se::DeviceMemoryBase> unowning_buffers; std::vector<se::DeviceMemoryBase> unowning_buffers;
std::vector<se::OwningDeviceMemory> buffers_to_release;
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::tie(unowning_buffers, owning_buffers, buffers_to_release), std::tie(unowning_buffers, owning_buffers),
CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(), CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
std::move(arguments))); arguments));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer result, ScopedShapedBuffer result,
@ -354,8 +339,7 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
std::move(owning_buffers)), std::move(owning_buffers)),
hlo_execution_profile}); hlo_execution_profile});
return ExecutionOutput(std::move(result), std::move(buffers_to_release), {}, return std::move(result);
se::OwningDeviceMemory());
} }
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {

View File

@ -55,9 +55,9 @@ class CpuExecutable : public Executable {
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map); std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~CpuExecutable() override {} ~CpuExecutable() override {}
StatusOr<ExecutionOutput> ExecuteAsyncOnStream( StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments, absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override; HloExecutionProfile* hlo_execution_profile) override;
// This should be called after set_ir_module_string. // This should be called after set_ir_module_string.
@ -96,15 +96,11 @@ class CpuExecutable : public Executable {
// allocated by this routine. This routine allocates buffers for temporary // allocated by this routine. This routine allocates buffers for temporary
// storage and the live-out buffer into which the computation writes it // storage and the live-out buffer into which the computation writes it
// result. // result.
// StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
// - buffers_to_free: buffers whose ownership was donated by the caller that
// are to be freed by the caller.
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>,
std::vector<se::OwningDeviceMemory>,
std::vector<se::OwningDeviceMemory>>> std::vector<se::OwningDeviceMemory>>>
CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator, CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
int device_ordinal, int device_ordinal,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments); absl::Span<const ShapedBuffer* const> arguments);
// Calls the generated function performing the computation with the given // Calls the generated function performing the computation with the given
// arguments using the supplied buffers. // arguments using the supplied buffers.

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -44,36 +43,9 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
return result; return result;
} }
static ShapeTree<MaybeOwningDeviceMemory> MakeMaybeOwningDeviceMemoryTree(
const ShapedBuffer& shaped_buffer) {
ShapeTree<MaybeOwningDeviceMemory> result(shaped_buffer.on_device_shape());
auto in_it = shaped_buffer.buffers().begin();
auto out_it = result.begin();
for (; in_it != shaped_buffer.buffers().end(); ++in_it, ++out_it) {
DCHECK(out_it != result.end());
out_it->second = MaybeOwningDeviceMemory(in_it->second);
}
return result;
}
StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
std::vector<ShapeTree<MaybeOwningDeviceMemory>> args(arguments.size());
auto out_it = args.begin();
for (const ShapedBuffer* arg : arguments) {
*out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg);
}
TF_ASSIGN_OR_RETURN(ExecutionOutput out,
ExecuteAsyncOnStream(run_options, std::move(args),
hlo_execution_profile));
return out.ConsumeResult();
}
StatusOr<ExecutionOutput> Executable::ExecuteOnStream( StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments, std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
HloExecutionProfile* hlo_execution_profile) { HloExecutionProfile* hlo_execution_profile) {
StatusOr<ExecutionOutput> result = ExecuteAsyncOnStream( StatusOr<ExecutionOutput> result = ExecuteAsyncOnStream(
run_options, std::move(arguments), hlo_execution_profile); run_options, std::move(arguments), hlo_execution_profile);
@ -83,6 +55,14 @@ StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
return result; return result;
} }
StatusOr<ExecutionOutput> Executable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* /*run_options*/,
std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> /*arguments*/,
HloExecutionProfile* /*hlo_execution_profile*/) {
return Unimplemented(
"MaybeOwningDeviceMemory version of overload is not implemented ");
}
StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams( StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
absl::Span<const ServiceExecutableRunOptions> run_options, absl::Span<const ServiceExecutableRunOptions> run_options,
absl::Span<const absl::Span<const ShapedBuffer* const>> arguments) { absl::Span<const absl::Span<const ShapedBuffer* const>> arguments) {

View File

@ -160,22 +160,22 @@ class Executable {
// If the hlo_execution_profile is provided as non-nullptr, profiling will be // If the hlo_execution_profile is provided as non-nullptr, profiling will be
// enabled. Note that profiling is tricky to use correctly, as the profiling // enabled. Note that profiling is tricky to use correctly, as the profiling
// objects (when they exist) must out-live the task. // objects (when they exist) must out-live the task.
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream( virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments, absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile); HloExecutionProfile* hlo_execution_profile) = 0;
// Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to // Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to
// complete. // complete.
StatusOr<ExecutionOutput> ExecuteOnStream( StatusOr<ExecutionOutput> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments, std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
HloExecutionProfile* hlo_execution_profile); HloExecutionProfile* hlo_execution_profile);
virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream( virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments, std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
HloExecutionProfile* hlo_execution_profile) = 0; HloExecutionProfile* hlo_execution_profile);
// Same as ExecuteOnStream(), but runs this executable on multiple // Same as ExecuteOnStream(), but runs this executable on multiple
// streams. arguments[i] contains the arguments to the execution on // streams. arguments[i] contains the arguments to the execution on

View File

@ -299,14 +299,11 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
return &module_globals_.emplace(executor, std::move(globals)).first->second; return &module_globals_.emplace(executor, std::move(globals)).first->second;
} }
StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream( StatusOr<ScopedShapedBuffer> GpuExecutable::Execute(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments, absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) { HloExecutionProfile* hlo_execution_profile, bool block_host_until_done) {
se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator(); se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
// Force synchronous execution if the allocator requires it.
const bool block_host_until_done =
!memory_allocator->AllowsAsynchronousDeallocation();
if (GetRootValueSet().IsAmbiguous()) { if (GetRootValueSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous"); return Unimplemented("Points-to set of root instruction is ambiguous");
@ -337,9 +334,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
if (allocation.is_entry_computation_parameter()) { if (allocation.is_entry_computation_parameter()) {
auto param_no = allocation.parameter_number(); auto param_no = allocation.parameter_number();
se::DeviceMemoryBase buffer = se::DeviceMemoryBase buffer =
arguments[param_no] arguments[param_no]->buffer(allocation.param_shape_index());
.element(allocation.param_shape_index())
.AsDeviceMemoryBase();
// All top-level buffers and sub-buffers must have an explicit, non-null // All top-level buffers and sub-buffers must have an explicit, non-null
// pointer, except for zero-sized buffers, which may be null. // pointer, except for zero-sized buffers, which may be null.
@ -428,17 +423,19 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
})); }));
TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result)); TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
std::vector<se::OwningDeviceMemory> buffers_to_free; return std::move(shaped_buffer);
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) { }
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
auto maybe_owning_buffer = buffer.second.Release(); StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
if (maybe_owning_buffer) { const ServiceExecutableRunOptions* run_options,
buffers_to_free.push_back(std::move(*maybe_owning_buffer)); absl::Span<const ShapedBuffer* const> arguments,
} HloExecutionProfile* hlo_execution_profile) {
} se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
} // Force synchronous execution if the allocator requires it.
return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free), bool block_host_until_done =
{}, {}); !memory_allocator->AllowsAsynchronousDeallocation();
return Execute(run_options, arguments, hlo_execution_profile,
block_host_until_done);
} }
const InstructionValueSet& GpuExecutable::GetRootValueSet() const { const InstructionValueSet& GpuExecutable::GetRootValueSet() const {

View File

@ -82,9 +82,9 @@ class GpuExecutable : public Executable {
// ExecuteAsyncOnStream will fail if the compute capability of the stream // ExecuteAsyncOnStream will fail if the compute capability of the stream
// doesn't match the compute capability passed to this object's constructor. // doesn't match the compute capability passed to this object's constructor.
StatusOr<ExecutionOutput> ExecuteAsyncOnStream( StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments, absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override; HloExecutionProfile* hlo_execution_profile) override;
std::shared_ptr<const BufferAssignment> GetBufferAssignment() const { std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {
@ -92,6 +92,11 @@ class GpuExecutable : public Executable {
} }
private: private:
StatusOr<ScopedShapedBuffer> Execute(
const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile, bool block_host_until_done);
// If `block_host_until_done` is false, execution will not block the host // If `block_host_until_done` is false, execution will not block the host
// until the kernels have completed. This is used as an optimization for // until the kernels have completed. This is used as an optimization for
// clients, such as Tensorflow, that use a single stream of execution for // clients, such as Tensorflow, that use a single stream of execution for

View File

@ -151,8 +151,7 @@ absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
absl::optional<HloInputOutputAliasConfig::Alias> absl::optional<HloInputOutputAliasConfig::Alias>
HloInputOutputAliasConfig::GetAliasedParameter( HloInputOutputAliasConfig::GetAliasedParameter(
const ShapeIndex& output_index) const { const ShapeIndex& output_index) const {
CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
<< ToString() << " " << alias_.shape().ToString() << " " << output_index;
return alias_.element(output_index); return alias_.element(output_index);
} }

View File

@ -89,15 +89,10 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:hlo_evaluator",
"//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_execution_profile",
"//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:macros",
"//tensorflow/core/platform:mutex",
"//tensorflow/core/platform:types",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
@ -40,39 +39,24 @@ namespace interpreter {
InterpreterExecutable::InterpreterExecutable( InterpreterExecutable::InterpreterExecutable(
std::unique_ptr<HloModule> hlo_module, std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloEvaluator> evaluator) std::unique_ptr<HloEvaluator> evaluator)
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr, : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr,
/*hlo_profile_index_map=*/nullptr), /*hlo_profile_index_map=*/nullptr),
evaluator_(std::move(evaluator)) {} evaluator_(std::move(evaluator)) {}
InterpreterExecutable::~InterpreterExecutable() {} InterpreterExecutable::~InterpreterExecutable() {}
StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream( StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments, absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) { HloExecutionProfile* hlo_execution_profile) {
se::Stream* stream = run_options->stream(); se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent(); se::StreamExecutor* executor = stream->parent();
const se::Platform* platform = executor->platform(); const se::Platform* platform = executor->platform();
// Convert the ShapeTree to a ShapedBuffer. We do this so we can call
// TransferManager methods below.
std::vector<ShapedBuffer> argument_buffers;
argument_buffers.reserve(arguments.size());
for (const ShapeTree<MaybeOwningDeviceMemory>& arg : arguments) {
argument_buffers.push_back(ShapedBuffer(arg.shape(), arg.shape(),
/*platform=*/nullptr,
/*device_ordinal=*/0));
auto in_it = arg.begin();
auto out_it = argument_buffers.back().buffers().begin();
for (; in_it != arg.end(); ++in_it, ++out_it) {
out_it->second = in_it->second.AsDeviceMemoryBase();
}
}
VLOG(1) << "Execute " << module().name(); VLOG(1) << "Execute " << module().name();
if (VLOG_IS_ON(2)) { if (VLOG_IS_ON(2)) {
for (const auto& a : argument_buffers) { for (const auto& a : arguments) {
VLOG(2) << "-- argument " << a; VLOG(2) << "-- argument " << *a;
} }
} }
@ -87,7 +71,7 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
// Check that the args have the right shape. // Check that the args have the right shape.
for (int64 i = 0; i < computation->num_parameters(); ++i) { for (int64 i = 0; i < computation->num_parameters(); ++i) {
const auto& expected_shape = computation->parameter_instruction(i)->shape(); const auto& expected_shape = computation->parameter_instruction(i)->shape();
const auto& actual_shape = argument_buffers[i].on_device_shape(); const auto& actual_shape = arguments[i]->on_device_shape();
if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape, if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
actual_shape)) { actual_shape)) {
return InvalidArgument( return InvalidArgument(
@ -106,7 +90,7 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
for (int64 p = 0; p < computation->num_parameters(); ++p) { for (int64 p = 0; p < computation->num_parameters(); ++p) {
TF_ASSIGN_OR_RETURN(Literal arg_literal, TF_ASSIGN_OR_RETURN(Literal arg_literal,
transfer_manager->TransferLiteralFromDevice( transfer_manager->TransferLiteralFromDevice(
run_options->stream(), argument_buffers[p])); run_options->stream(), *arguments[p]));
arg_literals.push_back(std::move(arg_literal)); arg_literals.push_back(std::move(arg_literal));
} }
@ -135,16 +119,7 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
profile->set_compute_time_ns(std::max(nanoseconds, 1.0)); profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
} }
std::vector<se::OwningDeviceMemory> buffers_to_free; return std::move(result);
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
auto maybe_owning_buffer = buffer.second.Release();
if (maybe_owning_buffer) {
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
}
}
}
return ExecutionOutput(std::move(result), std::move(buffers_to_free), {}, {});
} }
/*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) { /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {

View File

@ -46,9 +46,9 @@ class InterpreterExecutable : public Executable {
std::unique_ptr<HloEvaluator> evaluator); std::unique_ptr<HloEvaluator> evaluator);
~InterpreterExecutable() override; ~InterpreterExecutable() override;
StatusOr<ExecutionOutput> ExecuteAsyncOnStream( StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments, absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override HloExecutionProfile* hlo_execution_profile) override
LOCKS_EXCLUDED(evaluator_lock_); LOCKS_EXCLUDED(evaluator_lock_);

View File

@ -17,8 +17,7 @@ limitations under the License.
#include "absl/types/variant.h" #include "absl/types/variant.h"
namespace xla { namespace xla {
tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() {
const {
if (HasOwnership()) { if (HasOwnership()) {
return *absl::get<tensorflow::se::OwningDeviceMemory>(mem_); return *absl::get<tensorflow::se::OwningDeviceMemory>(mem_);
} else { } else {

View File

@ -49,7 +49,7 @@ class MaybeOwningDeviceMemory {
// Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The
// caller of this function is *not* responsible for freeing the memory. // caller of this function is *not* responsible for freeing the memory.
tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase() const; tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase();
// Release the tensorflow::se::OwningDeviceMemory without freeing it, and // Release the tensorflow::se::OwningDeviceMemory without freeing it, and
// moves the ownership of the memory buffer from the object to the caller. // moves the ownership of the memory buffer from the object to the caller.