[Executor] Split ExecutorState into PropagatorState and ExecutorState<PropagatorStateType>.

This change is part of an ongoing refactoring to simplify "executor.cc" and enable the substitution of more efficient implementations of `PropagateOutputs()`.

PiperOrigin-RevId: 304262448
Change-Id: I46a2d7fcdde89a71c502d272f35adfd34b0c4cab
This commit is contained in:
Derek Murray 2020-04-01 14:26:25 -07:00 committed by TensorFlower Gardener
parent 1778de6f64
commit bd530a65d5
5 changed files with 1550 additions and 1315 deletions

View File

@ -2546,6 +2546,7 @@ filegroup(
"common_runtime/debugger_state_interface.h",
"common_runtime/device_resolver_local.h",
"common_runtime/dma_helper.h",
"common_runtime/entry.h",
"common_runtime/executor.h",
"common_runtime/executor_factory.h",
"common_runtime/function_optimization_registry.h",
@ -2553,6 +2554,7 @@ filegroup(
"common_runtime/graph_view.h",
"common_runtime/immutable_executor_state.h",
"common_runtime/input_colocation_exemption_registry.h",
"common_runtime/inspecting_placer.h",
"common_runtime/isolate_placer_inspection_required_ops_pass.h",
"common_runtime/local_device.h",
"common_runtime/lower_function_call_op.h",
@ -2567,7 +2569,7 @@ filegroup(
"common_runtime/partitioning_utils.h",
"common_runtime/placer.h",
"common_runtime/process_util.h",
"common_runtime/inspecting_placer.h",
"common_runtime/propagator_state.h",
"common_runtime/profile_handler.h",
"common_runtime/renamed_device.h",
"common_runtime/rendezvous_mgr.h",
@ -2640,6 +2642,7 @@ tf_cuda_library(
"common_runtime/process_function_library_runtime.cc",
"common_runtime/process_state.cc",
"common_runtime/process_util.cc",
"common_runtime/propagator_state.cc",
"common_runtime/renamed_device.cc",
"common_runtime/rendezvous_mgr.cc",
"common_runtime/rendezvous_util.cc",

View File

@ -0,0 +1,142 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/manual_constructor.h"
namespace tensorflow {
class mutex;
class Tensor;
// An Entry store a single input value for an individual kernel invocation in
// an executor.
//
// Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
struct Entry {
enum class State {
NO_VALUE = 0, // The default state for a newly-created Entry.
HAS_VALUE, // `this->val` is valid.
HAS_CONST_TENSOR, // `this->const_tensor` is valid.
HAS_REF_TENSOR, // `this->ref_tensor` is valid.
};
Entry() : state(State::NO_VALUE) {}
Entry(const Entry& other) : state(other.state), alloc_attr(other.alloc_attr) {
switch (state) {
case State::NO_VALUE:
break;
case State::HAS_VALUE:
val.Init(*other.val);
break;
case State::HAS_CONST_TENSOR:
const_tensor = other.const_tensor;
break;
case State::HAS_REF_TENSOR:
ref_tensor = other.ref_tensor;
break;
}
}
~Entry() {
if (state == State::HAS_VALUE) val.Destroy();
}
Entry& operator=(const Entry& other) {
if (state == State::HAS_VALUE) {
val.Destroy();
}
state = other.state;
alloc_attr = other.alloc_attr;
switch (state) {
case State::NO_VALUE:
break;
case State::HAS_VALUE:
val.Init(*other.val);
break;
case State::HAS_CONST_TENSOR:
const_tensor = other.const_tensor;
break;
case State::HAS_REF_TENSOR:
ref_tensor = other.ref_tensor;
break;
}
return *this;
}
Entry& operator=(Entry&& other) {
if (state == State::HAS_VALUE) {
val.Destroy();
}
state = other.state;
alloc_attr = other.alloc_attr;
switch (state) {
case State::NO_VALUE:
break;
case State::HAS_VALUE:
val.Init(std::move(*other.val));
break;
case State::HAS_CONST_TENSOR:
const_tensor = other.const_tensor;
break;
case State::HAS_REF_TENSOR:
ref_tensor = other.ref_tensor;
break;
}
return *this;
}
// Clears the <val> field, and sets this entry to the `NO_VALUE` state.
void ClearVal() {
if (state == State::HAS_VALUE) {
val.Destroy();
}
state = State::NO_VALUE;
}
union {
// A tensor value. Valid iff `state_ == HAS_VALUE`.
ManualConstructor<Tensor> val;
// A pointer to a constant tensor value. Valid iff `state_ ==
// HAS_CONST_TENSOR`.
const Tensor* const_tensor;
// A tensor reference and associated mutex. Valid iff `state_ ==
// HAS_REF_TENSOR`.
struct {
Tensor* tensor;
mutex* mu;
} ref_tensor;
};
// The current state of this entry, indicating which member of the above
// union is active.
State state;
// The attributes of the allocator that creates the tensor.
AllocatorAttributes alloc_attr;
};
// TODO(b/152925936): Re-evaluate this constant with current usage patterns.
typedef gtl::InlinedVector<Entry, 4> EntryVector;
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,777 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/propagator_state.h"
#include "tensorflow/core/common_runtime/graph_view.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
// 1-D, 0 element tensor.
static const Tensor* const kEmptyTensor = new Tensor;
typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
int64 step_id)
: immutable_state_(immutable_state),
step_id_(step_id),
vlog_(VLOG_IS_ON(1)) {
// We start the entire execution in iteration 0 of the root frame
// so let us create the root frame and the state for iteration 0.
// We assume root_frame_->frame_name.empty().
root_frame_ = new FrameState(immutable_state_, 1);
root_frame_->frame_id = 0; // must be 0
root_frame_->InitializeFrameInfo(root_frame_->frame_name);
// Initialize iteration 0.
root_frame_->SetIteration(
0, new PropagatorState::IterationState(root_frame_->pending_counts,
root_frame_->total_input_tensors));
outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
}
PropagatorState::~PropagatorState() {
for (auto name_frame : outstanding_frames_) {
delete name_frame.second;
}
}
void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
TaggedNodeSeq* ready) {
for (const NodeItem* item : roots) {
DCHECK_EQ(item->num_inputs, 0);
ready->push_back(TaggedNode{item, root_frame_, 0, false});
}
mutex_lock l(root_frame_->mu);
root_frame_->GetIteration(0)->outstanding_ops = ready->size();
}
void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
EntryVector* outputs,
TaggedNodeSeq* ready) {
profiler::TraceMe activity(
[&]() {
return strings::StrCat(
"ExecutorPropagateOutputs#", "id=", step_id_,
",kernel_name=", tagged_node.node_item->kernel->name_view(),
",num_output_edges=", tagged_node.node_item->num_output_edges,
",num_output_control_edges=",
tagged_node.node_item->num_output_control_edges, "#");
},
profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
const NodeItem* const item = tagged_node.node_item;
FrameState* const input_frame = tagged_node.input_frame;
const int64 input_iter = tagged_node.input_iter;
const bool is_dead = tagged_node.is_dead;
// Propagates outputs along out edges, and puts newly ready nodes
// into the ready queue.
DCHECK(ready->empty());
bool is_frame_done = false;
FrameState* output_frame = input_frame;
int64 output_iter = input_iter;
if (!item->is_enter_exit_or_next_iter) {
// Fast path for nodes types that don't need special handling
DCHECK_EQ(input_frame, output_frame);
// Normal path for most nodes
mutex_lock l(input_frame->mu);
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
is_frame_done =
input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
} else if (item->is_enter) {
FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
output_iter = 0;
{
mutex_lock l(output_frame->mu);
if (item->is_constant_enter) {
// Propagate to all active iterations if this is a loop invariant.
output_frame->AddLoopInv(item, (*outputs)[0], ready);
} else {
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
}
output_frame->num_pending_inputs--;
}
is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
} else if (item->is_exit) {
if (is_dead) {
mutex_lock l(input_frame->mu);
// Stop and remember this node if it is a dead exit.
if (input_iter == input_frame->iteration_count) {
input_frame->dead_exits.push_back(item);
}
is_frame_done =
input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
} else {
output_frame = input_frame->parent_frame;
output_iter = input_frame->parent_iter;
{
mutex_lock l(output_frame->mu);
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
}
is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
}
} else {
DCHECK(item->is_next_iteration);
mutex_lock l(input_frame->mu);
if (is_dead) {
// Stop the deadness propagation.
output_frame = nullptr;
} else {
if (input_iter == input_frame->iteration_count &&
input_frame->num_outstanding_iterations ==
input_frame->max_parallel_iterations) {
// Reached the maximum for parallel iterations.
input_frame->next_iter_roots.push_back({item, (*outputs)[0]});
output_frame = nullptr;
} else {
// If this is a new iteration, start it.
if (input_iter == input_frame->iteration_count) {
input_frame->IncrementIteration(ready);
}
output_iter = input_iter + 1;
}
}
if (output_frame != nullptr) {
// This is the case when node is not Enter, Exit, or NextIteration.
DCHECK(input_frame == output_frame);
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
}
is_frame_done =
input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
}
// At this point, this node is completely done. We also know if the
// completion of this node makes its frame completed.
if (is_frame_done) {
FrameState* parent_frame = input_frame->parent_frame;
const int64 parent_iter = input_frame->parent_iter;
DeleteFrame(input_frame, ready);
if (parent_frame != nullptr) {
// The completion of frame may cause completions in its parent frame.
// So clean things up recursively.
CleanupFramesIterations(parent_frame, parent_iter, ready);
}
}
}
const Tensor* PropagatorState::GetTensorValueForDump(const Entry& input) {
switch (input.state) {
case Entry::State::NO_VALUE:
return kEmptyTensor;
case Entry::State::HAS_VALUE:
return input.val.get();
case Entry::State::HAS_CONST_TENSOR:
return input.const_tensor;
case Entry::State::HAS_REF_TENSOR:
return input.ref_tensor.tensor;
}
}
void PropagatorState::DumpPendingNodeState(
const int node_id, const Entry* input_vector,
const bool show_nodes_with_no_ready_inputs) {
const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
const int input_base = node_item.input_start;
if (!show_nodes_with_no_ready_inputs) {
bool has_ready_input = false;
for (int i = 0; i < node_item.num_inputs; ++i) {
const Entry& input = input_vector[input_base + i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor->IsInitialized()) {
has_ready_input = true;
break;
}
}
if (!has_ready_input) {
return;
}
}
LOG(WARNING) << " Pending Node: " << node_item.DebugString();
for (int i = 0; i < node_item.num_inputs; ++i) {
const Entry& input = input_vector[input_base + i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor->IsInitialized()) {
LOG(WARNING) << " Input " << i << ": "
<< strings::StrCat(
"Tensor<type: ", DataTypeString(tensor->dtype()),
" shape: ", tensor->shape().DebugString(), ">");
} else {
LOG(WARNING) << " Input " << i << ": not present";
}
}
}
void PropagatorState::DumpActiveNodeState(const int node_id,
const Entry* input_vector) {
const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
LOG(WARNING) << " Active Node: " << node_item.DebugString();
const int input_base = node_item.input_start;
for (int i = 0; i < node_item.num_inputs; ++i) {
const Entry& input = input_vector[input_base + i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor->IsInitialized()) {
LOG(WARNING) << " Input " << i << ": "
<< strings::StrCat(
"Tensor<type: ", DataTypeString(tensor->dtype()),
" shape: ", tensor->shape().DebugString(), ">");
} else {
LOG(WARNING) << " Input " << i << ": not present";
}
}
}
void PropagatorState::DumpIterationState(const FrameState* frame,
IterationState* iteration) {
const std::vector<const NodeItem*>* nodes = frame->nodes;
// Dump any waiting nodes that are holding on to tensors.
for (const NodeItem* node : *nodes) {
PendingCounts::Handle pending_id =
immutable_state_.pending_ids()[node->node_id];
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
DumpPendingNodeState(node->node_id, iteration->input_tensors, false);
}
}
// Then the active nodes.
for (const NodeItem* node : *nodes) {
PendingCounts::Handle pending_id =
immutable_state_.pending_ids()[node->node_id];
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
DumpActiveNodeState(node->node_id, iteration->input_tensors);
}
}
// Show all input tensors in use.
const int total_input_tensors = frame->total_input_tensors;
size_t total_bytes = 0;
for (int i = 0; i < total_input_tensors; ++i) {
const Entry& input = iteration->input_tensors[i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor->IsInitialized()) {
LOG(WARNING) << " Input " << i << ": "
<< strings::StrCat(
"Tensor<type: ", DataTypeString(tensor->dtype()),
" shape: ", tensor->shape().DebugString(),
", bytes: ", tensor->TotalBytes(), ">");
total_bytes += tensor->TotalBytes();
}
}
LOG(WARNING) << " Total bytes " << total_bytes;
}
void PropagatorState::DumpState() {
mutex_lock l(mu_);
if (!dumped_on_error_) {
LOG(WARNING) << "Dumping state";
for (auto& frame : outstanding_frames_) {
LOG(WARNING) << frame.first;
FrameState* frame_state = frame.second;
frame_state->DumpIterationState(this);
}
dumped_on_error_ = true;
}
}
void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
const NodeItem& node_item,
FrameState** child) {
// Get the child frame name.
AttrSlice attrs(node_item.kernel->def());
const string& enter_name = GetNodeAttrString(attrs, "frame_name");
DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
<< node_item.kernel->name();
const string child_name =
strings::StrCat(frame->frame_name, ";", iter, ";", enter_name);
{
mutex_lock executor_lock(mu_);
auto it = outstanding_frames_.find(child_name);
if (it != outstanding_frames_.end()) {
*child = it->second;
return;
}
}
// Need to create a new frame instance.
// Note that this new frame instance is created without any locks.
if (vlog_) VLOG(2) << "Create frame: " << child_name;
int parallel_iters;
bool found_parallel_iters =
TryGetNodeAttr(attrs, "parallel_iterations", &parallel_iters);
DCHECK(found_parallel_iters)
<< "Could not find \"parallel_iterations\" attr in node "
<< node_item.kernel->name();
FrameState* temp = new FrameState(immutable_state_, parallel_iters);
temp->frame_name = child_name;
temp->frame_id = Hash64(child_name);
temp->parent_frame = frame;
temp->parent_iter = iter;
temp->InitializeFrameInfo(enter_name);
// Initialize iteration 0.
{
mutex_lock l(temp->mu);
temp->SetIteration(
0, new IterationState(temp->pending_counts, temp->total_input_tensors));
}
{
mutex_lock executor_lock(mu_);
auto it = outstanding_frames_.find(child_name);
if (it != outstanding_frames_.end()) {
*child = it->second;
} else {
mutex_lock frame_lock(frame->mu);
frame->GetIteration(iter)->outstanding_frame_count++;
outstanding_frames_[child_name] = temp;
*child = temp;
temp = nullptr;
}
}
delete temp; // Not used so delete it.
}
void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
// First, propagate dead_exits (if any) to the parent frame.
FrameState* parent_frame = frame->parent_frame;
const int64 parent_iter = frame->parent_iter;
if (parent_frame != nullptr) {
mutex_lock parent_frame_lock(parent_frame->mu);
// Propagate all the dead exits to the parent frame.
mutex_lock this_frame_lock(frame->mu);
for (const NodeItem* item : frame->dead_exits) {
auto parent_iter_state = parent_frame->GetIteration(parent_iter);
auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
bool dst_dead) {
if (dst_ready) {
if (dst_item.is_control_trigger) dst_dead = false;
ready->emplace_back(&dst_item, parent_frame, parent_iter, dst_dead);
parent_iter_state->outstanding_ops++;
}
};
auto propagate_to_non_merge = [&](PendingCounts::Handle dst_pending_id) {
parent_iter_state->increment_dead_count(dst_pending_id);
return parent_iter_state->decrement_pending(dst_pending_id, 1) == 0;
};
for (const EdgeInfo& e : item->output_edges()) {
const NodeItem& dst_item =
*immutable_state_.graph_view().node(e.dst_id);
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
bool dst_dead = true;
bool dst_ready;
// We know this is a dead input to dst.
if (dst_item.is_merge) {
parent_iter_state->increment_dead_count(dst_pending_id);
const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
dst_dead = (dead_cnt == dst_item.num_inputs);
dst_ready =
(parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
} else {
dst_ready = propagate_to_non_merge(dst_pending_id);
}
maybe_add_to_ready(dst_item, dst_ready, dst_dead);
}
for (const ControlEdgeInfo& e : item->output_control_edges()) {
const NodeItem& dst_item =
*immutable_state_.graph_view().node(e.dst_id);
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
bool dst_dead;
bool dst_ready;
// We know this is a dead input to dst.
if (dst_item.is_merge) {
parent_iter_state->decrement_pending(dst_pending_id, 2);
int count = parent_iter_state->pending(dst_pending_id);
int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
dst_dead = (dead_cnt == dst_item.num_inputs);
dst_ready = (count == 0) || ((count == 1) && dst_dead);
} else {
dst_dead = true;
dst_ready = propagate_to_non_merge(dst_pending_id);
}
maybe_add_to_ready(dst_item, dst_ready, dst_dead);
}
}
}
// Delete the frame.
const string& frame_name = frame->frame_name;
if (vlog_) VLOG(2) << "Delete frame " << frame_name;
{
mutex_lock executor_lock(mu_);
outstanding_frames_.erase(frame_name);
}
delete frame;
}
void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter,
TaggedNodeSeq* ready) {
bool is_frame_done = false;
{
mutex_lock frame_lock(frame->mu);
frame->GetIteration(iter)->outstanding_frame_count--;
is_frame_done = frame->CleanupIterations(iter, ready);
}
if (is_frame_done) {
FrameState* parent_frame = frame->parent_frame;
const int64 parent_iter = frame->parent_iter;
DeleteFrame(frame, ready);
if (parent_frame != nullptr) {
// The completion of frame may cause completions in its parent frame.
// So clean things up recursively.
CleanupFramesIterations(parent_frame, parent_iter, ready);
}
}
}
void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
const bool is_dead,
int64 iter,
EntryVector* outputs,
TaggedNodeSeq* ready) {
// If we know that none of the item's edge destinations require special
// handling (i.e. none of the nodes is a merge or control trigger node), we
// can take a fast path that avoids accessing the destination NodeItem.
const GraphView& gview = immutable_state.graph_view();
IterationState* iter_state = GetIteration(iter);
// Add dst to the ready queue if it's ready
//
// NOTE(mrry): Use a macro here instead of a lambda, because this method is
// performance-critical and we need to ensure that the code is inlined.
#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
do { \
if (!adjust_result.any_pending) { \
const NodeItem* dst_item = gview.node(dst_id); \
TaggedNode& t = ready->emplace_back(); \
t.node_item = dst_item; \
t.input_frame = this; \
t.input_iter = iter; \
t.is_dead = adjust_result.any_dead; \
iter_state->outstanding_ops++; \
} \
} while (0);
Entry* input_tensors = iter_state->input_tensors;
for (const EdgeInfo& e : item->output_edges()) {
const int dst_id = e.dst_id;
const PendingCounts::Handle dst_pending_id =
immutable_state.pending_ids()[dst_id];
const int src_slot = e.output_slot;
const bool increment_dead =
(is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
const PendingCounts::AdjustResult adjust_result =
iter_state->adjust_for_activation(dst_pending_id, increment_dead);
const int dst_loc = e.input_slot;
if (e.is_last) {
input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
} else {
input_tensors[dst_loc] = (*outputs)[src_slot];
}
MAYBE_ADD_TO_READY(dst_id, adjust_result);
}
for (const ControlEdgeInfo& e : item->output_control_edges()) {
const int dst_id = e.dst_id;
const PendingCounts::Handle dst_pending_id =
immutable_state.pending_ids()[dst_id];
const PendingCounts::AdjustResult adjust_result =
iter_state->adjust_for_activation(dst_pending_id, is_dead);
MAYBE_ADD_TO_READY(dst_id, adjust_result);
}
#undef MAYBE_ADD_TO_READY
}
void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
const bool is_dead,
int64 iter,
EntryVector* outputs,
TaggedNodeSeq* ready) {
// If any of the edge destinations is a merge or a control trigger node,
// we need to read each destination NodeItem to determine what action
// to take.
const GraphView& gview = immutable_state.graph_view();
IterationState* iter_state = GetIteration(iter);
auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
bool dst_ready, bool dst_dead) {
// Add dst to the ready queue if it's ready
if (dst_ready) {
if (dst_item->is_control_trigger) dst_dead = false;
ready->emplace_back(dst_item, this, iter, dst_dead);
iter_state->outstanding_ops++;
}
};
Entry* input_tensors = iter_state->input_tensors;
for (const EdgeInfo& e : item->output_edges()) {
const int dst_id = e.dst_id;
const NodeItem* dst_item = gview.node(dst_id);
const PendingCounts::Handle dst_pending_id =
immutable_state.pending_ids()[dst_id];
const int src_slot = e.output_slot;
bool dst_dead = false;
bool dst_ready = false;
bool dst_need_input = true;
if (dst_item->is_merge) {
// A merge node is ready if all control inputs have arrived and either
// a) a live data input becomes available or b) all data inputs are
// dead. For Merge, pending's LSB is set iff a live data input has
// arrived.
if ((*outputs)[src_slot].state != Entry::State::NO_VALUE) {
// This is a live data input.
int count = iter_state->pending(dst_pending_id);
iter_state->mark_live(dst_pending_id);
// Only the first live edge sets the input and (potentially)
// triggers execution. The low bit of count is set if and
// only if no live input has been used yet (mark_live clears
// it). The node should be started if and only if this is
// the first live input and there are no pending control
// edges, i.e. count == 1.
dst_ready = (count == 1);
dst_need_input = ((count & 0x1) == 1);
} else {
// This is a dead data input. Note that dst_node is dead if node is
// a dead enter. We need this to handle properly a while loop on
// the untaken branch of a conditional.
// TODO(yuanbyu): This is a bit hacky, but a good solution for
// now.
iter_state->increment_dead_count(dst_pending_id);
const int dead_cnt = iter_state->dead_count(dst_pending_id);
dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
dst_need_input = false;
}
} else {
// Handle all other (non-merge) nodes.
const bool increment_dead =
(is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
const PendingCounts::AdjustResult adjust_result =
iter_state->adjust_for_activation(dst_pending_id, increment_dead);
dst_dead = adjust_result.any_dead;
dst_ready = !adjust_result.any_pending;
}
if (dst_need_input) {
const int dst_loc = e.input_slot;
if (e.is_last) {
input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
} else {
input_tensors[dst_loc] = (*outputs)[src_slot];
}
}
maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
}
for (const ControlEdgeInfo& e : item->output_control_edges()) {
const int dst_id = e.dst_id;
const NodeItem* dst_item = gview.node(dst_id);
const PendingCounts::Handle dst_pending_id =
immutable_state.pending_ids()[dst_id];
bool dst_dead;
bool dst_ready;
if (dst_item->is_merge) {
// A merge node is ready if all control inputs have arrived and either
// a) a live data input becomes available or b) all data inputs are
// dead. For Merge, pending's LSB is set iff a live data input has
// arrived.
iter_state->decrement_pending(dst_pending_id, 2);
int count = iter_state->pending(dst_pending_id);
int dead_cnt = iter_state->dead_count(dst_pending_id);
dst_dead = (dead_cnt == dst_item->num_inputs);
dst_ready = (count == 0) || ((count == 1) && dst_dead);
} else {
// Handle all other (non-merge) nodes.
const PendingCounts::AdjustResult adjust_result =
iter_state->adjust_for_activation(dst_pending_id, is_dead);
dst_dead = adjust_result.any_dead;
dst_ready = !adjust_result.any_pending;
}
maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
}
}
void PropagatorState::FrameState::ActivateNodes(const NodeItem* item,
const bool is_dead, int64 iter,
EntryVector* outputs,
TaggedNodeSeq* ready) {
if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
ActivateNodesSlowPath(item, is_dead, iter, outputs, ready);
} else {
ActivateNodesFastPath(item, is_dead, iter, outputs, ready);
}
}
void PropagatorState::FrameState::ActivateNexts(int64 iter,
TaggedNodeSeq* ready) {
// Propagate the deferred NextIteration nodes to the new iteration.
for (auto& node_entry : next_iter_roots) {
const NodeItem* item = node_entry.first;
const Entry& entry = node_entry.second;
const bool is_dead = entry.state == Entry::State::NO_VALUE;
EntryVector outputs{entry};
ActivateNodes(item, is_dead, iter, &outputs, ready);
}
next_iter_roots.clear();
}
void PropagatorState::FrameState::ActivateLoopInvs(int64 iter,
TaggedNodeSeq* ready) {
// Propagate loop invariants to the new iteration.
for (auto& node_entry : inv_values) {
const NodeItem* item = node_entry.first;
const Entry& entry = node_entry.second;
const bool is_dead = entry.state == Entry::State::NO_VALUE;
EntryVector outputs{entry};
ActivateNodes(item, is_dead, iter, &outputs, ready);
}
}
void PropagatorState::FrameState::AddLoopInv(const NodeItem* item,
const Entry& entry,
TaggedNodeSeq* ready) {
// Store this value.
inv_values.push_back({item, entry});
// Make this value available to all iterations.
const bool is_dead = entry.state == Entry::State::NO_VALUE;
for (int i = 0; i <= iteration_count; ++i) {
EntryVector outputs{entry};
ActivateNodes(item, is_dead, i, &outputs, ready);
}
}
bool PropagatorState::FrameState::IsIterationDone(int64 iter) {
IterationState* iter_state = GetIteration(iter);
if (iter_state->outstanding_ops == 0 &&
iter_state->outstanding_frame_count == 0) {
if (iter == 0) {
// The enclosing frame has no pending input.
return num_pending_inputs == 0;
} else {
// The preceding iteration is deleted (and therefore done).
return (GetIteration(iter - 1) == nullptr);
}
}
return false;
}
void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
iteration_count++;
const int64 next_iter = iteration_count;
// Initialize the next iteration.
IterationState* iter_state =
new IterationState(pending_counts, total_input_tensors);
SetIteration(next_iter, iter_state);
num_outstanding_iterations++;
dead_exits.clear();
// Activate the successors of the deferred roots in the new iteration.
ActivateNexts(next_iter, ready);
// Activate the loop invariants in the new iteration.
ActivateLoopInvs(next_iter, ready);
}
bool PropagatorState::FrameState::CleanupIterations(int64 iter,
TaggedNodeSeq* ready) {
int64 curr_iter = iter;
while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) {
// Delete the iteration curr_iter.
delete GetIteration(curr_iter);
SetIteration(curr_iter, nullptr);
--num_outstanding_iterations;
++curr_iter;
// When one iteration is completed, we check for deferred iteration,
// and start it if there is one.
if (!next_iter_roots.empty()) {
IncrementIteration(ready);
}
}
return IsFrameDone();
}
void PropagatorState::FrameState::InitializeFrameInfo(
const string& enter_name) {
const ImmutableExecutorState::FrameInfo* finfo =
immutable_state.get_frame_info(enter_name);
DCHECK_NE(finfo, nullptr);
pending_counts = finfo->pending_counts.get();
total_input_tensors = finfo->total_inputs;
num_pending_inputs = finfo->input_count;
nodes = finfo->nodes.get();
}
void PropagatorState::FrameState::SetIteration(int64 iter,
IterationState* state)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
size_t index = iter % (max_parallel_iterations + 1);
DCHECK(state == nullptr || iterations[index] == nullptr);
iterations_raw[index] = state;
if (index == 0) {
iterations_first = state;
}
}
// Decrement the outstanding op count and clean up the iterations in the
// frame. Return true iff the execution of the frame is done.
bool PropagatorState::FrameState::DecrementOutstandingOps(
int64 iter, TaggedNodeSeq* ready) {
mutex_lock l(mu);
return DecrementOutstandingOpsLocked(iter, ready);
}
// Decrement the outstanding op count and clean up the iterations in the
// frame. Return true iff the execution of the frame is done.
bool PropagatorState::FrameState::DecrementOutstandingOpsLocked(
int64 iter, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
IterationState* istate = GetIteration(iter);
istate->outstanding_ops--;
if (istate->outstanding_ops != 0) {
return false;
} else {
return CleanupIterations(iter, ready);
}
}
// Returns true if the computation in the frame is completed.
bool PropagatorState::FrameState::IsFrameDone()
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
}
} // namespace tensorflow

View File

@ -0,0 +1,466 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
#include <vector>
#include "tensorflow/core/common_runtime/entry.h"
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
#include "tensorflow/core/common_runtime/pending_counts.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
// Represents the ephemeral "edge state" associated with one invocation of
// `Executor::Run()`.
//
// `PropagatorState` is responsible for propagating values along dataflow
// edges in a TensorFlow graph and determining which nodes are runnable. The
// executor primarily updates `PropagatorState` by calling `PropagateOutputs()`
// after processing a node, and `PropagatorState` dispatches `TaggedNode`s by
// adding them to a `TaggedNodeSeq`.
class PropagatorState {
public:
PropagatorState(const ImmutableExecutorState& immutable_state, int64 step_id);
~PropagatorState();
private:
// Forward declaration so that `TaggedNode` can include a `FrameState*`.
struct FrameState;
public:
// A `TaggedNode` corresponds to a single invocation of a node's kernel,
// and it is created when the kernel becomes runnable (in a particular
// iteration of a particular frame).
struct TaggedNode {
const NodeItem* node_item;
FrameState* input_frame;
int64 input_iter;
bool is_dead;
TaggedNode() = default;
TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter,
bool dead)
: node_item(node_item),
input_frame(in_frame),
input_iter(in_iter),
is_dead(dead) {}
const NodeItem& get_node_item() const { return *node_item; }
bool get_is_dead() const { return is_dead; }
};
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
// have that many nodes in the ready queue, so we just use a vector and
// don't free up memory from the queue as we consume nodes.
class TaggedNodeReadyQueue {
public:
TaggedNodeReadyQueue() : front_index_(0) {}
void push_back(const TaggedNode& node) { ready_.push_back(node); }
TaggedNode front() const {
DCHECK_LT(front_index_, ready_.size());
return ready_[front_index_];
}
void pop_front() {
DCHECK_LT(front_index_, ready_.size());
front_index_++;
if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
if (front_index_ == ready_.size()) {
ready_.clear();
} else {
// Lots of unused entries at beginning of vector: move everything
// down to start of vector.
ready_.erase(ready_.begin(), ready_.begin() + front_index_);
}
front_index_ = 0;
}
}
bool empty() const { return ready_.empty(); }
private:
// TODO(b/152925936): Re-evaluate these constants with current usage
// patterns.
static constexpr int kSpillThreshold = 16384;
gtl::InlinedVector<TaggedNode, 16> ready_;
int front_index_;
};
// TODO(b/152925936): Re-evaluate this constant with current usage patterns.
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
private:
struct IterationState {
explicit IterationState(const PendingCounts* pending_counts,
int total_input_tensors)
: input_tensors(new Entry[total_input_tensors]),
outstanding_ops(0),
outstanding_frame_count(0),
counts(*pending_counts) { // Initialize with copy of *pending_counts
}
// The state of an iteration.
// One copy per iteration. For iteration k, i-th node's j-th input is in
// input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
// either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
//
// NOTE: No need to protect input_tensors[i] by any locks because it
// is resized once. Each element of tensors_ is written once by the
// source node of an edge and is cleared by the destination of the same
// edge. The latter node is never run concurrently with the former node.
Entry* input_tensors;
// The number of outstanding ops for each iteration.
size_t outstanding_ops;
// The number of outstanding frames for each iteration.
int outstanding_frame_count;
int pending(PendingCounts::Handle h) { return counts.pending(h); }
int decrement_pending(PendingCounts::Handle h, int v) {
return counts.decrement_pending(h, v);
}
// Mark a merge node as live
// REQUIRES: Node corresponding to "h" is a merge node
void mark_live(PendingCounts::Handle h) { counts.mark_live(h); }
// Mark a node to show that processing has started.
void mark_started(PendingCounts::Handle h) { counts.mark_started(h); }
// Mark a node to show that processing has completed.
void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); }
PendingCounts::NodeState node_state(PendingCounts::Handle h) {
return counts.node_state(h);
}
int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); }
void increment_dead_count(PendingCounts::Handle h) {
counts.increment_dead_count(h);
}
PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h,
bool increment_dead) {
return counts.adjust_for_activation(h, increment_dead);
}
~IterationState() { delete[] input_tensors; }
private:
PendingCounts counts;
};
struct FrameState {
explicit FrameState(const ImmutableExecutorState& immutable_state,
int parallel_iters)
: immutable_state(immutable_state),
max_parallel_iterations(parallel_iters),
num_outstanding_iterations(1),
iterations(parallel_iters + 1),
iterations_raw(iterations.data()) {}
// A new frame is created for each loop. Execution starts at iteration 0.
// When a value at iteration 0 passes through a NextIteration node,
// iteration 1 is created and starts running. Note that iteration 0 may
// still be running so multiple iterations may run in parallel. The
// frame maintains the state of iterations in several data structures
// such as pending_count and input_tensors. When iteration 0 completes,
// we garbage collect the state of iteration 0.
//
// A frame instance is considered "done" and can be garbage collected
// if all its inputs have entered and all its iterations are "done".
//
// A frame manages the live iterations of an iterative computation.
// Iteration i is considered "done" when there are no outstanding ops,
// frames at iteration i are done, all recvs for this iteration are
// completed, and iteration i-1 is done. For iteration 0, we instead
// wait for there to be no more pending inputs of the frame.
//
// Frames and iterations are garbage collected once they are done.
// The state we need to keep around is highly dependent on the
// parallelism enabled by the scheduler. We may want to have the
// scheduler dynamically control the outstanding number of live
// parallel frames and iterations. To reduce the state space, the
// scheduler might want to schedule ops in inner frames first and
// lower iterations first.
//
// This frame state is mostly initialized lazily on demand so we
// don't introduce unnecessary overhead.
// The immutable state of the executor the frame is in.
const ImmutableExecutorState& immutable_state;
// The name of this frame, which is the concatenation of its parent
// frame name, the iteration of the parent frame when this frame was
// created, and the value of the attr 'frame_name'.
string frame_name;
// The unique id for this frame. Generated by fingerprinting
// frame_name.
uint64 frame_id;
// The iteration id of its parent frame when this frame is created.
// -1 if there is no parent frame. The frame_name/parent_iter pair
// uniquely identifies this FrameState.
int64 parent_iter = -1;
// The FrameState of its parent frame.
FrameState* parent_frame = nullptr;
// The maximum allowed number of parallel iterations.
const int max_parallel_iterations;
// The number of inputs this frame is still waiting.
int num_pending_inputs = 0;
// The highest iteration number we have reached so far in this frame.
int64 iteration_count TF_GUARDED_BY(mu) = 0;
// The number of outstanding iterations.
int num_outstanding_iterations TF_GUARDED_BY(mu) = 1;
private:
// The active iteration states of this frame.
gtl::InlinedVector<IterationState*, 12> iterations;
IterationState** const iterations_raw TF_GUARDED_BY(mu);
IterationState* iterations_first TF_GUARDED_BY(mu);
public:
// The NextIteration nodes to enter a new iteration. If the number of
// outstanding iterations reaches the limit, we will defer the start of
// the next iteration until the number of outstanding iterations falls
// below the limit.
std::vector<std::pair<const NodeItem*, Entry>> next_iter_roots
TF_GUARDED_BY(mu);
// The values of the loop invariants for this loop. They are added into
// this list as they "enter" the frame. When a loop invariant enters,
// we make it available to all active iterations. When the frame starts
// a new iteration, we make all the current loop invariants available
// to the new iteration.
std::vector<std::pair<const NodeItem*, Entry>> inv_values TF_GUARDED_BY(mu);
// The list of dead exit node items for the current highest iteration. We
// will only "execute" the dead exits of the final iteration.
std::vector<const NodeItem*> dead_exits TF_GUARDED_BY(mu);
// Static information specific to this frame.
PendingCounts* pending_counts = nullptr;
int total_input_tensors = 0;
std::vector<const NodeItem*>* nodes = nullptr;
// Lock ordering: ExecutorState.mu_ < mu;
// during structured traversal: parent_frame->mu < mu.
mutex mu;
void InitializeFrameInfo(const string& enter_name);
inline IterationState* GetIteration(int64 iter)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
if (TF_PREDICT_TRUE(iter == 0)) {
return iterations_first;
} else {
size_t index = iter % (max_parallel_iterations + 1);
return iterations_raw[index];
}
}
void SetIteration(int64 iter, IterationState* state);
// Decrement the outstanding op count and clean up the iterations in the
// frame. Return true iff the execution of the frame is done.
bool DecrementOutstandingOps(int64 iter, TaggedNodeSeq* ready);
// Decrement the outstanding op count and clean up the iterations in the
// frame. Return true iff the execution of the frame is done.
bool DecrementOutstandingOpsLocked(int64 iter, TaggedNodeSeq* ready);
// Returns true if the computation in the frame is completed.
bool IsFrameDone();
// Returns true if the iteration of the frame is completed.
bool IsIterationDone(int64 iter) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Increments the iteration id. If this is a new iteration, initialize it.
void IncrementIteration(TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Activate all the deferred NextIteration nodes in a new iteration.
void ActivateNexts(int64 iter, TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Activate all the current loop invariants in a new iteration.
void ActivateLoopInvs(int64 iter, TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Add a new loop invariant and make it available to all active
// iterations.
void AddLoopInv(const NodeItem* item, const Entry& entry,
TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Activate the successors of a node. Contents of *outputs are left in an
// indeterminate state after returning from this method.
void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter,
EntryVector* outputs, TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Cleanup iterations of this frame starting from iteration iter.
bool CleanupIterations(int64 iter, TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
void DumpIterationState(PropagatorState* parent) {
mutex_lock l(mu);
for (IterationState* iteration : iterations) {
if (iteration) {
LOG(WARNING) << " Iteration:";
parent->DumpIterationState(this, iteration);
}
}
}
~FrameState() {
for (size_t i = 0; i < iterations.size(); ++i) {
delete iterations[i];
iterations[i] = nullptr;
}
}
private:
// REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
void ActivateNodesFastPath(const NodeItem* item, const bool is_dead,
int64 iter, EntryVector* outputs,
TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
int64 iter, EntryVector* outputs,
TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
};
public:
// Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
TaggedNodeSeq* ready);
// After processing the outputs, propagates the outputs to their dsts.
// Contents of *outputs are left in an indeterminate state after
// returning from this method.
void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
TaggedNodeSeq* ready);
// Returns an array of `Entry` objects corresponding to the inputs of
// `tagged_node`.
//
// NOTE: Thread safety analysis is disabled on this method, because the
// underlying `IterationState` and its array of `input_tensors` retain the
// same address while the iteration is live.
Entry* GetInputTensors(const TaggedNode& tagged_node) const
TF_NO_THREAD_SAFETY_ANALYSIS {
return tagged_node.input_frame->GetIteration(tagged_node.input_iter)
->input_tensors +
tagged_node.node_item->input_start;
}
FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
return {tagged_node.input_frame->frame_id, tagged_node.input_iter};
}
// Provide debugging output of the state of the executor.
void DumpState();
// For debugging/logging only.
void MaybeMarkStarted(const TaggedNode& tagged_node) {
// TODO(misard) Replace with a finer-grain enabling flag once we add better
// optional debugging support.
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
mutex_lock l(tagged_node.input_frame->mu);
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
->mark_started(
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
}
}
void MaybeMarkCompleted(const TaggedNode& tagged_node) {
// TODO(misard) Replace with a finer-grain enabling flag once we add better
// optional debugging support.
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
mutex_lock l(tagged_node.input_frame->mu);
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
->mark_completed(
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
}
}
private:
// Find an existing or create a new child frame in the frame 'frame' at
// iteration 'iter'.
void FindOrCreateChildFrame(FrameState* frame, int64 iter,
const NodeItem& node_item, FrameState** child);
// Delete a frame. Called when the frame is done.
void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
// Cleanup frames and iterations starting from frame/iter. Called when
// a child frame is done.
void CleanupFramesIterations(FrameState* frame, int64 iter,
TaggedNodeSeq* ready);
// Provide debugging output about an outstanding node in the executor.
void DumpPendingNodeState(const int node_id, const Entry* input_vector,
bool show_nodes_with_no_ready_inputs);
void DumpActiveNodeState(const int node_id, const Entry* input_vector);
// Provide debugging output about an outstanding iteration in the executor.
void DumpIterationState(const FrameState* frame, IterationState* iteration);
const Tensor* GetTensorValueForDump(const Entry& input);
const ImmutableExecutorState& immutable_state_;
const int64 step_id_;
const bool vlog_;
mutex mu_;
// A flag that is set on error after the frame state has been
// dumped for diagnostic purposes.
bool dumped_on_error_ TF_GUARDED_BY(mu_) = false;
// The root frame in which the execution of this step is started.
FrameState* root_frame_;
// Mapping from frame name to outstanding frames. A new frame is created
// at some iteration of an active frame. So the unique key for the new
// child frame is composed of the name of the parent frame, the iteration
// number at which the parent frame is creating the new frame, and the
// name of the new frame from nodedef.
gtl::FlatMap<string, FrameState*> outstanding_frames_ TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_