[XLA:Python] Add a PyClient wrapper object around PjRtClient.
Refactoring in preparation for adding Python-specific logic to clients, but it also is more readable than inlining this kind of logic into the binding code. PiperOrigin-RevId: 314994078 Change-Id: I3b7253d4a14ff418068e49f9bd321d85695f9c8b
This commit is contained in:
parent
e96c412005
commit
3f7adb0d48
@ -180,9 +180,17 @@ py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "py_buffer",
|
name = "py_client",
|
||||||
srcs = ["py_buffer.cc"],
|
srcs = [
|
||||||
hdrs = ["py_buffer.h"],
|
"py_buffer.cc",
|
||||||
|
"py_client.cc",
|
||||||
|
"py_executable.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"py_buffer.h",
|
||||||
|
"py_client.h",
|
||||||
|
"py_executable.h",
|
||||||
|
],
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
@ -195,29 +203,10 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||||
"@com_google_absl//absl/types:optional",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "py_executable",
|
|
||||||
srcs = ["py_executable.cc"],
|
|
||||||
hdrs = ["py_executable.h"],
|
|
||||||
copts = [
|
|
||||||
"-fexceptions",
|
|
||||||
"-fno-strict-aliasing",
|
|
||||||
],
|
|
||||||
features = ["-use_header_modules"],
|
|
||||||
deps = [
|
|
||||||
":py_buffer",
|
|
||||||
":traceback_manager",
|
|
||||||
":types",
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
|
||||||
"//tensorflow/compiler/xla:types",
|
|
||||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
|
"@pybind11",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -231,7 +220,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
features = ["-use_header_modules"],
|
features = ["-use_header_modules"],
|
||||||
deps = [
|
deps = [
|
||||||
":py_buffer",
|
":py_client",
|
||||||
":traceback_manager",
|
":traceback_manager",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
@ -325,6 +314,7 @@ cc_library(
|
|||||||
features = ["-use_header_modules"],
|
features = ["-use_header_modules"],
|
||||||
deps = [
|
deps = [
|
||||||
":outfeed_receiver",
|
":outfeed_receiver",
|
||||||
|
":py_client",
|
||||||
":types",
|
":types",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||||
@ -355,8 +345,7 @@ pybind_extension(
|
|||||||
":bfloat16",
|
":bfloat16",
|
||||||
":dlpack",
|
":dlpack",
|
||||||
":ops",
|
":ops",
|
||||||
":py_buffer",
|
":py_client",
|
||||||
":py_executable",
|
|
||||||
":python_ref_manager",
|
":python_ref_manager",
|
||||||
":outfeed_receiver_py",
|
":outfeed_receiver_py",
|
||||||
":traceback_manager",
|
":traceback_manager",
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "include/dlpack/dlpack.h" // from @dlpack
|
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||||
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
@ -298,7 +299,7 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(PyBuffer* buffer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
||||||
const pybind11::capsule& tensor, std::shared_ptr<PjRtClient> client) {
|
const pybind11::capsule& tensor, std::shared_ptr<PyClient> client) {
|
||||||
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
|
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
|
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
|
||||||
@ -311,8 +312,9 @@ StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
|||||||
"Number of dimensions in DLManagedTensor must be nonnegative, got %d",
|
"Number of dimensions in DLManagedTensor must be nonnegative, got %d",
|
||||||
dlmt->dl_tensor.ndim);
|
dlmt->dl_tensor.ndim);
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(Device * device,
|
TF_ASSIGN_OR_RETURN(
|
||||||
DeviceForDLContext(*client, dlmt->dl_tensor.ctx));
|
Device * device,
|
||||||
|
DeviceForDLContext(*client->pjrt_client(), dlmt->dl_tensor.ctx));
|
||||||
absl::Span<int64 const> dimensions(
|
absl::Span<int64 const> dimensions(
|
||||||
reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
|
reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
|
||||||
TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
|
TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
|
||||||
@ -349,7 +351,7 @@ StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
|||||||
PyCapsule_SetName(tensor.ptr(), "used_dltensor");
|
PyCapsule_SetName(tensor.ptr(), "used_dltensor");
|
||||||
PyCapsule_SetDestructor(tensor.ptr(), nullptr);
|
PyCapsule_SetDestructor(tensor.ptr(), nullptr);
|
||||||
auto pjrt_buffer = std::make_unique<PjRtBuffer>(
|
auto pjrt_buffer = std::make_unique<PjRtBuffer>(
|
||||||
shape, shape, std::move(device_buffer), client.get(), device);
|
shape, shape, std::move(device_buffer), client->pjrt_client(), device);
|
||||||
return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer),
|
return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer),
|
||||||
TracebackManager::Get()->GetTraceback());
|
TracebackManager::Get()->GetTraceback());
|
||||||
}
|
}
|
||||||
|
@ -17,15 +17,15 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
|
#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
|
||||||
|
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
|
||||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
StatusOr<pybind11::capsule> BufferToDLPackManagedTensor(PyBuffer* buffer);
|
StatusOr<pybind11::capsule> BufferToDLPackManagedTensor(PyBuffer* buffer);
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
||||||
const pybind11::capsule& tensor, std::shared_ptr<PjRtClient> client);
|
const pybind11::capsule& tensor, std::shared_ptr<PyClient> client);
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
#include "tensorflow/compiler/xla/python/outfeed_receiver.h"
|
#include "tensorflow/compiler/xla/python/outfeed_receiver.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||||
#include "tensorflow/compiler/xla/python/types.h"
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -42,7 +43,7 @@ class OutfeedReceiverForPython {
|
|||||||
std::function<void(ClientAndPtr<Device>, uint32_t, pybind11::object)>;
|
std::function<void(ClientAndPtr<Device>, uint32_t, pybind11::object)>;
|
||||||
|
|
||||||
OutfeedReceiverForPython(CallbackToPython callback_python,
|
OutfeedReceiverForPython(CallbackToPython callback_python,
|
||||||
std::vector<std::shared_ptr<PjRtClient>> clients,
|
std::vector<std::shared_ptr<PyClient>> clients,
|
||||||
ssize_t max_callback_queue_size_bytes)
|
ssize_t max_callback_queue_size_bytes)
|
||||||
: callback_python_(std::move(callback_python)),
|
: callback_python_(std::move(callback_python)),
|
||||||
clients_(std::move(clients)) {
|
clients_(std::move(clients)) {
|
||||||
@ -52,9 +53,10 @@ class OutfeedReceiverForPython {
|
|||||||
this->Callback(device, consumer_id, std::move(literal));
|
this->Callback(device, consumer_id, std::move(literal));
|
||||||
};
|
};
|
||||||
std::vector<PjRtClient*> client_ptrs(clients.size());
|
std::vector<PjRtClient*> client_ptrs(clients.size());
|
||||||
absl::c_transform(
|
absl::c_transform(clients_, client_ptrs.begin(),
|
||||||
clients_, client_ptrs.begin(),
|
[](const std::shared_ptr<PyClient>& client) {
|
||||||
[](const std::shared_ptr<PjRtClient>& client) { return client.get(); });
|
return client->pjrt_client();
|
||||||
|
});
|
||||||
outfeed_receiver_ = absl::make_unique<OutfeedReceiver>(
|
outfeed_receiver_ = absl::make_unique<OutfeedReceiver>(
|
||||||
callback, client_ptrs, max_callback_queue_size_bytes);
|
callback, client_ptrs, max_callback_queue_size_bytes);
|
||||||
}
|
}
|
||||||
@ -95,8 +97,8 @@ class OutfeedReceiverForPython {
|
|||||||
}
|
}
|
||||||
// We expect the number of clients to be small, so an O(n) search is fine.
|
// We expect the number of clients to be small, so an O(n) search is fine.
|
||||||
auto it = absl::c_find_if(
|
auto it = absl::c_find_if(
|
||||||
clients_, [device](const std::shared_ptr<PjRtClient>& client) {
|
clients_, [device](const std::shared_ptr<PyClient>& client) {
|
||||||
return client.get() == device->client();
|
return client->pjrt_client() == device->client();
|
||||||
});
|
});
|
||||||
CHECK(it != clients_.end());
|
CHECK(it != clients_.end());
|
||||||
py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython
|
py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython
|
||||||
@ -112,7 +114,7 @@ class OutfeedReceiverForPython {
|
|||||||
CallbackToPython callback_python_;
|
CallbackToPython callback_python_;
|
||||||
absl::Mutex mu_;
|
absl::Mutex mu_;
|
||||||
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_) = false;
|
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_) = false;
|
||||||
std::vector<std::shared_ptr<PjRtClient>> clients_;
|
std::vector<std::shared_ptr<PyClient>> clients_;
|
||||||
std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
|
std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -124,7 +126,7 @@ void BuildOutfeedReceiverSubmodule(py::module* m) {
|
|||||||
outfeed_receiver.def(
|
outfeed_receiver.def(
|
||||||
"start",
|
"start",
|
||||||
[](OutfeedReceiverForPython::CallbackToPython callback_to_python,
|
[](OutfeedReceiverForPython::CallbackToPython callback_to_python,
|
||||||
std::vector<std::shared_ptr<PjRtClient>> clients,
|
std::vector<std::shared_ptr<PyClient>> clients,
|
||||||
ssize_t max_callback_queue_size_bytes)
|
ssize_t max_callback_queue_size_bytes)
|
||||||
-> std::unique_ptr<OutfeedReceiverForPython> {
|
-> std::unique_ptr<OutfeedReceiverForPython> {
|
||||||
auto server = absl::make_unique<OutfeedReceiverForPython>(
|
auto server = absl::make_unique<OutfeedReceiverForPython>(
|
||||||
|
@ -15,13 +15,15 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
PyBuffer::PyBuffer(std::shared_ptr<PjRtClient> client,
|
PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
|
||||||
std::unique_ptr<PjRtBuffer> buffer,
|
std::unique_ptr<PjRtBuffer> buffer,
|
||||||
absl::optional<TracebackManager::Traceback> traceback)
|
absl::optional<TracebackManager::Traceback> traceback)
|
||||||
: client_(std::move(client)),
|
: client_(std::move(client)),
|
||||||
|
@ -20,9 +20,8 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||||
#include "tensorflow/compiler/xla/python/types.h"
|
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
|
||||||
@ -33,11 +32,10 @@ namespace xla {
|
|||||||
// b) to add Python-specific functionality.
|
// b) to add Python-specific functionality.
|
||||||
class PyBuffer {
|
class PyBuffer {
|
||||||
public:
|
public:
|
||||||
PyBuffer(std::shared_ptr<PjRtClient> client,
|
PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer,
|
||||||
std::unique_ptr<PjRtBuffer> buffer,
|
|
||||||
absl::optional<TracebackManager::Traceback> traceback);
|
absl::optional<TracebackManager::Traceback> traceback);
|
||||||
|
|
||||||
std::shared_ptr<PjRtClient> client() const { return client_; }
|
std::shared_ptr<PyClient> client() const { return client_; }
|
||||||
PjRtBuffer* buffer() const { return buffer_.get(); }
|
PjRtBuffer* buffer() const { return buffer_.get(); }
|
||||||
|
|
||||||
ClientAndPtr<Device> device() const;
|
ClientAndPtr<Device> device() const;
|
||||||
@ -68,7 +66,7 @@ class PyBuffer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<PjRtClient> client_;
|
std::shared_ptr<PyClient> client_;
|
||||||
std::unique_ptr<PjRtBuffer> buffer_;
|
std::unique_ptr<PjRtBuffer> buffer_;
|
||||||
absl::optional<TracebackManager::Traceback> traceback_;
|
absl::optional<TracebackManager::Traceback> traceback_;
|
||||||
};
|
};
|
||||||
|
130
tensorflow/compiler/xla/python/py_client.cc
Normal file
130
tensorflow/compiler/xla/python/py_client.cc
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/py_executable.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
|
||||||
|
: pjrt_client_(std::move(pjrt_client)) {}
|
||||||
|
|
||||||
|
std::vector<ClientAndPtr<Device>> PyClient::Devices() {
|
||||||
|
std::vector<ClientAndPtr<Device>> devices;
|
||||||
|
devices.reserve(pjrt_client_->devices().size());
|
||||||
|
for (const auto& device : pjrt_client_->devices()) {
|
||||||
|
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
|
||||||
|
}
|
||||||
|
return devices;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ClientAndPtr<Device>> PyClient::LocalDevices() {
|
||||||
|
std::vector<ClientAndPtr<Device>> devices;
|
||||||
|
devices.reserve(pjrt_client_->local_devices().size());
|
||||||
|
for (Device* device : pjrt_client_->local_devices()) {
|
||||||
|
devices.push_back(WrapWithClient(shared_from_this(), device));
|
||||||
|
}
|
||||||
|
return devices;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
|
||||||
|
PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
DeviceAssignment device_assignment,
|
||||||
|
pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
|
||||||
|
std::vector<std::vector<ClientAndPtr<Device>>> result;
|
||||||
|
result.resize(num_replicas);
|
||||||
|
for (int r = 0; r < num_replicas; ++r) {
|
||||||
|
result[r].resize(num_partitions);
|
||||||
|
for (int p = 0; p < num_partitions; ++p) {
|
||||||
|
int device_id = device_assignment(r, p);
|
||||||
|
auto iter = pjrt_client_->id_to_device().find(device_id);
|
||||||
|
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
||||||
|
result[r][p] = WrapWithClient(shared_from_this(), iter->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<ClientAndPtr<Device>>>
|
||||||
|
PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
||||||
|
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||||
|
pjrt_client_->GetDefaultDeviceAssignment(
|
||||||
|
num_replicas, /*num_partitions=*/1));
|
||||||
|
std::vector<ClientAndPtr<Device>> result;
|
||||||
|
for (int i = 0; i < num_replicas; ++i) {
|
||||||
|
int device_id = device_assignment(i, 0);
|
||||||
|
auto iter = pjrt_client_->id_to_device().find(device_id);
|
||||||
|
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
||||||
|
result.push_back(WrapWithClient(shared_from_this(), iter->second));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
|
||||||
|
const pybind11::object& argument, Device* device, bool force_copy) {
|
||||||
|
if (device == nullptr) {
|
||||||
|
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
||||||
|
device = pjrt_client_->local_devices().front();
|
||||||
|
}
|
||||||
|
CHECK(device != nullptr);
|
||||||
|
auto iter = pjrt_client_->id_to_device().find(device->id());
|
||||||
|
if (iter->second != device) {
|
||||||
|
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
|
||||||
|
device->DebugString(),
|
||||||
|
pjrt_client_->platform_name());
|
||||||
|
}
|
||||||
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
|
|
||||||
|
absl::optional<CastToArrayResult> c = CastToArray(argument);
|
||||||
|
if (!c) {
|
||||||
|
return InvalidArgument("from_python argument must be an array.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
|
||||||
|
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||||
|
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
||||||
|
|
||||||
|
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||||
|
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<PjRtBuffer> buffer,
|
||||||
|
PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
|
||||||
|
std::move(py_buffer_ref), pjrt_client_.get(),
|
||||||
|
device));
|
||||||
|
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
||||||
|
traceback);
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PyExecutable>> PyClient::Compile(
|
||||||
|
const XlaComputation& computation, CompileOptions options) {
|
||||||
|
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
||||||
|
PjRtExecutable::Compile(computation, pjrt_client_.get(),
|
||||||
|
std::move(options)));
|
||||||
|
return std::make_unique<PyExecutable>(
|
||||||
|
shared_from_this(), std::move(executable), std::move(traceback));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
136
tensorflow/compiler/xla/python/py_client.h
Normal file
136
tensorflow/compiler/xla/python/py_client.h
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
class PyBuffer;
|
||||||
|
class PyClient;
|
||||||
|
class PyExecutable;
|
||||||
|
|
||||||
|
// Custom holder types.
|
||||||
|
//
|
||||||
|
// We must keep the PyClient object alive as long as any of the runtime
|
||||||
|
// objects are alive. Since we don't have a lot of control over Python
|
||||||
|
// destructor ordering, we keep the PyClient object as a std::shared_ptr<>,
|
||||||
|
// and ensure that each Python runtime object holds a reference to the
|
||||||
|
// PyClient. An alternative design would be to keep a single global
|
||||||
|
// singleton PyClient, although this seems less flexible, especially for
|
||||||
|
// writing tests.
|
||||||
|
//
|
||||||
|
// To maintain PyClient references, we define pybind11 holder classes that
|
||||||
|
// are custom smart pointers that also keep a reference to a PyClient.
|
||||||
|
// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
|
||||||
|
// seem sufficiently flexible to describe ownership relationships in cases where
|
||||||
|
// the ownership doesn't pertain to a direct argument or return value of a
|
||||||
|
// function. Another alternative to the holder classes would be to create proxy
|
||||||
|
// objects that contain both a reference and a runtime class; holder classes
|
||||||
|
// seem less tedious to define.
|
||||||
|
|
||||||
|
// A pair of a PyClient reference and an unowned pointer to T.
|
||||||
|
template <typename T>
|
||||||
|
struct ClientAndPtr {
|
||||||
|
ClientAndPtr() = default;
|
||||||
|
// pybind11 requires that we define a constructor that takes a raw pointer,
|
||||||
|
// but it should be unreachable.
|
||||||
|
explicit ClientAndPtr(T*) {
|
||||||
|
LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
|
||||||
|
}
|
||||||
|
|
||||||
|
ClientAndPtr(const ClientAndPtr&) = default;
|
||||||
|
ClientAndPtr(ClientAndPtr&&) = default;
|
||||||
|
ClientAndPtr& operator=(const ClientAndPtr&) = default;
|
||||||
|
ClientAndPtr& operator=(ClientAndPtr&&) = default;
|
||||||
|
|
||||||
|
std::shared_ptr<PyClient> client;
|
||||||
|
T* contents;
|
||||||
|
|
||||||
|
T* get() const { return contents; }
|
||||||
|
T* operator->() const { return contents; }
|
||||||
|
T& operator*() const { return *contents; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// By defining a templated helper function, we can use return type deduction
|
||||||
|
// and avoid specifying types at the caller.
|
||||||
|
template <typename T>
|
||||||
|
ClientAndPtr<T> WrapWithClient(std::shared_ptr<PyClient> client, T* contents) {
|
||||||
|
ClientAndPtr<T> result;
|
||||||
|
result.client = std::move(client);
|
||||||
|
result.contents = contents;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Python wrapper around PjRtClient.
|
||||||
|
// We use a wrapper class to add Python-specific functionality.
|
||||||
|
class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||||
|
public:
|
||||||
|
explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
|
||||||
|
|
||||||
|
PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
|
||||||
|
|
||||||
|
const std::string& platform_name() const {
|
||||||
|
return pjrt_client_->platform_name();
|
||||||
|
}
|
||||||
|
int local_device_count() const { return pjrt_client_->local_device_count(); }
|
||||||
|
int device_count() const { return pjrt_client_->device_count(); }
|
||||||
|
int host_id() const { return pjrt_client_->host_id(); }
|
||||||
|
|
||||||
|
std::vector<ClientAndPtr<Device>> Devices();
|
||||||
|
std::vector<ClientAndPtr<Device>> LocalDevices();
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
|
||||||
|
GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
|
||||||
|
|
||||||
|
// TODO(skye): delete after all callers can handle 2D output
|
||||||
|
StatusOr<std::vector<ClientAndPtr<Device>>> GetDefaultDeviceAssignment1D(
|
||||||
|
int num_replicas);
|
||||||
|
|
||||||
|
StatusOr<ChannelHandle> CreateChannelHandle() {
|
||||||
|
return pjrt_client_->client()->CreateChannelHandle();
|
||||||
|
}
|
||||||
|
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
|
||||||
|
return pjrt_client_->client()->CreateDeviceToHostChannelHandle();
|
||||||
|
}
|
||||||
|
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
|
||||||
|
return pjrt_client_->client()->CreateHostToDeviceChannelHandle();
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
|
||||||
|
const pybind11::object& argument, Device* device, bool force_copy);
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PyExecutable>> Compile(
|
||||||
|
const XlaComputation& computation, CompileOptions options);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<PjRtClient> pjrt_client_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
|
@ -22,7 +22,7 @@ namespace xla {
|
|||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
PyExecutable::PyExecutable(
|
PyExecutable::PyExecutable(
|
||||||
std::shared_ptr<PjRtClient> client,
|
std::shared_ptr<PyClient> client,
|
||||||
std::unique_ptr<PjRtExecutable> executable,
|
std::unique_ptr<PjRtExecutable> executable,
|
||||||
absl::optional<TracebackManager::Traceback> traceback)
|
absl::optional<TracebackManager::Traceback> traceback)
|
||||||
: client_(std::move(client)),
|
: client_(std::move(client)),
|
||||||
|
@ -24,23 +24,23 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||||
#include "tensorflow/compiler/xla/python/types.h"
|
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Python wrapper around PjRtExecutable. We use a wrapper class:
|
// Python wrapper around PjRtExecutable. We use a wrapper class:
|
||||||
// a) to keep the PjRtClient alive via a std::shared_ptr<>
|
// a) to keep the PyClient alive via a std::shared_ptr<>
|
||||||
// b) to add Python-specific functionality.
|
// b) to add Python-specific functionality.
|
||||||
class PyExecutable {
|
class PyExecutable {
|
||||||
public:
|
public:
|
||||||
PyExecutable(std::shared_ptr<PjRtClient> client,
|
PyExecutable(std::shared_ptr<PyClient> client,
|
||||||
std::unique_ptr<PjRtExecutable> executable,
|
std::unique_ptr<PjRtExecutable> executable,
|
||||||
absl::optional<TracebackManager::Traceback> traceback);
|
absl::optional<TracebackManager::Traceback> traceback);
|
||||||
|
|
||||||
std::shared_ptr<PjRtClient> client() const { return client_; }
|
std::shared_ptr<PyClient> client() const { return client_; }
|
||||||
|
|
||||||
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
|
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
|
||||||
return executable_->local_logical_device_ids();
|
return executable_->local_logical_device_ids();
|
||||||
@ -67,7 +67,7 @@ class PyExecutable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<PjRtClient> client_;
|
std::shared_ptr<PyClient> client_;
|
||||||
std::unique_ptr<PjRtExecutable> executable_;
|
std::unique_ptr<PjRtExecutable> executable_;
|
||||||
absl::optional<TracebackManager::Traceback> traceback_;
|
absl::optional<TracebackManager::Traceback> traceback_;
|
||||||
};
|
};
|
||||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "pybind11/stl.h"
|
#include "pybind11/stl.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
|
||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
@ -36,65 +35,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Custom holder types.
|
|
||||||
//
|
|
||||||
// We must keep the PjRtClient object alive as long as any of the runtime
|
|
||||||
// objects are alive. Since we don't have a lot of control over Python
|
|
||||||
// destructor ordering, we keep the PjRtClient object as a std::shared_ptr<>,
|
|
||||||
// and ensure that each Python runtime object holds a reference to the
|
|
||||||
// PjRtClient. An alternative design would be to keep a single global
|
|
||||||
// singleton PjRtClient, although this seems less flexible, especially for
|
|
||||||
// writing tests.
|
|
||||||
//
|
|
||||||
// To maintain PjRtClient references, we define pybind11 holder classes that
|
|
||||||
// are custom smart pointers that also keep a reference to a PjRtClient.
|
|
||||||
// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
|
|
||||||
// seem sufficiently flexible to describe ownership relationships in cases where
|
|
||||||
// the ownership doesn't pertain to a direct argument or return value of a
|
|
||||||
// function. Another alternative to the holder classes would be to create proxy
|
|
||||||
// objects that contain both a reference and a runtime class; holder classes
|
|
||||||
// seem less tedious to define.
|
|
||||||
|
|
||||||
// A pair of a PjRtClient reference and an unowned pointer to T.
|
|
||||||
template <typename T>
|
|
||||||
struct ClientAndPtr {
|
|
||||||
ClientAndPtr() = default;
|
|
||||||
// pybind11 requires that we define a constructor that takes a raw pointer,
|
|
||||||
// but it should be unreachable.
|
|
||||||
explicit ClientAndPtr(T*) {
|
|
||||||
LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
|
|
||||||
}
|
|
||||||
|
|
||||||
ClientAndPtr(const ClientAndPtr&) = default;
|
|
||||||
ClientAndPtr(ClientAndPtr&&) = default;
|
|
||||||
ClientAndPtr& operator=(const ClientAndPtr&) = default;
|
|
||||||
ClientAndPtr& operator=(ClientAndPtr&&) = default;
|
|
||||||
|
|
||||||
std::shared_ptr<PjRtClient> client;
|
|
||||||
T* contents;
|
|
||||||
|
|
||||||
T* get() const { return contents; }
|
|
||||||
T* operator->() const { return contents; }
|
|
||||||
T& operator*() const { return *contents; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// By defining a templated helper function, we can use return type deduction
|
|
||||||
// and avoid specifying types at the caller.
|
|
||||||
template <typename T>
|
|
||||||
ClientAndPtr<T> WrapWithClient(std::shared_ptr<PjRtClient> client,
|
|
||||||
T* contents) {
|
|
||||||
ClientAndPtr<T> result;
|
|
||||||
result.client = std::move(client);
|
|
||||||
result.contents = contents;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla
|
|
||||||
|
|
||||||
PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
|
|
||||||
|
|
||||||
namespace xla {
|
|
||||||
|
|
||||||
// Initializes the NumPy API for the use of the types module.
|
// Initializes the NumPy API for the use of the types module.
|
||||||
bool InitializeNumpyAPIForTypes();
|
bool InitializeNumpyAPIForTypes();
|
||||||
|
|
||||||
|
@ -501,135 +501,52 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
|
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
|
||||||
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
|
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
|
||||||
|
|
||||||
py::class_<PjRtClient, std::shared_ptr<PjRtClient>> py_local_client(
|
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||||
m, "LocalClient");
|
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||||
py_local_client.def_property_readonly("platform", &PjRtClient::platform_name)
|
.def("device_count", &PyClient::device_count)
|
||||||
.def("device_count", &PjRtClient::device_count)
|
.def("local_device_count", &PyClient::local_device_count)
|
||||||
.def("local_device_count", &PjRtClient::local_device_count)
|
.def("devices", &PyClient::Devices)
|
||||||
.def("devices",
|
.def("local_devices", &PyClient::LocalDevices)
|
||||||
[](std::shared_ptr<PjRtClient> client) {
|
.def("host_id", &PyClient::host_id)
|
||||||
std::vector<ClientAndPtr<Device>> devices;
|
|
||||||
devices.reserve(client->devices().size());
|
|
||||||
for (const auto& device : client->devices()) {
|
|
||||||
devices.push_back(WrapWithClient(client, device.get()));
|
|
||||||
}
|
|
||||||
return devices;
|
|
||||||
})
|
|
||||||
.def("local_devices",
|
|
||||||
[](std::shared_ptr<PjRtClient> client) {
|
|
||||||
std::vector<ClientAndPtr<Device>> devices;
|
|
||||||
devices.reserve(client->local_devices().size());
|
|
||||||
for (Device* device : client->local_devices()) {
|
|
||||||
devices.push_back(WrapWithClient(client, device));
|
|
||||||
}
|
|
||||||
return devices;
|
|
||||||
})
|
|
||||||
.def("host_id", &PjRtClient::host_id)
|
|
||||||
.def("get_default_device_assignment",
|
.def("get_default_device_assignment",
|
||||||
[](std::shared_ptr<PjRtClient> client, int num_replicas,
|
&PyClient::GetDefaultDeviceAssignment)
|
||||||
int num_partitions)
|
|
||||||
-> StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>> {
|
|
||||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
|
||||||
client->GetDefaultDeviceAssignment(
|
|
||||||
num_replicas, num_partitions));
|
|
||||||
std::vector<std::vector<ClientAndPtr<Device>>> result;
|
|
||||||
result.resize(num_replicas);
|
|
||||||
for (int r = 0; r < num_replicas; ++r) {
|
|
||||||
result[r].resize(num_partitions);
|
|
||||||
for (int p = 0; p < num_partitions; ++p) {
|
|
||||||
int device_id = device_assignment(r, p);
|
|
||||||
auto iter = client->id_to_device().find(device_id);
|
|
||||||
CHECK(iter != client->id_to_device().end()) << device_id;
|
|
||||||
result[r][p] = WrapWithClient(client, iter->second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
})
|
|
||||||
// TODO(skye): delete after all callers can handle 2D output
|
// TODO(skye): delete after all callers can handle 2D output
|
||||||
.def("get_default_device_assignment",
|
.def("get_default_device_assignment",
|
||||||
[](std::shared_ptr<PjRtClient> client,
|
&PyClient::GetDefaultDeviceAssignment1D)
|
||||||
int num_replicas) -> StatusOr<std::vector<ClientAndPtr<Device>>> {
|
.def("create_channel_handle", &PyClient::CreateChannelHandle)
|
||||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
|
||||||
client->GetDefaultDeviceAssignment(
|
|
||||||
num_replicas, /*num_partitions=*/1));
|
|
||||||
std::vector<ClientAndPtr<Device>> result;
|
|
||||||
for (int i = 0; i < num_replicas; ++i) {
|
|
||||||
int device_id = device_assignment(i, 0);
|
|
||||||
auto iter = client->id_to_device().find(device_id);
|
|
||||||
CHECK(iter != client->id_to_device().end()) << device_id;
|
|
||||||
result.push_back(WrapWithClient(client, iter->second));
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
})
|
|
||||||
.def("create_channel_handle",
|
|
||||||
[](PjRtClient* client) {
|
|
||||||
return client->client()->CreateChannelHandle();
|
|
||||||
})
|
|
||||||
.def("create_device_to_host_channel_handle",
|
.def("create_device_to_host_channel_handle",
|
||||||
[](PjRtClient* client) {
|
&PyClient::CreateDeviceToHostChannelHandle)
|
||||||
return client->client()->CreateDeviceToHostChannelHandle();
|
.def("create_host_to_device_channel_handle",
|
||||||
})
|
&PyClient::CreateHostToDeviceChannelHandle)
|
||||||
.def("create_host_to_device_channel_handle", [](PjRtClient* client) {
|
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
|
||||||
return client->client()->CreateHostToDeviceChannelHandle();
|
py::arg("device") = nullptr, py::arg("force_copy") = false)
|
||||||
|
.def("compile", &PyClient::Compile, py::arg("computation"),
|
||||||
|
py::arg("compile_options") = CompileOptions());
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"get_cpu_client",
|
||||||
|
[](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
|
||||||
|
TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
|
||||||
|
GetCpuClient(asynchronous));
|
||||||
|
return std::make_shared<PyClient>(std::move(client));
|
||||||
|
},
|
||||||
|
py::arg("asynchronous") = true);
|
||||||
|
m.def("get_interpreter_client", []() -> StatusOr<std::shared_ptr<PyClient>> {
|
||||||
|
TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
|
||||||
|
GetInterpreterClient());
|
||||||
|
return std::make_shared<PyClient>(std::move(client));
|
||||||
});
|
});
|
||||||
py_local_client.def(
|
m.def(
|
||||||
"buffer_from_pyval",
|
"get_nvidia_gpu_client",
|
||||||
[](std::shared_ptr<PjRtClient> client, const pybind11::object& argument,
|
[](bool asynchronous, const GpuAllocatorConfig& allocator_config,
|
||||||
Device* device,
|
std::shared_ptr<DistributedRuntimeClient> distributed_client,
|
||||||
bool force_copy) -> StatusOr<std::unique_ptr<PyBuffer>> {
|
int node_id) -> StatusOr<std::shared_ptr<PyClient>> {
|
||||||
if (device == nullptr) {
|
|
||||||
TF_RET_CHECK(!client->local_devices().empty());
|
|
||||||
device = client->local_devices().front();
|
|
||||||
}
|
|
||||||
CHECK(device != nullptr);
|
|
||||||
auto iter = client->id_to_device().find(device->id());
|
|
||||||
if (iter->second != device) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Cannot copy value to device '%s' with '%s' backend",
|
|
||||||
device->DebugString(), client->platform_name());
|
|
||||||
}
|
|
||||||
GlobalPyRefManager()->CollectGarbage();
|
|
||||||
|
|
||||||
absl::optional<CastToArrayResult> c = CastToArray(argument);
|
|
||||||
if (!c) {
|
|
||||||
return InvalidArgument("from_python argument must be an array.");
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
|
|
||||||
GetPythonBufferTree(argument));
|
|
||||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
|
||||||
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
|
||||||
|
|
||||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
|
||||||
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<PjRtBuffer> buffer,
|
std::shared_ptr<PjRtClient> client,
|
||||||
PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
|
GetNvidiaGpuClient(asynchronous, allocator_config,
|
||||||
std::move(py_buffer_ref), client.get(),
|
std::move(distributed_client), node_id));
|
||||||
device));
|
return std::make_shared<PyClient>(std::move(client));
|
||||||
return std::make_unique<PyBuffer>(std::move(client), std::move(buffer),
|
|
||||||
traceback);
|
|
||||||
},
|
},
|
||||||
py::arg("argument"), py::arg("device") = nullptr,
|
|
||||||
py::arg("force_copy") = false);
|
|
||||||
py_local_client.def(
|
|
||||||
"compile",
|
|
||||||
[](std::shared_ptr<PjRtClient> client, const XlaComputation& computation,
|
|
||||||
CompileOptions options) -> StatusOr<std::unique_ptr<PyExecutable>> {
|
|
||||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
|
||||||
PjRtExecutable::Compile(computation, client.get(),
|
|
||||||
std::move(options)));
|
|
||||||
return std::make_unique<PyExecutable>(
|
|
||||||
std::move(client), std::move(executable), std::move(traceback));
|
|
||||||
},
|
|
||||||
py::arg("computation"), py::arg("compile_options") = CompileOptions());
|
|
||||||
|
|
||||||
m.def("get_cpu_client", &GetCpuClient, py::arg("asynchronous") = true);
|
|
||||||
m.def("get_interpreter_client", &GetInterpreterClient);
|
|
||||||
m.def("get_nvidia_gpu_client", &GetNvidiaGpuClient,
|
|
||||||
py::arg("asynchronous") = true,
|
py::arg("asynchronous") = true,
|
||||||
py::arg("allocator_config") = GpuAllocatorConfig(),
|
py::arg("allocator_config") = GpuAllocatorConfig(),
|
||||||
py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
|
py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
|
||||||
|
@ -408,7 +408,7 @@ def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
|
|||||||
XlaBuilder = _xla.XlaBuilder
|
XlaBuilder = _xla.XlaBuilder
|
||||||
XlaComputation = _xla.XlaComputation
|
XlaComputation = _xla.XlaComputation
|
||||||
FftType = _xla.FftType
|
FftType = _xla.FftType
|
||||||
Client = _xla.LocalClient
|
Client = _xla.Client
|
||||||
Buffer = _xla.Buffer
|
Buffer = _xla.Buffer
|
||||||
Executable = _xla.Executable
|
Executable = _xla.Executable
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user