[XLA:Python] Add an option to Compile() that determines whether to tuple arguments.
Change in preparation for moving tuple_arguments from Execute() to Compile(), which is preferable since it makes it easier to reason about aliasing properties in Compile(). During the transition, we tuple arguments if it was requested at either compile-time or runtime. PiperOrigin-RevId: 303807478 Change-Id: Ifa0b4a74bb1aaf422fc618727e45cb3ba46ee707
This commit is contained in:
parent
763b3b8ee4
commit
433d40514b
@ -503,9 +503,10 @@ static std::shared_ptr<Device> LookupDevice(const PyTpuClient& client,
|
||||
PyTpuExecutable::PyTpuExecutable(
|
||||
std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
|
||||
DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
|
||||
xla::Shape result_shape)
|
||||
xla::Shape result_shape, bool tuple_arguments)
|
||||
: client_(std::move(client)),
|
||||
device_assignment_(std::move(device_assignment)),
|
||||
tuple_arguments_(tuple_arguments),
|
||||
result_shape_(std::move(result_shape)) {
|
||||
VLOG(1) << "DeviceAssignment. " << device_assignment_.ToString();
|
||||
const int num_replicas = device_assignment_.replica_count();
|
||||
@ -627,7 +628,7 @@ StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> PyTpuExecutable::Execute(
|
||||
std::vector<PyTpuBuffer*> all_core_arguments;
|
||||
|
||||
std::unique_ptr<PyTpuBuffer> tupled_arguments;
|
||||
if (tuple_arguments) {
|
||||
if (tuple_arguments_ || tuple_arguments) {
|
||||
TF_ASSIGN_OR_RETURN(tupled_arguments,
|
||||
PyTpuBuffer::MakeTuple(argument_handles, client_,
|
||||
local_devices_.front()));
|
||||
@ -679,7 +680,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
||||
|
||||
std::vector<std::unique_ptr<PyTpuBuffer>> tupled_arguments;
|
||||
std::vector<std::vector<PyTpuBuffer*>> tupled_argument_pointers;
|
||||
if (tuple_arguments) {
|
||||
if (tuple_arguments_ || tuple_arguments) {
|
||||
tupled_arguments.resize(argument_handles.size());
|
||||
tupled_argument_pointers.resize(argument_handles.size());
|
||||
for (int i = 0; i < num_local_devices; ++i) {
|
||||
@ -750,7 +751,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
||||
absl::optional<std::vector<Shape>> argument_layouts,
|
||||
const ExecutableBuildOptions* build_options,
|
||||
std::shared_ptr<PyTpuClient> client,
|
||||
absl::optional<DeviceAssignment> device_assignment) {
|
||||
absl::optional<DeviceAssignment> device_assignment, bool tuple_arguments) {
|
||||
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile");
|
||||
|
||||
VLOG(1) << "Compile: "
|
||||
@ -814,7 +815,7 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
||||
|
||||
return absl::make_unique<PyTpuExecutable>(
|
||||
std::move(compiled_program), std::move(*device_assignment),
|
||||
std::move(client), std::move(result_layout));
|
||||
std::move(client), std::move(result_layout), tuple_arguments);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -268,12 +268,12 @@ class PyTpuExecutable {
|
||||
absl::optional<std::vector<Shape>> argument_layouts,
|
||||
const ExecutableBuildOptions* build_options,
|
||||
std::shared_ptr<PyTpuClient> client,
|
||||
absl::optional<DeviceAssignment> device_assignment);
|
||||
absl::optional<DeviceAssignment> device_assignment, bool tuple_arguments);
|
||||
|
||||
PyTpuExecutable(
|
||||
std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
|
||||
DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
|
||||
xla::Shape result_shape);
|
||||
xla::Shape result_shape, bool tuple_arguments);
|
||||
virtual ~PyTpuExecutable() {
|
||||
for (auto it = executables_.begin(); it != executables_.end(); ++it) {
|
||||
client_->driver()->UnloadProgram(std::move(it->second), {});
|
||||
@ -336,6 +336,7 @@ class PyTpuExecutable {
|
||||
std::shared_ptr<PyTpuClient> const client_;
|
||||
std::map<int, std::unique_ptr<tpu_driver::LoadedProgramHandle>> executables_;
|
||||
const DeviceAssignment device_assignment_;
|
||||
const bool tuple_arguments_;
|
||||
|
||||
// The replica and partition indices of device_assignment_ to be run by this
|
||||
// client. On single-host platforms without partitioning, this is all replicas
|
||||
|
@ -100,7 +100,8 @@ class TpuBackend(xla_client.Backend):
|
||||
return _tpu_client.TpuExecutable.Compile(c_computation,
|
||||
compile_options.argument_layouts,
|
||||
options, self.client,
|
||||
compile_options.device_assignment)
|
||||
compile_options.device_assignment,
|
||||
compile_options.tuple_arguments)
|
||||
|
||||
def get_default_device_assignment(self, num_replicas, num_partitions=None):
|
||||
if num_partitions is not None:
|
||||
|
@ -167,7 +167,8 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
const ExecutableBuildOptions* build_options,
|
||||
std::shared_ptr<PyTpuClient> client,
|
||||
absl::optional<std::vector<std::vector<Device*>>>
|
||||
device_assignment)
|
||||
device_assignment,
|
||||
bool tuple_arguments)
|
||||
-> StatusOr<std::unique_ptr<PyTpuExecutable>> {
|
||||
py::gil_scoped_release gil_release;
|
||||
absl::optional<DeviceAssignment> xla_device_assignment;
|
||||
@ -178,7 +179,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
}
|
||||
return PyTpuExecutable::Compile(
|
||||
computation, argument_layouts, build_options, client,
|
||||
std::move(xla_device_assignment));
|
||||
std::move(xla_device_assignment), tuple_arguments);
|
||||
})
|
||||
.def("local_logical_device_ids",
|
||||
&PyTpuExecutable::local_logical_device_ids)
|
||||
@ -188,10 +189,10 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
.def("Delete", &PyTpuExecutable::Delete)
|
||||
.def("Execute", &PyTpuExecutable::Execute,
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
|
||||
py::arg("tuple_arguments"))
|
||||
py::arg("tuple_arguments") = false)
|
||||
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"),
|
||||
py::arg("tuple_arguments"));
|
||||
py::arg("tuple_arguments") = false);
|
||||
|
||||
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||
.def_property_readonly("coords", &TpuDevice::coords)
|
||||
|
@ -1040,7 +1040,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
absl::optional<std::vector<Shape>> argument_layouts,
|
||||
const ExecutableBuildOptions* build_options,
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
absl::optional<DeviceAssignment> device_assignment)
|
||||
absl::optional<DeviceAssignment> device_assignment,
|
||||
bool tuple_arguments)
|
||||
-> StatusOr<ClientAndUniquePtr<PyLocalExecutable>> {
|
||||
py::gil_scoped_release gil_release;
|
||||
CompileOptions options;
|
||||
@ -1048,6 +1049,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
if (build_options) {
|
||||
options.executable_build_options = *build_options;
|
||||
}
|
||||
options.tuple_arguments = tuple_arguments;
|
||||
if (device_assignment) {
|
||||
options.executable_build_options.set_device_assignment(
|
||||
*device_assignment);
|
||||
@ -1065,7 +1067,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
const ExecutableBuildOptions* build_options,
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
absl::optional<std::vector<std::vector<Device*>>>
|
||||
device_assignment)
|
||||
device_assignment,
|
||||
bool tuple_arguments)
|
||||
-> StatusOr<ClientAndUniquePtr<PyLocalExecutable>> {
|
||||
py::gil_scoped_release gil_release;
|
||||
CompileOptions options;
|
||||
@ -1073,6 +1076,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
if (build_options) {
|
||||
options.executable_build_options = *build_options;
|
||||
}
|
||||
options.tuple_arguments = tuple_arguments;
|
||||
if (device_assignment) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment xla_assignment,
|
||||
@ -1122,7 +1126,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
}
|
||||
return outputs;
|
||||
},
|
||||
py::arg("arguments"), py::arg("tuple_arguments"))
|
||||
py::arg("arguments"), py::arg("tuple_arguments") = false)
|
||||
.def(
|
||||
"ExecuteOnLocalDevices",
|
||||
[](const PyLocalExecutable& executable,
|
||||
@ -1150,7 +1154,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
}
|
||||
return outputs;
|
||||
},
|
||||
py::arg("arguments"), py::arg("tuple_arguments"))
|
||||
py::arg("arguments"), py::arg("tuple_arguments") = false)
|
||||
.def(
|
||||
"get_hlo_modules",
|
||||
[](const PyLocalExecutable& executable)
|
||||
|
@ -147,7 +147,8 @@ class LocalBackend(Backend):
|
||||
return _xla.LocalExecutable.Compile(c_computation,
|
||||
compile_options.argument_layouts,
|
||||
options, self.client,
|
||||
compile_options.device_assignment)
|
||||
compile_options.device_assignment,
|
||||
compile_options.tuple_arguments)
|
||||
|
||||
def get_default_device_assignment(self, num_replicas, num_partitions=None):
|
||||
if num_partitions is not None:
|
||||
@ -504,6 +505,7 @@ class CompileOptions(object):
|
||||
self.argument_layouts = None
|
||||
self.result_layout = None
|
||||
self.device_assignment = None
|
||||
self.tuple_arguments = False
|
||||
|
||||
|
||||
class Computation(object):
|
||||
|
@ -494,7 +494,7 @@ class BufferTest(ComputationTest):
|
||||
arg_buffer = xla_client.Buffer.from_pyval(arg)
|
||||
arg_buffer.delete()
|
||||
with self.assertRaises(RuntimeError):
|
||||
compiled_c.Execute([arg_buffer], tuple_arguments=False)
|
||||
compiled_c.Execute([arg_buffer])
|
||||
|
||||
def testShape(self):
|
||||
pyval = np.array([[1., 2.]], np.float32)
|
||||
@ -1903,8 +1903,7 @@ class EmbeddedComputationsTest(ComputationTest):
|
||||
compiled_c = c.Build().Compile()
|
||||
|
||||
for want in to_round_trip:
|
||||
execution = threading.Thread(
|
||||
target=lambda: compiled_c.Execute([], tuple_arguments=False))
|
||||
execution = threading.Thread(target=lambda: compiled_c.Execute([]))
|
||||
execution.start()
|
||||
xla_client.transfer_to_infeed(want)
|
||||
got = xla_client.transfer_from_outfeed(
|
||||
|
Loading…
Reference in New Issue
Block a user