diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index ffdec02a0da..425b4d7c0a7 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -40,12 +40,6 @@ namespace { using ::tensorflow::strings::HumanReadableNumBytes; -constexpr char kAttrInputSrc[] = "input_source_"; -constexpr char kAttrSrcDevice[] = "send_device"; -constexpr char kAttrDstDevice[] = "recv_device"; -constexpr char kAttrTensorName[] = "tensor_name"; -constexpr char kChannelDevice[] = "Channel"; - float Round2(const float x) { // Not using std::round from here because not all platforms seem to // support that (specifically Android). @@ -347,13 +341,11 @@ std::unique_ptr ReadyNodeManagerFactory( return nullptr; } -VirtualScheduler::VirtualScheduler(const bool use_static_shapes, - const bool use_aggressive_shape_inference, - Cluster* cluster, - ReadyNodeManager* ready_nodes, - std::unique_ptr placer) - : ready_nodes_(ready_nodes), - graph_costs_(Costs::ZeroCosts()), +SchedulerState::SchedulerState(const bool use_static_shapes, + const bool use_aggressive_shape_inference, + Cluster* cluster, + std::unique_ptr placer) + : graph_costs_(Costs::ZeroCosts()), cluster_(cluster), use_static_shapes_(use_static_shapes), use_aggressive_shape_inference_(use_aggressive_shape_inference), @@ -364,10 +356,12 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes, track_mem_usage_snapshot_ = VLOG_IS_ON(1); } -Status VirtualScheduler::Init(const GrapplerItem* item) { +Status SchedulerState::Init(const GrapplerItem* item, + std::vector* initial_nodes, + bool create_explicit_channel_device) { initialized_ = false; - // Clear all internal states so that the VirtualScheduler is reusable for + // Clear all internal states so that the SchedulerState is reusable for // different GrapplerItems node_map_.clear(); device_.clear(); @@ -380,14 +374,12 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { op_counts_.clear(); op_costs_.clear(); - // Init() preprocesses the input grappler_item and graph_properties to extract - // necessary information for emulating tensorflow op scheduling and - // construct internal data structures (NodeState and DeviceState) for virtual - // scheduling. - TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates())); + initial_nodes->clear(); // Constructs graph properties and performs shape inference. graph_properties_ = absl::make_unique(*item); + // TODO(safeen,dyoon): Will we ever use InferDynamically? If not we may want + // to get rid of use_static_shapes_ and cluster_. if (use_static_shapes_) { TF_RETURN_IF_ERROR(graph_properties_->InferStatically( true, use_aggressive_shape_inference_, true)); @@ -399,6 +391,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { const auto& graph = grappler_item_->graph; const auto& fetch_nodes = grappler_item_->fetch; std::set feed_nodes; + for (const auto& f : grappler_item_->feed) { auto iter_and_inserted_flag = feed_nodes.insert(f.first); QCHECK(iter_and_inserted_flag.second) @@ -486,8 +479,9 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { } else { // Different device, no cached copy; transfer input_node to the // curr_node's device. - auto send_and_recv = CreateSendRecv(input_node, curr_node, input_node, - input_node_name); + auto send_and_recv = + CreateSendRecv(input_node, curr_node, input_node, input_node_name, + create_explicit_channel_device); // Note that CreateSendRecv() already connected input/output between // _Send and _Recv ops. const auto* send = send_and_recv.first; @@ -514,7 +508,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { if (given_as_feed || has_no_inputs) { curr_node_state.time_ready = Costs::Duration(); - ready_nodes_->AddNode(curr_node); + initial_nodes->push_back(curr_node); VLOG(3) << "Added ready node: " << curr_node->name(); } @@ -530,7 +524,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { } } - if (ready_nodes_->Empty()) { + if (initial_nodes->empty()) { return errors::InvalidArgument("No ready nodes in the graph."); } @@ -546,20 +540,20 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { return Status::OK(); } -void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) { +void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) { CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init()."; // This method is called when NodeState is created and adds input and output // properties for a few exceptional cases that GraphProperties cannot provide // input/output properties. if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) { - // _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc + // _Send and _Recv ops created from SchedulerState have kAttrInputSrc // attr; normal _Send and _Recv ops (from the input graph) do not have that // attr. auto& node_state = node_map_[node]; auto& inputs = node_state.input_properties; auto& outputs = node_state.output_properties; - // _Send and _Recv ops are created from VirtualScheduler, so + // _Send and _Recv ops are created from SchedulerState, so // there should be no inputs TensorProperties. CHECK(inputs.empty()); CHECK(outputs.empty()); @@ -595,27 +589,27 @@ void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) { } } -string VirtualScheduler::DeviceName(const NodeDef* node) const { +string SchedulerState::DeviceName(const NodeDef* node) const { return placer_->get_canonical_device_name(*node); } -string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const { +string SchedulerState::SanitizedDeviceName(const NodeDef* node) const { // Replace the ":" characters that may be present in the device name with "_". // This makes it possible to then use the resulting string in a node name. return absl::StrReplaceAll(placer_->get_canonical_device_name(*node), {{":", "_"}}); } -string VirtualScheduler::ChannelDeviceName(const NodeDef* from, - const NodeDef* to) const { +string SchedulerState::ChannelDeviceName(const NodeDef* from, + const NodeDef* to) const { CHECK(!initialized_) << "ChannelDeviceName is called after Init()."; return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from), "_to_", SanitizedDeviceName(to)); } -std::pair VirtualScheduler::CreateSendRecv( +std::pair SchedulerState::CreateSendRecv( const NodeDef* from, const NodeDef* to, const NodeDef* input_node, - const string& input_name) { + const string& input_name, bool create_channel_device) { CHECK(!initialized_) << "CreateSendRecv is called after Init()."; // Connect "from" node to "to" node with _Send and _Recv such that @@ -643,7 +637,9 @@ std::pair VirtualScheduler::CreateSendRecv( "_to_" + SanitizedDeviceName(to)); send->set_op("_Send"); send->add_input(from->name()); - send->set_device(ChannelDeviceName(from, to)); + auto send_device = + create_channel_device ? ChannelDeviceName(from, to) : DeviceName(from); + send->set_device(send_device); auto& send_attr = *(send->mutable_attr()); send_attr[kAttrInputSrc].set_s(input_name); send_attr[kAttrSrcDevice].set_s(DeviceName(from)); @@ -687,9 +683,7 @@ std::pair VirtualScheduler::CreateSendRecv( return std::make_pair(send, recv); } -OpContext VirtualScheduler::GetCurrNode() const { - const NodeDef* node = ready_nodes_->GetCurrNode(); - +OpContext SchedulerState::CreateOpContext(const NodeDef* node) const { // Get the device from the placer. DeviceProperties device; device = placer_->get_device(*node); @@ -721,7 +715,7 @@ OpContext VirtualScheduler::GetCurrNode() const { return op_context; } -NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) { +NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) { CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init()."; auto it = node_map_.find(node); @@ -766,8 +760,9 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) { return it->second; } -void VirtualScheduler::AddOutputNodesToReadyQueue( - const NodeDef* node, const Costs::Duration& curr_time) { +void SchedulerState::GetOutputNodes(const NodeDef* node, + const Costs::Duration& curr_time, + std::vector* output_nodes) { // Checks whether the Switch's output slots change over iterations. int slot = -1; if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 && @@ -780,7 +775,6 @@ void VirtualScheduler::AddOutputNodesToReadyQueue( } } } - // Increment num_inputs_ready of the output nodes and maybe add to ready // nodes. auto& node_state = node_map_[node]; @@ -799,16 +793,15 @@ void VirtualScheduler::AddOutputNodesToReadyQueue( IsMerge(*output_node)) { // This output node is now ready. output_state.time_ready = curr_time; - ready_nodes_->AddNode(output_node); + output_nodes->push_back(output_node); VLOG(3) << " Add output: " << output_node->name(); } } } } -bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { - // Update graph_costs_ and per-op costs. - const NodeDef* node = ready_nodes_->GetCurrNode(); +std::vector SchedulerState::MarkNodeExecuted( + const NodeDef* node, const Costs& node_costs, const OpContext& op_context) { auto& node_state = node_map_[node]; // TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and // Merge -- it can create a loop which may include loop-carried dependency, @@ -834,8 +827,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { if (VLOG_IS_ON(2)) { // Also keep track of op counts and costs per op (with their shapes). - OpContext op_context = GetCurrNode(); - string node_description = GetOpDescription(op_context.op_info); op_counts_[node_description] += 1; op_costs_[node_description] = @@ -886,7 +877,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { << ", ready: " << node_state.time_ready.count() << ", scheduled: " << node_state.time_scheduled.count() << ", finished: " << node_state.time_finished.count(); - + std::vector new_nodes; if (previously_executed_merge) { // Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge. VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] " @@ -894,7 +885,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { << "Skip scheduling its output nodes."; } else { // Checks outputs, and adds ready nodes to queue. - AddOutputNodesToReadyQueue(node, curr_time); + GetOutputNodes(node, curr_time, &new_nodes); } // Increment num_outputs_executed of the input nodes and maybe update memory. @@ -929,13 +920,10 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { } } } - - ready_nodes_->RemoveCurrNode(); - - return !ready_nodes_->Empty(); + return new_nodes; } -Costs VirtualScheduler::Summary() const { +Costs SchedulerState::Summary() const { // Overall statement about accuracy VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with " << graph_costs_.num_ops_with_unknown_shapes @@ -1109,12 +1097,12 @@ Costs VirtualScheduler::Summary() const { return critical_path_costs; } -Costs VirtualScheduler::Summary(RunMetadata* metadata) { +Costs SchedulerState::Summary(RunMetadata* metadata) { if (metadata) GenerateRunMetadata(metadata); return Summary(); } -void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) { +void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) { // Fill RunMetadata's step_stats and partition_graphs fields. StepStats* stepstats = metadata->mutable_step_stats(); for (const auto& device : device_) { @@ -1176,7 +1164,7 @@ void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) { nodestate.time_scheduled.count()); auto* mem_stats = node_stats->mutable_memory_stats(); - // VirtualScheduler does not specify scratch pad memory usage. + // SchedulerState does not specify scratch pad memory usage. mem_stats->set_temp_memory_size(0); int64 persistent_memory_size = 0; if (IsPersistent(*node_def)) { @@ -1188,7 +1176,7 @@ void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) { } } -const std::unordered_map VirtualScheduler::GetPeakMemoryUsage() +const std::unordered_map SchedulerState::GetPeakMemoryUsage() const { std::unordered_map result; for (const auto& device : device_) { @@ -1200,7 +1188,7 @@ const std::unordered_map VirtualScheduler::GetPeakMemoryUsage() } const std::unordered_map -VirtualScheduler::GetPersistentMemoryUsage() const { +SchedulerState::GetPersistentMemoryUsage() const { std::unordered_map result; for (const auto& device : device_) { const string& name = device.first; @@ -1217,5 +1205,51 @@ VirtualScheduler::GetPersistentMemoryUsage() const { } return result; } + +VirtualScheduler::VirtualScheduler(const bool use_static_shapes, + const bool use_aggressive_shape_inference, + Cluster* cluster, + ReadyNodeManager* ready_nodes, + std::unique_ptr placer) + : scheduler_state_(use_static_shapes, use_aggressive_shape_inference, + cluster, std::move(placer)), + ready_nodes_(ready_nodes) {} + +Status VirtualScheduler::Init(const GrapplerItem* item) { + // SchedulerState::Init() preprocesses the input grappler_item and + // graph_properties to extract necessary information for emulating tensorflow + // op scheduling and construct internal data structures (NodeState and + // DeviceState) for virtual scheduling. + TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates())); + std::vector initial_nodes; + auto status = scheduler_state_.Init(item, &initial_nodes); + if (status.ok()) { + // Add the set of initial nodes to ready_nodes_ + for (auto node : initial_nodes) { + ready_nodes_->AddNode(node); + } + } + return status; +} + +OpContext VirtualScheduler::GetCurrNode() const { + const NodeDef* node = ready_nodes_->GetCurrNode(); + return scheduler_state_.CreateOpContext(node); +} + +bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { + // Update graph_costs_ and per-op costs. + const NodeDef* node = ready_nodes_->GetCurrNode(); + auto new_nodes = scheduler_state_.MarkNodeExecuted( + node, node_costs, + scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode())); + ready_nodes_->RemoveCurrNode(); + // Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_. + for (auto node : new_nodes) { + ready_nodes_->AddNode(node); + } + return !ready_nodes_->Empty(); +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index d380947f158..a287fc29d5e 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -32,6 +32,12 @@ limitations under the License. namespace tensorflow { namespace grappler { +inline constexpr char kAttrInputSrc[] = "input_source_"; +inline constexpr char kAttrSrcDevice[] = "send_device"; +inline constexpr char kAttrDstDevice[] = "recv_device"; +inline constexpr char kAttrTensorName[] = "tensor_name"; +inline constexpr char kChannelDevice[] = "Channel"; + struct NodeState { // A node (i.e., an op) takes a set of input:port pairs and produces // a set of output ports. @@ -233,7 +239,7 @@ class HeapReadyManager : public ReadyNodeManager { // functor for keeping the smallest time_ready node at the front of heap. std::function greater_; - // NodeState structure from VirtualScheduler to get time_ready of ready nodes. + // NodeState structure from SchedulerState to get time_ready of ready nodes. // Not owned by FirstReadyManager. const std::unordered_map* node_map_; }; @@ -298,7 +304,7 @@ class CompositeNodeManager : public ReadyNodeManager { FirstReadyManager send_manager_; FirstReadyManager recv_manager_; - // NodeState structure from VirtualScheduler to get time_ready of ready nodes. + // NodeState structure from SchedulerState to get time_ready of ready nodes. // Not owned by CompositeReadyManager. const std::unordered_map* node_map_; @@ -310,32 +316,22 @@ class CompositeNodeManager : public ReadyNodeManager { std::unique_ptr ReadyNodeManagerFactory( const string& ready_node_manager); -// The virtual scheduler emulates execution of nodes in a graph, considering -// dependencies, device, etc. -class VirtualScheduler { +// Encapsulates all of the various pieces uses to track state of a scheduler; +// enables reuse of all scheduler state-related utilities across different +// scheduler implementations. +class SchedulerState { public: - // Does not take ownership of cluster or ready_nodes. - VirtualScheduler(const bool use_static_shapes, - const bool use_aggressive_shape_inference, Cluster* cluster, - ReadyNodeManager* ready_nodes, - std::unique_ptr placer); + SchedulerState(const bool use_static_shapes, + const bool use_aggressive_shape_inference, Cluster* cluster, + std::unique_ptr placer); + // Sets up the graph while also performing some necessary transformations + // initial_nodes is the set of nodes (primary inputs) discovered by Init() + // which may be added by a ReadyNodeManager (or related/derivative scheduler) + // to begin node schedule and graph simulation. + Status Init(const GrapplerItem* item, + std::vector* initial_nodes, + bool create_explicit_channel_device = true); - // Initializes the scheduler for the specific grappler item. - // Should be called immediately after the c'tor or when the scheduler will be - // reused for a new grappler item. All internal states of the scheduler - // related to the previous grappler item will be reset/cleared. - // - // This function should be called at least once after the scheduler is - // constructed. An uninitialized or failed-to-initialize scheduler will cause - // undefined behavior. - Status Init(const GrapplerItem* item); - - OpContext GetCurrNode() const; - - // Returns true if there is any node to be scheduled. - bool MarkCurrNodeExecuted(const Costs& node_costs); - - // Prints out summary of execution (timing, memory usage, etc.) Costs Summary() const; // Like the above, but writes detailed stats to RunMetadata. // If metadata is nullptr, then just calls and return Summary(). @@ -347,34 +343,40 @@ class VirtualScheduler { // Returns per device memory usage. const std::unordered_map GetPeakMemoryUsage() const; const std::unordered_map GetPersistentMemoryUsage() const; - - // Returns VirtualScheduler (read only) device and node states. + void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; } + // Returns (read only) device and node states. const std::unordered_map* GetDeviceStates() const { return &device_; } + const std::unordered_map* GetNodeStates() const { return &node_map_; } - void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; } + OpContext CreateOpContext(const NodeDef* node) const; + std::vector MarkNodeExecuted(const NodeDef* node, + const Costs& node_costs, + const OpContext& op_context); private: // Methods called from Init(). Fails if initialize_ is set. + void MaybeUpdateInputOutput(const NodeDef* node); NodeState& GetNodeStateOrCreateIt(const NodeDef* node); + // Creates a Send_ and Recv_ pair between from and to. The argument + // create_channel_device tells the function to create an explicit device for + // the channel. std::pair CreateSendRecv( const NodeDef* from, const NodeDef* to, const NodeDef* input_node, - const string& input_name); + const string& input_name, bool create_channel_device); string DeviceName(const NodeDef* node) const; string SanitizedDeviceName(const NodeDef* node) const; string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const; // Helper methods. - void AddOutputNodesToReadyQueue(const NodeDef* node, - const Costs::Duration& curr_time); + void GetOutputNodes(const NodeDef* node, const Costs::Duration& curr_time, + std::vector* output_nodes); - // Scheduler states: - ReadyNodeManager* ready_nodes_; // Not owned. std::unordered_map node_map_; std::unordered_map device_; @@ -396,16 +398,81 @@ class VirtualScheduler { // Auxiliary data structures for constructing NodeState and DeviceState. std::unique_ptr graph_properties_; // Initialized in Init(). Cluster* cluster_; // Not owned. - const GrapplerItem* grappler_item_; // Not owned. bool use_static_shapes_; bool initialized_; bool track_mem_usage_snapshot_; const bool use_aggressive_shape_inference_; - std::unique_ptr placer_; }; +// The virtual scheduler emulates execution of nodes in a graph, considering +// dependencies, device, etc. +class VirtualScheduler { + public: + // Does not take ownership of cluster or ready_nodes. + VirtualScheduler(const bool use_static_shapes, + const bool use_aggressive_shape_inference, Cluster* cluster, + ReadyNodeManager* ready_nodes, + std::unique_ptr placer); + + // Initializes the scheduler for the specific grappler item. + // Should be called immediately after the c'tor or when the scheduler will be + // reused for a new grappler item. All internal states of the scheduler + // related to the previous grappler item will be reset/cleared. + // + // This function should be called at least once after the scheduler is + // constructed. An uninitialized or failed-to-initialize scheduler will cause + // undefined behavior. + Status Init(const GrapplerItem* item); + + // Gets the current scheduled node for execution; the caller of this function + // can accordingly simulate the execution of the current scheduled node. + OpContext GetCurrNode() const; + // Marks the current scheduled node as executed. Note that we should call this + // function only after the execution of the node has been simulated; + // node_costs_ capture the simulated costs of the node. + // Returns true if there is any node to be scheduled. + bool MarkCurrNodeExecuted(const Costs& node_costs); + + // Prints out summary of execution (timing, memory usage, etc.) + Costs Summary() const { return scheduler_state_.Summary(); } + // Like the above, but writes detailed stats to RunMetadata. + // If metadata is nullptr, then just calls and return Summary(). + Costs Summary(RunMetadata* metadata) { + return scheduler_state_.Summary(metadata); + } + // Generates RunMetadata's step_stats and partition_graphs fields from results + // of the virtual execution of the graph. + void GenerateRunMetadata(RunMetadata* metadata) { + scheduler_state_.GenerateRunMetadata(metadata); + } + // Returns per device memory usage. + const std::unordered_map GetPeakMemoryUsage() const { + return scheduler_state_.GetPeakMemoryUsage(); + } + const std::unordered_map GetPersistentMemoryUsage() const { + return scheduler_state_.GetPersistentMemoryUsage(); + } + // Returns VirtualScheduler (read only) device and node states. + const std::unordered_map* GetDeviceStates() const { + return scheduler_state_.GetDeviceStates(); + } + const std::unordered_map* GetNodeStates() const { + return scheduler_state_.GetNodeStates(); + } + void enable_mem_usage_tracking() { + scheduler_state_.enable_mem_usage_tracking(); + } + + private: + // The state of the scheduler and the execution of the graph is encapsulated + // by the scheduler_state_ object. + SchedulerState scheduler_state_; + // ready_nodes_ is responsible for ordering the traversal of the graph. + ReadyNodeManager* ready_nodes_; // Not owned. +}; + } // namespace grappler } // end namespace tensorflow