[XLA:Python] Plumb the PyBuffer's on-device size in bytes through to Python.
PiperOrigin-RevId: 347700150 Change-Id: Ia4da1a15fd8f9f0bdbd1569fa394f852739b3190
This commit is contained in:
parent
deeb7f2e74
commit
3c374ed73b
@ -65,6 +65,8 @@ class PyBuffer : public DeviceArrayBase {
|
||||
StatusOr<std::unique_ptr<PyBuffer>> CopyToDevice(
|
||||
const ClientAndPtr<PjRtDevice>& dst_device) const;
|
||||
|
||||
int64 OnDeviceSizeInBytes() { return buffer_->OnDeviceSizeInBytes(); }
|
||||
|
||||
void Delete() {
|
||||
buffer_->Delete();
|
||||
npy_value_ = pybind11::none();
|
||||
|
@ -349,6 +349,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
return npy_value_;
|
||||
})
|
||||
.def("copy_to_device", &PyBuffer::CopyToDevice)
|
||||
.def("on_device_size_in_bytes", &PyBuffer::OnDeviceSizeInBytes)
|
||||
.def("delete", &PyBuffer::Delete)
|
||||
// The GIL is released within BlockHostUntilReady.
|
||||
.def("block_until_ready",
|
||||
|
@ -486,6 +486,21 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
with self.assertRaises(RuntimeError):
|
||||
buffer.block_until_ready()
|
||||
|
||||
def testOnDeviceSizeInBytes(self):
|
||||
if not isinstance(self.backend, xla_client.Client):
|
||||
self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.")
|
||||
arg0 = np.array([])
|
||||
arg1 = np.array([[0., 1., 2.]], np.float32)
|
||||
arg2 = np.array([[3., 4., 5.]], bfloat16)
|
||||
arg0_buffer = self.backend.buffer_from_pyval(arg0)
|
||||
arg1_buffer = self.backend.buffer_from_pyval(arg1)
|
||||
arg2_buffer = self.backend.buffer_from_pyval(arg2)
|
||||
self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0)
|
||||
# OnDeviceSizeInBytes varies depending on the platform. Confirm there's
|
||||
# a reasonable value.
|
||||
self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0)
|
||||
self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0)
|
||||
|
||||
def testCopyToHost(self):
|
||||
arg0 = np.array([[1., 2.]], np.float32)
|
||||
arg1 = np.array([[3., 4.]], np.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user