[XLA:Python] Refactor Python specifics out of PyLocalClient and PyLocalBuffer to remove dependency on pybind11.

PiperOrigin-RevId: 259312456
This commit is contained in:
A. Unique TensorFlower 2019-07-22 06:03:34 -07:00 committed by TensorFlower Gardener
parent 1ee51a3b86
commit 384e7f8c86
4 changed files with 117 additions and 93 deletions

View File

@ -69,7 +69,6 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional",
"@pybind11",

View File

@ -85,7 +85,6 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "include/pybind11/pybind11.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
@ -106,8 +105,6 @@ limitations under the License.
namespace xla {
namespace py = pybind11;
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
se::Platform* platform, LocalClient* client, double memory_fraction,
bool preallocate) {
@ -222,47 +219,21 @@ PyLocalClient::PyLocalClient(
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
int device_ordinal) {
py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
return client_->TransferToInfeedLocal(literal, device_ordinal);
}
StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
const Shape& shape, int device_ordinal) {
py_ref_manager().CollectGarbage();
Literal literal;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
literal, client_->TransferFromOutfeedLocal(shape, device_ordinal));
}
return LiteralToPython(std::make_shared<Literal>(std::move(literal)));
StatusOr<Literal> PyLocalClient::TransferFromOutfeed(const Shape& shape,
int device_ordinal) {
return client_->TransferFromOutfeedLocal(shape, device_ordinal);
}
/* static */
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
const py::object& argument, std::shared_ptr<PyLocalClient> client,
int device_ordinal) {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython");
struct H2DTransfer {
PythonBufferTree tree;
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref;
};
auto transfer = std::make_shared<H2DTransfer>();
TF_ASSIGN_OR_RETURN(transfer->tree, GetPythonBufferTree(argument));
client->py_ref_manager().CollectGarbage();
// Take a reference to the buffer to ensure that the inputs in host memory
// remain live until the transfer is complete.
transfer->py_buffer_ref = client->py_ref_manager().ManageReferences(
absl::MakeSpan(transfer->tree.arrays));
transfer->tree.arrays.clear();
// We are done manipulating Python objects; release the GIL.
py::gil_scoped_release gil_release;
VLOG(1) << "PyLocalBuffer::FromPython: shape: "
<< transfer->tree.shape.ToString()
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
std::shared_ptr<void> leaves_reference,
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
VLOG(1) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString()
<< " device ordinal: " << device_ordinal;
Device* device = &client->device(device_ordinal);
@ -270,11 +241,11 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
client->client()->backend().transfer_manager();
se::DeviceMemoryAllocator* allocator = client->allocator();
TF_ASSIGN_OR_RETURN(
transfer->tree.shape,
transfer_manager->ChooseCompactLayoutForShape(transfer->tree.shape));
Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(tuple_shape));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer,
transfer_manager->AllocateScopedShapedBuffer(
transfer->tree.shape, allocator, device_ordinal));
compact_shape, allocator, device_ordinal));
// Make the host to device stream wait for the newly allocated buffer to be
// available on the compute stream. We schedule this wait synchronously; while
@ -293,21 +264,25 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(scoped_buffer),
definition_event);
// TODO(makro): Use move capture once C++ 14 features are available.
auto leaves = std::make_shared<std::vector<BorrowingLiteral>>(
std::move(leaves_literals));
auto transfer_h2d = [client, transfer_manager, device, device_ordinal,
device_buffer, transfer]() {
device_buffer, compact_shape, leaves,
leaves_reference]() {
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
// report failures from a callback. However, the operations here are
// unlikely to fail and not recoverable even if we were to fail: DMAs to
// memory that has already been allocated, and a possible Event allocation.
ShapedBuffer buffer = device_buffer->AsShapedBuffer(transfer->tree.shape);
ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape);
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
device->host_to_device_stream(), buffer));
std::vector<std::shared_ptr<void>> staging_buffers;
staging_buffers.reserve(transfer->tree.leaves.size());
auto it = transfer->tree.leaves.begin();
staging_buffers.reserve(leaves->size());
auto it = leaves->begin();
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(transfer->tree.shape)) {
CHECK(it != transfer->tree.leaves.end());
ShapeUtil::GetLeafShapes(compact_shape)) {
CHECK(it != leaves->end());
ShapedBuffer leaf(
indexed_shape.shape,
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
@ -352,19 +327,19 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
device->ThenRelease(device->host_to_device_stream(), device_buffer);
}
device->ThenRelease(device->host_to_device_stream(),
std::make_pair(std::move(transfer->py_buffer_ref),
std::move(staging_buffers)));
device->ThenRelease(
device->host_to_device_stream(),
std::make_pair(leaves_reference, std::move(staging_buffers)));
};
client->h2d_transfer_pool()->Schedule(transfer_h2d);
return absl::make_unique<PyLocalBuffer>(
transfer->tree.shape, std::move(device_buffer), std::move(client));
compact_shape, std::move(device_buffer), std::move(client));
}
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
const std::vector<PyLocalBuffer*> buffers,
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
std::vector<xla::Shape> host_shapes;
std::vector<Shape> host_shapes;
std::vector<std::shared_ptr<SharedDeviceBuffer>> device_buffers;
host_shapes.reserve(buffers.size());
device_buffers.reserve(buffers.size());
@ -458,17 +433,13 @@ Status PyLocalBuffer::CopyToHostAsync() {
return Status::OK();
}
StatusOr<py::object> PyLocalBuffer::ToPython() {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython");
StatusOr<std::shared_ptr<Literal>> PyLocalBuffer::ToLiteral() {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToLiteral");
std::shared_ptr<SharedDeviceBuffer> device_buffer = DeviceBuffer();
if (!device_buffer) {
return InvalidArgument("ToPython() called on invalid buffer.");
return InvalidArgument("ToLiteral() called on invalid buffer.");
}
client_->py_ref_manager().CollectGarbage();
std::shared_ptr<Literal> literal;
{
py::gil_scoped_release gil_release;
TF_RETURN_IF_ERROR(CopyToHostAsync());
std::shared_ptr<HostValue> host_value;
{
@ -477,10 +448,7 @@ StatusOr<py::object> PyLocalBuffer::ToPython() {
}
host_value->ready.WaitForNotification();
TF_RETURN_IF_ERROR(host_value->status);
literal = host_value->value;
}
return LiteralToPython(std::move(literal));
return host_value->value;
}
std::shared_ptr<SharedDeviceBuffer> PyLocalBuffer::DeviceBuffer() const {
@ -524,8 +492,6 @@ PyLocalBuffer::DestructureTuple() {
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
int dst_device_ordinal) {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice");
client_->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
std::shared_ptr<SharedDeviceBuffer> src_device_buffer = DeviceBuffer();
if (dst_device_ordinal == device_ordinal_) {
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
@ -554,7 +520,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
// Copy the leaf buffers.
for (const auto& leaf : src_buffer.buffers().leaves()) {
const xla::ShapeIndex& index = leaf.first;
const ShapeIndex& index = leaf.first;
const se::DeviceMemoryBase& input_buffer = leaf.second;
const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
TF_RET_CHECK(input_buffer.size() == output_buffer.size())
@ -603,9 +569,6 @@ Status PyLocalBuffer::BlockHostUntilReady() {
return InvalidArgument("BlockHostUntilReady() called on invalid buffer.");
}
client_->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
// This code waits at least until the buffer is ready, but it may wait longer
// if there are other device to host transfers scheduled. If this proves to
// be an issue, we could either use a separate stream for this purpose, or

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
#include "include/pybind11/pybind11.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@ -78,8 +77,7 @@ class PyLocalClient {
virtual ~PyLocalClient() = default;
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
StatusOr<pybind11::object> TransferFromOutfeed(const Shape& shape,
int device_ordinal);
StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal);
int device_count() const { return client_->device_count(); }
Device& device(int device_ordinal) const {
@ -128,9 +126,10 @@ class PyLocalClient {
// Thread-safe.
class PyLocalBuffer {
public:
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromPython(
const pybind11::object& argument, std::shared_ptr<PyLocalClient> client,
int device_ordinal);
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals(
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
std::shared_ptr<void> leaves_reference,
std::shared_ptr<PyLocalClient> client, int device_ordinal);
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
const std::vector<PyLocalBuffer*> buffers,
@ -149,15 +148,19 @@ class PyLocalBuffer {
const Shape& on_host_shape() const { return on_host_shape_; }
int device_ordinal() const { return device_ordinal_; }
// TODO(makro): Make `client` private once `PythonRefManager` is refactored
// out of `PyLocalClient`.
PyLocalClient* client() const { return client_.get(); }
// Returns the buffer's value as a tuple DAG of Python arrays. If the value
// has previously been prefetched to the host, then returns the prefetched
// version, otherwise copies the buffer to the host. Blocks until the
// value is ready.
StatusOr<pybind11::object> ToPython();
StatusOr<std::shared_ptr<Literal>> ToLiteral();
// Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. The value can be retrieved by a later call to
// ToPython().
// ToLiteral().
Status CopyToHostAsync();
// Returns the associated device buffer. Returns a nullptr if the buffer is
@ -190,14 +193,14 @@ class PyLocalBuffer {
std::shared_ptr<SharedDeviceBuffer> device_buffer_ GUARDED_BY(mu_);
// The cached value of the buffer on the host, produced either from a call to
// CopyToHost or from a call to ToPython. Once a value has been fetched to
// CopyToHost or from a call to ToLiteral. Once a value has been fetched to
// the host, it persists Delete() is called or the PyLocalBuffer is destroyed.
struct HostValue {
absl::Notification ready;
// status and value are valid for reading only after `ready` has been
// notified.
Status status;
std::shared_ptr<xla::Literal> value;
std::shared_ptr<Literal> value;
};
std::shared_ptr<HostValue> host_value_ GUARDED_BY(mu_);
};

View File

@ -312,18 +312,77 @@ PYBIND11_MODULE(xla_extension, m) {
py::arg("xla_platform_id"), py::arg("asynchronous"),
py::arg("allocator_config") = AllocatorConfig())
.def("DeviceCount", &PyLocalClient::device_count)
.def("TransferToInfeed", &PyLocalClient::TransferToInfeed)
.def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed);
.def("TransferToInfeed",
[](PyLocalClient* client, const LiteralSlice& literal,
int device_ordinal) {
client->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
return client->TransferToInfeed(literal, device_ordinal);
})
.def("TransferFromOutfeed",
[](PyLocalClient* client, const Shape& shape,
int device_ordinal) -> StatusOr<py::object> {
client->py_ref_manager().CollectGarbage();
std::shared_ptr<Literal> literal_shared;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed(
shape, device_ordinal));
literal_shared = std::make_shared<Literal>(std::move(literal));
}
return LiteralToPython(std::move(literal_shared));
});
py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
.def_static("from_python", &PyLocalBuffer::FromPython)
.def_static(
"from_python",
[](const pybind11::object& argument,
std::shared_ptr<PyLocalClient> client,
int device_ordinal) -> StatusOr<std::unique_ptr<PyLocalBuffer>> {
client->py_ref_manager().CollectGarbage();
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
GetPythonBufferTree(argument));
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
client->py_ref_manager().ManageReferences(
absl::MakeSpan(tree.arrays));
tree.arrays.clear();
std::vector<BorrowingLiteral> leaves;
leaves.insert(leaves.end(),
std::make_move_iterator(tree.leaves.begin()),
std::make_move_iterator(tree.leaves.end()));
py::gil_scoped_release gil_release;
return PyLocalBuffer::FromLiterals(
std::move(leaves), tree.shape, std::move(py_buffer_ref),
std::move(client), device_ordinal);
})
.def_static("make_tuple", &PyLocalBuffer::MakeTuple)
.def("copy_to_device", &PyLocalBuffer::CopyToDevice)
.def("copy_to_device",
[](PyLocalBuffer* buffer, int dst_device_ordinal) {
buffer->client()->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
return buffer->CopyToDevice(dst_device_ordinal);
})
.def("delete", &PyLocalBuffer::Delete)
.def("destructure", &PyLocalBuffer::DestructureTuple)
.def("block_host_until_ready", &PyLocalBuffer::BlockHostUntilReady)
.def("block_host_until_ready",
[](PyLocalBuffer* buffer) {
buffer->client()->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
return buffer->BlockHostUntilReady();
})
.def("copy_to_host_async", &PyLocalBuffer::CopyToHostAsync)
.def("to_py", &PyLocalBuffer::ToPython)
.def("to_py",
[](PyLocalBuffer* buffer) -> StatusOr<py::object> {
buffer->client()->py_ref_manager().CollectGarbage();
std::shared_ptr<Literal> literal;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(literal, buffer->ToLiteral());
}
return LiteralToPython(std::move(literal));
})
.def("shape", &PyLocalBuffer::on_host_shape)
.def("device", &PyLocalBuffer::device_ordinal)
.def("is_deleted",
@ -640,6 +699,6 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<ChannelHandle>(m, "ChannelHandle");
tensorflow::AddXrtSubmodule(&m);
}
} // NOLINT(readability/fn_size)
} // namespace xla