[XLA:Python] Fix segfault when using newly-added Compile() overload.

PiperOrigin-RevId: 294512055
Change-Id: Idbb314fab390efd8dc4259b87f3e7e06d1e848a1
This commit is contained in:
Peter Hawkins 2020-02-11 13:50:57 -08:00 committed by TensorFlower Gardener
parent 836719056d
commit 4e5860d00c
2 changed files with 4 additions and 2 deletions

View File

@ -989,7 +989,8 @@ PyLocalExecutable::CompileForDevices(
device_assignment[replica].size(), replica,
device_assignment[0].size());
}
for (int partition = 0; partition < device_assignment.size(); ++partition) {
for (int partition = 0; partition < device_assignment[replica].size();
++partition) {
if (device_assignment[0][0]->platform_name() !=
device_assignment[replica][partition]->platform_name()) {
return InvalidArgument(

View File

@ -745,7 +745,8 @@ PyTpuExecutable::CompileForDevices(
device_assignment[replica].size(), replica,
device_assignment[0].size());
}
for (int partition = 0; partition < device_assignment.size(); ++partition) {
for (int partition = 0; partition < device_assignment[replica].size();
++partition) {
if (device_assignment[0][0]->platform_name() !=
device_assignment[replica][partition]->platform_name()) {
return InvalidArgument(