diff --git a/tensorflow/compiler/xla/python/py_buffer.h b/tensorflow/compiler/xla/python/py_buffer.h index 50960fe5d37..efef7649926 100644 --- a/tensorflow/compiler/xla/python/py_buffer.h +++ b/tensorflow/compiler/xla/python/py_buffer.h @@ -65,6 +65,8 @@ class PyBuffer : public DeviceArrayBase { StatusOr> CopyToDevice( const ClientAndPtr& dst_device) const; + int64 OnDeviceSizeInBytes() { return buffer_->OnDeviceSizeInBytes(); } + void Delete() { buffer_->Delete(); npy_value_ = pybind11::none(); diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 4060df2b600..e6539ef3021 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -349,6 +349,7 @@ PYBIND11_MODULE(xla_extension, m) { return npy_value_; }) .def("copy_to_device", &PyBuffer::CopyToDevice) + .def("on_device_size_in_bytes", &PyBuffer::OnDeviceSizeInBytes) .def("delete", &PyBuffer::Delete) // The GIL is released within BlockHostUntilReady. .def("block_until_ready", diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index fc30883880e..bca3ca5907f 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -486,6 +486,21 @@ def TestFactory(xla_backend, cloud_tpu=False): with self.assertRaises(RuntimeError): buffer.block_until_ready() + def testOnDeviceSizeInBytes(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.") + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0) + # OnDeviceSizeInBytes varies depending on the platform. Confirm there's + # a reasonable value. + self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0) + self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0) + def testCopyToHost(self): arg0 = np.array([[1., 2.]], np.float32) arg1 = np.array([[3., 4.]], np.float32)