From bd530a65d5712b0734c0b6c9af5aa83ccd9e7387 Mon Sep 17 00:00:00 2001 From: Derek Murray <mrry@google.com> Date: Wed, 1 Apr 2020 14:26:25 -0700 Subject: [PATCH] [Executor] Split `ExecutorState` into `PropagatorState` and `ExecutorState<PropagatorStateType>`. This change is part of an ongoing refactoring to simplify "executor.cc" and enable the substitution of more efficient implementations of `PropagateOutputs()`. PiperOrigin-RevId: 304262448 Change-Id: I46a2d7fcdde89a71c502d272f35adfd34b0c4cab --- tensorflow/core/BUILD | 5 +- tensorflow/core/common_runtime/entry.h | 142 ++ tensorflow/core/common_runtime/executor.cc | 1475 ++--------------- .../core/common_runtime/propagator_state.cc | 777 +++++++++ .../core/common_runtime/propagator_state.h | 466 ++++++ 5 files changed, 1550 insertions(+), 1315 deletions(-) create mode 100644 tensorflow/core/common_runtime/entry.h create mode 100644 tensorflow/core/common_runtime/propagator_state.cc create mode 100644 tensorflow/core/common_runtime/propagator_state.h diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 4fd816dae4e..95a7b4d8411 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2546,6 +2546,7 @@ filegroup( "common_runtime/debugger_state_interface.h", "common_runtime/device_resolver_local.h", "common_runtime/dma_helper.h", + "common_runtime/entry.h", "common_runtime/executor.h", "common_runtime/executor_factory.h", "common_runtime/function_optimization_registry.h", @@ -2553,6 +2554,7 @@ filegroup( "common_runtime/graph_view.h", "common_runtime/immutable_executor_state.h", "common_runtime/input_colocation_exemption_registry.h", + "common_runtime/inspecting_placer.h", "common_runtime/isolate_placer_inspection_required_ops_pass.h", "common_runtime/local_device.h", "common_runtime/lower_function_call_op.h", @@ -2567,7 +2569,7 @@ filegroup( "common_runtime/partitioning_utils.h", "common_runtime/placer.h", "common_runtime/process_util.h", - "common_runtime/inspecting_placer.h", + "common_runtime/propagator_state.h", "common_runtime/profile_handler.h", "common_runtime/renamed_device.h", "common_runtime/rendezvous_mgr.h", @@ -2640,6 +2642,7 @@ tf_cuda_library( "common_runtime/process_function_library_runtime.cc", "common_runtime/process_state.cc", "common_runtime/process_util.cc", + "common_runtime/propagator_state.cc", "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", diff --git a/tensorflow/core/common_runtime/entry.h b/tensorflow/core/common_runtime/entry.h new file mode 100644 index 00000000000..27c1838af3f --- /dev/null +++ b/tensorflow/core/common_runtime/entry.h @@ -0,0 +1,142 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_ + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/manual_constructor.h" + +namespace tensorflow { + +class mutex; +class Tensor; + +// An Entry store a single input value for an individual kernel invocation in +// an executor. +// +// Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). +struct Entry { + enum class State { + NO_VALUE = 0, // The default state for a newly-created Entry. + HAS_VALUE, // `this->val` is valid. + HAS_CONST_TENSOR, // `this->const_tensor` is valid. + HAS_REF_TENSOR, // `this->ref_tensor` is valid. + }; + + Entry() : state(State::NO_VALUE) {} + Entry(const Entry& other) : state(other.state), alloc_attr(other.alloc_attr) { + switch (state) { + case State::NO_VALUE: + break; + case State::HAS_VALUE: + val.Init(*other.val); + break; + case State::HAS_CONST_TENSOR: + const_tensor = other.const_tensor; + break; + case State::HAS_REF_TENSOR: + ref_tensor = other.ref_tensor; + break; + } + } + + ~Entry() { + if (state == State::HAS_VALUE) val.Destroy(); + } + + Entry& operator=(const Entry& other) { + if (state == State::HAS_VALUE) { + val.Destroy(); + } + state = other.state; + alloc_attr = other.alloc_attr; + switch (state) { + case State::NO_VALUE: + break; + case State::HAS_VALUE: + val.Init(*other.val); + break; + case State::HAS_CONST_TENSOR: + const_tensor = other.const_tensor; + break; + case State::HAS_REF_TENSOR: + ref_tensor = other.ref_tensor; + break; + } + return *this; + } + + Entry& operator=(Entry&& other) { + if (state == State::HAS_VALUE) { + val.Destroy(); + } + state = other.state; + alloc_attr = other.alloc_attr; + switch (state) { + case State::NO_VALUE: + break; + case State::HAS_VALUE: + val.Init(std::move(*other.val)); + break; + case State::HAS_CONST_TENSOR: + const_tensor = other.const_tensor; + break; + case State::HAS_REF_TENSOR: + ref_tensor = other.ref_tensor; + break; + } + return *this; + } + + // Clears the <val> field, and sets this entry to the `NO_VALUE` state. + void ClearVal() { + if (state == State::HAS_VALUE) { + val.Destroy(); + } + state = State::NO_VALUE; + } + + union { + // A tensor value. Valid iff `state_ == HAS_VALUE`. + ManualConstructor<Tensor> val; + + // A pointer to a constant tensor value. Valid iff `state_ == + // HAS_CONST_TENSOR`. + const Tensor* const_tensor; + + // A tensor reference and associated mutex. Valid iff `state_ == + // HAS_REF_TENSOR`. + struct { + Tensor* tensor; + mutex* mu; + } ref_tensor; + }; + + // The current state of this entry, indicating which member of the above + // union is active. + State state; + + // The attributes of the allocator that creates the tensor. + AllocatorAttributes alloc_attr; +}; + +// TODO(b/152925936): Re-evaluate this constant with current usage patterns. +typedef gtl::InlinedVector<Entry, 4> EntryVector; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_ diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 43972589c17..39f396d2286 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -21,10 +21,12 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/costmodel_manager.h" +#include "tensorflow/core/common_runtime/entry.h" #include "tensorflow/core/common_runtime/executor_factory.h" #include "tensorflow/core/common_runtime/graph_view.h" #include "tensorflow/core/common_runtime/immutable_executor_state.h" #include "tensorflow/core/common_runtime/pending_counts.h" +#include "tensorflow/core/common_runtime/propagator_state.h" #include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/allocator.h" @@ -112,8 +114,6 @@ void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) { } // namespace nodestats -class ExecutorImpl; - // Time the execution of kernels (in CPU cycles). Used to dynamically identify // inexpensive kernels which can be dispatched inline. struct KernelTimer { @@ -124,6 +124,7 @@ struct KernelTimer { } }; +// TODO(b/152925936): Re-evaluate these constants with current usage patterns. typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; @@ -140,6 +141,7 @@ class ExecutorImpl : public Executor { void RunAsync(const Args& args, DoneCallback done) override; private: + template <class PropagatorStateType> friend class ExecutorState; // Stores execution time information about the kernels in an executor's graph. @@ -212,8 +214,55 @@ class ExecutorImpl : public Executor { }; // The state associated with one invocation of ExecutorImpl::Run. -// ExecutorState dispatches nodes when they become ready and keeps -// track of how many predecessors of a node have not done (pending_). +// +// ExecutorState dispatches nodes when they become ready, and delegates to an +// instance of `PropagatorStateType` to keep track of how many predecessors of a +// are still pending. +// +// The template argument `class PropagatorStateType` must define the following +// public members: +// * A type `TaggedNode`, representing a node to be processed, with public +// members: +// * `const NodeItem& get_node_item() const` +// * `bool get_is_dead() const` +// * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be +// processed, with public members (having the same meanings as in an +// `std::vector<TaggedNode>`): +// * `void push_back(const TaggedNode& node)` +// * `TaggedNode front() const` +// * `void pop_front()` +// * `bool empty() const` +// * A type `TaggedNodeSeq`, representing a list of nodes to be schedules, with +// public members (having the same meanings as in an +// `std::vector<TaggedNode>`): +// * `size_t size() const` +// * `bool empty() const` +// * `void clear()` +// * `const_iterator begin() const` +// * `const_iterator end() const` +// * A public constructor, `PropagatorStateType(const ImmutableExecutorState& +// immutable_state, int64 step_id)`. +// * The following public methods: +// * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, +// TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the +// nodes in `roots` and adds them to `*ready` +// * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* +// outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the +// given `tagged_node` to the destinations of its output edges, and adds +// any newly runnable nodes to `*ready` +// * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which +// returns a pointer to the input tensors for the given `tagged_node` +// * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`, +// which creates a `FrameAndIter` for the given `tagged_node` +// * `void DumpState()`, which dumps the dynamic state of the executing graph +// * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records +// that a node has started +// * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records +// that a node has completed +// +// See `PropagatorState` in "./propagator_state.h" for an example of a type that +// can be used to instantiate `PropagatorStateType`. +template <class PropagatorStateType> class ExecutorState { public: ExecutorState(const Executor::Args& args, @@ -224,452 +273,58 @@ class ExecutorState { void RunAsync(Executor::DoneCallback done); private: - // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). - struct Entry { - enum class State { - NO_VALUE = 0, // The default state for a newly-created Entry. - HAS_VALUE, // `this->val` is valid. - HAS_CONST_TENSOR, // `this->const_tensor` is valid. - HAS_REF_TENSOR, // `this->ref_tensor` is valid. - }; + // Use `TaggedNode` types defined by `PropagatorStateType`. + typedef typename PropagatorStateType::TaggedNode TaggedNode; + typedef + typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue; + typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq; - Entry() : state(State::NO_VALUE) {} - Entry(const Entry& other) - : state(other.state), alloc_attr(other.alloc_attr) { - switch (state) { - case State::NO_VALUE: - break; - case State::HAS_VALUE: - val.Init(*other.val); - break; - case State::HAS_CONST_TENSOR: - const_tensor = other.const_tensor; - break; - case State::HAS_REF_TENSOR: - ref_tensor = other.ref_tensor; - break; - } - } + struct AsyncState; - ~Entry() { - if (state == State::HAS_VALUE) val.Destroy(); - } + // Process a ready node in current thread. + void Process(TaggedNode node, int64 scheduled_nsec); - Entry& operator=(const Entry& other) { - if (state == State::HAS_VALUE) { - val.Destroy(); - } - state = other.state; - alloc_attr = other.alloc_attr; - switch (state) { - case State::NO_VALUE: - break; - case State::HAS_VALUE: - val.Init(*other.val); - break; - case State::HAS_CONST_TENSOR: - const_tensor = other.const_tensor; - break; - case State::HAS_REF_TENSOR: - ref_tensor = other.ref_tensor; - break; - } - return *this; - } + Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params, + EntryVector* outputs, NodeExecStatsInterface* stats); + void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params, + const TaggedNode& tagged_node, Entry* first_input, + NodeExecStatsInterface* stats); + void ProcessNoop(NodeExecStatsInterface* stats); + void ProcessConstTensor(const NodeItem& item, EntryVector* outputs, + NodeExecStatsInterface* stats); - Entry& operator=(Entry&& other) { - if (state == State::HAS_VALUE) { - val.Destroy(); - } - state = other.state; - alloc_attr = other.alloc_attr; - switch (state) { - case State::NO_VALUE: - break; - case State::HAS_VALUE: - val.Init(std::move(*other.val)); - break; - case State::HAS_CONST_TENSOR: - const_tensor = other.const_tensor; - break; - case State::HAS_REF_TENSOR: - ref_tensor = other.ref_tensor; - break; - } - return *this; - } + // Before invoking item->kernel, fills in its "inputs". + Status PrepareInputs(const NodeItem& item, Entry* first_input, + TensorValueVec* inputs, + AllocatorAttributeVec* input_alloc_attrs, + bool* is_input_dead); - // Clears the <val> field, and sets this entry to the `NO_VALUE` state. - void ClearVal() { - if (state == State::HAS_VALUE) { - val.Destroy(); - } - state = State::NO_VALUE; - } + // After item->kernel computation is done, processes its outputs. + Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, + EntryVector* outputs, NodeExecStatsInterface* stats); - union { - // A tensor value. Valid iff `state_ == HAS_VALUE`. - ManualConstructor<Tensor> val; + // Called after each node finishes. Takes ownership of "stats". Returns true + // if execution has completed. + // + // This method will clear `*ready` before returning. + bool NodeDone(const Status& s, TaggedNodeSeq* ready, + NodeExecStatsInterface* stats, + TaggedNodeReadyQueue* inline_ready); - // A pointer to a constant tensor value. Valid iff `state_ == - // HAS_CONST_TENSOR`. - const Tensor* const_tensor; + // Schedule all the expensive nodes in '*ready', and put all the inexpensive + // nodes in 'ready' into 'inline_ready'. + // + // This method will clear `*ready` before returning. + void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready); - // A tensor reference and associated mutex. Valid iff `state_ == - // HAS_REF_TENSOR`. - struct { - Tensor* tensor; - mutex* mu; - } ref_tensor; - }; - - // The current state of this entry, indicating which member of the above - // union is active. - State state; - - // The attributes of the allocator that creates the tensor. - AllocatorAttributes alloc_attr; - }; + // Clean up when this executor is done. + void Finish(); + void ScheduleFinish(); // Contains the device context assigned by the device at the beginning of a // step. DeviceContext* device_context_ = nullptr; - struct TaggedNode; - typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq; - typedef gtl::InlinedVector<Entry, 4> EntryVector; - - struct IterationState { - explicit IterationState(const PendingCounts* pending_counts, - int total_input_tensors) - : input_tensors(new Entry[total_input_tensors]), - outstanding_ops(0), - outstanding_frame_count(0), - counts(*pending_counts) { // Initialize with copy of *pending_counts - } - - // The state of an iteration. - - // One copy per iteration. For iteration k, i-th node's j-th input is in - // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is - // either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). - // - // NOTE: No need to protect input_tensors[i] by any locks because it - // is resized once. Each element of tensors_ is written once by the - // source node of an edge and is cleared by the destination of the same - // edge. The latter node is never run concurrently with the former node. - Entry* input_tensors; - - // The number of outstanding ops for each iteration. - size_t outstanding_ops; - - // The number of outstanding frames for each iteration. - int outstanding_frame_count; - int pending(PendingCounts::Handle h) { return counts.pending(h); } - int decrement_pending(PendingCounts::Handle h, int v) { - return counts.decrement_pending(h, v); - } - // Mark a merge node as live - // REQUIRES: Node corresponding to "h" is a merge node - void mark_live(PendingCounts::Handle h) { counts.mark_live(h); } - // Mark a node to show that processing has started. - void mark_started(PendingCounts::Handle h) { counts.mark_started(h); } - // Mark a node to show that processing has completed. - void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); } - PendingCounts::NodeState node_state(PendingCounts::Handle h) { - return counts.node_state(h); - } - - int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); } - void increment_dead_count(PendingCounts::Handle h) { - counts.increment_dead_count(h); - } - PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h, - bool increment_dead) { - return counts.adjust_for_activation(h, increment_dead); - } - - ~IterationState() { delete[] input_tensors; } - - private: - PendingCounts counts; - }; - - struct FrameState { - explicit FrameState(const ImmutableExecutorState& immutable_state, - int parallel_iters) - : immutable_state(immutable_state), - max_parallel_iterations(parallel_iters), - num_outstanding_iterations(1), - iterations(parallel_iters + 1), - iterations_raw(iterations.data()) {} - - // A new frame is created for each loop. Execution starts at iteration 0. - // When a value at iteration 0 passes through a NextIteration node, - // iteration 1 is created and starts running. Note that iteration 0 may - // still be running so multiple iterations may run in parallel. The - // frame maintains the state of iterations in several data structures - // such as pending_count and input_tensors. When iteration 0 completes, - // we garbage collect the state of iteration 0. - // - // A frame instance is considered "done" and can be garbage collected - // if all its inputs have entered and all its iterations are "done". - // - // A frame manages the live iterations of an iterative computation. - // Iteration i is considered "done" when there are no outstanding ops, - // frames at iteration i are done, all recvs for this iteration are - // completed, and iteration i-1 is done. For iteration 0, we instead - // wait for there to be no more pending inputs of the frame. - // - // Frames and iterations are garbage collected once they are done. - // The state we need to keep around is highly dependent on the - // parallelism enabled by the scheduler. We may want to have the - // scheduler dynamically control the outstanding number of live - // parallel frames and iterations. To reduce the state space, the - // scheduler might want to schedule ops in inner frames first and - // lower iterations first. - // - // This frame state is mostly initialized lazily on demand so we - // don't introduce unnecessary overhead. - - // The immutable state of the executor the frame is in. - const ImmutableExecutorState& immutable_state; - - // The name of this frame, which is the concatenation of its parent - // frame name, the iteration of the parent frame when this frame was - // created, and the value of the attr 'frame_name'. - string frame_name; - - // The unique id for this frame. Generated by fingerprinting - // frame_name. - uint64 frame_id; - - // The iteration id of its parent frame when this frame is created. - // -1 if there is no parent frame. The frame_name/parent_iter pair - // uniquely identifies this FrameState. - int64 parent_iter = -1; - - // The FrameState of its parent frame. - FrameState* parent_frame = nullptr; - - // The maximum allowed number of parallel iterations. - const int max_parallel_iterations; - - // The number of inputs this frame is still waiting. - int num_pending_inputs = 0; - - // The highest iteration number we have reached so far in this frame. - int64 iteration_count TF_GUARDED_BY(mu) = 0; - - // The number of outstanding iterations. - int num_outstanding_iterations TF_GUARDED_BY(mu) = 1; - - private: - // The active iteration states of this frame. - gtl::InlinedVector<IterationState*, 12> iterations; - IterationState** const iterations_raw TF_GUARDED_BY(mu); - IterationState* iterations_first TF_GUARDED_BY(mu); - - public: - // The NextIteration nodes to enter a new iteration. If the number of - // outstanding iterations reaches the limit, we will defer the start of - // the next iteration until the number of outstanding iterations falls - // below the limit. - std::vector<std::pair<const NodeItem*, Entry>> next_iter_roots - TF_GUARDED_BY(mu); - - // The values of the loop invariants for this loop. They are added into - // this list as they "enter" the frame. When a loop invariant enters, - // we make it available to all active iterations. When the frame starts - // a new iteration, we make all the current loop invariants available - // to the new iteration. - std::vector<std::pair<const NodeItem*, Entry>> inv_values TF_GUARDED_BY(mu); - - // The list of dead exit node items for the current highest iteration. We - // will only "execute" the dead exits of the final iteration. - std::vector<const NodeItem*> dead_exits TF_GUARDED_BY(mu); - - // Static information specific to this frame. - PendingCounts* pending_counts = nullptr; - int total_input_tensors = 0; - std::vector<const NodeItem*>* nodes = nullptr; - - // Lock ordering: ExecutorState.mu_ < mu; - // during structured traversal: parent_frame->mu < mu. - mutex mu; - - void InitializeFrameInfo(const string& enter_name) { - const ImmutableExecutorState::FrameInfo* finfo = - immutable_state.get_frame_info(enter_name); - DCHECK_NE(finfo, nullptr); - pending_counts = finfo->pending_counts.get(); - total_input_tensors = finfo->total_inputs; - num_pending_inputs = finfo->input_count; - nodes = finfo->nodes.get(); - } - - inline IterationState* GetIteration(int64 iter) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { - if (TF_PREDICT_TRUE(iter == 0)) { - return iterations_first; - } else { - size_t index = iter % (max_parallel_iterations + 1); - return iterations_raw[index]; - } - } - - inline void SetIteration(int64 iter, IterationState* state) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { - size_t index = iter % (max_parallel_iterations + 1); - DCHECK(state == nullptr || iterations[index] == nullptr); - iterations_raw[index] = state; - if (index == 0) { - iterations_first = state; - } - } - - // Decrement the outstanding op count and clean up the iterations in the - // frame. Return true iff the execution of the frame is done. - inline bool DecrementOutstandingOps(const GraphView* gview, int64 iter, - TaggedNodeSeq* ready) { - mutex_lock l(mu); - return DecrementOutstandingOpsLocked(gview, iter, ready); - } - - // Decrement the outstanding op count and clean up the iterations in the - // frame. Return true iff the execution of the frame is done. - inline bool DecrementOutstandingOpsLocked(const GraphView* gview, - int64 iter, TaggedNodeSeq* ready) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { - IterationState* istate = GetIteration(iter); - istate->outstanding_ops--; - if (istate->outstanding_ops != 0) { - return false; - } else { - return CleanupIterations(gview, iter, ready); - } - } - - // Returns true if the computation in the frame is completed. - inline bool IsFrameDone() TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { - return (num_pending_inputs == 0 && num_outstanding_iterations == 0); - } - - // Returns true if the iteration of the frame is completed. - bool IsIterationDone(int64 iter) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); - - // Increments the iteration id. If this is a new iteration, initialize it. - void IncrementIteration(const GraphView* gview, TaggedNodeSeq* ready) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); - - // Activate all the deferred NextIteration nodes in a new iteration. - void ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); - - // Activate all the current loop invariants in a new iteration. - void ActivateLoopInvs(const GraphView* gview, int64 iter, - TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); - - // Add a new loop invariant and make it available to all active - // iterations. - void AddLoopInv(const NodeItem* item, const Entry& entry, - TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); - - // Activate the successors of a node. Contents of *outputs are left in an - // indeterminate state after returning from this method. - void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter, - EntryVector* outputs, TaggedNodeSeq* ready) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); - - // Cleanup iterations of this frame starting from iteration iter. - bool CleanupIterations(const GraphView* gview, int64 iter, - TaggedNodeSeq* ready) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); - - void DumpIterationState(ExecutorState* parent) { - mutex_lock l(mu); - for (IterationState* iteration : iterations) { - if (iteration) { - LOG(WARNING) << " Iteration:"; - parent->DumpIterationState(this, iteration); - } - } - } - - ~FrameState() { - for (size_t i = 0; i < iterations.size(); ++i) { - delete iterations[i]; - iterations[i] = nullptr; - } - } - - private: - // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. - void ActivateNodesFastPath(const NodeItem* item, const bool is_dead, - int64 iter, EntryVector* outputs, - TaggedNodeSeq* ready) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); - - void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead, - int64 iter, EntryVector* outputs, - TaggedNodeSeq* ready) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); - }; - - // A tagged node: <frame*, iter, node*>. - struct TaggedNode { - const NodeItem* node_item; - FrameState* input_frame; // = nullptr; - int64 input_iter; // = -1; - bool is_dead; // = false; - - TaggedNode() {} - - TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter, - bool dead) - : node_item(node_item), - input_frame(in_frame), - input_iter(in_iter), - is_dead(dead) {} - }; - - // A drop-in replacement for std::deque<TaggedNode>. We typically don't - // have that many nodes in the ready queue, so we just use a vector and - // don't free up memory from the queue as we consume nodes. - class TaggedNodeReadyQueue { - public: - TaggedNodeReadyQueue() : front_index_(0) {} - - void push_back(const TaggedNode& node) { ready_.push_back(node); } - TaggedNode front() const { - DCHECK_LT(front_index_, ready_.size()); - return ready_[front_index_]; - } - void pop_front() { - DCHECK_LT(front_index_, ready_.size()); - front_index_++; - if ((front_index_ == ready_.size()) || (front_index_ > 16384)) { - if (front_index_ == ready_.size()) { - ready_.clear(); - } else { - // Lots of unused entries at beginning of vector: move everything - // down to start of vector. - ready_.erase(ready_.begin(), ready_.begin() + front_index_); - } - front_index_ = 0; - } - } - bool empty() const { return ready_.empty(); } - const TaggedNode* begin() const { return ready_.begin() + front_index_; } - const TaggedNode* end() const { return ready_.end(); } - - private: - gtl::InlinedVector<TaggedNode, 16> ready_; - int front_index_; - }; - - struct AsyncState; - const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply. // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply. @@ -702,14 +357,7 @@ class ExecutorState { bool sync_on_finish_; const bool run_all_kernels_inline_; - // Owned. - - // A flag that is set on error after the frame state has been - // dumped for diagnostic purposes. - bool dumped_on_error_ = false; - - // The root frame in which the execution of this step is started. - FrameState* root_frame_; + PropagatorStateType propagator_; // Invoked when the execution finishes. Executor::DoneCallback done_cb_; @@ -724,110 +372,12 @@ class ExecutorState { mutex mu_; Status status_ TF_GUARDED_BY(mu_); - - // Mapping from frame name to outstanding frames. A new frame is created - // at some iteration of an active frame. So the unique key for the new - // child frame is composed of the name of the parent frame, the iteration - // number at which the parent frame is creating the new frame, and the - // name of the new frame from nodedef. - gtl::FlatMap<string, FrameState*> outstanding_frames_ TF_GUARDED_BY(mu_); - - // The unique name of a frame. - inline string MakeFrameName(FrameState* frame, int64 iter_id, - const string& name) { - return strings::StrCat(frame->frame_name, ";", iter_id, ";", name); - } - - // Find an existing or create a new child frame in the frame 'frame' at - // iteration 'iter'. - void FindOrCreateChildFrame(FrameState* frame, int64 iter, - const NodeItem& node_item, FrameState** child); - - // Delete a frame. Called when the frame is done. - void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready); - - // Cleanup frames and iterations starting from frame/iter. Called when - // a child frame is done. - void CleanupFramesIterations(FrameState* frame, int64 iter, - TaggedNodeSeq* ready); - - // Process a ready node in current thread. - void Process(TaggedNode node, int64 scheduled_nsec); - - Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params, - EntryVector* outputs, - NodeExecStatsInterface* stats); - void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params, - const TaggedNode& tagged_node, Entry* first_input, - NodeExecStatsInterface* stats); - void ProcessNoop(NodeExecStatsInterface* stats); - void ProcessConstTensor(const NodeItem& item, EntryVector* outputs, - NodeExecStatsInterface* stats); - - // Before invoking item->kernel, fills in its "inputs". - Status PrepareInputs(const NodeItem& item, Entry* first_input, - TensorValueVec* inputs, - AllocatorAttributeVec* input_alloc_attrs, - bool* is_input_dead); - - // After item->kernel computation is done, processes its outputs. - Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, - EntryVector* outputs, NodeExecStatsInterface* stats); - - // After processing the outputs, propagates the outputs to their dsts. - // Contents of *outputs are left in an indeterminate state after - // returning from this method. - void PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item, - EntryVector* outputs, TaggedNodeSeq* ready); - - // Called after each node finishes. Takes ownership of "stats". Returns true - // if execution has completed. - // - // This method will clear `*ready` before returning. - bool NodeDone(const Status& s, TaggedNodeSeq* ready, - NodeExecStatsInterface* stats, - TaggedNodeReadyQueue* inline_ready); - - // Schedule all the expensive nodes in '*ready', and put all the inexpensive - // nodes in 'ready' into 'inline_ready'. - // - // This method will clear `*ready` before returning. - void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready); - - // For debugging/logging only. - inline void MaybeMarkCompleted(FrameState* frame, int64 iter, - const int node_id); - - // Provide debugging output about an outstanding node in the executor. - void DumpPendingNodeState(const int node_id, const Entry* input_vector, - bool show_nodes_with_no_ready_inputs); - void DumpActiveNodeState(const int node_id, const Entry* input_vector); - - // Provide debugging output about an outstanding iteration in the executor. - void DumpIterationState(const FrameState* frame, IterationState* iteration); - - // Provide debugging output of the state of the executor. - void DumpState(); - const Tensor* GetTensorValueForDump(const Entry& input); - - // Clean up when this executor is done. - void Finish(); - void ScheduleFinish(); - - // A standalone routine for this expression so that we can express - // that we don't want thread safety analysis on this reference (it's - // safe to do without the lock because the iterations array never - // resizes and this particular iteration's array element will not - // be changed out from under us because the iteration is still alive). - Entry* GetInputTensors(FrameState* input_frame, - int64 input_iter) const TF_NO_THREAD_SAFETY_ANALYSIS { - return input_frame->GetIteration(input_iter)->input_tensors; - } }; -ExecutorState::ExecutorState(const Executor::Args& args, - const ImmutableExecutorState& immutable_state, - ExecutorImpl::KernelStats* kernel_stats) +template <class PropagatorStateType> +ExecutorState<PropagatorStateType>::ExecutorState( + const Executor::Args& args, const ImmutableExecutorState& immutable_state, + ExecutorImpl::KernelStats* kernel_stats) : vlog_(VLOG_IS_ON(1)), log_memory_(LogMemory::IsEnabled()), step_id_(args.step_id), @@ -850,39 +400,25 @@ ExecutorState::ExecutorState(const Executor::Args& args, runner_(args.runner), sync_on_finish_(args.sync_on_finish), run_all_kernels_inline_(args.run_all_kernels_inline), + propagator_(immutable_state, step_id_), num_outstanding_ops_(0) { if (args.user_intra_op_threadpool != nullptr) { Device* device = immutable_state_.params().device; user_device_ = RenamedDevice::NewRenamedDevice( device->name(), device, false, false, args.user_intra_op_threadpool); } - - // We start the entire execution in iteration 0 of the root frame - // so let us create the root frame and the state for iteration 0. - // We assume root_frame_->frame_name.empty(). - root_frame_ = new FrameState(immutable_state_, 1); - root_frame_->frame_id = 0; // must be 0 - root_frame_->InitializeFrameInfo(root_frame_->frame_name); - - // Initialize iteration 0. - root_frame_->SetIteration( - 0, new IterationState(root_frame_->pending_counts, - root_frame_->total_input_tensors)); - - outstanding_frames_.insert({root_frame_->frame_name, root_frame_}); } -ExecutorState::~ExecutorState() { - for (auto name_frame : outstanding_frames_) { - delete name_frame.second; - } +template <class PropagatorStateType> +ExecutorState<PropagatorStateType>::~ExecutorState() { if (device_context_) { device_context_->Unref(); } delete slice_reader_cache_; } -void ExecutorState::RunAsync(Executor::DoneCallback done) { +template <class PropagatorStateType> +void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) { TaggedNodeSeq ready; // Ask the device to fill in the device context map. @@ -897,19 +433,12 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { // Initialize the ready queue. ready.reserve(immutable_state_.root_nodes().size()); - for (const NodeItem* item : immutable_state_.root_nodes()) { - DCHECK_EQ(item->num_inputs, 0); - ready.push_back(TaggedNode{item, root_frame_, 0, false}); - } + propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready); + num_outstanding_ops_ = ready.size(); if (ready.empty()) { delete this; done(Status::OK()); } else { - num_outstanding_ops_ = ready.size(); - { - mutex_lock l(root_frame_->mu); - root_frame_->GetIteration(0)->outstanding_ops = ready.size(); - } done_cb_ = std::move(done); // Schedule to run all the ready ops in thread pool. ScheduleReady(&ready, nullptr); @@ -921,7 +450,8 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { // asynchronous kernels because OpKernelContext methods like input_type(i) needs // the param points to valid input type vector. It's not an issue for // sync kernels because these vectors are kept on the stack. -struct ExecutorState::AsyncState { +template <class PropagatorStateType> +struct ExecutorState<PropagatorStateType>::AsyncState { AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node, const NodeItem* _item, Entry* _first_input, NodeExecStatsInterface* _stats) @@ -978,10 +508,10 @@ bool MightTrace(const tracing::EventCollector* event_collector, return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive)); } -Status ExecutorState::ProcessSync(const NodeItem& item, - OpKernelContext::Params* params, - EntryVector* outputs, - NodeExecStatsInterface* stats) { +template <class PropagatorStateType> +Status ExecutorState<PropagatorStateType>::ProcessSync( + const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs, + NodeExecStatsInterface* stats) { Status s; OpKernelContext ctx(params, item.num_outputs); nodestats::SetOpStart(stats); @@ -1018,11 +548,11 @@ Status ExecutorState::ProcessSync(const NodeItem& item, return s; } -void ExecutorState::ProcessAsync(const NodeItem& item, - const OpKernelContext::Params& params, - const TaggedNode& tagged_node, - Entry* first_input, - NodeExecStatsInterface* stats) { +template <class PropagatorStateType> +void ExecutorState<PropagatorStateType>::ProcessAsync( + const NodeItem& item, const OpKernelContext::Params& params, + const TaggedNode& tagged_node, Entry* first_input, + NodeExecStatsInterface* stats) { AsyncOpKernel* async_kernel = item.kernel->AsAsync(); DCHECK(async_kernel != nullptr); AsyncState* state = @@ -1040,7 +570,7 @@ void ExecutorState::ProcessAsync(const NodeItem& item, if (vlog_) { VLOG(2) << "Async kernel done: " << state->item->node_id << " step " << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def()) - << (state->tagged_node.is_dead ? " is dead" : "") + << (state->tagged_node.get_is_dead() ? " is dead" : "") << " device: " << device->name(); } @@ -1049,12 +579,10 @@ void ExecutorState::ProcessAsync(const NodeItem& item, for (int i = 0; i < num_inputs; ++i) { (first_input + i)->ClearVal(); } - FrameState* input_frame = state->tagged_node.input_frame; - const int64 input_iter = state->tagged_node.input_iter; - MaybeMarkCompleted(input_frame, input_iter, state->item->node_id); + propagator_.MaybeMarkCompleted(state->tagged_node); TaggedNodeSeq ready; if (s.ok()) { - PropagateOutputs(state->tagged_node, state->item, &outputs, &ready); + propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready); } outputs.clear(); const bool completed = NodeDone(s, &ready, stats, nullptr); @@ -1074,14 +602,16 @@ void ExecutorState::ProcessAsync(const NodeItem& item, } } -void ExecutorState::ProcessNoop(NodeExecStatsInterface* stats) { +template <class PropagatorStateType> +void ExecutorState<PropagatorStateType>::ProcessNoop( + NodeExecStatsInterface* stats) { nodestats::SetOpStart(stats); nodestats::SetOpEnd(stats); } -void ExecutorState::ProcessConstTensor(const NodeItem& item, - EntryVector* outputs, - NodeExecStatsInterface* stats) { +template <class PropagatorStateType> +void ExecutorState<PropagatorStateType>::ProcessConstTensor( + const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) { nodestats::SetOpStart(stats); nodestats::SetOpEnd(stats); outputs->resize(1); @@ -1091,12 +621,11 @@ void ExecutorState::ProcessConstTensor(const NodeItem& item, output.alloc_attr = item.output_attrs()[0]; } -void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { +template <class PropagatorStateType> +void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node, + int64 scheduled_nsec) { profiler::TraceMe activity( - [&] { - return absl::StrCat("ExecutorState::Process#id=", step_id_, - ",iter_num=", tagged_node.input_iter, "#"); - }, + [&] { return absl::StrCat("ExecutorState::Process#id=", step_id_, "#"); }, 2); WithContext wc(context_); TaggedNodeSeq ready; @@ -1164,22 +693,14 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { while (!inline_ready.empty()) { tagged_node = inline_ready.front(); inline_ready.pop_front(); - const NodeItem& item = *tagged_node.node_item; - FrameState* input_frame = tagged_node.input_frame; - const int64 input_iter = tagged_node.input_iter; + const NodeItem& item = tagged_node.get_node_item(); const int id = item.node_id; - // TODO(misard) Replace with a finer-grain enabling flag once we - // add better optional debugging support. - if (vlog_ && VLOG_IS_ON(1)) { - mutex_lock l(input_frame->mu); - input_frame->GetIteration(input_iter) - ->mark_started(immutable_state_.pending_ids()[id]); - } + propagator_.MaybeMarkStarted(tagged_node); params.track_allocations = false; stats = nullptr; - if (stats_collector_ && !tagged_node.is_dead) { + if (stats_collector_ && !tagged_node.get_is_dead()) { stats = stats_collector_->CreateNodeExecStats(&item.kernel->def()); // Track allocations if and only if we are collecting statistics, and // `stats` object is expecting allocations to be tracked. @@ -1191,19 +712,18 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { if (vlog_) { VLOG(1) << "Process node: " << id << " step " << params.step_id << " " << SummarizeNodeDef(item.kernel->def()) - << (tagged_node.is_dead ? " is dead" : "") + << (tagged_node.get_is_dead() ? " is dead" : "") << " device: " << device->name(); } - Entry* input_tensors = GetInputTensors(input_frame, input_iter); - Entry* first_input = input_tensors + item.input_start; + Entry* first_input = propagator_.GetInputTensors(tagged_node); outputs.clear(); // Only execute this node if it is not dead or it is a send/recv // transfer node. For transfer nodes, we need to propagate the "dead" // bit even when the node is dead. bool launched_asynchronously = false; - if (tagged_node.is_dead && !item.is_transfer_node) { + if (tagged_node.get_is_dead() && !item.is_transfer_node) { outputs.resize(item.num_outputs); } else if (TF_PREDICT_FALSE(item.is_noop)) { ProcessNoop(stats); @@ -1220,7 +740,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { for (int i = 0; i < num_inputs; ++i) { (first_input + i)->ClearVal(); } - MaybeMarkCompleted(input_frame, input_iter, id); + propagator_.MaybeMarkCompleted(tagged_node); // Continue to process the nodes in 'inline_ready'. completed = NodeDone(s, &ready, stats, &inline_ready); continue; @@ -1228,7 +748,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { // Set up compute params. params.op_kernel = item.kernel; - params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter); + params.frame_iter = propagator_.GetFrameAndIter(tagged_node); params.is_input_dead = is_input_dead; params.output_attr_array = item.output_attrs(); params.forward_from_array = item.forward_from(); @@ -1246,7 +766,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { if (vlog_) { VLOG(2) << "Synchronous kernel done: " << id << " step " << params.step_id << " " << SummarizeNodeDef(item.kernel->def()) - << (tagged_node.is_dead ? " is dead: " : "") + << (tagged_node.get_is_dead() ? " is dead: " : "") << " device: " << device->name(); } @@ -1255,10 +775,10 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { for (int i = 0; i < num_inputs; ++i) { (first_input + i)->ClearVal(); } - MaybeMarkCompleted(input_frame, input_iter, id); + propagator_.MaybeMarkCompleted(tagged_node); // Propagates outputs. if (s.ok()) { - PropagateOutputs(tagged_node, &item, &outputs, &ready); + propagator_.PropagateOutputs(tagged_node, &outputs, &ready); } outputs.clear(); if (stats) { @@ -1273,10 +793,10 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { if (completed) ScheduleFinish(); } -Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, - TensorValueVec* inputs, - AllocatorAttributeVec* input_alloc_attrs, - bool* is_input_dead) { +template <class PropagatorStateType> +Status ExecutorState<PropagatorStateType>::PrepareInputs( + const NodeItem& item, Entry* first_input, TensorValueVec* inputs, + AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) { inputs->clear(); inputs->resize(item.num_inputs); input_alloc_attrs->clear(); @@ -1384,9 +904,10 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, return Status::OK(); } -Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, - EntryVector* outputs, - NodeExecStatsInterface* stats) { +template <class PropagatorStateType> +Status ExecutorState<PropagatorStateType>::ProcessOutputs( + const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs, + NodeExecStatsInterface* stats) { DCHECK_EQ(0, outputs->size()); outputs->resize(item.num_outputs); @@ -1397,7 +918,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, // add better optional debugging support. if (vlog_ && VLOG_IS_ON(1)) { LOG(WARNING) << this << " Compute status: " << s; - DumpState(); + propagator_.DumpState(); } if (s.code() == error::RESOURCE_EXHAUSTED) { if (stats_collector_) { @@ -1481,120 +1002,10 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, return s; } -void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, - const NodeItem* item, EntryVector* outputs, - TaggedNodeSeq* ready) { - profiler::TraceMe activity( - [&]() { - return strings::StrCat( - "ExecutorPropagateOutputs#", "id=", step_id_, - ",kernel_name=", item->kernel->name_view(), - ",num_output_edges=", item->num_output_edges, - ",num_output_control_edges=", item->num_output_control_edges, "#"); - }, - profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); - - FrameState* input_frame = tagged_node.input_frame; - const int64 input_iter = tagged_node.input_iter; - const bool is_dead = tagged_node.is_dead; - - // Propagates outputs along out edges, and puts newly ready nodes - // into the ready queue. - DCHECK(ready->empty()); - bool is_frame_done = false; - FrameState* output_frame = input_frame; - int64 output_iter = input_iter; - - if (!item->is_enter_exit_or_next_iter) { - // Fast path for nodes types that don't need special handling - DCHECK_EQ(input_frame, output_frame); - // Normal path for most nodes - mutex_lock l(input_frame->mu); - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); - is_frame_done = input_frame->DecrementOutstandingOpsLocked( - &immutable_state_.graph_view(), input_iter, ready); - } else if (item->is_enter) { - FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); - output_iter = 0; - { - mutex_lock l(output_frame->mu); - if (item->is_constant_enter) { - // Propagate to all active iterations if this is a loop invariant. - output_frame->AddLoopInv(item, (*outputs)[0], ready); - } else { - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); - } - output_frame->num_pending_inputs--; - } - is_frame_done = input_frame->DecrementOutstandingOps( - &immutable_state_.graph_view(), input_iter, ready); - } else if (item->is_exit) { - if (is_dead) { - mutex_lock l(input_frame->mu); - // Stop and remember this node if it is a dead exit. - if (input_iter == input_frame->iteration_count) { - input_frame->dead_exits.push_back(item); - } - is_frame_done = input_frame->DecrementOutstandingOpsLocked( - &immutable_state_.graph_view(), input_iter, ready); - } else { - output_frame = input_frame->parent_frame; - output_iter = input_frame->parent_iter; - { - mutex_lock l(output_frame->mu); - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); - } - is_frame_done = input_frame->DecrementOutstandingOps( - &immutable_state_.graph_view(), input_iter, ready); - } - } else { - DCHECK(item->is_next_iteration); - mutex_lock l(input_frame->mu); - if (is_dead) { - // Stop the deadness propagation. - output_frame = nullptr; - } else { - if (input_iter == input_frame->iteration_count && - input_frame->num_outstanding_iterations == - input_frame->max_parallel_iterations) { - // Reached the maximum for parallel iterations. - input_frame->next_iter_roots.push_back({item, (*outputs)[0]}); - output_frame = nullptr; - } else { - // If this is a new iteration, start it. - if (input_iter == input_frame->iteration_count) { - input_frame->IncrementIteration(&immutable_state_.graph_view(), - ready); - } - output_iter = input_iter + 1; - } - } - if (output_frame != nullptr) { - // This is the case when node is not Enter, Exit, or NextIteration. - DCHECK(input_frame == output_frame); - output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); - } - is_frame_done = input_frame->DecrementOutstandingOpsLocked( - &immutable_state_.graph_view(), input_iter, ready); - } - - // At this point, this node is completely done. We also know if the - // completion of this node makes its frame completed. - if (is_frame_done) { - FrameState* parent_frame = input_frame->parent_frame; - const int64 parent_iter = input_frame->parent_iter; - DeleteFrame(input_frame, ready); - if (parent_frame != nullptr) { - // The completion of frame may cause completions in its parent frame. - // So clean things up recursively. - CleanupFramesIterations(parent_frame, parent_iter, ready); - } - } -} - -bool ExecutorState::NodeDone(const Status& s, TaggedNodeSeq* ready, - NodeExecStatsInterface* stats, - TaggedNodeReadyQueue* inline_ready) { +template <class PropagatorStateType> +bool ExecutorState<PropagatorStateType>::NodeDone( + const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats, + TaggedNodeReadyQueue* inline_ready) { nodestats::SetAllEnd(stats); if (stats) { if (stats_collector_) { @@ -1658,8 +1069,9 @@ bool ExecutorState::NodeDone(const Status& s, TaggedNodeSeq* ready, return completed; } -void ExecutorState::ScheduleReady(TaggedNodeSeq* ready, - TaggedNodeReadyQueue* inline_ready) { +template <class PropagatorStateType> +void ExecutorState<PropagatorStateType>::ScheduleReady( + TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) { if (ready->empty()) return; int64 scheduled_nsec = 0; @@ -1693,7 +1105,7 @@ void ExecutorState::ScheduleReady(TaggedNodeSeq* ready, } else { for (auto& tagged_node : *ready) { const NodeItem& item = *tagged_node.node_item; - if (tagged_node.is_dead || !kernel_stats_->IsExpensive(item)) { + if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) { // Inline this inexpensive node. inline_ready->push_back(tagged_node); } else { @@ -1721,135 +1133,8 @@ void ExecutorState::ScheduleReady(TaggedNodeSeq* ready, ready->clear(); } -inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter, - const int node_id) { - // TODO(misard) Replace with a finer-grain enabling flag once we - // add better optional debugging support. - if (vlog_ && VLOG_IS_ON(1)) { - mutex_lock l(frame->mu); - frame->GetIteration(iter)->mark_completed( - immutable_state_.pending_ids()[node_id]); - } -} - -const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) { - switch (input.state) { - case Entry::State::NO_VALUE: - return kEmptyTensor; - case Entry::State::HAS_VALUE: - return input.val.get(); - case Entry::State::HAS_CONST_TENSOR: - return input.const_tensor; - case Entry::State::HAS_REF_TENSOR: - return input.ref_tensor.tensor; - } -} - -void ExecutorState::DumpPendingNodeState( - const int node_id, const Entry* input_vector, - const bool show_nodes_with_no_ready_inputs) { - const NodeItem& node_item = *immutable_state_.graph_view().node(node_id); - const int input_base = node_item.input_start; - if (!show_nodes_with_no_ready_inputs) { - bool has_ready_input = false; - for (int i = 0; i < node_item.num_inputs; ++i) { - const Entry& input = input_vector[input_base + i]; - const Tensor* tensor = GetTensorValueForDump(input); - if (tensor->IsInitialized()) { - has_ready_input = true; - break; - } - } - if (!has_ready_input) { - return; - } - } - LOG(WARNING) << " Pending Node: " << node_item.DebugString(); - for (int i = 0; i < node_item.num_inputs; ++i) { - const Entry& input = input_vector[input_base + i]; - const Tensor* tensor = GetTensorValueForDump(input); - if (tensor->IsInitialized()) { - LOG(WARNING) << " Input " << i << ": " - << strings::StrCat( - "Tensor<type: ", DataTypeString(tensor->dtype()), - " shape: ", tensor->shape().DebugString(), ">"); - } else { - LOG(WARNING) << " Input " << i << ": not present"; - } - } -} - -void ExecutorState::DumpActiveNodeState(const int node_id, - const Entry* input_vector) { - const NodeItem& node_item = *immutable_state_.graph_view().node(node_id); - LOG(WARNING) << " Active Node: " << node_item.DebugString(); - const int input_base = node_item.input_start; - for (int i = 0; i < node_item.num_inputs; ++i) { - const Entry& input = input_vector[input_base + i]; - const Tensor* tensor = GetTensorValueForDump(input); - if (tensor->IsInitialized()) { - LOG(WARNING) << " Input " << i << ": " - << strings::StrCat( - "Tensor<type: ", DataTypeString(tensor->dtype()), - " shape: ", tensor->shape().DebugString(), ">"); - } else { - LOG(WARNING) << " Input " << i << ": not present"; - } - } -} - -void ExecutorState::DumpIterationState(const FrameState* frame, - IterationState* iteration) { - const std::vector<const NodeItem*>* nodes = frame->nodes; - // Dump any waiting nodes that are holding on to tensors. - for (const NodeItem* node : *nodes) { - PendingCounts::Handle pending_id = - immutable_state_.pending_ids()[node->node_id]; - if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY || - iteration->node_state(pending_id) == PendingCounts::PENDING_READY) { - DumpPendingNodeState(node->node_id, iteration->input_tensors, false); - } - } - // Then the active nodes. - for (const NodeItem* node : *nodes) { - PendingCounts::Handle pending_id = - immutable_state_.pending_ids()[node->node_id]; - if (iteration->node_state(pending_id) == PendingCounts::STARTED) { - DumpActiveNodeState(node->node_id, iteration->input_tensors); - } - } - // Show all input tensors in use. - const int total_input_tensors = frame->total_input_tensors; - size_t total_bytes = 0; - for (int i = 0; i < total_input_tensors; ++i) { - const Entry& input = iteration->input_tensors[i]; - const Tensor* tensor = GetTensorValueForDump(input); - if (tensor->IsInitialized()) { - LOG(WARNING) << " Input " << i << ": " - << strings::StrCat( - "Tensor<type: ", DataTypeString(tensor->dtype()), - " shape: ", tensor->shape().DebugString(), - ", bytes: ", tensor->TotalBytes(), ">"); - total_bytes += tensor->TotalBytes(); - } - } - LOG(WARNING) << " Total bytes " << total_bytes; -} - -void ExecutorState::DumpState() { - mutex_lock l(mu_); - if (!dumped_on_error_) { - LOG(WARNING) << "Dumping state"; - for (auto& frame : outstanding_frames_) { - LOG(WARNING) << frame.first; - FrameState* frame_state = frame.second; - frame_state->DumpIterationState(this); - } - dumped_on_error_ = true; - } -} - -void ExecutorState::ScheduleFinish() { +template <class PropagatorStateType> +void ExecutorState<PropagatorStateType>::ScheduleFinish() { // Checks condition to decide if needs to invoke Finish(). If there are // in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke // Finish(). Otherwise, invoke Finish() directly. @@ -1868,7 +1153,8 @@ void ExecutorState::ScheduleFinish() { Finish(); } -void ExecutorState::Finish() { +template <class PropagatorStateType> +void ExecutorState<PropagatorStateType>::Finish() { mu_.lock(); auto status = status_; auto done_cb = std::move(done_cb_); @@ -1962,447 +1248,8 @@ void ExecutorState::Finish() { } } -void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, - const NodeItem& node_item, - FrameState** child) { - // Get the child frame name. - AttrSlice attrs(node_item.kernel->def()); - const string& enter_name = GetNodeAttrString(attrs, "frame_name"); - DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node " - << node_item.kernel->name(); - const string child_name = MakeFrameName(frame, iter, enter_name); - - { - mutex_lock executor_lock(mu_); - auto it = outstanding_frames_.find(child_name); - if (it != outstanding_frames_.end()) { - *child = it->second; - return; - } - } - - // Need to create a new frame instance. - // Note that this new frame instance is created without any locks. - if (vlog_) VLOG(2) << "Create frame: " << child_name; - - int parallel_iters; - bool found_parallel_iters = - TryGetNodeAttr(attrs, "parallel_iterations", ¶llel_iters); - DCHECK(found_parallel_iters) - << "Could not find \"parallel_iterations\" attr in node " - << node_item.kernel->name(); - FrameState* temp = new FrameState(immutable_state_, parallel_iters); - temp->frame_name = child_name; - temp->frame_id = Hash64(child_name); - temp->parent_frame = frame; - temp->parent_iter = iter; - temp->InitializeFrameInfo(enter_name); - - // Initialize iteration 0. - { - mutex_lock l(temp->mu); - temp->SetIteration( - 0, new IterationState(temp->pending_counts, temp->total_input_tensors)); - } - - { - mutex_lock executor_lock(mu_); - auto it = outstanding_frames_.find(child_name); - if (it != outstanding_frames_.end()) { - *child = it->second; - } else { - mutex_lock frame_lock(frame->mu); - frame->GetIteration(iter)->outstanding_frame_count++; - outstanding_frames_[child_name] = temp; - *child = temp; - temp = nullptr; - } - } - delete temp; // Not used so delete it. -} - -void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { - // First, propagate dead_exits (if any) to the parent frame. - FrameState* parent_frame = frame->parent_frame; - const int64 parent_iter = frame->parent_iter; - if (parent_frame != nullptr) { - mutex_lock parent_frame_lock(parent_frame->mu); - // Propagate all the dead exits to the parent frame. - mutex_lock this_frame_lock(frame->mu); - - for (const NodeItem* item : frame->dead_exits) { - auto parent_iter_state = parent_frame->GetIteration(parent_iter); - - auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready, - bool dst_dead) { - if (dst_ready) { - if (dst_item.is_control_trigger) dst_dead = false; - ready->emplace_back(&dst_item, parent_frame, parent_iter, dst_dead); - parent_iter_state->outstanding_ops++; - } - }; - - auto propagate_to_non_merge = [&](PendingCounts::Handle dst_pending_id) { - parent_iter_state->increment_dead_count(dst_pending_id); - return parent_iter_state->decrement_pending(dst_pending_id, 1) == 0; - }; - - for (const EdgeInfo& e : item->output_edges()) { - const NodeItem& dst_item = - *immutable_state_.graph_view().node(e.dst_id); - const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; - - bool dst_dead = true; - bool dst_ready; - // We know this is a dead input to dst. - if (dst_item.is_merge) { - parent_iter_state->increment_dead_count(dst_pending_id); - const int dead_cnt = parent_iter_state->dead_count(dst_pending_id); - dst_dead = (dead_cnt == dst_item.num_inputs); - dst_ready = - (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead; - } else { - dst_ready = propagate_to_non_merge(dst_pending_id); - } - maybe_add_to_ready(dst_item, dst_ready, dst_dead); - } - - for (const ControlEdgeInfo& e : item->output_control_edges()) { - const NodeItem& dst_item = - *immutable_state_.graph_view().node(e.dst_id); - const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; - - bool dst_dead; - bool dst_ready; - // We know this is a dead input to dst. - if (dst_item.is_merge) { - parent_iter_state->decrement_pending(dst_pending_id, 2); - int count = parent_iter_state->pending(dst_pending_id); - int dead_cnt = parent_iter_state->dead_count(dst_pending_id); - dst_dead = (dead_cnt == dst_item.num_inputs); - dst_ready = (count == 0) || ((count == 1) && dst_dead); - } else { - dst_dead = true; - dst_ready = propagate_to_non_merge(dst_pending_id); - } - maybe_add_to_ready(dst_item, dst_ready, dst_dead); - } - } - } - - // Delete the frame. - const string& frame_name = frame->frame_name; - if (vlog_) VLOG(2) << "Delete frame " << frame_name; - { - mutex_lock executor_lock(mu_); - outstanding_frames_.erase(frame_name); - } - delete frame; -} - -void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter, - TaggedNodeSeq* ready) { - bool is_frame_done = false; - { - mutex_lock frame_lock(frame->mu); - frame->GetIteration(iter)->outstanding_frame_count--; - is_frame_done = - frame->CleanupIterations(&immutable_state_.graph_view(), iter, ready); - } - if (is_frame_done) { - FrameState* parent_frame = frame->parent_frame; - const int64 parent_iter = frame->parent_iter; - DeleteFrame(frame, ready); - if (parent_frame != nullptr) { - // The completion of frame may cause completions in its parent frame. - // So clean things up recursively. - CleanupFramesIterations(parent_frame, parent_iter, ready); - } - } -} - -void ExecutorState::FrameState::ActivateNodesFastPath(const NodeItem* item, - const bool is_dead, - int64 iter, - EntryVector* outputs, - TaggedNodeSeq* ready) { - // If we know that none of the item's edge destinations require special - // handling (i.e. none of the nodes is a merge or control trigger node), we - // can take a fast path that avoids accessing the destination NodeItem. - const GraphView& gview = immutable_state.graph_view(); - IterationState* iter_state = GetIteration(iter); - -// Add dst to the ready queue if it's ready -// -// NOTE(mrry): Use a macro here instead of a lambda, because this method is -// performance-critical and we need to ensure that the code is inlined. -#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \ - do { \ - if (!adjust_result.any_pending) { \ - const NodeItem* dst_item = gview.node(dst_id); \ - TaggedNode& t = ready->emplace_back(); \ - t.node_item = dst_item; \ - t.input_frame = this; \ - t.input_iter = iter; \ - t.is_dead = adjust_result.any_dead; \ - iter_state->outstanding_ops++; \ - } \ - } while (0); - - Entry* input_tensors = iter_state->input_tensors; - - for (const EdgeInfo& e : item->output_edges()) { - const int dst_id = e.dst_id; - const PendingCounts::Handle dst_pending_id = - immutable_state.pending_ids()[dst_id]; - const int src_slot = e.output_slot; - - const bool increment_dead = - (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE)); - const PendingCounts::AdjustResult adjust_result = - iter_state->adjust_for_activation(dst_pending_id, increment_dead); - const int dst_loc = e.input_slot; - if (e.is_last) { - input_tensors[dst_loc] = std::move((*outputs)[src_slot]); - } else { - input_tensors[dst_loc] = (*outputs)[src_slot]; - } - MAYBE_ADD_TO_READY(dst_id, adjust_result); - } - - for (const ControlEdgeInfo& e : item->output_control_edges()) { - const int dst_id = e.dst_id; - const PendingCounts::Handle dst_pending_id = - immutable_state.pending_ids()[dst_id]; - const PendingCounts::AdjustResult adjust_result = - iter_state->adjust_for_activation(dst_pending_id, is_dead); - MAYBE_ADD_TO_READY(dst_id, adjust_result); - } -#undef MAYBE_ADD_TO_READY -} - -void ExecutorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, - const bool is_dead, - int64 iter, - EntryVector* outputs, - TaggedNodeSeq* ready) { - // If any of the edge destinations is a merge or a control trigger node, - // we need to read each destination NodeItem to determine what action - // to take. - const GraphView& gview = immutable_state.graph_view(); - IterationState* iter_state = GetIteration(iter); - - auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item, - bool dst_ready, bool dst_dead) { - // Add dst to the ready queue if it's ready - if (dst_ready) { - if (dst_item->is_control_trigger) dst_dead = false; - ready->emplace_back(dst_item, this, iter, dst_dead); - iter_state->outstanding_ops++; - } - }; - - Entry* input_tensors = iter_state->input_tensors; - - for (const EdgeInfo& e : item->output_edges()) { - const int dst_id = e.dst_id; - const NodeItem* dst_item = gview.node(dst_id); - const PendingCounts::Handle dst_pending_id = - immutable_state.pending_ids()[dst_id]; - const int src_slot = e.output_slot; - - bool dst_dead = false; - bool dst_ready = false; - bool dst_need_input = true; - - if (dst_item->is_merge) { - // A merge node is ready if all control inputs have arrived and either - // a) a live data input becomes available or b) all data inputs are - // dead. For Merge, pending's LSB is set iff a live data input has - // arrived. - if ((*outputs)[src_slot].state != Entry::State::NO_VALUE) { - // This is a live data input. - int count = iter_state->pending(dst_pending_id); - iter_state->mark_live(dst_pending_id); - // Only the first live edge sets the input and (potentially) - // triggers execution. The low bit of count is set if and - // only if no live input has been used yet (mark_live clears - // it). The node should be started if and only if this is - // the first live input and there are no pending control - // edges, i.e. count == 1. - dst_ready = (count == 1); - dst_need_input = ((count & 0x1) == 1); - } else { - // This is a dead data input. Note that dst_node is dead if node is - // a dead enter. We need this to handle properly a while loop on - // the untaken branch of a conditional. - // TODO(yuanbyu): This is a bit hacky, but a good solution for - // now. - iter_state->increment_dead_count(dst_pending_id); - const int dead_cnt = iter_state->dead_count(dst_pending_id); - dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter; - dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead; - dst_need_input = false; - } - } else { - // Handle all other (non-merge) nodes. - const bool increment_dead = - (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE)); - const PendingCounts::AdjustResult adjust_result = - iter_state->adjust_for_activation(dst_pending_id, increment_dead); - dst_dead = adjust_result.any_dead; - dst_ready = !adjust_result.any_pending; - } - - if (dst_need_input) { - const int dst_loc = e.input_slot; - if (e.is_last) { - input_tensors[dst_loc] = std::move((*outputs)[src_slot]); - } else { - input_tensors[dst_loc] = (*outputs)[src_slot]; - } - } - - maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead); - } - - for (const ControlEdgeInfo& e : item->output_control_edges()) { - const int dst_id = e.dst_id; - const NodeItem* dst_item = gview.node(dst_id); - const PendingCounts::Handle dst_pending_id = - immutable_state.pending_ids()[dst_id]; - - bool dst_dead; - bool dst_ready; - if (dst_item->is_merge) { - // A merge node is ready if all control inputs have arrived and either - // a) a live data input becomes available or b) all data inputs are - // dead. For Merge, pending's LSB is set iff a live data input has - // arrived. - iter_state->decrement_pending(dst_pending_id, 2); - int count = iter_state->pending(dst_pending_id); - int dead_cnt = iter_state->dead_count(dst_pending_id); - dst_dead = (dead_cnt == dst_item->num_inputs); - dst_ready = (count == 0) || ((count == 1) && dst_dead); - } else { - // Handle all other (non-merge) nodes. - const PendingCounts::AdjustResult adjust_result = - iter_state->adjust_for_activation(dst_pending_id, is_dead); - dst_dead = adjust_result.any_dead; - dst_ready = !adjust_result.any_pending; - } - maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead); - } -} - -void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, - const bool is_dead, int64 iter, - EntryVector* outputs, - TaggedNodeSeq* ready) { - if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { - ActivateNodesSlowPath(item, is_dead, iter, outputs, ready); - } else { - ActivateNodesFastPath(item, is_dead, iter, outputs, ready); - } -} - -void ExecutorState::FrameState::ActivateNexts(const GraphView* gview, - int64 iter, - TaggedNodeSeq* ready) { - // Propagate the deferred NextIteration nodes to the new iteration. - for (auto& node_entry : next_iter_roots) { - const NodeItem* item = node_entry.first; - const Entry& entry = node_entry.second; - const bool is_dead = entry.state == Entry::State::NO_VALUE; - EntryVector outputs{entry}; - ActivateNodes(item, is_dead, iter, &outputs, ready); - } - next_iter_roots.clear(); -} - -void ExecutorState::FrameState::ActivateLoopInvs(const GraphView* gview, - int64 iter, - TaggedNodeSeq* ready) { - // Propagate loop invariants to the new iteration. - for (auto& node_entry : inv_values) { - const NodeItem* item = node_entry.first; - const Entry& entry = node_entry.second; - const bool is_dead = entry.state == Entry::State::NO_VALUE; - EntryVector outputs{entry}; - ActivateNodes(item, is_dead, iter, &outputs, ready); - } -} - -void ExecutorState::FrameState::AddLoopInv(const NodeItem* item, - const Entry& entry, - TaggedNodeSeq* ready) { - // Store this value. - inv_values.push_back({item, entry}); - - // Make this value available to all iterations. - const bool is_dead = entry.state == Entry::State::NO_VALUE; - for (int i = 0; i <= iteration_count; ++i) { - EntryVector outputs{entry}; - ActivateNodes(item, is_dead, i, &outputs, ready); - } -} - -bool ExecutorState::FrameState::IsIterationDone(int64 iter) { - IterationState* iter_state = GetIteration(iter); - if (iter_state->outstanding_ops == 0 && - iter_state->outstanding_frame_count == 0) { - if (iter == 0) { - // The enclosing frame has no pending input. - return num_pending_inputs == 0; - } else { - // The preceding iteration is deleted (and therefore done). - return (GetIteration(iter - 1) == nullptr); - } - } - return false; -} - -void ExecutorState::FrameState::IncrementIteration(const GraphView* gview, - TaggedNodeSeq* ready) { - iteration_count++; - const int64 next_iter = iteration_count; - - // Initialize the next iteration. - IterationState* iter_state = - new IterationState(pending_counts, total_input_tensors); - SetIteration(next_iter, iter_state); - num_outstanding_iterations++; - dead_exits.clear(); - - // Activate the successors of the deferred roots in the new iteration. - ActivateNexts(gview, next_iter, ready); - - // Activate the loop invariants in the new iteration. - ActivateLoopInvs(gview, next_iter, ready); -} - -bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview, - int64 iter, - TaggedNodeSeq* ready) { - int64 curr_iter = iter; - while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) { - // Delete the iteration curr_iter. - delete GetIteration(curr_iter); - SetIteration(curr_iter, nullptr); - --num_outstanding_iterations; - ++curr_iter; - - // When one iteration is completed, we check for deferred iteration, - // and start it if there is one. - if (!next_iter_roots.empty()) { - IncrementIteration(gview, ready); - } - } - return IsFrameDone(); -} - void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { - (new ExecutorState(args, immutable_state_, &kernel_stats_)) + (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_)) ->RunAsync(std::move(done)); } diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc new file mode 100644 index 00000000000..e2827a8eb1f --- /dev/null +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -0,0 +1,777 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/propagator_state.h" + +#include "tensorflow/core/common_runtime/graph_view.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +namespace tensorflow { + +// 1-D, 0 element tensor. +static const Tensor* const kEmptyTensor = new Tensor; + +typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; +typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; + +PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state, + int64 step_id) + : immutable_state_(immutable_state), + step_id_(step_id), + vlog_(VLOG_IS_ON(1)) { + // We start the entire execution in iteration 0 of the root frame + // so let us create the root frame and the state for iteration 0. + // We assume root_frame_->frame_name.empty(). + root_frame_ = new FrameState(immutable_state_, 1); + root_frame_->frame_id = 0; // must be 0 + root_frame_->InitializeFrameInfo(root_frame_->frame_name); + + // Initialize iteration 0. + root_frame_->SetIteration( + 0, new PropagatorState::IterationState(root_frame_->pending_counts, + root_frame_->total_input_tensors)); + + outstanding_frames_.insert({root_frame_->frame_name, root_frame_}); +} + +PropagatorState::~PropagatorState() { + for (auto name_frame : outstanding_frames_) { + delete name_frame.second; + } +} + +void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, + TaggedNodeSeq* ready) { + for (const NodeItem* item : roots) { + DCHECK_EQ(item->num_inputs, 0); + ready->push_back(TaggedNode{item, root_frame_, 0, false}); + } + mutex_lock l(root_frame_->mu); + root_frame_->GetIteration(0)->outstanding_ops = ready->size(); +} + +void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, + EntryVector* outputs, + TaggedNodeSeq* ready) { + profiler::TraceMe activity( + [&]() { + return strings::StrCat( + "ExecutorPropagateOutputs#", "id=", step_id_, + ",kernel_name=", tagged_node.node_item->kernel->name_view(), + ",num_output_edges=", tagged_node.node_item->num_output_edges, + ",num_output_control_edges=", + tagged_node.node_item->num_output_control_edges, "#"); + }, + profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); + + const NodeItem* const item = tagged_node.node_item; + FrameState* const input_frame = tagged_node.input_frame; + const int64 input_iter = tagged_node.input_iter; + const bool is_dead = tagged_node.is_dead; + + // Propagates outputs along out edges, and puts newly ready nodes + // into the ready queue. + DCHECK(ready->empty()); + bool is_frame_done = false; + FrameState* output_frame = input_frame; + int64 output_iter = input_iter; + + if (!item->is_enter_exit_or_next_iter) { + // Fast path for nodes types that don't need special handling + DCHECK_EQ(input_frame, output_frame); + // Normal path for most nodes + mutex_lock l(input_frame->mu); + output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + is_frame_done = + input_frame->DecrementOutstandingOpsLocked(input_iter, ready); + } else if (item->is_enter) { + FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); + output_iter = 0; + { + mutex_lock l(output_frame->mu); + if (item->is_constant_enter) { + // Propagate to all active iterations if this is a loop invariant. + output_frame->AddLoopInv(item, (*outputs)[0], ready); + } else { + output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + } + output_frame->num_pending_inputs--; + } + is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); + } else if (item->is_exit) { + if (is_dead) { + mutex_lock l(input_frame->mu); + // Stop and remember this node if it is a dead exit. + if (input_iter == input_frame->iteration_count) { + input_frame->dead_exits.push_back(item); + } + is_frame_done = + input_frame->DecrementOutstandingOpsLocked(input_iter, ready); + } else { + output_frame = input_frame->parent_frame; + output_iter = input_frame->parent_iter; + { + mutex_lock l(output_frame->mu); + output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + } + is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); + } + } else { + DCHECK(item->is_next_iteration); + mutex_lock l(input_frame->mu); + if (is_dead) { + // Stop the deadness propagation. + output_frame = nullptr; + } else { + if (input_iter == input_frame->iteration_count && + input_frame->num_outstanding_iterations == + input_frame->max_parallel_iterations) { + // Reached the maximum for parallel iterations. + input_frame->next_iter_roots.push_back({item, (*outputs)[0]}); + output_frame = nullptr; + } else { + // If this is a new iteration, start it. + if (input_iter == input_frame->iteration_count) { + input_frame->IncrementIteration(ready); + } + output_iter = input_iter + 1; + } + } + if (output_frame != nullptr) { + // This is the case when node is not Enter, Exit, or NextIteration. + DCHECK(input_frame == output_frame); + output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + } + is_frame_done = + input_frame->DecrementOutstandingOpsLocked(input_iter, ready); + } + + // At this point, this node is completely done. We also know if the + // completion of this node makes its frame completed. + if (is_frame_done) { + FrameState* parent_frame = input_frame->parent_frame; + const int64 parent_iter = input_frame->parent_iter; + DeleteFrame(input_frame, ready); + if (parent_frame != nullptr) { + // The completion of frame may cause completions in its parent frame. + // So clean things up recursively. + CleanupFramesIterations(parent_frame, parent_iter, ready); + } + } +} + +const Tensor* PropagatorState::GetTensorValueForDump(const Entry& input) { + switch (input.state) { + case Entry::State::NO_VALUE: + return kEmptyTensor; + case Entry::State::HAS_VALUE: + return input.val.get(); + case Entry::State::HAS_CONST_TENSOR: + return input.const_tensor; + case Entry::State::HAS_REF_TENSOR: + return input.ref_tensor.tensor; + } +} + +void PropagatorState::DumpPendingNodeState( + const int node_id, const Entry* input_vector, + const bool show_nodes_with_no_ready_inputs) { + const NodeItem& node_item = *immutable_state_.graph_view().node(node_id); + const int input_base = node_item.input_start; + if (!show_nodes_with_no_ready_inputs) { + bool has_ready_input = false; + for (int i = 0; i < node_item.num_inputs; ++i) { + const Entry& input = input_vector[input_base + i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor->IsInitialized()) { + has_ready_input = true; + break; + } + } + if (!has_ready_input) { + return; + } + } + LOG(WARNING) << " Pending Node: " << node_item.DebugString(); + for (int i = 0; i < node_item.num_inputs; ++i) { + const Entry& input = input_vector[input_base + i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor->IsInitialized()) { + LOG(WARNING) << " Input " << i << ": " + << strings::StrCat( + "Tensor<type: ", DataTypeString(tensor->dtype()), + " shape: ", tensor->shape().DebugString(), ">"); + } else { + LOG(WARNING) << " Input " << i << ": not present"; + } + } +} + +void PropagatorState::DumpActiveNodeState(const int node_id, + const Entry* input_vector) { + const NodeItem& node_item = *immutable_state_.graph_view().node(node_id); + LOG(WARNING) << " Active Node: " << node_item.DebugString(); + const int input_base = node_item.input_start; + for (int i = 0; i < node_item.num_inputs; ++i) { + const Entry& input = input_vector[input_base + i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor->IsInitialized()) { + LOG(WARNING) << " Input " << i << ": " + << strings::StrCat( + "Tensor<type: ", DataTypeString(tensor->dtype()), + " shape: ", tensor->shape().DebugString(), ">"); + } else { + LOG(WARNING) << " Input " << i << ": not present"; + } + } +} + +void PropagatorState::DumpIterationState(const FrameState* frame, + IterationState* iteration) { + const std::vector<const NodeItem*>* nodes = frame->nodes; + // Dump any waiting nodes that are holding on to tensors. + for (const NodeItem* node : *nodes) { + PendingCounts::Handle pending_id = + immutable_state_.pending_ids()[node->node_id]; + if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY || + iteration->node_state(pending_id) == PendingCounts::PENDING_READY) { + DumpPendingNodeState(node->node_id, iteration->input_tensors, false); + } + } + // Then the active nodes. + for (const NodeItem* node : *nodes) { + PendingCounts::Handle pending_id = + immutable_state_.pending_ids()[node->node_id]; + if (iteration->node_state(pending_id) == PendingCounts::STARTED) { + DumpActiveNodeState(node->node_id, iteration->input_tensors); + } + } + // Show all input tensors in use. + const int total_input_tensors = frame->total_input_tensors; + size_t total_bytes = 0; + for (int i = 0; i < total_input_tensors; ++i) { + const Entry& input = iteration->input_tensors[i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor->IsInitialized()) { + LOG(WARNING) << " Input " << i << ": " + << strings::StrCat( + "Tensor<type: ", DataTypeString(tensor->dtype()), + " shape: ", tensor->shape().DebugString(), + ", bytes: ", tensor->TotalBytes(), ">"); + total_bytes += tensor->TotalBytes(); + } + } + LOG(WARNING) << " Total bytes " << total_bytes; +} + +void PropagatorState::DumpState() { + mutex_lock l(mu_); + if (!dumped_on_error_) { + LOG(WARNING) << "Dumping state"; + for (auto& frame : outstanding_frames_) { + LOG(WARNING) << frame.first; + FrameState* frame_state = frame.second; + frame_state->DumpIterationState(this); + } + dumped_on_error_ = true; + } +} + +void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, + const NodeItem& node_item, + FrameState** child) { + // Get the child frame name. + AttrSlice attrs(node_item.kernel->def()); + const string& enter_name = GetNodeAttrString(attrs, "frame_name"); + DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node " + << node_item.kernel->name(); + const string child_name = + strings::StrCat(frame->frame_name, ";", iter, ";", enter_name); + + { + mutex_lock executor_lock(mu_); + auto it = outstanding_frames_.find(child_name); + if (it != outstanding_frames_.end()) { + *child = it->second; + return; + } + } + + // Need to create a new frame instance. + // Note that this new frame instance is created without any locks. + if (vlog_) VLOG(2) << "Create frame: " << child_name; + + int parallel_iters; + bool found_parallel_iters = + TryGetNodeAttr(attrs, "parallel_iterations", ¶llel_iters); + DCHECK(found_parallel_iters) + << "Could not find \"parallel_iterations\" attr in node " + << node_item.kernel->name(); + FrameState* temp = new FrameState(immutable_state_, parallel_iters); + temp->frame_name = child_name; + temp->frame_id = Hash64(child_name); + temp->parent_frame = frame; + temp->parent_iter = iter; + temp->InitializeFrameInfo(enter_name); + + // Initialize iteration 0. + { + mutex_lock l(temp->mu); + temp->SetIteration( + 0, new IterationState(temp->pending_counts, temp->total_input_tensors)); + } + + { + mutex_lock executor_lock(mu_); + auto it = outstanding_frames_.find(child_name); + if (it != outstanding_frames_.end()) { + *child = it->second; + } else { + mutex_lock frame_lock(frame->mu); + frame->GetIteration(iter)->outstanding_frame_count++; + outstanding_frames_[child_name] = temp; + *child = temp; + temp = nullptr; + } + } + delete temp; // Not used so delete it. +} + +void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { + // First, propagate dead_exits (if any) to the parent frame. + FrameState* parent_frame = frame->parent_frame; + const int64 parent_iter = frame->parent_iter; + if (parent_frame != nullptr) { + mutex_lock parent_frame_lock(parent_frame->mu); + // Propagate all the dead exits to the parent frame. + mutex_lock this_frame_lock(frame->mu); + + for (const NodeItem* item : frame->dead_exits) { + auto parent_iter_state = parent_frame->GetIteration(parent_iter); + + auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready, + bool dst_dead) { + if (dst_ready) { + if (dst_item.is_control_trigger) dst_dead = false; + ready->emplace_back(&dst_item, parent_frame, parent_iter, dst_dead); + parent_iter_state->outstanding_ops++; + } + }; + + auto propagate_to_non_merge = [&](PendingCounts::Handle dst_pending_id) { + parent_iter_state->increment_dead_count(dst_pending_id); + return parent_iter_state->decrement_pending(dst_pending_id, 1) == 0; + }; + + for (const EdgeInfo& e : item->output_edges()) { + const NodeItem& dst_item = + *immutable_state_.graph_view().node(e.dst_id); + const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; + + bool dst_dead = true; + bool dst_ready; + // We know this is a dead input to dst. + if (dst_item.is_merge) { + parent_iter_state->increment_dead_count(dst_pending_id); + const int dead_cnt = parent_iter_state->dead_count(dst_pending_id); + dst_dead = (dead_cnt == dst_item.num_inputs); + dst_ready = + (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead; + } else { + dst_ready = propagate_to_non_merge(dst_pending_id); + } + maybe_add_to_ready(dst_item, dst_ready, dst_dead); + } + + for (const ControlEdgeInfo& e : item->output_control_edges()) { + const NodeItem& dst_item = + *immutable_state_.graph_view().node(e.dst_id); + const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; + + bool dst_dead; + bool dst_ready; + // We know this is a dead input to dst. + if (dst_item.is_merge) { + parent_iter_state->decrement_pending(dst_pending_id, 2); + int count = parent_iter_state->pending(dst_pending_id); + int dead_cnt = parent_iter_state->dead_count(dst_pending_id); + dst_dead = (dead_cnt == dst_item.num_inputs); + dst_ready = (count == 0) || ((count == 1) && dst_dead); + } else { + dst_dead = true; + dst_ready = propagate_to_non_merge(dst_pending_id); + } + maybe_add_to_ready(dst_item, dst_ready, dst_dead); + } + } + } + + // Delete the frame. + const string& frame_name = frame->frame_name; + if (vlog_) VLOG(2) << "Delete frame " << frame_name; + { + mutex_lock executor_lock(mu_); + outstanding_frames_.erase(frame_name); + } + delete frame; +} + +void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter, + TaggedNodeSeq* ready) { + bool is_frame_done = false; + { + mutex_lock frame_lock(frame->mu); + frame->GetIteration(iter)->outstanding_frame_count--; + is_frame_done = frame->CleanupIterations(iter, ready); + } + if (is_frame_done) { + FrameState* parent_frame = frame->parent_frame; + const int64 parent_iter = frame->parent_iter; + DeleteFrame(frame, ready); + if (parent_frame != nullptr) { + // The completion of frame may cause completions in its parent frame. + // So clean things up recursively. + CleanupFramesIterations(parent_frame, parent_iter, ready); + } + } +} + +void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item, + const bool is_dead, + int64 iter, + EntryVector* outputs, + TaggedNodeSeq* ready) { + // If we know that none of the item's edge destinations require special + // handling (i.e. none of the nodes is a merge or control trigger node), we + // can take a fast path that avoids accessing the destination NodeItem. + const GraphView& gview = immutable_state.graph_view(); + IterationState* iter_state = GetIteration(iter); + +// Add dst to the ready queue if it's ready +// +// NOTE(mrry): Use a macro here instead of a lambda, because this method is +// performance-critical and we need to ensure that the code is inlined. +#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \ + do { \ + if (!adjust_result.any_pending) { \ + const NodeItem* dst_item = gview.node(dst_id); \ + TaggedNode& t = ready->emplace_back(); \ + t.node_item = dst_item; \ + t.input_frame = this; \ + t.input_iter = iter; \ + t.is_dead = adjust_result.any_dead; \ + iter_state->outstanding_ops++; \ + } \ + } while (0); + + Entry* input_tensors = iter_state->input_tensors; + + for (const EdgeInfo& e : item->output_edges()) { + const int dst_id = e.dst_id; + const PendingCounts::Handle dst_pending_id = + immutable_state.pending_ids()[dst_id]; + const int src_slot = e.output_slot; + + const bool increment_dead = + (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE)); + const PendingCounts::AdjustResult adjust_result = + iter_state->adjust_for_activation(dst_pending_id, increment_dead); + const int dst_loc = e.input_slot; + if (e.is_last) { + input_tensors[dst_loc] = std::move((*outputs)[src_slot]); + } else { + input_tensors[dst_loc] = (*outputs)[src_slot]; + } + MAYBE_ADD_TO_READY(dst_id, adjust_result); + } + + for (const ControlEdgeInfo& e : item->output_control_edges()) { + const int dst_id = e.dst_id; + const PendingCounts::Handle dst_pending_id = + immutable_state.pending_ids()[dst_id]; + const PendingCounts::AdjustResult adjust_result = + iter_state->adjust_for_activation(dst_pending_id, is_dead); + MAYBE_ADD_TO_READY(dst_id, adjust_result); + } +#undef MAYBE_ADD_TO_READY +} + +void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, + const bool is_dead, + int64 iter, + EntryVector* outputs, + TaggedNodeSeq* ready) { + // If any of the edge destinations is a merge or a control trigger node, + // we need to read each destination NodeItem to determine what action + // to take. + const GraphView& gview = immutable_state.graph_view(); + IterationState* iter_state = GetIteration(iter); + + auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item, + bool dst_ready, bool dst_dead) { + // Add dst to the ready queue if it's ready + if (dst_ready) { + if (dst_item->is_control_trigger) dst_dead = false; + ready->emplace_back(dst_item, this, iter, dst_dead); + iter_state->outstanding_ops++; + } + }; + + Entry* input_tensors = iter_state->input_tensors; + + for (const EdgeInfo& e : item->output_edges()) { + const int dst_id = e.dst_id; + const NodeItem* dst_item = gview.node(dst_id); + const PendingCounts::Handle dst_pending_id = + immutable_state.pending_ids()[dst_id]; + const int src_slot = e.output_slot; + + bool dst_dead = false; + bool dst_ready = false; + bool dst_need_input = true; + + if (dst_item->is_merge) { + // A merge node is ready if all control inputs have arrived and either + // a) a live data input becomes available or b) all data inputs are + // dead. For Merge, pending's LSB is set iff a live data input has + // arrived. + if ((*outputs)[src_slot].state != Entry::State::NO_VALUE) { + // This is a live data input. + int count = iter_state->pending(dst_pending_id); + iter_state->mark_live(dst_pending_id); + // Only the first live edge sets the input and (potentially) + // triggers execution. The low bit of count is set if and + // only if no live input has been used yet (mark_live clears + // it). The node should be started if and only if this is + // the first live input and there are no pending control + // edges, i.e. count == 1. + dst_ready = (count == 1); + dst_need_input = ((count & 0x1) == 1); + } else { + // This is a dead data input. Note that dst_node is dead if node is + // a dead enter. We need this to handle properly a while loop on + // the untaken branch of a conditional. + // TODO(yuanbyu): This is a bit hacky, but a good solution for + // now. + iter_state->increment_dead_count(dst_pending_id); + const int dead_cnt = iter_state->dead_count(dst_pending_id); + dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter; + dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead; + dst_need_input = false; + } + } else { + // Handle all other (non-merge) nodes. + const bool increment_dead = + (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE)); + const PendingCounts::AdjustResult adjust_result = + iter_state->adjust_for_activation(dst_pending_id, increment_dead); + dst_dead = adjust_result.any_dead; + dst_ready = !adjust_result.any_pending; + } + + if (dst_need_input) { + const int dst_loc = e.input_slot; + if (e.is_last) { + input_tensors[dst_loc] = std::move((*outputs)[src_slot]); + } else { + input_tensors[dst_loc] = (*outputs)[src_slot]; + } + } + + maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead); + } + + for (const ControlEdgeInfo& e : item->output_control_edges()) { + const int dst_id = e.dst_id; + const NodeItem* dst_item = gview.node(dst_id); + const PendingCounts::Handle dst_pending_id = + immutable_state.pending_ids()[dst_id]; + + bool dst_dead; + bool dst_ready; + if (dst_item->is_merge) { + // A merge node is ready if all control inputs have arrived and either + // a) a live data input becomes available or b) all data inputs are + // dead. For Merge, pending's LSB is set iff a live data input has + // arrived. + iter_state->decrement_pending(dst_pending_id, 2); + int count = iter_state->pending(dst_pending_id); + int dead_cnt = iter_state->dead_count(dst_pending_id); + dst_dead = (dead_cnt == dst_item->num_inputs); + dst_ready = (count == 0) || ((count == 1) && dst_dead); + } else { + // Handle all other (non-merge) nodes. + const PendingCounts::AdjustResult adjust_result = + iter_state->adjust_for_activation(dst_pending_id, is_dead); + dst_dead = adjust_result.any_dead; + dst_ready = !adjust_result.any_pending; + } + maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead); + } +} + +void PropagatorState::FrameState::ActivateNodes(const NodeItem* item, + const bool is_dead, int64 iter, + EntryVector* outputs, + TaggedNodeSeq* ready) { + if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { + ActivateNodesSlowPath(item, is_dead, iter, outputs, ready); + } else { + ActivateNodesFastPath(item, is_dead, iter, outputs, ready); + } +} + +void PropagatorState::FrameState::ActivateNexts(int64 iter, + TaggedNodeSeq* ready) { + // Propagate the deferred NextIteration nodes to the new iteration. + for (auto& node_entry : next_iter_roots) { + const NodeItem* item = node_entry.first; + const Entry& entry = node_entry.second; + const bool is_dead = entry.state == Entry::State::NO_VALUE; + EntryVector outputs{entry}; + ActivateNodes(item, is_dead, iter, &outputs, ready); + } + next_iter_roots.clear(); +} + +void PropagatorState::FrameState::ActivateLoopInvs(int64 iter, + TaggedNodeSeq* ready) { + // Propagate loop invariants to the new iteration. + for (auto& node_entry : inv_values) { + const NodeItem* item = node_entry.first; + const Entry& entry = node_entry.second; + const bool is_dead = entry.state == Entry::State::NO_VALUE; + EntryVector outputs{entry}; + ActivateNodes(item, is_dead, iter, &outputs, ready); + } +} + +void PropagatorState::FrameState::AddLoopInv(const NodeItem* item, + const Entry& entry, + TaggedNodeSeq* ready) { + // Store this value. + inv_values.push_back({item, entry}); + + // Make this value available to all iterations. + const bool is_dead = entry.state == Entry::State::NO_VALUE; + for (int i = 0; i <= iteration_count; ++i) { + EntryVector outputs{entry}; + ActivateNodes(item, is_dead, i, &outputs, ready); + } +} + +bool PropagatorState::FrameState::IsIterationDone(int64 iter) { + IterationState* iter_state = GetIteration(iter); + if (iter_state->outstanding_ops == 0 && + iter_state->outstanding_frame_count == 0) { + if (iter == 0) { + // The enclosing frame has no pending input. + return num_pending_inputs == 0; + } else { + // The preceding iteration is deleted (and therefore done). + return (GetIteration(iter - 1) == nullptr); + } + } + return false; +} + +void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) { + iteration_count++; + const int64 next_iter = iteration_count; + + // Initialize the next iteration. + IterationState* iter_state = + new IterationState(pending_counts, total_input_tensors); + SetIteration(next_iter, iter_state); + num_outstanding_iterations++; + dead_exits.clear(); + + // Activate the successors of the deferred roots in the new iteration. + ActivateNexts(next_iter, ready); + + // Activate the loop invariants in the new iteration. + ActivateLoopInvs(next_iter, ready); +} + +bool PropagatorState::FrameState::CleanupIterations(int64 iter, + TaggedNodeSeq* ready) { + int64 curr_iter = iter; + while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) { + // Delete the iteration curr_iter. + delete GetIteration(curr_iter); + SetIteration(curr_iter, nullptr); + --num_outstanding_iterations; + ++curr_iter; + + // When one iteration is completed, we check for deferred iteration, + // and start it if there is one. + if (!next_iter_roots.empty()) { + IncrementIteration(ready); + } + } + return IsFrameDone(); +} + +void PropagatorState::FrameState::InitializeFrameInfo( + const string& enter_name) { + const ImmutableExecutorState::FrameInfo* finfo = + immutable_state.get_frame_info(enter_name); + DCHECK_NE(finfo, nullptr); + pending_counts = finfo->pending_counts.get(); + total_input_tensors = finfo->total_inputs; + num_pending_inputs = finfo->input_count; + nodes = finfo->nodes.get(); +} + +void PropagatorState::FrameState::SetIteration(int64 iter, + IterationState* state) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + size_t index = iter % (max_parallel_iterations + 1); + DCHECK(state == nullptr || iterations[index] == nullptr); + iterations_raw[index] = state; + if (index == 0) { + iterations_first = state; + } +} + +// Decrement the outstanding op count and clean up the iterations in the +// frame. Return true iff the execution of the frame is done. +bool PropagatorState::FrameState::DecrementOutstandingOps( + int64 iter, TaggedNodeSeq* ready) { + mutex_lock l(mu); + return DecrementOutstandingOpsLocked(iter, ready); +} + +// Decrement the outstanding op count and clean up the iterations in the +// frame. Return true iff the execution of the frame is done. +bool PropagatorState::FrameState::DecrementOutstandingOpsLocked( + int64 iter, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + IterationState* istate = GetIteration(iter); + istate->outstanding_ops--; + if (istate->outstanding_ops != 0) { + return false; + } else { + return CleanupIterations(iter, ready); + } +} + +// Returns true if the computation in the frame is completed. +bool PropagatorState::FrameState::IsFrameDone() + TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + return (num_pending_inputs == 0 && num_outstanding_iterations == 0); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h new file mode 100644 index 00000000000..d82d3bf7261 --- /dev/null +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -0,0 +1,466 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_ + +#include <vector> + +#include "tensorflow/core/common_runtime/entry.h" +#include "tensorflow/core/common_runtime/immutable_executor_state.h" +#include "tensorflow/core/common_runtime/pending_counts.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; +typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; + +// Represents the ephemeral "edge state" associated with one invocation of +// `Executor::Run()`. +// +// `PropagatorState` is responsible for propagating values along dataflow +// edges in a TensorFlow graph and determining which nodes are runnable. The +// executor primarily updates `PropagatorState` by calling `PropagateOutputs()` +// after processing a node, and `PropagatorState` dispatches `TaggedNode`s by +// adding them to a `TaggedNodeSeq`. +class PropagatorState { + public: + PropagatorState(const ImmutableExecutorState& immutable_state, int64 step_id); + ~PropagatorState(); + + private: + // Forward declaration so that `TaggedNode` can include a `FrameState*`. + struct FrameState; + + public: + // A `TaggedNode` corresponds to a single invocation of a node's kernel, + // and it is created when the kernel becomes runnable (in a particular + // iteration of a particular frame). + struct TaggedNode { + const NodeItem* node_item; + FrameState* input_frame; + int64 input_iter; + bool is_dead; + + TaggedNode() = default; + TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter, + bool dead) + : node_item(node_item), + input_frame(in_frame), + input_iter(in_iter), + is_dead(dead) {} + + const NodeItem& get_node_item() const { return *node_item; } + + bool get_is_dead() const { return is_dead; } + }; + + // A drop-in replacement for std::deque<TaggedNode>. We typically don't + // have that many nodes in the ready queue, so we just use a vector and + // don't free up memory from the queue as we consume nodes. + class TaggedNodeReadyQueue { + public: + TaggedNodeReadyQueue() : front_index_(0) {} + + void push_back(const TaggedNode& node) { ready_.push_back(node); } + TaggedNode front() const { + DCHECK_LT(front_index_, ready_.size()); + return ready_[front_index_]; + } + void pop_front() { + DCHECK_LT(front_index_, ready_.size()); + front_index_++; + if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) { + if (front_index_ == ready_.size()) { + ready_.clear(); + } else { + // Lots of unused entries at beginning of vector: move everything + // down to start of vector. + ready_.erase(ready_.begin(), ready_.begin() + front_index_); + } + front_index_ = 0; + } + } + bool empty() const { return ready_.empty(); } + + private: + // TODO(b/152925936): Re-evaluate these constants with current usage + // patterns. + static constexpr int kSpillThreshold = 16384; + gtl::InlinedVector<TaggedNode, 16> ready_; + int front_index_; + }; + + // TODO(b/152925936): Re-evaluate this constant with current usage patterns. + typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq; + + private: + struct IterationState { + explicit IterationState(const PendingCounts* pending_counts, + int total_input_tensors) + : input_tensors(new Entry[total_input_tensors]), + outstanding_ops(0), + outstanding_frame_count(0), + counts(*pending_counts) { // Initialize with copy of *pending_counts + } + + // The state of an iteration. + + // One copy per iteration. For iteration k, i-th node's j-th input is in + // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is + // either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). + // + // NOTE: No need to protect input_tensors[i] by any locks because it + // is resized once. Each element of tensors_ is written once by the + // source node of an edge and is cleared by the destination of the same + // edge. The latter node is never run concurrently with the former node. + Entry* input_tensors; + + // The number of outstanding ops for each iteration. + size_t outstanding_ops; + + // The number of outstanding frames for each iteration. + int outstanding_frame_count; + int pending(PendingCounts::Handle h) { return counts.pending(h); } + int decrement_pending(PendingCounts::Handle h, int v) { + return counts.decrement_pending(h, v); + } + // Mark a merge node as live + // REQUIRES: Node corresponding to "h" is a merge node + void mark_live(PendingCounts::Handle h) { counts.mark_live(h); } + // Mark a node to show that processing has started. + void mark_started(PendingCounts::Handle h) { counts.mark_started(h); } + // Mark a node to show that processing has completed. + void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); } + PendingCounts::NodeState node_state(PendingCounts::Handle h) { + return counts.node_state(h); + } + + int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); } + void increment_dead_count(PendingCounts::Handle h) { + counts.increment_dead_count(h); + } + PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h, + bool increment_dead) { + return counts.adjust_for_activation(h, increment_dead); + } + + ~IterationState() { delete[] input_tensors; } + + private: + PendingCounts counts; + }; + + struct FrameState { + explicit FrameState(const ImmutableExecutorState& immutable_state, + int parallel_iters) + : immutable_state(immutable_state), + max_parallel_iterations(parallel_iters), + num_outstanding_iterations(1), + iterations(parallel_iters + 1), + iterations_raw(iterations.data()) {} + + // A new frame is created for each loop. Execution starts at iteration 0. + // When a value at iteration 0 passes through a NextIteration node, + // iteration 1 is created and starts running. Note that iteration 0 may + // still be running so multiple iterations may run in parallel. The + // frame maintains the state of iterations in several data structures + // such as pending_count and input_tensors. When iteration 0 completes, + // we garbage collect the state of iteration 0. + // + // A frame instance is considered "done" and can be garbage collected + // if all its inputs have entered and all its iterations are "done". + // + // A frame manages the live iterations of an iterative computation. + // Iteration i is considered "done" when there are no outstanding ops, + // frames at iteration i are done, all recvs for this iteration are + // completed, and iteration i-1 is done. For iteration 0, we instead + // wait for there to be no more pending inputs of the frame. + // + // Frames and iterations are garbage collected once they are done. + // The state we need to keep around is highly dependent on the + // parallelism enabled by the scheduler. We may want to have the + // scheduler dynamically control the outstanding number of live + // parallel frames and iterations. To reduce the state space, the + // scheduler might want to schedule ops in inner frames first and + // lower iterations first. + // + // This frame state is mostly initialized lazily on demand so we + // don't introduce unnecessary overhead. + + // The immutable state of the executor the frame is in. + const ImmutableExecutorState& immutable_state; + + // The name of this frame, which is the concatenation of its parent + // frame name, the iteration of the parent frame when this frame was + // created, and the value of the attr 'frame_name'. + string frame_name; + + // The unique id for this frame. Generated by fingerprinting + // frame_name. + uint64 frame_id; + + // The iteration id of its parent frame when this frame is created. + // -1 if there is no parent frame. The frame_name/parent_iter pair + // uniquely identifies this FrameState. + int64 parent_iter = -1; + + // The FrameState of its parent frame. + FrameState* parent_frame = nullptr; + + // The maximum allowed number of parallel iterations. + const int max_parallel_iterations; + + // The number of inputs this frame is still waiting. + int num_pending_inputs = 0; + + // The highest iteration number we have reached so far in this frame. + int64 iteration_count TF_GUARDED_BY(mu) = 0; + + // The number of outstanding iterations. + int num_outstanding_iterations TF_GUARDED_BY(mu) = 1; + + private: + // The active iteration states of this frame. + gtl::InlinedVector<IterationState*, 12> iterations; + IterationState** const iterations_raw TF_GUARDED_BY(mu); + IterationState* iterations_first TF_GUARDED_BY(mu); + + public: + // The NextIteration nodes to enter a new iteration. If the number of + // outstanding iterations reaches the limit, we will defer the start of + // the next iteration until the number of outstanding iterations falls + // below the limit. + std::vector<std::pair<const NodeItem*, Entry>> next_iter_roots + TF_GUARDED_BY(mu); + + // The values of the loop invariants for this loop. They are added into + // this list as they "enter" the frame. When a loop invariant enters, + // we make it available to all active iterations. When the frame starts + // a new iteration, we make all the current loop invariants available + // to the new iteration. + std::vector<std::pair<const NodeItem*, Entry>> inv_values TF_GUARDED_BY(mu); + + // The list of dead exit node items for the current highest iteration. We + // will only "execute" the dead exits of the final iteration. + std::vector<const NodeItem*> dead_exits TF_GUARDED_BY(mu); + + // Static information specific to this frame. + PendingCounts* pending_counts = nullptr; + int total_input_tensors = 0; + std::vector<const NodeItem*>* nodes = nullptr; + + // Lock ordering: ExecutorState.mu_ < mu; + // during structured traversal: parent_frame->mu < mu. + mutex mu; + + void InitializeFrameInfo(const string& enter_name); + + inline IterationState* GetIteration(int64 iter) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + if (TF_PREDICT_TRUE(iter == 0)) { + return iterations_first; + } else { + size_t index = iter % (max_parallel_iterations + 1); + return iterations_raw[index]; + } + } + + void SetIteration(int64 iter, IterationState* state); + + // Decrement the outstanding op count and clean up the iterations in the + // frame. Return true iff the execution of the frame is done. + bool DecrementOutstandingOps(int64 iter, TaggedNodeSeq* ready); + + // Decrement the outstanding op count and clean up the iterations in the + // frame. Return true iff the execution of the frame is done. + bool DecrementOutstandingOpsLocked(int64 iter, TaggedNodeSeq* ready); + + // Returns true if the computation in the frame is completed. + bool IsFrameDone(); + + // Returns true if the iteration of the frame is completed. + bool IsIterationDone(int64 iter) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Increments the iteration id. If this is a new iteration, initialize it. + void IncrementIteration(TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Activate all the deferred NextIteration nodes in a new iteration. + void ActivateNexts(int64 iter, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Activate all the current loop invariants in a new iteration. + void ActivateLoopInvs(int64 iter, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Add a new loop invariant and make it available to all active + // iterations. + void AddLoopInv(const NodeItem* item, const Entry& entry, + TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Activate the successors of a node. Contents of *outputs are left in an + // indeterminate state after returning from this method. + void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter, + EntryVector* outputs, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Cleanup iterations of this frame starting from iteration iter. + bool CleanupIterations(int64 iter, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + void DumpIterationState(PropagatorState* parent) { + mutex_lock l(mu); + for (IterationState* iteration : iterations) { + if (iteration) { + LOG(WARNING) << " Iteration:"; + parent->DumpIterationState(this, iteration); + } + } + } + + ~FrameState() { + for (size_t i = 0; i < iterations.size(); ++i) { + delete iterations[i]; + iterations[i] = nullptr; + } + } + + private: + // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. + void ActivateNodesFastPath(const NodeItem* item, const bool is_dead, + int64 iter, EntryVector* outputs, + TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead, + int64 iter, EntryVector* outputs, + TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + }; + + public: + // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. + void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, + TaggedNodeSeq* ready); + + // After processing the outputs, propagates the outputs to their dsts. + // Contents of *outputs are left in an indeterminate state after + // returning from this method. + void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs, + TaggedNodeSeq* ready); + + // Returns an array of `Entry` objects corresponding to the inputs of + // `tagged_node`. + // + // NOTE: Thread safety analysis is disabled on this method, because the + // underlying `IterationState` and its array of `input_tensors` retain the + // same address while the iteration is live. + Entry* GetInputTensors(const TaggedNode& tagged_node) const + TF_NO_THREAD_SAFETY_ANALYSIS { + return tagged_node.input_frame->GetIteration(tagged_node.input_iter) + ->input_tensors + + tagged_node.node_item->input_start; + } + + FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { + return {tagged_node.input_frame->frame_id, tagged_node.input_iter}; + } + + // Provide debugging output of the state of the executor. + void DumpState(); + + // For debugging/logging only. + void MaybeMarkStarted(const TaggedNode& tagged_node) { + // TODO(misard) Replace with a finer-grain enabling flag once we add better + // optional debugging support. + if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { + mutex_lock l(tagged_node.input_frame->mu); + tagged_node.input_frame->GetIteration(tagged_node.input_iter) + ->mark_started( + immutable_state_.pending_ids()[tagged_node.node_item->node_id]); + } + } + + void MaybeMarkCompleted(const TaggedNode& tagged_node) { + // TODO(misard) Replace with a finer-grain enabling flag once we add better + // optional debugging support. + if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { + mutex_lock l(tagged_node.input_frame->mu); + tagged_node.input_frame->GetIteration(tagged_node.input_iter) + ->mark_completed( + immutable_state_.pending_ids()[tagged_node.node_item->node_id]); + } + } + + private: + // Find an existing or create a new child frame in the frame 'frame' at + // iteration 'iter'. + void FindOrCreateChildFrame(FrameState* frame, int64 iter, + const NodeItem& node_item, FrameState** child); + + // Delete a frame. Called when the frame is done. + void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready); + + // Cleanup frames and iterations starting from frame/iter. Called when + // a child frame is done. + void CleanupFramesIterations(FrameState* frame, int64 iter, + TaggedNodeSeq* ready); + + // Provide debugging output about an outstanding node in the executor. + void DumpPendingNodeState(const int node_id, const Entry* input_vector, + bool show_nodes_with_no_ready_inputs); + void DumpActiveNodeState(const int node_id, const Entry* input_vector); + + // Provide debugging output about an outstanding iteration in the executor. + void DumpIterationState(const FrameState* frame, IterationState* iteration); + + const Tensor* GetTensorValueForDump(const Entry& input); + + const ImmutableExecutorState& immutable_state_; + const int64 step_id_; + const bool vlog_; + + mutex mu_; + + // A flag that is set on error after the frame state has been + // dumped for diagnostic purposes. + bool dumped_on_error_ TF_GUARDED_BY(mu_) = false; + + // The root frame in which the execution of this step is started. + FrameState* root_frame_; + + // Mapping from frame name to outstanding frames. A new frame is created + // at some iteration of an active frame. So the unique key for the new + // child frame is composed of the name of the parent frame, the iteration + // number at which the parent frame is creating the new frame, and the + // name of the new frame from nodedef. + gtl::FlatMap<string, FrameState*> outstanding_frames_ TF_GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_