From 62683d061cf31d05588a94cc333b53542cea9568 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Fri, 19 Jun 2020 16:25:08 -0700 Subject: [PATCH] [XLA] Rollback of rollback of "Implement LocalClient::Run which supports buffer donation" PiperOrigin-RevId: 317400695 Change-Id: I56f1f8df347d5a3b2bad9526c7315c63ad6ddadb --- .../compiler/xla/client/local_client.cc | 26 ++++++++++++------- tensorflow/compiler/xla/client/local_client.h | 21 +++++++++++++++ .../tests/multiple_devices_on_host_test.cc | 3 ++- tensorflow/compiler/xla/tests/while_test.cc | 6 +++-- 4 files changed, 44 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index afe115deda8..aa252067e19 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -176,15 +176,23 @@ StatusOr LocalExecutable::Run( for (const ShapedBuffer* const arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - TF_ASSIGN_OR_RETURN(auto options_and_stream, - RunHelper(argument_shapes, run_options)); - ExecutableRunOptions options = options_and_stream.first.run_options(); - options.set_device_ordinal(-1); - auto result = RunAsync(arguments, options); - Status block_status = options.stream()->BlockHostUntilDone(); - TF_RETURN_IF_ERROR(result.status()); - TF_RETURN_IF_ERROR(block_status); - return result; + return AsyncCallAndBlockHostUntilDone( + argument_shapes, run_options, [&](const ExecutableRunOptions& options) { + return RunAsync(arguments, options); + }); +} + +StatusOr LocalExecutable::Run( + std::vector arguments, ExecutableRunOptions run_options) { + std::vector argument_shapes; + argument_shapes.reserve(arguments.size()); + for (const ExecutionInput& arg : arguments) { + argument_shapes.push_back(&arg.shape()); + } + return AsyncCallAndBlockHostUntilDone( + argument_shapes, run_options, [&](const ExecutableRunOptions& options) { + return RunAsync(argument_shapes, std::move(arguments), options); + }); } static std::shared_ptr DumpArguments( diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 7cdeb9dcbf6..3241ac73d54 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -51,6 +51,11 @@ class LocalExecutable { const absl::Span arguments, ExecutableRunOptions run_options); + // Similar to Run(), but allows for donating argument buffers to the + // executable. + StatusOr Run(std::vector arguments, + ExecutableRunOptions run_options); + // Similar to Run(), but need not block the host waiting for the computation // to complete before returning. StatusOr RunAsync( @@ -90,6 +95,22 @@ class LocalExecutable { // Backend::devices_equivalent). int build_device_ordinal() const { return build_options_.device_ordinal(); } + template + StatusOr AsyncCallAndBlockHostUntilDone( + absl::Span argument_shapes, + const ExecutableRunOptions& run_options, + std::function(const ExecutableRunOptions&)> async_callback) { + TF_ASSIGN_OR_RETURN(auto options_and_stream, + RunHelper(argument_shapes, run_options)); + ExecutableRunOptions options = options_and_stream.first.run_options(); + options.set_device_ordinal(-1); + StatusOr result = async_callback(options); + Status block_status = options.stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(result.status()); + TF_RETURN_IF_ERROR(block_status); + return result; + } + // Compiled computation. std::unique_ptr executable_; diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc index 2b19aaded9c..2231fc6feab 100644 --- a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc +++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc @@ -45,7 +45,8 @@ void CompileAndExecute( xla::ClientLibrary::GetXlaService(client->platform()) ->backend() .memory_allocator()); - StatusOr result = executable->Run({}, execute_options); + StatusOr result = + executable->Run(absl::Span(), execute_options); { absl::MutexLock lock(results_mutex); results->emplace_back(device_ordinal, std::move(result)); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index d575bbb1f3e..8e8c3605cc7 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -1324,14 +1324,16 @@ void BM_WhileLoop(int num_iters) { options.set_allocator(&allocator); const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run(absl::Span(), options); ASSERT_TRUE(result.ok()); } // Run benchmark. tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run(absl::Span(), options); ASSERT_TRUE(result.ok()); } }