[XLA:Python] Remove PyLocalBuffer.make_tuple and PyLocalBuffer.destructure() from the API.
Since Execute() now supports tupling and untupling, we no longer need tuples in the Python API. This change is in preparation for changing the aliasing behavior of Execute(). PiperOrigin-RevId: 301631669 Change-Id: Idec8c5ebf0025052d6c0cef523f2c77c92e89e0a
This commit is contained in:
parent
85f7677b4a
commit
d8bccdb1b8
@ -86,9 +86,6 @@ class TpuBackend(xla_client.Backend):
|
||||
device = self.client.local_devices()[0]
|
||||
return _tpu_client.PyTpuBuffer.from_python(pyval, self.client, device)
|
||||
|
||||
def make_tuple(self, c_buffers, device):
|
||||
return _tpu_client.PyTpuBuffer.make_tuple(c_buffers, self.client, device)
|
||||
|
||||
def compile(self, c_computation, compile_options):
|
||||
options = _xla.ExecutableBuildOptions()
|
||||
options.num_replicas = compile_options.num_replicas
|
||||
|
@ -125,21 +125,6 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
std::move(leaves), tree.shape, std::move(py_buffer_ref),
|
||||
std::move(client), std::move(device));
|
||||
})
|
||||
.def_static("make_tuple",
|
||||
[](const std::vector<PyTpuBuffer*> buffers,
|
||||
std::shared_ptr<PyTpuClient> client,
|
||||
std::shared_ptr<Device> device)
|
||||
-> StatusOr<std::unique_ptr<PyTpuBuffer>> {
|
||||
CHECK(device != nullptr);
|
||||
auto iter = client->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
return InvalidArgument(
|
||||
"Cannot make tuple on device '%s' with '%s' backend",
|
||||
device->DebugString(), client->platform_name());
|
||||
}
|
||||
return PyTpuBuffer::MakeTuple(buffers, std::move(client),
|
||||
std::move(device));
|
||||
})
|
||||
.def("copy_to_device",
|
||||
[](PyTpuBuffer* buffer, std::shared_ptr<Device> dst_device) {
|
||||
CHECK(dst_device != nullptr);
|
||||
@ -148,7 +133,6 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
return buffer->CopyToDevice(std::move(dst_device));
|
||||
})
|
||||
.def("delete", &PyTpuBuffer::Delete)
|
||||
.def("destructure", &PyTpuBuffer::DestructureTuple)
|
||||
.def("block_host_until_ready",
|
||||
[](PyTpuBuffer* buffer) {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
|
@ -960,23 +960,6 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
},
|
||||
py::arg("argument"), py::arg("client"), py::arg("device"),
|
||||
py::arg("force_copy") = false)
|
||||
.def_static(
|
||||
"make_tuple",
|
||||
[](std::vector<PyLocalBuffer*> buffers,
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
Device* device) -> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
|
||||
CHECK(device != nullptr);
|
||||
auto iter = client->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
return InvalidArgument(
|
||||
"Cannot make tuple on device '%s' with '%s' backend",
|
||||
device->DebugString(), client->platform_name());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PyLocalBuffer> buffer,
|
||||
PyLocalBuffer::MakeTuple(buffers, client.get(), device));
|
||||
return WrapWithClient(std::move(client), std::move(buffer));
|
||||
})
|
||||
.def("copy_to_device",
|
||||
[](PyLocalBuffer* buffer, const ClientAndPtr<Device>& dst_device)
|
||||
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
|
||||
@ -988,20 +971,6 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
return WrapWithClient(dst_device.client, std::move(out));
|
||||
})
|
||||
.def("delete", &PyLocalBuffer::Delete)
|
||||
.def("destructure",
|
||||
[](const PyLocalBuffer& buffer)
|
||||
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::unique_ptr<PyLocalBuffer>> parts,
|
||||
buffer.DestructureTuple());
|
||||
std::vector<ClientAndUniquePtr<PyLocalBuffer>> output;
|
||||
output.reserve(parts.size());
|
||||
for (auto& part : parts) {
|
||||
output.push_back(WrapWithClient(
|
||||
buffer.client()->shared_from_this(), std::move(part)));
|
||||
}
|
||||
return std::move(output);
|
||||
})
|
||||
.def("block_host_until_ready",
|
||||
[](PyLocalBuffer* buffer) {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
|
@ -76,10 +76,6 @@ class Backend(object, metaclass=abc.ABCMeta):
|
||||
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
|
||||
"""Allocates a fresh buffer and populates it with `pyval`."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def make_tuple(self, c_buffers, device):
|
||||
"""Makes a tuple from a sequence of backend buffer objects."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def compile(self, computation, compile_options):
|
||||
"""Compiles a computation. Returns an executable."""
|
||||
@ -137,9 +133,6 @@ class LocalBackend(Backend):
|
||||
return _xla.PyLocalBuffer.from_python(pyval, self.client, device,
|
||||
force_copy)
|
||||
|
||||
def make_tuple(self, c_buffers, device):
|
||||
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device)
|
||||
|
||||
def compile(self, c_computation, compile_options):
|
||||
options = _xla.ExecutableBuildOptions()
|
||||
options.num_replicas = compile_options.num_replicas
|
||||
@ -396,18 +389,12 @@ class Buffer(object):
|
||||
backend = backend or get_local_backend()
|
||||
return backend.buffer_from_pyval(pyval, device, force_copy=force_copy)
|
||||
|
||||
@staticmethod
|
||||
def make_tuple(buffers, device, backend=None):
|
||||
backend = backend or get_local_backend()
|
||||
return backend.make_tuple(buffers, device)
|
||||
|
||||
# Buffer is not an instantiable type and exists only for its static methods.
|
||||
# The underlying buffer objects are C++ object with the following
|
||||
# API:
|
||||
# def shape(self) -> Shape:
|
||||
# def device(self) -> int:
|
||||
# def delete(self):
|
||||
# def destructure(self) -> [Buffer]
|
||||
# def is_deleted(self) -> bool:
|
||||
# def block_host_until_ready(self):
|
||||
# """Blocks the calling thread until the buffer is ready on device."""
|
||||
@ -426,11 +413,6 @@ class Buffer(object):
|
||||
# clients call methods on Backend to create buffers.
|
||||
|
||||
|
||||
# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops
|
||||
# compatibility with Jaxlib versions older than 0.1.13.
|
||||
LocalBuffer = Buffer
|
||||
|
||||
|
||||
def shape_from_pyval(pyval):
|
||||
"""Returns a Shape that describes a tuple-tree of Numpy arrays."""
|
||||
|
||||
|
@ -496,84 +496,6 @@ class BufferTest(ComputationTest):
|
||||
with self.assertRaises(RuntimeError):
|
||||
compiled_c.Execute([arg_buffer], tuple_arguments=False)
|
||||
|
||||
def testDestructureTupleEmpty(self):
|
||||
device = xla_client.get_local_backend().devices()[0]
|
||||
local_buffer = xla_client.Buffer.make_tuple((), device=device)
|
||||
pieces = local_buffer.destructure()
|
||||
self.assertFalse(local_buffer.is_deleted())
|
||||
self.assertEmpty(pieces)
|
||||
|
||||
def testDestructureTupleOneArrayElement(self):
|
||||
device = xla_client.get_local_backend().devices()[0]
|
||||
t = xla_client.Buffer.from_pyval(np.array([1, 2, 3, 4], dtype=np.int32))
|
||||
local_buffer = xla_client.Buffer.make_tuple((t,), device)
|
||||
pieces = local_buffer.destructure()
|
||||
self.assertFalse(local_buffer.is_deleted())
|
||||
self.assertLen(pieces, 1)
|
||||
array = pieces[0]
|
||||
got = array.to_py()
|
||||
want = NumpyArrayS32([1, 2, 3, 4])
|
||||
np.testing.assert_equal(want, got)
|
||||
|
||||
def testDestructureTupleTwoArrayElementDifferentType(self):
|
||||
device = xla_client.get_local_backend().devices()[0]
|
||||
t = (
|
||||
xla_client.Buffer.from_pyval(
|
||||
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)),
|
||||
xla_client.Buffer.from_pyval(np.array([2, 3, 4, 5], dtype=np.int32)),
|
||||
)
|
||||
local_buffer = xla_client.Buffer.make_tuple(t, device)
|
||||
# Run the test twice to verify that the original tuple buffer remains valid
|
||||
# even after destructuring.
|
||||
for _ in range(2):
|
||||
pieces = local_buffer.destructure()
|
||||
self.assertFalse(local_buffer.is_deleted())
|
||||
self.assertLen(pieces, 2)
|
||||
array0, array1 = pieces
|
||||
got = array0.to_py()
|
||||
want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0])
|
||||
np.testing.assert_equal(want, got)
|
||||
got = array1.to_py()
|
||||
want = NumpyArrayS32([2, 3, 4, 5])
|
||||
np.testing.assert_equal(want, got)
|
||||
|
||||
def testDestructureTupleNested(self):
|
||||
device = xla_client.get_local_backend().devices()[0]
|
||||
t = xla_client.Buffer.make_tuple(
|
||||
(xla_client.Buffer.from_pyval(NumpyArrayF32([1.0, 2.0])),
|
||||
xla_client.Buffer.from_pyval(NumpyArrayS32([3, 4]))), device)
|
||||
local_buffer = xla_client.Buffer.make_tuple(
|
||||
(t, xla_client.Buffer.from_pyval(NumpyArrayS32([5]))), device)
|
||||
pieces = local_buffer.destructure()
|
||||
self.assertFalse(local_buffer.is_deleted())
|
||||
self.assertLen(pieces, 2)
|
||||
tuple0, array1 = pieces
|
||||
got = array1.to_py()
|
||||
want = NumpyArrayS32([5])
|
||||
np.testing.assert_equal(want, got)
|
||||
got = tuple0.to_py()
|
||||
self.assertEqual(type(got), tuple)
|
||||
self.assertLen(got, 2)
|
||||
np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0])
|
||||
np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1])
|
||||
|
||||
def testMakeTuple(self):
|
||||
t = (
|
||||
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
|
||||
np.array([2, 3, 4, 5], dtype=np.int32),
|
||||
)
|
||||
b0 = xla_client.Buffer.from_pyval(t[0])
|
||||
b1 = xla_client.Buffer.from_pyval(t[1])
|
||||
device = xla_client.get_local_backend().local_devices()[0]
|
||||
btup = xla_client.Buffer.make_tuple([b0, b1], device=device)
|
||||
pieces = btup.destructure()
|
||||
self.assertLen(pieces, 2)
|
||||
array0, array1 = pieces
|
||||
np.testing.assert_equal(
|
||||
np.array([1, 2, 3, 4], dtype=np.float32), array0.to_py())
|
||||
np.testing.assert_equal(
|
||||
np.array([2, 3, 4, 5], dtype=np.int32), array1.to_py())
|
||||
|
||||
def testShape(self):
|
||||
pyval = np.array([[1., 2.]], np.float32)
|
||||
local_buffer = xla_client.Buffer.from_pyval(pyval)
|
||||
@ -581,23 +503,6 @@ class BufferTest(ComputationTest):
|
||||
self.assertEqual(xla_shape.dimensions(), (1, 2))
|
||||
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
|
||||
|
||||
def testTupleShape(self):
|
||||
t = (
|
||||
np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32),
|
||||
np.array([2, 3, 4, 5], dtype=np.int32),
|
||||
)
|
||||
b0 = xla_client.Buffer.from_pyval(t[0])
|
||||
b1 = xla_client.Buffer.from_pyval(t[1])
|
||||
device = xla_client.get_local_backend().local_devices()[0]
|
||||
tuple_buffer = xla_client.Buffer.make_tuple([b0, b1], device=device)
|
||||
tuple_shape = tuple_buffer.shape()
|
||||
self.assertEqual(tuple_shape.leaf_count(), 2)
|
||||
shapes = tuple_shape.tuple_shapes()
|
||||
self.assertLen(shapes, 2)
|
||||
shape1, shape2 = shapes
|
||||
self.assertEqual(shape1.dimensions(), (1, 4))
|
||||
self.assertEqual(shape2.dimensions(), (4,))
|
||||
|
||||
def testBlockHostUntilReadyWorks(self):
|
||||
arg = np.array([[1., 2.]], np.float32)
|
||||
arg_buffer = xla_client.Buffer.from_pyval(arg)
|
||||
|
Loading…
x
Reference in New Issue
Block a user