diff --git a/tensorflow/core/common_runtime/pending_counts.h b/tensorflow/core/common_runtime/pending_counts.h index 46b6509390f..9792bf8befd 100644 --- a/tensorflow/core/common_runtime/pending_counts.h +++ b/tensorflow/core/common_runtime/pending_counts.h @@ -16,6 +16,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -93,63 +95,75 @@ class PendingCounts { void set_initial_count(Handle h, size_t pending_count) { if (h.is_large_) { - LargeCounts* c = Large(h); - c->pending = pending_count; - c->dead_count = 0; - c->has_started = 0; + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending = pending_count; + c.dead_count = 0; + c.has_started = 0; + c_ptr->store(c, std::memory_order_relaxed); } else { - PackedCounts* c = Packed(h); DCHECK_LE(pending_count, kMaxCountForPackedCounts); - c->pending = pending_count; - c->dead_count = 0; - c->has_started = 0; + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending = pending_count; + c.dead_count = 0; + c.has_started = 0; + c_ptr->store(c, std::memory_order_relaxed); } } NodeState node_state(Handle h) { if (h.is_large_) { - return NodeStateForStruct(Large(h)); + return NodeStateForStruct(Large(h)->load(std::memory_order_relaxed)); } else { - return NodeStateForStruct(Packed(h)); + return NodeStateForStruct(Packed(h)->load(std::memory_order_relaxed)); } } void mark_started(Handle h) { DCHECK_EQ(pending(h), 0); if (h.is_large_) { - LargeCounts* c = Large(h); - DCHECK_EQ(c->has_started, 0); - c->has_started = 1; + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 0); + c.has_started = 1; + c_ptr->store(c, std::memory_order_relaxed); } else { - PackedCounts* c = Packed(h); - DCHECK_EQ(c->has_started, 0); - c->has_started = 1; + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 0); + c.has_started = 1; + c_ptr->store(c, std::memory_order_relaxed); } } void mark_completed(Handle h) { if (h.is_large_) { - LargeCounts* c = Large(h); - DCHECK_EQ(c->has_started, 1); - c->pending = 1; + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 1); + c.pending = 1; + c_ptr->store(c, std::memory_order_relaxed); } else { - PackedCounts* c = Packed(h); - DCHECK_EQ(c->has_started, 1); - c->pending = 1; + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 1); + c.pending = 1; + c_ptr->store(c, std::memory_order_relaxed); } } int pending(Handle h) { if (h.is_large_) { - LargeCounts* c = Large(h); + LargeCounts c = Large(h)->load(std::memory_order_relaxed); if (PENDING_NOTREADY == NodeStateForStruct(c)) { - return c->pending; + return c.pending; } else { // The pending count encodes the state once the node has // started, so just return 0. return 0; } } else { - PackedCounts* c = Packed(h); + PackedCounts c = Packed(h)->load(std::memory_order_relaxed); if (PENDING_NOTREADY == NodeStateForStruct(c)) { - return c->pending; + return c.pending; } else { // The pending count encodes the state once the node has // started, so just return 0. @@ -160,50 +174,63 @@ class PendingCounts { int decrement_pending(Handle h, int v) { DCHECK_GE(pending(h), v); if (h.is_large_) { - LargeCounts* c = Large(h); - c->pending -= v; - return c->pending; + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending -= v; + c_ptr->store(c, std::memory_order_relaxed); + return c.pending; } else { - PackedCounts* c = Packed(h); - c->pending -= v; - return c->pending; + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending -= v; + c_ptr->store(c, std::memory_order_relaxed); + return c.pending; } } // Mark a merge node as live // REQUIRES: Node corresponding to "h" is a merge node void mark_live(Handle h) { if (h.is_large_) { - LargeCounts* c = Large(h); + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); // Only do anything if the node hasn't already started executing. if (PENDING_NOTREADY == NodeStateForStruct(c)) { - c->pending &= ~static_cast(0x1); + c.pending &= ~static_cast(0x1); + c_ptr->store(c, std::memory_order_relaxed); } } else { - PackedCounts* c = Packed(h); + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); // Only do anything if the node hasn't already started executing. if (PENDING_NOTREADY == NodeStateForStruct(c)) { static_assert(7 == kMaxCountForPackedCounts, "Live flag incorrect for max packed count"); - c->pending &= 0x6; + c.pending &= 0x6; + c_ptr->store(c, std::memory_order_relaxed); } } } int dead_count(Handle h) { - int r = h.is_large_ ? Large(h)->dead_count : Packed(h)->dead_count; + int r = h.is_large_ ? Large(h)->load(std::memory_order_relaxed).dead_count + : Packed(h)->load(std::memory_order_relaxed).dead_count; return r; } void increment_dead_count(Handle h) { if (h.is_large_) { - LargeCounts* c = Large(h); + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); if (PENDING_NOTREADY == NodeStateForStruct(c)) { - c->dead_count++; + c.dead_count++; + c_ptr->store(c, std::memory_order_relaxed); } } else { - PackedCounts* c = Packed(h); + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); if (PENDING_NOTREADY == NodeStateForStruct(c)) { - DCHECK_LT(c->dead_count, kMaxCountForPackedCounts); - c->dead_count++; + DCHECK_LT(c.dead_count, kMaxCountForPackedCounts); + c.dead_count++; + c_ptr->store(c, std::memory_order_relaxed); } } } @@ -230,6 +257,17 @@ class PendingCounts { } } + // The same as the above, but performs the operation atomically. This + // is thread-safe to run concurrently with other threads. + AdjustResult adjust_for_activation_atomic(Handle h, bool increment_dead) { + DCHECK_GE(pending(h), 1); + if (h.is_large_) { + return adjust_for_activation_shared_atomic(Large(h), increment_dead); + } else { + return adjust_for_activation_shared_atomic(Packed(h), increment_dead); + } + } + class Handle { public: Handle() : byte_offset_(0), is_large_(0) {} @@ -242,12 +280,31 @@ class PendingCounts { private: template - inline AdjustResult adjust_for_activation_shared(T* c, bool increment_dead) { - if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(c)) { - c->dead_count++; + inline AdjustResult adjust_for_activation_shared(std::atomic* c, + bool increment_dead) { + T val = c->load(std::memory_order_relaxed); + if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(val)) { + val.dead_count++; + } + val.pending--; + c->store(val, std::memory_order_relaxed); + return AdjustResult(val.dead_count, val.pending); + } + + template + inline AdjustResult adjust_for_activation_shared_atomic(std::atomic* c, + bool increment_dead) { + T old_val = c->load(std::memory_order_relaxed); + while (true) { + T new_val = old_val; + if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(new_val)) { + new_val.dead_count++; + } + new_val.pending--; + AdjustResult ret(new_val.dead_count, new_val.pending); + if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val))) + return ret; } - c->pending -= 1; - return AdjustResult(c->dead_count, c->pending); } // We keep track of the pending count and dead input count for each @@ -279,23 +336,24 @@ class PendingCounts { }; template - NodeState NodeStateForStruct(T* c) const { - if (c->has_started) { - return (c->pending == 0) ? STARTED : COMPLETED; + NodeState NodeStateForStruct(const T& c) const { + if (c.has_started) { + return (c.pending == 0) ? STARTED : COMPLETED; } else { - return (c->pending == 0) ? PENDING_READY : PENDING_NOTREADY; + return (c.pending == 0) ? PENDING_READY : PENDING_NOTREADY; } } - inline LargeCounts* Large(Handle h) { + inline std::atomic* Large(Handle h) { DCHECK(h.is_large_); - DCHECK_LE(h.byte_offset_ + sizeof(LargeCounts), num_bytes_); - DCHECK_EQ(h.byte_offset_ % alignof(LargeCounts), 0); - return reinterpret_cast(bytes_ + h.byte_offset_); + DCHECK_LE(h.byte_offset_ + sizeof(std::atomic), num_bytes_); + DCHECK_EQ(h.byte_offset_ % alignof(std::atomic), 0); + return reinterpret_cast*>(bytes_ + h.byte_offset_); } - inline PackedCounts* Packed(Handle h) { + inline std::atomic* Packed(Handle h) { DCHECK(!h.is_large_); DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_); - return reinterpret_cast(bytes_ + h.byte_offset_); + return reinterpret_cast*>(bytes_ + + h.byte_offset_); } const int num_bytes_; // Just for bounds checking in debug mode @@ -309,9 +367,10 @@ inline PendingCounts::Handle PendingCounts::Layout::CreateHandle( Handle result; if ((max_pending_count > kMaxCountForPackedCounts) || (max_dead_count > kMaxCountForPackedCounts)) { - int B = sizeof(LargeCounts); + constexpr int B = sizeof(std::atomic); // Round byte offset to proper alignment - DCHECK_GE(sizeof(LargeCounts), alignof(LargeCounts)); + static_assert(sizeof(std::atomic) >= + alignof(std::atomic)); int64 offset = ((static_cast(next_offset_) + B - 1) / B) * B; result.byte_offset_ = offset; result.is_large_ = true; @@ -319,8 +378,8 @@ inline PendingCounts::Handle PendingCounts::Layout::CreateHandle( } else { result.byte_offset_ = next_offset_; result.is_large_ = false; - DCHECK_EQ(sizeof(PackedCounts), 1); - next_offset_ += sizeof(PackedCounts); + static_assert(sizeof(std::atomic) == 1); + next_offset_ += sizeof(std::atomic); } return result; } diff --git a/tensorflow/core/common_runtime/pending_counts_test.cc b/tensorflow/core/common_runtime/pending_counts_test.cc index 5d5e7367c86..9debed4528a 100644 --- a/tensorflow/core/common_runtime/pending_counts_test.cc +++ b/tensorflow/core/common_runtime/pending_counts_test.cc @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/common_runtime/pending_counts.h" + #include #include +#include -#include "tensorflow/core/common_runtime/pending_counts.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" +using std::unique_ptr; + namespace tensorflow { TEST(PendingCounts, Simple) { @@ -165,4 +170,36 @@ TEST(PendingCounts, AdjustForActivation) { } } +TEST(PendingCounts, AdjustForActivationAtomic) { + PendingCounts::Layout layout; + PendingCounts::Handle handles[2]; + const int kInitialCounts[2] = {6, 16}; + handles[0] = layout.CreateHandle(kInitialCounts[0], 0); + handles[1] = layout.CreateHandle(kInitialCounts[1], 0); + PendingCounts c(layout); + c.set_initial_count(handles[0], kInitialCounts[0]); + c.set_initial_count(handles[1], kInitialCounts[1]); + + Env* env = Env::Default(); + std::atomic start{false}; + std::vector> threads; + for (int t = 0; t < 2; t++) { + threads.emplace_back(env->StartThread({}, "tester", [&]() { + while (!start) { + } + for (int i = 0; i < kInitialCounts[0] / 2; i++) { + c.adjust_for_activation_atomic(handles[0], false); + } + for (int i = 0; i < kInitialCounts[1] / 2; i++) { + c.adjust_for_activation_atomic(handles[1], false); + } + })); + } + start = true; + threads.clear(); // Joins the threads. + + EXPECT_EQ(c.pending(handles[0]), 0); + EXPECT_EQ(c.pending(handles[1]), 0); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index a6639b1132e..fde47200282 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -89,13 +89,12 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, IterationState* output_iter = input_iter; if (!item->is_enter_exit_or_next_iter) { - // Fast path for nodes types that don't need special handling + // Fast path for node types that don't need special handling. + // This is the case for most nodes. DCHECK_EQ(input_frame, output_frame); - // Normal path for most nodes - mutex_lock l(input_frame->mu); - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); - is_frame_done = - input_frame->DecrementOutstandingOpsLocked(input_iter, ready); + FrameState* frame = input_frame; + is_frame_done = frame->ActivateNodesAndAdjustOutstanding( + item, is_dead, output_iter, outputs, ready); } else if (item->is_enter) { FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); { @@ -105,7 +104,9 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, // Propagate to all active iterations if this is a loop invariant. output_frame->AddLoopInv(item, (*outputs)[0], ready); } else { - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + int activated = output_frame->ActivateNodesLocked( + item, is_dead, output_iter, outputs, ready); + output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); } output_frame->num_pending_inputs--; } @@ -124,7 +125,9 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, output_iter = input_frame->parent_iter; { mutex_lock l(output_frame->mu); - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + int activated = output_frame->ActivateNodesLocked( + item, is_dead, output_iter, outputs, ready); + output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); } is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); } @@ -153,7 +156,9 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, if (output_frame != nullptr) { // This is the case when node is not Enter, Exit, or NextIteration. DCHECK(input_frame == output_frame); - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + int activated = output_frame->ActivateNodesLocked( + item, is_dead, output_iter, outputs, ready); + output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); } is_frame_done = input_frame->DecrementOutstandingOpsLocked(input_iter, ready); @@ -378,13 +383,15 @@ void PropagatorState::CleanupFramesIterations(FrameState* frame, } } -void PropagatorState::FrameState::ActivateNodesFastPath( +template +int PropagatorState::FrameState::ActivateNodesFastPathInternal( const NodeItem* item, const bool is_dead, IterationState* iter_state, EntryVector* outputs, TaggedNodeSeq* ready) { // If we know that none of the item's edge destinations require special // handling (i.e. none of the nodes is a merge or control trigger node), we // can take a fast path that avoids accessing the destination NodeItem. const GraphView& gview = immutable_state.graph_view(); + int new_outstanding = 0; // Add dst to the ready queue if it's ready // @@ -399,12 +406,11 @@ void PropagatorState::FrameState::ActivateNodesFastPath( t.input_frame = this; \ t.input_iter = iter_state; \ t.is_dead = adjust_result.any_dead; \ - iter_state->outstanding_ops++; \ + new_outstanding++; \ } \ } while (0); Entry* input_tensors = iter_state->input_tensors; - for (const EdgeInfo& e : item->output_edges()) { const int dst_id = e.dst_id; const PendingCounts::Handle dst_pending_id = @@ -413,14 +419,17 @@ void PropagatorState::FrameState::ActivateNodesFastPath( const bool increment_dead = (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE)); - const PendingCounts::AdjustResult adjust_result = - iter_state->adjust_for_activation(dst_pending_id, increment_dead); const int dst_loc = e.input_slot; if (e.is_last) { input_tensors[dst_loc] = std::move((*outputs)[src_slot]); } else { input_tensors[dst_loc] = (*outputs)[src_slot]; } + const PendingCounts::AdjustResult adjust_result = + atomic + ? iter_state->adjust_for_activation_atomic(dst_pending_id, + increment_dead) + : iter_state->adjust_for_activation(dst_pending_id, increment_dead); MAYBE_ADD_TO_READY(dst_id, adjust_result); } @@ -429,27 +438,31 @@ void PropagatorState::FrameState::ActivateNodesFastPath( const PendingCounts::Handle dst_pending_id = immutable_state.pending_ids()[dst_id]; const PendingCounts::AdjustResult adjust_result = - iter_state->adjust_for_activation(dst_pending_id, is_dead); + atomic + ? iter_state->adjust_for_activation_atomic(dst_pending_id, is_dead) + : iter_state->adjust_for_activation(dst_pending_id, is_dead); MAYBE_ADD_TO_READY(dst_id, adjust_result); } + + return new_outstanding; #undef MAYBE_ADD_TO_READY } -void PropagatorState::FrameState::ActivateNodesSlowPath( +int PropagatorState::FrameState::ActivateNodesSlowPath( const NodeItem* item, const bool is_dead, IterationState* iter_state, EntryVector* outputs, TaggedNodeSeq* ready) { // If any of the edge destinations is a merge or a control trigger node, // we need to read each destination NodeItem to determine what action // to take. const GraphView& gview = immutable_state.graph_view(); - + int activated = 0; auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item, bool dst_ready, bool dst_dead) { // Add dst to the ready queue if it's ready if (dst_ready) { if (dst_item->is_control_trigger) dst_dead = false; ready->emplace_back(dst_item, this, iter_state, dst_dead); - iter_state->outstanding_ops++; + activated++; } }; @@ -544,43 +557,72 @@ void PropagatorState::FrameState::ActivateNodesSlowPath( } maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead); } + + return activated; } -void PropagatorState::FrameState::ActivateNodes(const NodeItem* item, - const bool is_dead, - IterationState* iter_state, - EntryVector* outputs, - TaggedNodeSeq* ready) { +bool PropagatorState::FrameState::ActivateNodesAndAdjustOutstanding( + const NodeItem* item, const bool is_dead, IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) { if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { - ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready); + mutex_lock l(mu); + int activated = + ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready); + return AdjustOutstandingOpsLocked(iter_state, activated - 1, ready); + } + { + tf_shared_lock l(mu); + int activated = + ActivateNodesFastPathShared(item, is_dead, iter_state, outputs, ready); + bool iter_done = AdjustOutstandingOpsFastPath(iter_state, activated - 1); + if (!iter_done) return false; + } + mutex_lock l(mu); + return CleanupIterations(iter_state, ready); +} + +int PropagatorState::FrameState::ActivateNodesLocked(const NodeItem* item, + const bool is_dead, + IterationState* iter_state, + EntryVector* outputs, + TaggedNodeSeq* ready) { + if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { + return ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready); } else { - ActivateNodesFastPath(item, is_dead, iter_state, outputs, ready); + return ActivateNodesFastPathLocked(item, is_dead, iter_state, outputs, + ready); } } void PropagatorState::FrameState::ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready) { + int activated = 0; // Propagate the deferred NextIteration nodes to the new iteration. for (auto& node_entry : next_iter_roots) { const NodeItem* item = node_entry.first; const Entry& entry = node_entry.second; const bool is_dead = entry.state == Entry::State::NO_VALUE; EntryVector outputs{entry}; - ActivateNodes(item, is_dead, iter_state, &outputs, ready); + activated += + ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready); } next_iter_roots.clear(); + AdjustOutstandingOpsLocked(iter_state, activated, ready); } void PropagatorState::FrameState::ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready) { // Propagate loop invariants to the new iteration. + int activated = 0; for (auto& node_entry : inv_values) { const NodeItem* item = node_entry.first; const Entry& entry = node_entry.second; const bool is_dead = entry.state == Entry::State::NO_VALUE; EntryVector outputs{entry}; - ActivateNodes(item, is_dead, iter_state, &outputs, ready); + activated += + ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready); } + AdjustOutstandingOpsLocked(iter_state, activated, ready); } void PropagatorState::FrameState::AddLoopInv(const NodeItem* item, @@ -593,7 +635,10 @@ void PropagatorState::FrameState::AddLoopInv(const NodeItem* item, const bool is_dead = entry.state == Entry::State::NO_VALUE; for (int i = 0; i <= iteration_count; ++i) { EntryVector outputs{entry}; - ActivateNodes(item, is_dead, GetIteration(i), &outputs, ready); + IterationState* iter_state = GetIteration(i); + int activated = + ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready); + AdjustOutstandingOpsLocked(iter_state, activated, ready); } } @@ -676,8 +721,50 @@ void PropagatorState::FrameState::SetIteration(int64 iter, // frame. Return true iff the execution of the frame is done. bool PropagatorState::FrameState::DecrementOutstandingOps( IterationState* iter_state, TaggedNodeSeq* ready) { + return AdjustOutstandingOps(iter_state, -1, ready); +} + +bool PropagatorState::FrameState::AdjustOutstandingOps( + IterationState* iter_state, int delta, TaggedNodeSeq* ready) { + // Given the following profile of values of 'delta' for wide_deep model from + // the TF model garden: + // + // Count Value + // --------------- + // 757938 delta=0x0 + // 541713 delta=0xffffffff + // 138115 delta=0x1 + // 58770 delta=0x2 + // 5394 delta=0x3 + // 4669 delta=0x4 + // 2037 delta=0xa + // 1646 delta=0x7 + // 1632 delta=0x6 + // 1613 delta=0x6c + // 1224 delta=0x5 + // 409 delta=0x53 + // 17 delta=0x86 + // + // ... it's worth no-opping out when delta == 0 to avoid the atomic + // instruction. + if (delta == 0) { + return false; + } + { + tf_shared_lock sl(mu); + if (TF_PREDICT_TRUE(!AdjustOutstandingOpsFastPath(iter_state, delta))) { + return false; + } + } mutex_lock l(mu); - return DecrementOutstandingOpsLocked(iter_state, ready); + DCHECK(IsIterationDone(iter_state)); + return CleanupIterations(iter_state, ready); +} + +bool PropagatorState::FrameState::AdjustOutstandingOpsFastPath( + IterationState* iter_state, int delta) { + auto old_val = iter_state->outstanding_ops.fetch_add(delta); + return (old_val + delta == 0) && IsIterationDone(iter_state); } // Decrement the outstanding op count and clean up the iterations in the @@ -685,12 +772,22 @@ bool PropagatorState::FrameState::DecrementOutstandingOps( bool PropagatorState::FrameState::DecrementOutstandingOpsLocked( IterationState* iter_state, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { - iter_state->outstanding_ops--; - if (iter_state->outstanding_ops != 0) { + return AdjustOutstandingOpsLocked(iter_state, -1, ready); +} + +bool PropagatorState::FrameState::AdjustOutstandingOpsLocked( + IterationState* iter_state, int delta, TaggedNodeSeq* ready) { + // We hold the lock, so we don't need to use an atomic modification. + auto cur_val = iter_state->outstanding_ops.load(std::memory_order_relaxed); + DCHECK(delta >= 0 || cur_val >= -delta) + << "cannot adjust outstanding_ops by " << delta + << " when current value is " << cur_val; + auto new_val = cur_val + delta; + iter_state->outstanding_ops.store(new_val, std::memory_order_relaxed); + if (new_val != 0) { return false; - } else { - return CleanupIterations(iter_state, ready); } + return CleanupIterations(iter_state, ready); } // Returns true if the computation in the frame is completed. diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index 167519ccc73..4e66e709310 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -143,7 +143,7 @@ class PropagatorState { Entry* input_tensors; // The number of outstanding ops for each iteration. - size_t outstanding_ops; + std::atomic outstanding_ops; // The number of outstanding frames for each iteration. int outstanding_frame_count; @@ -170,6 +170,10 @@ class PropagatorState { bool increment_dead) { return counts.adjust_for_activation(h, increment_dead); } + PendingCounts::AdjustResult adjust_for_activation_atomic( + PendingCounts::Handle h, bool increment_dead) { + return counts.adjust_for_activation_atomic(h, increment_dead); + } ~IterationState() { delete[] input_tensors; } @@ -283,7 +287,7 @@ class PropagatorState { void InitializeFrameInfo(const ImmutableExecutorState::FrameInfo& finfo); inline IterationState* GetIteration(int64 iter) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + TF_SHARED_LOCKS_REQUIRED(mu) { if (TF_PREDICT_TRUE(iter == 0)) { return iterations_first; } else { @@ -294,13 +298,26 @@ class PropagatorState { void SetIteration(int64 iter, IterationState* state); - // Decrement the outstanding op count and clean up the iterations in the - // frame. Return true iff the execution of the frame is done. + // Adjust the outstanding op count by 'delta' and clean up the iterations in + // the frame if no more ops are oustanding. Return true iff the execution of + // the frame is done. + // + // Avoids acquiring the lock in the common case that the frame is not done. + bool AdjustOutstandingOps(IterationState* iter_state, int delta, + TaggedNodeSeq* ready); + + bool AdjustOutstandingOpsLocked(IterationState* iter_state, int delta, + TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + bool AdjustOutstandingOpsFastPath(IterationState* iter_state, int delta) + TF_SHARED_LOCKS_REQUIRED(mu); + + // Convenience methods for the above 'Adjust' calls where delta takes the + // common value of -1. bool DecrementOutstandingOps(IterationState* iter_state, TaggedNodeSeq* ready); - // Decrement the outstanding op count and clean up the iterations in the - // frame. Return true iff the execution of the frame is done. bool DecrementOutstandingOpsLocked(IterationState* iter_state, TaggedNodeSeq* ready); @@ -309,7 +326,7 @@ class PropagatorState { // Returns true if the iteration of the frame is completed. bool IsIterationDone(IterationState* iter_state) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + TF_SHARED_LOCKS_REQUIRED(mu); // Increments the iteration id. If this is a new iteration, initialize it. // @@ -332,9 +349,23 @@ class PropagatorState { // Activate the successors of a node. Contents of *outputs are left in an // indeterminate state after returning from this method. - void ActivateNodes(const NodeItem* item, const bool is_dead, - IterationState* iter_state, EntryVector* outputs, - TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + // + // In the case that 'item' is a simple node (no merge/control outputs) this + // will acquire a shared lock and can run concurrently with other + // invocations. + // + // Return true if the frame is done after activation. + bool ActivateNodesAndAdjustOutstanding(const NodeItem* item, + const bool is_dead, + IterationState* iter_state, + EntryVector* outputs, + TaggedNodeSeq* ready); + + // Same as the above, but requires 'mu' already held in exclusive mode. + int ActivateNodesLocked(const NodeItem* item, const bool is_dead, + IterationState* iter_state, EntryVector* outputs, + TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); // Cleanup iterations of this frame starting from the given iteration. bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready) @@ -359,14 +390,35 @@ class PropagatorState { private: // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. - void ActivateNodesFastPath(const NodeItem* item, const bool is_dead, - IterationState* iter_state, EntryVector* outputs, - TaggedNodeSeq* ready) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + // This variant does not use atomic operations to modify the pending counts + // and thus must hold the exclusive lock. + int ActivateNodesFastPathLocked(const NodeItem* item, const bool is_dead, + IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + return ActivateNodesFastPathInternal(item, is_dead, iter_state, + outputs, ready); + } - void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead, - IterationState* iter_state, EntryVector* outputs, - TaggedNodeSeq* ready) + // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. + // This variant uses atomic operations to modify the pending counts. + int ActivateNodesFastPathShared(const NodeItem* item, const bool is_dead, + IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) + TF_SHARED_LOCKS_REQUIRED(mu) { + return ActivateNodesFastPathInternal(item, is_dead, iter_state, + outputs, ready); + } + + template + int ActivateNodesFastPathInternal(const NodeItem* item, const bool is_dead, + IterationState* iter_state, + EntryVector* outputs, + TaggedNodeSeq* ready); + + int ActivateNodesSlowPath(const NodeItem* item, const bool is_dead, + IterationState* iter_state, EntryVector* outputs, + TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); };