[XLA:Python] Plumb the PyBuffer's on-device size in bytes through to Python.

PiperOrigin-RevId: 347700150
Change-Id: Ia4da1a15fd8f9f0bdbd1569fa394f852739b3190
This commit is contained in:
A. Unique TensorFlower 2020-12-15 15:03:17 -08:00 committed by TensorFlower Gardener
parent deeb7f2e74
commit 3c374ed73b
3 changed files with 18 additions and 0 deletions

View File

@ -65,6 +65,8 @@ class PyBuffer : public DeviceArrayBase {
StatusOr<std::unique_ptr<PyBuffer>> CopyToDevice(
const ClientAndPtr<PjRtDevice>& dst_device) const;
int64 OnDeviceSizeInBytes() { return buffer_->OnDeviceSizeInBytes(); }
void Delete() {
buffer_->Delete();
npy_value_ = pybind11::none();

View File

@ -349,6 +349,7 @@ PYBIND11_MODULE(xla_extension, m) {
return npy_value_;
})
.def("copy_to_device", &PyBuffer::CopyToDevice)
.def("on_device_size_in_bytes", &PyBuffer::OnDeviceSizeInBytes)
.def("delete", &PyBuffer::Delete)
// The GIL is released within BlockHostUntilReady.
.def("block_until_ready",

View File

@ -486,6 +486,21 @@ def TestFactory(xla_backend, cloud_tpu=False):
with self.assertRaises(RuntimeError):
buffer.block_until_ready()
def testOnDeviceSizeInBytes(self):
if not isinstance(self.backend, xla_client.Client):
self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.")
arg0 = np.array([])
arg1 = np.array([[0., 1., 2.]], np.float32)
arg2 = np.array([[3., 4., 5.]], bfloat16)
arg0_buffer = self.backend.buffer_from_pyval(arg0)
arg1_buffer = self.backend.buffer_from_pyval(arg1)
arg2_buffer = self.backend.buffer_from_pyval(arg2)
self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0)
# OnDeviceSizeInBytes varies depending on the platform. Confirm there's
# a reasonable value.
self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0)
self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0)
def testCopyToHost(self):
arg0 = np.array([[1., 2.]], np.float32)
arg1 = np.array([[3., 4.]], np.float32)