[XLA:Python] Add support for tupling arguments and untupling results in Execute()/ExecuteOnLocalDevices().

Change in preparation for removing tuples from the runtime.

PiperOrigin-RevId: 300606797
Change-Id: Ic59986d0494355380146b47f6a16770d4aa1a688
This commit is contained in:
Peter Hawkins 2020-03-12 12:57:16 -07:00 committed by TensorFlower Gardener
parent 8b80da1235
commit 46b7331fb1
4 changed files with 130 additions and 46 deletions

View File

@ -87,11 +87,11 @@ TEST(GpuMultiStream, Basics) {
/*buffer_reference=*/nullptr, client.get(), device));
// The execution may be enqueued before the transfers complete, requiring
// adequate device-side synchronization.
ExecuteOptions options;
options.untuple_result = true;
TF_ASSERT_OK_AND_ASSIGN(
auto out_buffer,
executable->Execute({in_buffer0.get(), in_buffer1.get()}));
TF_ASSERT_OK_AND_ASSIGN(auto out_buffers, out_buffer->DestructureTuple());
auto out_buffers,
executable->Execute({in_buffer0.get(), in_buffer1.get()}, options));
TF_ASSERT_OK_AND_ASSIGN(auto out_literal, out_buffers[0]->ToLiteral());
LiteralTestUtil::ExpectR1Equal<int32>(expected_outputs, *out_literal);

View File

@ -319,7 +319,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
}
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
const std::vector<PyLocalBuffer*> buffers, PyLocalClient* client,
absl::Span<PyLocalBuffer* const> buffers, PyLocalClient* client,
Device* device) {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
@ -716,11 +716,21 @@ const std::string& PyLocalExecutable::name() const {
}
}
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
PyLocalExecutable::ExecuteHelper(
absl::Span<PyLocalBuffer* const> argument_handles, int replica,
int partition, const RunId& run_id) const {
int partition, const RunId& run_id, const ExecuteOptions& options) const {
const int device_id = (*device_assignment_)(replica, partition);
Device* device = LookupDevice(*client_, device_id);
std::unique_ptr<PyLocalBuffer> tuple_buffer;
std::vector<PyLocalBuffer*> tupled_arguments;
if (options.tuple_arguments) {
TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple(
argument_handles, client_, device));
tupled_arguments = {tuple_buffer.get()};
argument_handles = tupled_arguments;
}
CHECK_EQ(device->host_id(), client_->host_id());
int device_ordinal = device->local_device_state()->device_ordinal();
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
@ -763,16 +773,16 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
event->WaitForEventOnStream(device_state->compute_stream());
}
ExecutableRunOptions options;
options.set_stream(device_state->compute_stream());
options.set_host_to_device_stream(device_state->host_to_device_stream());
options.set_allocator(client_->allocator());
options.set_intra_op_thread_pool(
ExecutableRunOptions run_options;
run_options.set_stream(device_state->compute_stream());
run_options.set_host_to_device_stream(device_state->host_to_device_stream());
run_options.set_allocator(client_->allocator());
run_options.set_intra_op_thread_pool(
client_->client()->backend().eigen_intra_op_thread_pool_device());
options.set_device_assignment(device_assignment_.get());
options.set_run_id(run_id);
options.set_rng_seed(device_state->GetNewPrngSeed());
options.set_gpu_executable_run_options(client_->gpu_run_options());
run_options.set_device_assignment(device_assignment_.get());
run_options.set_run_id(run_id);
run_options.set_rng_seed(device_state->GetNewPrngSeed());
run_options.set_gpu_executable_run_options(client_->gpu_run_options());
// The choice of where we wait is arbitrary; the reason for the wait is pacing
// to avoid problems such as memory fragmentation and running ahead too far,
@ -785,7 +795,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
int executable_idx = executables_.size() > 1 ? partition : 0;
StatusOr<ScopedShapedBuffer> result_buffer_or_status =
executables_[executable_idx]->RunAsync(argument_buffer_ptrs, options);
executables_[executable_idx]->RunAsync(argument_buffer_ptrs, run_options);
VLOG(1) << "Replica " << replica << " partition " << partition
<< " completed; ok=" << result_buffer_or_status.ok();
@ -817,13 +827,19 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
device_state->compute_stream(),
std::make_tuple(executables_[executable_idx], compute_reservation,
device_assignment_));
return absl::make_unique<PyLocalBuffer>(
std::vector<std::unique_ptr<PyLocalBuffer>> outputs;
outputs.push_back(absl::make_unique<PyLocalBuffer>(
result_buffer.on_host_shape(), result_buffer.on_device_shape(),
std::move(out_buffer), client_, device);
std::move(out_buffer), client_, device));
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
TF_ASSIGN_OR_RETURN(outputs, outputs.front()->DestructureTuple());
}
return outputs;
}
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::Execute(
absl::Span<PyLocalBuffer* const> argument_handles) const {
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
PyLocalExecutable::Execute(absl::Span<PyLocalBuffer* const> argument_handles,
const ExecuteOptions& options) const {
if (num_replicas() != 1) {
return InvalidArgument(
"Attempted to execute computation with %d replicas using Execute()",
@ -836,12 +852,13 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::Execute(
}
VLOG(1) << "Executing computation " << name();
return ExecuteHelper(argument_handles, /*replica=*/0, /*partition=*/0,
RunId());
RunId(), options);
}
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
StatusOr<std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>>
PyLocalExecutable::ExecuteOnLocalDevices(
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles) const {
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles,
const ExecuteOptions& options) const {
tensorflow::profiler::TraceMe traceme(
"LocalExecutable::ExecuteOnLocalDevices");
@ -857,17 +874,17 @@ PyLocalExecutable::ExecuteOnLocalDevices(
VLOG(1) << "Executing computation " << name()
<< "; num_replicas=" << num_replicas()
<< " num_partitions=" << num_partitions()
<< " num_local_devices=" << num_local_devices;
std::vector<StatusOr<std::unique_ptr<PyLocalBuffer>>> results(
<< " num_partitions=" << num_partitions() << " num_local_devices=8"
<< num_local_devices;
std::vector<StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>> results(
num_local_devices);
if (num_local_devices == 1) {
// Fast-path if there is only one device — run the computation on the
// current thread.
const int replica = local_logical_device_ids_[0].first;
const int partition = local_logical_device_ids_[0].second;
results[0] =
ExecuteHelper(argument_handles[0], replica, partition, RunId());
results[0] = ExecuteHelper(argument_handles[0], replica, partition, RunId(),
options);
} else {
RunId run_id;
absl::Mutex mu;
@ -881,8 +898,8 @@ PyLocalExecutable::ExecuteOnLocalDevices(
Device* device = local_devices_[i];
const LocalDeviceState& device_state = *device->local_device_state();
device_state.execute_thread()->Schedule([&, replica, partition, i] {
results[i] =
ExecuteHelper(argument_handles[i], replica, partition, run_id);
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
run_id, options);
absl::MutexLock lock(&mu);
--running;
@ -923,7 +940,7 @@ PyLocalExecutable::ExecuteOnLocalDevices(
}
VLOG(1) << "Replicated execution complete.";
std::vector<std::unique_ptr<PyLocalBuffer>> wrapped_results(
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>> wrapped_results(
num_local_devices);
for (int i = 0; i < num_local_devices; ++i) {
const int replica = local_logical_device_ids_[i].first;

View File

@ -214,7 +214,7 @@ class PyLocalBuffer {
Device* device);
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
const std::vector<PyLocalBuffer*> buffers, PyLocalClient* client,
absl::Span<PyLocalBuffer* const> buffers, PyLocalClient* client,
Device* device);
// Asynchronously makes a vector of PyLocalBuffers that can be used to receive
@ -320,6 +320,16 @@ struct CompileOptions {
ExecutableBuildOptions executable_build_options;
};
struct ExecuteOptions {
// If true, the arguments to the computation will be wrapped in a tuple and
// passed as a single parameter.
bool tuple_arguments = false;
// If true, the computation must return a tuple, which will be destructured
// into its elements.
bool untuple_result = false;
};
// Represents a compiled computation that can be executed given handles to
// device-allocated literals. Wraps one or more XLA LocalExecutables (one per
// partition, as specified by the build options).
@ -364,24 +374,27 @@ class PyLocalExecutable {
const std::vector<Device*>& local_devices() const { return local_devices_; }
StatusOr<std::unique_ptr<PyLocalBuffer>> Execute(
absl::Span<PyLocalBuffer* const> argument_handles) const;
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> Execute(
absl::Span<PyLocalBuffer* const> argument_handles,
const ExecuteOptions& options) const;
// Execute on local devices. Takes a sequence of argument lists (one argument
// list per local device) and returns a tuple of results (one result per local
// device). The number of argument lists must be equal to the local device
// count.
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ExecuteOnLocalDevices(
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles) const;
StatusOr<std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>>
ExecuteOnLocalDevices(
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles,
const ExecuteOptions& options) const;
void Delete() { executables_.clear(); }
const string& name() const;
private:
StatusOr<std::unique_ptr<PyLocalBuffer>> ExecuteHelper(
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ExecuteHelper(
absl::Span<PyLocalBuffer* const> argument_handles, int replica,
int partition, const RunId& run_id) const;
int partition, const RunId& run_id, const ExecuteOptions& options) const;
// Create shared pointers so we can free them after the execution: with
// asynchronous execution, the process being executed can outlive the

View File

@ -964,7 +964,7 @@ PYBIND11_MODULE(xla_extension, m) {
py::arg("force_copy") = false)
.def_static(
"make_tuple",
[](const std::vector<PyLocalBuffer*> buffers,
[](std::vector<PyLocalBuffer*> buffers,
std::shared_ptr<PyLocalClient> client,
Device* device) -> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
CHECK(device != nullptr);
@ -1141,21 +1141,26 @@ PYBIND11_MODULE(xla_extension, m) {
absl::Span<PyLocalBuffer* const> args)
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(std::unique_ptr<PyLocalBuffer> output,
executable.Execute(args));
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<PyLocalBuffer>> output,
executable.Execute(args, ExecuteOptions()));
return WrapWithClient(executable.client()->shared_from_this(),
std::move(output));
std::move(output.front()));
},
py::arg("arguments"))
// TODO(phawkins): remove in favor of overload that returns a vector.
.def(
"ExecuteOnLocalDevices",
"Execute",
[](const PyLocalExecutable& executable,
absl::Span<const std::vector<PyLocalBuffer*>> args)
absl::Span<PyLocalBuffer* const> args, bool tuple_arguments)
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.tuple_arguments = tuple_arguments;
options.untuple_result = true;
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<PyLocalBuffer>> output_buffers,
executable.ExecuteOnLocalDevices(args));
executable.Execute(args, options));
std::vector<ClientAndUniquePtr<PyLocalBuffer>> outputs;
outputs.reserve(output_buffers.size());
for (auto& buffer : output_buffers) {
@ -1164,7 +1169,56 @@ PYBIND11_MODULE(xla_extension, m) {
}
return outputs;
},
py::arg("arguments"), py::arg("tuple_arguments"))
// TODO(phawkins): remove in favor of overload that returns a vector.
.def(
"ExecuteOnLocalDevices",
[](const PyLocalExecutable& executable,
absl::Span<const std::vector<PyLocalBuffer*>> args)
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
output_buffers,
executable.ExecuteOnLocalDevices(args, ExecuteOptions()));
std::vector<ClientAndUniquePtr<PyLocalBuffer>> outputs;
outputs.reserve(output_buffers.size());
for (auto& buffers : output_buffers) {
outputs.push_back(
WrapWithClient(executable.client()->shared_from_this(),
std::move(buffers.front())));
}
return outputs;
},
py::arg("arguments"))
.def(
"ExecuteOnLocalDevices",
[](const PyLocalExecutable& executable,
absl::Span<const std::vector<PyLocalBuffer*>> args,
bool tuple_arguments)
-> StatusOr<
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>>> {
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.tuple_arguments = tuple_arguments;
options.untuple_result = true;
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
output_buffers,
executable.ExecuteOnLocalDevices(args, options));
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> outputs;
outputs.resize(output_buffers.size());
for (int computation = 0; computation < output_buffers.size();
++computation) {
for (auto& buffer : output_buffers[computation]) {
outputs[computation].push_back(
WrapWithClient(executable.client()->shared_from_this(),
std::move(buffer)));
}
}
return outputs;
},
py::arg("arguments"), py::arg("tuple_arguments"))
.def(
"get_hlo_modules",
[](const PyLocalExecutable& executable)