Add compile-time option to tuple PyLocalExecutable arguments. For now
execute options also includes the flag and arguments are tupled if either is set. Once the Python callers switch to using the compile time option the executable option can be retired. PiperOrigin-RevId: 303376062 Change-Id: I5c10e2997e4ab22013ccec7bb647c7c995f52406
This commit is contained in:
parent
ad101dc7a8
commit
d2b90c81c1
@ -667,10 +667,11 @@ static Device* LookupDevice(const PyLocalClient& client, int device_id) {
|
||||
|
||||
PyLocalExecutable::PyLocalExecutable(
|
||||
std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||
DeviceAssignment device_assignment, PyLocalClient* client)
|
||||
bool tuple_arguments, DeviceAssignment device_assignment,
|
||||
PyLocalClient* client)
|
||||
: client_(client),
|
||||
device_assignment_(
|
||||
std::make_shared<DeviceAssignment>(device_assignment)) {
|
||||
device_assignment_(std::make_shared<DeviceAssignment>(device_assignment)),
|
||||
tuple_arguments_(tuple_arguments) {
|
||||
executables_.reserve(executables.size());
|
||||
for (auto& executable : executables) {
|
||||
executables_.emplace_back(std::move(executable));
|
||||
@ -727,7 +728,7 @@ PyLocalExecutable::ExecuteHelper(
|
||||
|
||||
std::unique_ptr<PyLocalBuffer> tuple_buffer;
|
||||
std::vector<PyLocalBuffer*> tupled_arguments;
|
||||
if (options.tuple_arguments) {
|
||||
if (options.tuple_arguments || tuple_arguments_) {
|
||||
TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple(
|
||||
argument_handles, client_, device));
|
||||
tupled_arguments = {tuple_buffer.get()};
|
||||
@ -1037,7 +1038,8 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||
build_options));
|
||||
|
||||
return absl::make_unique<PyLocalExecutable>(
|
||||
std::move(local_executables), build_options.device_assignment(), client);
|
||||
std::move(local_executables), options.tuple_arguments,
|
||||
build_options.device_assignment(), client);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -316,6 +316,10 @@ struct CompileOptions {
|
||||
// The layouts of the arguments that the computation should expect.
|
||||
absl::optional<std::vector<Shape>> argument_layouts;
|
||||
|
||||
// If true, the arguments to the computation will be wrapped in a tuple and
|
||||
// passed as a single parameter.
|
||||
bool tuple_arguments = false;
|
||||
|
||||
// XLA's compilation time options.
|
||||
ExecutableBuildOptions executable_build_options;
|
||||
};
|
||||
@ -340,7 +344,8 @@ class PyLocalExecutable {
|
||||
CompileOptions options);
|
||||
|
||||
PyLocalExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||
DeviceAssignment device_assignment, PyLocalClient* client);
|
||||
bool tuple_arguments, DeviceAssignment device_assignment,
|
||||
PyLocalClient* client);
|
||||
|
||||
PyLocalClient* client() const { return client_; }
|
||||
|
||||
@ -404,6 +409,10 @@ class PyLocalExecutable {
|
||||
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
||||
std::shared_ptr<DeviceAssignment> device_assignment_;
|
||||
|
||||
// True if the executables were compiled expecting arguments in a single
|
||||
// tuple.
|
||||
const bool tuple_arguments_;
|
||||
|
||||
// The replica and partition indices of device_assignment_ to be run by this
|
||||
// client. On single-host platforms without partitioning, this is all replicas
|
||||
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
|
||||
|
Loading…
x
Reference in New Issue
Block a user