Add compile-time option to tuple PyLocalExecutable arguments. For now

execute options also includes the flag and arguments are tupled if
either is set. Once the Python callers switch to using the compile
time option the executable option can be retired.

PiperOrigin-RevId: 303376062
Change-Id: I5c10e2997e4ab22013ccec7bb647c7c995f52406
This commit is contained in:
A. Unique TensorFlower 2020-03-27 11:54:06 -07:00 committed by TensorFlower Gardener
parent ad101dc7a8
commit d2b90c81c1
2 changed files with 17 additions and 6 deletions

View File

@ -667,10 +667,11 @@ static Device* LookupDevice(const PyLocalClient& client, int device_id) {
PyLocalExecutable::PyLocalExecutable(
std::vector<std::unique_ptr<LocalExecutable>> executables,
DeviceAssignment device_assignment, PyLocalClient* client)
bool tuple_arguments, DeviceAssignment device_assignment,
PyLocalClient* client)
: client_(client),
device_assignment_(
std::make_shared<DeviceAssignment>(device_assignment)) {
device_assignment_(std::make_shared<DeviceAssignment>(device_assignment)),
tuple_arguments_(tuple_arguments) {
executables_.reserve(executables.size());
for (auto& executable : executables) {
executables_.emplace_back(std::move(executable));
@ -727,7 +728,7 @@ PyLocalExecutable::ExecuteHelper(
std::unique_ptr<PyLocalBuffer> tuple_buffer;
std::vector<PyLocalBuffer*> tupled_arguments;
if (options.tuple_arguments) {
if (options.tuple_arguments || tuple_arguments_) {
TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple(
argument_handles, client_, device));
tupled_arguments = {tuple_buffer.get()};
@ -1037,7 +1038,8 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
build_options));
return absl::make_unique<PyLocalExecutable>(
std::move(local_executables), build_options.device_assignment(), client);
std::move(local_executables), options.tuple_arguments,
build_options.device_assignment(), client);
}
} // namespace xla

View File

@ -316,6 +316,10 @@ struct CompileOptions {
// The layouts of the arguments that the computation should expect.
absl::optional<std::vector<Shape>> argument_layouts;
// If true, the arguments to the computation will be wrapped in a tuple and
// passed as a single parameter.
bool tuple_arguments = false;
// XLA's compilation time options.
ExecutableBuildOptions executable_build_options;
};
@ -340,7 +344,8 @@ class PyLocalExecutable {
CompileOptions options);
PyLocalExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
DeviceAssignment device_assignment, PyLocalClient* client);
bool tuple_arguments, DeviceAssignment device_assignment,
PyLocalClient* client);
PyLocalClient* client() const { return client_; }
@ -404,6 +409,10 @@ class PyLocalExecutable {
std::vector<std::shared_ptr<LocalExecutable>> executables_;
std::shared_ptr<DeviceAssignment> device_assignment_;
// True if the executables were compiled expecting arguments in a single
// tuple.
const bool tuple_arguments_;
// The replica and partition indices of device_assignment_ to be run by this
// client. On single-host platforms without partitioning, this is all replicas
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case