[XLA/Client] Implement LocalClient::Run which supports buffer donation

PiperOrigin-RevId: 317195199
Change-Id: If4d35d0627fa068a0c2b522fdae52466abd21f51
This commit is contained in:
A. Unique TensorFlower 2020-06-18 15:35:42 -07:00 committed by TensorFlower Gardener
parent 834fe68f36
commit a82b75c82b
4 changed files with 12 additions and 53 deletions

View File

@ -168,26 +168,6 @@ LocalExecutable::RunHelper(const absl::Span<const Shape* const> argument_shapes,
return std::make_pair(service_options, std::move(stream)); return std::make_pair(service_options, std::move(stream));
} }
StatusOr<ExecutableRunOptions> LocalExecutable::GetExecutableRunOptions(
absl::Span<Shape const* const> argument_shapes,
const ExecutableRunOptions& run_options) {
TF_ASSIGN_OR_RETURN(auto options_and_stream,
RunHelper(argument_shapes, run_options));
ExecutableRunOptions options = options_and_stream.first.run_options();
options.set_device_ordinal(-1);
return options;
}
template <typename T>
static StatusOr<T> BlockHostUntilDoneAfterAsyncCall(
se::Stream* stream, std::function<StatusOr<T>()> async_callback) {
StatusOr<T> result = async_callback();
Status block_status = stream->BlockHostUntilDone();
TF_RETURN_IF_ERROR(result.status());
TF_RETURN_IF_ERROR(block_status);
return result;
}
StatusOr<ScopedShapedBuffer> LocalExecutable::Run( StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
const absl::Span<const ShapedBuffer* const> arguments, const absl::Span<const ShapedBuffer* const> arguments,
ExecutableRunOptions run_options) { ExecutableRunOptions run_options) {
@ -196,24 +176,15 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
for (const ShapedBuffer* const arg : arguments) { for (const ShapedBuffer* const arg : arguments) {
argument_shapes.push_back(&arg->on_host_shape()); argument_shapes.push_back(&arg->on_host_shape());
} }
TF_ASSIGN_OR_RETURN(ExecutableRunOptions options, TF_ASSIGN_OR_RETURN(auto options_and_stream,
GetExecutableRunOptions(argument_shapes, run_options)); RunHelper(argument_shapes, run_options));
return BlockHostUntilDoneAfterAsyncCall<xla::ScopedShapedBuffer>( ExecutableRunOptions options = options_and_stream.first.run_options();
options.stream(), [&] { return RunAsync(arguments, options); }); options.set_device_ordinal(-1);
} auto result = RunAsync(arguments, options);
Status block_status = options.stream()->BlockHostUntilDone();
StatusOr<ExecutionOutput> LocalExecutable::Run( TF_RETURN_IF_ERROR(result.status());
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) { TF_RETURN_IF_ERROR(block_status);
std::vector<const Shape*> argument_shapes; return result;
argument_shapes.reserve(arguments.size());
for (const ExecutionInput& arg : arguments) {
argument_shapes.push_back(&arg.shape());
}
TF_ASSIGN_OR_RETURN(ExecutableRunOptions options,
GetExecutableRunOptions(argument_shapes, run_options));
return BlockHostUntilDoneAfterAsyncCall<ExecutionOutput>(
options.stream(),
[&] { return RunAsync(argument_shapes, std::move(arguments), options); });
} }
static std::shared_ptr<HloSnapshot> DumpArguments( static std::shared_ptr<HloSnapshot> DumpArguments(

View File

@ -51,11 +51,6 @@ class LocalExecutable {
const absl::Span<const ShapedBuffer* const> arguments, const absl::Span<const ShapedBuffer* const> arguments,
ExecutableRunOptions run_options); ExecutableRunOptions run_options);
// Similar to Run(), but allows for donating argument buffers to the
// executable.
StatusOr<ExecutionOutput> Run(std::vector<ExecutionInput> arguments,
ExecutableRunOptions run_options);
// Similar to Run(), but need not block the host waiting for the computation // Similar to Run(), but need not block the host waiting for the computation
// to complete before returning. // to complete before returning.
StatusOr<ScopedShapedBuffer> RunAsync( StatusOr<ScopedShapedBuffer> RunAsync(
@ -90,10 +85,6 @@ class LocalExecutable {
const absl::Span<const Shape* const> argument_shapes, const absl::Span<const Shape* const> argument_shapes,
ExecutableRunOptions run_options); ExecutableRunOptions run_options);
StatusOr<ExecutableRunOptions> GetExecutableRunOptions(
absl::Span<Shape const* const> argument_shapes,
const ExecutableRunOptions& run_options);
// The ordinal of the device which this executable was compiled for. The // The ordinal of the device which this executable was compiled for. The
// executable can run on all equivalent devices (as determined by // executable can run on all equivalent devices (as determined by
// Backend::devices_equivalent). // Backend::devices_equivalent).

View File

@ -45,8 +45,7 @@ void CompileAndExecute(
xla::ClientLibrary::GetXlaService(client->platform()) xla::ClientLibrary::GetXlaService(client->platform())
->backend() ->backend()
.memory_allocator()); .memory_allocator());
StatusOr<ScopedShapedBuffer> result = StatusOr<ScopedShapedBuffer> result = executable->Run({}, execute_options);
executable->Run(absl::Span<const ShapedBuffer* const>(), execute_options);
{ {
absl::MutexLock lock(results_mutex); absl::MutexLock lock(results_mutex);
results->emplace_back(device_ordinal, std::move(result)); results->emplace_back(device_ordinal, std::move(result));

View File

@ -1324,16 +1324,14 @@ void BM_WhileLoop(int num_iters) {
options.set_allocator(&allocator); options.set_allocator(&allocator);
const int kWarmups = 2; const int kWarmups = 2;
for (int i = 0; i < kWarmups; ++i) { for (int i = 0; i < kWarmups; ++i) {
auto result = auto result = executable->Run({}, options);
executable->Run(absl::Span<const ShapedBuffer* const>(), options);
ASSERT_TRUE(result.ok()); ASSERT_TRUE(result.ok());
} }
// Run benchmark. // Run benchmark.
tensorflow::testing::StartTiming(); tensorflow::testing::StartTiming();
for (int i = 0; i < num_iters; ++i) { for (int i = 0; i < num_iters; ++i) {
auto result = auto result = executable->Run({}, options);
executable->Run(absl::Span<const ShapedBuffer* const>(), options);
ASSERT_TRUE(result.ok()); ASSERT_TRUE(result.ok());
} }
} }