[XLA:Python] Remove the tuple_arguments argument from Execute().
tuple_arguments can be passed to Compile() instead. PiperOrigin-RevId: 304083684 Change-Id: I310986631c1489d0e21cff73a75c15ed41b4c747
This commit is contained in:
parent
486a20777c
commit
566c03a749
@ -728,7 +728,7 @@ PyLocalExecutable::ExecuteHelper(
|
|||||||
|
|
||||||
std::unique_ptr<PyLocalBuffer> tuple_buffer;
|
std::unique_ptr<PyLocalBuffer> tuple_buffer;
|
||||||
std::vector<PyLocalBuffer*> tupled_arguments;
|
std::vector<PyLocalBuffer*> tupled_arguments;
|
||||||
if (options.tuple_arguments || tuple_arguments_) {
|
if (tuple_arguments_) {
|
||||||
TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple(
|
TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple(
|
||||||
argument_handles, client_, device));
|
argument_handles, client_, device));
|
||||||
tupled_arguments = {tuple_buffer.get()};
|
tupled_arguments = {tuple_buffer.get()};
|
||||||
|
@ -323,10 +323,6 @@ struct CompileOptions {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct ExecuteOptions {
|
struct ExecuteOptions {
|
||||||
// If true, the arguments to the computation will be wrapped in a tuple and
|
|
||||||
// passed as a single parameter.
|
|
||||||
bool tuple_arguments = false;
|
|
||||||
|
|
||||||
// If true, the computation must return a tuple, which will be destructured
|
// If true, the computation must return a tuple, which will be destructured
|
||||||
// into its elements.
|
// into its elements.
|
||||||
bool untuple_result = false;
|
bool untuple_result = false;
|
||||||
|
@ -613,7 +613,7 @@ Status WaitForExecuteEvent(tpu_driver::Event* event) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
|
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
|
||||||
absl::Span<PyTpuBuffer* const> argument_handles, bool tuple_arguments) {
|
absl::Span<PyTpuBuffer* const> argument_handles) {
|
||||||
if (num_replicas() != 1) {
|
if (num_replicas() != 1) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Attempted to execute computation with %d replicas using Execute().",
|
"Attempted to execute computation with %d replicas using Execute().",
|
||||||
@ -628,7 +628,7 @@ StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
|
|||||||
std::vector<PyTpuBuffer*> all_core_arguments;
|
std::vector<PyTpuBuffer*> all_core_arguments;
|
||||||
|
|
||||||
std::unique_ptr<PyTpuBuffer> tupled_arguments;
|
std::unique_ptr<PyTpuBuffer> tupled_arguments;
|
||||||
if (tuple_arguments_ || tuple_arguments) {
|
if (tuple_arguments_) {
|
||||||
TF_ASSIGN_OR_RETURN(tupled_arguments,
|
TF_ASSIGN_OR_RETURN(tupled_arguments,
|
||||||
PyTpuBuffer::MakeTuple(argument_handles, client_,
|
PyTpuBuffer::MakeTuple(argument_handles, client_,
|
||||||
local_devices_.front()));
|
local_devices_.front()));
|
||||||
@ -659,8 +659,7 @@ StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
|
|||||||
|
|
||||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
|
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
|
||||||
PyTpuExecutable::ExecuteOnLocalDevices(
|
PyTpuExecutable::ExecuteOnLocalDevices(
|
||||||
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles,
|
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles) {
|
||||||
bool tuple_arguments) {
|
|
||||||
tensorflow::profiler::TraceMe traceme(
|
tensorflow::profiler::TraceMe traceme(
|
||||||
"PyTpuExecutable::ExecuteOnLocalDevices");
|
"PyTpuExecutable::ExecuteOnLocalDevices");
|
||||||
|
|
||||||
@ -680,7 +679,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
|||||||
|
|
||||||
std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
|
std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
|
||||||
std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
|
std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
|
||||||
if (tuple_arguments_ || tuple_arguments) {
|
if (tuple_arguments_) {
|
||||||
tupled_arguments.resize(argument_handles.size());
|
tupled_arguments.resize(argument_handles.size());
|
||||||
tupled_argument_pointers.resize(argument_handles.size());
|
tupled_argument_pointers.resize(argument_handles.size());
|
||||||
for (int i = 0; i < num_local_devices; ++i) {
|
for (int i = 0; i < num_local_devices; ++i) {
|
||||||
|
@ -309,7 +309,7 @@ class PyTpuExecutable {
|
|||||||
// inside for computation to finish. Coordinate with JAX code change to see if
|
// inside for computation to finish. Coordinate with JAX code change to see if
|
||||||
// we can make both Execute and ExecutePerReplica non-blocking.
|
// we can make both Execute and ExecutePerReplica non-blocking.
|
||||||
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> Execute(
|
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> Execute(
|
||||||
absl::Span<PyTpuBuffer* const> argument_handles, bool tuple_arguments);
|
absl::Span<PyTpuBuffer* const> argument_handles);
|
||||||
|
|
||||||
// Execute on local devices. Takes a sequence of argument lists (one argument
|
// Execute on local devices. Takes a sequence of argument lists (one argument
|
||||||
// list per local device) and returns a tuple of results (one result per local
|
// list per local device) and returns a tuple of results (one result per local
|
||||||
@ -317,8 +317,7 @@ class PyTpuExecutable {
|
|||||||
// count.
|
// count.
|
||||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
|
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
|
||||||
ExecuteOnLocalDevices(
|
ExecuteOnLocalDevices(
|
||||||
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles,
|
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles);
|
||||||
bool tuple_arguments);
|
|
||||||
|
|
||||||
void Delete() { executables_.clear(); }
|
void Delete() { executables_.clear(); }
|
||||||
|
|
||||||
|
@ -188,11 +188,9 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
&PyTpuExecutable::SizeOfGeneratedCodeInBytes)
|
&PyTpuExecutable::SizeOfGeneratedCodeInBytes)
|
||||||
.def("Delete", &PyTpuExecutable::Delete)
|
.def("Delete", &PyTpuExecutable::Delete)
|
||||||
.def("Execute", &PyTpuExecutable::Execute,
|
.def("Execute", &PyTpuExecutable::Execute,
|
||||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
|
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
|
||||||
py::arg("tuple_arguments") = false)
|
|
||||||
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
|
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
|
||||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
|
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
|
||||||
py::arg("tuple_arguments") = false);
|
|
||||||
|
|
||||||
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||||
.def_property_readonly("coords", &TpuDevice::coords)
|
.def_property_readonly("coords", &TpuDevice::coords)
|
||||||
|
@ -1109,11 +1109,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def(
|
.def(
|
||||||
"Execute",
|
"Execute",
|
||||||
[](const PyLocalExecutable& executable,
|
[](const PyLocalExecutable& executable,
|
||||||
absl::Span<PyLocalBuffer* const> args, bool tuple_arguments)
|
absl::Span<PyLocalBuffer* const> args)
|
||||||
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
|
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
ExecuteOptions options;
|
ExecuteOptions options;
|
||||||
options.tuple_arguments = tuple_arguments;
|
|
||||||
options.untuple_result = true;
|
options.untuple_result = true;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<std::unique_ptr<PyLocalBuffer>> output_buffers,
|
std::vector<std::unique_ptr<PyLocalBuffer>> output_buffers,
|
||||||
@ -1126,17 +1125,15 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
}
|
}
|
||||||
return outputs;
|
return outputs;
|
||||||
},
|
},
|
||||||
py::arg("arguments"), py::arg("tuple_arguments") = false)
|
py::arg("arguments"))
|
||||||
.def(
|
.def(
|
||||||
"ExecuteOnLocalDevices",
|
"ExecuteOnLocalDevices",
|
||||||
[](const PyLocalExecutable& executable,
|
[](const PyLocalExecutable& executable,
|
||||||
absl::Span<const std::vector<PyLocalBuffer*>> args,
|
absl::Span<const std::vector<PyLocalBuffer*>> args)
|
||||||
bool tuple_arguments)
|
|
||||||
-> StatusOr<
|
-> StatusOr<
|
||||||
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>>> {
|
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>>> {
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
ExecuteOptions options;
|
ExecuteOptions options;
|
||||||
options.tuple_arguments = tuple_arguments;
|
|
||||||
options.untuple_result = true;
|
options.untuple_result = true;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||||
@ -1154,7 +1151,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
}
|
}
|
||||||
return outputs;
|
return outputs;
|
||||||
},
|
},
|
||||||
py::arg("arguments"), py::arg("tuple_arguments") = false)
|
py::arg("arguments"))
|
||||||
.def(
|
.def(
|
||||||
"get_hlo_modules",
|
"get_hlo_modules",
|
||||||
[](const PyLocalExecutable& executable)
|
[](const PyLocalExecutable& executable)
|
||||||
|
@ -615,7 +615,7 @@ def execute_with_python_values(executable, arguments=(), backend=None):
|
|||||||
arg, device=executable.local_devices()[0], backend=backend)
|
arg, device=executable.local_devices()[0], backend=backend)
|
||||||
|
|
||||||
arguments = [put(arg) for arg in arguments]
|
arguments = [put(arg) for arg in arguments]
|
||||||
outputs = executable.Execute(arguments, tuple_arguments=False)
|
outputs = executable.Execute(arguments)
|
||||||
return [x.to_py() for x in outputs]
|
return [x.to_py() for x in outputs]
|
||||||
|
|
||||||
|
|
||||||
@ -644,8 +644,9 @@ def execute_with_python_values_replicated(executable, arguments, backend=None):
|
|||||||
for replica_args in arguments:
|
for replica_args in arguments:
|
||||||
arg_buffers.append(flat_arg_buffers[:len(replica_args)])
|
arg_buffers.append(flat_arg_buffers[:len(replica_args)])
|
||||||
flat_arg_buffers = flat_arg_buffers[len(replica_args):]
|
flat_arg_buffers = flat_arg_buffers[len(replica_args):]
|
||||||
return [[x.to_py() for x in xs] for xs in executable.ExecuteOnLocalDevices(
|
return [[x.to_py()
|
||||||
arg_buffers, tuple_arguments=False)]
|
for x in xs]
|
||||||
|
for xs in executable.ExecuteOnLocalDevices(arg_buffers)]
|
||||||
|
|
||||||
|
|
||||||
class PaddingType(enum.Enum):
|
class PaddingType(enum.Enum):
|
||||||
|
Loading…
Reference in New Issue
Block a user