From 4c961fb918d771f95a25fec53b6f90a7caa0d71f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Sep 2020 12:53:07 -0700 Subject: [PATCH] This follows a similar approach to the earlier work by Derek Murray (change ID I9058c55898079cb3ab6011513b4468ddb293f59f) and allows using atomic operations on the pending counts. It extends that work to also apply to graphs that include control/merge nodes (eg the 'wide_deep' model in the tf models repo) The approach here is to note that even in a graph that has control/merge nodes, it's likely that many nodes do not have any merge/control outputs. There's already a fast path for this case with simpler logic, and that simple logic has the advantage of only accessing the pending count in one spot where it decrements the incoming pending values. Only the last such decrementer propagates the output and activates the node. Given this, we can use atomic operations for all such "fast path" nodes and avoid holding an exclusive lock. We fall back to the exclusive lock for any nodes that have slow-path outputs. As a conservative measure, this patch acquires the shared lock for the fast path, though it may not be necessary. The other bit in this patch is the management of the 'outstanding_ops' member. In order to allow concurrent completion of ops without the exclusive lock, this also becomes atomic. Simply making it atomic is sufficient but leaves a lot of performance on the table. This patch batches the updates of this variable so that each op completion only touches it at most once, and no-ops out in the case that the pending op count doesn't need to be modified (eg when an op completion activates exactly one downstream node). I benchmarked this change using tensorflow-models/official/r1/wide_deep and collecting the distribution of "steps/sec" reported using commands like: # switch to new code $ python -u census_main.py -ebe=20 -te 20 -mt wide 2>&1 | tee /tmp/before/wide.txt # switch to new code $ python -u census_main.py -ebe=20 -te 20 -mt wide 2>&1 | tee /tmp/after/wide.txt ... for the 'wide', 'wide_deep', and 'deep' models. I then exported the 'steps/second' metrics to TSV with: $ grep 'tensorflow:global_step/sec' \ /tmp/{before,after}/*.txt | perl -p -e \ 's,/tmp/(.+)/(.+).txt:.*?([\d\.]+)$,$1 $2 $3,g' > /tmp/data.tsv Mean steps/sec are as follows on my AMD Ryzen 9 3900X desktop (12 physical cores, 'performance' CPU governor enabled): Before: model steps/sec ---------------------- deep 876 wide 776 wide_deep 625 After: model steps/sec improvement ----------------------------------- deep 897 (+2.5%) wide 928 (+19.5%) wide_deep 760 (+21.6%) A few notes worth considering: Using atomic operations has some fixed overhead compared to non-atomics, and particularly when the cache line being operated on is in M "modified" state in another core. This increased latency of operating on a cache-line in remote M state (aka "HitM" access) is also true with non-atomic operations, but non-atomics can potentially allow for instruction level parallelism and pipeline multiple such accesses together, whereas atomics do not on current x86 architectures[1]. So, there is possibly a cross-over point on some graph shapes where the mutex-based synchronization with non-atomic operations actually beats the atomic-based implementation here. That said, the current code path for the non-atomic (mutex-protected) case is complex/branchy enough that it's not clear that any significant pipelining of reads can actually occur without some more explicit unrolling, etc. It's also worth noting that, in uncontended cases, the atomics will be slower than the non-atomics. However, this is unlikely an issue in real workloads, considering that these code paths are only uncontended when the actual work of op execution dominates the runtime. In other words, this is only a perf bottleneck when it's under contention, and a small regression for uncontended cases is likely in the noise. [1] https://spcl.inf.ethz.ch/Publications/.pdf/atomic-bench.pdf PiperOrigin-RevId: 331204846 Change-Id: I8e4968ad653e1973bdd7ef95d0747348677f8b4b --- .../core/common_runtime/pending_counts.h | 181 ++++++++++++------ .../common_runtime/pending_counts_test.cc | 39 +++- .../core/common_runtime/propagator_state.cc | 163 ++++++++++++---- .../core/common_runtime/propagator_state.h | 86 +++++++-- 4 files changed, 357 insertions(+), 112 deletions(-) diff --git a/tensorflow/core/common_runtime/pending_counts.h b/tensorflow/core/common_runtime/pending_counts.h index 46b6509390f..9792bf8befd 100644 --- a/tensorflow/core/common_runtime/pending_counts.h +++ b/tensorflow/core/common_runtime/pending_counts.h @@ -16,6 +16,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -93,63 +95,75 @@ class PendingCounts { void set_initial_count(Handle h, size_t pending_count) { if (h.is_large_) { - LargeCounts* c = Large(h); - c->pending = pending_count; - c->dead_count = 0; - c->has_started = 0; + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending = pending_count; + c.dead_count = 0; + c.has_started = 0; + c_ptr->store(c, std::memory_order_relaxed); } else { - PackedCounts* c = Packed(h); DCHECK_LE(pending_count, kMaxCountForPackedCounts); - c->pending = pending_count; - c->dead_count = 0; - c->has_started = 0; + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending = pending_count; + c.dead_count = 0; + c.has_started = 0; + c_ptr->store(c, std::memory_order_relaxed); } } NodeState node_state(Handle h) { if (h.is_large_) { - return NodeStateForStruct(Large(h)); + return NodeStateForStruct(Large(h)->load(std::memory_order_relaxed)); } else { - return NodeStateForStruct(Packed(h)); + return NodeStateForStruct(Packed(h)->load(std::memory_order_relaxed)); } } void mark_started(Handle h) { DCHECK_EQ(pending(h), 0); if (h.is_large_) { - LargeCounts* c = Large(h); - DCHECK_EQ(c->has_started, 0); - c->has_started = 1; + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 0); + c.has_started = 1; + c_ptr->store(c, std::memory_order_relaxed); } else { - PackedCounts* c = Packed(h); - DCHECK_EQ(c->has_started, 0); - c->has_started = 1; + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 0); + c.has_started = 1; + c_ptr->store(c, std::memory_order_relaxed); } } void mark_completed(Handle h) { if (h.is_large_) { - LargeCounts* c = Large(h); - DCHECK_EQ(c->has_started, 1); - c->pending = 1; + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 1); + c.pending = 1; + c_ptr->store(c, std::memory_order_relaxed); } else { - PackedCounts* c = Packed(h); - DCHECK_EQ(c->has_started, 1); - c->pending = 1; + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 1); + c.pending = 1; + c_ptr->store(c, std::memory_order_relaxed); } } int pending(Handle h) { if (h.is_large_) { - LargeCounts* c = Large(h); + LargeCounts c = Large(h)->load(std::memory_order_relaxed); if (PENDING_NOTREADY == NodeStateForStruct(c)) { - return c->pending; + return c.pending; } else { // The pending count encodes the state once the node has // started, so just return 0. return 0; } } else { - PackedCounts* c = Packed(h); + PackedCounts c = Packed(h)->load(std::memory_order_relaxed); if (PENDING_NOTREADY == NodeStateForStruct(c)) { - return c->pending; + return c.pending; } else { // The pending count encodes the state once the node has // started, so just return 0. @@ -160,50 +174,63 @@ class PendingCounts { int decrement_pending(Handle h, int v) { DCHECK_GE(pending(h), v); if (h.is_large_) { - LargeCounts* c = Large(h); - c->pending -= v; - return c->pending; + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending -= v; + c_ptr->store(c, std::memory_order_relaxed); + return c.pending; } else { - PackedCounts* c = Packed(h); - c->pending -= v; - return c->pending; + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending -= v; + c_ptr->store(c, std::memory_order_relaxed); + return c.pending; } } // Mark a merge node as live // REQUIRES: Node corresponding to "h" is a merge node void mark_live(Handle h) { if (h.is_large_) { - LargeCounts* c = Large(h); + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); // Only do anything if the node hasn't already started executing. if (PENDING_NOTREADY == NodeStateForStruct(c)) { - c->pending &= ~static_cast(0x1); + c.pending &= ~static_cast(0x1); + c_ptr->store(c, std::memory_order_relaxed); } } else { - PackedCounts* c = Packed(h); + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); // Only do anything if the node hasn't already started executing. if (PENDING_NOTREADY == NodeStateForStruct(c)) { static_assert(7 == kMaxCountForPackedCounts, "Live flag incorrect for max packed count"); - c->pending &= 0x6; + c.pending &= 0x6; + c_ptr->store(c, std::memory_order_relaxed); } } } int dead_count(Handle h) { - int r = h.is_large_ ? Large(h)->dead_count : Packed(h)->dead_count; + int r = h.is_large_ ? Large(h)->load(std::memory_order_relaxed).dead_count + : Packed(h)->load(std::memory_order_relaxed).dead_count; return r; } void increment_dead_count(Handle h) { if (h.is_large_) { - LargeCounts* c = Large(h); + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); if (PENDING_NOTREADY == NodeStateForStruct(c)) { - c->dead_count++; + c.dead_count++; + c_ptr->store(c, std::memory_order_relaxed); } } else { - PackedCounts* c = Packed(h); + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); if (PENDING_NOTREADY == NodeStateForStruct(c)) { - DCHECK_LT(c->dead_count, kMaxCountForPackedCounts); - c->dead_count++; + DCHECK_LT(c.dead_count, kMaxCountForPackedCounts); + c.dead_count++; + c_ptr->store(c, std::memory_order_relaxed); } } } @@ -230,6 +257,17 @@ class PendingCounts { } } + // The same as the above, but performs the operation atomically. This + // is thread-safe to run concurrently with other threads. + AdjustResult adjust_for_activation_atomic(Handle h, bool increment_dead) { + DCHECK_GE(pending(h), 1); + if (h.is_large_) { + return adjust_for_activation_shared_atomic(Large(h), increment_dead); + } else { + return adjust_for_activation_shared_atomic(Packed(h), increment_dead); + } + } + class Handle { public: Handle() : byte_offset_(0), is_large_(0) {} @@ -242,12 +280,31 @@ class PendingCounts { private: template - inline AdjustResult adjust_for_activation_shared(T* c, bool increment_dead) { - if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(c)) { - c->dead_count++; + inline AdjustResult adjust_for_activation_shared(std::atomic* c, + bool increment_dead) { + T val = c->load(std::memory_order_relaxed); + if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(val)) { + val.dead_count++; + } + val.pending--; + c->store(val, std::memory_order_relaxed); + return AdjustResult(val.dead_count, val.pending); + } + + template + inline AdjustResult adjust_for_activation_shared_atomic(std::atomic* c, + bool increment_dead) { + T old_val = c->load(std::memory_order_relaxed); + while (true) { + T new_val = old_val; + if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(new_val)) { + new_val.dead_count++; + } + new_val.pending--; + AdjustResult ret(new_val.dead_count, new_val.pending); + if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val))) + return ret; } - c->pending -= 1; - return AdjustResult(c->dead_count, c->pending); } // We keep track of the pending count and dead input count for each @@ -279,23 +336,24 @@ class PendingCounts { }; template - NodeState NodeStateForStruct(T* c) const { - if (c->has_started) { - return (c->pending == 0) ? STARTED : COMPLETED; + NodeState NodeStateForStruct(const T& c) const { + if (c.has_started) { + return (c.pending == 0) ? STARTED : COMPLETED; } else { - return (c->pending == 0) ? PENDING_READY : PENDING_NOTREADY; + return (c.pending == 0) ? PENDING_READY : PENDING_NOTREADY; } } - inline LargeCounts* Large(Handle h) { + inline std::atomic* Large(Handle h) { DCHECK(h.is_large_); - DCHECK_LE(h.byte_offset_ + sizeof(LargeCounts), num_bytes_); - DCHECK_EQ(h.byte_offset_ % alignof(LargeCounts), 0); - return reinterpret_cast(bytes_ + h.byte_offset_); + DCHECK_LE(h.byte_offset_ + sizeof(std::atomic), num_bytes_); + DCHECK_EQ(h.byte_offset_ % alignof(std::atomic), 0); + return reinterpret_cast*>(bytes_ + h.byte_offset_); } - inline PackedCounts* Packed(Handle h) { + inline std::atomic* Packed(Handle h) { DCHECK(!h.is_large_); DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_); - return reinterpret_cast(bytes_ + h.byte_offset_); + return reinterpret_cast*>(bytes_ + + h.byte_offset_); } const int num_bytes_; // Just for bounds checking in debug mode @@ -309,9 +367,10 @@ inline PendingCounts::Handle PendingCounts::Layout::CreateHandle( Handle result; if ((max_pending_count > kMaxCountForPackedCounts) || (max_dead_count > kMaxCountForPackedCounts)) { - int B = sizeof(LargeCounts); + constexpr int B = sizeof(std::atomic); // Round byte offset to proper alignment - DCHECK_GE(sizeof(LargeCounts), alignof(LargeCounts)); + static_assert(sizeof(std::atomic) >= + alignof(std::atomic)); int64 offset = ((static_cast(next_offset_) + B - 1) / B) * B; result.byte_offset_ = offset; result.is_large_ = true; @@ -319,8 +378,8 @@ inline PendingCounts::Handle PendingCounts::Layout::CreateHandle( } else { result.byte_offset_ = next_offset_; result.is_large_ = false; - DCHECK_EQ(sizeof(PackedCounts), 1); - next_offset_ += sizeof(PackedCounts); + static_assert(sizeof(std::atomic) == 1); + next_offset_ += sizeof(std::atomic); } return result; } diff --git a/tensorflow/core/common_runtime/pending_counts_test.cc b/tensorflow/core/common_runtime/pending_counts_test.cc index 5d5e7367c86..9debed4528a 100644 --- a/tensorflow/core/common_runtime/pending_counts_test.cc +++ b/tensorflow/core/common_runtime/pending_counts_test.cc @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/common_runtime/pending_counts.h" + #include #include +#include -#include "tensorflow/core/common_runtime/pending_counts.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" +using std::unique_ptr; + namespace tensorflow { TEST(PendingCounts, Simple) { @@ -165,4 +170,36 @@ TEST(PendingCounts, AdjustForActivation) { } } +TEST(PendingCounts, AdjustForActivationAtomic) { + PendingCounts::Layout layout; + PendingCounts::Handle handles[2]; + const int kInitialCounts[2] = {6, 16}; + handles[0] = layout.CreateHandle(kInitialCounts[0], 0); + handles[1] = layout.CreateHandle(kInitialCounts[1], 0); + PendingCounts c(layout); + c.set_initial_count(handles[0], kInitialCounts[0]); + c.set_initial_count(handles[1], kInitialCounts[1]); + + Env* env = Env::Default(); + std::atomic start{false}; + std::vector> threads; + for (int t = 0; t < 2; t++) { + threads.emplace_back(env->StartThread({}, "tester", [&]() { + while (!start) { + } + for (int i = 0; i < kInitialCounts[0] / 2; i++) { + c.adjust_for_activation_atomic(handles[0], false); + } + for (int i = 0; i < kInitialCounts[1] / 2; i++) { + c.adjust_for_activation_atomic(handles[1], false); + } + })); + } + start = true; + threads.clear(); // Joins the threads. + + EXPECT_EQ(c.pending(handles[0]), 0); + EXPECT_EQ(c.pending(handles[1]), 0); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index a6639b1132e..fde47200282 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -89,13 +89,12 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, IterationState* output_iter = input_iter; if (!item->is_enter_exit_or_next_iter) { - // Fast path for nodes types that don't need special handling + // Fast path for node types that don't need special handling. + // This is the case for most nodes. DCHECK_EQ(input_frame, output_frame); - // Normal path for most nodes - mutex_lock l(input_frame->mu); - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); - is_frame_done = - input_frame->DecrementOutstandingOpsLocked(input_iter, ready); + FrameState* frame = input_frame; + is_frame_done = frame->ActivateNodesAndAdjustOutstanding( + item, is_dead, output_iter, outputs, ready); } else if (item->is_enter) { FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); { @@ -105,7 +104,9 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, // Propagate to all active iterations if this is a loop invariant. output_frame->AddLoopInv(item, (*outputs)[0], ready); } else { - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + int activated = output_frame->ActivateNodesLocked( + item, is_dead, output_iter, outputs, ready); + output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); } output_frame->num_pending_inputs--; } @@ -124,7 +125,9 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, output_iter = input_frame->parent_iter; { mutex_lock l(output_frame->mu); - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + int activated = output_frame->ActivateNodesLocked( + item, is_dead, output_iter, outputs, ready); + output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); } is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); } @@ -153,7 +156,9 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, if (output_frame != nullptr) { // This is the case when node is not Enter, Exit, or NextIteration. DCHECK(input_frame == output_frame); - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + int activated = output_frame->ActivateNodesLocked( + item, is_dead, output_iter, outputs, ready); + output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); } is_frame_done = input_frame->DecrementOutstandingOpsLocked(input_iter, ready); @@ -378,13 +383,15 @@ void PropagatorState::CleanupFramesIterations(FrameState* frame, } } -void PropagatorState::FrameState::ActivateNodesFastPath( +template +int PropagatorState::FrameState::ActivateNodesFastPathInternal( const NodeItem* item, const bool is_dead, IterationState* iter_state, EntryVector* outputs, TaggedNodeSeq* ready) { // If we know that none of the item's edge destinations require special // handling (i.e. none of the nodes is a merge or control trigger node), we // can take a fast path that avoids accessing the destination NodeItem. const GraphView& gview = immutable_state.graph_view(); + int new_outstanding = 0; // Add dst to the ready queue if it's ready // @@ -399,12 +406,11 @@ void PropagatorState::FrameState::ActivateNodesFastPath( t.input_frame = this; \ t.input_iter = iter_state; \ t.is_dead = adjust_result.any_dead; \ - iter_state->outstanding_ops++; \ + new_outstanding++; \ } \ } while (0); Entry* input_tensors = iter_state->input_tensors; - for (const EdgeInfo& e : item->output_edges()) { const int dst_id = e.dst_id; const PendingCounts::Handle dst_pending_id = @@ -413,14 +419,17 @@ void PropagatorState::FrameState::ActivateNodesFastPath( const bool increment_dead = (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE)); - const PendingCounts::AdjustResult adjust_result = - iter_state->adjust_for_activation(dst_pending_id, increment_dead); const int dst_loc = e.input_slot; if (e.is_last) { input_tensors[dst_loc] = std::move((*outputs)[src_slot]); } else { input_tensors[dst_loc] = (*outputs)[src_slot]; } + const PendingCounts::AdjustResult adjust_result = + atomic + ? iter_state->adjust_for_activation_atomic(dst_pending_id, + increment_dead) + : iter_state->adjust_for_activation(dst_pending_id, increment_dead); MAYBE_ADD_TO_READY(dst_id, adjust_result); } @@ -429,27 +438,31 @@ void PropagatorState::FrameState::ActivateNodesFastPath( const PendingCounts::Handle dst_pending_id = immutable_state.pending_ids()[dst_id]; const PendingCounts::AdjustResult adjust_result = - iter_state->adjust_for_activation(dst_pending_id, is_dead); + atomic + ? iter_state->adjust_for_activation_atomic(dst_pending_id, is_dead) + : iter_state->adjust_for_activation(dst_pending_id, is_dead); MAYBE_ADD_TO_READY(dst_id, adjust_result); } + + return new_outstanding; #undef MAYBE_ADD_TO_READY } -void PropagatorState::FrameState::ActivateNodesSlowPath( +int PropagatorState::FrameState::ActivateNodesSlowPath( const NodeItem* item, const bool is_dead, IterationState* iter_state, EntryVector* outputs, TaggedNodeSeq* ready) { // If any of the edge destinations is a merge or a control trigger node, // we need to read each destination NodeItem to determine what action // to take. const GraphView& gview = immutable_state.graph_view(); - + int activated = 0; auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item, bool dst_ready, bool dst_dead) { // Add dst to the ready queue if it's ready if (dst_ready) { if (dst_item->is_control_trigger) dst_dead = false; ready->emplace_back(dst_item, this, iter_state, dst_dead); - iter_state->outstanding_ops++; + activated++; } }; @@ -544,43 +557,72 @@ void PropagatorState::FrameState::ActivateNodesSlowPath( } maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead); } + + return activated; } -void PropagatorState::FrameState::ActivateNodes(const NodeItem* item, - const bool is_dead, - IterationState* iter_state, - EntryVector* outputs, - TaggedNodeSeq* ready) { +bool PropagatorState::FrameState::ActivateNodesAndAdjustOutstanding( + const NodeItem* item, const bool is_dead, IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) { if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { - ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready); + mutex_lock l(mu); + int activated = + ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready); + return AdjustOutstandingOpsLocked(iter_state, activated - 1, ready); + } + { + tf_shared_lock l(mu); + int activated = + ActivateNodesFastPathShared(item, is_dead, iter_state, outputs, ready); + bool iter_done = AdjustOutstandingOpsFastPath(iter_state, activated - 1); + if (!iter_done) return false; + } + mutex_lock l(mu); + return CleanupIterations(iter_state, ready); +} + +int PropagatorState::FrameState::ActivateNodesLocked(const NodeItem* item, + const bool is_dead, + IterationState* iter_state, + EntryVector* outputs, + TaggedNodeSeq* ready) { + if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { + return ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready); } else { - ActivateNodesFastPath(item, is_dead, iter_state, outputs, ready); + return ActivateNodesFastPathLocked(item, is_dead, iter_state, outputs, + ready); } } void PropagatorState::FrameState::ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready) { + int activated = 0; // Propagate the deferred NextIteration nodes to the new iteration. for (auto& node_entry : next_iter_roots) { const NodeItem* item = node_entry.first; const Entry& entry = node_entry.second; const bool is_dead = entry.state == Entry::State::NO_VALUE; EntryVector outputs{entry}; - ActivateNodes(item, is_dead, iter_state, &outputs, ready); + activated += + ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready); } next_iter_roots.clear(); + AdjustOutstandingOpsLocked(iter_state, activated, ready); } void PropagatorState::FrameState::ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready) { // Propagate loop invariants to the new iteration. + int activated = 0; for (auto& node_entry : inv_values) { const NodeItem* item = node_entry.first; const Entry& entry = node_entry.second; const bool is_dead = entry.state == Entry::State::NO_VALUE; EntryVector outputs{entry}; - ActivateNodes(item, is_dead, iter_state, &outputs, ready); + activated += + ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready); } + AdjustOutstandingOpsLocked(iter_state, activated, ready); } void PropagatorState::FrameState::AddLoopInv(const NodeItem* item, @@ -593,7 +635,10 @@ void PropagatorState::FrameState::AddLoopInv(const NodeItem* item, const bool is_dead = entry.state == Entry::State::NO_VALUE; for (int i = 0; i <= iteration_count; ++i) { EntryVector outputs{entry}; - ActivateNodes(item, is_dead, GetIteration(i), &outputs, ready); + IterationState* iter_state = GetIteration(i); + int activated = + ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready); + AdjustOutstandingOpsLocked(iter_state, activated, ready); } } @@ -676,8 +721,50 @@ void PropagatorState::FrameState::SetIteration(int64 iter, // frame. Return true iff the execution of the frame is done. bool PropagatorState::FrameState::DecrementOutstandingOps( IterationState* iter_state, TaggedNodeSeq* ready) { + return AdjustOutstandingOps(iter_state, -1, ready); +} + +bool PropagatorState::FrameState::AdjustOutstandingOps( + IterationState* iter_state, int delta, TaggedNodeSeq* ready) { + // Given the following profile of values of 'delta' for wide_deep model from + // the TF model garden: + // + // Count Value + // --------------- + // 757938 delta=0x0 + // 541713 delta=0xffffffff + // 138115 delta=0x1 + // 58770 delta=0x2 + // 5394 delta=0x3 + // 4669 delta=0x4 + // 2037 delta=0xa + // 1646 delta=0x7 + // 1632 delta=0x6 + // 1613 delta=0x6c + // 1224 delta=0x5 + // 409 delta=0x53 + // 17 delta=0x86 + // + // ... it's worth no-opping out when delta == 0 to avoid the atomic + // instruction. + if (delta == 0) { + return false; + } + { + tf_shared_lock sl(mu); + if (TF_PREDICT_TRUE(!AdjustOutstandingOpsFastPath(iter_state, delta))) { + return false; + } + } mutex_lock l(mu); - return DecrementOutstandingOpsLocked(iter_state, ready); + DCHECK(IsIterationDone(iter_state)); + return CleanupIterations(iter_state, ready); +} + +bool PropagatorState::FrameState::AdjustOutstandingOpsFastPath( + IterationState* iter_state, int delta) { + auto old_val = iter_state->outstanding_ops.fetch_add(delta); + return (old_val + delta == 0) && IsIterationDone(iter_state); } // Decrement the outstanding op count and clean up the iterations in the @@ -685,12 +772,22 @@ bool PropagatorState::FrameState::DecrementOutstandingOps( bool PropagatorState::FrameState::DecrementOutstandingOpsLocked( IterationState* iter_state, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { - iter_state->outstanding_ops--; - if (iter_state->outstanding_ops != 0) { + return AdjustOutstandingOpsLocked(iter_state, -1, ready); +} + +bool PropagatorState::FrameState::AdjustOutstandingOpsLocked( + IterationState* iter_state, int delta, TaggedNodeSeq* ready) { + // We hold the lock, so we don't need to use an atomic modification. + auto cur_val = iter_state->outstanding_ops.load(std::memory_order_relaxed); + DCHECK(delta >= 0 || cur_val >= -delta) + << "cannot adjust outstanding_ops by " << delta + << " when current value is " << cur_val; + auto new_val = cur_val + delta; + iter_state->outstanding_ops.store(new_val, std::memory_order_relaxed); + if (new_val != 0) { return false; - } else { - return CleanupIterations(iter_state, ready); } + return CleanupIterations(iter_state, ready); } // Returns true if the computation in the frame is completed. diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index 167519ccc73..4e66e709310 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -143,7 +143,7 @@ class PropagatorState { Entry* input_tensors; // The number of outstanding ops for each iteration. - size_t outstanding_ops; + std::atomic outstanding_ops; // The number of outstanding frames for each iteration. int outstanding_frame_count; @@ -170,6 +170,10 @@ class PropagatorState { bool increment_dead) { return counts.adjust_for_activation(h, increment_dead); } + PendingCounts::AdjustResult adjust_for_activation_atomic( + PendingCounts::Handle h, bool increment_dead) { + return counts.adjust_for_activation_atomic(h, increment_dead); + } ~IterationState() { delete[] input_tensors; } @@ -283,7 +287,7 @@ class PropagatorState { void InitializeFrameInfo(const ImmutableExecutorState::FrameInfo& finfo); inline IterationState* GetIteration(int64 iter) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + TF_SHARED_LOCKS_REQUIRED(mu) { if (TF_PREDICT_TRUE(iter == 0)) { return iterations_first; } else { @@ -294,13 +298,26 @@ class PropagatorState { void SetIteration(int64 iter, IterationState* state); - // Decrement the outstanding op count and clean up the iterations in the - // frame. Return true iff the execution of the frame is done. + // Adjust the outstanding op count by 'delta' and clean up the iterations in + // the frame if no more ops are oustanding. Return true iff the execution of + // the frame is done. + // + // Avoids acquiring the lock in the common case that the frame is not done. + bool AdjustOutstandingOps(IterationState* iter_state, int delta, + TaggedNodeSeq* ready); + + bool AdjustOutstandingOpsLocked(IterationState* iter_state, int delta, + TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + bool AdjustOutstandingOpsFastPath(IterationState* iter_state, int delta) + TF_SHARED_LOCKS_REQUIRED(mu); + + // Convenience methods for the above 'Adjust' calls where delta takes the + // common value of -1. bool DecrementOutstandingOps(IterationState* iter_state, TaggedNodeSeq* ready); - // Decrement the outstanding op count and clean up the iterations in the - // frame. Return true iff the execution of the frame is done. bool DecrementOutstandingOpsLocked(IterationState* iter_state, TaggedNodeSeq* ready); @@ -309,7 +326,7 @@ class PropagatorState { // Returns true if the iteration of the frame is completed. bool IsIterationDone(IterationState* iter_state) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + TF_SHARED_LOCKS_REQUIRED(mu); // Increments the iteration id. If this is a new iteration, initialize it. // @@ -332,9 +349,23 @@ class PropagatorState { // Activate the successors of a node. Contents of *outputs are left in an // indeterminate state after returning from this method. - void ActivateNodes(const NodeItem* item, const bool is_dead, - IterationState* iter_state, EntryVector* outputs, - TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + // + // In the case that 'item' is a simple node (no merge/control outputs) this + // will acquire a shared lock and can run concurrently with other + // invocations. + // + // Return true if the frame is done after activation. + bool ActivateNodesAndAdjustOutstanding(const NodeItem* item, + const bool is_dead, + IterationState* iter_state, + EntryVector* outputs, + TaggedNodeSeq* ready); + + // Same as the above, but requires 'mu' already held in exclusive mode. + int ActivateNodesLocked(const NodeItem* item, const bool is_dead, + IterationState* iter_state, EntryVector* outputs, + TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); // Cleanup iterations of this frame starting from the given iteration. bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready) @@ -359,14 +390,35 @@ class PropagatorState { private: // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. - void ActivateNodesFastPath(const NodeItem* item, const bool is_dead, - IterationState* iter_state, EntryVector* outputs, - TaggedNodeSeq* ready) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + // This variant does not use atomic operations to modify the pending counts + // and thus must hold the exclusive lock. + int ActivateNodesFastPathLocked(const NodeItem* item, const bool is_dead, + IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + return ActivateNodesFastPathInternal(item, is_dead, iter_state, + outputs, ready); + } - void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead, - IterationState* iter_state, EntryVector* outputs, - TaggedNodeSeq* ready) + // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. + // This variant uses atomic operations to modify the pending counts. + int ActivateNodesFastPathShared(const NodeItem* item, const bool is_dead, + IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) + TF_SHARED_LOCKS_REQUIRED(mu) { + return ActivateNodesFastPathInternal(item, is_dead, iter_state, + outputs, ready); + } + + template + int ActivateNodesFastPathInternal(const NodeItem* item, const bool is_dead, + IterationState* iter_state, + EntryVector* outputs, + TaggedNodeSeq* ready); + + int ActivateNodesSlowPath(const NodeItem* item, const bool is_dead, + IterationState* iter_state, EntryVector* outputs, + TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); };