[pod driver] Add temporary restriction to replicate computation across all cores or not at all
PiperOrigin-RevId: 334933314 Change-Id: I57dc3889af8acc335e592fc4ea37a27c8b960119
This commit is contained in:
parent
8564160d1f
commit
d1399ed7cf
@ -54,6 +54,26 @@ class PodEvent : public Event {
|
||||
const int64_t operation_id_;
|
||||
};
|
||||
|
||||
class ErrorEvent : public PodEvent {
|
||||
public:
|
||||
explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status)
|
||||
: PodEvent(driver, operation_id) {
|
||||
status_ = status;
|
||||
}
|
||||
|
||||
xla::Status Await() override { return status_; }
|
||||
absl::optional<xla::Status> AwaitWithTimeout(
|
||||
absl::Duration duration) override {
|
||||
return status_;
|
||||
}
|
||||
void AddCallback(std::function<void(Status)> callback) override {
|
||||
callback(status_);
|
||||
}
|
||||
|
||||
private:
|
||||
Status status_;
|
||||
};
|
||||
|
||||
class CombinedEvent : public PodEvent {
|
||||
public:
|
||||
explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id,
|
||||
@ -543,6 +563,19 @@ class PodTpuDriver : public TpuDriver {
|
||||
const xla::DeviceAssignmentProto& device_assignment,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
int64_t operation_id = GetOperationId();
|
||||
|
||||
if (device_assignment.replica_count() != 1 &&
|
||||
device_assignment.replica_count() != core_to_driver_id_.size()) {
|
||||
// TODO(frankchn): Remove restriction once we figure out what's wrong.
|
||||
std::string error_msg =
|
||||
absl::StrCat("Programs must be replicated across all ",
|
||||
core_to_driver_id_.size(), " cores. Program specified ",
|
||||
device_assignment.replica_count(), " replicas.");
|
||||
LOG(WARNING) << error_msg;
|
||||
return std::make_shared<ErrorEvent>(
|
||||
this, operation_id, tensorflow::errors::InvalidArgument(error_msg));
|
||||
}
|
||||
|
||||
auto deps = GetDependencyOperationIds(wait_for);
|
||||
deps.insert(static_cast<PodLoadedProgramHandle*>(program)->operation_id());
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user