[XLA/Client] Implement LocalClient::Run which supports buffer donation
PiperOrigin-RevId: 317195199 Change-Id: If4d35d0627fa068a0c2b522fdae52466abd21f51
This commit is contained in:
parent
834fe68f36
commit
a82b75c82b
@ -168,26 +168,6 @@ LocalExecutable::RunHelper(const absl::Span<const Shape* const> argument_shapes,
|
|||||||
return std::make_pair(service_options, std::move(stream));
|
return std::make_pair(service_options, std::move(stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ExecutableRunOptions> LocalExecutable::GetExecutableRunOptions(
|
|
||||||
absl::Span<Shape const* const> argument_shapes,
|
|
||||||
const ExecutableRunOptions& run_options) {
|
|
||||||
TF_ASSIGN_OR_RETURN(auto options_and_stream,
|
|
||||||
RunHelper(argument_shapes, run_options));
|
|
||||||
ExecutableRunOptions options = options_and_stream.first.run_options();
|
|
||||||
options.set_device_ordinal(-1);
|
|
||||||
return options;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static StatusOr<T> BlockHostUntilDoneAfterAsyncCall(
|
|
||||||
se::Stream* stream, std::function<StatusOr<T>()> async_callback) {
|
|
||||||
StatusOr<T> result = async_callback();
|
|
||||||
Status block_status = stream->BlockHostUntilDone();
|
|
||||||
TF_RETURN_IF_ERROR(result.status());
|
|
||||||
TF_RETURN_IF_ERROR(block_status);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
||||||
const absl::Span<const ShapedBuffer* const> arguments,
|
const absl::Span<const ShapedBuffer* const> arguments,
|
||||||
ExecutableRunOptions run_options) {
|
ExecutableRunOptions run_options) {
|
||||||
@ -196,24 +176,15 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
|||||||
for (const ShapedBuffer* const arg : arguments) {
|
for (const ShapedBuffer* const arg : arguments) {
|
||||||
argument_shapes.push_back(&arg->on_host_shape());
|
argument_shapes.push_back(&arg->on_host_shape());
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(ExecutableRunOptions options,
|
TF_ASSIGN_OR_RETURN(auto options_and_stream,
|
||||||
GetExecutableRunOptions(argument_shapes, run_options));
|
RunHelper(argument_shapes, run_options));
|
||||||
return BlockHostUntilDoneAfterAsyncCall<xla::ScopedShapedBuffer>(
|
ExecutableRunOptions options = options_and_stream.first.run_options();
|
||||||
options.stream(), [&] { return RunAsync(arguments, options); });
|
options.set_device_ordinal(-1);
|
||||||
}
|
auto result = RunAsync(arguments, options);
|
||||||
|
Status block_status = options.stream()->BlockHostUntilDone();
|
||||||
StatusOr<ExecutionOutput> LocalExecutable::Run(
|
TF_RETURN_IF_ERROR(result.status());
|
||||||
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
|
TF_RETURN_IF_ERROR(block_status);
|
||||||
std::vector<const Shape*> argument_shapes;
|
return result;
|
||||||
argument_shapes.reserve(arguments.size());
|
|
||||||
for (const ExecutionInput& arg : arguments) {
|
|
||||||
argument_shapes.push_back(&arg.shape());
|
|
||||||
}
|
|
||||||
TF_ASSIGN_OR_RETURN(ExecutableRunOptions options,
|
|
||||||
GetExecutableRunOptions(argument_shapes, run_options));
|
|
||||||
return BlockHostUntilDoneAfterAsyncCall<ExecutionOutput>(
|
|
||||||
options.stream(),
|
|
||||||
[&] { return RunAsync(argument_shapes, std::move(arguments), options); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::shared_ptr<HloSnapshot> DumpArguments(
|
static std::shared_ptr<HloSnapshot> DumpArguments(
|
||||||
|
@ -51,11 +51,6 @@ class LocalExecutable {
|
|||||||
const absl::Span<const ShapedBuffer* const> arguments,
|
const absl::Span<const ShapedBuffer* const> arguments,
|
||||||
ExecutableRunOptions run_options);
|
ExecutableRunOptions run_options);
|
||||||
|
|
||||||
// Similar to Run(), but allows for donating argument buffers to the
|
|
||||||
// executable.
|
|
||||||
StatusOr<ExecutionOutput> Run(std::vector<ExecutionInput> arguments,
|
|
||||||
ExecutableRunOptions run_options);
|
|
||||||
|
|
||||||
// Similar to Run(), but need not block the host waiting for the computation
|
// Similar to Run(), but need not block the host waiting for the computation
|
||||||
// to complete before returning.
|
// to complete before returning.
|
||||||
StatusOr<ScopedShapedBuffer> RunAsync(
|
StatusOr<ScopedShapedBuffer> RunAsync(
|
||||||
@ -90,10 +85,6 @@ class LocalExecutable {
|
|||||||
const absl::Span<const Shape* const> argument_shapes,
|
const absl::Span<const Shape* const> argument_shapes,
|
||||||
ExecutableRunOptions run_options);
|
ExecutableRunOptions run_options);
|
||||||
|
|
||||||
StatusOr<ExecutableRunOptions> GetExecutableRunOptions(
|
|
||||||
absl::Span<Shape const* const> argument_shapes,
|
|
||||||
const ExecutableRunOptions& run_options);
|
|
||||||
|
|
||||||
// The ordinal of the device which this executable was compiled for. The
|
// The ordinal of the device which this executable was compiled for. The
|
||||||
// executable can run on all equivalent devices (as determined by
|
// executable can run on all equivalent devices (as determined by
|
||||||
// Backend::devices_equivalent).
|
// Backend::devices_equivalent).
|
||||||
|
@ -45,8 +45,7 @@ void CompileAndExecute(
|
|||||||
xla::ClientLibrary::GetXlaService(client->platform())
|
xla::ClientLibrary::GetXlaService(client->platform())
|
||||||
->backend()
|
->backend()
|
||||||
.memory_allocator());
|
.memory_allocator());
|
||||||
StatusOr<ScopedShapedBuffer> result =
|
StatusOr<ScopedShapedBuffer> result = executable->Run({}, execute_options);
|
||||||
executable->Run(absl::Span<const ShapedBuffer* const>(), execute_options);
|
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(results_mutex);
|
absl::MutexLock lock(results_mutex);
|
||||||
results->emplace_back(device_ordinal, std::move(result));
|
results->emplace_back(device_ordinal, std::move(result));
|
||||||
|
@ -1324,16 +1324,14 @@ void BM_WhileLoop(int num_iters) {
|
|||||||
options.set_allocator(&allocator);
|
options.set_allocator(&allocator);
|
||||||
const int kWarmups = 2;
|
const int kWarmups = 2;
|
||||||
for (int i = 0; i < kWarmups; ++i) {
|
for (int i = 0; i < kWarmups; ++i) {
|
||||||
auto result =
|
auto result = executable->Run({}, options);
|
||||||
executable->Run(absl::Span<const ShapedBuffer* const>(), options);
|
|
||||||
ASSERT_TRUE(result.ok());
|
ASSERT_TRUE(result.ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run benchmark.
|
// Run benchmark.
|
||||||
tensorflow::testing::StartTiming();
|
tensorflow::testing::StartTiming();
|
||||||
for (int i = 0; i < num_iters; ++i) {
|
for (int i = 0; i < num_iters; ++i) {
|
||||||
auto result =
|
auto result = executable->Run({}, options);
|
||||||
executable->Run(absl::Span<const ShapedBuffer* const>(), options);
|
|
||||||
ASSERT_TRUE(result.ok());
|
ASSERT_TRUE(result.ok());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user