[pod driver] Add temporary restriction to replicate computation across all cores or not at all

PiperOrigin-RevId: 334933314
Change-Id: I57dc3889af8acc335e592fc4ea37a27c8b960119
This commit is contained in:
Frank Chen 2020-10-01 17:37:11 -07:00 committed by TensorFlower Gardener
parent 8564160d1f
commit d1399ed7cf

View File

@ -54,6 +54,26 @@ class PodEvent : public Event {
const int64_t operation_id_; const int64_t operation_id_;
}; };
class ErrorEvent : public PodEvent {
public:
explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status)
: PodEvent(driver, operation_id) {
status_ = status;
}
xla::Status Await() override { return status_; }
absl::optional<xla::Status> AwaitWithTimeout(
absl::Duration duration) override {
return status_;
}
void AddCallback(std::function<void(Status)> callback) override {
callback(status_);
}
private:
Status status_;
};
class CombinedEvent : public PodEvent { class CombinedEvent : public PodEvent {
public: public:
explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id, explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id,
@ -543,6 +563,19 @@ class PodTpuDriver : public TpuDriver {
const xla::DeviceAssignmentProto& device_assignment, const xla::DeviceAssignmentProto& device_assignment,
absl::Span<Event* const> wait_for) override { absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId(); int64_t operation_id = GetOperationId();
if (device_assignment.replica_count() != 1 &&
device_assignment.replica_count() != core_to_driver_id_.size()) {
// TODO(frankchn): Remove restriction once we figure out what's wrong.
std::string error_msg =
absl::StrCat("Programs must be replicated across all ",
core_to_driver_id_.size(), " cores. Program specified ",
device_assignment.replica_count(), " replicas.");
LOG(WARNING) << error_msg;
return std::make_shared<ErrorEvent>(
this, operation_id, tensorflow::errors::InvalidArgument(error_msg));
}
auto deps = GetDependencyOperationIds(wait_for); auto deps = GetDependencyOperationIds(wait_for);
deps.insert(static_cast<PodLoadedProgramHandle*>(program)->operation_id()); deps.insert(static_cast<PodLoadedProgramHandle*>(program)->operation_id());