[XLA:Python] Add a PyClient wrapper object around PjRtClient.
Refactoring in preparation for adding Python-specific logic to clients, but it also is more readable than inlining this kind of logic into the binding code. PiperOrigin-RevId: 314994078 Change-Id: I3b7253d4a14ff418068e49f9bd321d85695f9c8b
This commit is contained in:
parent
e96c412005
commit
3f7adb0d48
@ -180,9 +180,17 @@ py_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_buffer",
|
||||
srcs = ["py_buffer.cc"],
|
||||
hdrs = ["py_buffer.h"],
|
||||
name = "py_client",
|
||||
srcs = [
|
||||
"py_buffer.cc",
|
||||
"py_client.cc",
|
||||
"py_executable.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"py_buffer.h",
|
||||
"py_client.h",
|
||||
"py_executable.h",
|
||||
],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -195,29 +203,10 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_executable",
|
||||
srcs = ["py_executable.cc"],
|
||||
hdrs = ["py_executable.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":py_buffer",
|
||||
":traceback_manager",
|
||||
":types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
@ -231,7 +220,7 @@ cc_library(
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":py_buffer",
|
||||
":py_client",
|
||||
":traceback_manager",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -325,6 +314,7 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":outfeed_receiver",
|
||||
":py_client",
|
||||
":types",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
@ -355,8 +345,7 @@ pybind_extension(
|
||||
":bfloat16",
|
||||
":dlpack",
|
||||
":ops",
|
||||
":py_buffer",
|
||||
":py_executable",
|
||||
":py_client",
|
||||
":python_ref_manager",
|
||||
":outfeed_receiver_py",
|
||||
":traceback_manager",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -298,7 +299,7 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(PyBuffer* buffer) {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
||||
const pybind11::capsule& tensor, std::shared_ptr<PjRtClient> client) {
|
||||
const pybind11::capsule& tensor, std::shared_ptr<PyClient> client) {
|
||||
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
|
||||
return InvalidArgument(
|
||||
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
|
||||
@ -311,8 +312,9 @@ StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
||||
"Number of dimensions in DLManagedTensor must be nonnegative, got %d",
|
||||
dlmt->dl_tensor.ndim);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Device * device,
|
||||
DeviceForDLContext(*client, dlmt->dl_tensor.ctx));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Device * device,
|
||||
DeviceForDLContext(*client->pjrt_client(), dlmt->dl_tensor.ctx));
|
||||
absl::Span<int64 const> dimensions(
|
||||
reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
|
||||
TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
|
||||
@ -349,7 +351,7 @@ StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
||||
PyCapsule_SetName(tensor.ptr(), "used_dltensor");
|
||||
PyCapsule_SetDestructor(tensor.ptr(), nullptr);
|
||||
auto pjrt_buffer = std::make_unique<PjRtBuffer>(
|
||||
shape, shape, std::move(device_buffer), client.get(), device);
|
||||
shape, shape, std::move(device_buffer), client->pjrt_client(), device);
|
||||
return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer),
|
||||
TracebackManager::Get()->GetTraceback());
|
||||
}
|
||||
|
@ -17,15 +17,15 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<pybind11::capsule> BufferToDLPackManagedTensor(PyBuffer* buffer);
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
||||
const pybind11::capsule& tensor, std::shared_ptr<PjRtClient> client);
|
||||
const pybind11::capsule& tensor, std::shared_ptr<PyClient> client);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/outfeed_receiver.h"
|
||||
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
|
||||
namespace xla {
|
||||
@ -42,7 +43,7 @@ class OutfeedReceiverForPython {
|
||||
std::function<void(ClientAndPtr<Device>, uint32_t, pybind11::object)>;
|
||||
|
||||
OutfeedReceiverForPython(CallbackToPython callback_python,
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients,
|
||||
std::vector<std::shared_ptr<PyClient>> clients,
|
||||
ssize_t max_callback_queue_size_bytes)
|
||||
: callback_python_(std::move(callback_python)),
|
||||
clients_(std::move(clients)) {
|
||||
@ -52,9 +53,10 @@ class OutfeedReceiverForPython {
|
||||
this->Callback(device, consumer_id, std::move(literal));
|
||||
};
|
||||
std::vector<PjRtClient*> client_ptrs(clients.size());
|
||||
absl::c_transform(
|
||||
clients_, client_ptrs.begin(),
|
||||
[](const std::shared_ptr<PjRtClient>& client) { return client.get(); });
|
||||
absl::c_transform(clients_, client_ptrs.begin(),
|
||||
[](const std::shared_ptr<PyClient>& client) {
|
||||
return client->pjrt_client();
|
||||
});
|
||||
outfeed_receiver_ = absl::make_unique<OutfeedReceiver>(
|
||||
callback, client_ptrs, max_callback_queue_size_bytes);
|
||||
}
|
||||
@ -95,8 +97,8 @@ class OutfeedReceiverForPython {
|
||||
}
|
||||
// We expect the number of clients to be small, so an O(n) search is fine.
|
||||
auto it = absl::c_find_if(
|
||||
clients_, [device](const std::shared_ptr<PjRtClient>& client) {
|
||||
return client.get() == device->client();
|
||||
clients_, [device](const std::shared_ptr<PyClient>& client) {
|
||||
return client->pjrt_client() == device->client();
|
||||
});
|
||||
CHECK(it != clients_.end());
|
||||
py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython
|
||||
@ -112,7 +114,7 @@ class OutfeedReceiverForPython {
|
||||
CallbackToPython callback_python_;
|
||||
absl::Mutex mu_;
|
||||
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_) = false;
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients_;
|
||||
std::vector<std::shared_ptr<PyClient>> clients_;
|
||||
std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
|
||||
};
|
||||
|
||||
@ -124,7 +126,7 @@ void BuildOutfeedReceiverSubmodule(py::module* m) {
|
||||
outfeed_receiver.def(
|
||||
"start",
|
||||
[](OutfeedReceiverForPython::CallbackToPython callback_to_python,
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients,
|
||||
std::vector<std::shared_ptr<PyClient>> clients,
|
||||
ssize_t max_callback_queue_size_bytes)
|
||||
-> std::unique_ptr<OutfeedReceiverForPython> {
|
||||
auto server = absl::make_unique<OutfeedReceiverForPython>(
|
||||
|
@ -15,13 +15,15 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PyBuffer::PyBuffer(std::shared_ptr<PjRtClient> client,
|
||||
PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
absl::optional<TracebackManager::Traceback> traceback)
|
||||
: client_(std::move(client)),
|
||||
|
@ -20,9 +20,8 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
@ -33,11 +32,10 @@ namespace xla {
|
||||
// b) to add Python-specific functionality.
|
||||
class PyBuffer {
|
||||
public:
|
||||
PyBuffer(std::shared_ptr<PjRtClient> client,
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer,
|
||||
absl::optional<TracebackManager::Traceback> traceback);
|
||||
|
||||
std::shared_ptr<PjRtClient> client() const { return client_; }
|
||||
std::shared_ptr<PyClient> client() const { return client_; }
|
||||
PjRtBuffer* buffer() const { return buffer_.get(); }
|
||||
|
||||
ClientAndPtr<Device> device() const;
|
||||
@ -68,7 +66,7 @@ class PyBuffer {
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<PjRtClient> client_;
|
||||
std::shared_ptr<PyClient> client_;
|
||||
std::unique_ptr<PjRtBuffer> buffer_;
|
||||
absl::optional<TracebackManager::Traceback> traceback_;
|
||||
};
|
||||
|
130
tensorflow/compiler/xla/python/py_client.cc
Normal file
130
tensorflow/compiler/xla/python/py_client.cc
Normal file
@ -0,0 +1,130 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||
#include "tensorflow/compiler/xla/python/py_executable.h"
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
|
||||
: pjrt_client_(std::move(pjrt_client)) {}
|
||||
|
||||
std::vector<ClientAndPtr<Device>> PyClient::Devices() {
|
||||
std::vector<ClientAndPtr<Device>> devices;
|
||||
devices.reserve(pjrt_client_->devices().size());
|
||||
for (const auto& device : pjrt_client_->devices()) {
|
||||
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
|
||||
std::vector<ClientAndPtr<Device>> PyClient::LocalDevices() {
|
||||
std::vector<ClientAndPtr<Device>> devices;
|
||||
devices.reserve(pjrt_client_->local_devices().size());
|
||||
for (Device* device : pjrt_client_->local_devices()) {
|
||||
devices.push_back(WrapWithClient(shared_from_this(), device));
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
|
||||
PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment device_assignment,
|
||||
pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
|
||||
std::vector<std::vector<ClientAndPtr<Device>>> result;
|
||||
result.resize(num_replicas);
|
||||
for (int r = 0; r < num_replicas; ++r) {
|
||||
result[r].resize(num_partitions);
|
||||
for (int p = 0; p < num_partitions; ++p) {
|
||||
int device_id = device_assignment(r, p);
|
||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
||||
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
||||
result[r][p] = WrapWithClient(shared_from_this(), iter->second);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<ClientAndPtr<Device>>>
|
||||
PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
pjrt_client_->GetDefaultDeviceAssignment(
|
||||
num_replicas, /*num_partitions=*/1));
|
||||
std::vector<ClientAndPtr<Device>> result;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
int device_id = device_assignment(i, 0);
|
||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
||||
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
||||
result.push_back(WrapWithClient(shared_from_this(), iter->second));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
|
||||
const pybind11::object& argument, Device* device, bool force_copy) {
|
||||
if (device == nullptr) {
|
||||
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
||||
device = pjrt_client_->local_devices().front();
|
||||
}
|
||||
CHECK(device != nullptr);
|
||||
auto iter = pjrt_client_->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
|
||||
device->DebugString(),
|
||||
pjrt_client_->platform_name());
|
||||
}
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
|
||||
absl::optional<CastToArrayResult> c = CastToArray(argument);
|
||||
if (!c) {
|
||||
return InvalidArgument("from_python argument must be an array.");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
|
||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
||||
|
||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
|
||||
std::move(py_buffer_ref), pjrt_client_.get(),
|
||||
device));
|
||||
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
||||
traceback);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyExecutable>> PyClient::Compile(
|
||||
const XlaComputation& computation, CompileOptions options) {
|
||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
||||
PjRtExecutable::Compile(computation, pjrt_client_.get(),
|
||||
std::move(options)));
|
||||
return std::make_unique<PyExecutable>(
|
||||
shared_from_this(), std::move(executable), std::move(traceback));
|
||||
}
|
||||
|
||||
} // namespace xla
|
136
tensorflow/compiler/xla/python/py_client.h
Normal file
136
tensorflow/compiler/xla/python/py_client.h
Normal file
@ -0,0 +1,136 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class PyBuffer;
|
||||
class PyClient;
|
||||
class PyExecutable;
|
||||
|
||||
// Custom holder types.
|
||||
//
|
||||
// We must keep the PyClient object alive as long as any of the runtime
|
||||
// objects are alive. Since we don't have a lot of control over Python
|
||||
// destructor ordering, we keep the PyClient object as a std::shared_ptr<>,
|
||||
// and ensure that each Python runtime object holds a reference to the
|
||||
// PyClient. An alternative design would be to keep a single global
|
||||
// singleton PyClient, although this seems less flexible, especially for
|
||||
// writing tests.
|
||||
//
|
||||
// To maintain PyClient references, we define pybind11 holder classes that
|
||||
// are custom smart pointers that also keep a reference to a PyClient.
|
||||
// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
|
||||
// seem sufficiently flexible to describe ownership relationships in cases where
|
||||
// the ownership doesn't pertain to a direct argument or return value of a
|
||||
// function. Another alternative to the holder classes would be to create proxy
|
||||
// objects that contain both a reference and a runtime class; holder classes
|
||||
// seem less tedious to define.
|
||||
|
||||
// A pair of a PyClient reference and an unowned pointer to T.
|
||||
template <typename T>
|
||||
struct ClientAndPtr {
|
||||
ClientAndPtr() = default;
|
||||
// pybind11 requires that we define a constructor that takes a raw pointer,
|
||||
// but it should be unreachable.
|
||||
explicit ClientAndPtr(T*) {
|
||||
LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
|
||||
}
|
||||
|
||||
ClientAndPtr(const ClientAndPtr&) = default;
|
||||
ClientAndPtr(ClientAndPtr&&) = default;
|
||||
ClientAndPtr& operator=(const ClientAndPtr&) = default;
|
||||
ClientAndPtr& operator=(ClientAndPtr&&) = default;
|
||||
|
||||
std::shared_ptr<PyClient> client;
|
||||
T* contents;
|
||||
|
||||
T* get() const { return contents; }
|
||||
T* operator->() const { return contents; }
|
||||
T& operator*() const { return *contents; }
|
||||
};
|
||||
|
||||
// By defining a templated helper function, we can use return type deduction
|
||||
// and avoid specifying types at the caller.
|
||||
template <typename T>
|
||||
ClientAndPtr<T> WrapWithClient(std::shared_ptr<PyClient> client, T* contents) {
|
||||
ClientAndPtr<T> result;
|
||||
result.client = std::move(client);
|
||||
result.contents = contents;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Python wrapper around PjRtClient.
|
||||
// We use a wrapper class to add Python-specific functionality.
|
||||
class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
public:
|
||||
explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
|
||||
|
||||
PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
|
||||
|
||||
const std::string& platform_name() const {
|
||||
return pjrt_client_->platform_name();
|
||||
}
|
||||
int local_device_count() const { return pjrt_client_->local_device_count(); }
|
||||
int device_count() const { return pjrt_client_->device_count(); }
|
||||
int host_id() const { return pjrt_client_->host_id(); }
|
||||
|
||||
std::vector<ClientAndPtr<Device>> Devices();
|
||||
std::vector<ClientAndPtr<Device>> LocalDevices();
|
||||
|
||||
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
|
||||
GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
|
||||
|
||||
// TODO(skye): delete after all callers can handle 2D output
|
||||
StatusOr<std::vector<ClientAndPtr<Device>>> GetDefaultDeviceAssignment1D(
|
||||
int num_replicas);
|
||||
|
||||
StatusOr<ChannelHandle> CreateChannelHandle() {
|
||||
return pjrt_client_->client()->CreateChannelHandle();
|
||||
}
|
||||
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
|
||||
return pjrt_client_->client()->CreateDeviceToHostChannelHandle();
|
||||
}
|
||||
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
|
||||
return pjrt_client_->client()->CreateHostToDeviceChannelHandle();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
|
||||
const pybind11::object& argument, Device* device, bool force_copy);
|
||||
|
||||
StatusOr<std::unique_ptr<PyExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options);
|
||||
|
||||
private:
|
||||
std::shared_ptr<PjRtClient> pjrt_client_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
|
@ -22,7 +22,7 @@ namespace xla {
|
||||
namespace py = pybind11;
|
||||
|
||||
PyExecutable::PyExecutable(
|
||||
std::shared_ptr<PjRtClient> client,
|
||||
std::shared_ptr<PyClient> client,
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
absl::optional<TracebackManager::Traceback> traceback)
|
||||
: client_(std::move(client)),
|
||||
|
@ -24,23 +24,23 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Python wrapper around PjRtExecutable. We use a wrapper class:
|
||||
// a) to keep the PjRtClient alive via a std::shared_ptr<>
|
||||
// a) to keep the PyClient alive via a std::shared_ptr<>
|
||||
// b) to add Python-specific functionality.
|
||||
class PyExecutable {
|
||||
public:
|
||||
PyExecutable(std::shared_ptr<PjRtClient> client,
|
||||
PyExecutable(std::shared_ptr<PyClient> client,
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
absl::optional<TracebackManager::Traceback> traceback);
|
||||
|
||||
std::shared_ptr<PjRtClient> client() const { return client_; }
|
||||
std::shared_ptr<PyClient> client() const { return client_; }
|
||||
|
||||
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
|
||||
return executable_->local_logical_device_ids();
|
||||
@ -67,7 +67,7 @@ class PyExecutable {
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<PjRtClient> client_;
|
||||
std::shared_ptr<PyClient> client_;
|
||||
std::unique_ptr<PjRtExecutable> executable_;
|
||||
absl::optional<TracebackManager::Traceback> traceback_;
|
||||
};
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -36,65 +35,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Custom holder types.
|
||||
//
|
||||
// We must keep the PjRtClient object alive as long as any of the runtime
|
||||
// objects are alive. Since we don't have a lot of control over Python
|
||||
// destructor ordering, we keep the PjRtClient object as a std::shared_ptr<>,
|
||||
// and ensure that each Python runtime object holds a reference to the
|
||||
// PjRtClient. An alternative design would be to keep a single global
|
||||
// singleton PjRtClient, although this seems less flexible, especially for
|
||||
// writing tests.
|
||||
//
|
||||
// To maintain PjRtClient references, we define pybind11 holder classes that
|
||||
// are custom smart pointers that also keep a reference to a PjRtClient.
|
||||
// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
|
||||
// seem sufficiently flexible to describe ownership relationships in cases where
|
||||
// the ownership doesn't pertain to a direct argument or return value of a
|
||||
// function. Another alternative to the holder classes would be to create proxy
|
||||
// objects that contain both a reference and a runtime class; holder classes
|
||||
// seem less tedious to define.
|
||||
|
||||
// A pair of a PjRtClient reference and an unowned pointer to T.
|
||||
template <typename T>
|
||||
struct ClientAndPtr {
|
||||
ClientAndPtr() = default;
|
||||
// pybind11 requires that we define a constructor that takes a raw pointer,
|
||||
// but it should be unreachable.
|
||||
explicit ClientAndPtr(T*) {
|
||||
LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
|
||||
}
|
||||
|
||||
ClientAndPtr(const ClientAndPtr&) = default;
|
||||
ClientAndPtr(ClientAndPtr&&) = default;
|
||||
ClientAndPtr& operator=(const ClientAndPtr&) = default;
|
||||
ClientAndPtr& operator=(ClientAndPtr&&) = default;
|
||||
|
||||
std::shared_ptr<PjRtClient> client;
|
||||
T* contents;
|
||||
|
||||
T* get() const { return contents; }
|
||||
T* operator->() const { return contents; }
|
||||
T& operator*() const { return *contents; }
|
||||
};
|
||||
|
||||
// By defining a templated helper function, we can use return type deduction
|
||||
// and avoid specifying types at the caller.
|
||||
template <typename T>
|
||||
ClientAndPtr<T> WrapWithClient(std::shared_ptr<PjRtClient> client,
|
||||
T* contents) {
|
||||
ClientAndPtr<T> result;
|
||||
result.client = std::move(client);
|
||||
result.contents = contents;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Initializes the NumPy API for the use of the types module.
|
||||
bool InitializeNumpyAPIForTypes();
|
||||
|
||||
|
@ -501,138 +501,55 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
|
||||
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
|
||||
|
||||
py::class_<PjRtClient, std::shared_ptr<PjRtClient>> py_local_client(
|
||||
m, "LocalClient");
|
||||
py_local_client.def_property_readonly("platform", &PjRtClient::platform_name)
|
||||
.def("device_count", &PjRtClient::device_count)
|
||||
.def("local_device_count", &PjRtClient::local_device_count)
|
||||
.def("devices",
|
||||
[](std::shared_ptr<PjRtClient> client) {
|
||||
std::vector<ClientAndPtr<Device>> devices;
|
||||
devices.reserve(client->devices().size());
|
||||
for (const auto& device : client->devices()) {
|
||||
devices.push_back(WrapWithClient(client, device.get()));
|
||||
}
|
||||
return devices;
|
||||
})
|
||||
.def("local_devices",
|
||||
[](std::shared_ptr<PjRtClient> client) {
|
||||
std::vector<ClientAndPtr<Device>> devices;
|
||||
devices.reserve(client->local_devices().size());
|
||||
for (Device* device : client->local_devices()) {
|
||||
devices.push_back(WrapWithClient(client, device));
|
||||
}
|
||||
return devices;
|
||||
})
|
||||
.def("host_id", &PjRtClient::host_id)
|
||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||
.def("device_count", &PyClient::device_count)
|
||||
.def("local_device_count", &PyClient::local_device_count)
|
||||
.def("devices", &PyClient::Devices)
|
||||
.def("local_devices", &PyClient::LocalDevices)
|
||||
.def("host_id", &PyClient::host_id)
|
||||
.def("get_default_device_assignment",
|
||||
[](std::shared_ptr<PjRtClient> client, int num_replicas,
|
||||
int num_partitions)
|
||||
-> StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>> {
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
client->GetDefaultDeviceAssignment(
|
||||
num_replicas, num_partitions));
|
||||
std::vector<std::vector<ClientAndPtr<Device>>> result;
|
||||
result.resize(num_replicas);
|
||||
for (int r = 0; r < num_replicas; ++r) {
|
||||
result[r].resize(num_partitions);
|
||||
for (int p = 0; p < num_partitions; ++p) {
|
||||
int device_id = device_assignment(r, p);
|
||||
auto iter = client->id_to_device().find(device_id);
|
||||
CHECK(iter != client->id_to_device().end()) << device_id;
|
||||
result[r][p] = WrapWithClient(client, iter->second);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
})
|
||||
&PyClient::GetDefaultDeviceAssignment)
|
||||
// TODO(skye): delete after all callers can handle 2D output
|
||||
.def("get_default_device_assignment",
|
||||
[](std::shared_ptr<PjRtClient> client,
|
||||
int num_replicas) -> StatusOr<std::vector<ClientAndPtr<Device>>> {
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
client->GetDefaultDeviceAssignment(
|
||||
num_replicas, /*num_partitions=*/1));
|
||||
std::vector<ClientAndPtr<Device>> result;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
int device_id = device_assignment(i, 0);
|
||||
auto iter = client->id_to_device().find(device_id);
|
||||
CHECK(iter != client->id_to_device().end()) << device_id;
|
||||
result.push_back(WrapWithClient(client, iter->second));
|
||||
}
|
||||
return result;
|
||||
})
|
||||
.def("create_channel_handle",
|
||||
[](PjRtClient* client) {
|
||||
return client->client()->CreateChannelHandle();
|
||||
})
|
||||
&PyClient::GetDefaultDeviceAssignment1D)
|
||||
.def("create_channel_handle", &PyClient::CreateChannelHandle)
|
||||
.def("create_device_to_host_channel_handle",
|
||||
[](PjRtClient* client) {
|
||||
return client->client()->CreateDeviceToHostChannelHandle();
|
||||
})
|
||||
.def("create_host_to_device_channel_handle", [](PjRtClient* client) {
|
||||
return client->client()->CreateHostToDeviceChannelHandle();
|
||||
});
|
||||
py_local_client.def(
|
||||
"buffer_from_pyval",
|
||||
[](std::shared_ptr<PjRtClient> client, const pybind11::object& argument,
|
||||
Device* device,
|
||||
bool force_copy) -> StatusOr<std::unique_ptr<PyBuffer>> {
|
||||
if (device == nullptr) {
|
||||
TF_RET_CHECK(!client->local_devices().empty());
|
||||
device = client->local_devices().front();
|
||||
}
|
||||
CHECK(device != nullptr);
|
||||
auto iter = client->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
return InvalidArgument(
|
||||
"Cannot copy value to device '%s' with '%s' backend",
|
||||
device->DebugString(), client->platform_name());
|
||||
}
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
&PyClient::CreateDeviceToHostChannelHandle)
|
||||
.def("create_host_to_device_channel_handle",
|
||||
&PyClient::CreateHostToDeviceChannelHandle)
|
||||
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
|
||||
py::arg("device") = nullptr, py::arg("force_copy") = false)
|
||||
.def("compile", &PyClient::Compile, py::arg("computation"),
|
||||
py::arg("compile_options") = CompileOptions());
|
||||
|
||||
absl::optional<CastToArrayResult> c = CastToArray(argument);
|
||||
if (!c) {
|
||||
return InvalidArgument("from_python argument must be an array.");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
|
||||
GetPythonBufferTree(argument));
|
||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
||||
|
||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||
|
||||
py::gil_scoped_release gil_release;
|
||||
m.def(
|
||||
"get_cpu_client",
|
||||
[](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
|
||||
GetCpuClient(asynchronous));
|
||||
return std::make_shared<PyClient>(std::move(client));
|
||||
},
|
||||
py::arg("asynchronous") = true);
|
||||
m.def("get_interpreter_client", []() -> StatusOr<std::shared_ptr<PyClient>> {
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
|
||||
GetInterpreterClient());
|
||||
return std::make_shared<PyClient>(std::move(client));
|
||||
});
|
||||
m.def(
|
||||
"get_nvidia_gpu_client",
|
||||
[](bool asynchronous, const GpuAllocatorConfig& allocator_config,
|
||||
std::shared_ptr<DistributedRuntimeClient> distributed_client,
|
||||
int node_id) -> StatusOr<std::shared_ptr<PyClient>> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
|
||||
std::move(py_buffer_ref), client.get(),
|
||||
device));
|
||||
return std::make_unique<PyBuffer>(std::move(client), std::move(buffer),
|
||||
traceback);
|
||||
std::shared_ptr<PjRtClient> client,
|
||||
GetNvidiaGpuClient(asynchronous, allocator_config,
|
||||
std::move(distributed_client), node_id));
|
||||
return std::make_shared<PyClient>(std::move(client));
|
||||
},
|
||||
py::arg("argument"), py::arg("device") = nullptr,
|
||||
py::arg("force_copy") = false);
|
||||
py_local_client.def(
|
||||
"compile",
|
||||
[](std::shared_ptr<PjRtClient> client, const XlaComputation& computation,
|
||||
CompileOptions options) -> StatusOr<std::unique_ptr<PyExecutable>> {
|
||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
||||
PjRtExecutable::Compile(computation, client.get(),
|
||||
std::move(options)));
|
||||
return std::make_unique<PyExecutable>(
|
||||
std::move(client), std::move(executable), std::move(traceback));
|
||||
},
|
||||
py::arg("computation"), py::arg("compile_options") = CompileOptions());
|
||||
|
||||
m.def("get_cpu_client", &GetCpuClient, py::arg("asynchronous") = true);
|
||||
m.def("get_interpreter_client", &GetInterpreterClient);
|
||||
m.def("get_nvidia_gpu_client", &GetNvidiaGpuClient,
|
||||
py::arg("asynchronous") = true,
|
||||
py::arg("allocator_config") = GpuAllocatorConfig(),
|
||||
py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
|
||||
py::arg("asynchronous") = true,
|
||||
py::arg("allocator_config") = GpuAllocatorConfig(),
|
||||
py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
|
||||
|
||||
py::class_<TracebackManager::Frame>(m, "Frame")
|
||||
.def_readonly("file_name", &TracebackManager::Frame::file_name)
|
||||
|
@ -408,7 +408,7 @@ def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
|
||||
XlaBuilder = _xla.XlaBuilder
|
||||
XlaComputation = _xla.XlaComputation
|
||||
FftType = _xla.FftType
|
||||
Client = _xla.LocalClient
|
||||
Client = _xla.Client
|
||||
Buffer = _xla.Buffer
|
||||
Executable = _xla.Executable
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user