[pod driver] Add temporary restriction to replicate computation across all cores or not at all
PiperOrigin-RevId: 334933314 Change-Id: I57dc3889af8acc335e592fc4ea37a27c8b960119
This commit is contained in:
parent
8564160d1f
commit
d1399ed7cf
@ -54,6 +54,26 @@ class PodEvent : public Event {
|
|||||||
const int64_t operation_id_;
|
const int64_t operation_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ErrorEvent : public PodEvent {
|
||||||
|
public:
|
||||||
|
explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status)
|
||||||
|
: PodEvent(driver, operation_id) {
|
||||||
|
status_ = status;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Status Await() override { return status_; }
|
||||||
|
absl::optional<xla::Status> AwaitWithTimeout(
|
||||||
|
absl::Duration duration) override {
|
||||||
|
return status_;
|
||||||
|
}
|
||||||
|
void AddCallback(std::function<void(Status)> callback) override {
|
||||||
|
callback(status_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status status_;
|
||||||
|
};
|
||||||
|
|
||||||
class CombinedEvent : public PodEvent {
|
class CombinedEvent : public PodEvent {
|
||||||
public:
|
public:
|
||||||
explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id,
|
explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id,
|
||||||
@ -543,6 +563,19 @@ class PodTpuDriver : public TpuDriver {
|
|||||||
const xla::DeviceAssignmentProto& device_assignment,
|
const xla::DeviceAssignmentProto& device_assignment,
|
||||||
absl::Span<Event* const> wait_for) override {
|
absl::Span<Event* const> wait_for) override {
|
||||||
int64_t operation_id = GetOperationId();
|
int64_t operation_id = GetOperationId();
|
||||||
|
|
||||||
|
if (device_assignment.replica_count() != 1 &&
|
||||||
|
device_assignment.replica_count() != core_to_driver_id_.size()) {
|
||||||
|
// TODO(frankchn): Remove restriction once we figure out what's wrong.
|
||||||
|
std::string error_msg =
|
||||||
|
absl::StrCat("Programs must be replicated across all ",
|
||||||
|
core_to_driver_id_.size(), " cores. Program specified ",
|
||||||
|
device_assignment.replica_count(), " replicas.");
|
||||||
|
LOG(WARNING) << error_msg;
|
||||||
|
return std::make_shared<ErrorEvent>(
|
||||||
|
this, operation_id, tensorflow::errors::InvalidArgument(error_msg));
|
||||||
|
}
|
||||||
|
|
||||||
auto deps = GetDependencyOperationIds(wait_for);
|
auto deps = GetDependencyOperationIds(wait_for);
|
||||||
deps.insert(static_cast<PodLoadedProgramHandle*>(program)->operation_id());
|
deps.insert(static_cast<PodLoadedProgramHandle*>(program)->operation_id());
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user