diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index 4c12626f987..ee22cd9ebcc 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -989,7 +989,8 @@ PyLocalExecutable::CompileForDevices( device_assignment[replica].size(), replica, device_assignment[0].size()); } - for (int partition = 0; partition < device_assignment.size(); ++partition) { + for (int partition = 0; partition < device_assignment[replica].size(); + ++partition) { if (device_assignment[0][0]->platform_name() != device_assignment[replica][partition]->platform_name()) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 6682866bbf0..33573c1c8d8 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -745,7 +745,8 @@ PyTpuExecutable::CompileForDevices( device_assignment[replica].size(), replica, device_assignment[0].size()); } - for (int partition = 0; partition < device_assignment.size(); ++partition) { + for (int partition = 0; partition < device_assignment[replica].size(); + ++partition) { if (device_assignment[0][0]->platform_name() != device_assignment[replica][partition]->platform_name()) { return InvalidArgument(