From 0bde6a68f2e3ab93a266c905465f87a1c6b3b0c0 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 2 Apr 2020 17:22:07 -0700 Subject: [PATCH] Rolling forward "[Executor] Implement `SimplePropagatorState` and use it when a graph has no v1-style control flow." The previous version had undefined behavior when executing a graph comprised solely of operations with no input tensors, in which case the `SimpleExecutorState::input_tensors_` vector would be empty. Using `vector::data()` instead of `operator[]` avoids creating a reference to a null element when the vector is empty. PiperOrigin-RevId: 304510355 Change-Id: I12785a6ac72355d0e5b1a312d41e227dbdf9dc96 --- tensorflow/core/common_runtime/BUILD | 4 + tensorflow/core/common_runtime/executor.cc | 21 +- .../core/common_runtime/executor_test.cc | 8 + tensorflow/core/common_runtime/graph_view.h | 13 ++ .../immutable_executor_state.cc | 21 ++ .../common_runtime/immutable_executor_state.h | 6 + .../common_runtime/propagator_debug_utils.cc | 95 +++++++++ .../common_runtime/propagator_debug_utils.h | 40 ++++ .../core/common_runtime/propagator_state.cc | 124 +++--------- .../core/common_runtime/propagator_state.h | 11 -- .../common_runtime/simple_propagator_state.cc | 134 +++++++++++++ .../common_runtime/simple_propagator_state.h | 186 ++++++++++++++++++ 12 files changed, 551 insertions(+), 112 deletions(-) create mode 100644 tensorflow/core/common_runtime/propagator_debug_utils.cc create mode 100644 tensorflow/core/common_runtime/propagator_debug_utils.h create mode 100644 tensorflow/core/common_runtime/simple_propagator_state.cc create mode 100644 tensorflow/core/common_runtime/simple_propagator_state.h diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 45f4f2fee6e..93804bc8889 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -232,6 +232,7 @@ filegroup( "process_util.h", "inspecting_placer.h", "profile_handler.h", + "propagator_debug_utils.h", "propagator_state.h", "renamed_device.h", "rendezvous_mgr.h", @@ -241,6 +242,7 @@ filegroup( "ring_alg.h", "ring_gatherer.h", "session_factory.h", + "simple_propagator_state.h", "single_threaded_cpu_device.h", "stats_publisher_interface.h", "step_stats_collector.h", @@ -304,6 +306,7 @@ tf_cuda_library( "process_function_library_runtime.cc", "process_state.cc", "process_util.cc", + "propagator_debug_utils.cc", "propagator_state.cc", "renamed_device.cc", "rendezvous_mgr.cc", @@ -316,6 +319,7 @@ tf_cuda_library( "session_factory.cc", "session_options.cc", "session_state.cc", + "simple_propagator_state.cc", "single_threaded_cpu_device.cc", "stats_publisher_interface.cc", "step_stats_collector.cc", diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index d2cb7961454..84bf02eae02 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/pending_counts.h" #include "tensorflow/core/common_runtime/propagator_state.h" #include "tensorflow/core/common_runtime/renamed_device.h" +#include "tensorflow/core/common_runtime/simple_propagator_state.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/cancellation.h" @@ -372,6 +373,10 @@ class ExecutorState { mutex mu_; Status status_ TF_GUARDED_BY(mu_); + + // A flag that is set on error after the propagator state has been + // dumped for diagnostic purposes. + bool dumped_on_error_ TF_GUARDED_BY(mu_) = false; }; template @@ -925,7 +930,11 @@ Status ExecutorState::ProcessOutputs( // add better optional debugging support. if (vlog_ && VLOG_IS_ON(1)) { LOG(WARNING) << this << " Compute status: " << s; - propagator_.DumpState(); + mutex_lock l(mu_); + if (!dumped_on_error_) { + propagator_.DumpState(); + dumped_on_error_ = true; + } } if (s.code() == error::RESOURCE_EXHAUSTED) { if (stats_collector_) { @@ -1256,8 +1265,14 @@ void ExecutorState::Finish() { } void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { - (new ExecutorState(args, immutable_state_, &kernel_stats_)) - ->RunAsync(std::move(done)); + if (immutable_state_.requires_control_flow_support()) { + (new ExecutorState(args, immutable_state_, &kernel_stats_)) + ->RunAsync(std::move(done)); + } else { + (new ExecutorState(args, immutable_state_, + &kernel_stats_)) + ->RunAsync(std::move(done)); + } } } // namespace diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index 74febf43287..fe62a8459f1 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -413,6 +413,14 @@ TEST_F(ExecutorTest, RecvInvalidRefDtype) { rendez->Unref(); } +TEST_F(ExecutorTest, NoInputTensors) { + // Create a graph where none of the nodes have input tensors. + auto g = absl::make_unique(OpRegistry::Global()); + test::graph::Constant(g.get(), V(1.0)); + Create(std::move(g)); + TF_ASSERT_OK(Run(rendez_)); +} + // Create a graph that is 'depth' deep. At each level, fan-in and fan-out a // maximum of 'width' nodes. All nodes are no-ops and all dependencies are // control dependencies. diff --git a/tensorflow/core/common_runtime/graph_view.h b/tensorflow/core/common_runtime/graph_view.h index b0bc0f4b6de..6d31555ed9a 100644 --- a/tensorflow/core/common_runtime/graph_view.h +++ b/tensorflow/core/common_runtime/graph_view.h @@ -211,6 +211,8 @@ class GraphView { Status SetAllocAttrs(const Graph* g, const Device* device); void SetScopedAllocatorAttrs(const std::vector& sa_nodes); + // Returns a mutable pointer to the `NodeItem` with the given `id` if it + // exists in the graph, or `nullptr` if it does not. NodeItem* node(int32 id) const { DCHECK_GE(id, 0); DCHECK_LT(id, num_nodes_); @@ -220,6 +222,17 @@ class GraphView { : reinterpret_cast(space_ + node_offsets_[id])); } + // Returns the `NodeItem` with the given `id`. + // + // REQUIRES: `id` must be the ID of a valid node in the graph. + const NodeItem& node_ref(int32 id) const { + DCHECK_GE(id, 0); + DCHECK_LT(id, num_nodes_); + uint32 offset = node_offsets_[id]; + DCHECK_NE(offset, kuint32max); + return *reinterpret_cast(space_ + node_offsets_[id]); + } + int32 num_nodes() const { return num_nodes_; } private: diff --git a/tensorflow/core/common_runtime/immutable_executor_state.cc b/tensorflow/core/common_runtime/immutable_executor_state.cc index 97c17aa287d..8ec84c73a59 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.cc +++ b/tensorflow/core/common_runtime/immutable_executor_state.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/metrics.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/edgeset.h" #include "tensorflow/core/graph/graph.h" @@ -88,13 +89,33 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { EnsureFrameInfo(it)->nodes = absl::make_unique>(); } + root_frame_info_ = frame_info_[""]; pending_ids_.resize(gview_.num_nodes()); // Preprocess every node in the graph to create an instance of op // kernel for each node. + requires_control_flow_ = false; for (const Node* n : graph.nodes()) { if (IsSink(n)) continue; + if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) { + requires_control_flow_ = true; + } else if (IsRecv(n)) { + // A Recv node from a different device may produce dead tensors from + // non-local control-flow nodes. + // + // TODO(mrry): Track whether control flow was present in the + // pre-partitioned graph, and enable the caller (e.g. + // `DirectSession`) to relax this constraint. + string send_device; + string recv_device; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device)); + if (send_device != recv_device) { + requires_control_flow_ = true; + } + } + const int id = n->id(); const string& frame_name = cf_info.frame_names[id]; FrameInfo* frame_info = EnsureFrameInfo(frame_name); diff --git a/tensorflow/core/common_runtime/immutable_executor_state.h b/tensorflow/core/common_runtime/immutable_executor_state.h index c9c23e55a21..a03d795166e 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.h +++ b/tensorflow/core/common_runtime/immutable_executor_state.h @@ -91,6 +91,10 @@ class ImmutableExecutorState { } } + const FrameInfo& get_root_frame_info() const { return *root_frame_info_; } + + bool requires_control_flow_support() const { return requires_control_flow_; } + private: struct ControlFlowInfo { gtl::FlatSet unique_frame_names; @@ -106,6 +110,7 @@ class ImmutableExecutorState { // Owned. LocalExecutorParams params_; GraphView gview_; + bool requires_control_flow_; std::vector pending_ids_; // Root nodes (with no in edges) that should form the initial ready queue @@ -115,6 +120,7 @@ class ImmutableExecutorState { // TODO(yuanbyu): We could cache it along with the graph so to avoid // the overhead of constructing it for each executor instance. gtl::FlatMap frame_info_; + const FrameInfo* root_frame_info_; // Not owned. // Shallow copies of the constant tensors used in the graph. std::vector const_tensors_; diff --git a/tensorflow/core/common_runtime/propagator_debug_utils.cc b/tensorflow/core/common_runtime/propagator_debug_utils.cc new file mode 100644 index 00000000000..27f9da7ea52 --- /dev/null +++ b/tensorflow/core/common_runtime/propagator_debug_utils.cc @@ -0,0 +1,95 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/propagator_debug_utils.h" + +#include + +#include "tensorflow/core/common_runtime/entry.h" +#include "tensorflow/core/common_runtime/immutable_executor_state.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// 1-D, 0 element tensor. +static const Tensor* const kEmptyTensor = new Tensor; + + +const Tensor* GetTensorValueForDump(const Entry& input) { + switch (input.state) { + case Entry::State::NO_VALUE: + return kEmptyTensor; + case Entry::State::HAS_VALUE: + return input.val.get(); + case Entry::State::HAS_CONST_TENSOR: + return input.const_tensor; + case Entry::State::HAS_REF_TENSOR: + return input.ref_tensor.tensor; + } +} + +void DumpPendingNodeState(const ImmutableExecutorState& immutable_state, + const int node_id, const Entry* input_vector, + const bool show_nodes_with_no_ready_inputs) { + const NodeItem& node_item = immutable_state.graph_view().node_ref(node_id); + const int input_base = node_item.input_start; + if (!show_nodes_with_no_ready_inputs) { + bool has_ready_input = false; + for (int i = 0; i < node_item.num_inputs; ++i) { + const Entry& input = input_vector[input_base + i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor && tensor->IsInitialized()) { + has_ready_input = true; + break; + } + } + if (!has_ready_input) { + return; + } + } + LOG(WARNING) << " Pending Node: " << node_item.DebugString(); + for (int i = 0; i < node_item.num_inputs; ++i) { + const Entry& input = input_vector[input_base + i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor->IsInitialized()) { + LOG(WARNING) << " Input " << i << ": " + << strings::StrCat( + "Tensordtype()), + " shape: ", tensor->shape().DebugString(), ">"); + } else { + LOG(WARNING) << " Input " << i << ": not present"; + } + } +} + +void DumpActiveNodeState(const ImmutableExecutorState& immutable_state, + const int node_id, const Entry* input_vector) { + const NodeItem& node_item = immutable_state.graph_view().node_ref(node_id); + LOG(WARNING) << " Active Node: " << node_item.DebugString(); + const int input_base = node_item.input_start; + for (int i = 0; i < node_item.num_inputs; ++i) { + const Entry& input = input_vector[input_base + i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor->IsInitialized()) { + LOG(WARNING) << " Input " << i << ": " + << strings::StrCat( + "Tensordtype()), + " shape: ", tensor->shape().DebugString(), ">"); + } else { + LOG(WARNING) << " Input " << i << ": not present"; + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/propagator_debug_utils.h b/tensorflow/core/common_runtime/propagator_debug_utils.h new file mode 100644 index 00000000000..8f1204998ff --- /dev/null +++ b/tensorflow/core/common_runtime/propagator_debug_utils.h @@ -0,0 +1,40 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_ + +namespace tensorflow { + +struct Entry; +class ImmutableExecutorState; +class Tensor; + +// Returns a pointer to the tensor in `input` if one exists, or `nullptr`. +const Tensor* GetTensorValueForDump(const Entry& input); + +// Writes a LOG(WARNING) message describing the state of the pending node +// `node_id` in the graph described by `immutable_state`. +void DumpPendingNodeState(const ImmutableExecutorState& immutable_state, + const int node_id, const Entry* input_vector, + const bool show_nodes_with_no_ready_inputs); + +// Writes a LOG(WARNING) message describing the state of the active node +// `node_id` in the graph described by `immutable_state`. +void DumpActiveNodeState(const ImmutableExecutorState& immutable_state, + const int node_id, const Entry* input_vector); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_ diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index e2827a8eb1f..a4e311cbc6b 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -16,17 +16,12 @@ limitations under the License. #include "tensorflow/core/common_runtime/propagator_state.h" #include "tensorflow/core/common_runtime/graph_view.h" +#include "tensorflow/core/common_runtime/propagator_debug_utils.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { -// 1-D, 0 element tensor. -static const Tensor* const kEmptyTensor = new Tensor; - -typedef gtl::InlinedVector TensorValueVec; -typedef gtl::InlinedVector AllocatorAttributeVec; - PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state, int64 step_id) : immutable_state_(immutable_state), @@ -57,7 +52,7 @@ void PropagatorState::ActivateRoots(gtl::ArraySlice roots, TaggedNodeSeq* ready) { for (const NodeItem* item : roots) { DCHECK_EQ(item->num_inputs, 0); - ready->push_back(TaggedNode{item, root_frame_, 0, false}); + ready->emplace_back(item, root_frame_, 0, false); } mutex_lock l(root_frame_->mu); root_frame_->GetIteration(0)->outstanding_ops = ready->size(); @@ -173,72 +168,6 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, } } -const Tensor* PropagatorState::GetTensorValueForDump(const Entry& input) { - switch (input.state) { - case Entry::State::NO_VALUE: - return kEmptyTensor; - case Entry::State::HAS_VALUE: - return input.val.get(); - case Entry::State::HAS_CONST_TENSOR: - return input.const_tensor; - case Entry::State::HAS_REF_TENSOR: - return input.ref_tensor.tensor; - } -} - -void PropagatorState::DumpPendingNodeState( - const int node_id, const Entry* input_vector, - const bool show_nodes_with_no_ready_inputs) { - const NodeItem& node_item = *immutable_state_.graph_view().node(node_id); - const int input_base = node_item.input_start; - if (!show_nodes_with_no_ready_inputs) { - bool has_ready_input = false; - for (int i = 0; i < node_item.num_inputs; ++i) { - const Entry& input = input_vector[input_base + i]; - const Tensor* tensor = GetTensorValueForDump(input); - if (tensor->IsInitialized()) { - has_ready_input = true; - break; - } - } - if (!has_ready_input) { - return; - } - } - LOG(WARNING) << " Pending Node: " << node_item.DebugString(); - for (int i = 0; i < node_item.num_inputs; ++i) { - const Entry& input = input_vector[input_base + i]; - const Tensor* tensor = GetTensorValueForDump(input); - if (tensor->IsInitialized()) { - LOG(WARNING) << " Input " << i << ": " - << strings::StrCat( - "Tensordtype()), - " shape: ", tensor->shape().DebugString(), ">"); - } else { - LOG(WARNING) << " Input " << i << ": not present"; - } - } -} - -void PropagatorState::DumpActiveNodeState(const int node_id, - const Entry* input_vector) { - const NodeItem& node_item = *immutable_state_.graph_view().node(node_id); - LOG(WARNING) << " Active Node: " << node_item.DebugString(); - const int input_base = node_item.input_start; - for (int i = 0; i < node_item.num_inputs; ++i) { - const Entry& input = input_vector[input_base + i]; - const Tensor* tensor = GetTensorValueForDump(input); - if (tensor->IsInitialized()) { - LOG(WARNING) << " Input " << i << ": " - << strings::StrCat( - "Tensordtype()), - " shape: ", tensor->shape().DebugString(), ">"); - } else { - LOG(WARNING) << " Input " << i << ": not present"; - } - } -} - void PropagatorState::DumpIterationState(const FrameState* frame, IterationState* iteration) { const std::vector* nodes = frame->nodes; @@ -248,7 +177,8 @@ void PropagatorState::DumpIterationState(const FrameState* frame, immutable_state_.pending_ids()[node->node_id]; if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY || iteration->node_state(pending_id) == PendingCounts::PENDING_READY) { - DumpPendingNodeState(node->node_id, iteration->input_tensors, false); + DumpPendingNodeState(immutable_state_, node->node_id, + iteration->input_tensors, false); } } // Then the active nodes. @@ -256,7 +186,8 @@ void PropagatorState::DumpIterationState(const FrameState* frame, PendingCounts::Handle pending_id = immutable_state_.pending_ids()[node->node_id]; if (iteration->node_state(pending_id) == PendingCounts::STARTED) { - DumpActiveNodeState(node->node_id, iteration->input_tensors); + DumpActiveNodeState(immutable_state_, node->node_id, + iteration->input_tensors); } } // Show all input tensors in use. @@ -279,14 +210,11 @@ void PropagatorState::DumpIterationState(const FrameState* frame, void PropagatorState::DumpState() { mutex_lock l(mu_); - if (!dumped_on_error_) { - LOG(WARNING) << "Dumping state"; - for (auto& frame : outstanding_frames_) { - LOG(WARNING) << frame.first; - FrameState* frame_state = frame.second; - frame_state->DumpIterationState(this); - } - dumped_on_error_ = true; + LOG(WARNING) << "Dumping state"; + for (auto& frame : outstanding_frames_) { + LOG(WARNING) << frame.first; + FrameState* frame_state = frame.second; + frame_state->DumpIterationState(this); } } @@ -378,7 +306,7 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { for (const EdgeInfo& e : item->output_edges()) { const NodeItem& dst_item = - *immutable_state_.graph_view().node(e.dst_id); + immutable_state_.graph_view().node_ref(e.dst_id); const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; bool dst_dead = true; @@ -398,7 +326,7 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { for (const ControlEdgeInfo& e : item->output_control_edges()) { const NodeItem& dst_item = - *immutable_state_.graph_view().node(e.dst_id); + immutable_state_.graph_view().node_ref(e.dst_id); const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; bool dst_dead; @@ -464,17 +392,17 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item, // // NOTE(mrry): Use a macro here instead of a lambda, because this method is // performance-critical and we need to ensure that the code is inlined. -#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \ - do { \ - if (!adjust_result.any_pending) { \ - const NodeItem* dst_item = gview.node(dst_id); \ - TaggedNode& t = ready->emplace_back(); \ - t.node_item = dst_item; \ - t.input_frame = this; \ - t.input_iter = iter; \ - t.is_dead = adjust_result.any_dead; \ - iter_state->outstanding_ops++; \ - } \ +#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \ + do { \ + if (!adjust_result.any_pending) { \ + const NodeItem* dst_item = &gview.node_ref(dst_id); \ + TaggedNode& t = ready->emplace_back(); \ + t.node_item = dst_item; \ + t.input_frame = this; \ + t.input_iter = iter; \ + t.is_dead = adjust_result.any_dead; \ + iter_state->outstanding_ops++; \ + } \ } while (0); Entry* input_tensors = iter_state->input_tensors; @@ -534,7 +462,7 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, for (const EdgeInfo& e : item->output_edges()) { const int dst_id = e.dst_id; - const NodeItem* dst_item = gview.node(dst_id); + const NodeItem* dst_item = &gview.node_ref(dst_id); const PendingCounts::Handle dst_pending_id = immutable_state.pending_ids()[dst_id]; const int src_slot = e.output_slot; @@ -596,7 +524,7 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, for (const ControlEdgeInfo& e : item->output_control_edges()) { const int dst_id = e.dst_id; - const NodeItem* dst_item = gview.node(dst_id); + const NodeItem* dst_item = &gview.node_ref(dst_id); const PendingCounts::Handle dst_pending_id = immutable_state.pending_ids()[dst_id]; diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index 4a5a26ba0f6..6d5abd02afa 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -429,26 +429,15 @@ class PropagatorState { void CleanupFramesIterations(FrameState* frame, int64 iter, TaggedNodeSeq* ready); - // Provide debugging output about an outstanding node in the executor. - void DumpPendingNodeState(const int node_id, const Entry* input_vector, - bool show_nodes_with_no_ready_inputs); - void DumpActiveNodeState(const int node_id, const Entry* input_vector); - // Provide debugging output about an outstanding iteration in the executor. void DumpIterationState(const FrameState* frame, IterationState* iteration); - const Tensor* GetTensorValueForDump(const Entry& input); - const ImmutableExecutorState& immutable_state_; const int64 step_id_; const bool vlog_; mutex mu_; - // A flag that is set on error after the frame state has been - // dumped for diagnostic purposes. - bool dumped_on_error_ TF_GUARDED_BY(mu_) = false; - // The root frame in which the execution of this step is started. FrameState* root_frame_; diff --git a/tensorflow/core/common_runtime/simple_propagator_state.cc b/tensorflow/core/common_runtime/simple_propagator_state.cc new file mode 100644 index 00000000000..bb9e33979b8 --- /dev/null +++ b/tensorflow/core/common_runtime/simple_propagator_state.cc @@ -0,0 +1,134 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/simple_propagator_state.h" + +#include "tensorflow/core/common_runtime/propagator_debug_utils.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +namespace tensorflow { + +SimplePropagatorState::SimplePropagatorState( + const ImmutableExecutorState& immutable_state, int64 step_id) + : SimplePropagatorState(immutable_state, step_id, + immutable_state.get_root_frame_info()) {} + +SimplePropagatorState::SimplePropagatorState( + const ImmutableExecutorState& immutable_state, int64 step_id, + const ImmutableExecutorState::FrameInfo& finfo) + : immutable_state_(immutable_state), + step_id_(step_id), + vlog_(VLOG_IS_ON(1)), + input_tensors_(finfo.total_inputs), + counts_(*finfo.pending_counts), + nodes_(finfo.nodes.get()) {} + +SimplePropagatorState::~SimplePropagatorState() {} + +void SimplePropagatorState::ActivateRoots( + gtl::ArraySlice roots, TaggedNodeSeq* ready) { + for (const NodeItem* item : roots) { + DCHECK_EQ(item->num_inputs, 0); + ready->emplace_back(item); + } +} + +void SimplePropagatorState::PropagateOutputs(const TaggedNode& tagged_node, + EntryVector* outputs, + TaggedNodeSeq* ready) { + profiler::TraceMe activity( + [&]() { + return strings::StrCat( + "ExecutorPropagateOutputs#", "id=", step_id_, + ",kernel_name=", tagged_node.node_item->kernel->name_view(), + ",num_output_edges=", tagged_node.node_item->num_output_edges, + ",num_output_control_edges=", + tagged_node.node_item->num_output_control_edges, "#"); + }, + profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); + + // Propagates outputs along out edges, and puts newly ready nodes + // into the ready queue. + DCHECK(ready->empty()); + + const GraphView& gview = immutable_state_.graph_view(); + const NodeItem* item = tagged_node.node_item; + + mutex_lock l(mu_); + + for (const EdgeInfo& e : item->output_edges()) { + const int dst_id = e.dst_id; + const PendingCounts::Handle dst_pending_id = + immutable_state_.pending_ids()[dst_id]; + const int src_slot = e.output_slot; + int num_pending = counts_.decrement_pending(dst_pending_id, 1); + const int dst_loc = e.input_slot; + if (e.is_last) { + input_tensors_[dst_loc] = std::move((*outputs)[src_slot]); + } else { + input_tensors_[dst_loc] = (*outputs)[src_slot]; + } + if (num_pending == 0) ready->emplace_back(&gview.node_ref(dst_id)); + } + + for (const ControlEdgeInfo& e : item->output_control_edges()) { + const int dst_id = e.dst_id; + const PendingCounts::Handle dst_pending_id = + immutable_state_.pending_ids()[dst_id]; + int num_pending = counts_.decrement_pending(dst_pending_id, 1); + if (num_pending == 0) ready->emplace_back(&gview.node_ref(dst_id)); + } +} + +void SimplePropagatorState::DumpState() { + mutex_lock l(mu_); + LOG(WARNING) << "Dumping state"; + + // Dump any waiting nodes that are holding on to tensors. + for (const NodeItem* node : *nodes_) { + PendingCounts::Handle pending_id = + immutable_state_.pending_ids()[node->node_id]; + if (counts_.node_state(pending_id) == PendingCounts::PENDING_NOTREADY || + counts_.node_state(pending_id) == PendingCounts::PENDING_READY) { + DumpPendingNodeState(immutable_state_, node->node_id, + input_tensors_.data(), false); + } + } + // Then the active nodes. + for (const NodeItem* node : *nodes_) { + PendingCounts::Handle pending_id = + immutable_state_.pending_ids()[node->node_id]; + if (counts_.node_state(pending_id) == PendingCounts::STARTED) { + DumpActiveNodeState(immutable_state_, node->node_id, + input_tensors_.data()); + } + } + // Show all input tensors in use. + size_t total_bytes = 0; + for (size_t i = 0; i < input_tensors_.size(); ++i) { + const Entry& input = input_tensors_[i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor && tensor->IsInitialized()) { + LOG(WARNING) << " Input " << i << ": " + << strings::StrCat( + "Tensordtype()), + " shape: ", tensor->shape().DebugString(), + ", bytes: ", tensor->TotalBytes(), ">"); + total_bytes += tensor->TotalBytes(); + } + } + LOG(WARNING) << " Total bytes " << total_bytes; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simple_propagator_state.h b/tensorflow/core/common_runtime/simple_propagator_state.h new file mode 100644 index 00000000000..0df9bd53018 --- /dev/null +++ b/tensorflow/core/common_runtime/simple_propagator_state.h @@ -0,0 +1,186 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ + +#include + +#include "tensorflow/core/common_runtime/entry.h" +#include "tensorflow/core/common_runtime/immutable_executor_state.h" +#include "tensorflow/core/common_runtime/pending_counts.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents the ephemeral "edge state" associated with one invocation of +// `Executor::Run()`. +// +// NOTE: `SimplePropagatorState` does not support "v1-style" control flow, +// including "dead tensors", "Switch" and "Merge" nodes, and cycles in the +// graph. Use `PropagatorState` for graphs with those features. +// `SimplePropagatorState` *does* support "v2-style" or "functional" control +// flow. +// +// `SimplePropagatorState` is responsible for propagating values along dataflow +// edges in a TensorFlow graph and determining which nodes are runnable. The +// executor primarily updates `SimplePropagatorState` by calling +// `PropagateOutputs()` after processing a node, and `SimplePropagatorState` +// dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`. +class SimplePropagatorState { + public: + SimplePropagatorState(const ImmutableExecutorState& immutable_state, + int64 step_id); + ~SimplePropagatorState(); + + // A `TaggedNode` corresponds to a single invocation of a node's kernel, + // and it is created when the kernel becomes runnable. + struct TaggedNode { + const NodeItem* node_item; + + explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {} + + const NodeItem& get_node_item() const { return *node_item; } + + bool get_is_dead() const { return false; } + int64 get_iter_num() const { return 0; } + }; + + // A drop-in replacement for std::deque. We typically don't + // have that many nodes in the ready queue, so we just use a vector and + // don't free up memory from the queue as we consume nodes. + // TODO(mrry): Extract this and share it with the version in + // `PropagatorState`. The correct constants might be different, since + // sizeof(TaggedNode) is smaller in this version. + class TaggedNodeReadyQueue { + public: + TaggedNodeReadyQueue() : front_index_(0) {} + + void push_back(const TaggedNode& node) { ready_.push_back(node); } + TaggedNode front() const { + DCHECK_LT(front_index_, ready_.size()); + return ready_[front_index_]; + } + void pop_front() { + DCHECK_LT(front_index_, ready_.size()); + front_index_++; + if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) { + if (front_index_ == ready_.size()) { + ready_.clear(); + } else { + // Lots of unused entries at beginning of vector: move everything + // down to start of vector. + ready_.erase(ready_.begin(), ready_.begin() + front_index_); + } + front_index_ = 0; + } + } + bool empty() const { return ready_.empty(); } + + private: + // TODO(b/152925936): Re-evaluate these constants with current usage + // patterns. + static constexpr int kSpillThreshold = 16384; + gtl::InlinedVector ready_; + int front_index_; + }; + + // TODO(b/152925936): Re-evaluate this constant with current usage patterns. + typedef gtl::InlinedVector TaggedNodeSeq; + + // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. + void ActivateRoots(gtl::ArraySlice roots, + TaggedNodeSeq* ready); + + // After processing the outputs, propagates the outputs to their dsts. + // Contents of *outputs are left in an indeterminate state after + // returning from this method. + void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs, + TaggedNodeSeq* ready); + + // Returns an array of `Entry` objects corresponding to the inputs of + // `tagged_node`. + // + // NOTE: Thread safety analysis is disabled on this method, because the + // underlying `IterationState` and its array of `input_tensors` retain the + // same address while the iteration is live. + Entry* GetInputTensors(const TaggedNode& tagged_node) + TF_NO_THREAD_SAFETY_ANALYSIS { + return input_tensors_.data() + tagged_node.node_item->input_start; + } + + FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { + return {0, 0}; + } + + // Provide debugging output of the state of the executor. + void DumpState(); + + // For debugging/logging only. + void MaybeMarkStarted(const TaggedNode& tagged_node) { + // TODO(misard) Replace with a finer-grain enabling flag once we add better + // optional debugging support. + if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { + mutex_lock l(mu_); + counts_.mark_started( + immutable_state_.pending_ids()[tagged_node.node_item->node_id]); + } + } + void MaybeMarkCompleted(const TaggedNode& tagged_node) { + // TODO(misard) Replace with a finer-grain enabling flag once we add better + // optional debugging support. + if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { + mutex_lock l(mu_); + counts_.mark_completed( + immutable_state_.pending_ids()[tagged_node.node_item->node_id]); + } + } + + private: + SimplePropagatorState(const ImmutableExecutorState& immutable_state_, + int64 step_id, + const ImmutableExecutorState::FrameInfo& finfo); + + const ImmutableExecutorState& immutable_state_; + const int64 step_id_; + const bool vlog_; + + mutex mu_; + + // The i-th node's j-th input is stored at + // `input_tensors_[impl_->nodes[i].input_start + j]`. + // + // NOTE: No need to protect input_tensors[i] by any locks because it + // is resized once. Each element of input_tensors is written once by the + // source node of an edge and is cleared by the destination of the same + // edge. The destination node always runs after the source node, so there + // is never concurrent access to the same entry. + std::vector input_tensors_; + + PendingCounts counts_ TF_GUARDED_BY(mu_); + + const std::vector* const nodes_; + + TF_DISALLOW_COPY_AND_ASSIGN(SimplePropagatorState); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_