[XLA:Python] Add support for tupling arguments and untupling results in Execute()/ExecuteOnLocalDevices().
Change in preparation for removing tuples from the runtime. PiperOrigin-RevId: 300606797 Change-Id: Ic59986d0494355380146b47f6a16770d4aa1a688
This commit is contained in:
parent
8b80da1235
commit
46b7331fb1
@ -87,11 +87,11 @@ TEST(GpuMultiStream, Basics) {
|
|||||||
/*buffer_reference=*/nullptr, client.get(), device));
|
/*buffer_reference=*/nullptr, client.get(), device));
|
||||||
// The execution may be enqueued before the transfers complete, requiring
|
// The execution may be enqueued before the transfers complete, requiring
|
||||||
// adequate device-side synchronization.
|
// adequate device-side synchronization.
|
||||||
|
ExecuteOptions options;
|
||||||
|
options.untuple_result = true;
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto out_buffer,
|
auto out_buffers,
|
||||||
executable->Execute({in_buffer0.get(), in_buffer1.get()}));
|
executable->Execute({in_buffer0.get(), in_buffer1.get()}, options));
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto out_buffers, out_buffer->DestructureTuple());
|
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto out_literal, out_buffers[0]->ToLiteral());
|
TF_ASSERT_OK_AND_ASSIGN(auto out_literal, out_buffers[0]->ToLiteral());
|
||||||
LiteralTestUtil::ExpectR1Equal<int32>(expected_outputs, *out_literal);
|
LiteralTestUtil::ExpectR1Equal<int32>(expected_outputs, *out_literal);
|
||||||
|
@ -319,7 +319,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
|
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
|
||||||
const std::vector<PyLocalBuffer*> buffers, PyLocalClient* client,
|
absl::Span<PyLocalBuffer* const> buffers, PyLocalClient* client,
|
||||||
Device* device) {
|
Device* device) {
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
device->GetLocalDeviceState());
|
||||||
@ -716,11 +716,21 @@ const std::string& PyLocalExecutable::name() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||||
|
PyLocalExecutable::ExecuteHelper(
|
||||||
absl::Span<PyLocalBuffer* const> argument_handles, int replica,
|
absl::Span<PyLocalBuffer* const> argument_handles, int replica,
|
||||||
int partition, const RunId& run_id) const {
|
int partition, const RunId& run_id, const ExecuteOptions& options) const {
|
||||||
const int device_id = (*device_assignment_)(replica, partition);
|
const int device_id = (*device_assignment_)(replica, partition);
|
||||||
Device* device = LookupDevice(*client_, device_id);
|
Device* device = LookupDevice(*client_, device_id);
|
||||||
|
|
||||||
|
std::unique_ptr<PyLocalBuffer> tuple_buffer;
|
||||||
|
std::vector<PyLocalBuffer*> tupled_arguments;
|
||||||
|
if (options.tuple_arguments) {
|
||||||
|
TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple(
|
||||||
|
argument_handles, client_, device));
|
||||||
|
tupled_arguments = {tuple_buffer.get()};
|
||||||
|
argument_handles = tupled_arguments;
|
||||||
|
}
|
||||||
CHECK_EQ(device->host_id(), client_->host_id());
|
CHECK_EQ(device->host_id(), client_->host_id());
|
||||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
int device_ordinal = device->local_device_state()->device_ordinal();
|
||||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
||||||
@ -763,16 +773,16 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
|||||||
event->WaitForEventOnStream(device_state->compute_stream());
|
event->WaitForEventOnStream(device_state->compute_stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
ExecutableRunOptions options;
|
ExecutableRunOptions run_options;
|
||||||
options.set_stream(device_state->compute_stream());
|
run_options.set_stream(device_state->compute_stream());
|
||||||
options.set_host_to_device_stream(device_state->host_to_device_stream());
|
run_options.set_host_to_device_stream(device_state->host_to_device_stream());
|
||||||
options.set_allocator(client_->allocator());
|
run_options.set_allocator(client_->allocator());
|
||||||
options.set_intra_op_thread_pool(
|
run_options.set_intra_op_thread_pool(
|
||||||
client_->client()->backend().eigen_intra_op_thread_pool_device());
|
client_->client()->backend().eigen_intra_op_thread_pool_device());
|
||||||
options.set_device_assignment(device_assignment_.get());
|
run_options.set_device_assignment(device_assignment_.get());
|
||||||
options.set_run_id(run_id);
|
run_options.set_run_id(run_id);
|
||||||
options.set_rng_seed(device_state->GetNewPrngSeed());
|
run_options.set_rng_seed(device_state->GetNewPrngSeed());
|
||||||
options.set_gpu_executable_run_options(client_->gpu_run_options());
|
run_options.set_gpu_executable_run_options(client_->gpu_run_options());
|
||||||
|
|
||||||
// The choice of where we wait is arbitrary; the reason for the wait is pacing
|
// The choice of where we wait is arbitrary; the reason for the wait is pacing
|
||||||
// to avoid problems such as memory fragmentation and running ahead too far,
|
// to avoid problems such as memory fragmentation and running ahead too far,
|
||||||
@ -785,7 +795,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
|||||||
int executable_idx = executables_.size() > 1 ? partition : 0;
|
int executable_idx = executables_.size() > 1 ? partition : 0;
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> result_buffer_or_status =
|
StatusOr<ScopedShapedBuffer> result_buffer_or_status =
|
||||||
executables_[executable_idx]->RunAsync(argument_buffer_ptrs, options);
|
executables_[executable_idx]->RunAsync(argument_buffer_ptrs, run_options);
|
||||||
|
|
||||||
VLOG(1) << "Replica " << replica << " partition " << partition
|
VLOG(1) << "Replica " << replica << " partition " << partition
|
||||||
<< " completed; ok=" << result_buffer_or_status.ok();
|
<< " completed; ok=" << result_buffer_or_status.ok();
|
||||||
@ -817,13 +827,19 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
|||||||
device_state->compute_stream(),
|
device_state->compute_stream(),
|
||||||
std::make_tuple(executables_[executable_idx], compute_reservation,
|
std::make_tuple(executables_[executable_idx], compute_reservation,
|
||||||
device_assignment_));
|
device_assignment_));
|
||||||
return absl::make_unique<PyLocalBuffer>(
|
std::vector<std::unique_ptr<PyLocalBuffer>> outputs;
|
||||||
|
outputs.push_back(absl::make_unique<PyLocalBuffer>(
|
||||||
result_buffer.on_host_shape(), result_buffer.on_device_shape(),
|
result_buffer.on_host_shape(), result_buffer.on_device_shape(),
|
||||||
std::move(out_buffer), client_, device);
|
std::move(out_buffer), client_, device));
|
||||||
|
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(outputs, outputs.front()->DestructureTuple());
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::Execute(
|
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||||
absl::Span<PyLocalBuffer* const> argument_handles) const {
|
PyLocalExecutable::Execute(absl::Span<PyLocalBuffer* const> argument_handles,
|
||||||
|
const ExecuteOptions& options) const {
|
||||||
if (num_replicas() != 1) {
|
if (num_replicas() != 1) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Attempted to execute computation with %d replicas using Execute()",
|
"Attempted to execute computation with %d replicas using Execute()",
|
||||||
@ -836,12 +852,13 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::Execute(
|
|||||||
}
|
}
|
||||||
VLOG(1) << "Executing computation " << name();
|
VLOG(1) << "Executing computation " << name();
|
||||||
return ExecuteHelper(argument_handles, /*replica=*/0, /*partition=*/0,
|
return ExecuteHelper(argument_handles, /*replica=*/0, /*partition=*/0,
|
||||||
RunId());
|
RunId(), options);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
StatusOr<std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>>
|
||||||
PyLocalExecutable::ExecuteOnLocalDevices(
|
PyLocalExecutable::ExecuteOnLocalDevices(
|
||||||
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles) const {
|
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles,
|
||||||
|
const ExecuteOptions& options) const {
|
||||||
tensorflow::profiler::TraceMe traceme(
|
tensorflow::profiler::TraceMe traceme(
|
||||||
"LocalExecutable::ExecuteOnLocalDevices");
|
"LocalExecutable::ExecuteOnLocalDevices");
|
||||||
|
|
||||||
@ -857,17 +874,17 @@ PyLocalExecutable::ExecuteOnLocalDevices(
|
|||||||
|
|
||||||
VLOG(1) << "Executing computation " << name()
|
VLOG(1) << "Executing computation " << name()
|
||||||
<< "; num_replicas=" << num_replicas()
|
<< "; num_replicas=" << num_replicas()
|
||||||
<< " num_partitions=" << num_partitions()
|
<< " num_partitions=" << num_partitions() << " num_local_devices=8"
|
||||||
<< " num_local_devices=" << num_local_devices;
|
<< num_local_devices;
|
||||||
std::vector<StatusOr<std::unique_ptr<PyLocalBuffer>>> results(
|
std::vector<StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>> results(
|
||||||
num_local_devices);
|
num_local_devices);
|
||||||
if (num_local_devices == 1) {
|
if (num_local_devices == 1) {
|
||||||
// Fast-path if there is only one device — run the computation on the
|
// Fast-path if there is only one device — run the computation on the
|
||||||
// current thread.
|
// current thread.
|
||||||
const int replica = local_logical_device_ids_[0].first;
|
const int replica = local_logical_device_ids_[0].first;
|
||||||
const int partition = local_logical_device_ids_[0].second;
|
const int partition = local_logical_device_ids_[0].second;
|
||||||
results[0] =
|
results[0] = ExecuteHelper(argument_handles[0], replica, partition, RunId(),
|
||||||
ExecuteHelper(argument_handles[0], replica, partition, RunId());
|
options);
|
||||||
} else {
|
} else {
|
||||||
RunId run_id;
|
RunId run_id;
|
||||||
absl::Mutex mu;
|
absl::Mutex mu;
|
||||||
@ -881,8 +898,8 @@ PyLocalExecutable::ExecuteOnLocalDevices(
|
|||||||
Device* device = local_devices_[i];
|
Device* device = local_devices_[i];
|
||||||
const LocalDeviceState& device_state = *device->local_device_state();
|
const LocalDeviceState& device_state = *device->local_device_state();
|
||||||
device_state.execute_thread()->Schedule([&, replica, partition, i] {
|
device_state.execute_thread()->Schedule([&, replica, partition, i] {
|
||||||
results[i] =
|
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
|
||||||
ExecuteHelper(argument_handles[i], replica, partition, run_id);
|
run_id, options);
|
||||||
|
|
||||||
absl::MutexLock lock(&mu);
|
absl::MutexLock lock(&mu);
|
||||||
--running;
|
--running;
|
||||||
@ -923,7 +940,7 @@ PyLocalExecutable::ExecuteOnLocalDevices(
|
|||||||
}
|
}
|
||||||
VLOG(1) << "Replicated execution complete.";
|
VLOG(1) << "Replicated execution complete.";
|
||||||
|
|
||||||
std::vector<std::unique_ptr<PyLocalBuffer>> wrapped_results(
|
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>> wrapped_results(
|
||||||
num_local_devices);
|
num_local_devices);
|
||||||
for (int i = 0; i < num_local_devices; ++i) {
|
for (int i = 0; i < num_local_devices; ++i) {
|
||||||
const int replica = local_logical_device_ids_[i].first;
|
const int replica = local_logical_device_ids_[i].first;
|
||||||
|
@ -214,7 +214,7 @@ class PyLocalBuffer {
|
|||||||
Device* device);
|
Device* device);
|
||||||
|
|
||||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
||||||
const std::vector<PyLocalBuffer*> buffers, PyLocalClient* client,
|
absl::Span<PyLocalBuffer* const> buffers, PyLocalClient* client,
|
||||||
Device* device);
|
Device* device);
|
||||||
|
|
||||||
// Asynchronously makes a vector of PyLocalBuffers that can be used to receive
|
// Asynchronously makes a vector of PyLocalBuffers that can be used to receive
|
||||||
@ -320,6 +320,16 @@ struct CompileOptions {
|
|||||||
ExecutableBuildOptions executable_build_options;
|
ExecutableBuildOptions executable_build_options;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ExecuteOptions {
|
||||||
|
// If true, the arguments to the computation will be wrapped in a tuple and
|
||||||
|
// passed as a single parameter.
|
||||||
|
bool tuple_arguments = false;
|
||||||
|
|
||||||
|
// If true, the computation must return a tuple, which will be destructured
|
||||||
|
// into its elements.
|
||||||
|
bool untuple_result = false;
|
||||||
|
};
|
||||||
|
|
||||||
// Represents a compiled computation that can be executed given handles to
|
// Represents a compiled computation that can be executed given handles to
|
||||||
// device-allocated literals. Wraps one or more XLA LocalExecutables (one per
|
// device-allocated literals. Wraps one or more XLA LocalExecutables (one per
|
||||||
// partition, as specified by the build options).
|
// partition, as specified by the build options).
|
||||||
@ -364,24 +374,27 @@ class PyLocalExecutable {
|
|||||||
|
|
||||||
const std::vector<Device*>& local_devices() const { return local_devices_; }
|
const std::vector<Device*>& local_devices() const { return local_devices_; }
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PyLocalBuffer>> Execute(
|
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> Execute(
|
||||||
absl::Span<PyLocalBuffer* const> argument_handles) const;
|
absl::Span<PyLocalBuffer* const> argument_handles,
|
||||||
|
const ExecuteOptions& options) const;
|
||||||
|
|
||||||
// Execute on local devices. Takes a sequence of argument lists (one argument
|
// Execute on local devices. Takes a sequence of argument lists (one argument
|
||||||
// list per local device) and returns a tuple of results (one result per local
|
// list per local device) and returns a tuple of results (one result per local
|
||||||
// device). The number of argument lists must be equal to the local device
|
// device). The number of argument lists must be equal to the local device
|
||||||
// count.
|
// count.
|
||||||
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ExecuteOnLocalDevices(
|
StatusOr<std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>>
|
||||||
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles) const;
|
ExecuteOnLocalDevices(
|
||||||
|
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles,
|
||||||
|
const ExecuteOptions& options) const;
|
||||||
|
|
||||||
void Delete() { executables_.clear(); }
|
void Delete() { executables_.clear(); }
|
||||||
|
|
||||||
const string& name() const;
|
const string& name() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
StatusOr<std::unique_ptr<PyLocalBuffer>> ExecuteHelper(
|
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ExecuteHelper(
|
||||||
absl::Span<PyLocalBuffer* const> argument_handles, int replica,
|
absl::Span<PyLocalBuffer* const> argument_handles, int replica,
|
||||||
int partition, const RunId& run_id) const;
|
int partition, const RunId& run_id, const ExecuteOptions& options) const;
|
||||||
|
|
||||||
// Create shared pointers so we can free them after the execution: with
|
// Create shared pointers so we can free them after the execution: with
|
||||||
// asynchronous execution, the process being executed can outlive the
|
// asynchronous execution, the process being executed can outlive the
|
||||||
|
@ -964,7 +964,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
py::arg("force_copy") = false)
|
py::arg("force_copy") = false)
|
||||||
.def_static(
|
.def_static(
|
||||||
"make_tuple",
|
"make_tuple",
|
||||||
[](const std::vector<PyLocalBuffer*> buffers,
|
[](std::vector<PyLocalBuffer*> buffers,
|
||||||
std::shared_ptr<PyLocalClient> client,
|
std::shared_ptr<PyLocalClient> client,
|
||||||
Device* device) -> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
|
Device* device) -> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
|
||||||
CHECK(device != nullptr);
|
CHECK(device != nullptr);
|
||||||
@ -1141,21 +1141,26 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
absl::Span<PyLocalBuffer* const> args)
|
absl::Span<PyLocalBuffer* const> args)
|
||||||
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
|
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PyLocalBuffer> output,
|
TF_ASSIGN_OR_RETURN(
|
||||||
executable.Execute(args));
|
std::vector<std::unique_ptr<PyLocalBuffer>> output,
|
||||||
|
executable.Execute(args, ExecuteOptions()));
|
||||||
return WrapWithClient(executable.client()->shared_from_this(),
|
return WrapWithClient(executable.client()->shared_from_this(),
|
||||||
std::move(output));
|
std::move(output.front()));
|
||||||
},
|
},
|
||||||
py::arg("arguments"))
|
py::arg("arguments"))
|
||||||
|
// TODO(phawkins): remove in favor of overload that returns a vector.
|
||||||
.def(
|
.def(
|
||||||
"ExecuteOnLocalDevices",
|
"Execute",
|
||||||
[](const PyLocalExecutable& executable,
|
[](const PyLocalExecutable& executable,
|
||||||
absl::Span<const std::vector<PyLocalBuffer*>> args)
|
absl::Span<PyLocalBuffer* const> args, bool tuple_arguments)
|
||||||
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
|
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
|
ExecuteOptions options;
|
||||||
|
options.tuple_arguments = tuple_arguments;
|
||||||
|
options.untuple_result = true;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<std::unique_ptr<PyLocalBuffer>> output_buffers,
|
std::vector<std::unique_ptr<PyLocalBuffer>> output_buffers,
|
||||||
executable.ExecuteOnLocalDevices(args));
|
executable.Execute(args, options));
|
||||||
std::vector<ClientAndUniquePtr<PyLocalBuffer>> outputs;
|
std::vector<ClientAndUniquePtr<PyLocalBuffer>> outputs;
|
||||||
outputs.reserve(output_buffers.size());
|
outputs.reserve(output_buffers.size());
|
||||||
for (auto& buffer : output_buffers) {
|
for (auto& buffer : output_buffers) {
|
||||||
@ -1164,7 +1169,56 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
}
|
}
|
||||||
return outputs;
|
return outputs;
|
||||||
},
|
},
|
||||||
|
py::arg("arguments"), py::arg("tuple_arguments"))
|
||||||
|
// TODO(phawkins): remove in favor of overload that returns a vector.
|
||||||
|
.def(
|
||||||
|
"ExecuteOnLocalDevices",
|
||||||
|
[](const PyLocalExecutable& executable,
|
||||||
|
absl::Span<const std::vector<PyLocalBuffer*>> args)
|
||||||
|
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||||
|
output_buffers,
|
||||||
|
executable.ExecuteOnLocalDevices(args, ExecuteOptions()));
|
||||||
|
std::vector<ClientAndUniquePtr<PyLocalBuffer>> outputs;
|
||||||
|
outputs.reserve(output_buffers.size());
|
||||||
|
for (auto& buffers : output_buffers) {
|
||||||
|
outputs.push_back(
|
||||||
|
WrapWithClient(executable.client()->shared_from_this(),
|
||||||
|
std::move(buffers.front())));
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
},
|
||||||
py::arg("arguments"))
|
py::arg("arguments"))
|
||||||
|
.def(
|
||||||
|
"ExecuteOnLocalDevices",
|
||||||
|
[](const PyLocalExecutable& executable,
|
||||||
|
absl::Span<const std::vector<PyLocalBuffer*>> args,
|
||||||
|
bool tuple_arguments)
|
||||||
|
-> StatusOr<
|
||||||
|
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>>> {
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
ExecuteOptions options;
|
||||||
|
options.tuple_arguments = tuple_arguments;
|
||||||
|
options.untuple_result = true;
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||||
|
output_buffers,
|
||||||
|
executable.ExecuteOnLocalDevices(args, options));
|
||||||
|
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> outputs;
|
||||||
|
outputs.resize(output_buffers.size());
|
||||||
|
for (int computation = 0; computation < output_buffers.size();
|
||||||
|
++computation) {
|
||||||
|
for (auto& buffer : output_buffers[computation]) {
|
||||||
|
outputs[computation].push_back(
|
||||||
|
WrapWithClient(executable.client()->shared_from_this(),
|
||||||
|
std::move(buffer)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
},
|
||||||
|
py::arg("arguments"), py::arg("tuple_arguments"))
|
||||||
.def(
|
.def(
|
||||||
"get_hlo_modules",
|
"get_hlo_modules",
|
||||||
[](const PyLocalExecutable& executable)
|
[](const PyLocalExecutable& executable)
|
||||||
|
Loading…
Reference in New Issue
Block a user