[XLA:Python] Add an option to Compile() that determines whether to tuple arguments.

Change in preparation for moving tuple_arguments from Execute() to Compile(), which is preferable since it makes it easier to reason about aliasing properties in Compile().
During the transition, we tuple arguments if it was requested at either compile-time or runtime.

PiperOrigin-RevId: 303807478
Change-Id: Ifa0b4a74bb1aaf422fc618727e45cb3ba46ee707
This commit is contained in:
Peter Hawkins 2020-03-30 13:13:15 -07:00 committed by TensorFlower Gardener
parent 763b3b8ee4
commit 433d40514b
7 changed files with 29 additions and 20 deletions

View File

@ -503,9 +503,10 @@ static std::shared_ptr<Device> LookupDevice(const PyTpuClient& client,
PyTpuExecutable::PyTpuExecutable(
std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
xla::Shape result_shape)
xla::Shape result_shape, bool tuple_arguments)
: client_(std::move(client)),
device_assignment_(std::move(device_assignment)),
tuple_arguments_(tuple_arguments),
result_shape_(std::move(result_shape)) {
VLOG(1) << "DeviceAssignment. " << device_assignment_.ToString();
const int num_replicas = device_assignment_.replica_count();
@ -627,7 +628,7 @@ StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
std::vector<PyTpuBuffer*> all_core_arguments;
std::unique_ptr<PyTpuBuffer> tupled_arguments;
if (tuple_arguments) {
if (tuple_arguments_ || tuple_arguments) {
TF_ASSIGN_OR_RETURN(tupled_arguments,
PyTpuBuffer::MakeTuple(argument_handles, client_,
local_devices_.front()));
@ -679,7 +680,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
if (tuple_arguments) {
if (tuple_arguments_ || tuple_arguments) {
tupled_arguments.resize(argument_handles.size());
tupled_argument_pointers.resize(argument_handles.size());
for (int i = 0; i < num_local_devices; ++i) {
@ -750,7 +751,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
absl::optional<std::vector<Shape>> argument_layouts,
const ExecutableBuildOptions* build_options,
std::shared_ptr<PyTpuClient> client,
absl::optional<DeviceAssignment> device_assignment) {
absl::optional<DeviceAssignment> device_assignment, bool tuple_arguments) {
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile");
VLOG(1) << "Compile: "
@ -814,7 +815,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
return absl::make_unique<PyTpuExecutable>(
std::move(compiled_program), std::move(*device_assignment),
std::move(client), std::move(result_layout));
std::move(client), std::move(result_layout), tuple_arguments);
}
} // namespace xla

View File

@ -268,12 +268,12 @@ class PyTpuExecutable {
absl::optional<std::vector<Shape>> argument_layouts,
const ExecutableBuildOptions* build_options,
std::shared_ptr<PyTpuClient> client,
absl::optional<DeviceAssignment> device_assignment);
absl::optional<DeviceAssignment> device_assignment, bool tuple_arguments);
PyTpuExecutable(
std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
xla::Shape result_shape);
xla::Shape result_shape, bool tuple_arguments);
virtual ~PyTpuExecutable() {
for (auto it = executables_.begin(); it != executables_.end(); ++it) {
client_->driver()->UnloadProgram(std::move(it->second), {});
@ -336,6 +336,7 @@ class PyTpuExecutable {
std::shared_ptr<PyTpuClient> const client_;
std::map<int, std::unique_ptr<tpu_driver::LoadedProgramHandle>> executables_;
const DeviceAssignment device_assignment_;
const bool tuple_arguments_;
// The replica and partition indices of device_assignment_ to be run by this
// client. On single-host platforms without partitioning, this is all replicas

View File

@ -100,7 +100,8 @@ class TpuBackend(xla_client.Backend):
return _tpu_client.TpuExecutable.Compile(c_computation,
compile_options.argument_layouts,
options, self.client,
compile_options.device_assignment)
compile_options.device_assignment,
compile_options.tuple_arguments)
def get_default_device_assignment(self, num_replicas, num_partitions=None):
if num_partitions is not None:

View File

@ -167,7 +167,8 @@ PYBIND11_MODULE(tpu_client_extension, m) {
const ExecutableBuildOptions* build_options,
std::shared_ptr<PyTpuClient> client,
absl::optional<std::vector<std::vector<Device*>>>
device_assignment)
device_assignment,
bool tuple_arguments)
-> StatusOr<std::unique_ptr<PyTpuExecutable>> {
py::gil_scoped_release gil_release;
absl::optional<DeviceAssignment> xla_device_assignment;
@ -178,7 +179,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
}
return PyTpuExecutable::Compile(
computation, argument_layouts, build_options, client,
std::move(xla_device_assignment));
std::move(xla_device_assignment), tuple_arguments);
})
.def("local_logical_device_ids",
&PyTpuExecutable::local_logical_device_ids)
@ -188,10 +189,10 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("Delete", &PyTpuExecutable::Delete)
.def("Execute", &PyTpuExecutable::Execute,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
py::arg("tuple_arguments"))
py::arg("tuple_arguments") = false)
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
py::arg("tuple_arguments"));
py::arg("tuple_arguments") = false);
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
.def_property_readonly("coords", &TpuDevice::coords)

View File

@ -1040,7 +1040,8 @@ PYBIND11_MODULE(xla_extension, m) {
absl::optional<std::vector<Shape>> argument_layouts,
const ExecutableBuildOptions* build_options,
std::shared_ptr<PyLocalClient> client,
absl::optional<DeviceAssignment> device_assignment)
absl::optional<DeviceAssignment> device_assignment,
bool tuple_arguments)
-> StatusOr<ClientAndUniquePtr<PyLocalExecutable>> {
py::gil_scoped_release gil_release;
CompileOptions options;
@ -1048,6 +1049,7 @@ PYBIND11_MODULE(xla_extension, m) {
if (build_options) {
options.executable_build_options = *build_options;
}
options.tuple_arguments = tuple_arguments;
if (device_assignment) {
options.executable_build_options.set_device_assignment(
*device_assignment);
@ -1065,7 +1067,8 @@ PYBIND11_MODULE(xla_extension, m) {
const ExecutableBuildOptions* build_options,
std::shared_ptr<PyLocalClient> client,
absl::optional<std::vector<std::vector<Device*>>>
device_assignment)
device_assignment,
bool tuple_arguments)
-> StatusOr<ClientAndUniquePtr<PyLocalExecutable>> {
py::gil_scoped_release gil_release;
CompileOptions options;
@ -1073,6 +1076,7 @@ PYBIND11_MODULE(xla_extension, m) {
if (build_options) {
options.executable_build_options = *build_options;
}
options.tuple_arguments = tuple_arguments;
if (device_assignment) {
TF_ASSIGN_OR_RETURN(
DeviceAssignment xla_assignment,
@ -1122,7 +1126,7 @@ PYBIND11_MODULE(xla_extension, m) {
}
return outputs;
},
py::arg("arguments"), py::arg("tuple_arguments"))
py::arg("arguments"), py::arg("tuple_arguments") = false)
.def(
"ExecuteOnLocalDevices",
[](const PyLocalExecutable& executable,
@ -1150,7 +1154,7 @@ PYBIND11_MODULE(xla_extension, m) {
}
return outputs;
},
py::arg("arguments"), py::arg("tuple_arguments"))
py::arg("arguments"), py::arg("tuple_arguments") = false)
.def(
"get_hlo_modules",
[](const PyLocalExecutable& executable)

View File

@ -147,7 +147,8 @@ class LocalBackend(Backend):
return _xla.LocalExecutable.Compile(c_computation,
compile_options.argument_layouts,
options, self.client,
compile_options.device_assignment)
compile_options.device_assignment,
compile_options.tuple_arguments)
def get_default_device_assignment(self, num_replicas, num_partitions=None):
if num_partitions is not None:
@ -504,6 +505,7 @@ class CompileOptions(object):
self.argument_layouts = None
self.result_layout = None
self.device_assignment = None
self.tuple_arguments = False
class Computation(object):

View File

@ -494,7 +494,7 @@ class BufferTest(ComputationTest):
arg_buffer = xla_client.Buffer.from_pyval(arg)
arg_buffer.delete()
with self.assertRaises(RuntimeError):
compiled_c.Execute([arg_buffer], tuple_arguments=False)
compiled_c.Execute([arg_buffer])
def testShape(self):
pyval = np.array([[1., 2.]], np.float32)
@ -1903,8 +1903,7 @@ class EmbeddedComputationsTest(ComputationTest):
compiled_c = c.Build().Compile()
for want in to_round_trip:
execution = threading.Thread(
target=lambda: compiled_c.Execute([], tuple_arguments=False))
execution = threading.Thread(target=lambda: compiled_c.Execute([]))
execution.start()
xla_client.transfer_to_infeed(want)
got = xla_client.transfer_from_outfeed(