[XLA:Python] Add a PyClient wrapper object around PjRtClient.

Refactoring in preparation for adding Python-specific logic to clients, but it also is more readable than inlining this kind of logic into the binding code.

PiperOrigin-RevId: 314994078
Change-Id: I3b7253d4a14ff418068e49f9bd321d85695f9c8b
This commit is contained in:
Peter Hawkins 2020-06-05 14:12:51 -07:00 committed by TensorFlower Gardener
parent e96c412005
commit 3f7adb0d48
13 changed files with 355 additions and 239 deletions

View File

@ -180,9 +180,17 @@ py_test(
)
cc_library(
name = "py_buffer",
srcs = ["py_buffer.cc"],
hdrs = ["py_buffer.h"],
name = "py_client",
srcs = [
"py_buffer.cc",
"py_client.cc",
"py_executable.cc",
],
hdrs = [
"py_buffer.h",
"py_client.h",
"py_executable.h",
],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -195,29 +203,10 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "py_executable",
srcs = ["py_executable.cc"],
hdrs = ["py_executable.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":py_buffer",
":traceback_manager",
":types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@pybind11",
],
)
@ -231,7 +220,7 @@ cc_library(
],
features = ["-use_header_modules"],
deps = [
":py_buffer",
":py_client",
":traceback_manager",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
@ -325,6 +314,7 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":outfeed_receiver",
":py_client",
":types",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
@ -355,8 +345,7 @@ pybind_extension(
":bfloat16",
":dlpack",
":ops",
":py_buffer",
":py_executable",
":py_client",
":python_ref_manager",
":outfeed_receiver_py",
":traceback_manager",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/python/traceback_manager.h"
#include "tensorflow/compiler/xla/types.h"
@ -298,7 +299,7 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(PyBuffer* buffer) {
}
StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
const pybind11::capsule& tensor, std::shared_ptr<PjRtClient> client) {
const pybind11::capsule& tensor, std::shared_ptr<PyClient> client) {
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
return InvalidArgument(
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
@ -311,8 +312,9 @@ StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
"Number of dimensions in DLManagedTensor must be nonnegative, got %d",
dlmt->dl_tensor.ndim);
}
TF_ASSIGN_OR_RETURN(Device * device,
DeviceForDLContext(*client, dlmt->dl_tensor.ctx));
TF_ASSIGN_OR_RETURN(
Device * device,
DeviceForDLContext(*client->pjrt_client(), dlmt->dl_tensor.ctx));
absl::Span<int64 const> dimensions(
reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
@ -349,7 +351,7 @@ StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
PyCapsule_SetName(tensor.ptr(), "used_dltensor");
PyCapsule_SetDestructor(tensor.ptr(), nullptr);
auto pjrt_buffer = std::make_unique<PjRtBuffer>(
shape, shape, std::move(device_buffer), client.get(), device);
shape, shape, std::move(device_buffer), client->pjrt_client(), device);
return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer),
TracebackManager::Get()->GetTraceback());
}

View File

@ -17,15 +17,15 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/python/py_client.h"
namespace xla {
StatusOr<pybind11::capsule> BufferToDLPackManagedTensor(PyBuffer* buffer);
StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
const pybind11::capsule& tensor, std::shared_ptr<PjRtClient> client);
const pybind11::capsule& tensor, std::shared_ptr<PyClient> client);
} // namespace xla

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/outfeed_receiver.h"
#include "tensorflow/compiler/xla/python/py_client.h"
#include "tensorflow/compiler/xla/python/types.h"
namespace xla {
@ -42,7 +43,7 @@ class OutfeedReceiverForPython {
std::function<void(ClientAndPtr<Device>, uint32_t, pybind11::object)>;
OutfeedReceiverForPython(CallbackToPython callback_python,
std::vector<std::shared_ptr<PjRtClient>> clients,
std::vector<std::shared_ptr<PyClient>> clients,
ssize_t max_callback_queue_size_bytes)
: callback_python_(std::move(callback_python)),
clients_(std::move(clients)) {
@ -52,9 +53,10 @@ class OutfeedReceiverForPython {
this->Callback(device, consumer_id, std::move(literal));
};
std::vector<PjRtClient*> client_ptrs(clients.size());
absl::c_transform(
clients_, client_ptrs.begin(),
[](const std::shared_ptr<PjRtClient>& client) { return client.get(); });
absl::c_transform(clients_, client_ptrs.begin(),
[](const std::shared_ptr<PyClient>& client) {
return client->pjrt_client();
});
outfeed_receiver_ = absl::make_unique<OutfeedReceiver>(
callback, client_ptrs, max_callback_queue_size_bytes);
}
@ -95,8 +97,8 @@ class OutfeedReceiverForPython {
}
// We expect the number of clients to be small, so an O(n) search is fine.
auto it = absl::c_find_if(
clients_, [device](const std::shared_ptr<PjRtClient>& client) {
return client.get() == device->client();
clients_, [device](const std::shared_ptr<PyClient>& client) {
return client->pjrt_client() == device->client();
});
CHECK(it != clients_.end());
py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython
@ -112,7 +114,7 @@ class OutfeedReceiverForPython {
CallbackToPython callback_python_;
absl::Mutex mu_;
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_) = false;
std::vector<std::shared_ptr<PjRtClient>> clients_;
std::vector<std::shared_ptr<PyClient>> clients_;
std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
};
@ -124,7 +126,7 @@ void BuildOutfeedReceiverSubmodule(py::module* m) {
outfeed_receiver.def(
"start",
[](OutfeedReceiverForPython::CallbackToPython callback_to_python,
std::vector<std::shared_ptr<PjRtClient>> clients,
std::vector<std::shared_ptr<PyClient>> clients,
ssize_t max_callback_queue_size_bytes)
-> std::unique_ptr<OutfeedReceiverForPython> {
auto server = absl::make_unique<OutfeedReceiverForPython>(

View File

@ -15,13 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
namespace xla {
namespace py = pybind11;
PyBuffer::PyBuffer(std::shared_ptr<PjRtClient> client,
PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtBuffer> buffer,
absl::optional<TracebackManager::Traceback> traceback)
: client_(std::move(client)),

View File

@ -20,9 +20,8 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/py_client.h"
#include "tensorflow/compiler/xla/python/traceback_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -33,11 +32,10 @@ namespace xla {
// b) to add Python-specific functionality.
class PyBuffer {
public:
PyBuffer(std::shared_ptr<PjRtClient> client,
std::unique_ptr<PjRtBuffer> buffer,
PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer,
absl::optional<TracebackManager::Traceback> traceback);
std::shared_ptr<PjRtClient> client() const { return client_; }
std::shared_ptr<PyClient> client() const { return client_; }
PjRtBuffer* buffer() const { return buffer_.get(); }
ClientAndPtr<Device> device() const;
@ -68,7 +66,7 @@ class PyBuffer {
}
private:
std::shared_ptr<PjRtClient> client_;
std::shared_ptr<PyClient> client_;
std::unique_ptr<PjRtBuffer> buffer_;
absl::optional<TracebackManager::Traceback> traceback_;
};

View File

@ -0,0 +1,130 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/py_client.h"
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/python/py_executable.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/traceback_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
namespace xla {
namespace py = pybind11;
PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
: pjrt_client_(std::move(pjrt_client)) {}
std::vector<ClientAndPtr<Device>> PyClient::Devices() {
std::vector<ClientAndPtr<Device>> devices;
devices.reserve(pjrt_client_->devices().size());
for (const auto& device : pjrt_client_->devices()) {
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
}
return devices;
}
std::vector<ClientAndPtr<Device>> PyClient::LocalDevices() {
std::vector<ClientAndPtr<Device>> devices;
devices.reserve(pjrt_client_->local_devices().size());
for (Device* device : pjrt_client_->local_devices()) {
devices.push_back(WrapWithClient(shared_from_this(), device));
}
return devices;
}
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
std::vector<std::vector<ClientAndPtr<Device>>> result;
result.resize(num_replicas);
for (int r = 0; r < num_replicas; ++r) {
result[r].resize(num_partitions);
for (int p = 0; p < num_partitions; ++p) {
int device_id = device_assignment(r, p);
auto iter = pjrt_client_->id_to_device().find(device_id);
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
result[r][p] = WrapWithClient(shared_from_this(), iter->second);
}
}
return result;
}
StatusOr<std::vector<ClientAndPtr<Device>>>
PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
pjrt_client_->GetDefaultDeviceAssignment(
num_replicas, /*num_partitions=*/1));
std::vector<ClientAndPtr<Device>> result;
for (int i = 0; i < num_replicas; ++i) {
int device_id = device_assignment(i, 0);
auto iter = pjrt_client_->id_to_device().find(device_id);
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
result.push_back(WrapWithClient(shared_from_this(), iter->second));
}
return result;
}
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy) {
if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
device = pjrt_client_->local_devices().front();
}
CHECK(device != nullptr);
auto iter = pjrt_client_->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
device->DebugString(),
pjrt_client_->platform_name());
}
GlobalPyRefManager()->CollectGarbage();
absl::optional<CastToArrayResult> c = CastToArray(argument);
if (!c) {
return InvalidArgument("from_python argument must be an array.");
}
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
GlobalPyRefManager()->ManageReference(std::move(c->array));
auto traceback = TracebackManager::Get()->GetTraceback();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer,
PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
std::move(py_buffer_ref), pjrt_client_.get(),
device));
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
traceback);
}
StatusOr<std::unique_ptr<PyExecutable>> PyClient::Compile(
const XlaComputation& computation, CompileOptions options) {
auto traceback = TracebackManager::Get()->GetTraceback();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
PjRtExecutable::Compile(computation, pjrt_client_.get(),
std::move(options)));
return std::make_unique<PyExecutable>(
shared_from_this(), std::move(executable), std::move(traceback));
}
} // namespace xla

View File

@ -0,0 +1,136 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
class PyBuffer;
class PyClient;
class PyExecutable;
// Custom holder types.
//
// We must keep the PyClient object alive as long as any of the runtime
// objects are alive. Since we don't have a lot of control over Python
// destructor ordering, we keep the PyClient object as a std::shared_ptr<>,
// and ensure that each Python runtime object holds a reference to the
// PyClient. An alternative design would be to keep a single global
// singleton PyClient, although this seems less flexible, especially for
// writing tests.
//
// To maintain PyClient references, we define pybind11 holder classes that
// are custom smart pointers that also keep a reference to a PyClient.
// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
// seem sufficiently flexible to describe ownership relationships in cases where
// the ownership doesn't pertain to a direct argument or return value of a
// function. Another alternative to the holder classes would be to create proxy
// objects that contain both a reference and a runtime class; holder classes
// seem less tedious to define.
// A pair of a PyClient reference and an unowned pointer to T.
template <typename T>
struct ClientAndPtr {
ClientAndPtr() = default;
// pybind11 requires that we define a constructor that takes a raw pointer,
// but it should be unreachable.
explicit ClientAndPtr(T*) {
LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
}
ClientAndPtr(const ClientAndPtr&) = default;
ClientAndPtr(ClientAndPtr&&) = default;
ClientAndPtr& operator=(const ClientAndPtr&) = default;
ClientAndPtr& operator=(ClientAndPtr&&) = default;
std::shared_ptr<PyClient> client;
T* contents;
T* get() const { return contents; }
T* operator->() const { return contents; }
T& operator*() const { return *contents; }
};
// By defining a templated helper function, we can use return type deduction
// and avoid specifying types at the caller.
template <typename T>
ClientAndPtr<T> WrapWithClient(std::shared_ptr<PyClient> client, T* contents) {
ClientAndPtr<T> result;
result.client = std::move(client);
result.contents = contents;
return result;
}
// Python wrapper around PjRtClient.
// We use a wrapper class to add Python-specific functionality.
class PyClient : public std::enable_shared_from_this<PyClient> {
public:
explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
const std::string& platform_name() const {
return pjrt_client_->platform_name();
}
int local_device_count() const { return pjrt_client_->local_device_count(); }
int device_count() const { return pjrt_client_->device_count(); }
int host_id() const { return pjrt_client_->host_id(); }
std::vector<ClientAndPtr<Device>> Devices();
std::vector<ClientAndPtr<Device>> LocalDevices();
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
// TODO(skye): delete after all callers can handle 2D output
StatusOr<std::vector<ClientAndPtr<Device>>> GetDefaultDeviceAssignment1D(
int num_replicas);
StatusOr<ChannelHandle> CreateChannelHandle() {
return pjrt_client_->client()->CreateChannelHandle();
}
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
return pjrt_client_->client()->CreateDeviceToHostChannelHandle();
}
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
return pjrt_client_->client()->CreateHostToDeviceChannelHandle();
}
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy);
StatusOr<std::unique_ptr<PyExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);
private:
std::shared_ptr<PjRtClient> pjrt_client_;
};
} // namespace xla
PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_

View File

@ -22,7 +22,7 @@ namespace xla {
namespace py = pybind11;
PyExecutable::PyExecutable(
std::shared_ptr<PjRtClient> client,
std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtExecutable> executable,
absl::optional<TracebackManager::Traceback> traceback)
: client_(std::move(client)),

View File

@ -24,23 +24,23 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/python/py_client.h"
#include "tensorflow/compiler/xla/python/traceback_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
// Python wrapper around PjRtExecutable. We use a wrapper class:
// a) to keep the PjRtClient alive via a std::shared_ptr<>
// a) to keep the PyClient alive via a std::shared_ptr<>
// b) to add Python-specific functionality.
class PyExecutable {
public:
PyExecutable(std::shared_ptr<PjRtClient> client,
PyExecutable(std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtExecutable> executable,
absl::optional<TracebackManager::Traceback> traceback);
std::shared_ptr<PjRtClient> client() const { return client_; }
std::shared_ptr<PyClient> client() const { return client_; }
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
return executable_->local_logical_device_ids();
@ -67,7 +67,7 @@ class PyExecutable {
}
private:
std::shared_ptr<PjRtClient> client_;
std::shared_ptr<PyClient> client_;
std::unique_ptr<PjRtExecutable> executable_;
absl::optional<TracebackManager::Traceback> traceback_;
};

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -36,65 +35,6 @@ limitations under the License.
namespace xla {
// Custom holder types.
//
// We must keep the PjRtClient object alive as long as any of the runtime
// objects are alive. Since we don't have a lot of control over Python
// destructor ordering, we keep the PjRtClient object as a std::shared_ptr<>,
// and ensure that each Python runtime object holds a reference to the
// PjRtClient. An alternative design would be to keep a single global
// singleton PjRtClient, although this seems less flexible, especially for
// writing tests.
//
// To maintain PjRtClient references, we define pybind11 holder classes that
// are custom smart pointers that also keep a reference to a PjRtClient.
// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
// seem sufficiently flexible to describe ownership relationships in cases where
// the ownership doesn't pertain to a direct argument or return value of a
// function. Another alternative to the holder classes would be to create proxy
// objects that contain both a reference and a runtime class; holder classes
// seem less tedious to define.
// A pair of a PjRtClient reference and an unowned pointer to T.
template <typename T>
struct ClientAndPtr {
ClientAndPtr() = default;
// pybind11 requires that we define a constructor that takes a raw pointer,
// but it should be unreachable.
explicit ClientAndPtr(T*) {
LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
}
ClientAndPtr(const ClientAndPtr&) = default;
ClientAndPtr(ClientAndPtr&&) = default;
ClientAndPtr& operator=(const ClientAndPtr&) = default;
ClientAndPtr& operator=(ClientAndPtr&&) = default;
std::shared_ptr<PjRtClient> client;
T* contents;
T* get() const { return contents; }
T* operator->() const { return contents; }
T& operator*() const { return *contents; }
};
// By defining a templated helper function, we can use return type deduction
// and avoid specifying types at the caller.
template <typename T>
ClientAndPtr<T> WrapWithClient(std::shared_ptr<PjRtClient> client,
T* contents) {
ClientAndPtr<T> result;
result.client = std::move(client);
result.contents = contents;
return result;
}
} // namespace xla
PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
namespace xla {
// Initializes the NumPy API for the use of the types module.
bool InitializeNumpyAPIForTypes();

View File

@ -501,138 +501,55 @@ PYBIND11_MODULE(xla_extension, m) {
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
py::class_<PjRtClient, std::shared_ptr<PjRtClient>> py_local_client(
m, "LocalClient");
py_local_client.def_property_readonly("platform", &PjRtClient::platform_name)
.def("device_count", &PjRtClient::device_count)
.def("local_device_count", &PjRtClient::local_device_count)
.def("devices",
[](std::shared_ptr<PjRtClient> client) {
std::vector<ClientAndPtr<Device>> devices;
devices.reserve(client->devices().size());
for (const auto& device : client->devices()) {
devices.push_back(WrapWithClient(client, device.get()));
}
return devices;
})
.def("local_devices",
[](std::shared_ptr<PjRtClient> client) {
std::vector<ClientAndPtr<Device>> devices;
devices.reserve(client->local_devices().size());
for (Device* device : client->local_devices()) {
devices.push_back(WrapWithClient(client, device));
}
return devices;
})
.def("host_id", &PjRtClient::host_id)
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
.def("device_count", &PyClient::device_count)
.def("local_device_count", &PyClient::local_device_count)
.def("devices", &PyClient::Devices)
.def("local_devices", &PyClient::LocalDevices)
.def("host_id", &PyClient::host_id)
.def("get_default_device_assignment",
[](std::shared_ptr<PjRtClient> client, int num_replicas,
int num_partitions)
-> StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>> {
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
client->GetDefaultDeviceAssignment(
num_replicas, num_partitions));
std::vector<std::vector<ClientAndPtr<Device>>> result;
result.resize(num_replicas);
for (int r = 0; r < num_replicas; ++r) {
result[r].resize(num_partitions);
for (int p = 0; p < num_partitions; ++p) {
int device_id = device_assignment(r, p);
auto iter = client->id_to_device().find(device_id);
CHECK(iter != client->id_to_device().end()) << device_id;
result[r][p] = WrapWithClient(client, iter->second);
}
}
return result;
})
&PyClient::GetDefaultDeviceAssignment)
// TODO(skye): delete after all callers can handle 2D output
.def("get_default_device_assignment",
[](std::shared_ptr<PjRtClient> client,
int num_replicas) -> StatusOr<std::vector<ClientAndPtr<Device>>> {
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
client->GetDefaultDeviceAssignment(
num_replicas, /*num_partitions=*/1));
std::vector<ClientAndPtr<Device>> result;
for (int i = 0; i < num_replicas; ++i) {
int device_id = device_assignment(i, 0);
auto iter = client->id_to_device().find(device_id);
CHECK(iter != client->id_to_device().end()) << device_id;
result.push_back(WrapWithClient(client, iter->second));
}
return result;
})
.def("create_channel_handle",
[](PjRtClient* client) {
return client->client()->CreateChannelHandle();
})
&PyClient::GetDefaultDeviceAssignment1D)
.def("create_channel_handle", &PyClient::CreateChannelHandle)
.def("create_device_to_host_channel_handle",
[](PjRtClient* client) {
return client->client()->CreateDeviceToHostChannelHandle();
})
.def("create_host_to_device_channel_handle", [](PjRtClient* client) {
return client->client()->CreateHostToDeviceChannelHandle();
});
py_local_client.def(
"buffer_from_pyval",
[](std::shared_ptr<PjRtClient> client, const pybind11::object& argument,
Device* device,
bool force_copy) -> StatusOr<std::unique_ptr<PyBuffer>> {
if (device == nullptr) {
TF_RET_CHECK(!client->local_devices().empty());
device = client->local_devices().front();
}
CHECK(device != nullptr);
auto iter = client->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument(
"Cannot copy value to device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
GlobalPyRefManager()->CollectGarbage();
&PyClient::CreateDeviceToHostChannelHandle)
.def("create_host_to_device_channel_handle",
&PyClient::CreateHostToDeviceChannelHandle)
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
py::arg("device") = nullptr, py::arg("force_copy") = false)
.def("compile", &PyClient::Compile, py::arg("computation"),
py::arg("compile_options") = CompileOptions());
absl::optional<CastToArrayResult> c = CastToArray(argument);
if (!c) {
return InvalidArgument("from_python argument must be an array.");
}
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
GetPythonBufferTree(argument));
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
GlobalPyRefManager()->ManageReference(std::move(c->array));
auto traceback = TracebackManager::Get()->GetTraceback();
py::gil_scoped_release gil_release;
m.def(
"get_cpu_client",
[](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
GetCpuClient(asynchronous));
return std::make_shared<PyClient>(std::move(client));
},
py::arg("asynchronous") = true);
m.def("get_interpreter_client", []() -> StatusOr<std::shared_ptr<PyClient>> {
TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
GetInterpreterClient());
return std::make_shared<PyClient>(std::move(client));
});
m.def(
"get_nvidia_gpu_client",
[](bool asynchronous, const GpuAllocatorConfig& allocator_config,
std::shared_ptr<DistributedRuntimeClient> distributed_client,
int node_id) -> StatusOr<std::shared_ptr<PyClient>> {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer,
PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
std::move(py_buffer_ref), client.get(),
device));
return std::make_unique<PyBuffer>(std::move(client), std::move(buffer),
traceback);
std::shared_ptr<PjRtClient> client,
GetNvidiaGpuClient(asynchronous, allocator_config,
std::move(distributed_client), node_id));
return std::make_shared<PyClient>(std::move(client));
},
py::arg("argument"), py::arg("device") = nullptr,
py::arg("force_copy") = false);
py_local_client.def(
"compile",
[](std::shared_ptr<PjRtClient> client, const XlaComputation& computation,
CompileOptions options) -> StatusOr<std::unique_ptr<PyExecutable>> {
auto traceback = TracebackManager::Get()->GetTraceback();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
PjRtExecutable::Compile(computation, client.get(),
std::move(options)));
return std::make_unique<PyExecutable>(
std::move(client), std::move(executable), std::move(traceback));
},
py::arg("computation"), py::arg("compile_options") = CompileOptions());
m.def("get_cpu_client", &GetCpuClient, py::arg("asynchronous") = true);
m.def("get_interpreter_client", &GetInterpreterClient);
m.def("get_nvidia_gpu_client", &GetNvidiaGpuClient,
py::arg("asynchronous") = true,
py::arg("allocator_config") = GpuAllocatorConfig(),
py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
py::arg("asynchronous") = true,
py::arg("allocator_config") = GpuAllocatorConfig(),
py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
py::class_<TracebackManager::Frame>(m, "Frame")
.def_readonly("file_name", &TracebackManager::Frame::file_name)

View File

@ -408,7 +408,7 @@ def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
XlaBuilder = _xla.XlaBuilder
XlaComputation = _xla.XlaComputation
FftType = _xla.FftType
Client = _xla.LocalClient
Client = _xla.Client
Buffer = _xla.Buffer
Executable = _xla.Executable