Rolling forward "[Executor] Implement SimplePropagatorState
and use it when a graph has no v1-style control flow."
The previous version had undefined behavior when executing a graph comprised solely of operations with no input tensors, in which case the `SimpleExecutorState::input_tensors_` vector would be empty. Using `vector::data()` instead of `operator[]` avoids creating a reference to a null element when the vector is empty. PiperOrigin-RevId: 304510355 Change-Id: I12785a6ac72355d0e5b1a312d41e227dbdf9dc96
This commit is contained in:
parent
9ba3ff2df6
commit
0bde6a68f2
@ -232,6 +232,7 @@ filegroup(
|
||||
"process_util.h",
|
||||
"inspecting_placer.h",
|
||||
"profile_handler.h",
|
||||
"propagator_debug_utils.h",
|
||||
"propagator_state.h",
|
||||
"renamed_device.h",
|
||||
"rendezvous_mgr.h",
|
||||
@ -241,6 +242,7 @@ filegroup(
|
||||
"ring_alg.h",
|
||||
"ring_gatherer.h",
|
||||
"session_factory.h",
|
||||
"simple_propagator_state.h",
|
||||
"single_threaded_cpu_device.h",
|
||||
"stats_publisher_interface.h",
|
||||
"step_stats_collector.h",
|
||||
@ -304,6 +306,7 @@ tf_cuda_library(
|
||||
"process_function_library_runtime.cc",
|
||||
"process_state.cc",
|
||||
"process_util.cc",
|
||||
"propagator_debug_utils.cc",
|
||||
"propagator_state.cc",
|
||||
"renamed_device.cc",
|
||||
"rendezvous_mgr.cc",
|
||||
@ -316,6 +319,7 @@ tf_cuda_library(
|
||||
"session_factory.cc",
|
||||
"session_options.cc",
|
||||
"session_state.cc",
|
||||
"simple_propagator_state.cc",
|
||||
"single_threaded_cpu_device.cc",
|
||||
"stats_publisher_interface.cc",
|
||||
"step_stats_collector.cc",
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/pending_counts.h"
|
||||
#include "tensorflow/core/common_runtime/propagator_state.h"
|
||||
#include "tensorflow/core/common_runtime/renamed_device.h"
|
||||
#include "tensorflow/core/common_runtime/simple_propagator_state.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
@ -372,6 +373,10 @@ class ExecutorState {
|
||||
|
||||
mutex mu_;
|
||||
Status status_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// A flag that is set on error after the propagator state has been
|
||||
// dumped for diagnostic purposes.
|
||||
bool dumped_on_error_ TF_GUARDED_BY(mu_) = false;
|
||||
};
|
||||
|
||||
template <class PropagatorStateType>
|
||||
@ -925,7 +930,11 @@ Status ExecutorState<PropagatorStateType>::ProcessOutputs(
|
||||
// add better optional debugging support.
|
||||
if (vlog_ && VLOG_IS_ON(1)) {
|
||||
LOG(WARNING) << this << " Compute status: " << s;
|
||||
propagator_.DumpState();
|
||||
mutex_lock l(mu_);
|
||||
if (!dumped_on_error_) {
|
||||
propagator_.DumpState();
|
||||
dumped_on_error_ = true;
|
||||
}
|
||||
}
|
||||
if (s.code() == error::RESOURCE_EXHAUSTED) {
|
||||
if (stats_collector_) {
|
||||
@ -1256,8 +1265,14 @@ void ExecutorState<PropagatorStateType>::Finish() {
|
||||
}
|
||||
|
||||
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
|
||||
(new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
|
||||
->RunAsync(std::move(done));
|
||||
if (immutable_state_.requires_control_flow_support()) {
|
||||
(new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
|
||||
->RunAsync(std::move(done));
|
||||
} else {
|
||||
(new ExecutorState<SimplePropagatorState>(args, immutable_state_,
|
||||
&kernel_stats_))
|
||||
->RunAsync(std::move(done));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -413,6 +413,14 @@ TEST_F(ExecutorTest, RecvInvalidRefDtype) {
|
||||
rendez->Unref();
|
||||
}
|
||||
|
||||
TEST_F(ExecutorTest, NoInputTensors) {
|
||||
// Create a graph where none of the nodes have input tensors.
|
||||
auto g = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
test::graph::Constant(g.get(), V(1.0));
|
||||
Create(std::move(g));
|
||||
TF_ASSERT_OK(Run(rendez_));
|
||||
}
|
||||
|
||||
// Create a graph that is 'depth' deep. At each level, fan-in and fan-out a
|
||||
// maximum of 'width' nodes. All nodes are no-ops and all dependencies are
|
||||
// control dependencies.
|
||||
|
@ -211,6 +211,8 @@ class GraphView {
|
||||
Status SetAllocAttrs(const Graph* g, const Device* device);
|
||||
void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
|
||||
|
||||
// Returns a mutable pointer to the `NodeItem` with the given `id` if it
|
||||
// exists in the graph, or `nullptr` if it does not.
|
||||
NodeItem* node(int32 id) const {
|
||||
DCHECK_GE(id, 0);
|
||||
DCHECK_LT(id, num_nodes_);
|
||||
@ -220,6 +222,17 @@ class GraphView {
|
||||
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
|
||||
}
|
||||
|
||||
// Returns the `NodeItem` with the given `id`.
|
||||
//
|
||||
// REQUIRES: `id` must be the ID of a valid node in the graph.
|
||||
const NodeItem& node_ref(int32 id) const {
|
||||
DCHECK_GE(id, 0);
|
||||
DCHECK_LT(id, num_nodes_);
|
||||
uint32 offset = node_offsets_[id];
|
||||
DCHECK_NE(offset, kuint32max);
|
||||
return *reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]);
|
||||
}
|
||||
|
||||
int32 num_nodes() const { return num_nodes_; }
|
||||
|
||||
private:
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/edgeset.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -88,13 +89,33 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) {
|
||||
EnsureFrameInfo(it)->nodes =
|
||||
absl::make_unique<std::vector<const NodeItem*>>();
|
||||
}
|
||||
root_frame_info_ = frame_info_[""];
|
||||
|
||||
pending_ids_.resize(gview_.num_nodes());
|
||||
|
||||
// Preprocess every node in the graph to create an instance of op
|
||||
// kernel for each node.
|
||||
requires_control_flow_ = false;
|
||||
for (const Node* n : graph.nodes()) {
|
||||
if (IsSink(n)) continue;
|
||||
if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) {
|
||||
requires_control_flow_ = true;
|
||||
} else if (IsRecv(n)) {
|
||||
// A Recv node from a different device may produce dead tensors from
|
||||
// non-local control-flow nodes.
|
||||
//
|
||||
// TODO(mrry): Track whether control flow was present in the
|
||||
// pre-partitioned graph, and enable the caller (e.g.
|
||||
// `DirectSession`) to relax this constraint.
|
||||
string send_device;
|
||||
string recv_device;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device));
|
||||
if (send_device != recv_device) {
|
||||
requires_control_flow_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
const int id = n->id();
|
||||
const string& frame_name = cf_info.frame_names[id];
|
||||
FrameInfo* frame_info = EnsureFrameInfo(frame_name);
|
||||
|
@ -91,6 +91,10 @@ class ImmutableExecutorState {
|
||||
}
|
||||
}
|
||||
|
||||
const FrameInfo& get_root_frame_info() const { return *root_frame_info_; }
|
||||
|
||||
bool requires_control_flow_support() const { return requires_control_flow_; }
|
||||
|
||||
private:
|
||||
struct ControlFlowInfo {
|
||||
gtl::FlatSet<string> unique_frame_names;
|
||||
@ -106,6 +110,7 @@ class ImmutableExecutorState {
|
||||
// Owned.
|
||||
LocalExecutorParams params_;
|
||||
GraphView gview_;
|
||||
bool requires_control_flow_;
|
||||
std::vector<PendingCounts::Handle> pending_ids_;
|
||||
|
||||
// Root nodes (with no in edges) that should form the initial ready queue
|
||||
@ -115,6 +120,7 @@ class ImmutableExecutorState {
|
||||
// TODO(yuanbyu): We could cache it along with the graph so to avoid
|
||||
// the overhead of constructing it for each executor instance.
|
||||
gtl::FlatMap<string, FrameInfo*> frame_info_;
|
||||
const FrameInfo* root_frame_info_; // Not owned.
|
||||
|
||||
// Shallow copies of the constant tensors used in the graph.
|
||||
std::vector<Tensor> const_tensors_;
|
||||
|
95
tensorflow/core/common_runtime/propagator_debug_utils.cc
Normal file
95
tensorflow/core/common_runtime/propagator_debug_utils.cc
Normal file
@ -0,0 +1,95 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/entry.h"
|
||||
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// 1-D, 0 element tensor.
|
||||
static const Tensor* const kEmptyTensor = new Tensor;
|
||||
|
||||
|
||||
const Tensor* GetTensorValueForDump(const Entry& input) {
|
||||
switch (input.state) {
|
||||
case Entry::State::NO_VALUE:
|
||||
return kEmptyTensor;
|
||||
case Entry::State::HAS_VALUE:
|
||||
return input.val.get();
|
||||
case Entry::State::HAS_CONST_TENSOR:
|
||||
return input.const_tensor;
|
||||
case Entry::State::HAS_REF_TENSOR:
|
||||
return input.ref_tensor.tensor;
|
||||
}
|
||||
}
|
||||
|
||||
void DumpPendingNodeState(const ImmutableExecutorState& immutable_state,
|
||||
const int node_id, const Entry* input_vector,
|
||||
const bool show_nodes_with_no_ready_inputs) {
|
||||
const NodeItem& node_item = immutable_state.graph_view().node_ref(node_id);
|
||||
const int input_base = node_item.input_start;
|
||||
if (!show_nodes_with_no_ready_inputs) {
|
||||
bool has_ready_input = false;
|
||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||
const Entry& input = input_vector[input_base + i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor && tensor->IsInitialized()) {
|
||||
has_ready_input = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!has_ready_input) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG(WARNING) << " Pending Node: " << node_item.DebugString();
|
||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||
const Entry& input = input_vector[input_base + i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor->IsInitialized()) {
|
||||
LOG(WARNING) << " Input " << i << ": "
|
||||
<< strings::StrCat(
|
||||
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||
" shape: ", tensor->shape().DebugString(), ">");
|
||||
} else {
|
||||
LOG(WARNING) << " Input " << i << ": not present";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DumpActiveNodeState(const ImmutableExecutorState& immutable_state,
|
||||
const int node_id, const Entry* input_vector) {
|
||||
const NodeItem& node_item = immutable_state.graph_view().node_ref(node_id);
|
||||
LOG(WARNING) << " Active Node: " << node_item.DebugString();
|
||||
const int input_base = node_item.input_start;
|
||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||
const Entry& input = input_vector[input_base + i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor->IsInitialized()) {
|
||||
LOG(WARNING) << " Input " << i << ": "
|
||||
<< strings::StrCat(
|
||||
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||
" shape: ", tensor->shape().DebugString(), ">");
|
||||
} else {
|
||||
LOG(WARNING) << " Input " << i << ": not present";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
40
tensorflow/core/common_runtime/propagator_debug_utils.h
Normal file
40
tensorflow/core/common_runtime/propagator_debug_utils.h
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
struct Entry;
|
||||
class ImmutableExecutorState;
|
||||
class Tensor;
|
||||
|
||||
// Returns a pointer to the tensor in `input` if one exists, or `nullptr`.
|
||||
const Tensor* GetTensorValueForDump(const Entry& input);
|
||||
|
||||
// Writes a LOG(WARNING) message describing the state of the pending node
|
||||
// `node_id` in the graph described by `immutable_state`.
|
||||
void DumpPendingNodeState(const ImmutableExecutorState& immutable_state,
|
||||
const int node_id, const Entry* input_vector,
|
||||
const bool show_nodes_with_no_ready_inputs);
|
||||
|
||||
// Writes a LOG(WARNING) message describing the state of the active node
|
||||
// `node_id` in the graph described by `immutable_state`.
|
||||
void DumpActiveNodeState(const ImmutableExecutorState& immutable_state,
|
||||
const int node_id, const Entry* input_vector);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_
|
@ -16,17 +16,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/propagator_state.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/graph_view.h"
|
||||
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// 1-D, 0 element tensor.
|
||||
static const Tensor* const kEmptyTensor = new Tensor;
|
||||
|
||||
typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
|
||||
typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
|
||||
|
||||
PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
|
||||
int64 step_id)
|
||||
: immutable_state_(immutable_state),
|
||||
@ -57,7 +52,7 @@ void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
|
||||
TaggedNodeSeq* ready) {
|
||||
for (const NodeItem* item : roots) {
|
||||
DCHECK_EQ(item->num_inputs, 0);
|
||||
ready->push_back(TaggedNode{item, root_frame_, 0, false});
|
||||
ready->emplace_back(item, root_frame_, 0, false);
|
||||
}
|
||||
mutex_lock l(root_frame_->mu);
|
||||
root_frame_->GetIteration(0)->outstanding_ops = ready->size();
|
||||
@ -173,72 +168,6 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
}
|
||||
}
|
||||
|
||||
const Tensor* PropagatorState::GetTensorValueForDump(const Entry& input) {
|
||||
switch (input.state) {
|
||||
case Entry::State::NO_VALUE:
|
||||
return kEmptyTensor;
|
||||
case Entry::State::HAS_VALUE:
|
||||
return input.val.get();
|
||||
case Entry::State::HAS_CONST_TENSOR:
|
||||
return input.const_tensor;
|
||||
case Entry::State::HAS_REF_TENSOR:
|
||||
return input.ref_tensor.tensor;
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::DumpPendingNodeState(
|
||||
const int node_id, const Entry* input_vector,
|
||||
const bool show_nodes_with_no_ready_inputs) {
|
||||
const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
|
||||
const int input_base = node_item.input_start;
|
||||
if (!show_nodes_with_no_ready_inputs) {
|
||||
bool has_ready_input = false;
|
||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||
const Entry& input = input_vector[input_base + i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor->IsInitialized()) {
|
||||
has_ready_input = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!has_ready_input) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG(WARNING) << " Pending Node: " << node_item.DebugString();
|
||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||
const Entry& input = input_vector[input_base + i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor->IsInitialized()) {
|
||||
LOG(WARNING) << " Input " << i << ": "
|
||||
<< strings::StrCat(
|
||||
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||
" shape: ", tensor->shape().DebugString(), ">");
|
||||
} else {
|
||||
LOG(WARNING) << " Input " << i << ": not present";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::DumpActiveNodeState(const int node_id,
|
||||
const Entry* input_vector) {
|
||||
const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
|
||||
LOG(WARNING) << " Active Node: " << node_item.DebugString();
|
||||
const int input_base = node_item.input_start;
|
||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||
const Entry& input = input_vector[input_base + i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor->IsInitialized()) {
|
||||
LOG(WARNING) << " Input " << i << ": "
|
||||
<< strings::StrCat(
|
||||
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||
" shape: ", tensor->shape().DebugString(), ">");
|
||||
} else {
|
||||
LOG(WARNING) << " Input " << i << ": not present";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::DumpIterationState(const FrameState* frame,
|
||||
IterationState* iteration) {
|
||||
const std::vector<const NodeItem*>* nodes = frame->nodes;
|
||||
@ -248,7 +177,8 @@ void PropagatorState::DumpIterationState(const FrameState* frame,
|
||||
immutable_state_.pending_ids()[node->node_id];
|
||||
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
|
||||
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
|
||||
DumpPendingNodeState(node->node_id, iteration->input_tensors, false);
|
||||
DumpPendingNodeState(immutable_state_, node->node_id,
|
||||
iteration->input_tensors, false);
|
||||
}
|
||||
}
|
||||
// Then the active nodes.
|
||||
@ -256,7 +186,8 @@ void PropagatorState::DumpIterationState(const FrameState* frame,
|
||||
PendingCounts::Handle pending_id =
|
||||
immutable_state_.pending_ids()[node->node_id];
|
||||
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
|
||||
DumpActiveNodeState(node->node_id, iteration->input_tensors);
|
||||
DumpActiveNodeState(immutable_state_, node->node_id,
|
||||
iteration->input_tensors);
|
||||
}
|
||||
}
|
||||
// Show all input tensors in use.
|
||||
@ -279,14 +210,11 @@ void PropagatorState::DumpIterationState(const FrameState* frame,
|
||||
|
||||
void PropagatorState::DumpState() {
|
||||
mutex_lock l(mu_);
|
||||
if (!dumped_on_error_) {
|
||||
LOG(WARNING) << "Dumping state";
|
||||
for (auto& frame : outstanding_frames_) {
|
||||
LOG(WARNING) << frame.first;
|
||||
FrameState* frame_state = frame.second;
|
||||
frame_state->DumpIterationState(this);
|
||||
}
|
||||
dumped_on_error_ = true;
|
||||
LOG(WARNING) << "Dumping state";
|
||||
for (auto& frame : outstanding_frames_) {
|
||||
LOG(WARNING) << frame.first;
|
||||
FrameState* frame_state = frame.second;
|
||||
frame_state->DumpIterationState(this);
|
||||
}
|
||||
}
|
||||
|
||||
@ -378,7 +306,7 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
||||
|
||||
for (const EdgeInfo& e : item->output_edges()) {
|
||||
const NodeItem& dst_item =
|
||||
*immutable_state_.graph_view().node(e.dst_id);
|
||||
immutable_state_.graph_view().node_ref(e.dst_id);
|
||||
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
|
||||
|
||||
bool dst_dead = true;
|
||||
@ -398,7 +326,7 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
||||
|
||||
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
||||
const NodeItem& dst_item =
|
||||
*immutable_state_.graph_view().node(e.dst_id);
|
||||
immutable_state_.graph_view().node_ref(e.dst_id);
|
||||
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
|
||||
|
||||
bool dst_dead;
|
||||
@ -464,17 +392,17 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
|
||||
//
|
||||
// NOTE(mrry): Use a macro here instead of a lambda, because this method is
|
||||
// performance-critical and we need to ensure that the code is inlined.
|
||||
#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
|
||||
do { \
|
||||
if (!adjust_result.any_pending) { \
|
||||
const NodeItem* dst_item = gview.node(dst_id); \
|
||||
TaggedNode& t = ready->emplace_back(); \
|
||||
t.node_item = dst_item; \
|
||||
t.input_frame = this; \
|
||||
t.input_iter = iter; \
|
||||
t.is_dead = adjust_result.any_dead; \
|
||||
iter_state->outstanding_ops++; \
|
||||
} \
|
||||
#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
|
||||
do { \
|
||||
if (!adjust_result.any_pending) { \
|
||||
const NodeItem* dst_item = &gview.node_ref(dst_id); \
|
||||
TaggedNode& t = ready->emplace_back(); \
|
||||
t.node_item = dst_item; \
|
||||
t.input_frame = this; \
|
||||
t.input_iter = iter; \
|
||||
t.is_dead = adjust_result.any_dead; \
|
||||
iter_state->outstanding_ops++; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
Entry* input_tensors = iter_state->input_tensors;
|
||||
@ -534,7 +462,7 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
|
||||
|
||||
for (const EdgeInfo& e : item->output_edges()) {
|
||||
const int dst_id = e.dst_id;
|
||||
const NodeItem* dst_item = gview.node(dst_id);
|
||||
const NodeItem* dst_item = &gview.node_ref(dst_id);
|
||||
const PendingCounts::Handle dst_pending_id =
|
||||
immutable_state.pending_ids()[dst_id];
|
||||
const int src_slot = e.output_slot;
|
||||
@ -596,7 +524,7 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
|
||||
|
||||
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
||||
const int dst_id = e.dst_id;
|
||||
const NodeItem* dst_item = gview.node(dst_id);
|
||||
const NodeItem* dst_item = &gview.node_ref(dst_id);
|
||||
const PendingCounts::Handle dst_pending_id =
|
||||
immutable_state.pending_ids()[dst_id];
|
||||
|
||||
|
@ -429,26 +429,15 @@ class PropagatorState {
|
||||
void CleanupFramesIterations(FrameState* frame, int64 iter,
|
||||
TaggedNodeSeq* ready);
|
||||
|
||||
// Provide debugging output about an outstanding node in the executor.
|
||||
void DumpPendingNodeState(const int node_id, const Entry* input_vector,
|
||||
bool show_nodes_with_no_ready_inputs);
|
||||
void DumpActiveNodeState(const int node_id, const Entry* input_vector);
|
||||
|
||||
// Provide debugging output about an outstanding iteration in the executor.
|
||||
void DumpIterationState(const FrameState* frame, IterationState* iteration);
|
||||
|
||||
const Tensor* GetTensorValueForDump(const Entry& input);
|
||||
|
||||
const ImmutableExecutorState& immutable_state_;
|
||||
const int64 step_id_;
|
||||
const bool vlog_;
|
||||
|
||||
mutex mu_;
|
||||
|
||||
// A flag that is set on error after the frame state has been
|
||||
// dumped for diagnostic purposes.
|
||||
bool dumped_on_error_ TF_GUARDED_BY(mu_) = false;
|
||||
|
||||
// The root frame in which the execution of this step is started.
|
||||
FrameState* root_frame_;
|
||||
|
||||
|
134
tensorflow/core/common_runtime/simple_propagator_state.cc
Normal file
134
tensorflow/core/common_runtime/simple_propagator_state.cc
Normal file
@ -0,0 +1,134 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/common_runtime/simple_propagator_state.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
SimplePropagatorState::SimplePropagatorState(
|
||||
const ImmutableExecutorState& immutable_state, int64 step_id)
|
||||
: SimplePropagatorState(immutable_state, step_id,
|
||||
immutable_state.get_root_frame_info()) {}
|
||||
|
||||
SimplePropagatorState::SimplePropagatorState(
|
||||
const ImmutableExecutorState& immutable_state, int64 step_id,
|
||||
const ImmutableExecutorState::FrameInfo& finfo)
|
||||
: immutable_state_(immutable_state),
|
||||
step_id_(step_id),
|
||||
vlog_(VLOG_IS_ON(1)),
|
||||
input_tensors_(finfo.total_inputs),
|
||||
counts_(*finfo.pending_counts),
|
||||
nodes_(finfo.nodes.get()) {}
|
||||
|
||||
SimplePropagatorState::~SimplePropagatorState() {}
|
||||
|
||||
void SimplePropagatorState::ActivateRoots(
|
||||
gtl::ArraySlice<const NodeItem*> roots, TaggedNodeSeq* ready) {
|
||||
for (const NodeItem* item : roots) {
|
||||
DCHECK_EQ(item->num_inputs, 0);
|
||||
ready->emplace_back(item);
|
||||
}
|
||||
}
|
||||
|
||||
void SimplePropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
EntryVector* outputs,
|
||||
TaggedNodeSeq* ready) {
|
||||
profiler::TraceMe activity(
|
||||
[&]() {
|
||||
return strings::StrCat(
|
||||
"ExecutorPropagateOutputs#", "id=", step_id_,
|
||||
",kernel_name=", tagged_node.node_item->kernel->name_view(),
|
||||
",num_output_edges=", tagged_node.node_item->num_output_edges,
|
||||
",num_output_control_edges=",
|
||||
tagged_node.node_item->num_output_control_edges, "#");
|
||||
},
|
||||
profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
|
||||
|
||||
// Propagates outputs along out edges, and puts newly ready nodes
|
||||
// into the ready queue.
|
||||
DCHECK(ready->empty());
|
||||
|
||||
const GraphView& gview = immutable_state_.graph_view();
|
||||
const NodeItem* item = tagged_node.node_item;
|
||||
|
||||
mutex_lock l(mu_);
|
||||
|
||||
for (const EdgeInfo& e : item->output_edges()) {
|
||||
const int dst_id = e.dst_id;
|
||||
const PendingCounts::Handle dst_pending_id =
|
||||
immutable_state_.pending_ids()[dst_id];
|
||||
const int src_slot = e.output_slot;
|
||||
int num_pending = counts_.decrement_pending(dst_pending_id, 1);
|
||||
const int dst_loc = e.input_slot;
|
||||
if (e.is_last) {
|
||||
input_tensors_[dst_loc] = std::move((*outputs)[src_slot]);
|
||||
} else {
|
||||
input_tensors_[dst_loc] = (*outputs)[src_slot];
|
||||
}
|
||||
if (num_pending == 0) ready->emplace_back(&gview.node_ref(dst_id));
|
||||
}
|
||||
|
||||
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
||||
const int dst_id = e.dst_id;
|
||||
const PendingCounts::Handle dst_pending_id =
|
||||
immutable_state_.pending_ids()[dst_id];
|
||||
int num_pending = counts_.decrement_pending(dst_pending_id, 1);
|
||||
if (num_pending == 0) ready->emplace_back(&gview.node_ref(dst_id));
|
||||
}
|
||||
}
|
||||
|
||||
void SimplePropagatorState::DumpState() {
|
||||
mutex_lock l(mu_);
|
||||
LOG(WARNING) << "Dumping state";
|
||||
|
||||
// Dump any waiting nodes that are holding on to tensors.
|
||||
for (const NodeItem* node : *nodes_) {
|
||||
PendingCounts::Handle pending_id =
|
||||
immutable_state_.pending_ids()[node->node_id];
|
||||
if (counts_.node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
|
||||
counts_.node_state(pending_id) == PendingCounts::PENDING_READY) {
|
||||
DumpPendingNodeState(immutable_state_, node->node_id,
|
||||
input_tensors_.data(), false);
|
||||
}
|
||||
}
|
||||
// Then the active nodes.
|
||||
for (const NodeItem* node : *nodes_) {
|
||||
PendingCounts::Handle pending_id =
|
||||
immutable_state_.pending_ids()[node->node_id];
|
||||
if (counts_.node_state(pending_id) == PendingCounts::STARTED) {
|
||||
DumpActiveNodeState(immutable_state_, node->node_id,
|
||||
input_tensors_.data());
|
||||
}
|
||||
}
|
||||
// Show all input tensors in use.
|
||||
size_t total_bytes = 0;
|
||||
for (size_t i = 0; i < input_tensors_.size(); ++i) {
|
||||
const Entry& input = input_tensors_[i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor && tensor->IsInitialized()) {
|
||||
LOG(WARNING) << " Input " << i << ": "
|
||||
<< strings::StrCat(
|
||||
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||
" shape: ", tensor->shape().DebugString(),
|
||||
", bytes: ", tensor->TotalBytes(), ">");
|
||||
total_bytes += tensor->TotalBytes();
|
||||
}
|
||||
}
|
||||
LOG(WARNING) << " Total bytes " << total_bytes;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
186
tensorflow/core/common_runtime/simple_propagator_state.h
Normal file
186
tensorflow/core/common_runtime/simple_propagator_state.h
Normal file
@ -0,0 +1,186 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/entry.h"
|
||||
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
|
||||
#include "tensorflow/core/common_runtime/pending_counts.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Represents the ephemeral "edge state" associated with one invocation of
|
||||
// `Executor::Run()`.
|
||||
//
|
||||
// NOTE: `SimplePropagatorState` does not support "v1-style" control flow,
|
||||
// including "dead tensors", "Switch" and "Merge" nodes, and cycles in the
|
||||
// graph. Use `PropagatorState` for graphs with those features.
|
||||
// `SimplePropagatorState` *does* support "v2-style" or "functional" control
|
||||
// flow.
|
||||
//
|
||||
// `SimplePropagatorState` is responsible for propagating values along dataflow
|
||||
// edges in a TensorFlow graph and determining which nodes are runnable. The
|
||||
// executor primarily updates `SimplePropagatorState` by calling
|
||||
// `PropagateOutputs()` after processing a node, and `SimplePropagatorState`
|
||||
// dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`.
|
||||
class SimplePropagatorState {
|
||||
public:
|
||||
SimplePropagatorState(const ImmutableExecutorState& immutable_state,
|
||||
int64 step_id);
|
||||
~SimplePropagatorState();
|
||||
|
||||
// A `TaggedNode` corresponds to a single invocation of a node's kernel,
|
||||
// and it is created when the kernel becomes runnable.
|
||||
struct TaggedNode {
|
||||
const NodeItem* node_item;
|
||||
|
||||
explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {}
|
||||
|
||||
const NodeItem& get_node_item() const { return *node_item; }
|
||||
|
||||
bool get_is_dead() const { return false; }
|
||||
int64 get_iter_num() const { return 0; }
|
||||
};
|
||||
|
||||
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
|
||||
// have that many nodes in the ready queue, so we just use a vector and
|
||||
// don't free up memory from the queue as we consume nodes.
|
||||
// TODO(mrry): Extract this and share it with the version in
|
||||
// `PropagatorState`. The correct constants might be different, since
|
||||
// sizeof(TaggedNode) is smaller in this version.
|
||||
class TaggedNodeReadyQueue {
|
||||
public:
|
||||
TaggedNodeReadyQueue() : front_index_(0) {}
|
||||
|
||||
void push_back(const TaggedNode& node) { ready_.push_back(node); }
|
||||
TaggedNode front() const {
|
||||
DCHECK_LT(front_index_, ready_.size());
|
||||
return ready_[front_index_];
|
||||
}
|
||||
void pop_front() {
|
||||
DCHECK_LT(front_index_, ready_.size());
|
||||
front_index_++;
|
||||
if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
|
||||
if (front_index_ == ready_.size()) {
|
||||
ready_.clear();
|
||||
} else {
|
||||
// Lots of unused entries at beginning of vector: move everything
|
||||
// down to start of vector.
|
||||
ready_.erase(ready_.begin(), ready_.begin() + front_index_);
|
||||
}
|
||||
front_index_ = 0;
|
||||
}
|
||||
}
|
||||
bool empty() const { return ready_.empty(); }
|
||||
|
||||
private:
|
||||
// TODO(b/152925936): Re-evaluate these constants with current usage
|
||||
// patterns.
|
||||
static constexpr int kSpillThreshold = 16384;
|
||||
gtl::InlinedVector<TaggedNode, 16> ready_;
|
||||
int front_index_;
|
||||
};
|
||||
|
||||
// TODO(b/152925936): Re-evaluate this constant with current usage patterns.
|
||||
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
|
||||
|
||||
// Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
|
||||
void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
|
||||
TaggedNodeSeq* ready);
|
||||
|
||||
// After processing the outputs, propagates the outputs to their dsts.
|
||||
// Contents of *outputs are left in an indeterminate state after
|
||||
// returning from this method.
|
||||
void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
|
||||
TaggedNodeSeq* ready);
|
||||
|
||||
// Returns an array of `Entry` objects corresponding to the inputs of
|
||||
// `tagged_node`.
|
||||
//
|
||||
// NOTE: Thread safety analysis is disabled on this method, because the
|
||||
// underlying `IterationState` and its array of `input_tensors` retain the
|
||||
// same address while the iteration is live.
|
||||
Entry* GetInputTensors(const TaggedNode& tagged_node)
|
||||
TF_NO_THREAD_SAFETY_ANALYSIS {
|
||||
return input_tensors_.data() + tagged_node.node_item->input_start;
|
||||
}
|
||||
|
||||
FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
|
||||
return {0, 0};
|
||||
}
|
||||
|
||||
// Provide debugging output of the state of the executor.
|
||||
void DumpState();
|
||||
|
||||
// For debugging/logging only.
|
||||
void MaybeMarkStarted(const TaggedNode& tagged_node) {
|
||||
// TODO(misard) Replace with a finer-grain enabling flag once we add better
|
||||
// optional debugging support.
|
||||
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
||||
mutex_lock l(mu_);
|
||||
counts_.mark_started(
|
||||
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||
}
|
||||
}
|
||||
void MaybeMarkCompleted(const TaggedNode& tagged_node) {
|
||||
// TODO(misard) Replace with a finer-grain enabling flag once we add better
|
||||
// optional debugging support.
|
||||
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
||||
mutex_lock l(mu_);
|
||||
counts_.mark_completed(
|
||||
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
SimplePropagatorState(const ImmutableExecutorState& immutable_state_,
|
||||
int64 step_id,
|
||||
const ImmutableExecutorState::FrameInfo& finfo);
|
||||
|
||||
const ImmutableExecutorState& immutable_state_;
|
||||
const int64 step_id_;
|
||||
const bool vlog_;
|
||||
|
||||
mutex mu_;
|
||||
|
||||
// The i-th node's j-th input is stored at
|
||||
// `input_tensors_[impl_->nodes[i].input_start + j]`.
|
||||
//
|
||||
// NOTE: No need to protect input_tensors[i] by any locks because it
|
||||
// is resized once. Each element of input_tensors is written once by the
|
||||
// source node of an edge and is cleared by the destination of the same
|
||||
// edge. The destination node always runs after the source node, so there
|
||||
// is never concurrent access to the same entry.
|
||||
std::vector<Entry> input_tensors_;
|
||||
|
||||
PendingCounts counts_ TF_GUARDED_BY(mu_);
|
||||
|
||||
const std::vector<const NodeItem*>* const nodes_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SimplePropagatorState);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
|
Loading…
Reference in New Issue
Block a user