diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index a4c83083a9e..2c3fcf5dedb 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -291,21 +291,46 @@ StatusOr PyLocalClient::GetDefaultDeviceAssignment( } /* static */ -StatusOr> PyLocalBuffer::FromLiterals( - std::vector leaves_literals, const Shape& tuple_shape, - std::shared_ptr leaves_reference, +StatusOr> PyLocalBuffer::FromHostBuffer( + const void* data, const Shape& shape, bool force_copy, + std::shared_ptr buffer_reference, std::shared_ptr client, std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals"); - VLOG(2) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString() + VLOG(2) << "PyLocalBuffer::FromLiterals: shape: " << shape.ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); + + // If we are on the host platform and the input buffer is sufficiently + // aligned, we can simply point to the NumPy array's data without any further + // copies. We require a 64-byte alignment because XLA may generate AVX512 + // code which requires it. Unfortunately NumPy's allocator doesn't align + // quite as aggressively, so there's a high chance this test will fail. + static constexpr int kMinimumAlignment = 64; + if (!force_copy && + ((absl::bit_cast(data) & (kMinimumAlignment - 1)) == 0) && + local_device->executor()->platform_kind() == se::PlatformKind::kHost) { + std::function on_delete_callback = + [buffer_reference{std::move(buffer_reference)}]() { + // Frees buffer_reference. + }; + se::DeviceMemoryBase buffer(const_cast(data), + ShapeUtil::ByteSizeOf(shape)); + auto device_buffer = std::make_shared( + /*allocator=*/nullptr, local_device->device_ordinal(), + std::initializer_list{buffer}, + /*children=*/std::vector>{}, + /*definition_event=*/nullptr, std::move(on_delete_callback)); + return absl::make_unique( + shape, shape, std::move(device_buffer), std::move(client), + std::move(device)); + } + TransferManager* transfer_manager = client->client()->backend().transfer_manager(); se::DeviceMemoryAllocator* allocator = client->allocator(); - TF_ASSIGN_OR_RETURN( - Shape compact_shape, - transfer_manager->ChooseCompactLayoutForShape(tuple_shape)); + TF_ASSIGN_OR_RETURN(Shape compact_shape, + transfer_manager->ChooseCompactLayoutForShape(shape)); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer scoped_buffer, transfer_manager->AllocateScopedShapedBuffer( @@ -330,12 +355,9 @@ StatusOr> PyLocalBuffer::FromLiterals( definition_event); Shape on_device_shape = scoped_buffer.on_device_shape(); - // TODO(makro): Use move capture once C++ 14 features are available. - auto leaves = std::make_shared>( - std::move(leaves_literals)); auto transfer_h2d = [client, transfer_manager, local_device, device_buffer, - compact_shape, on_device_shape, leaves, - leaves_reference]() { + shape, compact_shape, on_device_shape, data, + buffer_reference{std::move(buffer_reference)}]() { // This function uses TF_CHECK_OK and ValueOrDie() since we have no way to // report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to @@ -344,39 +366,27 @@ StatusOr> PyLocalBuffer::FromLiterals( compact_shape, on_device_shape, client->client()->platform()); TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync( local_device->host_to_device_stream(), buffer)); - std::vector> staging_buffers; - staging_buffers.reserve(leaves->size()); - auto it = leaves->begin(); - for (const ShapeUtil::IndexedShape& indexed_shape : - ShapeUtil::GetLeafShapes(compact_shape)) { - CHECK(it != leaves->end()); - ShapedBuffer leaf( - indexed_shape.shape, - transfer_manager->HostShapeToDeviceShape(indexed_shape.shape), - client->client()->platform(), local_device->device_ordinal()); - leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {}); + std::shared_ptr staging_buffer; - // If applicable on the backend, stage the transfer via host memory - // allocated via the host_memory_allocator. On GPU, this is pinned memory. - if (client->host_memory_allocator()) { - int64 size = it->size_bytes({}); - void* ptr = client->host_memory_allocator()->AllocateRaw( - tensorflow::Allocator::kAllocatorAlignment, size); - std::shared_ptr staging_buffer(ptr, [client](void* ptr) { - client->host_memory_allocator()->DeallocateRaw(ptr); - }); - std::memcpy(ptr, it->untyped_data({}), size); - BorrowingLiteral literal(static_cast(staging_buffer.get()), - it->shape()); - TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( - local_device->host_to_device_stream(), literal, leaf)); - staging_buffers.push_back(std::move(staging_buffer)); - } else { - // Otherwise, just transfer the literal. - TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( - local_device->host_to_device_stream(), *it, leaf)); - } - ++it; + // If applicable on the backend, stage the transfer via host memory + // allocated via the host_memory_allocator. On GPU, this is pinned memory. + if (client->host_memory_allocator()) { + int64 size = ShapeUtil::ByteSizeOf(shape); + void* ptr = client->host_memory_allocator()->AllocateRaw( + tensorflow::Allocator::kAllocatorAlignment, size); + staging_buffer = std::shared_ptr(ptr, [client](void* ptr) { + client->host_memory_allocator()->DeallocateRaw(ptr); + }); + std::memcpy(ptr, data, size); + BorrowingLiteral literal(static_cast(staging_buffer.get()), + shape); + TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( + local_device->host_to_device_stream(), literal, buffer)); + } else { + BorrowingLiteral literal(static_cast(data), shape); + // Otherwise, just transfer the literal. + TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( + local_device->host_to_device_stream(), literal, buffer)); } EventPool::Handle event = @@ -397,7 +407,7 @@ StatusOr> PyLocalBuffer::FromLiterals( local_device->ThenRelease( local_device->host_to_device_stream(), - std::make_pair(leaves_reference, std::move(staging_buffers))); + std::make_pair(buffer_reference, std::move(staging_buffer))); }; client->h2d_transfer_pool()->Schedule(transfer_h2d); return absl::make_unique( diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 72afa3d0135..9baece335fa 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -202,9 +202,14 @@ class PyLocalClient { // Thread-safe. class PyLocalBuffer { public: - static StatusOr> FromLiterals( - std::vector leaves_literals, const Shape& tuple_shape, - std::shared_ptr leaves_reference, + // If `force_copy` is true, forces a copy of the input buffer on CPU. + // Otherwise the library is free to alias the output buffer with `data`. + // `buffer_reference` is an optional shared pointer that should be kept alive + // by the runtime as long as the contents of `data` may still be accessed by + // the runtime (may be nullptr). + static StatusOr> FromHostBuffer( + const void* data, const Shape& shape, bool force_copy, + std::shared_ptr buffer_reference, std::shared_ptr client, std::shared_ptr device); static StatusOr> MakeTuple( diff --git a/tensorflow/compiler/xla/python/python_ref_manager.cc b/tensorflow/compiler/xla/python/python_ref_manager.cc index 0a980f1a749..cf449801205 100644 --- a/tensorflow/compiler/xla/python/python_ref_manager.cc +++ b/tensorflow/compiler/xla/python/python_ref_manager.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/python/python_ref_manager.h" +#include "absl/container/inlined_vector.h" + namespace xla { namespace py = pybind11; @@ -37,16 +39,27 @@ PythonRefManager::ManagedPyObjects::~ManagedPyObjects() { } } +std::shared_ptr +PythonRefManager::ManageReference(py::object object) { + return std::make_shared(this, + absl::Span(&object, 1)); +} + std::shared_ptr PythonRefManager::ManageReferences(absl::Span objects) { return std::make_shared(this, objects); } void PythonRefManager::CollectGarbage() { - // TODO(phawkins): ideally we would assert that the GIL is held, but there is - // no API to do this across all Python versions. - absl::MutexLock lock(&mu_); - python_garbage_.clear(); + // TODO(phawkins): we should CHECK(PyGILState_Check()); + std::deque garbage; + { + absl::MutexLock lock(&mu_); + garbage.swap(python_garbage_); + } + // We defer deleting garbage until the lock is released. It's possible that + // deleting garbage will lead to more Python garbage being added; if we held + // the lock we would deadlock because absl::Mutex is not reentrant. } PythonRefManager* GlobalPyRefManager() { diff --git a/tensorflow/compiler/xla/python/python_ref_manager.h b/tensorflow/compiler/xla/python/python_ref_manager.h index 22d9f659e98..2c6ea16c7f7 100644 --- a/tensorflow/compiler/xla/python/python_ref_manager.h +++ b/tensorflow/compiler/xla/python/python_ref_manager.h @@ -62,6 +62,7 @@ class PythonRefManager { // Creates a managed std::shared_ptr to an object. When the shared_ptr is // destroyed, the reference to 'object' will be added to python_garbage_, // and collected next time CollectGarbage() is called. + std::shared_ptr ManageReference(pybind11::object object); std::shared_ptr ManageReferences( absl::Span objects); diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py index 6096d3774f1..9e44a3d7aed 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py @@ -81,7 +81,7 @@ class TpuBackend(xla_client.Backend): def host_id(self): return self.client.host_id() - def buffer_from_pyval(self, pyval, device=None): + def buffer_from_pyval(self, pyval, device=None, force_copy=False): if device is None: device = self.client.local_devices()[0] return _tpu_client.PyTpuBuffer.from_python(pyval, self.client, device) diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h index 713751d78d5..ceefbda4f90 100644 --- a/tensorflow/compiler/xla/python/types.h +++ b/tensorflow/compiler/xla/python/types.h @@ -96,7 +96,7 @@ std::vector IntSequenceToVector(const pybind11::object& sequence); // xla::BorrowingLiteral. Converts a Python array-like object into a buffer // pointer and shape. struct CastToArrayResult { - pybind11::array array; // Holds a reference to the array to keep it alive. + pybind11::object array; // Holds a reference to the array to keep it alive. const char* buf_ptr; xla::Shape shape; }; diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index c1fb967fda9..15a60521096 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -655,8 +655,8 @@ PYBIND11_MODULE(xla_extension, m) { "from_python", [](const pybind11::object& argument, std::shared_ptr client, - std::shared_ptr device) - -> StatusOr> { + std::shared_ptr device, + bool force_copy) -> StatusOr> { CHECK(device != nullptr); auto iter = client->id_to_device().find(device->id()); if (iter->second != device) { @@ -665,23 +665,24 @@ PYBIND11_MODULE(xla_extension, m) { device->DebugString(), client->platform_name()); } GlobalPyRefManager()->CollectGarbage(); + + absl::optional c = CastToArray(argument); + if (!c) { + return InvalidArgument("from_python argument must be an array."); + } + TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument)); std::shared_ptr py_buffer_ref = - GlobalPyRefManager()->ManageReferences( - absl::MakeSpan(tree.arrays)); - tree.arrays.clear(); - - std::vector leaves; - leaves.insert(leaves.end(), - std::make_move_iterator(tree.leaves.begin()), - std::make_move_iterator(tree.leaves.end())); + GlobalPyRefManager()->ManageReference(std::move(c->array)); py::gil_scoped_release gil_release; - return PyLocalBuffer::FromLiterals( - std::move(leaves), tree.shape, std::move(py_buffer_ref), + return PyLocalBuffer::FromHostBuffer( + c->buf_ptr, c->shape, force_copy, std::move(py_buffer_ref), std::move(client), std::move(device)); - }) + }, + py::arg("argument"), py::arg("client"), py::arg("device"), + py::arg("force_copy") = false) .def_static("make_tuple", [](const std::vector buffers, std::shared_ptr client, diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 63e57f88803..7e10b660117 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -71,7 +71,7 @@ class Backend(object, metaclass=abc.ABCMeta): """Returns the integer ID of this host.""" @abc.abstractmethod - def buffer_from_pyval(self, pyval, device=None): + def buffer_from_pyval(self, pyval, device=None, force_copy=False): """Allocates a fresh buffer and populates it with `pyval`.""" @abc.abstractmethod @@ -129,10 +129,11 @@ class LocalBackend(Backend): def host_id(self): return self.client.host_id() - def buffer_from_pyval(self, pyval, device=None): + def buffer_from_pyval(self, pyval, device=None, force_copy=False): if device is None: device = self.local_devices()[0] - return _xla.PyLocalBuffer.from_python(pyval, self.client, device) + return _xla.PyLocalBuffer.from_python(pyval, self.client, device, + force_copy) def make_tuple(self, c_buffers, device): return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device) @@ -391,10 +392,10 @@ class Buffer(object): """ @staticmethod - def from_pyval(pyval, device=None, backend=None): + def from_pyval(pyval, device=None, backend=None, force_copy=False): """Copies the `pyval` to a freshly allocated on-device buffer.""" backend = backend or get_local_backend() - return backend.buffer_from_pyval(pyval, device) + return backend.buffer_from_pyval(pyval, device, force_copy=force_copy) @staticmethod def make_tuple(buffers, device, backend=None): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 53f94457005..0f97d06e5f7 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -471,15 +471,16 @@ class BufferTest(ComputationTest): compiled_c.Execute([arg_buffer]) def testDestructureTupleEmpty(self): - t = () - local_buffer = xla_client.Buffer.from_pyval(t) + device = xla_client.get_local_backend().devices()[0] + local_buffer = xla_client.Buffer.make_tuple((), device=device) pieces = local_buffer.destructure() self.assertFalse(local_buffer.is_deleted()) self.assertEmpty(pieces) def testDestructureTupleOneArrayElement(self): - t = (np.array([1, 2, 3, 4], dtype=np.int32),) - local_buffer = xla_client.Buffer.from_pyval(t) + device = xla_client.get_local_backend().devices()[0] + t = xla_client.Buffer.from_pyval(np.array([1, 2, 3, 4], dtype=np.int32)) + local_buffer = xla_client.Buffer.make_tuple((t,), device) pieces = local_buffer.destructure() self.assertFalse(local_buffer.is_deleted()) self.assertLen(pieces, 1) @@ -489,11 +490,13 @@ class BufferTest(ComputationTest): np.testing.assert_equal(want, got) def testDestructureTupleTwoArrayElementDifferentType(self): + device = xla_client.get_local_backend().devices()[0] t = ( - np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), - np.array([2, 3, 4, 5], dtype=np.int32), + xla_client.Buffer.from_pyval( + np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)), + xla_client.Buffer.from_pyval(np.array([2, 3, 4, 5], dtype=np.int32)), ) - local_buffer = xla_client.Buffer.from_pyval(t) + local_buffer = xla_client.Buffer.make_tuple(t, device) # Run the test twice to verify that the original tuple buffer remains valid # even after destructuring. for _ in range(2): @@ -509,8 +512,12 @@ class BufferTest(ComputationTest): np.testing.assert_equal(want, got) def testDestructureTupleNested(self): - t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5])) - local_buffer = xla_client.Buffer.from_pyval(t) + device = xla_client.get_local_backend().devices()[0] + t = xla_client.Buffer.make_tuple( + (xla_client.Buffer.from_pyval(NumpyArrayF32([1.0, 2.0])), + xla_client.Buffer.from_pyval(NumpyArrayS32([3, 4]))), device) + local_buffer = xla_client.Buffer.make_tuple( + (t, xla_client.Buffer.from_pyval(NumpyArrayS32([5]))), device) pieces = local_buffer.destructure() self.assertFalse(local_buffer.is_deleted()) self.assertLen(pieces, 2) @@ -2119,12 +2126,20 @@ class BufferProtocolTest(parameterized.TestCase): } for dtype in standard_dtypes for shape in testcase_shapes) def testRoundTrip(self, dtype, shape): x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + x_ptr = x.__array_interface__["data"][0] backend = xla_client.get_local_backend("cpu") buffer = xla_client.Buffer.from_pyval(x, backend=backend) y = np.array(buffer, copy=False) + y_ptr = y.__array_interface__["data"][0] np.testing.assert_array_equal(x, y) - self.assertEqual(y.__array_interface__["data"][0], - buffer.unsafe_buffer_pointer()) + # If the input was sufficiently aligned, the input and output should alias. + self.assertTrue((x_ptr & 63) != 0 or x_ptr == y_ptr) + self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) + + buffer2 = xla_client.Buffer.from_pyval(x, backend=backend, force_copy=True) + z = np.array(buffer2, copy=False) + self.assertNotEqual(x.__array_interface__["data"][0], + z.__array_interface__["data"][0]) def testDeleteWithActiveView(self): x = np.random.randn(20, 10)