diff --git a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc index cb1647832f7..8c3fe6df5e6 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc @@ -54,6 +54,26 @@ class PodEvent : public Event { const int64_t operation_id_; }; +class ErrorEvent : public PodEvent { + public: + explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status) + : PodEvent(driver, operation_id) { + status_ = status; + } + + xla::Status Await() override { return status_; } + absl::optional AwaitWithTimeout( + absl::Duration duration) override { + return status_; + } + void AddCallback(std::function callback) override { + callback(status_); + } + + private: + Status status_; +}; + class CombinedEvent : public PodEvent { public: explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id, @@ -543,6 +563,19 @@ class PodTpuDriver : public TpuDriver { const xla::DeviceAssignmentProto& device_assignment, absl::Span wait_for) override { int64_t operation_id = GetOperationId(); + + if (device_assignment.replica_count() != 1 && + device_assignment.replica_count() != core_to_driver_id_.size()) { + // TODO(frankchn): Remove restriction once we figure out what's wrong. + std::string error_msg = + absl::StrCat("Programs must be replicated across all ", + core_to_driver_id_.size(), " cores. Program specified ", + device_assignment.replica_count(), " replicas."); + LOG(WARNING) << error_msg; + return std::make_shared( + this, operation_id, tensorflow::errors::InvalidArgument(error_msg)); + } + auto deps = GetDependencyOperationIds(wait_for); deps.insert(static_cast(program)->operation_id());