diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 0ab4a223916..4f210442005 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -124,6 +124,17 @@ StatusOr Executable::ExecuteOnStreamWrapper( return result; } +StatusOr Executable::ExecuteOnStreamWrapper( + const ServiceExecutableRunOptions* run_options, + std::vector arguments) { + StatusOr result = + ExecuteAsyncOnStreamWrapper(run_options, std::move(arguments)); + Status block_status = run_options->stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(result.status()); + TF_RETURN_IF_ERROR(block_status); + return result; +} + struct ExecuteAsyncOnStreamWrapperState { ExecutionProfile* profile; std::shared_ptr timer; diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index e6b26b4fdae..49614c1af00 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -276,6 +276,10 @@ class Executable { const ServiceExecutableRunOptions* run_options, absl::Span arguments); + StatusOr ExecuteOnStreamWrapper( + const ServiceExecutableRunOptions* run_options, + std::vector arguments); + StatusOr ExecuteAsyncOnStreamWrapper( const ServiceExecutableRunOptions* run_options, absl::Span arguments); diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index daeb5943fda..30a7916c408 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -22,9 +22,11 @@ limitations under the License. #include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/platform/logging.h" @@ -144,13 +146,13 @@ StatusOr HloRunner::Execute(std::unique_ptr module, ExecutionProfile* profile) { TF_ASSIGN_OR_RETURN(std::vector argument_buffers, TransferLiteralsToDevice(arguments)); - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + TF_ASSIGN_OR_RETURN(ExecutionOutput result, ExecuteWithDeviceBuffers( /*module=*/std::move(module), /*arguments=*/argument_buffers, /*run_hlo_passes=*/run_hlo_passes, /*profile=*/profile)); - return TransferLiteralFromDevice(result); + return TransferLiteralFromDevice(result.Result()); } StatusOr HloRunner::Execute(std::unique_ptr module, @@ -171,72 +173,60 @@ StatusOr HloRunner::Execute(std::unique_ptr module, } StatusOr HloRunner::Execute(std::unique_ptr executable, - absl::Span arguments, + absl::Span arguments, ExecutionProfile* profile) { TF_ASSIGN_OR_RETURN(std::vector argument_buffers, TransferLiteralsToDevice(arguments)); - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + TF_ASSIGN_OR_RETURN(ExecutionOutput result, ExecuteWithDeviceBuffers( /*executable=*/executable.get(), /*arguments=*/argument_buffers, /*profile=*/profile)); - return TransferLiteralFromDevice(result); + return TransferLiteralFromDevice(result.Result()); } -StatusOr HloRunner::Execute(std::unique_ptr executable, - absl::Span arguments, - ExecutionProfile* profile) { - // Construct a vector of plain pointers for the arguments. - std::vector argument_pointers; - argument_pointers.reserve(arguments.size()); - for (const auto& argument : arguments) { - argument_pointers.push_back(&argument); +// Convert the owning buffer of inputs into a (partially) owning vector of +// ExecutionInputs, and an owning vector of `OwningDeviceMemory`'s. +static std::vector ExecutionInputsFromScopedShapedBuffers( + absl::Span inputs, + HloInputOutputAliasConfig alias_config, int device_ordinal, + se::DeviceMemoryAllocator* allocator) { + std::vector execution_inputs; + std::vector owned_args; + + for (int param_num = 0; param_num < inputs.size(); param_num++) { + const ScopedShapedBuffer& input_buffer = inputs[param_num]; + ShapeTree buffer_tree( + input_buffer.on_device_shape()); + + input_buffer.buffers().ForEachElement( + [&](const ShapeIndex& index, + const se::DeviceMemoryBase& execution_input_buffer) { + if (alias_config.ParameterHasAlias(param_num, index)) { + // Store owned. + *buffer_tree.mutable_element(index) = se::OwningDeviceMemory{ + execution_input_buffer, device_ordinal, allocator}; + } else { + // Store unowned. + *buffer_tree.mutable_element(index) = execution_input_buffer; + } + }); + execution_inputs.emplace_back(std::move(buffer_tree)); } - return Execute( - /*module=*/std::move(executable), - /*arguments=*/argument_pointers, - /*profile=*/profile); + return execution_inputs; } -StatusOr HloRunner::ExecuteWithDeviceBuffers( +StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - absl::Span arguments, bool run_hlo_passes, + absl::Span arguments, bool run_hlo_passes, ExecutionProfile* profile) { - // Get service run options. - se::Stream stream(backend().default_stream_executor()); - stream.Init(); - ServiceExecutableRunOptions service_run_options = - GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, - nullptr, RunId()); - service_run_options.mutable_run_options()->set_execution_profile(profile); - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, CreateExecutable(std::move(module), run_hlo_passes)); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer retval, - executable->ExecuteOnStreamWrapper(&service_run_options, arguments)); - TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); - return std::move(retval); + return ExecuteWithDeviceBuffers(executable.get(), arguments, profile); } -StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr module, - absl::Span arguments, bool run_hlo_passes, - ExecutionProfile* profile) { - std::vector argument_pointers; - argument_pointers.reserve(arguments.size()); - for (const auto& argument : arguments) { - argument_pointers.push_back(&argument); - } - return ExecuteWithDeviceBuffers( - /*module=*/std::move(module), - /*arguments=*/argument_pointers, - /*run_hlo_passes=*/run_hlo_passes, - /*profile=*/profile); -} - -StatusOr HloRunner::ExecuteWithDeviceBuffers( - Executable* executable, absl::Span arguments, +StatusOr HloRunner::ExecuteWithDeviceBuffers( + Executable* executable, absl::Span arguments, ExecutionProfile* profile) { // Get service run options. se::Stream stream(backend().default_stream_executor()); @@ -246,27 +236,19 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( nullptr, RunId()); service_run_options.mutable_run_options()->set_execution_profile(profile); + std::vector execution_arguments = + ExecutionInputsFromScopedShapedBuffers( + arguments, executable->module().input_output_alias_config(), + stream.parent()->device_ordinal(), stream.parent()->GetAllocator()); + TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer retval, - executable->ExecuteOnStreamWrapper(&service_run_options, arguments)); + ExecutionOutput retval, + executable->ExecuteOnStreamWrapper(&service_run_options, + std::move(execution_arguments))); TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); return std::move(retval); } -StatusOr HloRunner::ExecuteWithDeviceBuffers( - Executable* executable, absl::Span arguments, - ExecutionProfile* profile) { - std::vector argument_pointers; - argument_pointers.reserve(arguments.size()); - for (const auto& argument : arguments) { - argument_pointers.push_back(&argument); - } - return ExecuteWithDeviceBuffers( - /*executable=*/std::move(executable), - /*arguments=*/argument_pointers, - /*profile=*/profile); -} - StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 3b5a80ce33b..7e8b301ab54 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -134,35 +134,19 @@ class HloRunner { bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); - StatusOr Execute(std::unique_ptr executable, - absl::Span arguments, - ExecutionProfile* profile = nullptr); - StatusOr Execute(std::unique_ptr executable, absl::Span arguments, ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. - StatusOr ExecuteWithDeviceBuffers( + StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - absl::Span arguments, + absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); - StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr module, - absl::Span arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); - - // In the following two calls, "executable" is not a unique_ptr to allow - // reuse of the Executable. This call may update the profile information in - // *executable. - StatusOr ExecuteWithDeviceBuffers( - Executable* executable, absl::Span arguments, - ExecutionProfile* profile = nullptr); - - StatusOr ExecuteWithDeviceBuffers( - Executable* executable, absl::Span arguments, + StatusOr ExecuteWithDeviceBuffers( + Executable* executable, absl::Span arguments, ExecutionProfile* profile = nullptr); // Creates an executable object given an HLO module. If run_hlo_passes is diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 7b64be5597b..d0b6e5f80ed 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -421,9 +421,6 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::move(module_or_status.ValueOrDie()); fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie(); - absl::c_transform( - fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]), - [](const Literal& literal) { return const_cast(&literal); }); if (profiles != nullptr) { // We have to enable HLO profiling since otherwise currently the @@ -457,7 +454,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( absl::optional canonical_output; for (int i = 0; i < n; ++i) { StatusOr output = - test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i], + test_runner_.Execute(std::move(executables[i]), fake_arguments[i], /*profile=*/&((*profiles)[i])); if (!output.ok()) { return ::testing::AssertionFailure() << output.status().error_message();