[XLA:Python] Refactor Python specifics out of PyLocalClient and PyLocalBuffer to remove dependency on pybind11.
PiperOrigin-RevId: 259312456
This commit is contained in:
parent
1ee51a3b86
commit
384e7f8c86
@ -69,7 +69,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@pybind11",
|
||||
|
@ -85,7 +85,6 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
@ -106,8 +105,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
|
||||
se::Platform* platform, LocalClient* client, double memory_fraction,
|
||||
bool preallocate) {
|
||||
@ -222,47 +219,21 @@ PyLocalClient::PyLocalClient(
|
||||
|
||||
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
|
||||
int device_ordinal) {
|
||||
py_ref_manager().CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
return client_->TransferToInfeedLocal(literal, device_ordinal);
|
||||
}
|
||||
|
||||
StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
|
||||
const Shape& shape, int device_ordinal) {
|
||||
py_ref_manager().CollectGarbage();
|
||||
Literal literal;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
literal, client_->TransferFromOutfeedLocal(shape, device_ordinal));
|
||||
}
|
||||
return LiteralToPython(std::make_shared<Literal>(std::move(literal)));
|
||||
StatusOr<Literal> PyLocalClient::TransferFromOutfeed(const Shape& shape,
|
||||
int device_ordinal) {
|
||||
return client_->TransferFromOutfeedLocal(shape, device_ordinal);
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
||||
const py::object& argument, std::shared_ptr<PyLocalClient> client,
|
||||
int device_ordinal) {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython");
|
||||
struct H2DTransfer {
|
||||
PythonBufferTree tree;
|
||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref;
|
||||
};
|
||||
auto transfer = std::make_shared<H2DTransfer>();
|
||||
TF_ASSIGN_OR_RETURN(transfer->tree, GetPythonBufferTree(argument));
|
||||
|
||||
client->py_ref_manager().CollectGarbage();
|
||||
|
||||
// Take a reference to the buffer to ensure that the inputs in host memory
|
||||
// remain live until the transfer is complete.
|
||||
transfer->py_buffer_ref = client->py_ref_manager().ManageReferences(
|
||||
absl::MakeSpan(transfer->tree.arrays));
|
||||
transfer->tree.arrays.clear();
|
||||
|
||||
// We are done manipulating Python objects; release the GIL.
|
||||
py::gil_scoped_release gil_release;
|
||||
VLOG(1) << "PyLocalBuffer::FromPython: shape: "
|
||||
<< transfer->tree.shape.ToString()
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
||||
std::shared_ptr<void> leaves_reference,
|
||||
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
|
||||
VLOG(1) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString()
|
||||
<< " device ordinal: " << device_ordinal;
|
||||
|
||||
Device* device = &client->device(device_ordinal);
|
||||
@ -270,11 +241,11 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
||||
client->client()->backend().transfer_manager();
|
||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
transfer->tree.shape,
|
||||
transfer_manager->ChooseCompactLayoutForShape(transfer->tree.shape));
|
||||
Shape compact_shape,
|
||||
transfer_manager->ChooseCompactLayoutForShape(tuple_shape));
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
transfer->tree.shape, allocator, device_ordinal));
|
||||
compact_shape, allocator, device_ordinal));
|
||||
|
||||
// Make the host to device stream wait for the newly allocated buffer to be
|
||||
// available on the compute stream. We schedule this wait synchronously; while
|
||||
@ -293,21 +264,25 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
||||
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(scoped_buffer),
|
||||
definition_event);
|
||||
|
||||
// TODO(makro): Use move capture once C++ 14 features are available.
|
||||
auto leaves = std::make_shared<std::vector<BorrowingLiteral>>(
|
||||
std::move(leaves_literals));
|
||||
auto transfer_h2d = [client, transfer_manager, device, device_ordinal,
|
||||
device_buffer, transfer]() {
|
||||
device_buffer, compact_shape, leaves,
|
||||
leaves_reference]() {
|
||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
|
||||
// report failures from a callback. However, the operations here are
|
||||
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
||||
// memory that has already been allocated, and a possible Event allocation.
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(transfer->tree.shape);
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape);
|
||||
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
|
||||
device->host_to_device_stream(), buffer));
|
||||
std::vector<std::shared_ptr<void>> staging_buffers;
|
||||
staging_buffers.reserve(transfer->tree.leaves.size());
|
||||
auto it = transfer->tree.leaves.begin();
|
||||
staging_buffers.reserve(leaves->size());
|
||||
auto it = leaves->begin();
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(transfer->tree.shape)) {
|
||||
CHECK(it != transfer->tree.leaves.end());
|
||||
ShapeUtil::GetLeafShapes(compact_shape)) {
|
||||
CHECK(it != leaves->end());
|
||||
ShapedBuffer leaf(
|
||||
indexed_shape.shape,
|
||||
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
|
||||
@ -352,19 +327,19 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
||||
device->ThenRelease(device->host_to_device_stream(), device_buffer);
|
||||
}
|
||||
|
||||
device->ThenRelease(device->host_to_device_stream(),
|
||||
std::make_pair(std::move(transfer->py_buffer_ref),
|
||||
std::move(staging_buffers)));
|
||||
device->ThenRelease(
|
||||
device->host_to_device_stream(),
|
||||
std::make_pair(leaves_reference, std::move(staging_buffers)));
|
||||
};
|
||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
return absl::make_unique<PyLocalBuffer>(
|
||||
transfer->tree.shape, std::move(device_buffer), std::move(client));
|
||||
compact_shape, std::move(device_buffer), std::move(client));
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
|
||||
const std::vector<PyLocalBuffer*> buffers,
|
||||
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
|
||||
std::vector<xla::Shape> host_shapes;
|
||||
std::vector<Shape> host_shapes;
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> device_buffers;
|
||||
host_shapes.reserve(buffers.size());
|
||||
device_buffers.reserve(buffers.size());
|
||||
@ -458,17 +433,13 @@ Status PyLocalBuffer::CopyToHostAsync() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<py::object> PyLocalBuffer::ToPython() {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython");
|
||||
StatusOr<std::shared_ptr<Literal>> PyLocalBuffer::ToLiteral() {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToLiteral");
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer = DeviceBuffer();
|
||||
if (!device_buffer) {
|
||||
return InvalidArgument("ToPython() called on invalid buffer.");
|
||||
return InvalidArgument("ToLiteral() called on invalid buffer.");
|
||||
}
|
||||
|
||||
client_->py_ref_manager().CollectGarbage();
|
||||
std::shared_ptr<Literal> literal;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_RETURN_IF_ERROR(CopyToHostAsync());
|
||||
std::shared_ptr<HostValue> host_value;
|
||||
{
|
||||
@ -477,10 +448,7 @@ StatusOr<py::object> PyLocalBuffer::ToPython() {
|
||||
}
|
||||
host_value->ready.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(host_value->status);
|
||||
literal = host_value->value;
|
||||
}
|
||||
|
||||
return LiteralToPython(std::move(literal));
|
||||
return host_value->value;
|
||||
}
|
||||
|
||||
std::shared_ptr<SharedDeviceBuffer> PyLocalBuffer::DeviceBuffer() const {
|
||||
@ -524,8 +492,6 @@ PyLocalBuffer::DestructureTuple() {
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
||||
int dst_device_ordinal) {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice");
|
||||
client_->py_ref_manager().CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
std::shared_ptr<SharedDeviceBuffer> src_device_buffer = DeviceBuffer();
|
||||
if (dst_device_ordinal == device_ordinal_) {
|
||||
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
|
||||
@ -554,7 +520,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
||||
|
||||
// Copy the leaf buffers.
|
||||
for (const auto& leaf : src_buffer.buffers().leaves()) {
|
||||
const xla::ShapeIndex& index = leaf.first;
|
||||
const ShapeIndex& index = leaf.first;
|
||||
const se::DeviceMemoryBase& input_buffer = leaf.second;
|
||||
const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
|
||||
TF_RET_CHECK(input_buffer.size() == output_buffer.size())
|
||||
@ -603,9 +569,6 @@ Status PyLocalBuffer::BlockHostUntilReady() {
|
||||
return InvalidArgument("BlockHostUntilReady() called on invalid buffer.");
|
||||
}
|
||||
|
||||
client_->py_ref_manager().CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
|
||||
// This code waits at least until the buffer is ready, but it may wait longer
|
||||
// if there are other device to host transfers scheduled. If this proves to
|
||||
// be an issue, we could either use a separate stream for this purpose, or
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
@ -78,8 +77,7 @@ class PyLocalClient {
|
||||
virtual ~PyLocalClient() = default;
|
||||
|
||||
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
|
||||
StatusOr<pybind11::object> TransferFromOutfeed(const Shape& shape,
|
||||
int device_ordinal);
|
||||
StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal);
|
||||
|
||||
int device_count() const { return client_->device_count(); }
|
||||
Device& device(int device_ordinal) const {
|
||||
@ -128,9 +126,10 @@ class PyLocalClient {
|
||||
// Thread-safe.
|
||||
class PyLocalBuffer {
|
||||
public:
|
||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromPython(
|
||||
const pybind11::object& argument, std::shared_ptr<PyLocalClient> client,
|
||||
int device_ordinal);
|
||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals(
|
||||
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
||||
std::shared_ptr<void> leaves_reference,
|
||||
std::shared_ptr<PyLocalClient> client, int device_ordinal);
|
||||
|
||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
||||
const std::vector<PyLocalBuffer*> buffers,
|
||||
@ -149,15 +148,19 @@ class PyLocalBuffer {
|
||||
const Shape& on_host_shape() const { return on_host_shape_; }
|
||||
int device_ordinal() const { return device_ordinal_; }
|
||||
|
||||
// TODO(makro): Make `client` private once `PythonRefManager` is refactored
|
||||
// out of `PyLocalClient`.
|
||||
PyLocalClient* client() const { return client_.get(); }
|
||||
|
||||
// Returns the buffer's value as a tuple DAG of Python arrays. If the value
|
||||
// has previously been prefetched to the host, then returns the prefetched
|
||||
// version, otherwise copies the buffer to the host. Blocks until the
|
||||
// value is ready.
|
||||
StatusOr<pybind11::object> ToPython();
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral();
|
||||
|
||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
||||
// the transfer to complete. The value can be retrieved by a later call to
|
||||
// ToPython().
|
||||
// ToLiteral().
|
||||
Status CopyToHostAsync();
|
||||
|
||||
// Returns the associated device buffer. Returns a nullptr if the buffer is
|
||||
@ -190,14 +193,14 @@ class PyLocalBuffer {
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer_ GUARDED_BY(mu_);
|
||||
|
||||
// The cached value of the buffer on the host, produced either from a call to
|
||||
// CopyToHost or from a call to ToPython. Once a value has been fetched to
|
||||
// CopyToHost or from a call to ToLiteral. Once a value has been fetched to
|
||||
// the host, it persists Delete() is called or the PyLocalBuffer is destroyed.
|
||||
struct HostValue {
|
||||
absl::Notification ready;
|
||||
// status and value are valid for reading only after `ready` has been
|
||||
// notified.
|
||||
Status status;
|
||||
std::shared_ptr<xla::Literal> value;
|
||||
std::shared_ptr<Literal> value;
|
||||
};
|
||||
std::shared_ptr<HostValue> host_value_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
@ -312,18 +312,77 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
py::arg("xla_platform_id"), py::arg("asynchronous"),
|
||||
py::arg("allocator_config") = AllocatorConfig())
|
||||
.def("DeviceCount", &PyLocalClient::device_count)
|
||||
.def("TransferToInfeed", &PyLocalClient::TransferToInfeed)
|
||||
.def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed);
|
||||
.def("TransferToInfeed",
|
||||
[](PyLocalClient* client, const LiteralSlice& literal,
|
||||
int device_ordinal) {
|
||||
client->py_ref_manager().CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
return client->TransferToInfeed(literal, device_ordinal);
|
||||
})
|
||||
.def("TransferFromOutfeed",
|
||||
[](PyLocalClient* client, const Shape& shape,
|
||||
int device_ordinal) -> StatusOr<py::object> {
|
||||
client->py_ref_manager().CollectGarbage();
|
||||
std::shared_ptr<Literal> literal_shared;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed(
|
||||
shape, device_ordinal));
|
||||
literal_shared = std::make_shared<Literal>(std::move(literal));
|
||||
}
|
||||
return LiteralToPython(std::move(literal_shared));
|
||||
});
|
||||
|
||||
py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
|
||||
.def_static("from_python", &PyLocalBuffer::FromPython)
|
||||
.def_static(
|
||||
"from_python",
|
||||
[](const pybind11::object& argument,
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
int device_ordinal) -> StatusOr<std::unique_ptr<PyLocalBuffer>> {
|
||||
client->py_ref_manager().CollectGarbage();
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
|
||||
GetPythonBufferTree(argument));
|
||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||
client->py_ref_manager().ManageReferences(
|
||||
absl::MakeSpan(tree.arrays));
|
||||
tree.arrays.clear();
|
||||
|
||||
std::vector<BorrowingLiteral> leaves;
|
||||
leaves.insert(leaves.end(),
|
||||
std::make_move_iterator(tree.leaves.begin()),
|
||||
std::make_move_iterator(tree.leaves.end()));
|
||||
|
||||
py::gil_scoped_release gil_release;
|
||||
return PyLocalBuffer::FromLiterals(
|
||||
std::move(leaves), tree.shape, std::move(py_buffer_ref),
|
||||
std::move(client), device_ordinal);
|
||||
})
|
||||
.def_static("make_tuple", &PyLocalBuffer::MakeTuple)
|
||||
.def("copy_to_device", &PyLocalBuffer::CopyToDevice)
|
||||
.def("copy_to_device",
|
||||
[](PyLocalBuffer* buffer, int dst_device_ordinal) {
|
||||
buffer->client()->py_ref_manager().CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
return buffer->CopyToDevice(dst_device_ordinal);
|
||||
})
|
||||
.def("delete", &PyLocalBuffer::Delete)
|
||||
.def("destructure", &PyLocalBuffer::DestructureTuple)
|
||||
.def("block_host_until_ready", &PyLocalBuffer::BlockHostUntilReady)
|
||||
.def("block_host_until_ready",
|
||||
[](PyLocalBuffer* buffer) {
|
||||
buffer->client()->py_ref_manager().CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
return buffer->BlockHostUntilReady();
|
||||
})
|
||||
.def("copy_to_host_async", &PyLocalBuffer::CopyToHostAsync)
|
||||
.def("to_py", &PyLocalBuffer::ToPython)
|
||||
.def("to_py",
|
||||
[](PyLocalBuffer* buffer) -> StatusOr<py::object> {
|
||||
buffer->client()->py_ref_manager().CollectGarbage();
|
||||
std::shared_ptr<Literal> literal;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(literal, buffer->ToLiteral());
|
||||
}
|
||||
return LiteralToPython(std::move(literal));
|
||||
})
|
||||
.def("shape", &PyLocalBuffer::on_host_shape)
|
||||
.def("device", &PyLocalBuffer::device_ordinal)
|
||||
.def("is_deleted",
|
||||
@ -640,6 +699,6 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
py::class_<ChannelHandle>(m, "ChannelHandle");
|
||||
|
||||
tensorflow::AddXrtSubmodule(&m);
|
||||
}
|
||||
} // NOLINT(readability/fn_size)
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user