[XLA:Python] Update tpu_driver to add the same automatic tupling of arguments and untupling of results present in the local client.
Update tests to use the automatic untupling support. PiperOrigin-RevId: 301623333 Change-Id: I1233e6a63eaea2bfef2ac7a85bf1b55b820361d1
This commit is contained in:
parent
4679feb3ce
commit
85f7677b4a
@ -227,8 +227,8 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
|
||||
const std::vector<PyTpuBuffer*> buffers,
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device) {
|
||||
absl::Span<PyTpuBuffer* const> buffers, std::shared_ptr<PyTpuClient> client,
|
||||
std::shared_ptr<Device> device) {
|
||||
std::vector<Shape> child_shapes;
|
||||
std::vector<std::shared_ptr<TpuSharedBuffer>> child_device_buffers;
|
||||
std::vector<tpu_driver::BufferHandle*> child_handle_ptrs;
|
||||
@ -611,8 +611,8 @@ Status WaitForExecuteEvent(tpu_driver::Event* event) {
|
||||
return opt_status.value();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuExecutable::Execute(
|
||||
absl::Span<PyTpuBuffer* const> argument_handles) {
|
||||
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
|
||||
absl::Span<PyTpuBuffer* const> argument_handles, bool tuple_arguments) {
|
||||
if (num_replicas() != 1) {
|
||||
return InvalidArgument(
|
||||
"Attempted to execute computation with %d replicas using Execute().",
|
||||
@ -624,9 +624,18 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuExecutable::Execute(
|
||||
num_partitions());
|
||||
}
|
||||
|
||||
std::vector<PyTpuBuffer*> all_core_arguments(argument_handles.begin(),
|
||||
argument_handles.end());
|
||||
std::vector<PyTpuBuffer*> all_core_arguments;
|
||||
|
||||
std::unique_ptr<PyTpuBuffer> tupled_arguments;
|
||||
if (tuple_arguments) {
|
||||
TF_ASSIGN_OR_RETURN(tupled_arguments,
|
||||
PyTpuBuffer::MakeTuple(argument_handles, client_,
|
||||
local_devices_.front()));
|
||||
all_core_arguments = {tupled_arguments.get()};
|
||||
} else {
|
||||
all_core_arguments = std::vector<PyTpuBuffer*>(argument_handles.begin(),
|
||||
argument_handles.end());
|
||||
}
|
||||
ExecuteResult result =
|
||||
ExecuteHelper(absl::MakeSpan(&all_core_arguments, 1), argument_handles,
|
||||
/*replica=*/0, /*partition=*/0, RunId());
|
||||
@ -638,12 +647,19 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuExecutable::Execute(
|
||||
return status;
|
||||
}
|
||||
|
||||
return std::move(result.buffer);
|
||||
if (result.buffer->on_host_shape().IsTuple()) {
|
||||
return result.buffer->DestructureTuple();
|
||||
} else {
|
||||
std::vector<std::unique_ptr<PyTpuBuffer>> outputs;
|
||||
outputs.push_back(std::move(result.buffer));
|
||||
return outputs;
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>>
|
||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
|
||||
PyTpuExecutable::ExecuteOnLocalDevices(
|
||||
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles) {
|
||||
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles,
|
||||
bool tuple_arguments) {
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PyTpuExecutable::ExecuteOnLocalDevices");
|
||||
|
||||
@ -661,6 +677,20 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
||||
<< " num_partitions=" << num_partitions()
|
||||
<< " num_local_devices=" << num_local_devices;
|
||||
|
||||
std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
|
||||
std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
|
||||
if (tuple_arguments) {
|
||||
tupled_arguments.resize(argument_handles.size());
|
||||
tupled_argument_pointers.resize(argument_handles.size());
|
||||
for (int i = 0; i < num_local_devices; ++i) {
|
||||
TF_ASSIGN_OR_RETURN(tupled_arguments[i],
|
||||
PyTpuBuffer::MakeTuple(argument_handles[i], client_,
|
||||
local_devices_.at(i)));
|
||||
tupled_argument_pointers[i] = {tupled_arguments[i].get()};
|
||||
}
|
||||
argument_handles = tupled_argument_pointers;
|
||||
}
|
||||
|
||||
absl::Mutex results_lock;
|
||||
std::vector<ExecuteResult> results(num_local_devices);
|
||||
|
||||
@ -702,9 +732,15 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
||||
}
|
||||
VLOG(1) << "Replicated execution complete.";
|
||||
|
||||
std::vector<std::unique_ptr<PyTpuBuffer>> wrapped_results(num_local_devices);
|
||||
std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> wrapped_results(
|
||||
num_local_devices);
|
||||
for (int i = 0; i < num_local_devices; ++i) {
|
||||
wrapped_results[i] = std::move(results[i].buffer);
|
||||
if (results[i].buffer->on_host_shape().IsTuple()) {
|
||||
TF_ASSIGN_OR_RETURN(wrapped_results[i],
|
||||
results[i].buffer->DestructureTuple());
|
||||
} else {
|
||||
wrapped_results[i].push_back(std::move(results[i].buffer));
|
||||
}
|
||||
}
|
||||
return wrapped_results;
|
||||
}
|
||||
|
||||
@ -166,7 +166,7 @@ class PyTpuBuffer {
|
||||
|
||||
// Supports nested tuple creation.
|
||||
static StatusOr<std::unique_ptr<PyTpuBuffer>> MakeTuple(
|
||||
const std::vector<PyTpuBuffer*> buffers,
|
||||
absl::Span<PyTpuBuffer* const> buffers,
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device);
|
||||
|
||||
PyTpuBuffer() = delete;
|
||||
@ -308,15 +308,17 @@ class PyTpuExecutable {
|
||||
// TODO(power): Both Execute and ExecutePerOnLocalDevices block and wait
|
||||
// inside for computation to finish. Coordinate with JAX code change to see if
|
||||
// we can make both Execute and ExecutePerReplica non-blocking.
|
||||
StatusOr<std::unique_ptr<PyTpuBuffer>> Execute(
|
||||
absl::Span<PyTpuBuffer* const> argument_handles);
|
||||
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> Execute(
|
||||
absl::Span<PyTpuBuffer* const> argument_handles, bool tuple_arguments);
|
||||
|
||||
// Execute on local devices. Takes a sequence of argument lists (one argument
|
||||
// list per local device) and returns a tuple of results (one result per local
|
||||
// device). The number of argument lists must be equal to the local device
|
||||
// count.
|
||||
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> ExecuteOnLocalDevices(
|
||||
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles);
|
||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
|
||||
ExecuteOnLocalDevices(
|
||||
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles,
|
||||
bool tuple_arguments);
|
||||
|
||||
void Delete() { executables_.clear(); }
|
||||
|
||||
|
||||
@ -203,9 +203,11 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
&PyTpuExecutable::SizeOfGeneratedCodeInBytes)
|
||||
.def("Delete", &PyTpuExecutable::Delete)
|
||||
.def("Execute", &PyTpuExecutable::Execute,
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
|
||||
py::arg("tuple_arguments"))
|
||||
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
|
||||
py::arg("tuple_arguments"));
|
||||
|
||||
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||
.def_property_readonly("coords", &TpuDevice::coords)
|
||||
|
||||
@ -1133,20 +1133,6 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("SizeOfGeneratedCodeInBytes",
|
||||
&PyLocalExecutable::SizeOfGeneratedCodeInBytes)
|
||||
.def("Delete", &PyLocalExecutable::Delete)
|
||||
.def(
|
||||
"Execute",
|
||||
[](const PyLocalExecutable& executable,
|
||||
absl::Span<PyLocalBuffer* const> args)
|
||||
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::unique_ptr<PyLocalBuffer>> output,
|
||||
executable.Execute(args, ExecuteOptions()));
|
||||
return WrapWithClient(executable.client()->shared_from_this(),
|
||||
std::move(output.front()));
|
||||
},
|
||||
py::arg("arguments"))
|
||||
// TODO(phawkins): remove in favor of overload that returns a vector.
|
||||
.def(
|
||||
"Execute",
|
||||
[](const PyLocalExecutable& executable,
|
||||
@ -1168,27 +1154,6 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
return outputs;
|
||||
},
|
||||
py::arg("arguments"), py::arg("tuple_arguments"))
|
||||
// TODO(phawkins): remove in favor of overload that returns a vector.
|
||||
.def(
|
||||
"ExecuteOnLocalDevices",
|
||||
[](const PyLocalExecutable& executable,
|
||||
absl::Span<const std::vector<PyLocalBuffer*>> args)
|
||||
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||
output_buffers,
|
||||
executable.ExecuteOnLocalDevices(args, ExecuteOptions()));
|
||||
std::vector<ClientAndUniquePtr<PyLocalBuffer>> outputs;
|
||||
outputs.reserve(output_buffers.size());
|
||||
for (auto& buffers : output_buffers) {
|
||||
outputs.push_back(
|
||||
WrapWithClient(executable.client()->shared_from_this(),
|
||||
std::move(buffers.front())));
|
||||
}
|
||||
return outputs;
|
||||
},
|
||||
py::arg("arguments"))
|
||||
.def(
|
||||
"ExecuteOnLocalDevices",
|
||||
[](const PyLocalExecutable& executable,
|
||||
|
||||
@ -42,7 +42,6 @@ from tensorflow.compiler.xla.python.xla_extension import ops
|
||||
# consistency with XLA.
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
|
||||
profiler = _xla.profiler
|
||||
|
||||
|
||||
@ -454,8 +453,8 @@ def transfer_to_infeed(value, device=None):
|
||||
Args:
|
||||
value: the value that the caller would like to enqueue into the XLA infeed
|
||||
queue
|
||||
device: the device to infeed the value to. Each device has a
|
||||
distinct infeed queue.
|
||||
device: the device to infeed the value to. Each device has a distinct infeed
|
||||
queue.
|
||||
"""
|
||||
# TODO(phawkins): support non-default backends.
|
||||
backend = get_local_backend()
|
||||
@ -501,7 +500,6 @@ def computation_count():
|
||||
'''Returns the number of computations per replica.'''
|
||||
"""
|
||||
|
||||
|
||||
Device = _xla.Device
|
||||
|
||||
|
||||
@ -633,7 +631,8 @@ def execute_with_python_values(executable, arguments=(), backend=None):
|
||||
arg, device=executable.local_devices()[0], backend=backend)
|
||||
|
||||
arguments = [put(arg) for arg in arguments]
|
||||
return executable.Execute(arguments).to_py()
|
||||
outputs = executable.Execute(arguments, tuple_arguments=False)
|
||||
return [x.to_py() for x in outputs]
|
||||
|
||||
|
||||
def execute_with_python_values_replicated(executable, arguments, backend=None):
|
||||
@ -641,8 +640,8 @@ def execute_with_python_values_replicated(executable, arguments, backend=None):
|
||||
|
||||
Arguments:
|
||||
executable: the program to run.
|
||||
arguments: a list of lists of Python values indexed by
|
||||
`[replica][arg_num]` to pass as inputs.
|
||||
arguments: a list of lists of Python values indexed by `[replica][arg_num]`
|
||||
to pass as inputs.
|
||||
backend: the backend we are targeting.
|
||||
|
||||
Returns:
|
||||
@ -661,7 +660,8 @@ def execute_with_python_values_replicated(executable, arguments, backend=None):
|
||||
for replica_args in arguments:
|
||||
arg_buffers.append(flat_arg_buffers[:len(replica_args)])
|
||||
flat_arg_buffers = flat_arg_buffers[len(replica_args):]
|
||||
return [out.to_py() for out in executable.ExecuteOnLocalDevices(arg_buffers)]
|
||||
return [[x.to_py() for x in xs] for xs in executable.ExecuteOnLocalDevices(
|
||||
arg_buffers, tuple_arguments=False)]
|
||||
|
||||
|
||||
class PaddingType(enum.Enum):
|
||||
@ -787,6 +787,7 @@ class ComputationBuilder(object):
|
||||
shape: a `Shape` describing the shape of the infed value.
|
||||
token: an optional `XlaOp` representing a token after which the infeed
|
||||
effect should be sequenced.
|
||||
|
||||
Returns:
|
||||
An XlaOp, representing a (value, token) pair.
|
||||
"""
|
||||
@ -805,6 +806,7 @@ class ComputationBuilder(object):
|
||||
operand: an `XlaOp` representing the data to outfeed.
|
||||
token: an `XlaOp` representing a token after which the outfeed should be
|
||||
sequenced.
|
||||
|
||||
Returns:
|
||||
An `XlaOp` representing a token.
|
||||
"""
|
||||
@ -880,7 +882,10 @@ class ComputationBuilder(object):
|
||||
"""
|
||||
return self.Constant(np.array(value, dtype=np.bool))
|
||||
|
||||
def ParameterWithShape(self, shape, name=None, parameter_num=None,
|
||||
def ParameterWithShape(self,
|
||||
shape,
|
||||
name=None,
|
||||
parameter_num=None,
|
||||
replicated=False):
|
||||
"""Enqueues a Parameter op onto the computation, given a shape.
|
||||
|
||||
@ -891,8 +896,8 @@ class ComputationBuilder(object):
|
||||
next linear parameter number is used. The default value capability can
|
||||
be used for auto-numbering. If you're using auto-numbering for some
|
||||
parameters, use it for *all* parameters to avoid clashes.
|
||||
replicated: whether to mark the parameter's leaves as replicated. May be
|
||||
a bool, in which case it applies to all leaves, or an iterable of bools.
|
||||
replicated: whether to mark the parameter's leaves as replicated. May be a
|
||||
bool, in which case it applies to all leaves, or an iterable of bools.
|
||||
|
||||
Returns:
|
||||
An XlaOp.
|
||||
@ -1791,6 +1796,7 @@ def register_custom_call_target(name, fn, platform='cpu'):
|
||||
"""
|
||||
_xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform])
|
||||
|
||||
|
||||
# Deprecated. Use register_custom_call_target instead.
|
||||
register_cpu_custom_call_target = register_custom_call_target
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user