[Executor] Reorganize code in ExecutorState::NodeDone()
for efficiency.
Executor microbenchmarks show a 3.22% to 4.16% improvement with this change, which avoids re-checking the status multiple times in the non-error case. PiperOrigin-RevId: 304719934 Change-Id: I6a9e3d1db8b13f32eb558a57fcb272c07ba1079a
This commit is contained in:
parent
d06dd339b7
commit
fc2d7fdacb
@ -316,6 +316,8 @@ class ExecutorState {
|
||||
// nodes in 'ready' into 'inline_ready'.
|
||||
//
|
||||
// This method will clear `*ready` before returning.
|
||||
//
|
||||
// REQUIRES: `!ready->empty()`.
|
||||
void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
|
||||
|
||||
// Clean up when this executor is done.
|
||||
@ -1022,73 +1024,80 @@ template <class PropagatorStateType>
|
||||
bool ExecutorState<PropagatorStateType>::NodeDone(
|
||||
const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats,
|
||||
TaggedNodeReadyQueue* inline_ready) {
|
||||
nodestats::SetAllEnd(stats);
|
||||
if (stats) {
|
||||
if (stats_collector_) {
|
||||
stats->Done(immutable_state_.params().device->name());
|
||||
} else {
|
||||
delete stats;
|
||||
}
|
||||
nodestats::SetAllEnd(stats);
|
||||
DCHECK_NE(stats_collector_, nullptr);
|
||||
stats->Done(immutable_state_.params().device->name());
|
||||
}
|
||||
|
||||
bool abort_run = false;
|
||||
if (!s.ok()) {
|
||||
// Some error happened. This thread of computation is done.
|
||||
mutex_lock l(mu_);
|
||||
if (status_.ok()) {
|
||||
abort_run = true;
|
||||
if (TF_PREDICT_TRUE(s.ok())) {
|
||||
const size_t ready_size = ready->size();
|
||||
if (ready_size == 0) {
|
||||
return num_outstanding_ops_.fetch_sub(1) == 1;
|
||||
} else {
|
||||
// NOTE: Avoid touching the atomic counter if only one node becomes ready.
|
||||
if (ready_size > 1) {
|
||||
num_outstanding_ops_.fetch_add(ready_size - 1,
|
||||
std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
// If execution has been cancelled, mark any new errors as being derived.
|
||||
// This ensures any errors triggered by cancellation are marked as
|
||||
// derived.
|
||||
if (cancellation_manager_ && cancellation_manager_->IsCancelled()) {
|
||||
status_ = StatusGroup::MakeDerived(s);
|
||||
} else {
|
||||
status_ = s;
|
||||
// Schedule the ready nodes in 'ready'.
|
||||
ScheduleReady(ready, inline_ready);
|
||||
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
bool abort_run = false;
|
||||
|
||||
// Some error happened. This thread of computation is done.
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (status_.ok()) {
|
||||
// If this is the first node to fail in this run, we are responsible for
|
||||
// aborting all other execution in the step.
|
||||
abort_run = true;
|
||||
|
||||
// If execution has been cancelled, mark any new errors as being
|
||||
// derived. This ensures any errors triggered by cancellation are marked
|
||||
// as derived.
|
||||
if (cancellation_manager_ && cancellation_manager_->IsCancelled()) {
|
||||
status_ = StatusGroup::MakeDerived(s);
|
||||
} else {
|
||||
status_ = s;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (abort_run) {
|
||||
TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
|
||||
if (cancellation_manager_) {
|
||||
// only log when the abort happens during the actual run time.
|
||||
auto device_name = immutable_state_.params().device->name();
|
||||
// Use VLOG instead of LOG(warning) because error status is expected when
|
||||
// the executor is run under the grappler optimization phase or when
|
||||
// iterating through a tf.data input pipeline.
|
||||
VLOG(1) << "[" << device_name << "] Executor start aborting: " << s;
|
||||
|
||||
if (abort_run) {
|
||||
TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
|
||||
if (cancellation_manager_) {
|
||||
// Only log when the abort happens during the actual run time.
|
||||
// Use VLOG instead of LOG(warning) because error status is expected
|
||||
// when the executor is run under the grappler optimization phase or
|
||||
// when iterating through a tf.data input pipeline.
|
||||
VLOG(1) << "[" << immutable_state_.params().device->name()
|
||||
<< "] Executor start aborting: " << s;
|
||||
}
|
||||
|
||||
if (rendezvous_) {
|
||||
rendezvous_->StartAbort(s);
|
||||
}
|
||||
if (collective_executor_) {
|
||||
collective_executor_->StartAbort(s);
|
||||
}
|
||||
if (cancellation_manager_) {
|
||||
cancellation_manager_->StartCancel();
|
||||
}
|
||||
}
|
||||
|
||||
if (rendezvous_) {
|
||||
rendezvous_->StartAbort(s);
|
||||
}
|
||||
if (collective_executor_) {
|
||||
collective_executor_->StartAbort(s);
|
||||
}
|
||||
if (cancellation_manager_) {
|
||||
cancellation_manager_->StartCancel();
|
||||
}
|
||||
return num_outstanding_ops_.fetch_sub(1) == 1;
|
||||
}
|
||||
|
||||
bool completed = false;
|
||||
const size_t ready_size = ready->size();
|
||||
if (ready_size == 0 || !s.ok()) {
|
||||
completed = (num_outstanding_ops_.fetch_sub(1) == 1);
|
||||
} else if (ready_size > 1) {
|
||||
num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
// Schedule the ready nodes in 'ready'.
|
||||
if (s.ok()) {
|
||||
ScheduleReady(ready, inline_ready);
|
||||
}
|
||||
return completed;
|
||||
}
|
||||
|
||||
template <class PropagatorStateType>
|
||||
void ExecutorState<PropagatorStateType>::ScheduleReady(
|
||||
TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
|
||||
if (ready->empty()) return;
|
||||
DCHECK(!ready->empty());
|
||||
|
||||
int64 scheduled_nsec = 0;
|
||||
if (stats_collector_) {
|
||||
|
Loading…
Reference in New Issue
Block a user