Cleanup PjRtExecutable API.

- Regroup Execute API into 3 distinct Execute methods
- Rename local_devices to addressable_devices
- Introduce LogicalDeviceId struct to name replica and partition, to replace std::pair.
- Return Span instead of const vector&.

PiperOrigin-RevId: 345551501
Change-Id: I2f7b50101849af02c7188d547d78a53dc7d030be
This commit is contained in:
Qiao Zhang 2020-12-03 15:43:20 -08:00 committed by TensorFlower Gardener
parent 077fe29d9d
commit 5bbe185466
12 changed files with 196 additions and 176 deletions

View File

@ -93,11 +93,11 @@ TEST(GpuMultiStream, Basics) {
options.untuple_result = true;
TF_ASSERT_OK_AND_ASSIGN(
auto out_buffers,
executable->Execute({in_buffer0.get(), in_buffer1.get()}, options));
executable->Execute({{in_buffer0.get(), in_buffer1.get()}}, options));
TF_ASSERT_OK_AND_ASSIGN(auto out_literal, out_buffers[0]->ToLiteral());
TF_ASSERT_OK_AND_ASSIGN(auto out_literal, out_buffers[0][0]->ToLiteral());
LiteralTestUtil::ExpectR1Equal<int32>(expected_outputs, *out_literal);
TF_ASSERT_OK_AND_ASSIGN(out_literal, out_buffers[1]->ToLiteral());
TF_ASSERT_OK_AND_ASSIGN(out_literal, out_buffers[0][1]->ToLiteral());
LiteralTestUtil::ExpectR1Equal<int32>(expected_outputs, *out_literal);
}
}

View File

@ -1458,13 +1458,14 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<std::pair<int, int>> local_logical_device_ids,
std::vector<PjRtDevice*> local_devices, PjRtClient* client)
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client)
: client_(client),
device_assignment_(std::move(device_assignment)),
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
local_logical_device_ids_(std::move(local_logical_device_ids)),
local_devices_(std::move(local_devices)) {
addressable_device_logical_ids_(
std::move(addressable_device_logical_ids)),
addressable_devices_(std::move(addressable_devices)) {
executables_.reserve(executables.size());
for (auto& executable : executables) {
executables_.emplace_back(std::move(executable));
@ -1475,13 +1476,13 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
// This must go after `executables_` is initialized.
VLOG(1) << "PjRtStreamExecutorExecutable portable single-core";
num_partitions = 1;
CHECK(local_devices_.empty());
CHECK(addressable_devices_.empty());
} else {
// This must go after `executables_` is initialized.
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
<< device_assignment_->ToString();
CHECK_GE(local_devices_.size(), 1) << device_assignment_->ToString();
CHECK_LE(local_devices_.size(), client_->local_device_count())
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
CHECK_LE(addressable_devices_.size(), client_->local_device_count())
<< "Inconsistent local device count.";
num_partitions = device_assignment_->computation_count();
}
@ -1807,7 +1808,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
CHECK(device_assignment_ == nullptr);
CHECK_EQ(replica, 0);
CHECK_EQ(partition, 0);
CHECK(local_devices_.empty());
CHECK(addressable_devices_.empty());
device_assignment = std::make_shared<DeviceAssignment>(1, 1);
(*device_assignment)(0, 0) = device->id();
}
@ -1875,94 +1876,52 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
return outputs;
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::Execute(
absl::Span<PjRtBuffer* const> argument_handles,
const ExecuteOptions& options) const {
if (num_replicas() != 1) {
return InvalidArgument(
"Attempted to execute computation with %d replicas using Execute()",
num_replicas());
}
if (num_partitions() != 1) {
return InvalidArgument(
"Attempted to execute computation with %d partitions using Execute()",
num_partitions());
}
VLOG(1) << "Executing computation " << name();
return ExecuteHelper(argument_handles, /*replica=*/0, /*partition=*/0,
RunId(), options);
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::ExecuteOnLocalDevice(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const {
if (device_assignment_ == nullptr) {
VLOG(1) << "Executing portable single-core program on "
<< device->DebugString();
return ExecuteHelper(argument_handles,
/*replica=*/0,
/*partition=*/0, RunId(), options, device);
}
for (int i = 0; i < local_devices_.size(); ++i) {
if (local_devices_[i] == device) {
VLOG(1) << "Executing computation " << name();
return ExecuteHelper(argument_handles,
/*replica=*/local_logical_device_ids_[i].first,
/*partition=*/local_logical_device_ids_[i].second,
RunId(), options);
}
}
return InvalidArgument(
"Attempted to execute on device id %d which is not a local device",
device->id());
}
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
PjRtStreamExecutorExecutable::ExecuteOnLocalDevices(
PjRtStreamExecutorExecutable::Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const {
CHECK(device_assignment_ != nullptr);
if (device_assignment_ == nullptr) {
return InvalidArgument("Execute expects a non-null device_assignment");
}
RunId run_id;
tensorflow::profiler::TraceMeProducer activity(
"LocalExecutable::ExecuteOnLocalDevices",
"PjRtStreamExecutorExecutable::Execute",
tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
const int num_local_devices = local_devices_.size();
const int num_addressable_devices = addressable_devices_.size();
if (argument_handles.size() != num_local_devices) {
if (argument_handles.size() != num_addressable_devices) {
return InvalidArgument(
"Attempted to execute with %d argument lists when local device "
"count is %d (total replica count: %d, partition count: %d)",
argument_handles.size(), num_local_devices, num_replicas(),
argument_handles.size(), num_addressable_devices, num_replicas(),
num_partitions());
}
VLOG(1) << "Executing computation " << name()
<< "; num_replicas=" << num_replicas()
<< " num_partitions=" << num_partitions()
<< " num_local_devices=" << num_local_devices;
<< " num_addressable_devices=" << num_addressable_devices;
std::vector<StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>> results(
num_local_devices);
if (num_local_devices == 1) {
num_addressable_devices);
if (num_addressable_devices == 1) {
// Fast-path if there is only one device — run the computation on the
// current thread.
const int replica = local_logical_device_ids_[0].first;
const int partition = local_logical_device_ids_[0].second;
const int replica = addressable_device_logical_ids_[0].replica;
const int partition = addressable_device_logical_ids_[0].partition;
results[0] =
ExecuteHelper(argument_handles[0], replica, partition, run_id, options);
} else {
absl::Mutex mu;
int running = num_local_devices;
int running = num_addressable_devices;
int failed = 0;
Status first_failure_status;
for (int i = 0; i < num_local_devices; ++i) {
const int replica = local_logical_device_ids_[i].first;
const int partition = local_logical_device_ids_[i].second;
PjRtDevice* device = local_devices_[i];
for (int i = 0; i < num_addressable_devices; ++i) {
const int replica = addressable_device_logical_ids_[i].replica;
const int partition = addressable_device_logical_ids_[i].partition;
PjRtDevice* device = addressable_devices_[i];
const LocalDeviceState& device_state = *device->local_device_state();
device_state.execute_thread()->Schedule([&, replica, partition, i] {
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
@ -2008,10 +1967,10 @@ PjRtStreamExecutorExecutable::ExecuteOnLocalDevices(
VLOG(1) << "Replicated execution complete.";
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
num_local_devices);
for (int i = 0; i < num_local_devices; ++i) {
const int replica = local_logical_device_ids_[i].first;
const int partition = local_logical_device_ids_[i].second;
num_addressable_devices);
for (int i = 0; i < num_addressable_devices; ++i) {
const int replica = addressable_device_logical_ids_[i].replica;
const int partition = addressable_device_logical_ids_[i].partition;
auto& statusor = results[i];
if (!statusor.ok()) {
return AppendStatus(
@ -2026,6 +1985,52 @@ PjRtStreamExecutorExecutable::ExecuteOnLocalDevices(
return wrapped_results;
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const {
if (device_assignment_ == nullptr) {
return InvalidArgument("ExecuteShard expects a non-null device_assignment");
}
for (int i = 0; i < addressable_devices_.size(); ++i) {
if (addressable_devices_[i] == device) {
VLOG(1) << "ExecuteShard executes computation " << name()
<< " on assigned replica/partition on device "
<< device->DebugString();
return ExecuteHelper(
argument_handles, addressable_device_logical_ids_[i].replica,
addressable_device_logical_ids_[i].partition, RunId(), options);
}
}
return InvalidArgument(
"ExecuteShard attempted to execute on device id %d which is not "
"addressable by this client",
device->id());
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const {
if (device_assignment_ != nullptr) {
return InvalidArgument("ExecutePortable gets a non-portable executable");
}
if (num_replicas() != 1 || num_partitions() != 1) {
return InvalidArgument(
"ExecutePortable expects a single-core executable but gets "
"one with %d replica %d partition",
num_replicas(), num_partitions());
}
if (device == nullptr) {
return InvalidArgument("ExecutePortable expects a device to be specified");
}
VLOG(1) << "ExecutePortable executes single-core portable executable "
<< name();
return ExecuteHelper(argument_handles,
/*replica=*/0,
/*partition=*/0, RunId(), options, device);
}
StatusOr<std::vector<std::shared_ptr<HloModule>>>
PjRtStreamExecutorExecutable::GetHloModules() const {
std::vector<std::shared_ptr<HloModule>> modules;
@ -2220,9 +2225,12 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
build_options.set_result_layout(result_layout);
std::vector<std::pair<int, int>> local_logical_device_ids;
std::vector<PjRtDevice*> local_devices;
// Find devices that are addressable by this client/task.
std::vector<PjRtExecutable::LogicalDeviceIds> addressable_device_logical_ids;
std::vector<PjRtDevice*> addressable_devices;
if (device_assignment != nullptr) {
addressable_device_logical_ids.reserve(num_replicas * num_partitions);
addressable_devices.reserve(num_replicas * num_partitions);
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition);
@ -2231,11 +2239,13 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
VLOG(3) << "Non-local device: " << device_id;
continue;
}
local_logical_device_ids.emplace_back(replica, partition);
local_devices.push_back(device);
addressable_device_logical_ids.push_back(
PjRtExecutable::LogicalDeviceIds{.replica = replica,
.partition = partition});
addressable_devices.push_back(device);
}
}
if (local_devices.empty()) {
if (addressable_devices.empty()) {
return InvalidArgument(
"Device assignment (%s) does not have any local devices.",
device_assignment->ToString());
@ -2243,7 +2253,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
if (build_options.device_ordinal() < 0) {
build_options.set_device_ordinal(
local_devices.front()->local_device_state()->device_ordinal());
addressable_devices.front()->local_device_state()->device_ordinal());
}
}
@ -2253,8 +2263,8 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
auto executable = absl::make_unique<PjRtStreamExecutorExecutable>(
std::move(local_executables), options.parameter_is_tupled_arguments,
std::move(device_assignment), std::move(local_logical_device_ids),
std::move(local_devices), this);
std::move(device_assignment), std::move(addressable_device_logical_ids),
std::move(addressable_devices), this);
TF_RETURN_IF_ERROR(
executable->SetUpDonation(options.parameter_is_tupled_arguments));
return std::unique_ptr<PjRtExecutable>(std::move(executable));

View File

@ -781,41 +781,43 @@ class PjRtExecutable {
// The replica and partition indices of device_assignment to be run by this
// client. On single-host platforms without partitioning, this is all replicas
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
// on multi-host platforms. If there are 4 replicas and 2 partitions on a
// single host platform, size of local_logical_device_ids_ is 4*2 = 8.
// TODO(zhangqiaorjc): Add a struct for the pair and return a span.
virtual const std::vector<std::pair<int, int>>& local_logical_device_ids()
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
struct LogicalDeviceIds {
int replica;
int partition;
};
virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
const = 0;
// local_devices()[i] is the Device to which local_logical_device_ids()[i] is
// assigned.
virtual const std::vector<PjRtDevice*>& local_devices() const = 0;
// addressable_devices()[i] is the Device to which
// addressable_device_logical_ids()[i] is assigned.
virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
// Return an HloModule (optimized) per partition.
virtual StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
const = 0;
// Execute on replica 0 and partition 0 with the requirement that there's a
// single replica and partition.
// TODO(zhangqiaorjc): Merge with ExecuteOnLocalDevice. Remove "local".
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> Execute(
absl::Span<PjRtBuffer* const> argument_handles,
// Executes on devices addressable by the client. Requires executable has a
// device_assignment and all devices in the device_assignment are addressable
// by the client.
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const = 0;
// Execute the assigned replica/partition on a given `device`. Requires
// executable has a device_assignment, `device` is present in the
// device_assignment and addressable by the client.
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const = 0;
// Execute on a given local device.
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteOnLocalDevice(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device,
const ExecuteOptions& options) const = 0;
// Execute on local devices. Takes a sequence of argument lists (one argument
// list per local device) and returns a tuple of results (one result per local
// device). The number of argument lists must be equal to the local device
// count.
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
ExecuteOnLocalDevices(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
// Execute on a given `device`. Requires `device` to be addressable by client.
// Requires executable has exactly 1 replica and 1 partition and no
// device_assignment (thus portable).
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const = 0;
// Asynchronously free resources after the last execution completes.
@ -830,8 +832,8 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<std::pair<int, int>> local_logical_device_ids,
std::vector<PjRtDevice*> local_devices, PjRtClient* client);
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client);
~PjRtStreamExecutorExecutable() override = default;
@ -859,39 +861,34 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
return *device_assignment_;
}
const std::vector<std::pair<int, int>>& local_logical_device_ids()
absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
const override {
return local_logical_device_ids_;
return addressable_device_logical_ids_;
}
const std::vector<PjRtDevice*>& local_devices() const override {
return local_devices_;
absl::Span<PjRtDevice* const> addressable_devices() const override {
return addressable_devices_;
}
// Return an HloModule per partition.
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
const override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> Execute(
absl::Span<PjRtBuffer* const> argument_handles,
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteOnLocalDevice(
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const override;
// Execute on local devices. Takes a sequence of argument lists (one argument
// list per local device) and returns a tuple of results (one result per local
// device). The number of argument lists must be equal to the local device
// count.
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
ExecuteOnLocalDevices(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const override;
void Delete() override { executables_.clear(); }
const std::vector<std::shared_ptr<LocalExecutable>>& executables() const {
absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
return executables_;
}
@ -950,17 +947,15 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
// The replica and partition indices of device_assignment_ to be run by this
// client. On single-host platforms without partitioning, this is all replicas
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
// on multi-host platforms.
// If there are 4 replicas and 2 partitions on a single host platform, size of
// local_logical_device_ids_ is 4*2 = 8.
std::vector<std::pair<int, int>> local_logical_device_ids_;
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
// local_devices_[i] is the Device to which local_logical_device_ids_[i] is
// assigned.
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
// (see xla.cc).
std::vector<PjRtDevice*> local_devices_;
// addressable_devices_[i] is the Device to which
// addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
// unique_ptrs to play well with the Python bindings (see xla.cc).
std::vector<PjRtDevice*> addressable_devices_;
};
// Executables can donate buffers so that buffers can be aliased from inputs

View File

@ -781,7 +781,7 @@ CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
cache_entry->executable = std::move(executable);
int num_devices =
cache_entry->executable->pjrt_executable().local_devices().size();
cache_entry->executable->pjrt_executable().addressable_devices().size();
// The presence of jit(pmap) is detected from Python.
CHECK_EQ(num_devices, 1);

View File

@ -413,8 +413,9 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
devices_[device_idx]->client()->Compile(
computation, std::move(compile_options)));
ExecuteOptions execute_options;
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
executable->Execute({}, execute_options));
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
executable->Execute({{}}, execute_options));
return Status::OK();
}

View File

@ -43,8 +43,9 @@ Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id,
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
client->Compile(computation, std::move(compile_options)));
ExecuteOptions execute_options;
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
executable->Execute({}, execute_options));
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
executable->Execute({{}}, execute_options));
return Status::OK();
}

View File

@ -58,26 +58,29 @@ PyExecutable::~PyExecutable() {
}
}
std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::LocalDevices() const {
std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::AddressableDevices() const {
std::vector<ClientAndPtr<PjRtDevice>> devices;
devices.reserve(executable_->local_devices().size());
for (PjRtDevice* device : executable_->local_devices()) {
devices.reserve(executable_->addressable_devices().size());
for (PjRtDevice* device : executable_->addressable_devices()) {
devices.push_back(WrapWithClient(client_, device));
}
return devices;
}
// Used by JAX JIT which has C++ PjRtBuffers as inputs (Numpy to PjRtBuffer is
// faster and simpler than Numpy to PyBuffer to PjRtBuffer) and requires
// PyBuffer as outputs as it will return to Python.
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::PjRtExecute(
absl::Span<PjRtBuffer* const> args) {
std::vector<std::unique_ptr<PjRtBuffer>> output_buffers;
const std::vector<PjRtBuffer*>& args) {
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(output_buffers, executable_->Execute(args, options_));
TF_ASSIGN_OR_RETURN(output_buffers, executable_->Execute({args}, options_));
}
auto traceback = Traceback::Get();
std::vector<std::unique_ptr<PyBuffer>> outputs;
outputs.reserve(output_buffers.size());
for (auto& buffer : output_buffers) {
outputs.reserve(output_buffers[0].size());
for (auto& buffer : output_buffers[0]) {
outputs.push_back(
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
}
@ -86,19 +89,19 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::PjRtExecute(
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
absl::Span<PyBuffer* const> args) {
std::vector<std::unique_ptr<PjRtBuffer>> output_buffers;
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
{
py::gil_scoped_release gil_release;
std::vector<PjRtBuffer*> arg_buffers(args.size());
absl::c_transform(args, arg_buffers.begin(),
[](PyBuffer* buf) { return buf->buffer(); });
TF_ASSIGN_OR_RETURN(output_buffers,
executable_->Execute(arg_buffers, options_));
executable_->Execute({arg_buffers}, options_));
}
auto traceback = Traceback::Get();
std::vector<std::unique_ptr<PyBuffer>> outputs;
outputs.reserve(output_buffers.size());
for (auto& buffer : output_buffers) {
outputs.reserve(output_buffers[0].size());
for (auto& buffer : output_buffers[0]) {
outputs.push_back(
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
}
@ -117,8 +120,8 @@ PyExecutable::ExecuteOnLocalDevices(
absl::c_transform(args[computation], arg_buffers[computation].begin(),
[](PyBuffer* buf) { return buf->buffer(); });
}
TF_ASSIGN_OR_RETURN(output_buffers, executable_->ExecuteOnLocalDevices(
arg_buffers, options_));
TF_ASSIGN_OR_RETURN(output_buffers,
executable_->Execute(arg_buffers, options_));
}
auto traceback = Traceback::Get();
std::vector<std::vector<std::unique_ptr<PyBuffer>>> outputs;

View File

@ -43,11 +43,12 @@ class PyExecutable {
std::shared_ptr<PyClient> client() const { return client_; }
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
return executable_->local_logical_device_ids();
absl::Span<const PjRtExecutable::LogicalDeviceIds>
addressable_device_logical_ids() const {
return executable_->addressable_device_logical_ids();
}
std::vector<ClientAndPtr<PjRtDevice>> LocalDevices() const;
std::vector<ClientAndPtr<PjRtDevice>> AddressableDevices() const;
int64 SizeOfGeneratedCodeInBytes() const {
return executable_->SizeOfGeneratedCodeInBytes();
@ -60,7 +61,7 @@ class PyExecutable {
// Same as above, but take as inputs `PjRtBuffer*`. Only targets C++ code.
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PjRtExecute(
absl::Span<PjRtBuffer* const> args);
const std::vector<PjRtBuffer*>& args);
StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
ExecuteOnLocalDevices(absl::Span<const std::vector<PyBuffer*>> args);

View File

@ -531,7 +531,7 @@ PyTpuExecutable::PyTpuExecutable(
<< "Inserting duplicate replica:" << replica;
executables_[replica] =
client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
local_logical_device_ids_.emplace_back(replica, partition);
addressable_device_logical_ids_.emplace_back(replica, partition);
local_devices_.push_back(device);
}
}
@ -711,8 +711,8 @@ PyTpuExecutable::ExecuteOnLocalDevices(
// long time and we want all cores to be scheduled in parallel.
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
&execute_semaphore]() {
const int replica = local_logical_device_ids_[i].first;
const int partition = local_logical_device_ids_[i].second;
const int replica = addressable_device_logical_ids_[i].first;
const int partition = addressable_device_logical_ids_[i].second;
RunId run_id;
auto result = ExecuteHelper(argument_handles, argument_handles[i],
replica, partition, run_id);

View File

@ -298,8 +298,9 @@ class PyTpuExecutable {
return device_assignment_;
}
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
return local_logical_device_ids_;
const std::vector<std::pair<int, int>>& addressable_device_logical_ids()
const {
return addressable_device_logical_ids_;
}
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
@ -340,16 +341,14 @@ class PyTpuExecutable {
// The replica and partition indices of device_assignment_ to be run by this
// client. On single-host platforms without partitioning, this is all replicas
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
// on multi-host platforms.
// If there are 4 replicas and 2 partitions on a single host platform, size of
// local_logical_device_ids_ is 4*2 = 8.
std::vector<std::pair<int, int>> local_logical_device_ids_;
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
std::vector<std::pair<int, int>> addressable_device_logical_ids_;
// local_devices_[i] is the Device to which local_logical_device_ids_[i] is
// assigned.
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
// (see xla.cc).
// local_devices_[i] is the Device to which addressable_device_logical_ids_[i]
// is assigned. shared_ptrs instead of unique_ptrs to play well with the
// Python bindings (see xla.cc).
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
xla::Shape result_shape_;

View File

@ -186,7 +186,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
py::class_<PyTpuExecutable>(m, "TpuExecutable")
.def("local_logical_device_ids",
&PyTpuExecutable::local_logical_device_ids)
&PyTpuExecutable::addressable_device_logical_ids)
.def("local_devices", &PyTpuExecutable::local_devices)
.def_property_readonly("client", &PyTpuExecutable::client)
.def("size_of_generated_code_in_bytes",

View File

@ -377,8 +377,18 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<PyExecutable, std::shared_ptr<PyExecutable>> executable(
m, "Executable");
executable.def_property_readonly("client", &PyExecutable::client)
.def("local_logical_device_ids", &PyExecutable::local_logical_device_ids)
.def("local_devices", &PyExecutable::LocalDevices)
.def("local_logical_device_ids",
[](PyExecutable* exec) {
auto span = exec->addressable_device_logical_ids();
// Not on dispatch critical path, so ok to have heap allocation.
std::vector<std::pair<int, int>> addressable_device_logical_ids;
addressable_device_logical_ids.reserve(span.size());
for (const auto& logical_device_id : span) {
addressable_device_logical_ids.push_back(std::make_pair(
logical_device_id.replica, logical_device_id.partition));
}
})
.def("local_devices", &PyExecutable::AddressableDevices)
.def("size_of_generated_code_in_bytes",
&PyExecutable::SizeOfGeneratedCodeInBytes)
.def("delete", &PyExecutable::Delete)