[XLA:Python] Refactor Python specifics out of PyLocalClient and PyLocalBuffer to remove dependency on pybind11.
PiperOrigin-RevId: 259312456
This commit is contained in:
parent
1ee51a3b86
commit
384e7f8c86
@ -69,7 +69,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/stream_executor:device_memory_allocator",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@pybind11",
|
"@pybind11",
|
||||||
|
@ -85,7 +85,6 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "absl/time/time.h"
|
#include "absl/time/time.h"
|
||||||
#include "include/pybind11/pybind11.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||||
@ -106,8 +105,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
|
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
|
||||||
se::Platform* platform, LocalClient* client, double memory_fraction,
|
se::Platform* platform, LocalClient* client, double memory_fraction,
|
||||||
bool preallocate) {
|
bool preallocate) {
|
||||||
@ -222,47 +219,21 @@ PyLocalClient::PyLocalClient(
|
|||||||
|
|
||||||
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
|
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
|
||||||
int device_ordinal) {
|
int device_ordinal) {
|
||||||
py_ref_manager().CollectGarbage();
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
return client_->TransferToInfeedLocal(literal, device_ordinal);
|
return client_->TransferToInfeedLocal(literal, device_ordinal);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
|
StatusOr<Literal> PyLocalClient::TransferFromOutfeed(const Shape& shape,
|
||||||
const Shape& shape, int device_ordinal) {
|
int device_ordinal) {
|
||||||
py_ref_manager().CollectGarbage();
|
return client_->TransferFromOutfeedLocal(shape, device_ordinal);
|
||||||
Literal literal;
|
|
||||||
{
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
literal, client_->TransferFromOutfeedLocal(shape, device_ordinal));
|
|
||||||
}
|
|
||||||
return LiteralToPython(std::make_shared<Literal>(std::move(literal)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||||
const py::object& argument, std::shared_ptr<PyLocalClient> client,
|
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
||||||
int device_ordinal) {
|
std::shared_ptr<void> leaves_reference,
|
||||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython");
|
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
|
||||||
struct H2DTransfer {
|
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
|
||||||
PythonBufferTree tree;
|
VLOG(1) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString()
|
||||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref;
|
|
||||||
};
|
|
||||||
auto transfer = std::make_shared<H2DTransfer>();
|
|
||||||
TF_ASSIGN_OR_RETURN(transfer->tree, GetPythonBufferTree(argument));
|
|
||||||
|
|
||||||
client->py_ref_manager().CollectGarbage();
|
|
||||||
|
|
||||||
// Take a reference to the buffer to ensure that the inputs in host memory
|
|
||||||
// remain live until the transfer is complete.
|
|
||||||
transfer->py_buffer_ref = client->py_ref_manager().ManageReferences(
|
|
||||||
absl::MakeSpan(transfer->tree.arrays));
|
|
||||||
transfer->tree.arrays.clear();
|
|
||||||
|
|
||||||
// We are done manipulating Python objects; release the GIL.
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
VLOG(1) << "PyLocalBuffer::FromPython: shape: "
|
|
||||||
<< transfer->tree.shape.ToString()
|
|
||||||
<< " device ordinal: " << device_ordinal;
|
<< " device ordinal: " << device_ordinal;
|
||||||
|
|
||||||
Device* device = &client->device(device_ordinal);
|
Device* device = &client->device(device_ordinal);
|
||||||
@ -270,11 +241,11 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
|||||||
client->client()->backend().transfer_manager();
|
client->client()->backend().transfer_manager();
|
||||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
transfer->tree.shape,
|
Shape compact_shape,
|
||||||
transfer_manager->ChooseCompactLayoutForShape(transfer->tree.shape));
|
transfer_manager->ChooseCompactLayoutForShape(tuple_shape));
|
||||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer,
|
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer,
|
||||||
transfer_manager->AllocateScopedShapedBuffer(
|
transfer_manager->AllocateScopedShapedBuffer(
|
||||||
transfer->tree.shape, allocator, device_ordinal));
|
compact_shape, allocator, device_ordinal));
|
||||||
|
|
||||||
// Make the host to device stream wait for the newly allocated buffer to be
|
// Make the host to device stream wait for the newly allocated buffer to be
|
||||||
// available on the compute stream. We schedule this wait synchronously; while
|
// available on the compute stream. We schedule this wait synchronously; while
|
||||||
@ -293,21 +264,25 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
|||||||
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(scoped_buffer),
|
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(scoped_buffer),
|
||||||
definition_event);
|
definition_event);
|
||||||
|
|
||||||
|
// TODO(makro): Use move capture once C++ 14 features are available.
|
||||||
|
auto leaves = std::make_shared<std::vector<BorrowingLiteral>>(
|
||||||
|
std::move(leaves_literals));
|
||||||
auto transfer_h2d = [client, transfer_manager, device, device_ordinal,
|
auto transfer_h2d = [client, transfer_manager, device, device_ordinal,
|
||||||
device_buffer, transfer]() {
|
device_buffer, compact_shape, leaves,
|
||||||
|
leaves_reference]() {
|
||||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
|
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
|
||||||
// report failures from a callback. However, the operations here are
|
// report failures from a callback. However, the operations here are
|
||||||
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
||||||
// memory that has already been allocated, and a possible Event allocation.
|
// memory that has already been allocated, and a possible Event allocation.
|
||||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(transfer->tree.shape);
|
ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape);
|
||||||
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
|
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
|
||||||
device->host_to_device_stream(), buffer));
|
device->host_to_device_stream(), buffer));
|
||||||
std::vector<std::shared_ptr<void>> staging_buffers;
|
std::vector<std::shared_ptr<void>> staging_buffers;
|
||||||
staging_buffers.reserve(transfer->tree.leaves.size());
|
staging_buffers.reserve(leaves->size());
|
||||||
auto it = transfer->tree.leaves.begin();
|
auto it = leaves->begin();
|
||||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||||
ShapeUtil::GetLeafShapes(transfer->tree.shape)) {
|
ShapeUtil::GetLeafShapes(compact_shape)) {
|
||||||
CHECK(it != transfer->tree.leaves.end());
|
CHECK(it != leaves->end());
|
||||||
ShapedBuffer leaf(
|
ShapedBuffer leaf(
|
||||||
indexed_shape.shape,
|
indexed_shape.shape,
|
||||||
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
|
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
|
||||||
@ -352,19 +327,19 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
|||||||
device->ThenRelease(device->host_to_device_stream(), device_buffer);
|
device->ThenRelease(device->host_to_device_stream(), device_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
device->ThenRelease(device->host_to_device_stream(),
|
device->ThenRelease(
|
||||||
std::make_pair(std::move(transfer->py_buffer_ref),
|
device->host_to_device_stream(),
|
||||||
std::move(staging_buffers)));
|
std::make_pair(leaves_reference, std::move(staging_buffers)));
|
||||||
};
|
};
|
||||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||||
return absl::make_unique<PyLocalBuffer>(
|
return absl::make_unique<PyLocalBuffer>(
|
||||||
transfer->tree.shape, std::move(device_buffer), std::move(client));
|
compact_shape, std::move(device_buffer), std::move(client));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
|
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
|
||||||
const std::vector<PyLocalBuffer*> buffers,
|
const std::vector<PyLocalBuffer*> buffers,
|
||||||
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
|
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
|
||||||
std::vector<xla::Shape> host_shapes;
|
std::vector<Shape> host_shapes;
|
||||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> device_buffers;
|
std::vector<std::shared_ptr<SharedDeviceBuffer>> device_buffers;
|
||||||
host_shapes.reserve(buffers.size());
|
host_shapes.reserve(buffers.size());
|
||||||
device_buffers.reserve(buffers.size());
|
device_buffers.reserve(buffers.size());
|
||||||
@ -458,17 +433,13 @@ Status PyLocalBuffer::CopyToHostAsync() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<py::object> PyLocalBuffer::ToPython() {
|
StatusOr<std::shared_ptr<Literal>> PyLocalBuffer::ToLiteral() {
|
||||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython");
|
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToLiteral");
|
||||||
std::shared_ptr<SharedDeviceBuffer> device_buffer = DeviceBuffer();
|
std::shared_ptr<SharedDeviceBuffer> device_buffer = DeviceBuffer();
|
||||||
if (!device_buffer) {
|
if (!device_buffer) {
|
||||||
return InvalidArgument("ToPython() called on invalid buffer.");
|
return InvalidArgument("ToLiteral() called on invalid buffer.");
|
||||||
}
|
}
|
||||||
|
|
||||||
client_->py_ref_manager().CollectGarbage();
|
|
||||||
std::shared_ptr<Literal> literal;
|
|
||||||
{
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
TF_RETURN_IF_ERROR(CopyToHostAsync());
|
TF_RETURN_IF_ERROR(CopyToHostAsync());
|
||||||
std::shared_ptr<HostValue> host_value;
|
std::shared_ptr<HostValue> host_value;
|
||||||
{
|
{
|
||||||
@ -477,10 +448,7 @@ StatusOr<py::object> PyLocalBuffer::ToPython() {
|
|||||||
}
|
}
|
||||||
host_value->ready.WaitForNotification();
|
host_value->ready.WaitForNotification();
|
||||||
TF_RETURN_IF_ERROR(host_value->status);
|
TF_RETURN_IF_ERROR(host_value->status);
|
||||||
literal = host_value->value;
|
return host_value->value;
|
||||||
}
|
|
||||||
|
|
||||||
return LiteralToPython(std::move(literal));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<SharedDeviceBuffer> PyLocalBuffer::DeviceBuffer() const {
|
std::shared_ptr<SharedDeviceBuffer> PyLocalBuffer::DeviceBuffer() const {
|
||||||
@ -524,8 +492,6 @@ PyLocalBuffer::DestructureTuple() {
|
|||||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
||||||
int dst_device_ordinal) {
|
int dst_device_ordinal) {
|
||||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice");
|
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice");
|
||||||
client_->py_ref_manager().CollectGarbage();
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
std::shared_ptr<SharedDeviceBuffer> src_device_buffer = DeviceBuffer();
|
std::shared_ptr<SharedDeviceBuffer> src_device_buffer = DeviceBuffer();
|
||||||
if (dst_device_ordinal == device_ordinal_) {
|
if (dst_device_ordinal == device_ordinal_) {
|
||||||
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
|
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
|
||||||
@ -554,7 +520,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
|||||||
|
|
||||||
// Copy the leaf buffers.
|
// Copy the leaf buffers.
|
||||||
for (const auto& leaf : src_buffer.buffers().leaves()) {
|
for (const auto& leaf : src_buffer.buffers().leaves()) {
|
||||||
const xla::ShapeIndex& index = leaf.first;
|
const ShapeIndex& index = leaf.first;
|
||||||
const se::DeviceMemoryBase& input_buffer = leaf.second;
|
const se::DeviceMemoryBase& input_buffer = leaf.second;
|
||||||
const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
|
const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
|
||||||
TF_RET_CHECK(input_buffer.size() == output_buffer.size())
|
TF_RET_CHECK(input_buffer.size() == output_buffer.size())
|
||||||
@ -603,9 +569,6 @@ Status PyLocalBuffer::BlockHostUntilReady() {
|
|||||||
return InvalidArgument("BlockHostUntilReady() called on invalid buffer.");
|
return InvalidArgument("BlockHostUntilReady() called on invalid buffer.");
|
||||||
}
|
}
|
||||||
|
|
||||||
client_->py_ref_manager().CollectGarbage();
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
|
|
||||||
// This code waits at least until the buffer is ready, but it may wait longer
|
// This code waits at least until the buffer is ready, but it may wait longer
|
||||||
// if there are other device to host transfers scheduled. If this proves to
|
// if there are other device to host transfers scheduled. If this proves to
|
||||||
// be an issue, we could either use a separate stream for this purpose, or
|
// be an issue, we could either use a separate stream for this purpose, or
|
||||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "absl/synchronization/notification.h"
|
#include "absl/synchronization/notification.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "include/pybind11/pybind11.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
@ -78,8 +77,7 @@ class PyLocalClient {
|
|||||||
virtual ~PyLocalClient() = default;
|
virtual ~PyLocalClient() = default;
|
||||||
|
|
||||||
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
|
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
|
||||||
StatusOr<pybind11::object> TransferFromOutfeed(const Shape& shape,
|
StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal);
|
||||||
int device_ordinal);
|
|
||||||
|
|
||||||
int device_count() const { return client_->device_count(); }
|
int device_count() const { return client_->device_count(); }
|
||||||
Device& device(int device_ordinal) const {
|
Device& device(int device_ordinal) const {
|
||||||
@ -128,9 +126,10 @@ class PyLocalClient {
|
|||||||
// Thread-safe.
|
// Thread-safe.
|
||||||
class PyLocalBuffer {
|
class PyLocalBuffer {
|
||||||
public:
|
public:
|
||||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromPython(
|
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals(
|
||||||
const pybind11::object& argument, std::shared_ptr<PyLocalClient> client,
|
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
||||||
int device_ordinal);
|
std::shared_ptr<void> leaves_reference,
|
||||||
|
std::shared_ptr<PyLocalClient> client, int device_ordinal);
|
||||||
|
|
||||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
||||||
const std::vector<PyLocalBuffer*> buffers,
|
const std::vector<PyLocalBuffer*> buffers,
|
||||||
@ -149,15 +148,19 @@ class PyLocalBuffer {
|
|||||||
const Shape& on_host_shape() const { return on_host_shape_; }
|
const Shape& on_host_shape() const { return on_host_shape_; }
|
||||||
int device_ordinal() const { return device_ordinal_; }
|
int device_ordinal() const { return device_ordinal_; }
|
||||||
|
|
||||||
|
// TODO(makro): Make `client` private once `PythonRefManager` is refactored
|
||||||
|
// out of `PyLocalClient`.
|
||||||
|
PyLocalClient* client() const { return client_.get(); }
|
||||||
|
|
||||||
// Returns the buffer's value as a tuple DAG of Python arrays. If the value
|
// Returns the buffer's value as a tuple DAG of Python arrays. If the value
|
||||||
// has previously been prefetched to the host, then returns the prefetched
|
// has previously been prefetched to the host, then returns the prefetched
|
||||||
// version, otherwise copies the buffer to the host. Blocks until the
|
// version, otherwise copies the buffer to the host. Blocks until the
|
||||||
// value is ready.
|
// value is ready.
|
||||||
StatusOr<pybind11::object> ToPython();
|
StatusOr<std::shared_ptr<Literal>> ToLiteral();
|
||||||
|
|
||||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
// Initiates a copy of the buffer to the host. Does not block waiting for
|
||||||
// the transfer to complete. The value can be retrieved by a later call to
|
// the transfer to complete. The value can be retrieved by a later call to
|
||||||
// ToPython().
|
// ToLiteral().
|
||||||
Status CopyToHostAsync();
|
Status CopyToHostAsync();
|
||||||
|
|
||||||
// Returns the associated device buffer. Returns a nullptr if the buffer is
|
// Returns the associated device buffer. Returns a nullptr if the buffer is
|
||||||
@ -190,14 +193,14 @@ class PyLocalBuffer {
|
|||||||
std::shared_ptr<SharedDeviceBuffer> device_buffer_ GUARDED_BY(mu_);
|
std::shared_ptr<SharedDeviceBuffer> device_buffer_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
// The cached value of the buffer on the host, produced either from a call to
|
// The cached value of the buffer on the host, produced either from a call to
|
||||||
// CopyToHost or from a call to ToPython. Once a value has been fetched to
|
// CopyToHost or from a call to ToLiteral. Once a value has been fetched to
|
||||||
// the host, it persists Delete() is called or the PyLocalBuffer is destroyed.
|
// the host, it persists Delete() is called or the PyLocalBuffer is destroyed.
|
||||||
struct HostValue {
|
struct HostValue {
|
||||||
absl::Notification ready;
|
absl::Notification ready;
|
||||||
// status and value are valid for reading only after `ready` has been
|
// status and value are valid for reading only after `ready` has been
|
||||||
// notified.
|
// notified.
|
||||||
Status status;
|
Status status;
|
||||||
std::shared_ptr<xla::Literal> value;
|
std::shared_ptr<Literal> value;
|
||||||
};
|
};
|
||||||
std::shared_ptr<HostValue> host_value_ GUARDED_BY(mu_);
|
std::shared_ptr<HostValue> host_value_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
@ -312,18 +312,77 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
py::arg("xla_platform_id"), py::arg("asynchronous"),
|
py::arg("xla_platform_id"), py::arg("asynchronous"),
|
||||||
py::arg("allocator_config") = AllocatorConfig())
|
py::arg("allocator_config") = AllocatorConfig())
|
||||||
.def("DeviceCount", &PyLocalClient::device_count)
|
.def("DeviceCount", &PyLocalClient::device_count)
|
||||||
.def("TransferToInfeed", &PyLocalClient::TransferToInfeed)
|
.def("TransferToInfeed",
|
||||||
.def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed);
|
[](PyLocalClient* client, const LiteralSlice& literal,
|
||||||
|
int device_ordinal) {
|
||||||
|
client->py_ref_manager().CollectGarbage();
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
return client->TransferToInfeed(literal, device_ordinal);
|
||||||
|
})
|
||||||
|
.def("TransferFromOutfeed",
|
||||||
|
[](PyLocalClient* client, const Shape& shape,
|
||||||
|
int device_ordinal) -> StatusOr<py::object> {
|
||||||
|
client->py_ref_manager().CollectGarbage();
|
||||||
|
std::shared_ptr<Literal> literal_shared;
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed(
|
||||||
|
shape, device_ordinal));
|
||||||
|
literal_shared = std::make_shared<Literal>(std::move(literal));
|
||||||
|
}
|
||||||
|
return LiteralToPython(std::move(literal_shared));
|
||||||
|
});
|
||||||
|
|
||||||
py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
|
py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
|
||||||
.def_static("from_python", &PyLocalBuffer::FromPython)
|
.def_static(
|
||||||
|
"from_python",
|
||||||
|
[](const pybind11::object& argument,
|
||||||
|
std::shared_ptr<PyLocalClient> client,
|
||||||
|
int device_ordinal) -> StatusOr<std::unique_ptr<PyLocalBuffer>> {
|
||||||
|
client->py_ref_manager().CollectGarbage();
|
||||||
|
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
|
||||||
|
GetPythonBufferTree(argument));
|
||||||
|
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||||
|
client->py_ref_manager().ManageReferences(
|
||||||
|
absl::MakeSpan(tree.arrays));
|
||||||
|
tree.arrays.clear();
|
||||||
|
|
||||||
|
std::vector<BorrowingLiteral> leaves;
|
||||||
|
leaves.insert(leaves.end(),
|
||||||
|
std::make_move_iterator(tree.leaves.begin()),
|
||||||
|
std::make_move_iterator(tree.leaves.end()));
|
||||||
|
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
return PyLocalBuffer::FromLiterals(
|
||||||
|
std::move(leaves), tree.shape, std::move(py_buffer_ref),
|
||||||
|
std::move(client), device_ordinal);
|
||||||
|
})
|
||||||
.def_static("make_tuple", &PyLocalBuffer::MakeTuple)
|
.def_static("make_tuple", &PyLocalBuffer::MakeTuple)
|
||||||
.def("copy_to_device", &PyLocalBuffer::CopyToDevice)
|
.def("copy_to_device",
|
||||||
|
[](PyLocalBuffer* buffer, int dst_device_ordinal) {
|
||||||
|
buffer->client()->py_ref_manager().CollectGarbage();
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
return buffer->CopyToDevice(dst_device_ordinal);
|
||||||
|
})
|
||||||
.def("delete", &PyLocalBuffer::Delete)
|
.def("delete", &PyLocalBuffer::Delete)
|
||||||
.def("destructure", &PyLocalBuffer::DestructureTuple)
|
.def("destructure", &PyLocalBuffer::DestructureTuple)
|
||||||
.def("block_host_until_ready", &PyLocalBuffer::BlockHostUntilReady)
|
.def("block_host_until_ready",
|
||||||
|
[](PyLocalBuffer* buffer) {
|
||||||
|
buffer->client()->py_ref_manager().CollectGarbage();
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
return buffer->BlockHostUntilReady();
|
||||||
|
})
|
||||||
.def("copy_to_host_async", &PyLocalBuffer::CopyToHostAsync)
|
.def("copy_to_host_async", &PyLocalBuffer::CopyToHostAsync)
|
||||||
.def("to_py", &PyLocalBuffer::ToPython)
|
.def("to_py",
|
||||||
|
[](PyLocalBuffer* buffer) -> StatusOr<py::object> {
|
||||||
|
buffer->client()->py_ref_manager().CollectGarbage();
|
||||||
|
std::shared_ptr<Literal> literal;
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
TF_ASSIGN_OR_RETURN(literal, buffer->ToLiteral());
|
||||||
|
}
|
||||||
|
return LiteralToPython(std::move(literal));
|
||||||
|
})
|
||||||
.def("shape", &PyLocalBuffer::on_host_shape)
|
.def("shape", &PyLocalBuffer::on_host_shape)
|
||||||
.def("device", &PyLocalBuffer::device_ordinal)
|
.def("device", &PyLocalBuffer::device_ordinal)
|
||||||
.def("is_deleted",
|
.def("is_deleted",
|
||||||
@ -640,6 +699,6 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
py::class_<ChannelHandle>(m, "ChannelHandle");
|
py::class_<ChannelHandle>(m, "ChannelHandle");
|
||||||
|
|
||||||
tensorflow::AddXrtSubmodule(&m);
|
tensorflow::AddXrtSubmodule(&m);
|
||||||
}
|
} // NOLINT(readability/fn_size)
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user