Rolling forward "[Executor] Implement SimplePropagatorState and use it when a graph has no v1-style control flow."

The previous version had undefined behavior when executing a graph comprised solely of operations with no input tensors, in which case the `SimpleExecutorState::input_tensors_` vector would be empty. Using `vector::data()` instead of `operator[]` avoids creating a reference to a null element when the vector is empty.

PiperOrigin-RevId: 304510355
Change-Id: I12785a6ac72355d0e5b1a312d41e227dbdf9dc96
This commit is contained in:
Derek Murray 2020-04-02 17:22:07 -07:00 committed by TensorFlower Gardener
parent 9ba3ff2df6
commit 0bde6a68f2
12 changed files with 551 additions and 112 deletions

View File

@ -232,6 +232,7 @@ filegroup(
"process_util.h",
"inspecting_placer.h",
"profile_handler.h",
"propagator_debug_utils.h",
"propagator_state.h",
"renamed_device.h",
"rendezvous_mgr.h",
@ -241,6 +242,7 @@ filegroup(
"ring_alg.h",
"ring_gatherer.h",
"session_factory.h",
"simple_propagator_state.h",
"single_threaded_cpu_device.h",
"stats_publisher_interface.h",
"step_stats_collector.h",
@ -304,6 +306,7 @@ tf_cuda_library(
"process_function_library_runtime.cc",
"process_state.cc",
"process_util.cc",
"propagator_debug_utils.cc",
"propagator_state.cc",
"renamed_device.cc",
"rendezvous_mgr.cc",
@ -316,6 +319,7 @@ tf_cuda_library(
"session_factory.cc",
"session_options.cc",
"session_state.cc",
"simple_propagator_state.cc",
"single_threaded_cpu_device.cc",
"stats_publisher_interface.cc",
"step_stats_collector.cc",

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/pending_counts.h"
#include "tensorflow/core/common_runtime/propagator_state.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/simple_propagator_state.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/cancellation.h"
@ -372,6 +373,10 @@ class ExecutorState {
mutex mu_;
Status status_ TF_GUARDED_BY(mu_);
// A flag that is set on error after the propagator state has been
// dumped for diagnostic purposes.
bool dumped_on_error_ TF_GUARDED_BY(mu_) = false;
};
template <class PropagatorStateType>
@ -925,7 +930,11 @@ Status ExecutorState<PropagatorStateType>::ProcessOutputs(
// add better optional debugging support.
if (vlog_ && VLOG_IS_ON(1)) {
LOG(WARNING) << this << " Compute status: " << s;
propagator_.DumpState();
mutex_lock l(mu_);
if (!dumped_on_error_) {
propagator_.DumpState();
dumped_on_error_ = true;
}
}
if (s.code() == error::RESOURCE_EXHAUSTED) {
if (stats_collector_) {
@ -1256,8 +1265,14 @@ void ExecutorState<PropagatorStateType>::Finish() {
}
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
(new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
->RunAsync(std::move(done));
if (immutable_state_.requires_control_flow_support()) {
(new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
->RunAsync(std::move(done));
} else {
(new ExecutorState<SimplePropagatorState>(args, immutable_state_,
&kernel_stats_))
->RunAsync(std::move(done));
}
}
} // namespace

View File

@ -413,6 +413,14 @@ TEST_F(ExecutorTest, RecvInvalidRefDtype) {
rendez->Unref();
}
TEST_F(ExecutorTest, NoInputTensors) {
// Create a graph where none of the nodes have input tensors.
auto g = absl::make_unique<Graph>(OpRegistry::Global());
test::graph::Constant(g.get(), V(1.0));
Create(std::move(g));
TF_ASSERT_OK(Run(rendez_));
}
// Create a graph that is 'depth' deep. At each level, fan-in and fan-out a
// maximum of 'width' nodes. All nodes are no-ops and all dependencies are
// control dependencies.

View File

@ -211,6 +211,8 @@ class GraphView {
Status SetAllocAttrs(const Graph* g, const Device* device);
void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
// Returns a mutable pointer to the `NodeItem` with the given `id` if it
// exists in the graph, or `nullptr` if it does not.
NodeItem* node(int32 id) const {
DCHECK_GE(id, 0);
DCHECK_LT(id, num_nodes_);
@ -220,6 +222,17 @@ class GraphView {
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
}
// Returns the `NodeItem` with the given `id`.
//
// REQUIRES: `id` must be the ID of a valid node in the graph.
const NodeItem& node_ref(int32 id) const {
DCHECK_GE(id, 0);
DCHECK_LT(id, num_nodes_);
uint32 offset = node_offsets_[id];
DCHECK_NE(offset, kuint32max);
return *reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]);
}
int32 num_nodes() const { return num_nodes_; }
private:

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/edgeset.h"
#include "tensorflow/core/graph/graph.h"
@ -88,13 +89,33 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) {
EnsureFrameInfo(it)->nodes =
absl::make_unique<std::vector<const NodeItem*>>();
}
root_frame_info_ = frame_info_[""];
pending_ids_.resize(gview_.num_nodes());
// Preprocess every node in the graph to create an instance of op
// kernel for each node.
requires_control_flow_ = false;
for (const Node* n : graph.nodes()) {
if (IsSink(n)) continue;
if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) {
requires_control_flow_ = true;
} else if (IsRecv(n)) {
// A Recv node from a different device may produce dead tensors from
// non-local control-flow nodes.
//
// TODO(mrry): Track whether control flow was present in the
// pre-partitioned graph, and enable the caller (e.g.
// `DirectSession`) to relax this constraint.
string send_device;
string recv_device;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device));
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device));
if (send_device != recv_device) {
requires_control_flow_ = true;
}
}
const int id = n->id();
const string& frame_name = cf_info.frame_names[id];
FrameInfo* frame_info = EnsureFrameInfo(frame_name);

View File

@ -91,6 +91,10 @@ class ImmutableExecutorState {
}
}
const FrameInfo& get_root_frame_info() const { return *root_frame_info_; }
bool requires_control_flow_support() const { return requires_control_flow_; }
private:
struct ControlFlowInfo {
gtl::FlatSet<string> unique_frame_names;
@ -106,6 +110,7 @@ class ImmutableExecutorState {
// Owned.
LocalExecutorParams params_;
GraphView gview_;
bool requires_control_flow_;
std::vector<PendingCounts::Handle> pending_ids_;
// Root nodes (with no in edges) that should form the initial ready queue
@ -115,6 +120,7 @@ class ImmutableExecutorState {
// TODO(yuanbyu): We could cache it along with the graph so to avoid
// the overhead of constructing it for each executor instance.
gtl::FlatMap<string, FrameInfo*> frame_info_;
const FrameInfo* root_frame_info_; // Not owned.
// Shallow copies of the constant tensors used in the graph.
std::vector<Tensor> const_tensors_;

View File

@ -0,0 +1,95 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
#include <vector>
#include "tensorflow/core/common_runtime/entry.h"
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
// 1-D, 0 element tensor.
static const Tensor* const kEmptyTensor = new Tensor;
const Tensor* GetTensorValueForDump(const Entry& input) {
switch (input.state) {
case Entry::State::NO_VALUE:
return kEmptyTensor;
case Entry::State::HAS_VALUE:
return input.val.get();
case Entry::State::HAS_CONST_TENSOR:
return input.const_tensor;
case Entry::State::HAS_REF_TENSOR:
return input.ref_tensor.tensor;
}
}
void DumpPendingNodeState(const ImmutableExecutorState& immutable_state,
const int node_id, const Entry* input_vector,
const bool show_nodes_with_no_ready_inputs) {
const NodeItem& node_item = immutable_state.graph_view().node_ref(node_id);
const int input_base = node_item.input_start;
if (!show_nodes_with_no_ready_inputs) {
bool has_ready_input = false;
for (int i = 0; i < node_item.num_inputs; ++i) {
const Entry& input = input_vector[input_base + i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor && tensor->IsInitialized()) {
has_ready_input = true;
break;
}
}
if (!has_ready_input) {
return;
}
}
LOG(WARNING) << " Pending Node: " << node_item.DebugString();
for (int i = 0; i < node_item.num_inputs; ++i) {
const Entry& input = input_vector[input_base + i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor->IsInitialized()) {
LOG(WARNING) << " Input " << i << ": "
<< strings::StrCat(
"Tensor<type: ", DataTypeString(tensor->dtype()),
" shape: ", tensor->shape().DebugString(), ">");
} else {
LOG(WARNING) << " Input " << i << ": not present";
}
}
}
void DumpActiveNodeState(const ImmutableExecutorState& immutable_state,
const int node_id, const Entry* input_vector) {
const NodeItem& node_item = immutable_state.graph_view().node_ref(node_id);
LOG(WARNING) << " Active Node: " << node_item.DebugString();
const int input_base = node_item.input_start;
for (int i = 0; i < node_item.num_inputs; ++i) {
const Entry& input = input_vector[input_base + i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor->IsInitialized()) {
LOG(WARNING) << " Input " << i << ": "
<< strings::StrCat(
"Tensor<type: ", DataTypeString(tensor->dtype()),
" shape: ", tensor->shape().DebugString(), ">");
} else {
LOG(WARNING) << " Input " << i << ": not present";
}
}
}
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_
namespace tensorflow {
struct Entry;
class ImmutableExecutorState;
class Tensor;
// Returns a pointer to the tensor in `input` if one exists, or `nullptr`.
const Tensor* GetTensorValueForDump(const Entry& input);
// Writes a LOG(WARNING) message describing the state of the pending node
// `node_id` in the graph described by `immutable_state`.
void DumpPendingNodeState(const ImmutableExecutorState& immutable_state,
const int node_id, const Entry* input_vector,
const bool show_nodes_with_no_ready_inputs);
// Writes a LOG(WARNING) message describing the state of the active node
// `node_id` in the graph described by `immutable_state`.
void DumpActiveNodeState(const ImmutableExecutorState& immutable_state,
const int node_id, const Entry* input_vector);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_

View File

@ -16,17 +16,12 @@ limitations under the License.
#include "tensorflow/core/common_runtime/propagator_state.h"
#include "tensorflow/core/common_runtime/graph_view.h"
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
// 1-D, 0 element tensor.
static const Tensor* const kEmptyTensor = new Tensor;
typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
int64 step_id)
: immutable_state_(immutable_state),
@ -57,7 +52,7 @@ void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
TaggedNodeSeq* ready) {
for (const NodeItem* item : roots) {
DCHECK_EQ(item->num_inputs, 0);
ready->push_back(TaggedNode{item, root_frame_, 0, false});
ready->emplace_back(item, root_frame_, 0, false);
}
mutex_lock l(root_frame_->mu);
root_frame_->GetIteration(0)->outstanding_ops = ready->size();
@ -173,72 +168,6 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
}
}
const Tensor* PropagatorState::GetTensorValueForDump(const Entry& input) {
switch (input.state) {
case Entry::State::NO_VALUE:
return kEmptyTensor;
case Entry::State::HAS_VALUE:
return input.val.get();
case Entry::State::HAS_CONST_TENSOR:
return input.const_tensor;
case Entry::State::HAS_REF_TENSOR:
return input.ref_tensor.tensor;
}
}
void PropagatorState::DumpPendingNodeState(
const int node_id, const Entry* input_vector,
const bool show_nodes_with_no_ready_inputs) {
const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
const int input_base = node_item.input_start;
if (!show_nodes_with_no_ready_inputs) {
bool has_ready_input = false;
for (int i = 0; i < node_item.num_inputs; ++i) {
const Entry& input = input_vector[input_base + i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor->IsInitialized()) {
has_ready_input = true;
break;
}
}
if (!has_ready_input) {
return;
}
}
LOG(WARNING) << " Pending Node: " << node_item.DebugString();
for (int i = 0; i < node_item.num_inputs; ++i) {
const Entry& input = input_vector[input_base + i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor->IsInitialized()) {
LOG(WARNING) << " Input " << i << ": "
<< strings::StrCat(
"Tensor<type: ", DataTypeString(tensor->dtype()),
" shape: ", tensor->shape().DebugString(), ">");
} else {
LOG(WARNING) << " Input " << i << ": not present";
}
}
}
void PropagatorState::DumpActiveNodeState(const int node_id,
const Entry* input_vector) {
const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
LOG(WARNING) << " Active Node: " << node_item.DebugString();
const int input_base = node_item.input_start;
for (int i = 0; i < node_item.num_inputs; ++i) {
const Entry& input = input_vector[input_base + i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor->IsInitialized()) {
LOG(WARNING) << " Input " << i << ": "
<< strings::StrCat(
"Tensor<type: ", DataTypeString(tensor->dtype()),
" shape: ", tensor->shape().DebugString(), ">");
} else {
LOG(WARNING) << " Input " << i << ": not present";
}
}
}
void PropagatorState::DumpIterationState(const FrameState* frame,
IterationState* iteration) {
const std::vector<const NodeItem*>* nodes = frame->nodes;
@ -248,7 +177,8 @@ void PropagatorState::DumpIterationState(const FrameState* frame,
immutable_state_.pending_ids()[node->node_id];
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
DumpPendingNodeState(node->node_id, iteration->input_tensors, false);
DumpPendingNodeState(immutable_state_, node->node_id,
iteration->input_tensors, false);
}
}
// Then the active nodes.
@ -256,7 +186,8 @@ void PropagatorState::DumpIterationState(const FrameState* frame,
PendingCounts::Handle pending_id =
immutable_state_.pending_ids()[node->node_id];
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
DumpActiveNodeState(node->node_id, iteration->input_tensors);
DumpActiveNodeState(immutable_state_, node->node_id,
iteration->input_tensors);
}
}
// Show all input tensors in use.
@ -279,14 +210,11 @@ void PropagatorState::DumpIterationState(const FrameState* frame,
void PropagatorState::DumpState() {
mutex_lock l(mu_);
if (!dumped_on_error_) {
LOG(WARNING) << "Dumping state";
for (auto& frame : outstanding_frames_) {
LOG(WARNING) << frame.first;
FrameState* frame_state = frame.second;
frame_state->DumpIterationState(this);
}
dumped_on_error_ = true;
LOG(WARNING) << "Dumping state";
for (auto& frame : outstanding_frames_) {
LOG(WARNING) << frame.first;
FrameState* frame_state = frame.second;
frame_state->DumpIterationState(this);
}
}
@ -378,7 +306,7 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
for (const EdgeInfo& e : item->output_edges()) {
const NodeItem& dst_item =
*immutable_state_.graph_view().node(e.dst_id);
immutable_state_.graph_view().node_ref(e.dst_id);
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
bool dst_dead = true;
@ -398,7 +326,7 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
for (const ControlEdgeInfo& e : item->output_control_edges()) {
const NodeItem& dst_item =
*immutable_state_.graph_view().node(e.dst_id);
immutable_state_.graph_view().node_ref(e.dst_id);
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
bool dst_dead;
@ -464,17 +392,17 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
//
// NOTE(mrry): Use a macro here instead of a lambda, because this method is
// performance-critical and we need to ensure that the code is inlined.
#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
do { \
if (!adjust_result.any_pending) { \
const NodeItem* dst_item = gview.node(dst_id); \
TaggedNode& t = ready->emplace_back(); \
t.node_item = dst_item; \
t.input_frame = this; \
t.input_iter = iter; \
t.is_dead = adjust_result.any_dead; \
iter_state->outstanding_ops++; \
} \
#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
do { \
if (!adjust_result.any_pending) { \
const NodeItem* dst_item = &gview.node_ref(dst_id); \
TaggedNode& t = ready->emplace_back(); \
t.node_item = dst_item; \
t.input_frame = this; \
t.input_iter = iter; \
t.is_dead = adjust_result.any_dead; \
iter_state->outstanding_ops++; \
} \
} while (0);
Entry* input_tensors = iter_state->input_tensors;
@ -534,7 +462,7 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
for (const EdgeInfo& e : item->output_edges()) {
const int dst_id = e.dst_id;
const NodeItem* dst_item = gview.node(dst_id);
const NodeItem* dst_item = &gview.node_ref(dst_id);
const PendingCounts::Handle dst_pending_id =
immutable_state.pending_ids()[dst_id];
const int src_slot = e.output_slot;
@ -596,7 +524,7 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
for (const ControlEdgeInfo& e : item->output_control_edges()) {
const int dst_id = e.dst_id;
const NodeItem* dst_item = gview.node(dst_id);
const NodeItem* dst_item = &gview.node_ref(dst_id);
const PendingCounts::Handle dst_pending_id =
immutable_state.pending_ids()[dst_id];

View File

@ -429,26 +429,15 @@ class PropagatorState {
void CleanupFramesIterations(FrameState* frame, int64 iter,
TaggedNodeSeq* ready);
// Provide debugging output about an outstanding node in the executor.
void DumpPendingNodeState(const int node_id, const Entry* input_vector,
bool show_nodes_with_no_ready_inputs);
void DumpActiveNodeState(const int node_id, const Entry* input_vector);
// Provide debugging output about an outstanding iteration in the executor.
void DumpIterationState(const FrameState* frame, IterationState* iteration);
const Tensor* GetTensorValueForDump(const Entry& input);
const ImmutableExecutorState& immutable_state_;
const int64 step_id_;
const bool vlog_;
mutex mu_;
// A flag that is set on error after the frame state has been
// dumped for diagnostic purposes.
bool dumped_on_error_ TF_GUARDED_BY(mu_) = false;
// The root frame in which the execution of this step is started.
FrameState* root_frame_;

View File

@ -0,0 +1,134 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/simple_propagator_state.h"
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
SimplePropagatorState::SimplePropagatorState(
const ImmutableExecutorState& immutable_state, int64 step_id)
: SimplePropagatorState(immutable_state, step_id,
immutable_state.get_root_frame_info()) {}
SimplePropagatorState::SimplePropagatorState(
const ImmutableExecutorState& immutable_state, int64 step_id,
const ImmutableExecutorState::FrameInfo& finfo)
: immutable_state_(immutable_state),
step_id_(step_id),
vlog_(VLOG_IS_ON(1)),
input_tensors_(finfo.total_inputs),
counts_(*finfo.pending_counts),
nodes_(finfo.nodes.get()) {}
SimplePropagatorState::~SimplePropagatorState() {}
void SimplePropagatorState::ActivateRoots(
gtl::ArraySlice<const NodeItem*> roots, TaggedNodeSeq* ready) {
for (const NodeItem* item : roots) {
DCHECK_EQ(item->num_inputs, 0);
ready->emplace_back(item);
}
}
void SimplePropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
EntryVector* outputs,
TaggedNodeSeq* ready) {
profiler::TraceMe activity(
[&]() {
return strings::StrCat(
"ExecutorPropagateOutputs#", "id=", step_id_,
",kernel_name=", tagged_node.node_item->kernel->name_view(),
",num_output_edges=", tagged_node.node_item->num_output_edges,
",num_output_control_edges=",
tagged_node.node_item->num_output_control_edges, "#");
},
profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
// Propagates outputs along out edges, and puts newly ready nodes
// into the ready queue.
DCHECK(ready->empty());
const GraphView& gview = immutable_state_.graph_view();
const NodeItem* item = tagged_node.node_item;
mutex_lock l(mu_);
for (const EdgeInfo& e : item->output_edges()) {
const int dst_id = e.dst_id;
const PendingCounts::Handle dst_pending_id =
immutable_state_.pending_ids()[dst_id];
const int src_slot = e.output_slot;
int num_pending = counts_.decrement_pending(dst_pending_id, 1);
const int dst_loc = e.input_slot;
if (e.is_last) {
input_tensors_[dst_loc] = std::move((*outputs)[src_slot]);
} else {
input_tensors_[dst_loc] = (*outputs)[src_slot];
}
if (num_pending == 0) ready->emplace_back(&gview.node_ref(dst_id));
}
for (const ControlEdgeInfo& e : item->output_control_edges()) {
const int dst_id = e.dst_id;
const PendingCounts::Handle dst_pending_id =
immutable_state_.pending_ids()[dst_id];
int num_pending = counts_.decrement_pending(dst_pending_id, 1);
if (num_pending == 0) ready->emplace_back(&gview.node_ref(dst_id));
}
}
void SimplePropagatorState::DumpState() {
mutex_lock l(mu_);
LOG(WARNING) << "Dumping state";
// Dump any waiting nodes that are holding on to tensors.
for (const NodeItem* node : *nodes_) {
PendingCounts::Handle pending_id =
immutable_state_.pending_ids()[node->node_id];
if (counts_.node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
counts_.node_state(pending_id) == PendingCounts::PENDING_READY) {
DumpPendingNodeState(immutable_state_, node->node_id,
input_tensors_.data(), false);
}
}
// Then the active nodes.
for (const NodeItem* node : *nodes_) {
PendingCounts::Handle pending_id =
immutable_state_.pending_ids()[node->node_id];
if (counts_.node_state(pending_id) == PendingCounts::STARTED) {
DumpActiveNodeState(immutable_state_, node->node_id,
input_tensors_.data());
}
}
// Show all input tensors in use.
size_t total_bytes = 0;
for (size_t i = 0; i < input_tensors_.size(); ++i) {
const Entry& input = input_tensors_[i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor && tensor->IsInitialized()) {
LOG(WARNING) << " Input " << i << ": "
<< strings::StrCat(
"Tensor<type: ", DataTypeString(tensor->dtype()),
" shape: ", tensor->shape().DebugString(),
", bytes: ", tensor->TotalBytes(), ">");
total_bytes += tensor->TotalBytes();
}
}
LOG(WARNING) << " Total bytes " << total_bytes;
}
} // namespace tensorflow

View File

@ -0,0 +1,186 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
#include <vector>
#include "tensorflow/core/common_runtime/entry.h"
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
#include "tensorflow/core/common_runtime/pending_counts.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Represents the ephemeral "edge state" associated with one invocation of
// `Executor::Run()`.
//
// NOTE: `SimplePropagatorState` does not support "v1-style" control flow,
// including "dead tensors", "Switch" and "Merge" nodes, and cycles in the
// graph. Use `PropagatorState` for graphs with those features.
// `SimplePropagatorState` *does* support "v2-style" or "functional" control
// flow.
//
// `SimplePropagatorState` is responsible for propagating values along dataflow
// edges in a TensorFlow graph and determining which nodes are runnable. The
// executor primarily updates `SimplePropagatorState` by calling
// `PropagateOutputs()` after processing a node, and `SimplePropagatorState`
// dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`.
class SimplePropagatorState {
public:
SimplePropagatorState(const ImmutableExecutorState& immutable_state,
int64 step_id);
~SimplePropagatorState();
// A `TaggedNode` corresponds to a single invocation of a node's kernel,
// and it is created when the kernel becomes runnable.
struct TaggedNode {
const NodeItem* node_item;
explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {}
const NodeItem& get_node_item() const { return *node_item; }
bool get_is_dead() const { return false; }
int64 get_iter_num() const { return 0; }
};
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
// have that many nodes in the ready queue, so we just use a vector and
// don't free up memory from the queue as we consume nodes.
// TODO(mrry): Extract this and share it with the version in
// `PropagatorState`. The correct constants might be different, since
// sizeof(TaggedNode) is smaller in this version.
class TaggedNodeReadyQueue {
public:
TaggedNodeReadyQueue() : front_index_(0) {}
void push_back(const TaggedNode& node) { ready_.push_back(node); }
TaggedNode front() const {
DCHECK_LT(front_index_, ready_.size());
return ready_[front_index_];
}
void pop_front() {
DCHECK_LT(front_index_, ready_.size());
front_index_++;
if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
if (front_index_ == ready_.size()) {
ready_.clear();
} else {
// Lots of unused entries at beginning of vector: move everything
// down to start of vector.
ready_.erase(ready_.begin(), ready_.begin() + front_index_);
}
front_index_ = 0;
}
}
bool empty() const { return ready_.empty(); }
private:
// TODO(b/152925936): Re-evaluate these constants with current usage
// patterns.
static constexpr int kSpillThreshold = 16384;
gtl::InlinedVector<TaggedNode, 16> ready_;
int front_index_;
};
// TODO(b/152925936): Re-evaluate this constant with current usage patterns.
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
// Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
TaggedNodeSeq* ready);
// After processing the outputs, propagates the outputs to their dsts.
// Contents of *outputs are left in an indeterminate state after
// returning from this method.
void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
TaggedNodeSeq* ready);
// Returns an array of `Entry` objects corresponding to the inputs of
// `tagged_node`.
//
// NOTE: Thread safety analysis is disabled on this method, because the
// underlying `IterationState` and its array of `input_tensors` retain the
// same address while the iteration is live.
Entry* GetInputTensors(const TaggedNode& tagged_node)
TF_NO_THREAD_SAFETY_ANALYSIS {
return input_tensors_.data() + tagged_node.node_item->input_start;
}
FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
return {0, 0};
}
// Provide debugging output of the state of the executor.
void DumpState();
// For debugging/logging only.
void MaybeMarkStarted(const TaggedNode& tagged_node) {
// TODO(misard) Replace with a finer-grain enabling flag once we add better
// optional debugging support.
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
mutex_lock l(mu_);
counts_.mark_started(
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
}
}
void MaybeMarkCompleted(const TaggedNode& tagged_node) {
// TODO(misard) Replace with a finer-grain enabling flag once we add better
// optional debugging support.
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
mutex_lock l(mu_);
counts_.mark_completed(
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
}
}
private:
SimplePropagatorState(const ImmutableExecutorState& immutable_state_,
int64 step_id,
const ImmutableExecutorState::FrameInfo& finfo);
const ImmutableExecutorState& immutable_state_;
const int64 step_id_;
const bool vlog_;
mutex mu_;
// The i-th node's j-th input is stored at
// `input_tensors_[impl_->nodes[i].input_start + j]`.
//
// NOTE: No need to protect input_tensors[i] by any locks because it
// is resized once. Each element of input_tensors is written once by the
// source node of an edge and is cleared by the destination of the same
// edge. The destination node always runs after the source node, so there
// is never concurrent access to the same entry.
std::vector<Entry> input_tensors_;
PendingCounts counts_ TF_GUARDED_BY(mu_);
const std::vector<const NodeItem*>* const nodes_;
TF_DISALLOW_COPY_AND_ASSIGN(SimplePropagatorState);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_