[XLA] Rollback of rollback of "Implement LocalClient::Run which supports buffer donation"
PiperOrigin-RevId: 317400695 Change-Id: I56f1f8df347d5a3b2bad9526c7315c63ad6ddadb
This commit is contained in:
parent
a70ad66828
commit
62683d061c
@ -176,15 +176,23 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
||||
for (const ShapedBuffer* const arg : arguments) {
|
||||
argument_shapes.push_back(&arg->on_host_shape());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto options_and_stream,
|
||||
RunHelper(argument_shapes, run_options));
|
||||
ExecutableRunOptions options = options_and_stream.first.run_options();
|
||||
options.set_device_ordinal(-1);
|
||||
auto result = RunAsync(arguments, options);
|
||||
Status block_status = options.stream()->BlockHostUntilDone();
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
TF_RETURN_IF_ERROR(block_status);
|
||||
return result;
|
||||
return AsyncCallAndBlockHostUntilDone<xla::ScopedShapedBuffer>(
|
||||
argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
|
||||
return RunAsync(arguments, options);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<ExecutionOutput> LocalExecutable::Run(
|
||||
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
|
||||
std::vector<const Shape*> argument_shapes;
|
||||
argument_shapes.reserve(arguments.size());
|
||||
for (const ExecutionInput& arg : arguments) {
|
||||
argument_shapes.push_back(&arg.shape());
|
||||
}
|
||||
return AsyncCallAndBlockHostUntilDone<ExecutionOutput>(
|
||||
argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
|
||||
return RunAsync(argument_shapes, std::move(arguments), options);
|
||||
});
|
||||
}
|
||||
|
||||
static std::shared_ptr<HloSnapshot> DumpArguments(
|
||||
|
@ -51,6 +51,11 @@ class LocalExecutable {
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
ExecutableRunOptions run_options);
|
||||
|
||||
// Similar to Run(), but allows for donating argument buffers to the
|
||||
// executable.
|
||||
StatusOr<ExecutionOutput> Run(std::vector<ExecutionInput> arguments,
|
||||
ExecutableRunOptions run_options);
|
||||
|
||||
// Similar to Run(), but need not block the host waiting for the computation
|
||||
// to complete before returning.
|
||||
StatusOr<ScopedShapedBuffer> RunAsync(
|
||||
@ -90,6 +95,22 @@ class LocalExecutable {
|
||||
// Backend::devices_equivalent).
|
||||
int build_device_ordinal() const { return build_options_.device_ordinal(); }
|
||||
|
||||
template <typename T>
|
||||
StatusOr<T> AsyncCallAndBlockHostUntilDone(
|
||||
absl::Span<Shape const* const> argument_shapes,
|
||||
const ExecutableRunOptions& run_options,
|
||||
std::function<StatusOr<T>(const ExecutableRunOptions&)> async_callback) {
|
||||
TF_ASSIGN_OR_RETURN(auto options_and_stream,
|
||||
RunHelper(argument_shapes, run_options));
|
||||
ExecutableRunOptions options = options_and_stream.first.run_options();
|
||||
options.set_device_ordinal(-1);
|
||||
StatusOr<T> result = async_callback(options);
|
||||
Status block_status = options.stream()->BlockHostUntilDone();
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
TF_RETURN_IF_ERROR(block_status);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compiled computation.
|
||||
std::unique_ptr<Executable> executable_;
|
||||
|
||||
|
@ -45,7 +45,8 @@ void CompileAndExecute(
|
||||
xla::ClientLibrary::GetXlaService(client->platform())
|
||||
->backend()
|
||||
.memory_allocator());
|
||||
StatusOr<ScopedShapedBuffer> result = executable->Run({}, execute_options);
|
||||
StatusOr<ScopedShapedBuffer> result =
|
||||
executable->Run(absl::Span<const ShapedBuffer* const>(), execute_options);
|
||||
{
|
||||
absl::MutexLock lock(results_mutex);
|
||||
results->emplace_back(device_ordinal, std::move(result));
|
||||
|
@ -1324,14 +1324,16 @@ void BM_WhileLoop(int num_iters) {
|
||||
options.set_allocator(&allocator);
|
||||
const int kWarmups = 2;
|
||||
for (int i = 0; i < kWarmups; ++i) {
|
||||
auto result = executable->Run({}, options);
|
||||
auto result =
|
||||
executable->Run(absl::Span<const ShapedBuffer* const>(), options);
|
||||
ASSERT_TRUE(result.ok());
|
||||
}
|
||||
|
||||
// Run benchmark.
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < num_iters; ++i) {
|
||||
auto result = executable->Run({}, options);
|
||||
auto result =
|
||||
executable->Run(absl::Span<const ShapedBuffer* const>(), options);
|
||||
ASSERT_TRUE(result.ok());
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user