[XLA:Python] Remove the tuple_arguments argument from Execute().

tuple_arguments can be passed to Compile() instead.

PiperOrigin-RevId: 304083684
Change-Id: I310986631c1489d0e21cff73a75c15ed41b4c747
This commit is contained in:
Peter Hawkins 2020-03-31 17:44:59 -07:00 committed by TensorFlower Gardener
parent 486a20777c
commit 566c03a749
7 changed files with 17 additions and 27 deletions

View File

@ -728,7 +728,7 @@ PyLocalExecutable::ExecuteHelper(
std::unique_ptr<PyLocalBuffer> tuple_buffer;
std::vector<PyLocalBuffer*> tupled_arguments;
if (options.tuple_arguments || tuple_arguments_) {
if (tuple_arguments_) {
TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple(
argument_handles, client_, device));
tupled_arguments = {tuple_buffer.get()};

View File

@ -323,10 +323,6 @@ struct CompileOptions {
};
struct ExecuteOptions {
// If true, the arguments to the computation will be wrapped in a tuple and
// passed as a single parameter.
bool tuple_arguments = false;
// If true, the computation must return a tuple, which will be destructured
// into its elements.
bool untuple_result = false;

View File

@ -613,7 +613,7 @@ Status WaitForExecuteEvent(tpu_driver::Event* event) {
}
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
absl::Span<PyTpuBuffer* const> argument_handles, bool tuple_arguments) {
absl::Span<PyTpuBuffer* const> argument_handles) {
if (num_replicas() != 1) {
return InvalidArgument(
"Attempted to execute computation with %d replicas using Execute().",
@ -628,7 +628,7 @@ StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
std::vector<PyTpuBuffer*> all_core_arguments;
std::unique_ptr<PyTpuBuffer> tupled_arguments;
if (tuple_arguments_ || tuple_arguments) {
if (tuple_arguments_) {
TF_ASSIGN_OR_RETURN(tupled_arguments,
PyTpuBuffer::MakeTuple(argument_handles, client_,
local_devices_.front()));
@ -659,8 +659,7 @@ StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
PyTpuExecutable::ExecuteOnLocalDevices(
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles,
bool tuple_arguments) {
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles) {
tensorflow::profiler::TraceMe traceme(
"PyTpuExecutable::ExecuteOnLocalDevices");
@ -680,7 +679,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
if (tuple_arguments_ || tuple_arguments) {
if (tuple_arguments_) {
tupled_arguments.resize(argument_handles.size());
tupled_argument_pointers.resize(argument_handles.size());
for (int i = 0; i < num_local_devices; ++i) {

View File

@ -309,7 +309,7 @@ class PyTpuExecutable {
// inside for computation to finish. Coordinate with JAX code change to see if
// we can make both Execute and ExecutePerReplica non-blocking.
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> Execute(
absl::Span<PyTpuBuffer* const> argument_handles, bool tuple_arguments);
absl::Span<PyTpuBuffer* const> argument_handles);
// Execute on local devices. Takes a sequence of argument lists (one argument
// list per local device) and returns a tuple of results (one result per local
@ -317,8 +317,7 @@ class PyTpuExecutable {
// count.
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
ExecuteOnLocalDevices(
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles,
bool tuple_arguments);
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles);
void Delete() { executables_.clear(); }

View File

@ -188,11 +188,9 @@ PYBIND11_MODULE(tpu_client_extension, m) {
&PyTpuExecutable::SizeOfGeneratedCodeInBytes)
.def("Delete", &PyTpuExecutable::Delete)
.def("Execute", &PyTpuExecutable::Execute,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
py::arg("tuple_arguments") = false)
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
py::arg("tuple_arguments") = false);
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
.def_property_readonly("coords", &TpuDevice::coords)

View File

@ -1109,11 +1109,10 @@ PYBIND11_MODULE(xla_extension, m) {
.def(
"Execute",
[](const PyLocalExecutable& executable,
absl::Span<PyLocalBuffer* const> args, bool tuple_arguments)
absl::Span<PyLocalBuffer* const> args)
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.tuple_arguments = tuple_arguments;
options.untuple_result = true;
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<PyLocalBuffer>> output_buffers,
@ -1126,17 +1125,15 @@ PYBIND11_MODULE(xla_extension, m) {
}
return outputs;
},
py::arg("arguments"), py::arg("tuple_arguments") = false)
py::arg("arguments"))
.def(
"ExecuteOnLocalDevices",
[](const PyLocalExecutable& executable,
absl::Span<const std::vector<PyLocalBuffer*>> args,
bool tuple_arguments)
absl::Span<const std::vector<PyLocalBuffer*>> args)
-> StatusOr<
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>>> {
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.tuple_arguments = tuple_arguments;
options.untuple_result = true;
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
@ -1154,7 +1151,7 @@ PYBIND11_MODULE(xla_extension, m) {
}
return outputs;
},
py::arg("arguments"), py::arg("tuple_arguments") = false)
py::arg("arguments"))
.def(
"get_hlo_modules",
[](const PyLocalExecutable& executable)

View File

@ -615,7 +615,7 @@ def execute_with_python_values(executable, arguments=(), backend=None):
arg, device=executable.local_devices()[0], backend=backend)
arguments = [put(arg) for arg in arguments]
outputs = executable.Execute(arguments, tuple_arguments=False)
outputs = executable.Execute(arguments)
return [x.to_py() for x in outputs]
@ -644,8 +644,9 @@ def execute_with_python_values_replicated(executable, arguments, backend=None):
for replica_args in arguments:
arg_buffers.append(flat_arg_buffers[:len(replica_args)])
flat_arg_buffers = flat_arg_buffers[len(replica_args):]
return [[x.to_py() for x in xs] for xs in executable.ExecuteOnLocalDevices(
arg_buffers, tuple_arguments=False)]
return [[x.to_py()
for x in xs]
for xs in executable.ExecuteOnLocalDevices(arg_buffers)]
class PaddingType(enum.Enum):