diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 62d3614ab1f..80d4c2bbebf 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -278,7 +278,7 @@ class Executable { // If the hlo_execution_profile is provided as non-nullptr, profiling will be // enabled. Note that profiling is tricky to use correctly, as the profiling // objects (when they exist) must out-live the task. - StatusOr ExecuteAsyncOnStream( + virtual StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index ffd62fe9cb3..1513d84ea01 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -299,7 +299,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { } StatusOr GpuExecutable::BufferForAllocation( - absl::Span arguments, + VariantArguments arguments, const GpuExecutable::BufferAllocToDeviceMemoryMap* globals, const BufferAllocation& allocation, se::DeviceMemoryAllocator* const memory_allocator, int device_ordinal, @@ -308,10 +308,17 @@ StatusOr GpuExecutable::BufferForAllocation( return se::DeviceMemoryBase{}; } else if (allocation.is_entry_computation_parameter()) { int64 param_no = allocation.parameter_number(); - se::DeviceMemoryBase registered_buffer = - arguments[param_no] + se::DeviceMemoryBase registered_buffer = [&] { + if (auto unowned_shapedbuffers = + absl::get_if>(&arguments)) { + return (*unowned_shapedbuffers)[param_no]->buffers().element( + allocation.param_shape_index()); + } else { + return absl::get>(arguments)[param_no] .Buffer(allocation.param_shape_index()) .AsDeviceMemoryBase(); + } + }(); if (registered_buffer.is_null() && registered_buffer.size() > 0) { return FailedPrecondition( "Cannot run XLA computation because pointer to (sub-)buffer at " @@ -364,7 +371,7 @@ static Status CheckAlignment(const BufferAllocation& allocation, } StatusOr GpuExecutable::GenerateBufferAllocations( - absl::Span arguments, + VariantArguments arguments, const GpuExecutable::BufferAllocToDeviceMemoryMap* globals, se::DeviceMemoryAllocator* const memory_allocator, se::StreamExecutor* executor) { @@ -391,8 +398,25 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) { - XLA_SCOPED_LOGGING_TIMER( - absl::StrCat("GpuExecutable::ExecuteAsyncOnStream(", module_name_, ")")); + return ExecuteAsyncOnStreamImpl(run_options, absl::MakeSpan(arguments), + hlo_execution_profile); +} + +StatusOr GpuExecutable::ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + absl::Span arguments, + HloExecutionProfile* hlo_execution_profile) { + TF_ASSIGN_OR_RETURN( + ExecutionOutput out, + ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile)); + return out.ConsumeResult(); +} + +StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, VariantArguments arguments, + HloExecutionProfile* hlo_execution_profile) { + XLA_SCOPED_LOGGING_TIMER(absl::StrCat( + "GpuExecutable::ExecuteAsyncOnStreamImpl(", module_name_, ")")); se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator(); // Force synchronous execution if the allocator requires it. const bool block_host_until_done = @@ -443,18 +467,31 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( << " @ index: " << index.ToString(); if (output_info.alias_config) { - ExecutionInput& input = arguments[allocation->parameter_number()]; MaybeOwningDeviceMemory* maybe_owning_memory = - input.MutableBuffer(allocation->param_shape_index()); - if (output_info.alias_config->must_alias() && + [&]() -> xla::MaybeOwningDeviceMemory* { + // ScopedBuffer is never an owned buffer. + if (auto* unowned_shapedbuffers = + absl::get_if>( + &arguments)) { + return nullptr; + } else { + auto unowned_execution_input = + absl::get>(arguments); + ExecutionInput& input = + unowned_execution_input[allocation->parameter_number()]; + return input.MutableBuffer(allocation->param_shape_index()); + } + }(); + if (output_info.alias_config->must_alias() && maybe_owning_memory && !maybe_owning_memory->HasOwnership()) { return InvalidArgument( "An input was configured to be must-alias at " "compile time but not donated at runtime: allocation %d", output_info.allocation_index); } - if (absl::optional owning = - maybe_owning_memory->Release()) { + if (maybe_owning_memory && maybe_owning_memory->HasOwnership()) { + absl::optional owning = + maybe_owning_memory->Release(); // If the caller passes the ownership of the device memory, reuse it // as the output buffer. It is up to the caller whether or not to // donate a buffer; the aliasing information describes which buffers @@ -520,7 +557,9 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( buffer_allocations.TearDown(buffers_in_result, allocations_)); // Free allocations for arguments. - MarkToBeReleasedArguments(absl::MakeSpan(arguments), result); + if (auto args = absl::get_if>(&arguments)) { + MarkToBeReleasedArguments(*args, result); + } return std::move(result); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 846e2e88ac3..c18f4e6be03 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -116,6 +116,17 @@ class GpuExecutable : public Executable { std::vector arguments, HloExecutionProfile* hlo_execution_profile) override; + StatusOr ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + absl::Span arguments, + HloExecutionProfile* hlo_execution_profile); + + using VariantArguments = absl::variant, + absl::Span>; + StatusOr ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + VariantArguments arguments, HloExecutionProfile* hlo_execution_profile); + absl::Span GetAllocations() const { return allocations_; } @@ -146,13 +157,13 @@ class GpuExecutable : public Executable { const ServiceExecutableRunOptions* run_options); StatusOr GenerateBufferAllocations( - absl::Span arguments, + VariantArguments arguments, const GpuExecutable::BufferAllocToDeviceMemoryMap* globals, se::DeviceMemoryAllocator* const memory_allocator, se::StreamExecutor* executor); StatusOr BufferForAllocation( - absl::Span arguments, + VariantArguments arguments, const GpuExecutable::BufferAllocToDeviceMemoryMap* globals, const BufferAllocation& allocation, se::DeviceMemoryAllocator* const memory_allocator, int device_ordinal,