[XLA:Python] Refactor Python specifics out of PyLocalClient and PyLocalBuffer to remove dependency on pybind11.

PiperOrigin-RevId: 259312456
This commit is contained in:
A. Unique TensorFlower 2019-07-22 06:03:34 -07:00 committed by TensorFlower Gardener
parent 1ee51a3b86
commit 384e7f8c86
4 changed files with 117 additions and 93 deletions

View File

@ -69,7 +69,6 @@ cc_library(
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@pybind11", "@pybind11",

View File

@ -85,7 +85,6 @@ limitations under the License.
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "absl/time/time.h" #include "absl/time/time.h"
#include "include/pybind11/pybind11.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/executable_run_options.h"
@ -106,8 +105,6 @@ limitations under the License.
namespace xla { namespace xla {
namespace py = pybind11;
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator( static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
se::Platform* platform, LocalClient* client, double memory_fraction, se::Platform* platform, LocalClient* client, double memory_fraction,
bool preallocate) { bool preallocate) {
@ -222,47 +219,21 @@ PyLocalClient::PyLocalClient(
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal, Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
int device_ordinal) { int device_ordinal) {
py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
return client_->TransferToInfeedLocal(literal, device_ordinal); return client_->TransferToInfeedLocal(literal, device_ordinal);
} }
StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed( StatusOr<Literal> PyLocalClient::TransferFromOutfeed(const Shape& shape,
const Shape& shape, int device_ordinal) { int device_ordinal) {
py_ref_manager().CollectGarbage(); return client_->TransferFromOutfeedLocal(shape, device_ordinal);
Literal literal;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
literal, client_->TransferFromOutfeedLocal(shape, device_ordinal));
}
return LiteralToPython(std::make_shared<Literal>(std::move(literal)));
} }
/* static */ /* static */
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython( StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
const py::object& argument, std::shared_ptr<PyLocalClient> client, std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
int device_ordinal) { std::shared_ptr<void> leaves_reference,
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython"); std::shared_ptr<PyLocalClient> client, int device_ordinal) {
struct H2DTransfer { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
PythonBufferTree tree; VLOG(1) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString()
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref;
};
auto transfer = std::make_shared<H2DTransfer>();
TF_ASSIGN_OR_RETURN(transfer->tree, GetPythonBufferTree(argument));
client->py_ref_manager().CollectGarbage();
// Take a reference to the buffer to ensure that the inputs in host memory
// remain live until the transfer is complete.
transfer->py_buffer_ref = client->py_ref_manager().ManageReferences(
absl::MakeSpan(transfer->tree.arrays));
transfer->tree.arrays.clear();
// We are done manipulating Python objects; release the GIL.
py::gil_scoped_release gil_release;
VLOG(1) << "PyLocalBuffer::FromPython: shape: "
<< transfer->tree.shape.ToString()
<< " device ordinal: " << device_ordinal; << " device ordinal: " << device_ordinal;
Device* device = &client->device(device_ordinal); Device* device = &client->device(device_ordinal);
@ -270,11 +241,11 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
client->client()->backend().transfer_manager(); client->client()->backend().transfer_manager();
se::DeviceMemoryAllocator* allocator = client->allocator(); se::DeviceMemoryAllocator* allocator = client->allocator();
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
transfer->tree.shape, Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(transfer->tree.shape)); transfer_manager->ChooseCompactLayoutForShape(tuple_shape));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer, TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer,
transfer_manager->AllocateScopedShapedBuffer( transfer_manager->AllocateScopedShapedBuffer(
transfer->tree.shape, allocator, device_ordinal)); compact_shape, allocator, device_ordinal));
// Make the host to device stream wait for the newly allocated buffer to be // Make the host to device stream wait for the newly allocated buffer to be
// available on the compute stream. We schedule this wait synchronously; while // available on the compute stream. We schedule this wait synchronously; while
@ -293,21 +264,25 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(scoped_buffer), SharedDeviceBuffer::FromScopedShapedBuffer(std::move(scoped_buffer),
definition_event); definition_event);
// TODO(makro): Use move capture once C++ 14 features are available.
auto leaves = std::make_shared<std::vector<BorrowingLiteral>>(
std::move(leaves_literals));
auto transfer_h2d = [client, transfer_manager, device, device_ordinal, auto transfer_h2d = [client, transfer_manager, device, device_ordinal,
device_buffer, transfer]() { device_buffer, compact_shape, leaves,
leaves_reference]() {
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to // This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
// report failures from a callback. However, the operations here are // report failures from a callback. However, the operations here are
// unlikely to fail and not recoverable even if we were to fail: DMAs to // unlikely to fail and not recoverable even if we were to fail: DMAs to
// memory that has already been allocated, and a possible Event allocation. // memory that has already been allocated, and a possible Event allocation.
ShapedBuffer buffer = device_buffer->AsShapedBuffer(transfer->tree.shape); ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape);
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync( TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
device->host_to_device_stream(), buffer)); device->host_to_device_stream(), buffer));
std::vector<std::shared_ptr<void>> staging_buffers; std::vector<std::shared_ptr<void>> staging_buffers;
staging_buffers.reserve(transfer->tree.leaves.size()); staging_buffers.reserve(leaves->size());
auto it = transfer->tree.leaves.begin(); auto it = leaves->begin();
for (const ShapeUtil::IndexedShape& indexed_shape : for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(transfer->tree.shape)) { ShapeUtil::GetLeafShapes(compact_shape)) {
CHECK(it != transfer->tree.leaves.end()); CHECK(it != leaves->end());
ShapedBuffer leaf( ShapedBuffer leaf(
indexed_shape.shape, indexed_shape.shape,
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape), transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
@ -352,19 +327,19 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
device->ThenRelease(device->host_to_device_stream(), device_buffer); device->ThenRelease(device->host_to_device_stream(), device_buffer);
} }
device->ThenRelease(device->host_to_device_stream(), device->ThenRelease(
std::make_pair(std::move(transfer->py_buffer_ref), device->host_to_device_stream(),
std::move(staging_buffers))); std::make_pair(leaves_reference, std::move(staging_buffers)));
}; };
client->h2d_transfer_pool()->Schedule(transfer_h2d); client->h2d_transfer_pool()->Schedule(transfer_h2d);
return absl::make_unique<PyLocalBuffer>( return absl::make_unique<PyLocalBuffer>(
transfer->tree.shape, std::move(device_buffer), std::move(client)); compact_shape, std::move(device_buffer), std::move(client));
} }
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple( /* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
const std::vector<PyLocalBuffer*> buffers, const std::vector<PyLocalBuffer*> buffers,
std::shared_ptr<PyLocalClient> client, int device_ordinal) { std::shared_ptr<PyLocalClient> client, int device_ordinal) {
std::vector<xla::Shape> host_shapes; std::vector<Shape> host_shapes;
std::vector<std::shared_ptr<SharedDeviceBuffer>> device_buffers; std::vector<std::shared_ptr<SharedDeviceBuffer>> device_buffers;
host_shapes.reserve(buffers.size()); host_shapes.reserve(buffers.size());
device_buffers.reserve(buffers.size()); device_buffers.reserve(buffers.size());
@ -458,17 +433,13 @@ Status PyLocalBuffer::CopyToHostAsync() {
return Status::OK(); return Status::OK();
} }
StatusOr<py::object> PyLocalBuffer::ToPython() { StatusOr<std::shared_ptr<Literal>> PyLocalBuffer::ToLiteral() {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython"); tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToLiteral");
std::shared_ptr<SharedDeviceBuffer> device_buffer = DeviceBuffer(); std::shared_ptr<SharedDeviceBuffer> device_buffer = DeviceBuffer();
if (!device_buffer) { if (!device_buffer) {
return InvalidArgument("ToPython() called on invalid buffer."); return InvalidArgument("ToLiteral() called on invalid buffer.");
} }
client_->py_ref_manager().CollectGarbage();
std::shared_ptr<Literal> literal;
{
py::gil_scoped_release gil_release;
TF_RETURN_IF_ERROR(CopyToHostAsync()); TF_RETURN_IF_ERROR(CopyToHostAsync());
std::shared_ptr<HostValue> host_value; std::shared_ptr<HostValue> host_value;
{ {
@ -477,10 +448,7 @@ StatusOr<py::object> PyLocalBuffer::ToPython() {
} }
host_value->ready.WaitForNotification(); host_value->ready.WaitForNotification();
TF_RETURN_IF_ERROR(host_value->status); TF_RETURN_IF_ERROR(host_value->status);
literal = host_value->value; return host_value->value;
}
return LiteralToPython(std::move(literal));
} }
std::shared_ptr<SharedDeviceBuffer> PyLocalBuffer::DeviceBuffer() const { std::shared_ptr<SharedDeviceBuffer> PyLocalBuffer::DeviceBuffer() const {
@ -524,8 +492,6 @@ PyLocalBuffer::DestructureTuple() {
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice( StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
int dst_device_ordinal) { int dst_device_ordinal) {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice"); tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice");
client_->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
std::shared_ptr<SharedDeviceBuffer> src_device_buffer = DeviceBuffer(); std::shared_ptr<SharedDeviceBuffer> src_device_buffer = DeviceBuffer();
if (dst_device_ordinal == device_ordinal_) { if (dst_device_ordinal == device_ordinal_) {
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer, return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
@ -554,7 +520,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
// Copy the leaf buffers. // Copy the leaf buffers.
for (const auto& leaf : src_buffer.buffers().leaves()) { for (const auto& leaf : src_buffer.buffers().leaves()) {
const xla::ShapeIndex& index = leaf.first; const ShapeIndex& index = leaf.first;
const se::DeviceMemoryBase& input_buffer = leaf.second; const se::DeviceMemoryBase& input_buffer = leaf.second;
const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index); const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
TF_RET_CHECK(input_buffer.size() == output_buffer.size()) TF_RET_CHECK(input_buffer.size() == output_buffer.size())
@ -603,9 +569,6 @@ Status PyLocalBuffer::BlockHostUntilReady() {
return InvalidArgument("BlockHostUntilReady() called on invalid buffer."); return InvalidArgument("BlockHostUntilReady() called on invalid buffer.");
} }
client_->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
// This code waits at least until the buffer is ready, but it may wait longer // This code waits at least until the buffer is ready, but it may wait longer
// if there are other device to host transfers scheduled. If this proves to // if there are other device to host transfers scheduled. If this proves to
// be an issue, we could either use a separate stream for this purpose, or // be an issue, we could either use a separate stream for this purpose, or

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h" #include "absl/synchronization/notification.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "include/pybind11/pybind11.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
@ -78,8 +77,7 @@ class PyLocalClient {
virtual ~PyLocalClient() = default; virtual ~PyLocalClient() = default;
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal); Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
StatusOr<pybind11::object> TransferFromOutfeed(const Shape& shape, StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal);
int device_ordinal);
int device_count() const { return client_->device_count(); } int device_count() const { return client_->device_count(); }
Device& device(int device_ordinal) const { Device& device(int device_ordinal) const {
@ -128,9 +126,10 @@ class PyLocalClient {
// Thread-safe. // Thread-safe.
class PyLocalBuffer { class PyLocalBuffer {
public: public:
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromPython( static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals(
const pybind11::object& argument, std::shared_ptr<PyLocalClient> client, std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
int device_ordinal); std::shared_ptr<void> leaves_reference,
std::shared_ptr<PyLocalClient> client, int device_ordinal);
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple( static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
const std::vector<PyLocalBuffer*> buffers, const std::vector<PyLocalBuffer*> buffers,
@ -149,15 +148,19 @@ class PyLocalBuffer {
const Shape& on_host_shape() const { return on_host_shape_; } const Shape& on_host_shape() const { return on_host_shape_; }
int device_ordinal() const { return device_ordinal_; } int device_ordinal() const { return device_ordinal_; }
// TODO(makro): Make `client` private once `PythonRefManager` is refactored
// out of `PyLocalClient`.
PyLocalClient* client() const { return client_.get(); }
// Returns the buffer's value as a tuple DAG of Python arrays. If the value // Returns the buffer's value as a tuple DAG of Python arrays. If the value
// has previously been prefetched to the host, then returns the prefetched // has previously been prefetched to the host, then returns the prefetched
// version, otherwise copies the buffer to the host. Blocks until the // version, otherwise copies the buffer to the host. Blocks until the
// value is ready. // value is ready.
StatusOr<pybind11::object> ToPython(); StatusOr<std::shared_ptr<Literal>> ToLiteral();
// Initiates a copy of the buffer to the host. Does not block waiting for // Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. The value can be retrieved by a later call to // the transfer to complete. The value can be retrieved by a later call to
// ToPython(). // ToLiteral().
Status CopyToHostAsync(); Status CopyToHostAsync();
// Returns the associated device buffer. Returns a nullptr if the buffer is // Returns the associated device buffer. Returns a nullptr if the buffer is
@ -190,14 +193,14 @@ class PyLocalBuffer {
std::shared_ptr<SharedDeviceBuffer> device_buffer_ GUARDED_BY(mu_); std::shared_ptr<SharedDeviceBuffer> device_buffer_ GUARDED_BY(mu_);
// The cached value of the buffer on the host, produced either from a call to // The cached value of the buffer on the host, produced either from a call to
// CopyToHost or from a call to ToPython. Once a value has been fetched to // CopyToHost or from a call to ToLiteral. Once a value has been fetched to
// the host, it persists Delete() is called or the PyLocalBuffer is destroyed. // the host, it persists Delete() is called or the PyLocalBuffer is destroyed.
struct HostValue { struct HostValue {
absl::Notification ready; absl::Notification ready;
// status and value are valid for reading only after `ready` has been // status and value are valid for reading only after `ready` has been
// notified. // notified.
Status status; Status status;
std::shared_ptr<xla::Literal> value; std::shared_ptr<Literal> value;
}; };
std::shared_ptr<HostValue> host_value_ GUARDED_BY(mu_); std::shared_ptr<HostValue> host_value_ GUARDED_BY(mu_);
}; };

View File

@ -312,18 +312,77 @@ PYBIND11_MODULE(xla_extension, m) {
py::arg("xla_platform_id"), py::arg("asynchronous"), py::arg("xla_platform_id"), py::arg("asynchronous"),
py::arg("allocator_config") = AllocatorConfig()) py::arg("allocator_config") = AllocatorConfig())
.def("DeviceCount", &PyLocalClient::device_count) .def("DeviceCount", &PyLocalClient::device_count)
.def("TransferToInfeed", &PyLocalClient::TransferToInfeed) .def("TransferToInfeed",
.def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed); [](PyLocalClient* client, const LiteralSlice& literal,
int device_ordinal) {
client->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
return client->TransferToInfeed(literal, device_ordinal);
})
.def("TransferFromOutfeed",
[](PyLocalClient* client, const Shape& shape,
int device_ordinal) -> StatusOr<py::object> {
client->py_ref_manager().CollectGarbage();
std::shared_ptr<Literal> literal_shared;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed(
shape, device_ordinal));
literal_shared = std::make_shared<Literal>(std::move(literal));
}
return LiteralToPython(std::move(literal_shared));
});
py::class_<PyLocalBuffer>(m, "PyLocalBuffer") py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
.def_static("from_python", &PyLocalBuffer::FromPython) .def_static(
"from_python",
[](const pybind11::object& argument,
std::shared_ptr<PyLocalClient> client,
int device_ordinal) -> StatusOr<std::unique_ptr<PyLocalBuffer>> {
client->py_ref_manager().CollectGarbage();
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
GetPythonBufferTree(argument));
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
client->py_ref_manager().ManageReferences(
absl::MakeSpan(tree.arrays));
tree.arrays.clear();
std::vector<BorrowingLiteral> leaves;
leaves.insert(leaves.end(),
std::make_move_iterator(tree.leaves.begin()),
std::make_move_iterator(tree.leaves.end()));
py::gil_scoped_release gil_release;
return PyLocalBuffer::FromLiterals(
std::move(leaves), tree.shape, std::move(py_buffer_ref),
std::move(client), device_ordinal);
})
.def_static("make_tuple", &PyLocalBuffer::MakeTuple) .def_static("make_tuple", &PyLocalBuffer::MakeTuple)
.def("copy_to_device", &PyLocalBuffer::CopyToDevice) .def("copy_to_device",
[](PyLocalBuffer* buffer, int dst_device_ordinal) {
buffer->client()->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
return buffer->CopyToDevice(dst_device_ordinal);
})
.def("delete", &PyLocalBuffer::Delete) .def("delete", &PyLocalBuffer::Delete)
.def("destructure", &PyLocalBuffer::DestructureTuple) .def("destructure", &PyLocalBuffer::DestructureTuple)
.def("block_host_until_ready", &PyLocalBuffer::BlockHostUntilReady) .def("block_host_until_ready",
[](PyLocalBuffer* buffer) {
buffer->client()->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
return buffer->BlockHostUntilReady();
})
.def("copy_to_host_async", &PyLocalBuffer::CopyToHostAsync) .def("copy_to_host_async", &PyLocalBuffer::CopyToHostAsync)
.def("to_py", &PyLocalBuffer::ToPython) .def("to_py",
[](PyLocalBuffer* buffer) -> StatusOr<py::object> {
buffer->client()->py_ref_manager().CollectGarbage();
std::shared_ptr<Literal> literal;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(literal, buffer->ToLiteral());
}
return LiteralToPython(std::move(literal));
})
.def("shape", &PyLocalBuffer::on_host_shape) .def("shape", &PyLocalBuffer::on_host_shape)
.def("device", &PyLocalBuffer::device_ordinal) .def("device", &PyLocalBuffer::device_ordinal)
.def("is_deleted", .def("is_deleted",
@ -640,6 +699,6 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<ChannelHandle>(m, "ChannelHandle"); py::class_<ChannelHandle>(m, "ChannelHandle");
tensorflow::AddXrtSubmodule(&m); tensorflow::AddXrtSubmodule(&m);
} } // NOLINT(readability/fn_size)
} // namespace xla } // namespace xla