[PJRT] Don't mention replica/partition numbers when reporting errors about computations with only a single local replica.

Should improve error messages in cases like https://github.com/google/jax/issues/5733

PiperOrigin-RevId: 357769664
Change-Id: Iaba7fa5229c54b5acda9d5379676835bf5851af6
This commit is contained in:
Peter Hawkins 2021-02-16 11:44:44 -08:00 committed by TensorFlower Gardener
parent 3af7f6d31f
commit 4feda2aa5e

View File

@ -2029,12 +2029,16 @@ PjRtStreamExecutorExecutable::Execute(
const int partition = addressable_device_logical_ids_[i].partition;
auto& statusor = results[i];
if (!statusor.ok()) {
return AppendStatus(
statusor.status(),
absl::StrFormat("while running replica %d and partition %d of a "
"replicated computation (other "
"replicas may have failed as well).",
replica, partition));
if (num_addressable_devices == 1) {
return statusor.status();
} else {
return AppendStatus(
statusor.status(),
absl::StrFormat("while running replica %d and partition %d of a "
"replicated computation (other "
"replicas may have failed as well).",
replica, partition));
}
}
wrapped_results[i] = std::move(statusor.ValueOrDie());
}