[XLA:Python] Remove the tuple_arguments argument from Execute().
tuple_arguments can be passed to Compile() instead. PiperOrigin-RevId: 304083684 Change-Id: I310986631c1489d0e21cff73a75c15ed41b4c747
This commit is contained in:
parent
486a20777c
commit
566c03a749
@ -728,7 +728,7 @@ PyLocalExecutable::ExecuteHelper(
|
||||
|
||||
std::unique_ptr<PyLocalBuffer> tuple_buffer;
|
||||
std::vector<PyLocalBuffer*> tupled_arguments;
|
||||
if (options.tuple_arguments || tuple_arguments_) {
|
||||
if (tuple_arguments_) {
|
||||
TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple(
|
||||
argument_handles, client_, device));
|
||||
tupled_arguments = {tuple_buffer.get()};
|
||||
|
@ -323,10 +323,6 @@ struct CompileOptions {
|
||||
};
|
||||
|
||||
struct ExecuteOptions {
|
||||
// If true, the arguments to the computation will be wrapped in a tuple and
|
||||
// passed as a single parameter.
|
||||
bool tuple_arguments = false;
|
||||
|
||||
// If true, the computation must return a tuple, which will be destructured
|
||||
// into its elements.
|
||||
bool untuple_result = false;
|
||||
|
@ -613,7 +613,7 @@ Status WaitForExecuteEvent(tpu_driver::Event* event) {
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
|
||||
absl::Span<PyTpuBuffer* const> argument_handles, bool tuple_arguments) {
|
||||
absl::Span<PyTpuBuffer* const> argument_handles) {
|
||||
if (num_replicas() != 1) {
|
||||
return InvalidArgument(
|
||||
"Attempted to execute computation with %d replicas using Execute().",
|
||||
@ -628,7 +628,7 @@ StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
|
||||
std::vector<PyTpuBuffer*> all_core_arguments;
|
||||
|
||||
std::unique_ptr<PyTpuBuffer> tupled_arguments;
|
||||
if (tuple_arguments_ || tuple_arguments) {
|
||||
if (tuple_arguments_) {
|
||||
TF_ASSIGN_OR_RETURN(tupled_arguments,
|
||||
PyTpuBuffer::MakeTuple(argument_handles, client_,
|
||||
local_devices_.front()));
|
||||
@ -659,8 +659,7 @@ StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
|
||||
|
||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
|
||||
PyTpuExecutable::ExecuteOnLocalDevices(
|
||||
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles,
|
||||
bool tuple_arguments) {
|
||||
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles) {
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PyTpuExecutable::ExecuteOnLocalDevices");
|
||||
|
||||
@ -680,7 +679,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
||||
|
||||
std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
|
||||
std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
|
||||
if (tuple_arguments_ || tuple_arguments) {
|
||||
if (tuple_arguments_) {
|
||||
tupled_arguments.resize(argument_handles.size());
|
||||
tupled_argument_pointers.resize(argument_handles.size());
|
||||
for (int i = 0; i < num_local_devices; ++i) {
|
||||
|
@ -309,7 +309,7 @@ class PyTpuExecutable {
|
||||
// inside for computation to finish. Coordinate with JAX code change to see if
|
||||
// we can make both Execute and ExecutePerReplica non-blocking.
|
||||
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> Execute(
|
||||
absl::Span<PyTpuBuffer* const> argument_handles, bool tuple_arguments);
|
||||
absl::Span<PyTpuBuffer* const> argument_handles);
|
||||
|
||||
// Execute on local devices. Takes a sequence of argument lists (one argument
|
||||
// list per local device) and returns a tuple of results (one result per local
|
||||
@ -317,8 +317,7 @@ class PyTpuExecutable {
|
||||
// count.
|
||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
|
||||
ExecuteOnLocalDevices(
|
||||
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles,
|
||||
bool tuple_arguments);
|
||||
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles);
|
||||
|
||||
void Delete() { executables_.clear(); }
|
||||
|
||||
|
@ -188,11 +188,9 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
&PyTpuExecutable::SizeOfGeneratedCodeInBytes)
|
||||
.def("Delete", &PyTpuExecutable::Delete)
|
||||
.def("Execute", &PyTpuExecutable::Execute,
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
|
||||
py::arg("tuple_arguments") = false)
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
|
||||
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
|
||||
py::arg("tuple_arguments") = false);
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
|
||||
|
||||
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||
.def_property_readonly("coords", &TpuDevice::coords)
|
||||
|
@ -1109,11 +1109,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def(
|
||||
"Execute",
|
||||
[](const PyLocalExecutable& executable,
|
||||
absl::Span<PyLocalBuffer* const> args, bool tuple_arguments)
|
||||
absl::Span<PyLocalBuffer* const> args)
|
||||
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
|
||||
py::gil_scoped_release gil_release;
|
||||
ExecuteOptions options;
|
||||
options.tuple_arguments = tuple_arguments;
|
||||
options.untuple_result = true;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::unique_ptr<PyLocalBuffer>> output_buffers,
|
||||
@ -1126,17 +1125,15 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
}
|
||||
return outputs;
|
||||
},
|
||||
py::arg("arguments"), py::arg("tuple_arguments") = false)
|
||||
py::arg("arguments"))
|
||||
.def(
|
||||
"ExecuteOnLocalDevices",
|
||||
[](const PyLocalExecutable& executable,
|
||||
absl::Span<const std::vector<PyLocalBuffer*>> args,
|
||||
bool tuple_arguments)
|
||||
absl::Span<const std::vector<PyLocalBuffer*>> args)
|
||||
-> StatusOr<
|
||||
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>>> {
|
||||
py::gil_scoped_release gil_release;
|
||||
ExecuteOptions options;
|
||||
options.tuple_arguments = tuple_arguments;
|
||||
options.untuple_result = true;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||
@ -1154,7 +1151,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
}
|
||||
return outputs;
|
||||
},
|
||||
py::arg("arguments"), py::arg("tuple_arguments") = false)
|
||||
py::arg("arguments"))
|
||||
.def(
|
||||
"get_hlo_modules",
|
||||
[](const PyLocalExecutable& executable)
|
||||
|
@ -615,7 +615,7 @@ def execute_with_python_values(executable, arguments=(), backend=None):
|
||||
arg, device=executable.local_devices()[0], backend=backend)
|
||||
|
||||
arguments = [put(arg) for arg in arguments]
|
||||
outputs = executable.Execute(arguments, tuple_arguments=False)
|
||||
outputs = executable.Execute(arguments)
|
||||
return [x.to_py() for x in outputs]
|
||||
|
||||
|
||||
@ -644,8 +644,9 @@ def execute_with_python_values_replicated(executable, arguments, backend=None):
|
||||
for replica_args in arguments:
|
||||
arg_buffers.append(flat_arg_buffers[:len(replica_args)])
|
||||
flat_arg_buffers = flat_arg_buffers[len(replica_args):]
|
||||
return [[x.to_py() for x in xs] for xs in executable.ExecuteOnLocalDevices(
|
||||
arg_buffers, tuple_arguments=False)]
|
||||
return [[x.to_py()
|
||||
for x in xs]
|
||||
for xs in executable.ExecuteOnLocalDevices(arg_buffers)]
|
||||
|
||||
|
||||
class PaddingType(enum.Enum):
|
||||
|
Loading…
Reference in New Issue
Block a user