[XLA:Python] Add an option to Compile() that determines whether to tuple arguments.

Change in preparation for moving tuple_arguments from Execute() to Compile(), which is preferable since it makes it easier to reason about aliasing properties in Compile().
During the transition, we tuple arguments if it was requested at either compile-time or runtime.

PiperOrigin-RevId: 303807478
Change-Id: Ifa0b4a74bb1aaf422fc618727e45cb3ba46ee707
This commit is contained in:
Peter Hawkins 2020-03-30 13:13:15 -07:00 committed by TensorFlower Gardener
parent 763b3b8ee4
commit 433d40514b
7 changed files with 29 additions and 20 deletions

View File

@ -503,9 +503,10 @@ static std::shared_ptr<Device> LookupDevice(const PyTpuClient& client,
PyTpuExecutable::PyTpuExecutable( PyTpuExecutable::PyTpuExecutable(
std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program, std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client, DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
xla::Shape result_shape) xla::Shape result_shape, bool tuple_arguments)
: client_(std::move(client)), : client_(std::move(client)),
device_assignment_(std::move(device_assignment)), device_assignment_(std::move(device_assignment)),
tuple_arguments_(tuple_arguments),
result_shape_(std::move(result_shape)) { result_shape_(std::move(result_shape)) {
VLOG(1) << "DeviceAssignment. " << device_assignment_.ToString(); VLOG(1) << "DeviceAssignment. " << device_assignment_.ToString();
const int num_replicas = device_assignment_.replica_count(); const int num_replicas = device_assignment_.replica_count();
@ -627,7 +628,7 @@ StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
std::vector<PyTpuBuffer*> all_core_arguments; std::vector<PyTpuBuffer*> all_core_arguments;
std::unique_ptr<PyTpuBuffer> tupled_arguments; std::unique_ptr<PyTpuBuffer> tupled_arguments;
if (tuple_arguments) { if (tuple_arguments_ || tuple_arguments) {
TF_ASSIGN_OR_RETURN(tupled_arguments, TF_ASSIGN_OR_RETURN(tupled_arguments,
PyTpuBuffer::MakeTuple(argument_handles, client_, PyTpuBuffer::MakeTuple(argument_handles, client_,
local_devices_.front())); local_devices_.front()));
@ -679,7 +680,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments; std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers; std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
if (tuple_arguments) { if (tuple_arguments_ || tuple_arguments) {
tupled_arguments.resize(argument_handles.size()); tupled_arguments.resize(argument_handles.size());
tupled_argument_pointers.resize(argument_handles.size()); tupled_argument_pointers.resize(argument_handles.size());
for (int i = 0; i < num_local_devices; ++i) { for (int i = 0; i < num_local_devices; ++i) {
@ -750,7 +751,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
absl::optional<std::vector<Shape>> argument_layouts, absl::optional<std::vector<Shape>> argument_layouts,
const ExecutableBuildOptions* build_options, const ExecutableBuildOptions* build_options,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PyTpuClient> client,
absl::optional<DeviceAssignment> device_assignment) { absl::optional<DeviceAssignment> device_assignment, bool tuple_arguments) {
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile"); tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile");
VLOG(1) << "Compile: " VLOG(1) << "Compile: "
@ -814,7 +815,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
return absl::make_unique<PyTpuExecutable>( return absl::make_unique<PyTpuExecutable>(
std::move(compiled_program), std::move(*device_assignment), std::move(compiled_program), std::move(*device_assignment),
std::move(client), std::move(result_layout)); std::move(client), std::move(result_layout), tuple_arguments);
} }
} // namespace xla } // namespace xla

View File

@ -268,12 +268,12 @@ class PyTpuExecutable {
absl::optional<std::vector<Shape>> argument_layouts, absl::optional<std::vector<Shape>> argument_layouts,
const ExecutableBuildOptions* build_options, const ExecutableBuildOptions* build_options,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PyTpuClient> client,
absl::optional<DeviceAssignment> device_assignment); absl::optional<DeviceAssignment> device_assignment, bool tuple_arguments);
PyTpuExecutable( PyTpuExecutable(
std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program, std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client, DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
xla::Shape result_shape); xla::Shape result_shape, bool tuple_arguments);
virtual ~PyTpuExecutable() { virtual ~PyTpuExecutable() {
for (auto it = executables_.begin(); it != executables_.end(); ++it) { for (auto it = executables_.begin(); it != executables_.end(); ++it) {
client_->driver()->UnloadProgram(std::move(it->second), {}); client_->driver()->UnloadProgram(std::move(it->second), {});
@ -336,6 +336,7 @@ class PyTpuExecutable {
std::shared_ptr<PyTpuClient> const client_; std::shared_ptr<PyTpuClient> const client_;
std::map<int, std::unique_ptr<tpu_driver::LoadedProgramHandle>> executables_; std::map<int, std::unique_ptr<tpu_driver::LoadedProgramHandle>> executables_;
const DeviceAssignment device_assignment_; const DeviceAssignment device_assignment_;
const bool tuple_arguments_;
// The replica and partition indices of device_assignment_ to be run by this // The replica and partition indices of device_assignment_ to be run by this
// client. On single-host platforms without partitioning, this is all replicas // client. On single-host platforms without partitioning, this is all replicas

View File

@ -100,7 +100,8 @@ class TpuBackend(xla_client.Backend):
return _tpu_client.TpuExecutable.Compile(c_computation, return _tpu_client.TpuExecutable.Compile(c_computation,
compile_options.argument_layouts, compile_options.argument_layouts,
options, self.client, options, self.client,
compile_options.device_assignment) compile_options.device_assignment,
compile_options.tuple_arguments)
def get_default_device_assignment(self, num_replicas, num_partitions=None): def get_default_device_assignment(self, num_replicas, num_partitions=None):
if num_partitions is not None: if num_partitions is not None:

View File

@ -167,7 +167,8 @@ PYBIND11_MODULE(tpu_client_extension, m) {
const ExecutableBuildOptions* build_options, const ExecutableBuildOptions* build_options,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PyTpuClient> client,
absl::optional<std::vector<std::vector<Device*>>> absl::optional<std::vector<std::vector<Device*>>>
device_assignment) device_assignment,
bool tuple_arguments)
-> StatusOr<std::unique_ptr<PyTpuExecutable>> { -> StatusOr<std::unique_ptr<PyTpuExecutable>> {
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
absl::optional<DeviceAssignment> xla_device_assignment; absl::optional<DeviceAssignment> xla_device_assignment;
@ -178,7 +179,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
} }
return PyTpuExecutable::Compile( return PyTpuExecutable::Compile(
computation, argument_layouts, build_options, client, computation, argument_layouts, build_options, client,
std::move(xla_device_assignment)); std::move(xla_device_assignment), tuple_arguments);
}) })
.def("local_logical_device_ids", .def("local_logical_device_ids",
&PyTpuExecutable::local_logical_device_ids) &PyTpuExecutable::local_logical_device_ids)
@ -188,10 +189,10 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("Delete", &PyTpuExecutable::Delete) .def("Delete", &PyTpuExecutable::Delete)
.def("Execute", &PyTpuExecutable::Execute, .def("Execute", &PyTpuExecutable::Execute,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"), py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
py::arg("tuple_arguments")) py::arg("tuple_arguments") = false)
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices, .def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"), py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
py::arg("tuple_arguments")); py::arg("tuple_arguments") = false);
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice") py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
.def_property_readonly("coords", &TpuDevice::coords) .def_property_readonly("coords", &TpuDevice::coords)

View File

@ -1040,7 +1040,8 @@ PYBIND11_MODULE(xla_extension, m) {
absl::optional<std::vector<Shape>> argument_layouts, absl::optional<std::vector<Shape>> argument_layouts,
const ExecutableBuildOptions* build_options, const ExecutableBuildOptions* build_options,
std::shared_ptr<PyLocalClient> client, std::shared_ptr<PyLocalClient> client,
absl::optional<DeviceAssignment> device_assignment) absl::optional<DeviceAssignment> device_assignment,
bool tuple_arguments)
-> StatusOr<ClientAndUniquePtr<PyLocalExecutable>> { -> StatusOr<ClientAndUniquePtr<PyLocalExecutable>> {
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
CompileOptions options; CompileOptions options;
@ -1048,6 +1049,7 @@ PYBIND11_MODULE(xla_extension, m) {
if (build_options) { if (build_options) {
options.executable_build_options = *build_options; options.executable_build_options = *build_options;
} }
options.tuple_arguments = tuple_arguments;
if (device_assignment) { if (device_assignment) {
options.executable_build_options.set_device_assignment( options.executable_build_options.set_device_assignment(
*device_assignment); *device_assignment);
@ -1065,7 +1067,8 @@ PYBIND11_MODULE(xla_extension, m) {
const ExecutableBuildOptions* build_options, const ExecutableBuildOptions* build_options,
std::shared_ptr<PyLocalClient> client, std::shared_ptr<PyLocalClient> client,
absl::optional<std::vector<std::vector<Device*>>> absl::optional<std::vector<std::vector<Device*>>>
device_assignment) device_assignment,
bool tuple_arguments)
-> StatusOr<ClientAndUniquePtr<PyLocalExecutable>> { -> StatusOr<ClientAndUniquePtr<PyLocalExecutable>> {
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
CompileOptions options; CompileOptions options;
@ -1073,6 +1076,7 @@ PYBIND11_MODULE(xla_extension, m) {
if (build_options) { if (build_options) {
options.executable_build_options = *build_options; options.executable_build_options = *build_options;
} }
options.tuple_arguments = tuple_arguments;
if (device_assignment) { if (device_assignment) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
DeviceAssignment xla_assignment, DeviceAssignment xla_assignment,
@ -1122,7 +1126,7 @@ PYBIND11_MODULE(xla_extension, m) {
} }
return outputs; return outputs;
}, },
py::arg("arguments"), py::arg("tuple_arguments")) py::arg("arguments"), py::arg("tuple_arguments") = false)
.def( .def(
"ExecuteOnLocalDevices", "ExecuteOnLocalDevices",
[](const PyLocalExecutable& executable, [](const PyLocalExecutable& executable,
@ -1150,7 +1154,7 @@ PYBIND11_MODULE(xla_extension, m) {
} }
return outputs; return outputs;
}, },
py::arg("arguments"), py::arg("tuple_arguments")) py::arg("arguments"), py::arg("tuple_arguments") = false)
.def( .def(
"get_hlo_modules", "get_hlo_modules",
[](const PyLocalExecutable& executable) [](const PyLocalExecutable& executable)

View File

@ -147,7 +147,8 @@ class LocalBackend(Backend):
return _xla.LocalExecutable.Compile(c_computation, return _xla.LocalExecutable.Compile(c_computation,
compile_options.argument_layouts, compile_options.argument_layouts,
options, self.client, options, self.client,
compile_options.device_assignment) compile_options.device_assignment,
compile_options.tuple_arguments)
def get_default_device_assignment(self, num_replicas, num_partitions=None): def get_default_device_assignment(self, num_replicas, num_partitions=None):
if num_partitions is not None: if num_partitions is not None:
@ -504,6 +505,7 @@ class CompileOptions(object):
self.argument_layouts = None self.argument_layouts = None
self.result_layout = None self.result_layout = None
self.device_assignment = None self.device_assignment = None
self.tuple_arguments = False
class Computation(object): class Computation(object):

View File

@ -494,7 +494,7 @@ class BufferTest(ComputationTest):
arg_buffer = xla_client.Buffer.from_pyval(arg) arg_buffer = xla_client.Buffer.from_pyval(arg)
arg_buffer.delete() arg_buffer.delete()
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
compiled_c.Execute([arg_buffer], tuple_arguments=False) compiled_c.Execute([arg_buffer])
def testShape(self): def testShape(self):
pyval = np.array([[1., 2.]], np.float32) pyval = np.array([[1., 2.]], np.float32)
@ -1903,8 +1903,7 @@ class EmbeddedComputationsTest(ComputationTest):
compiled_c = c.Build().Compile() compiled_c = c.Build().Compile()
for want in to_round_trip: for want in to_round_trip:
execution = threading.Thread( execution = threading.Thread(target=lambda: compiled_c.Execute([]))
target=lambda: compiled_c.Execute([], tuple_arguments=False))
execution.start() execution.start()
xla_client.transfer_to_infeed(want) xla_client.transfer_to_infeed(want)
got = xla_client.transfer_from_outfeed( got = xla_client.transfer_from_outfeed(