Rolling forward "[Executor] Implement SimplePropagatorState
and use it when a graph has no v1-style control flow."
The previous version had undefined behavior when executing a graph comprised solely of operations with no input tensors, in which case the `SimpleExecutorState::input_tensors_` vector would be empty. Using `vector::data()` instead of `operator[]` avoids creating a reference to a null element when the vector is empty. PiperOrigin-RevId: 304510355 Change-Id: I12785a6ac72355d0e5b1a312d41e227dbdf9dc96
This commit is contained in:
parent
9ba3ff2df6
commit
0bde6a68f2
@ -232,6 +232,7 @@ filegroup(
|
|||||||
"process_util.h",
|
"process_util.h",
|
||||||
"inspecting_placer.h",
|
"inspecting_placer.h",
|
||||||
"profile_handler.h",
|
"profile_handler.h",
|
||||||
|
"propagator_debug_utils.h",
|
||||||
"propagator_state.h",
|
"propagator_state.h",
|
||||||
"renamed_device.h",
|
"renamed_device.h",
|
||||||
"rendezvous_mgr.h",
|
"rendezvous_mgr.h",
|
||||||
@ -241,6 +242,7 @@ filegroup(
|
|||||||
"ring_alg.h",
|
"ring_alg.h",
|
||||||
"ring_gatherer.h",
|
"ring_gatherer.h",
|
||||||
"session_factory.h",
|
"session_factory.h",
|
||||||
|
"simple_propagator_state.h",
|
||||||
"single_threaded_cpu_device.h",
|
"single_threaded_cpu_device.h",
|
||||||
"stats_publisher_interface.h",
|
"stats_publisher_interface.h",
|
||||||
"step_stats_collector.h",
|
"step_stats_collector.h",
|
||||||
@ -304,6 +306,7 @@ tf_cuda_library(
|
|||||||
"process_function_library_runtime.cc",
|
"process_function_library_runtime.cc",
|
||||||
"process_state.cc",
|
"process_state.cc",
|
||||||
"process_util.cc",
|
"process_util.cc",
|
||||||
|
"propagator_debug_utils.cc",
|
||||||
"propagator_state.cc",
|
"propagator_state.cc",
|
||||||
"renamed_device.cc",
|
"renamed_device.cc",
|
||||||
"rendezvous_mgr.cc",
|
"rendezvous_mgr.cc",
|
||||||
@ -316,6 +319,7 @@ tf_cuda_library(
|
|||||||
"session_factory.cc",
|
"session_factory.cc",
|
||||||
"session_options.cc",
|
"session_options.cc",
|
||||||
"session_state.cc",
|
"session_state.cc",
|
||||||
|
"simple_propagator_state.cc",
|
||||||
"single_threaded_cpu_device.cc",
|
"single_threaded_cpu_device.cc",
|
||||||
"stats_publisher_interface.cc",
|
"stats_publisher_interface.cc",
|
||||||
"step_stats_collector.cc",
|
"step_stats_collector.cc",
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/pending_counts.h"
|
#include "tensorflow/core/common_runtime/pending_counts.h"
|
||||||
#include "tensorflow/core/common_runtime/propagator_state.h"
|
#include "tensorflow/core/common_runtime/propagator_state.h"
|
||||||
#include "tensorflow/core/common_runtime/renamed_device.h"
|
#include "tensorflow/core/common_runtime/renamed_device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/simple_propagator_state.h"
|
||||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
@ -372,6 +373,10 @@ class ExecutorState {
|
|||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
Status status_ TF_GUARDED_BY(mu_);
|
Status status_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// A flag that is set on error after the propagator state has been
|
||||||
|
// dumped for diagnostic purposes.
|
||||||
|
bool dumped_on_error_ TF_GUARDED_BY(mu_) = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class PropagatorStateType>
|
template <class PropagatorStateType>
|
||||||
@ -925,7 +930,11 @@ Status ExecutorState<PropagatorStateType>::ProcessOutputs(
|
|||||||
// add better optional debugging support.
|
// add better optional debugging support.
|
||||||
if (vlog_ && VLOG_IS_ON(1)) {
|
if (vlog_ && VLOG_IS_ON(1)) {
|
||||||
LOG(WARNING) << this << " Compute status: " << s;
|
LOG(WARNING) << this << " Compute status: " << s;
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (!dumped_on_error_) {
|
||||||
propagator_.DumpState();
|
propagator_.DumpState();
|
||||||
|
dumped_on_error_ = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (s.code() == error::RESOURCE_EXHAUSTED) {
|
if (s.code() == error::RESOURCE_EXHAUSTED) {
|
||||||
if (stats_collector_) {
|
if (stats_collector_) {
|
||||||
@ -1256,8 +1265,14 @@ void ExecutorState<PropagatorStateType>::Finish() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
|
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
|
||||||
|
if (immutable_state_.requires_control_flow_support()) {
|
||||||
(new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
|
(new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
|
||||||
->RunAsync(std::move(done));
|
->RunAsync(std::move(done));
|
||||||
|
} else {
|
||||||
|
(new ExecutorState<SimplePropagatorState>(args, immutable_state_,
|
||||||
|
&kernel_stats_))
|
||||||
|
->RunAsync(std::move(done));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -413,6 +413,14 @@ TEST_F(ExecutorTest, RecvInvalidRefDtype) {
|
|||||||
rendez->Unref();
|
rendez->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ExecutorTest, NoInputTensors) {
|
||||||
|
// Create a graph where none of the nodes have input tensors.
|
||||||
|
auto g = absl::make_unique<Graph>(OpRegistry::Global());
|
||||||
|
test::graph::Constant(g.get(), V(1.0));
|
||||||
|
Create(std::move(g));
|
||||||
|
TF_ASSERT_OK(Run(rendez_));
|
||||||
|
}
|
||||||
|
|
||||||
// Create a graph that is 'depth' deep. At each level, fan-in and fan-out a
|
// Create a graph that is 'depth' deep. At each level, fan-in and fan-out a
|
||||||
// maximum of 'width' nodes. All nodes are no-ops and all dependencies are
|
// maximum of 'width' nodes. All nodes are no-ops and all dependencies are
|
||||||
// control dependencies.
|
// control dependencies.
|
||||||
|
@ -211,6 +211,8 @@ class GraphView {
|
|||||||
Status SetAllocAttrs(const Graph* g, const Device* device);
|
Status SetAllocAttrs(const Graph* g, const Device* device);
|
||||||
void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
|
void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
|
||||||
|
|
||||||
|
// Returns a mutable pointer to the `NodeItem` with the given `id` if it
|
||||||
|
// exists in the graph, or `nullptr` if it does not.
|
||||||
NodeItem* node(int32 id) const {
|
NodeItem* node(int32 id) const {
|
||||||
DCHECK_GE(id, 0);
|
DCHECK_GE(id, 0);
|
||||||
DCHECK_LT(id, num_nodes_);
|
DCHECK_LT(id, num_nodes_);
|
||||||
@ -220,6 +222,17 @@ class GraphView {
|
|||||||
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
|
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the `NodeItem` with the given `id`.
|
||||||
|
//
|
||||||
|
// REQUIRES: `id` must be the ID of a valid node in the graph.
|
||||||
|
const NodeItem& node_ref(int32 id) const {
|
||||||
|
DCHECK_GE(id, 0);
|
||||||
|
DCHECK_LT(id, num_nodes_);
|
||||||
|
uint32 offset = node_offsets_[id];
|
||||||
|
DCHECK_NE(offset, kuint32max);
|
||||||
|
return *reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]);
|
||||||
|
}
|
||||||
|
|
||||||
int32 num_nodes() const { return num_nodes_; }
|
int32 num_nodes() const { return num_nodes_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/common_runtime/metrics.h"
|
#include "tensorflow/core/common_runtime/metrics.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/graph/edgeset.h"
|
#include "tensorflow/core/graph/edgeset.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
@ -88,13 +89,33 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) {
|
|||||||
EnsureFrameInfo(it)->nodes =
|
EnsureFrameInfo(it)->nodes =
|
||||||
absl::make_unique<std::vector<const NodeItem*>>();
|
absl::make_unique<std::vector<const NodeItem*>>();
|
||||||
}
|
}
|
||||||
|
root_frame_info_ = frame_info_[""];
|
||||||
|
|
||||||
pending_ids_.resize(gview_.num_nodes());
|
pending_ids_.resize(gview_.num_nodes());
|
||||||
|
|
||||||
// Preprocess every node in the graph to create an instance of op
|
// Preprocess every node in the graph to create an instance of op
|
||||||
// kernel for each node.
|
// kernel for each node.
|
||||||
|
requires_control_flow_ = false;
|
||||||
for (const Node* n : graph.nodes()) {
|
for (const Node* n : graph.nodes()) {
|
||||||
if (IsSink(n)) continue;
|
if (IsSink(n)) continue;
|
||||||
|
if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) {
|
||||||
|
requires_control_flow_ = true;
|
||||||
|
} else if (IsRecv(n)) {
|
||||||
|
// A Recv node from a different device may produce dead tensors from
|
||||||
|
// non-local control-flow nodes.
|
||||||
|
//
|
||||||
|
// TODO(mrry): Track whether control flow was present in the
|
||||||
|
// pre-partitioned graph, and enable the caller (e.g.
|
||||||
|
// `DirectSession`) to relax this constraint.
|
||||||
|
string send_device;
|
||||||
|
string recv_device;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device));
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device));
|
||||||
|
if (send_device != recv_device) {
|
||||||
|
requires_control_flow_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const int id = n->id();
|
const int id = n->id();
|
||||||
const string& frame_name = cf_info.frame_names[id];
|
const string& frame_name = cf_info.frame_names[id];
|
||||||
FrameInfo* frame_info = EnsureFrameInfo(frame_name);
|
FrameInfo* frame_info = EnsureFrameInfo(frame_name);
|
||||||
|
@ -91,6 +91,10 @@ class ImmutableExecutorState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const FrameInfo& get_root_frame_info() const { return *root_frame_info_; }
|
||||||
|
|
||||||
|
bool requires_control_flow_support() const { return requires_control_flow_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct ControlFlowInfo {
|
struct ControlFlowInfo {
|
||||||
gtl::FlatSet<string> unique_frame_names;
|
gtl::FlatSet<string> unique_frame_names;
|
||||||
@ -106,6 +110,7 @@ class ImmutableExecutorState {
|
|||||||
// Owned.
|
// Owned.
|
||||||
LocalExecutorParams params_;
|
LocalExecutorParams params_;
|
||||||
GraphView gview_;
|
GraphView gview_;
|
||||||
|
bool requires_control_flow_;
|
||||||
std::vector<PendingCounts::Handle> pending_ids_;
|
std::vector<PendingCounts::Handle> pending_ids_;
|
||||||
|
|
||||||
// Root nodes (with no in edges) that should form the initial ready queue
|
// Root nodes (with no in edges) that should form the initial ready queue
|
||||||
@ -115,6 +120,7 @@ class ImmutableExecutorState {
|
|||||||
// TODO(yuanbyu): We could cache it along with the graph so to avoid
|
// TODO(yuanbyu): We could cache it along with the graph so to avoid
|
||||||
// the overhead of constructing it for each executor instance.
|
// the overhead of constructing it for each executor instance.
|
||||||
gtl::FlatMap<string, FrameInfo*> frame_info_;
|
gtl::FlatMap<string, FrameInfo*> frame_info_;
|
||||||
|
const FrameInfo* root_frame_info_; // Not owned.
|
||||||
|
|
||||||
// Shallow copies of the constant tensors used in the graph.
|
// Shallow copies of the constant tensors used in the graph.
|
||||||
std::vector<Tensor> const_tensors_;
|
std::vector<Tensor> const_tensors_;
|
||||||
|
95
tensorflow/core/common_runtime/propagator_debug_utils.cc
Normal file
95
tensorflow/core/common_runtime/propagator_debug_utils.cc
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/entry.h"
|
||||||
|
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// 1-D, 0 element tensor.
|
||||||
|
static const Tensor* const kEmptyTensor = new Tensor;
|
||||||
|
|
||||||
|
|
||||||
|
const Tensor* GetTensorValueForDump(const Entry& input) {
|
||||||
|
switch (input.state) {
|
||||||
|
case Entry::State::NO_VALUE:
|
||||||
|
return kEmptyTensor;
|
||||||
|
case Entry::State::HAS_VALUE:
|
||||||
|
return input.val.get();
|
||||||
|
case Entry::State::HAS_CONST_TENSOR:
|
||||||
|
return input.const_tensor;
|
||||||
|
case Entry::State::HAS_REF_TENSOR:
|
||||||
|
return input.ref_tensor.tensor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DumpPendingNodeState(const ImmutableExecutorState& immutable_state,
|
||||||
|
const int node_id, const Entry* input_vector,
|
||||||
|
const bool show_nodes_with_no_ready_inputs) {
|
||||||
|
const NodeItem& node_item = immutable_state.graph_view().node_ref(node_id);
|
||||||
|
const int input_base = node_item.input_start;
|
||||||
|
if (!show_nodes_with_no_ready_inputs) {
|
||||||
|
bool has_ready_input = false;
|
||||||
|
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||||
|
const Entry& input = input_vector[input_base + i];
|
||||||
|
const Tensor* tensor = GetTensorValueForDump(input);
|
||||||
|
if (tensor && tensor->IsInitialized()) {
|
||||||
|
has_ready_input = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!has_ready_input) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG(WARNING) << " Pending Node: " << node_item.DebugString();
|
||||||
|
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||||
|
const Entry& input = input_vector[input_base + i];
|
||||||
|
const Tensor* tensor = GetTensorValueForDump(input);
|
||||||
|
if (tensor->IsInitialized()) {
|
||||||
|
LOG(WARNING) << " Input " << i << ": "
|
||||||
|
<< strings::StrCat(
|
||||||
|
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||||
|
" shape: ", tensor->shape().DebugString(), ">");
|
||||||
|
} else {
|
||||||
|
LOG(WARNING) << " Input " << i << ": not present";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DumpActiveNodeState(const ImmutableExecutorState& immutable_state,
|
||||||
|
const int node_id, const Entry* input_vector) {
|
||||||
|
const NodeItem& node_item = immutable_state.graph_view().node_ref(node_id);
|
||||||
|
LOG(WARNING) << " Active Node: " << node_item.DebugString();
|
||||||
|
const int input_base = node_item.input_start;
|
||||||
|
for (int i = 0; i < node_item.num_inputs; ++i) {
|
||||||
|
const Entry& input = input_vector[input_base + i];
|
||||||
|
const Tensor* tensor = GetTensorValueForDump(input);
|
||||||
|
if (tensor->IsInitialized()) {
|
||||||
|
LOG(WARNING) << " Input " << i << ": "
|
||||||
|
<< strings::StrCat(
|
||||||
|
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||||
|
" shape: ", tensor->shape().DebugString(), ">");
|
||||||
|
} else {
|
||||||
|
LOG(WARNING) << " Input " << i << ": not present";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
40
tensorflow/core/common_runtime/propagator_debug_utils.h
Normal file
40
tensorflow/core/common_runtime/propagator_debug_utils.h
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_
|
||||||
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
struct Entry;
|
||||||
|
class ImmutableExecutorState;
|
||||||
|
class Tensor;
|
||||||
|
|
||||||
|
// Returns a pointer to the tensor in `input` if one exists, or `nullptr`.
|
||||||
|
const Tensor* GetTensorValueForDump(const Entry& input);
|
||||||
|
|
||||||
|
// Writes a LOG(WARNING) message describing the state of the pending node
|
||||||
|
// `node_id` in the graph described by `immutable_state`.
|
||||||
|
void DumpPendingNodeState(const ImmutableExecutorState& immutable_state,
|
||||||
|
const int node_id, const Entry* input_vector,
|
||||||
|
const bool show_nodes_with_no_ready_inputs);
|
||||||
|
|
||||||
|
// Writes a LOG(WARNING) message describing the state of the active node
|
||||||
|
// `node_id` in the graph described by `immutable_state`.
|
||||||
|
void DumpActiveNodeState(const ImmutableExecutorState& immutable_state,
|
||||||
|
const int node_id, const Entry* input_vector);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_
|
@ -16,17 +16,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/propagator_state.h"
|
#include "tensorflow/core/common_runtime/propagator_state.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/graph_view.h"
|
#include "tensorflow/core/common_runtime/graph_view.h"
|
||||||
|
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// 1-D, 0 element tensor.
|
|
||||||
static const Tensor* const kEmptyTensor = new Tensor;
|
|
||||||
|
|
||||||
typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
|
|
||||||
typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
|
|
||||||
|
|
||||||
PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
|
PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
|
||||||
int64 step_id)
|
int64 step_id)
|
||||||
: immutable_state_(immutable_state),
|
: immutable_state_(immutable_state),
|
||||||
@ -57,7 +52,7 @@ void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
|
|||||||
TaggedNodeSeq* ready) {
|
TaggedNodeSeq* ready) {
|
||||||
for (const NodeItem* item : roots) {
|
for (const NodeItem* item : roots) {
|
||||||
DCHECK_EQ(item->num_inputs, 0);
|
DCHECK_EQ(item->num_inputs, 0);
|
||||||
ready->push_back(TaggedNode{item, root_frame_, 0, false});
|
ready->emplace_back(item, root_frame_, 0, false);
|
||||||
}
|
}
|
||||||
mutex_lock l(root_frame_->mu);
|
mutex_lock l(root_frame_->mu);
|
||||||
root_frame_->GetIteration(0)->outstanding_ops = ready->size();
|
root_frame_->GetIteration(0)->outstanding_ops = ready->size();
|
||||||
@ -173,72 +168,6 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const Tensor* PropagatorState::GetTensorValueForDump(const Entry& input) {
|
|
||||||
switch (input.state) {
|
|
||||||
case Entry::State::NO_VALUE:
|
|
||||||
return kEmptyTensor;
|
|
||||||
case Entry::State::HAS_VALUE:
|
|
||||||
return input.val.get();
|
|
||||||
case Entry::State::HAS_CONST_TENSOR:
|
|
||||||
return input.const_tensor;
|
|
||||||
case Entry::State::HAS_REF_TENSOR:
|
|
||||||
return input.ref_tensor.tensor;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void PropagatorState::DumpPendingNodeState(
|
|
||||||
const int node_id, const Entry* input_vector,
|
|
||||||
const bool show_nodes_with_no_ready_inputs) {
|
|
||||||
const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
|
|
||||||
const int input_base = node_item.input_start;
|
|
||||||
if (!show_nodes_with_no_ready_inputs) {
|
|
||||||
bool has_ready_input = false;
|
|
||||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
|
||||||
const Entry& input = input_vector[input_base + i];
|
|
||||||
const Tensor* tensor = GetTensorValueForDump(input);
|
|
||||||
if (tensor->IsInitialized()) {
|
|
||||||
has_ready_input = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!has_ready_input) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LOG(WARNING) << " Pending Node: " << node_item.DebugString();
|
|
||||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
|
||||||
const Entry& input = input_vector[input_base + i];
|
|
||||||
const Tensor* tensor = GetTensorValueForDump(input);
|
|
||||||
if (tensor->IsInitialized()) {
|
|
||||||
LOG(WARNING) << " Input " << i << ": "
|
|
||||||
<< strings::StrCat(
|
|
||||||
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
|
||||||
" shape: ", tensor->shape().DebugString(), ">");
|
|
||||||
} else {
|
|
||||||
LOG(WARNING) << " Input " << i << ": not present";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void PropagatorState::DumpActiveNodeState(const int node_id,
|
|
||||||
const Entry* input_vector) {
|
|
||||||
const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
|
|
||||||
LOG(WARNING) << " Active Node: " << node_item.DebugString();
|
|
||||||
const int input_base = node_item.input_start;
|
|
||||||
for (int i = 0; i < node_item.num_inputs; ++i) {
|
|
||||||
const Entry& input = input_vector[input_base + i];
|
|
||||||
const Tensor* tensor = GetTensorValueForDump(input);
|
|
||||||
if (tensor->IsInitialized()) {
|
|
||||||
LOG(WARNING) << " Input " << i << ": "
|
|
||||||
<< strings::StrCat(
|
|
||||||
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
|
||||||
" shape: ", tensor->shape().DebugString(), ">");
|
|
||||||
} else {
|
|
||||||
LOG(WARNING) << " Input " << i << ": not present";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void PropagatorState::DumpIterationState(const FrameState* frame,
|
void PropagatorState::DumpIterationState(const FrameState* frame,
|
||||||
IterationState* iteration) {
|
IterationState* iteration) {
|
||||||
const std::vector<const NodeItem*>* nodes = frame->nodes;
|
const std::vector<const NodeItem*>* nodes = frame->nodes;
|
||||||
@ -248,7 +177,8 @@ void PropagatorState::DumpIterationState(const FrameState* frame,
|
|||||||
immutable_state_.pending_ids()[node->node_id];
|
immutable_state_.pending_ids()[node->node_id];
|
||||||
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
|
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
|
||||||
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
|
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
|
||||||
DumpPendingNodeState(node->node_id, iteration->input_tensors, false);
|
DumpPendingNodeState(immutable_state_, node->node_id,
|
||||||
|
iteration->input_tensors, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Then the active nodes.
|
// Then the active nodes.
|
||||||
@ -256,7 +186,8 @@ void PropagatorState::DumpIterationState(const FrameState* frame,
|
|||||||
PendingCounts::Handle pending_id =
|
PendingCounts::Handle pending_id =
|
||||||
immutable_state_.pending_ids()[node->node_id];
|
immutable_state_.pending_ids()[node->node_id];
|
||||||
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
|
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
|
||||||
DumpActiveNodeState(node->node_id, iteration->input_tensors);
|
DumpActiveNodeState(immutable_state_, node->node_id,
|
||||||
|
iteration->input_tensors);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Show all input tensors in use.
|
// Show all input tensors in use.
|
||||||
@ -279,15 +210,12 @@ void PropagatorState::DumpIterationState(const FrameState* frame,
|
|||||||
|
|
||||||
void PropagatorState::DumpState() {
|
void PropagatorState::DumpState() {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (!dumped_on_error_) {
|
|
||||||
LOG(WARNING) << "Dumping state";
|
LOG(WARNING) << "Dumping state";
|
||||||
for (auto& frame : outstanding_frames_) {
|
for (auto& frame : outstanding_frames_) {
|
||||||
LOG(WARNING) << frame.first;
|
LOG(WARNING) << frame.first;
|
||||||
FrameState* frame_state = frame.second;
|
FrameState* frame_state = frame.second;
|
||||||
frame_state->DumpIterationState(this);
|
frame_state->DumpIterationState(this);
|
||||||
}
|
}
|
||||||
dumped_on_error_ = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||||
@ -378,7 +306,7 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
|||||||
|
|
||||||
for (const EdgeInfo& e : item->output_edges()) {
|
for (const EdgeInfo& e : item->output_edges()) {
|
||||||
const NodeItem& dst_item =
|
const NodeItem& dst_item =
|
||||||
*immutable_state_.graph_view().node(e.dst_id);
|
immutable_state_.graph_view().node_ref(e.dst_id);
|
||||||
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
|
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
|
||||||
|
|
||||||
bool dst_dead = true;
|
bool dst_dead = true;
|
||||||
@ -398,7 +326,7 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
|||||||
|
|
||||||
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
||||||
const NodeItem& dst_item =
|
const NodeItem& dst_item =
|
||||||
*immutable_state_.graph_view().node(e.dst_id);
|
immutable_state_.graph_view().node_ref(e.dst_id);
|
||||||
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
|
const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
|
||||||
|
|
||||||
bool dst_dead;
|
bool dst_dead;
|
||||||
@ -467,7 +395,7 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
|
|||||||
#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
|
#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
|
||||||
do { \
|
do { \
|
||||||
if (!adjust_result.any_pending) { \
|
if (!adjust_result.any_pending) { \
|
||||||
const NodeItem* dst_item = gview.node(dst_id); \
|
const NodeItem* dst_item = &gview.node_ref(dst_id); \
|
||||||
TaggedNode& t = ready->emplace_back(); \
|
TaggedNode& t = ready->emplace_back(); \
|
||||||
t.node_item = dst_item; \
|
t.node_item = dst_item; \
|
||||||
t.input_frame = this; \
|
t.input_frame = this; \
|
||||||
@ -534,7 +462,7 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
|
|||||||
|
|
||||||
for (const EdgeInfo& e : item->output_edges()) {
|
for (const EdgeInfo& e : item->output_edges()) {
|
||||||
const int dst_id = e.dst_id;
|
const int dst_id = e.dst_id;
|
||||||
const NodeItem* dst_item = gview.node(dst_id);
|
const NodeItem* dst_item = &gview.node_ref(dst_id);
|
||||||
const PendingCounts::Handle dst_pending_id =
|
const PendingCounts::Handle dst_pending_id =
|
||||||
immutable_state.pending_ids()[dst_id];
|
immutable_state.pending_ids()[dst_id];
|
||||||
const int src_slot = e.output_slot;
|
const int src_slot = e.output_slot;
|
||||||
@ -596,7 +524,7 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
|
|||||||
|
|
||||||
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
||||||
const int dst_id = e.dst_id;
|
const int dst_id = e.dst_id;
|
||||||
const NodeItem* dst_item = gview.node(dst_id);
|
const NodeItem* dst_item = &gview.node_ref(dst_id);
|
||||||
const PendingCounts::Handle dst_pending_id =
|
const PendingCounts::Handle dst_pending_id =
|
||||||
immutable_state.pending_ids()[dst_id];
|
immutable_state.pending_ids()[dst_id];
|
||||||
|
|
||||||
|
@ -429,26 +429,15 @@ class PropagatorState {
|
|||||||
void CleanupFramesIterations(FrameState* frame, int64 iter,
|
void CleanupFramesIterations(FrameState* frame, int64 iter,
|
||||||
TaggedNodeSeq* ready);
|
TaggedNodeSeq* ready);
|
||||||
|
|
||||||
// Provide debugging output about an outstanding node in the executor.
|
|
||||||
void DumpPendingNodeState(const int node_id, const Entry* input_vector,
|
|
||||||
bool show_nodes_with_no_ready_inputs);
|
|
||||||
void DumpActiveNodeState(const int node_id, const Entry* input_vector);
|
|
||||||
|
|
||||||
// Provide debugging output about an outstanding iteration in the executor.
|
// Provide debugging output about an outstanding iteration in the executor.
|
||||||
void DumpIterationState(const FrameState* frame, IterationState* iteration);
|
void DumpIterationState(const FrameState* frame, IterationState* iteration);
|
||||||
|
|
||||||
const Tensor* GetTensorValueForDump(const Entry& input);
|
|
||||||
|
|
||||||
const ImmutableExecutorState& immutable_state_;
|
const ImmutableExecutorState& immutable_state_;
|
||||||
const int64 step_id_;
|
const int64 step_id_;
|
||||||
const bool vlog_;
|
const bool vlog_;
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
|
||||||
// A flag that is set on error after the frame state has been
|
|
||||||
// dumped for diagnostic purposes.
|
|
||||||
bool dumped_on_error_ TF_GUARDED_BY(mu_) = false;
|
|
||||||
|
|
||||||
// The root frame in which the execution of this step is started.
|
// The root frame in which the execution of this step is started.
|
||||||
FrameState* root_frame_;
|
FrameState* root_frame_;
|
||||||
|
|
||||||
|
134
tensorflow/core/common_runtime/simple_propagator_state.cc
Normal file
134
tensorflow/core/common_runtime/simple_propagator_state.cc
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/common_runtime/simple_propagator_state.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
|
||||||
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
SimplePropagatorState::SimplePropagatorState(
|
||||||
|
const ImmutableExecutorState& immutable_state, int64 step_id)
|
||||||
|
: SimplePropagatorState(immutable_state, step_id,
|
||||||
|
immutable_state.get_root_frame_info()) {}
|
||||||
|
|
||||||
|
SimplePropagatorState::SimplePropagatorState(
|
||||||
|
const ImmutableExecutorState& immutable_state, int64 step_id,
|
||||||
|
const ImmutableExecutorState::FrameInfo& finfo)
|
||||||
|
: immutable_state_(immutable_state),
|
||||||
|
step_id_(step_id),
|
||||||
|
vlog_(VLOG_IS_ON(1)),
|
||||||
|
input_tensors_(finfo.total_inputs),
|
||||||
|
counts_(*finfo.pending_counts),
|
||||||
|
nodes_(finfo.nodes.get()) {}
|
||||||
|
|
||||||
|
SimplePropagatorState::~SimplePropagatorState() {}
|
||||||
|
|
||||||
|
void SimplePropagatorState::ActivateRoots(
|
||||||
|
gtl::ArraySlice<const NodeItem*> roots, TaggedNodeSeq* ready) {
|
||||||
|
for (const NodeItem* item : roots) {
|
||||||
|
DCHECK_EQ(item->num_inputs, 0);
|
||||||
|
ready->emplace_back(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SimplePropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||||
|
EntryVector* outputs,
|
||||||
|
TaggedNodeSeq* ready) {
|
||||||
|
profiler::TraceMe activity(
|
||||||
|
[&]() {
|
||||||
|
return strings::StrCat(
|
||||||
|
"ExecutorPropagateOutputs#", "id=", step_id_,
|
||||||
|
",kernel_name=", tagged_node.node_item->kernel->name_view(),
|
||||||
|
",num_output_edges=", tagged_node.node_item->num_output_edges,
|
||||||
|
",num_output_control_edges=",
|
||||||
|
tagged_node.node_item->num_output_control_edges, "#");
|
||||||
|
},
|
||||||
|
profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
|
||||||
|
|
||||||
|
// Propagates outputs along out edges, and puts newly ready nodes
|
||||||
|
// into the ready queue.
|
||||||
|
DCHECK(ready->empty());
|
||||||
|
|
||||||
|
const GraphView& gview = immutable_state_.graph_view();
|
||||||
|
const NodeItem* item = tagged_node.node_item;
|
||||||
|
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
|
||||||
|
for (const EdgeInfo& e : item->output_edges()) {
|
||||||
|
const int dst_id = e.dst_id;
|
||||||
|
const PendingCounts::Handle dst_pending_id =
|
||||||
|
immutable_state_.pending_ids()[dst_id];
|
||||||
|
const int src_slot = e.output_slot;
|
||||||
|
int num_pending = counts_.decrement_pending(dst_pending_id, 1);
|
||||||
|
const int dst_loc = e.input_slot;
|
||||||
|
if (e.is_last) {
|
||||||
|
input_tensors_[dst_loc] = std::move((*outputs)[src_slot]);
|
||||||
|
} else {
|
||||||
|
input_tensors_[dst_loc] = (*outputs)[src_slot];
|
||||||
|
}
|
||||||
|
if (num_pending == 0) ready->emplace_back(&gview.node_ref(dst_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const ControlEdgeInfo& e : item->output_control_edges()) {
|
||||||
|
const int dst_id = e.dst_id;
|
||||||
|
const PendingCounts::Handle dst_pending_id =
|
||||||
|
immutable_state_.pending_ids()[dst_id];
|
||||||
|
int num_pending = counts_.decrement_pending(dst_pending_id, 1);
|
||||||
|
if (num_pending == 0) ready->emplace_back(&gview.node_ref(dst_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SimplePropagatorState::DumpState() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
LOG(WARNING) << "Dumping state";
|
||||||
|
|
||||||
|
// Dump any waiting nodes that are holding on to tensors.
|
||||||
|
for (const NodeItem* node : *nodes_) {
|
||||||
|
PendingCounts::Handle pending_id =
|
||||||
|
immutable_state_.pending_ids()[node->node_id];
|
||||||
|
if (counts_.node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
|
||||||
|
counts_.node_state(pending_id) == PendingCounts::PENDING_READY) {
|
||||||
|
DumpPendingNodeState(immutable_state_, node->node_id,
|
||||||
|
input_tensors_.data(), false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Then the active nodes.
|
||||||
|
for (const NodeItem* node : *nodes_) {
|
||||||
|
PendingCounts::Handle pending_id =
|
||||||
|
immutable_state_.pending_ids()[node->node_id];
|
||||||
|
if (counts_.node_state(pending_id) == PendingCounts::STARTED) {
|
||||||
|
DumpActiveNodeState(immutable_state_, node->node_id,
|
||||||
|
input_tensors_.data());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Show all input tensors in use.
|
||||||
|
size_t total_bytes = 0;
|
||||||
|
for (size_t i = 0; i < input_tensors_.size(); ++i) {
|
||||||
|
const Entry& input = input_tensors_[i];
|
||||||
|
const Tensor* tensor = GetTensorValueForDump(input);
|
||||||
|
if (tensor && tensor->IsInitialized()) {
|
||||||
|
LOG(WARNING) << " Input " << i << ": "
|
||||||
|
<< strings::StrCat(
|
||||||
|
"Tensor<type: ", DataTypeString(tensor->dtype()),
|
||||||
|
" shape: ", tensor->shape().DebugString(),
|
||||||
|
", bytes: ", tensor->TotalBytes(), ">");
|
||||||
|
total_bytes += tensor->TotalBytes();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG(WARNING) << " Total bytes " << total_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
186
tensorflow/core/common_runtime/simple_propagator_state.h
Normal file
186
tensorflow/core/common_runtime/simple_propagator_state.h
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
|
||||||
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/entry.h"
|
||||||
|
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
|
||||||
|
#include "tensorflow/core/common_runtime/pending_counts.h"
|
||||||
|
#include "tensorflow/core/framework/control_flow.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Represents the ephemeral "edge state" associated with one invocation of
|
||||||
|
// `Executor::Run()`.
|
||||||
|
//
|
||||||
|
// NOTE: `SimplePropagatorState` does not support "v1-style" control flow,
|
||||||
|
// including "dead tensors", "Switch" and "Merge" nodes, and cycles in the
|
||||||
|
// graph. Use `PropagatorState` for graphs with those features.
|
||||||
|
// `SimplePropagatorState` *does* support "v2-style" or "functional" control
|
||||||
|
// flow.
|
||||||
|
//
|
||||||
|
// `SimplePropagatorState` is responsible for propagating values along dataflow
|
||||||
|
// edges in a TensorFlow graph and determining which nodes are runnable. The
|
||||||
|
// executor primarily updates `SimplePropagatorState` by calling
|
||||||
|
// `PropagateOutputs()` after processing a node, and `SimplePropagatorState`
|
||||||
|
// dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`.
|
||||||
|
class SimplePropagatorState {
|
||||||
|
public:
|
||||||
|
SimplePropagatorState(const ImmutableExecutorState& immutable_state,
|
||||||
|
int64 step_id);
|
||||||
|
~SimplePropagatorState();
|
||||||
|
|
||||||
|
// A `TaggedNode` corresponds to a single invocation of a node's kernel,
|
||||||
|
// and it is created when the kernel becomes runnable.
|
||||||
|
struct TaggedNode {
|
||||||
|
const NodeItem* node_item;
|
||||||
|
|
||||||
|
explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {}
|
||||||
|
|
||||||
|
const NodeItem& get_node_item() const { return *node_item; }
|
||||||
|
|
||||||
|
bool get_is_dead() const { return false; }
|
||||||
|
int64 get_iter_num() const { return 0; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
|
||||||
|
// have that many nodes in the ready queue, so we just use a vector and
|
||||||
|
// don't free up memory from the queue as we consume nodes.
|
||||||
|
// TODO(mrry): Extract this and share it with the version in
|
||||||
|
// `PropagatorState`. The correct constants might be different, since
|
||||||
|
// sizeof(TaggedNode) is smaller in this version.
|
||||||
|
class TaggedNodeReadyQueue {
|
||||||
|
public:
|
||||||
|
TaggedNodeReadyQueue() : front_index_(0) {}
|
||||||
|
|
||||||
|
void push_back(const TaggedNode& node) { ready_.push_back(node); }
|
||||||
|
TaggedNode front() const {
|
||||||
|
DCHECK_LT(front_index_, ready_.size());
|
||||||
|
return ready_[front_index_];
|
||||||
|
}
|
||||||
|
void pop_front() {
|
||||||
|
DCHECK_LT(front_index_, ready_.size());
|
||||||
|
front_index_++;
|
||||||
|
if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
|
||||||
|
if (front_index_ == ready_.size()) {
|
||||||
|
ready_.clear();
|
||||||
|
} else {
|
||||||
|
// Lots of unused entries at beginning of vector: move everything
|
||||||
|
// down to start of vector.
|
||||||
|
ready_.erase(ready_.begin(), ready_.begin() + front_index_);
|
||||||
|
}
|
||||||
|
front_index_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool empty() const { return ready_.empty(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// TODO(b/152925936): Re-evaluate these constants with current usage
|
||||||
|
// patterns.
|
||||||
|
static constexpr int kSpillThreshold = 16384;
|
||||||
|
gtl::InlinedVector<TaggedNode, 16> ready_;
|
||||||
|
int front_index_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(b/152925936): Re-evaluate this constant with current usage patterns.
|
||||||
|
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
|
||||||
|
|
||||||
|
// Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
|
||||||
|
void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
|
||||||
|
TaggedNodeSeq* ready);
|
||||||
|
|
||||||
|
// After processing the outputs, propagates the outputs to their dsts.
|
||||||
|
// Contents of *outputs are left in an indeterminate state after
|
||||||
|
// returning from this method.
|
||||||
|
void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
|
||||||
|
TaggedNodeSeq* ready);
|
||||||
|
|
||||||
|
// Returns an array of `Entry` objects corresponding to the inputs of
|
||||||
|
// `tagged_node`.
|
||||||
|
//
|
||||||
|
// NOTE: Thread safety analysis is disabled on this method, because the
|
||||||
|
// underlying `IterationState` and its array of `input_tensors` retain the
|
||||||
|
// same address while the iteration is live.
|
||||||
|
Entry* GetInputTensors(const TaggedNode& tagged_node)
|
||||||
|
TF_NO_THREAD_SAFETY_ANALYSIS {
|
||||||
|
return input_tensors_.data() + tagged_node.node_item->input_start;
|
||||||
|
}
|
||||||
|
|
||||||
|
FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
|
||||||
|
return {0, 0};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provide debugging output of the state of the executor.
|
||||||
|
void DumpState();
|
||||||
|
|
||||||
|
// For debugging/logging only.
|
||||||
|
void MaybeMarkStarted(const TaggedNode& tagged_node) {
|
||||||
|
// TODO(misard) Replace with a finer-grain enabling flag once we add better
|
||||||
|
// optional debugging support.
|
||||||
|
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
counts_.mark_started(
|
||||||
|
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void MaybeMarkCompleted(const TaggedNode& tagged_node) {
|
||||||
|
// TODO(misard) Replace with a finer-grain enabling flag once we add better
|
||||||
|
// optional debugging support.
|
||||||
|
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
counts_.mark_completed(
|
||||||
|
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SimplePropagatorState(const ImmutableExecutorState& immutable_state_,
|
||||||
|
int64 step_id,
|
||||||
|
const ImmutableExecutorState::FrameInfo& finfo);
|
||||||
|
|
||||||
|
const ImmutableExecutorState& immutable_state_;
|
||||||
|
const int64 step_id_;
|
||||||
|
const bool vlog_;
|
||||||
|
|
||||||
|
mutex mu_;
|
||||||
|
|
||||||
|
// The i-th node's j-th input is stored at
|
||||||
|
// `input_tensors_[impl_->nodes[i].input_start + j]`.
|
||||||
|
//
|
||||||
|
// NOTE: No need to protect input_tensors[i] by any locks because it
|
||||||
|
// is resized once. Each element of input_tensors is written once by the
|
||||||
|
// source node of an edge and is cleared by the destination of the same
|
||||||
|
// edge. The destination node always runs after the source node, so there
|
||||||
|
// is never concurrent access to the same entry.
|
||||||
|
std::vector<Entry> input_tensors_;
|
||||||
|
|
||||||
|
PendingCounts counts_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
const std::vector<const NodeItem*>* const nodes_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(SimplePropagatorState);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
|
Loading…
x
Reference in New Issue
Block a user