[XLA:Python] Update tpu_driver to add the same automatic tupling of arguments and untupling of results present in the local client.

Update tests to use the automatic untupling support.

PiperOrigin-RevId: 301623333
Change-Id: I1233e6a63eaea2bfef2ac7a85bf1b55b820361d1
This commit is contained in:
Peter Hawkins 2020-03-18 10:56:23 -07:00 committed by TensorFlower Gardener
parent 4679feb3ce
commit 85f7677b4a
6 changed files with 275 additions and 251 deletions

View File

@ -227,8 +227,8 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
/* static */
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
const std::vector<PyTpuBuffer*> buffers,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device) {
absl::Span<PyTpuBuffer* const> buffers, std::shared_ptr<PyTpuClient> client,
std::shared_ptr<Device> device) {
std::vector<Shape> child_shapes;
std::vector<std::shared_ptr<TpuSharedBuffer>> child_device_buffers;
std::vector<tpu_driver::BufferHandle*> child_handle_ptrs;
@ -611,8 +611,8 @@ Status WaitForExecuteEvent(tpu_driver::Event* event) {
return opt_status.value();
}
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuExecutable::Execute(
absl::Span<PyTpuBuffer* const> argument_handles) {
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
absl::Span<PyTpuBuffer* const> argument_handles, bool tuple_arguments) {
if (num_replicas() != 1) {
return InvalidArgument(
"Attempted to execute computation with %d replicas using Execute().",
@ -624,9 +624,18 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuExecutable::Execute(
num_partitions());
}
std::vector<PyTpuBuffer*> all_core_arguments(argument_handles.begin(),
argument_handles.end());
std::vector<PyTpuBuffer*> all_core_arguments;
std::unique_ptr<PyTpuBuffer> tupled_arguments;
if (tuple_arguments) {
TF_ASSIGN_OR_RETURN(tupled_arguments,
PyTpuBuffer::MakeTuple(argument_handles, client_,
local_devices_.front()));
all_core_arguments = {tupled_arguments.get()};
} else {
all_core_arguments = std::vector<PyTpuBuffer*>(argument_handles.begin(),
argument_handles.end());
}
ExecuteResult result =
ExecuteHelper(absl::MakeSpan(&all_core_arguments, 1), argument_handles,
/*replica=*/0, /*partition=*/0, RunId());
@ -638,12 +647,19 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuExecutable::Execute(
return status;
}
return std::move(result.buffer);
if (result.buffer->on_host_shape().IsTuple()) {
return result.buffer->DestructureTuple();
} else {
std::vector<std::unique_ptr<PyTpuBuffer>> outputs;
outputs.push_back(std::move(result.buffer));
return outputs;
}
}
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>>
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
PyTpuExecutable::ExecuteOnLocalDevices(
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles) {
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles,
bool tuple_arguments) {
tensorflow::profiler::TraceMe traceme(
"PyTpuExecutable::ExecuteOnLocalDevices");
@ -661,6 +677,20 @@ PyTpuExecutable::ExecuteOnLocalDevices(
<< " num_partitions=" << num_partitions()
<< " num_local_devices=" << num_local_devices;
std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
if (tuple_arguments) {
tupled_arguments.resize(argument_handles.size());
tupled_argument_pointers.resize(argument_handles.size());
for (int i = 0; i < num_local_devices; ++i) {
TF_ASSIGN_OR_RETURN(tupled_arguments[i],
PyTpuBuffer::MakeTuple(argument_handles[i], client_,
local_devices_.at(i)));
tupled_argument_pointers[i] = {tupled_arguments[i].get()};
}
argument_handles = tupled_argument_pointers;
}
absl::Mutex results_lock;
std::vector<ExecuteResult> results(num_local_devices);
@ -702,9 +732,15 @@ PyTpuExecutable::ExecuteOnLocalDevices(
}
VLOG(1) << "Replicated execution complete.";
std::vector<std::unique_ptr<PyTpuBuffer>> wrapped_results(num_local_devices);
std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>> wrapped_results(
num_local_devices);
for (int i = 0; i < num_local_devices; ++i) {
wrapped_results[i] = std::move(results[i].buffer);
if (results[i].buffer->on_host_shape().IsTuple()) {
TF_ASSIGN_OR_RETURN(wrapped_results[i],
results[i].buffer->DestructureTuple());
} else {
wrapped_results[i].push_back(std::move(results[i].buffer));
}
}
return wrapped_results;
}

View File

@ -166,7 +166,7 @@ class PyTpuBuffer {
// Supports nested tuple creation.
static StatusOr<std::unique_ptr<PyTpuBuffer>> MakeTuple(
const std::vector<PyTpuBuffer*> buffers,
absl::Span<PyTpuBuffer* const> buffers,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device);
PyTpuBuffer() = delete;
@ -308,15 +308,17 @@ class PyTpuExecutable {
// TODO(power): Both Execute and ExecutePerOnLocalDevices block and wait
// inside for computation to finish. Coordinate with JAX code change to see if
// we can make both Execute and ExecutePerReplica non-blocking.
StatusOr<std::unique_ptr<PyTpuBuffer>> Execute(
absl::Span<PyTpuBuffer* const> argument_handles);
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> Execute(
absl::Span<PyTpuBuffer* const> argument_handles, bool tuple_arguments);
// Execute on local devices. Takes a sequence of argument lists (one argument
// list per local device) and returns a tuple of results (one result per local
// device). The number of argument lists must be equal to the local device
// count.
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> ExecuteOnLocalDevices(
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles);
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
ExecuteOnLocalDevices(
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles,
bool tuple_arguments);
void Delete() { executables_.clear(); }

View File

@ -203,9 +203,11 @@ PYBIND11_MODULE(tpu_client_extension, m) {
&PyTpuExecutable::SizeOfGeneratedCodeInBytes)
.def("Delete", &PyTpuExecutable::Delete)
.def("Execute", &PyTpuExecutable::Execute,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
py::arg("tuple_arguments"))
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
py::arg("tuple_arguments"));
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
.def_property_readonly("coords", &TpuDevice::coords)

View File

@ -1133,20 +1133,6 @@ PYBIND11_MODULE(xla_extension, m) {
.def("SizeOfGeneratedCodeInBytes",
&PyLocalExecutable::SizeOfGeneratedCodeInBytes)
.def("Delete", &PyLocalExecutable::Delete)
.def(
"Execute",
[](const PyLocalExecutable& executable,
absl::Span<PyLocalBuffer* const> args)
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<PyLocalBuffer>> output,
executable.Execute(args, ExecuteOptions()));
return WrapWithClient(executable.client()->shared_from_this(),
std::move(output.front()));
},
py::arg("arguments"))
// TODO(phawkins): remove in favor of overload that returns a vector.
.def(
"Execute",
[](const PyLocalExecutable& executable,
@ -1168,27 +1154,6 @@ PYBIND11_MODULE(xla_extension, m) {
return outputs;
},
py::arg("arguments"), py::arg("tuple_arguments"))
// TODO(phawkins): remove in favor of overload that returns a vector.
.def(
"ExecuteOnLocalDevices",
[](const PyLocalExecutable& executable,
absl::Span<const std::vector<PyLocalBuffer*>> args)
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
output_buffers,
executable.ExecuteOnLocalDevices(args, ExecuteOptions()));
std::vector<ClientAndUniquePtr<PyLocalBuffer>> outputs;
outputs.reserve(output_buffers.size());
for (auto& buffers : output_buffers) {
outputs.push_back(
WrapWithClient(executable.client()->shared_from_this(),
std::move(buffers.front())));
}
return outputs;
},
py::arg("arguments"))
.def(
"ExecuteOnLocalDevices",
[](const PyLocalExecutable& executable,

View File

@ -42,7 +42,6 @@ from tensorflow.compiler.xla.python.xla_extension import ops
# consistency with XLA.
# pylint: disable=invalid-name
profiler = _xla.profiler
@ -454,8 +453,8 @@ def transfer_to_infeed(value, device=None):
Args:
value: the value that the caller would like to enqueue into the XLA infeed
queue
device: the device to infeed the value to. Each device has a
distinct infeed queue.
device: the device to infeed the value to. Each device has a distinct infeed
queue.
"""
# TODO(phawkins): support non-default backends.
backend = get_local_backend()
@ -501,7 +500,6 @@ def computation_count():
'''Returns the number of computations per replica.'''
"""
Device = _xla.Device
@ -633,7 +631,8 @@ def execute_with_python_values(executable, arguments=(), backend=None):
arg, device=executable.local_devices()[0], backend=backend)
arguments = [put(arg) for arg in arguments]
return executable.Execute(arguments).to_py()
outputs = executable.Execute(arguments, tuple_arguments=False)
return [x.to_py() for x in outputs]
def execute_with_python_values_replicated(executable, arguments, backend=None):
@ -641,8 +640,8 @@ def execute_with_python_values_replicated(executable, arguments, backend=None):
Arguments:
executable: the program to run.
arguments: a list of lists of Python values indexed by
`[replica][arg_num]` to pass as inputs.
arguments: a list of lists of Python values indexed by `[replica][arg_num]`
to pass as inputs.
backend: the backend we are targeting.
Returns:
@ -661,7 +660,8 @@ def execute_with_python_values_replicated(executable, arguments, backend=None):
for replica_args in arguments:
arg_buffers.append(flat_arg_buffers[:len(replica_args)])
flat_arg_buffers = flat_arg_buffers[len(replica_args):]
return [out.to_py() for out in executable.ExecuteOnLocalDevices(arg_buffers)]
return [[x.to_py() for x in xs] for xs in executable.ExecuteOnLocalDevices(
arg_buffers, tuple_arguments=False)]
class PaddingType(enum.Enum):
@ -787,6 +787,7 @@ class ComputationBuilder(object):
shape: a `Shape` describing the shape of the infed value.
token: an optional `XlaOp` representing a token after which the infeed
effect should be sequenced.
Returns:
An XlaOp, representing a (value, token) pair.
"""
@ -805,6 +806,7 @@ class ComputationBuilder(object):
operand: an `XlaOp` representing the data to outfeed.
token: an `XlaOp` representing a token after which the outfeed should be
sequenced.
Returns:
An `XlaOp` representing a token.
"""
@ -880,7 +882,10 @@ class ComputationBuilder(object):
"""
return self.Constant(np.array(value, dtype=np.bool))
def ParameterWithShape(self, shape, name=None, parameter_num=None,
def ParameterWithShape(self,
shape,
name=None,
parameter_num=None,
replicated=False):
"""Enqueues a Parameter op onto the computation, given a shape.
@ -891,8 +896,8 @@ class ComputationBuilder(object):
next linear parameter number is used. The default value capability can
be used for auto-numbering. If you're using auto-numbering for some
parameters, use it for *all* parameters to avoid clashes.
replicated: whether to mark the parameter's leaves as replicated. May be
a bool, in which case it applies to all leaves, or an iterable of bools.
replicated: whether to mark the parameter's leaves as replicated. May be a
bool, in which case it applies to all leaves, or an iterable of bools.
Returns:
An XlaOp.
@ -1791,6 +1796,7 @@ def register_custom_call_target(name, fn, platform='cpu'):
"""
_xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform])
# Deprecated. Use register_custom_call_target instead.
register_cpu_custom_call_target = register_custom_call_target

File diff suppressed because it is too large Load Diff