[Executor] Split ExecutorState
into PropagatorState
and ExecutorState<PropagatorStateType>
.
This change is part of an ongoing refactoring to simplify "executor.cc" and enable the substitution of more efficient implementations of `PropagateOutputs()`. PiperOrigin-RevId: 304262448 Change-Id: I46a2d7fcdde89a71c502d272f35adfd34b0c4cab
This commit is contained in:
parent
1778de6f64
commit
bd530a65d5
tensorflow/core
@ -2546,6 +2546,7 @@ filegroup(
|
||||
"common_runtime/debugger_state_interface.h",
|
||||
"common_runtime/device_resolver_local.h",
|
||||
"common_runtime/dma_helper.h",
|
||||
"common_runtime/entry.h",
|
||||
"common_runtime/executor.h",
|
||||
"common_runtime/executor_factory.h",
|
||||
"common_runtime/function_optimization_registry.h",
|
||||
@ -2553,6 +2554,7 @@ filegroup(
|
||||
"common_runtime/graph_view.h",
|
||||
"common_runtime/immutable_executor_state.h",
|
||||
"common_runtime/input_colocation_exemption_registry.h",
|
||||
"common_runtime/inspecting_placer.h",
|
||||
"common_runtime/isolate_placer_inspection_required_ops_pass.h",
|
||||
"common_runtime/local_device.h",
|
||||
"common_runtime/lower_function_call_op.h",
|
||||
@ -2567,7 +2569,7 @@ filegroup(
|
||||
"common_runtime/partitioning_utils.h",
|
||||
"common_runtime/placer.h",
|
||||
"common_runtime/process_util.h",
|
||||
"common_runtime/inspecting_placer.h",
|
||||
"common_runtime/propagator_state.h",
|
||||
"common_runtime/profile_handler.h",
|
||||
"common_runtime/renamed_device.h",
|
||||
"common_runtime/rendezvous_mgr.h",
|
||||
@ -2640,6 +2642,7 @@ tf_cuda_library(
|
||||
"common_runtime/process_function_library_runtime.cc",
|
||||
"common_runtime/process_state.cc",
|
||||
"common_runtime/process_util.cc",
|
||||
"common_runtime/propagator_state.cc",
|
||||
"common_runtime/renamed_device.cc",
|
||||
"common_runtime/rendezvous_mgr.cc",
|
||||
"common_runtime/rendezvous_util.cc",
|
||||
|
142
tensorflow/core/common_runtime/entry.h
Normal file
142
tensorflow/core/common_runtime/entry.h
Normal file
@ -0,0 +1,142 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/manual_constructor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class mutex;
|
||||
class Tensor;
|
||||
|
||||
// An Entry store a single input value for an individual kernel invocation in
|
||||
// an executor.
|
||||
//
|
||||
// Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
|
||||
struct Entry {
|
||||
enum class State {
|
||||
NO_VALUE = 0, // The default state for a newly-created Entry.
|
||||
HAS_VALUE, // `this->val` is valid.
|
||||
HAS_CONST_TENSOR, // `this->const_tensor` is valid.
|
||||
HAS_REF_TENSOR, // `this->ref_tensor` is valid.
|
||||
};
|
||||
|
||||
Entry() : state(State::NO_VALUE) {}
|
||||
Entry(const Entry& other) : state(other.state), alloc_attr(other.alloc_attr) {
|
||||
switch (state) {
|
||||
case State::NO_VALUE:
|
||||
break;
|
||||
case State::HAS_VALUE:
|
||||
val.Init(*other.val);
|
||||
break;
|
||||
case State::HAS_CONST_TENSOR:
|
||||
const_tensor = other.const_tensor;
|
||||
break;
|
||||
case State::HAS_REF_TENSOR:
|
||||
ref_tensor = other.ref_tensor;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
~Entry() {
|
||||
if (state == State::HAS_VALUE) val.Destroy();
|
||||
}
|
||||
|
||||
Entry& operator=(const Entry& other) {
|
||||
if (state == State::HAS_VALUE) {
|
||||
val.Destroy();
|
||||
}
|
||||
state = other.state;
|
||||
alloc_attr = other.alloc_attr;
|
||||
switch (state) {
|
||||
case State::NO_VALUE:
|
||||
break;
|
||||
case State::HAS_VALUE:
|
||||
val.Init(*other.val);
|
||||
break;
|
||||
case State::HAS_CONST_TENSOR:
|
||||
const_tensor = other.const_tensor;
|
||||
break;
|
||||
case State::HAS_REF_TENSOR:
|
||||
ref_tensor = other.ref_tensor;
|
||||
break;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Entry& operator=(Entry&& other) {
|
||||
if (state == State::HAS_VALUE) {
|
||||
val.Destroy();
|
||||
}
|
||||
state = other.state;
|
||||
alloc_attr = other.alloc_attr;
|
||||
switch (state) {
|
||||
case State::NO_VALUE:
|
||||
break;
|
||||
case State::HAS_VALUE:
|
||||
val.Init(std::move(*other.val));
|
||||
break;
|
||||
case State::HAS_CONST_TENSOR:
|
||||
const_tensor = other.const_tensor;
|
||||
break;
|
||||
case State::HAS_REF_TENSOR:
|
||||
ref_tensor = other.ref_tensor;
|
||||
break;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Clears the <val> field, and sets this entry to the `NO_VALUE` state.
|
||||
void ClearVal() {
|
||||
if (state == State::HAS_VALUE) {
|
||||
val.Destroy();
|
||||
}
|
||||
state = State::NO_VALUE;
|
||||
}
|
||||
|
||||
union {
|
||||
// A tensor value. Valid iff `state_ == HAS_VALUE`.
|
||||
ManualConstructor<Tensor> val;
|
||||
|
||||
// A pointer to a constant tensor value. Valid iff `state_ ==
|
||||
// HAS_CONST_TENSOR`.
|
||||
const Tensor* const_tensor;
|
||||
|
||||
// A tensor reference and associated mutex. Valid iff `state_ ==
|
||||
// HAS_REF_TENSOR`.
|
||||
struct {
|
||||
Tensor* tensor;
|
||||
mutex* mu;
|
||||
} ref_tensor;
|
||||
};
|
||||
|
||||
// The current state of this entry, indicating which member of the above
|
||||
// union is active.
|
||||
State state;
|
||||
|
||||
// The attributes of the allocator that creates the tensor.
|
||||
AllocatorAttributes alloc_attr;
|
||||
};
|
||||
|
||||
// TODO(b/152925936): Re-evaluate this constant with current usage patterns.
|
||||
typedef gtl::InlinedVector<Entry, 4> EntryVector;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_
|
File diff suppressed because it is too large
Load Diff
777
tensorflow/core/common_runtime/propagator_state.cc
Normal file
777
tensorflow/core/common_runtime/propagator_state.cc
Normal file
@ -0,0 +1,777 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/propagator_state.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/graph_view.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// 1-D, 0 element tensor.
|
||||
static const Tensor* const kEmptyTensor = new Tensor;
|
||||
|
||||
typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
|
||||
typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
|
||||
|
||||
PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
|
||||
int64 step_id)
|
||||
: immutable_state_(immutable_state),
|
||||
step_id_(step_id),
|
||||
vlog_(VLOG_IS_ON(1)) {
|
||||
// We start the entire execution in iteration 0 of the root frame
|
||||
// so let us create the root frame and the state for iteration 0.
|
||||
// We assume root_frame_->frame_name.empty().
|
||||
root_frame_ = new FrameState(immutable_state_, 1);
|
||||
root_frame_->frame_id = 0; // must be 0
|
||||
root_frame_->InitializeFrameInfo(root_frame_->frame_name);
|
||||
|
||||
// Initialize iteration 0.
|
||||
root_frame_->SetIteration(
|
||||
0, new PropagatorState::IterationState(root_frame_->pending_counts,
|
||||
root_frame_->total_input_tensors));
|
||||
|
||||
outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
|
||||
}
|
||||
|
||||
PropagatorState::~PropagatorState() {
|
||||
for (auto name_frame : outstanding_frames_) {
|
||||
delete name_frame.second;
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
|
||||
TaggedNodeSeq* ready) {
|
||||
for (const NodeItem* item : roots) {
|
||||
DCHECK_EQ(item->num_inputs, 0);
|
||||
ready->push_back(TaggedNode{item, root_frame_, 0, false});
|
||||
}
|
||||
mutex_lock l(root_frame_->mu);
|
||||
root_frame_->GetIteration(0)->outstanding_ops = ready->size();
|
||||
}
|
||||
|
||||
void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
EntryVector* outputs,
|
||||
TaggedNodeSeq* ready) {
|
||||
profiler::TraceMe activity(
|
||||
[&]() {
|
||||
return strings::StrCat(
|
||||
"ExecutorPropagateOutputs#", "id=", step_id_,
|
||||
",kernel_name=", tagged_node.node_item->kernel->name_view(),
|
||||
",num_output_edges=", tagged_node.node_item->num_output_edges,
|
||||
",num_output_control_edges=",
|
||||
tagged_node.node_item->num_output_control_edges, "#");
|
||||
},
|
||||
profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
|
||||
|
||||
const NodeItem* const item = tagged_node.node_item;
|
||||
FrameState* const input_frame = tagged_node.input_frame;
|
||||
const int64 input_iter = tagged_node.input_iter;
|
||||
const bool is_dead = tagged_node.is_dead;
|
||||
|
||||
// Propagates outputs along out edges, and puts newly ready nodes
|
||||
// into the ready queue.
|
||||
DCHECK(ready->empty());
|
||||
bool is_frame_done = false;
|
||||
FrameState* output_frame = input_frame;
|
||||
int64 output_iter = input_iter;
|
||||
|
||||
if (!item->is_enter_exit_or_next_iter) {
|
||||
// Fast path for nodes types that don't need special handling
|
||||
DCHECK_EQ(input_frame, output_frame);
|
||||
// Normal path for most nodes
|
||||
mutex_lock l(input_frame->mu);
|
||||
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
|
||||
is_frame_done =
|
||||
input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
|
||||
} else if (item->is_enter) {
|
||||
FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
|
||||
output_iter = 0;
|
||||
{
|
||||
mutex_lock l(output_frame->mu);
|
||||
if (item->is_constant_enter) {
|
||||
// Propagate to all active iterations if this is a loop invariant.
|
||||
output_frame->AddLoopInv(item, (*outputs)[0], ready);
|
||||
} else {
|
||||
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
|
||||
}
|
||||
output_frame->num_pending_inputs--;
|
||||
}
|
||||
is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
|
||||
} else if (item->is_exit) {
|
||||
if (is_dead) {
|
||||
mutex_lock l(input_frame->mu);
|
||||
// Stop and remember this node if it is a dead exit.
|
||||
if (input_iter == input_frame->iteration_count) {
|
||||
input_frame->dead_exits.push_back(item);
|
||||
}
|
||||
is_frame_done =
|
||||
input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
|
||||
} else {
|
||||
output_frame = input_frame->parent_frame;
|
||||
output_iter = input_frame->parent_iter;
|
||||
{
|
||||
mutex_lock l(output_frame->mu);
|
||||
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
|
||||
}
|
||||
is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
|
||||
}
|
||||
} else {
|
||||
DCHECK(item->is_next_iteration);
|
||||
mutex_lock l(input_frame->mu);
|
||||
if (is_dead) {
|
||||
// Stop the deadness propagation.
|
||||
output_frame = nullptr;
|
||||
} else {
|
||||
if (input_iter == input_frame->iteration_count &&
|
||||
input_frame->num_outstanding_iterations ==
|
||||
input_frame->max_parallel_iterations) {
|
||||
// Reached the maximum for parallel iterations.
|
||||
input_frame->next_iter_roots.push_back({item, (*outputs)[0]});
|
||||
output_frame = nullptr;
|
||||
} else {
|
||||
// If this is a new iteration, start it.
|
||||
if (input_iter == input_frame->iteration_count) {
|
||||
input_frame->IncrementIteration(ready);
|
||||
}
|
||||
output_iter = input_iter + 1;
|
||||
}
|
||||
}
|
||||
if (output_frame != nullptr) {
|
||||
// This is the case when node is not Enter, Exit, or NextIteration.
|
||||
DCHECK(input_frame == output_frame);
|
||||
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
|
||||
}
|
||||
is_frame_done =
|
||||
input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
|
||||
}
|
||||
|
||||
// At this point, this node is completely done. We also know if the
|
||||
// completion of this node makes its frame completed.
|
||||
if (is_frame_done) {
|
||||
FrameState* parent_frame = input_frame->parent_frame;
|
||||
const int64 parent_iter = input_frame->parent_iter;
|
||||
DeleteFrame(input_frame, ready);
|
||||
if (parent_frame != nullptr) {
|
||||
// The completion of frame may cause completions in its parent frame.
|
||||
// So clean things up recursively.
|
||||
CleanupFramesIterations(parent_frame, parent_iter, ready);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Tensor* PropagatorState::GetTensorValueForDump(const Entry& input) {
|
||||
switch (input.state) {
|
||||
case Entry::State::NO_VALUE:
|
||||
return kEmptyTensor;
|
||||
case Entry::State::HAS_VALUE:
|
||||
return input.val.get();
|
||||
case Entry::State::HAS_CONST_TENSOR:
|
||||
return input.const_tensor;
|
||||
case Entry::State::HAS_REF_TENSOR:
|
||||
return input.ref_tensor.tensor;
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::DumpPendingNodeState(
|
||||
const int node_id, const Entry* input_vector,
|
||||
const bool show_nodes_with_no_ready_inputs) {
|
||||
const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
|
||||
const int input_base = node_item.input_start;
|
||||
if (!show_nodes_with_no_ready_inputs) {
|
||||
bool has_ready_input = false;
|
||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||
const Entry& input = input_vector[input_base + i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor->IsInitialized()) {
|
||||
has_ready_input = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!has_ready_input) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG(WARNING) << " Pending Node: " << node_item.DebugString();
|
||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||
const Entry& input = input_vector[input_base + i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor->IsInitialized()) {
|
||||
LOG(WARNING) << " Input " << i << ": "
|
||||
<< strings::StrCat(
|
||||
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||
" shape: ", tensor->shape().DebugString(), ">");
|
||||
} else {
|
||||
LOG(WARNING) << " Input " << i << ": not present";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::DumpActiveNodeState(const int node_id,
|
||||
const Entry* input_vector) {
|
||||
const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
|
||||
LOG(WARNING) << " Active Node: " << node_item.DebugString();
|
||||
const int input_base = node_item.input_start;
|
||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||
const Entry& input = input_vector[input_base + i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor->IsInitialized()) {
|
||||
LOG(WARNING) << " Input " << i << ": "
|
||||
<< strings::StrCat(
|
||||
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||
" shape: ", tensor->shape().DebugString(), ">");
|
||||
} else {
|
||||
LOG(WARNING) << " Input " << i << ": not present";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::DumpIterationState(const FrameState* frame,
|
||||
IterationState* iteration) {
|
||||
const std::vector<const NodeItem*>* nodes = frame->nodes;
|
||||
// Dump any waiting nodes that are holding on to tensors.
|
||||
for (const NodeItem* node : *nodes) {
|
||||
PendingCounts::Handle pending_id =
|
||||
immutable_state_.pending_ids()[node->node_id];
|
||||
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
|
||||
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
|
||||
DumpPendingNodeState(node->node_id, iteration->input_tensors, false);
|
||||
}
|
||||
}
|
||||
// Then the active nodes.
|
||||
for (const NodeItem* node : *nodes) {
|
||||
PendingCounts::Handle pending_id =
|
||||
immutable_state_.pending_ids()[node->node_id];
|
||||
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
|
||||
DumpActiveNodeState(node->node_id, iteration->input_tensors);
|
||||
}
|
||||
}
|
||||
// Show all input tensors in use.
|
||||
const int total_input_tensors = frame->total_input_tensors;
|
||||
size_t total_bytes = 0;
|
||||
for (int i = 0; i < total_input_tensors; ++i) {
|
||||
const Entry& input = iteration->input_tensors[i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor->IsInitialized()) {
|
||||
LOG(WARNING) << " Input " << i << ": "
|
||||
<< strings::StrCat(
|
||||
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||
" shape: ", tensor->shape().DebugString(),
|
||||
", bytes: ", tensor->TotalBytes(), ">");
|
||||
total_bytes += tensor->TotalBytes();
|
||||
}
|
||||
}
|
||||
LOG(WARNING) << " Total bytes " << total_bytes;
|
||||
}
|
||||
|
||||
void PropagatorState::DumpState() {
|
||||
mutex_lock l(mu_);
|
||||
if (!dumped_on_error_) {
|
||||
LOG(WARNING) << "Dumping state";
|
||||
for (auto& frame : outstanding_frames_) {
|
||||
LOG(WARNING) << frame.first;
|
||||
FrameState* frame_state = frame.second;
|
||||
frame_state->DumpIterationState(this);
|
||||
}
|
||||
dumped_on_error_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
const NodeItem& node_item,
|
||||
FrameState** child) {
|
||||
// Get the child frame name.
|
||||
AttrSlice attrs(node_item.kernel->def());
|
||||
const string& enter_name = GetNodeAttrString(attrs, "frame_name");
|
||||
DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
|
||||
<< node_item.kernel->name();
|
||||
const string child_name =
|
||||
strings::StrCat(frame->frame_name, ";", iter, ";", enter_name);
|
||||
|
||||
{
|
||||
mutex_lock executor_lock(mu_);
|
||||
auto it = outstanding_frames_.find(child_name);
|
||||
if (it != outstanding_frames_.end()) {
|
||||
*child = it->second;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Need to create a new frame instance.
|
||||
// Note that this new frame instance is created without any locks.
|
||||
if (vlog_) VLOG(2) << "Create frame: " << child_name;
|
||||
|
||||
int parallel_iters;
|
||||
bool found_parallel_iters =
|
||||
TryGetNodeAttr(attrs, "parallel_iterations", ¶llel_iters);
|
||||
DCHECK(found_parallel_iters)
|
||||
<< "Could not find \"parallel_iterations\" attr in node "
|
||||
<< node_item.kernel->name();
|
||||
FrameState* temp = new FrameState(immutable_state_, parallel_iters);
|
||||
temp->frame_name = child_name;
|
||||
temp->frame_id = Hash64(child_name);
|
||||
temp->parent_frame = frame;
|
||||
temp->parent_iter = iter;
|
||||
temp->InitializeFrameInfo(enter_name);
|
||||
|
||||
// Initialize iteration 0.
|
||||
{
|
||||
mutex_lock l(temp->mu);
|
||||
temp->SetIteration(
|
||||
0, new IterationState(temp->pending_counts, temp->total_input_tensors));
|
||||
}
|
||||
|
||||
{
|
||||
mutex_lock executor_lock(mu_);
|
||||
auto it = outstanding_frames_.find(child_name);
|
||||
if (it != outstanding_frames_.end()) {
|
||||
*child = it->second;
|
||||
} else {
|
||||
mutex_lock frame_lock(frame->mu);
|
||||
frame->GetIteration(iter)->outstanding_frame_count++;
|
||||
outstanding_frames_[child_name] = temp;
|
||||
*child = temp;
|
||||
temp = nullptr;
|
||||
}
|
||||
}
|
||||
delete temp; // Not used so delete it.
|
||||
}
|
||||
|
||||
void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
||||
// First, propagate dead_exits (if any) to the parent frame.
|
||||
FrameState* parent_frame = frame->parent_frame;
|
||||
const int64 parent_iter = frame->parent_iter;
|
||||
if (parent_frame != nullptr) {
|
||||
mutex_lock parent_frame_lock(parent_frame->mu);
|
||||
// Propagate all the dead exits to the parent frame.
|
||||
mutex_lock this_frame_lock(frame->mu);
|
||||
|
||||
for (const NodeItem* item : frame->dead_exits) {
|
||||
auto parent_iter_state = parent_frame->GetIteration(parent_iter);
|
||||
|
||||
auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
|
||||
bool dst_dead) {
|
||||
if (dst_ready) {
|
||||
if (dst_item.is_control_trigger) dst_dead = false;
|
||||
ready->emplace_back(&dst_item, parent_frame, parent_iter, dst_dead);
|
||||
parent_iter_state->outstanding_ops++;
|
||||
}
|
||||
};
|
||||
|
||||
auto propagate_to_non_merge = [&](PendingCounts::Handle dst_pending_id) {
|
||||
parent_iter_state->increment_dead_count(dst_pending_id);
|
||||
return parent_iter_state->decrement_pending(dst_pending_id, 1) == 0;
|
||||
};
|
||||
|
||||
for (const EdgeInfo& e : item->output_edges()) {
|
||||
const NodeItem& dst_item =
|
||||
*immutable_state_.graph_view().node(e.dst_id);
|
||||
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
|
||||
|
||||
bool dst_dead = true;
|
||||
bool dst_ready;
|
||||
// We know this is a dead input to dst.
|
||||
if (dst_item.is_merge) {
|
||||
parent_iter_state->increment_dead_count(dst_pending_id);
|
||||
const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
|
||||
dst_dead = (dead_cnt == dst_item.num_inputs);
|
||||
dst_ready =
|
||||
(parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
|
||||
} else {
|
||||
dst_ready = propagate_to_non_merge(dst_pending_id);
|
||||
}
|
||||
maybe_add_to_ready(dst_item, dst_ready, dst_dead);
|
||||
}
|
||||
|
||||
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
||||
const NodeItem& dst_item =
|
||||
*immutable_state_.graph_view().node(e.dst_id);
|
||||
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
|
||||
|
||||
bool dst_dead;
|
||||
bool dst_ready;
|
||||
// We know this is a dead input to dst.
|
||||
if (dst_item.is_merge) {
|
||||
parent_iter_state->decrement_pending(dst_pending_id, 2);
|
||||
int count = parent_iter_state->pending(dst_pending_id);
|
||||
int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
|
||||
dst_dead = (dead_cnt == dst_item.num_inputs);
|
||||
dst_ready = (count == 0) || ((count == 1) && dst_dead);
|
||||
} else {
|
||||
dst_dead = true;
|
||||
dst_ready = propagate_to_non_merge(dst_pending_id);
|
||||
}
|
||||
maybe_add_to_ready(dst_item, dst_ready, dst_dead);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the frame.
|
||||
const string& frame_name = frame->frame_name;
|
||||
if (vlog_) VLOG(2) << "Delete frame " << frame_name;
|
||||
{
|
||||
mutex_lock executor_lock(mu_);
|
||||
outstanding_frames_.erase(frame_name);
|
||||
}
|
||||
delete frame;
|
||||
}
|
||||
|
||||
void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter,
|
||||
TaggedNodeSeq* ready) {
|
||||
bool is_frame_done = false;
|
||||
{
|
||||
mutex_lock frame_lock(frame->mu);
|
||||
frame->GetIteration(iter)->outstanding_frame_count--;
|
||||
is_frame_done = frame->CleanupIterations(iter, ready);
|
||||
}
|
||||
if (is_frame_done) {
|
||||
FrameState* parent_frame = frame->parent_frame;
|
||||
const int64 parent_iter = frame->parent_iter;
|
||||
DeleteFrame(frame, ready);
|
||||
if (parent_frame != nullptr) {
|
||||
// The completion of frame may cause completions in its parent frame.
|
||||
// So clean things up recursively.
|
||||
CleanupFramesIterations(parent_frame, parent_iter, ready);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
|
||||
const bool is_dead,
|
||||
int64 iter,
|
||||
EntryVector* outputs,
|
||||
TaggedNodeSeq* ready) {
|
||||
// If we know that none of the item's edge destinations require special
|
||||
// handling (i.e. none of the nodes is a merge or control trigger node), we
|
||||
// can take a fast path that avoids accessing the destination NodeItem.
|
||||
const GraphView& gview = immutable_state.graph_view();
|
||||
IterationState* iter_state = GetIteration(iter);
|
||||
|
||||
// Add dst to the ready queue if it's ready
|
||||
//
|
||||
// NOTE(mrry): Use a macro here instead of a lambda, because this method is
|
||||
// performance-critical and we need to ensure that the code is inlined.
|
||||
#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
|
||||
do { \
|
||||
if (!adjust_result.any_pending) { \
|
||||
const NodeItem* dst_item = gview.node(dst_id); \
|
||||
TaggedNode& t = ready->emplace_back(); \
|
||||
t.node_item = dst_item; \
|
||||
t.input_frame = this; \
|
||||
t.input_iter = iter; \
|
||||
t.is_dead = adjust_result.any_dead; \
|
||||
iter_state->outstanding_ops++; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
Entry* input_tensors = iter_state->input_tensors;
|
||||
|
||||
for (const EdgeInfo& e : item->output_edges()) {
|
||||
const int dst_id = e.dst_id;
|
||||
const PendingCounts::Handle dst_pending_id =
|
||||
immutable_state.pending_ids()[dst_id];
|
||||
const int src_slot = e.output_slot;
|
||||
|
||||
const bool increment_dead =
|
||||
(is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
|
||||
const PendingCounts::AdjustResult adjust_result =
|
||||
iter_state->adjust_for_activation(dst_pending_id, increment_dead);
|
||||
const int dst_loc = e.input_slot;
|
||||
if (e.is_last) {
|
||||
input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
|
||||
} else {
|
||||
input_tensors[dst_loc] = (*outputs)[src_slot];
|
||||
}
|
||||
MAYBE_ADD_TO_READY(dst_id, adjust_result);
|
||||
}
|
||||
|
||||
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
||||
const int dst_id = e.dst_id;
|
||||
const PendingCounts::Handle dst_pending_id =
|
||||
immutable_state.pending_ids()[dst_id];
|
||||
const PendingCounts::AdjustResult adjust_result =
|
||||
iter_state->adjust_for_activation(dst_pending_id, is_dead);
|
||||
MAYBE_ADD_TO_READY(dst_id, adjust_result);
|
||||
}
|
||||
#undef MAYBE_ADD_TO_READY
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
|
||||
const bool is_dead,
|
||||
int64 iter,
|
||||
EntryVector* outputs,
|
||||
TaggedNodeSeq* ready) {
|
||||
// If any of the edge destinations is a merge or a control trigger node,
|
||||
// we need to read each destination NodeItem to determine what action
|
||||
// to take.
|
||||
const GraphView& gview = immutable_state.graph_view();
|
||||
IterationState* iter_state = GetIteration(iter);
|
||||
|
||||
auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
|
||||
bool dst_ready, bool dst_dead) {
|
||||
// Add dst to the ready queue if it's ready
|
||||
if (dst_ready) {
|
||||
if (dst_item->is_control_trigger) dst_dead = false;
|
||||
ready->emplace_back(dst_item, this, iter, dst_dead);
|
||||
iter_state->outstanding_ops++;
|
||||
}
|
||||
};
|
||||
|
||||
Entry* input_tensors = iter_state->input_tensors;
|
||||
|
||||
for (const EdgeInfo& e : item->output_edges()) {
|
||||
const int dst_id = e.dst_id;
|
||||
const NodeItem* dst_item = gview.node(dst_id);
|
||||
const PendingCounts::Handle dst_pending_id =
|
||||
immutable_state.pending_ids()[dst_id];
|
||||
const int src_slot = e.output_slot;
|
||||
|
||||
bool dst_dead = false;
|
||||
bool dst_ready = false;
|
||||
bool dst_need_input = true;
|
||||
|
||||
if (dst_item->is_merge) {
|
||||
// A merge node is ready if all control inputs have arrived and either
|
||||
// a) a live data input becomes available or b) all data inputs are
|
||||
// dead. For Merge, pending's LSB is set iff a live data input has
|
||||
// arrived.
|
||||
if ((*outputs)[src_slot].state != Entry::State::NO_VALUE) {
|
||||
// This is a live data input.
|
||||
int count = iter_state->pending(dst_pending_id);
|
||||
iter_state->mark_live(dst_pending_id);
|
||||
// Only the first live edge sets the input and (potentially)
|
||||
// triggers execution. The low bit of count is set if and
|
||||
// only if no live input has been used yet (mark_live clears
|
||||
// it). The node should be started if and only if this is
|
||||
// the first live input and there are no pending control
|
||||
// edges, i.e. count == 1.
|
||||
dst_ready = (count == 1);
|
||||
dst_need_input = ((count & 0x1) == 1);
|
||||
} else {
|
||||
// This is a dead data input. Note that dst_node is dead if node is
|
||||
// a dead enter. We need this to handle properly a while loop on
|
||||
// the untaken branch of a conditional.
|
||||
// TODO(yuanbyu): This is a bit hacky, but a good solution for
|
||||
// now.
|
||||
iter_state->increment_dead_count(dst_pending_id);
|
||||
const int dead_cnt = iter_state->dead_count(dst_pending_id);
|
||||
dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
|
||||
dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
|
||||
dst_need_input = false;
|
||||
}
|
||||
} else {
|
||||
// Handle all other (non-merge) nodes.
|
||||
const bool increment_dead =
|
||||
(is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
|
||||
const PendingCounts::AdjustResult adjust_result =
|
||||
iter_state->adjust_for_activation(dst_pending_id, increment_dead);
|
||||
dst_dead = adjust_result.any_dead;
|
||||
dst_ready = !adjust_result.any_pending;
|
||||
}
|
||||
|
||||
if (dst_need_input) {
|
||||
const int dst_loc = e.input_slot;
|
||||
if (e.is_last) {
|
||||
input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
|
||||
} else {
|
||||
input_tensors[dst_loc] = (*outputs)[src_slot];
|
||||
}
|
||||
}
|
||||
|
||||
maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
|
||||
}
|
||||
|
||||
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
||||
const int dst_id = e.dst_id;
|
||||
const NodeItem* dst_item = gview.node(dst_id);
|
||||
const PendingCounts::Handle dst_pending_id =
|
||||
immutable_state.pending_ids()[dst_id];
|
||||
|
||||
bool dst_dead;
|
||||
bool dst_ready;
|
||||
if (dst_item->is_merge) {
|
||||
// A merge node is ready if all control inputs have arrived and either
|
||||
// a) a live data input becomes available or b) all data inputs are
|
||||
// dead. For Merge, pending's LSB is set iff a live data input has
|
||||
// arrived.
|
||||
iter_state->decrement_pending(dst_pending_id, 2);
|
||||
int count = iter_state->pending(dst_pending_id);
|
||||
int dead_cnt = iter_state->dead_count(dst_pending_id);
|
||||
dst_dead = (dead_cnt == dst_item->num_inputs);
|
||||
dst_ready = (count == 0) || ((count == 1) && dst_dead);
|
||||
} else {
|
||||
// Handle all other (non-merge) nodes.
|
||||
const PendingCounts::AdjustResult adjust_result =
|
||||
iter_state->adjust_for_activation(dst_pending_id, is_dead);
|
||||
dst_dead = adjust_result.any_dead;
|
||||
dst_ready = !adjust_result.any_pending;
|
||||
}
|
||||
maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::ActivateNodes(const NodeItem* item,
|
||||
const bool is_dead, int64 iter,
|
||||
EntryVector* outputs,
|
||||
TaggedNodeSeq* ready) {
|
||||
if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
|
||||
ActivateNodesSlowPath(item, is_dead, iter, outputs, ready);
|
||||
} else {
|
||||
ActivateNodesFastPath(item, is_dead, iter, outputs, ready);
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::ActivateNexts(int64 iter,
|
||||
TaggedNodeSeq* ready) {
|
||||
// Propagate the deferred NextIteration nodes to the new iteration.
|
||||
for (auto& node_entry : next_iter_roots) {
|
||||
const NodeItem* item = node_entry.first;
|
||||
const Entry& entry = node_entry.second;
|
||||
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
||||
EntryVector outputs{entry};
|
||||
ActivateNodes(item, is_dead, iter, &outputs, ready);
|
||||
}
|
||||
next_iter_roots.clear();
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::ActivateLoopInvs(int64 iter,
|
||||
TaggedNodeSeq* ready) {
|
||||
// Propagate loop invariants to the new iteration.
|
||||
for (auto& node_entry : inv_values) {
|
||||
const NodeItem* item = node_entry.first;
|
||||
const Entry& entry = node_entry.second;
|
||||
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
||||
EntryVector outputs{entry};
|
||||
ActivateNodes(item, is_dead, iter, &outputs, ready);
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::AddLoopInv(const NodeItem* item,
|
||||
const Entry& entry,
|
||||
TaggedNodeSeq* ready) {
|
||||
// Store this value.
|
||||
inv_values.push_back({item, entry});
|
||||
|
||||
// Make this value available to all iterations.
|
||||
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
||||
for (int i = 0; i <= iteration_count; ++i) {
|
||||
EntryVector outputs{entry};
|
||||
ActivateNodes(item, is_dead, i, &outputs, ready);
|
||||
}
|
||||
}
|
||||
|
||||
bool PropagatorState::FrameState::IsIterationDone(int64 iter) {
|
||||
IterationState* iter_state = GetIteration(iter);
|
||||
if (iter_state->outstanding_ops == 0 &&
|
||||
iter_state->outstanding_frame_count == 0) {
|
||||
if (iter == 0) {
|
||||
// The enclosing frame has no pending input.
|
||||
return num_pending_inputs == 0;
|
||||
} else {
|
||||
// The preceding iteration is deleted (and therefore done).
|
||||
return (GetIteration(iter - 1) == nullptr);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
|
||||
iteration_count++;
|
||||
const int64 next_iter = iteration_count;
|
||||
|
||||
// Initialize the next iteration.
|
||||
IterationState* iter_state =
|
||||
new IterationState(pending_counts, total_input_tensors);
|
||||
SetIteration(next_iter, iter_state);
|
||||
num_outstanding_iterations++;
|
||||
dead_exits.clear();
|
||||
|
||||
// Activate the successors of the deferred roots in the new iteration.
|
||||
ActivateNexts(next_iter, ready);
|
||||
|
||||
// Activate the loop invariants in the new iteration.
|
||||
ActivateLoopInvs(next_iter, ready);
|
||||
}
|
||||
|
||||
bool PropagatorState::FrameState::CleanupIterations(int64 iter,
|
||||
TaggedNodeSeq* ready) {
|
||||
int64 curr_iter = iter;
|
||||
while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) {
|
||||
// Delete the iteration curr_iter.
|
||||
delete GetIteration(curr_iter);
|
||||
SetIteration(curr_iter, nullptr);
|
||||
--num_outstanding_iterations;
|
||||
++curr_iter;
|
||||
|
||||
// When one iteration is completed, we check for deferred iteration,
|
||||
// and start it if there is one.
|
||||
if (!next_iter_roots.empty()) {
|
||||
IncrementIteration(ready);
|
||||
}
|
||||
}
|
||||
return IsFrameDone();
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::InitializeFrameInfo(
|
||||
const string& enter_name) {
|
||||
const ImmutableExecutorState::FrameInfo* finfo =
|
||||
immutable_state.get_frame_info(enter_name);
|
||||
DCHECK_NE(finfo, nullptr);
|
||||
pending_counts = finfo->pending_counts.get();
|
||||
total_input_tensors = finfo->total_inputs;
|
||||
num_pending_inputs = finfo->input_count;
|
||||
nodes = finfo->nodes.get();
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::SetIteration(int64 iter,
|
||||
IterationState* state)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
size_t index = iter % (max_parallel_iterations + 1);
|
||||
DCHECK(state == nullptr || iterations[index] == nullptr);
|
||||
iterations_raw[index] = state;
|
||||
if (index == 0) {
|
||||
iterations_first = state;
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement the outstanding op count and clean up the iterations in the
|
||||
// frame. Return true iff the execution of the frame is done.
|
||||
bool PropagatorState::FrameState::DecrementOutstandingOps(
|
||||
int64 iter, TaggedNodeSeq* ready) {
|
||||
mutex_lock l(mu);
|
||||
return DecrementOutstandingOpsLocked(iter, ready);
|
||||
}
|
||||
|
||||
// Decrement the outstanding op count and clean up the iterations in the
|
||||
// frame. Return true iff the execution of the frame is done.
|
||||
bool PropagatorState::FrameState::DecrementOutstandingOpsLocked(
|
||||
int64 iter, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
IterationState* istate = GetIteration(iter);
|
||||
istate->outstanding_ops--;
|
||||
if (istate->outstanding_ops != 0) {
|
||||
return false;
|
||||
} else {
|
||||
return CleanupIterations(iter, ready);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if the computation in the frame is completed.
|
||||
bool PropagatorState::FrameState::IsFrameDone()
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
466
tensorflow/core/common_runtime/propagator_state.h
Normal file
466
tensorflow/core/common_runtime/propagator_state.h
Normal file
@ -0,0 +1,466 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/entry.h"
|
||||
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
|
||||
#include "tensorflow/core/common_runtime/pending_counts.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
|
||||
typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
|
||||
|
||||
// Represents the ephemeral "edge state" associated with one invocation of
|
||||
// `Executor::Run()`.
|
||||
//
|
||||
// `PropagatorState` is responsible for propagating values along dataflow
|
||||
// edges in a TensorFlow graph and determining which nodes are runnable. The
|
||||
// executor primarily updates `PropagatorState` by calling `PropagateOutputs()`
|
||||
// after processing a node, and `PropagatorState` dispatches `TaggedNode`s by
|
||||
// adding them to a `TaggedNodeSeq`.
|
||||
class PropagatorState {
|
||||
public:
|
||||
PropagatorState(const ImmutableExecutorState& immutable_state, int64 step_id);
|
||||
~PropagatorState();
|
||||
|
||||
private:
|
||||
// Forward declaration so that `TaggedNode` can include a `FrameState*`.
|
||||
struct FrameState;
|
||||
|
||||
public:
|
||||
// A `TaggedNode` corresponds to a single invocation of a node's kernel,
|
||||
// and it is created when the kernel becomes runnable (in a particular
|
||||
// iteration of a particular frame).
|
||||
struct TaggedNode {
|
||||
const NodeItem* node_item;
|
||||
FrameState* input_frame;
|
||||
int64 input_iter;
|
||||
bool is_dead;
|
||||
|
||||
TaggedNode() = default;
|
||||
TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter,
|
||||
bool dead)
|
||||
: node_item(node_item),
|
||||
input_frame(in_frame),
|
||||
input_iter(in_iter),
|
||||
is_dead(dead) {}
|
||||
|
||||
const NodeItem& get_node_item() const { return *node_item; }
|
||||
|
||||
bool get_is_dead() const { return is_dead; }
|
||||
};
|
||||
|
||||
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
|
||||
// have that many nodes in the ready queue, so we just use a vector and
|
||||
// don't free up memory from the queue as we consume nodes.
|
||||
class TaggedNodeReadyQueue {
|
||||
public:
|
||||
TaggedNodeReadyQueue() : front_index_(0) {}
|
||||
|
||||
void push_back(const TaggedNode& node) { ready_.push_back(node); }
|
||||
TaggedNode front() const {
|
||||
DCHECK_LT(front_index_, ready_.size());
|
||||
return ready_[front_index_];
|
||||
}
|
||||
void pop_front() {
|
||||
DCHECK_LT(front_index_, ready_.size());
|
||||
front_index_++;
|
||||
if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
|
||||
if (front_index_ == ready_.size()) {
|
||||
ready_.clear();
|
||||
} else {
|
||||
// Lots of unused entries at beginning of vector: move everything
|
||||
// down to start of vector.
|
||||
ready_.erase(ready_.begin(), ready_.begin() + front_index_);
|
||||
}
|
||||
front_index_ = 0;
|
||||
}
|
||||
}
|
||||
bool empty() const { return ready_.empty(); }
|
||||
|
||||
private:
|
||||
// TODO(b/152925936): Re-evaluate these constants with current usage
|
||||
// patterns.
|
||||
static constexpr int kSpillThreshold = 16384;
|
||||
gtl::InlinedVector<TaggedNode, 16> ready_;
|
||||
int front_index_;
|
||||
};
|
||||
|
||||
// TODO(b/152925936): Re-evaluate this constant with current usage patterns.
|
||||
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
|
||||
|
||||
private:
|
||||
struct IterationState {
|
||||
explicit IterationState(const PendingCounts* pending_counts,
|
||||
int total_input_tensors)
|
||||
: input_tensors(new Entry[total_input_tensors]),
|
||||
outstanding_ops(0),
|
||||
outstanding_frame_count(0),
|
||||
counts(*pending_counts) { // Initialize with copy of *pending_counts
|
||||
}
|
||||
|
||||
// The state of an iteration.
|
||||
|
||||
// One copy per iteration. For iteration k, i-th node's j-th input is in
|
||||
// input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
|
||||
// either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
|
||||
//
|
||||
// NOTE: No need to protect input_tensors[i] by any locks because it
|
||||
// is resized once. Each element of tensors_ is written once by the
|
||||
// source node of an edge and is cleared by the destination of the same
|
||||
// edge. The latter node is never run concurrently with the former node.
|
||||
Entry* input_tensors;
|
||||
|
||||
// The number of outstanding ops for each iteration.
|
||||
size_t outstanding_ops;
|
||||
|
||||
// The number of outstanding frames for each iteration.
|
||||
int outstanding_frame_count;
|
||||
int pending(PendingCounts::Handle h) { return counts.pending(h); }
|
||||
int decrement_pending(PendingCounts::Handle h, int v) {
|
||||
return counts.decrement_pending(h, v);
|
||||
}
|
||||
// Mark a merge node as live
|
||||
// REQUIRES: Node corresponding to "h" is a merge node
|
||||
void mark_live(PendingCounts::Handle h) { counts.mark_live(h); }
|
||||
// Mark a node to show that processing has started.
|
||||
void mark_started(PendingCounts::Handle h) { counts.mark_started(h); }
|
||||
// Mark a node to show that processing has completed.
|
||||
void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); }
|
||||
PendingCounts::NodeState node_state(PendingCounts::Handle h) {
|
||||
return counts.node_state(h);
|
||||
}
|
||||
|
||||
int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); }
|
||||
void increment_dead_count(PendingCounts::Handle h) {
|
||||
counts.increment_dead_count(h);
|
||||
}
|
||||
PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h,
|
||||
bool increment_dead) {
|
||||
return counts.adjust_for_activation(h, increment_dead);
|
||||
}
|
||||
|
||||
~IterationState() { delete[] input_tensors; }
|
||||
|
||||
private:
|
||||
PendingCounts counts;
|
||||
};
|
||||
|
||||
struct FrameState {
|
||||
explicit FrameState(const ImmutableExecutorState& immutable_state,
|
||||
int parallel_iters)
|
||||
: immutable_state(immutable_state),
|
||||
max_parallel_iterations(parallel_iters),
|
||||
num_outstanding_iterations(1),
|
||||
iterations(parallel_iters + 1),
|
||||
iterations_raw(iterations.data()) {}
|
||||
|
||||
// A new frame is created for each loop. Execution starts at iteration 0.
|
||||
// When a value at iteration 0 passes through a NextIteration node,
|
||||
// iteration 1 is created and starts running. Note that iteration 0 may
|
||||
// still be running so multiple iterations may run in parallel. The
|
||||
// frame maintains the state of iterations in several data structures
|
||||
// such as pending_count and input_tensors. When iteration 0 completes,
|
||||
// we garbage collect the state of iteration 0.
|
||||
//
|
||||
// A frame instance is considered "done" and can be garbage collected
|
||||
// if all its inputs have entered and all its iterations are "done".
|
||||
//
|
||||
// A frame manages the live iterations of an iterative computation.
|
||||
// Iteration i is considered "done" when there are no outstanding ops,
|
||||
// frames at iteration i are done, all recvs for this iteration are
|
||||
// completed, and iteration i-1 is done. For iteration 0, we instead
|
||||
// wait for there to be no more pending inputs of the frame.
|
||||
//
|
||||
// Frames and iterations are garbage collected once they are done.
|
||||
// The state we need to keep around is highly dependent on the
|
||||
// parallelism enabled by the scheduler. We may want to have the
|
||||
// scheduler dynamically control the outstanding number of live
|
||||
// parallel frames and iterations. To reduce the state space, the
|
||||
// scheduler might want to schedule ops in inner frames first and
|
||||
// lower iterations first.
|
||||
//
|
||||
// This frame state is mostly initialized lazily on demand so we
|
||||
// don't introduce unnecessary overhead.
|
||||
|
||||
// The immutable state of the executor the frame is in.
|
||||
const ImmutableExecutorState& immutable_state;
|
||||
|
||||
// The name of this frame, which is the concatenation of its parent
|
||||
// frame name, the iteration of the parent frame when this frame was
|
||||
// created, and the value of the attr 'frame_name'.
|
||||
string frame_name;
|
||||
|
||||
// The unique id for this frame. Generated by fingerprinting
|
||||
// frame_name.
|
||||
uint64 frame_id;
|
||||
|
||||
// The iteration id of its parent frame when this frame is created.
|
||||
// -1 if there is no parent frame. The frame_name/parent_iter pair
|
||||
// uniquely identifies this FrameState.
|
||||
int64 parent_iter = -1;
|
||||
|
||||
// The FrameState of its parent frame.
|
||||
FrameState* parent_frame = nullptr;
|
||||
|
||||
// The maximum allowed number of parallel iterations.
|
||||
const int max_parallel_iterations;
|
||||
|
||||
// The number of inputs this frame is still waiting.
|
||||
int num_pending_inputs = 0;
|
||||
|
||||
// The highest iteration number we have reached so far in this frame.
|
||||
int64 iteration_count TF_GUARDED_BY(mu) = 0;
|
||||
|
||||
// The number of outstanding iterations.
|
||||
int num_outstanding_iterations TF_GUARDED_BY(mu) = 1;
|
||||
|
||||
private:
|
||||
// The active iteration states of this frame.
|
||||
gtl::InlinedVector<IterationState*, 12> iterations;
|
||||
IterationState** const iterations_raw TF_GUARDED_BY(mu);
|
||||
IterationState* iterations_first TF_GUARDED_BY(mu);
|
||||
|
||||
public:
|
||||
// The NextIteration nodes to enter a new iteration. If the number of
|
||||
// outstanding iterations reaches the limit, we will defer the start of
|
||||
// the next iteration until the number of outstanding iterations falls
|
||||
// below the limit.
|
||||
std::vector<std::pair<const NodeItem*, Entry>> next_iter_roots
|
||||
TF_GUARDED_BY(mu);
|
||||
|
||||
// The values of the loop invariants for this loop. They are added into
|
||||
// this list as they "enter" the frame. When a loop invariant enters,
|
||||
// we make it available to all active iterations. When the frame starts
|
||||
// a new iteration, we make all the current loop invariants available
|
||||
// to the new iteration.
|
||||
std::vector<std::pair<const NodeItem*, Entry>> inv_values TF_GUARDED_BY(mu);
|
||||
|
||||
// The list of dead exit node items for the current highest iteration. We
|
||||
// will only "execute" the dead exits of the final iteration.
|
||||
std::vector<const NodeItem*> dead_exits TF_GUARDED_BY(mu);
|
||||
|
||||
// Static information specific to this frame.
|
||||
PendingCounts* pending_counts = nullptr;
|
||||
int total_input_tensors = 0;
|
||||
std::vector<const NodeItem*>* nodes = nullptr;
|
||||
|
||||
// Lock ordering: ExecutorState.mu_ < mu;
|
||||
// during structured traversal: parent_frame->mu < mu.
|
||||
mutex mu;
|
||||
|
||||
void InitializeFrameInfo(const string& enter_name);
|
||||
|
||||
inline IterationState* GetIteration(int64 iter)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
if (TF_PREDICT_TRUE(iter == 0)) {
|
||||
return iterations_first;
|
||||
} else {
|
||||
size_t index = iter % (max_parallel_iterations + 1);
|
||||
return iterations_raw[index];
|
||||
}
|
||||
}
|
||||
|
||||
void SetIteration(int64 iter, IterationState* state);
|
||||
|
||||
// Decrement the outstanding op count and clean up the iterations in the
|
||||
// frame. Return true iff the execution of the frame is done.
|
||||
bool DecrementOutstandingOps(int64 iter, TaggedNodeSeq* ready);
|
||||
|
||||
// Decrement the outstanding op count and clean up the iterations in the
|
||||
// frame. Return true iff the execution of the frame is done.
|
||||
bool DecrementOutstandingOpsLocked(int64 iter, TaggedNodeSeq* ready);
|
||||
|
||||
// Returns true if the computation in the frame is completed.
|
||||
bool IsFrameDone();
|
||||
|
||||
// Returns true if the iteration of the frame is completed.
|
||||
bool IsIterationDone(int64 iter) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Increments the iteration id. If this is a new iteration, initialize it.
|
||||
void IncrementIteration(TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Activate all the deferred NextIteration nodes in a new iteration.
|
||||
void ActivateNexts(int64 iter, TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Activate all the current loop invariants in a new iteration.
|
||||
void ActivateLoopInvs(int64 iter, TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Add a new loop invariant and make it available to all active
|
||||
// iterations.
|
||||
void AddLoopInv(const NodeItem* item, const Entry& entry,
|
||||
TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Activate the successors of a node. Contents of *outputs are left in an
|
||||
// indeterminate state after returning from this method.
|
||||
void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter,
|
||||
EntryVector* outputs, TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Cleanup iterations of this frame starting from iteration iter.
|
||||
bool CleanupIterations(int64 iter, TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
void DumpIterationState(PropagatorState* parent) {
|
||||
mutex_lock l(mu);
|
||||
for (IterationState* iteration : iterations) {
|
||||
if (iteration) {
|
||||
LOG(WARNING) << " Iteration:";
|
||||
parent->DumpIterationState(this, iteration);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~FrameState() {
|
||||
for (size_t i = 0; i < iterations.size(); ++i) {
|
||||
delete iterations[i];
|
||||
iterations[i] = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
|
||||
void ActivateNodesFastPath(const NodeItem* item, const bool is_dead,
|
||||
int64 iter, EntryVector* outputs,
|
||||
TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
|
||||
int64 iter, EntryVector* outputs,
|
||||
TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
};
|
||||
|
||||
public:
|
||||
// Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
|
||||
void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
|
||||
TaggedNodeSeq* ready);
|
||||
|
||||
// After processing the outputs, propagates the outputs to their dsts.
|
||||
// Contents of *outputs are left in an indeterminate state after
|
||||
// returning from this method.
|
||||
void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
|
||||
TaggedNodeSeq* ready);
|
||||
|
||||
// Returns an array of `Entry` objects corresponding to the inputs of
|
||||
// `tagged_node`.
|
||||
//
|
||||
// NOTE: Thread safety analysis is disabled on this method, because the
|
||||
// underlying `IterationState` and its array of `input_tensors` retain the
|
||||
// same address while the iteration is live.
|
||||
Entry* GetInputTensors(const TaggedNode& tagged_node) const
|
||||
TF_NO_THREAD_SAFETY_ANALYSIS {
|
||||
return tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
||||
->input_tensors +
|
||||
tagged_node.node_item->input_start;
|
||||
}
|
||||
|
||||
FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
|
||||
return {tagged_node.input_frame->frame_id, tagged_node.input_iter};
|
||||
}
|
||||
|
||||
// Provide debugging output of the state of the executor.
|
||||
void DumpState();
|
||||
|
||||
// For debugging/logging only.
|
||||
void MaybeMarkStarted(const TaggedNode& tagged_node) {
|
||||
// TODO(misard) Replace with a finer-grain enabling flag once we add better
|
||||
// optional debugging support.
|
||||
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
||||
mutex_lock l(tagged_node.input_frame->mu);
|
||||
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
||||
->mark_started(
|
||||
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||
}
|
||||
}
|
||||
|
||||
void MaybeMarkCompleted(const TaggedNode& tagged_node) {
|
||||
// TODO(misard) Replace with a finer-grain enabling flag once we add better
|
||||
// optional debugging support.
|
||||
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
||||
mutex_lock l(tagged_node.input_frame->mu);
|
||||
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
||||
->mark_completed(
|
||||
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Find an existing or create a new child frame in the frame 'frame' at
|
||||
// iteration 'iter'.
|
||||
void FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
const NodeItem& node_item, FrameState** child);
|
||||
|
||||
// Delete a frame. Called when the frame is done.
|
||||
void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
|
||||
|
||||
// Cleanup frames and iterations starting from frame/iter. Called when
|
||||
// a child frame is done.
|
||||
void CleanupFramesIterations(FrameState* frame, int64 iter,
|
||||
TaggedNodeSeq* ready);
|
||||
|
||||
// Provide debugging output about an outstanding node in the executor.
|
||||
void DumpPendingNodeState(const int node_id, const Entry* input_vector,
|
||||
bool show_nodes_with_no_ready_inputs);
|
||||
void DumpActiveNodeState(const int node_id, const Entry* input_vector);
|
||||
|
||||
// Provide debugging output about an outstanding iteration in the executor.
|
||||
void DumpIterationState(const FrameState* frame, IterationState* iteration);
|
||||
|
||||
const Tensor* GetTensorValueForDump(const Entry& input);
|
||||
|
||||
const ImmutableExecutorState& immutable_state_;
|
||||
const int64 step_id_;
|
||||
const bool vlog_;
|
||||
|
||||
mutex mu_;
|
||||
|
||||
// A flag that is set on error after the frame state has been
|
||||
// dumped for diagnostic purposes.
|
||||
bool dumped_on_error_ TF_GUARDED_BY(mu_) = false;
|
||||
|
||||
// The root frame in which the execution of this step is started.
|
||||
FrameState* root_frame_;
|
||||
|
||||
// Mapping from frame name to outstanding frames. A new frame is created
|
||||
// at some iteration of an active frame. So the unique key for the new
|
||||
// child frame is composed of the name of the parent frame, the iteration
|
||||
// number at which the parent frame is creating the new frame, and the
|
||||
// name of the new frame from nodedef.
|
||||
gtl::FlatMap<string, FrameState*> outstanding_frames_ TF_GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
|
Loading…
Reference in New Issue
Block a user