From 2f9cc84ba3f5d59753d843f167adee2e2534c143 Mon Sep 17 00:00:00 2001 From: Peter Ma Date: Mon, 6 May 2019 19:15:04 -0700 Subject: [PATCH] - Changed to return Status for Init() function of ReadyNodeManager. - Changed node_state to node_map for CompositeNodeManager to be consistent with VirtualScheduler. PiperOrigin-RevId: 246942431 --- .../core/grappler/costs/virtual_scheduler.cc | 33 ++++++++++--------- .../core/grappler/costs/virtual_scheduler.h | 22 ++++++------- .../grappler/costs/virtual_scheduler_test.cc | 22 ++++++------- 3 files changed, 38 insertions(+), 39 deletions(-) diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index c384084e6cf..cd47188b65e 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -116,7 +116,7 @@ FirstReadyManager::FirstReadyManager() : ReadyNodeManager() { std::make_heap(nodes_.begin(), nodes_.end()); } -void FirstReadyManager::Init( +Status FirstReadyManager::Init( const std::unordered_map* node_map) { // Reset the node state since different instances of the scheduler can reuse // the same node_manager. @@ -133,6 +133,7 @@ void FirstReadyManager::Init( return node_map_->at(a).time_ready > node_map_->at(b).time_ready; } }; + return Status::OK(); } const NodeDef* FirstReadyManager::GetCurrNode() { @@ -172,12 +173,13 @@ void FirstReadyManager::DrainWaitingQueue() { CompositeNodeManager::CompositeNodeManager() : ReadyNodeManager(), send_manager_(), recv_manager_() {} -void CompositeNodeManager::Init( - const std::unordered_map* node_state) { - node_state_ = node_state; - send_manager_.Init(node_state); - recv_manager_.Init(node_state); +Status CompositeNodeManager::Init( + const std::unordered_map* node_map) { + node_map_ = node_map; + TF_RETURN_IF_ERROR(send_manager_.Init(node_map)); + TF_RETURN_IF_ERROR(recv_manager_.Init(node_map)); curr_node_ = nullptr; + return Status::OK(); } void CompositeNodeManager::AddNode(const NodeDef* node) { @@ -186,7 +188,7 @@ void CompositeNodeManager::AddNode(const NodeDef* node) { } else if (IsRecv(*node)) { recv_manager_.AddNode(node); } else { - const auto& device = node_state_->at(node).device_name; + const auto& device = node_map_->at(node).device_name; ops_lifo_map_[device].AddNode(node); } } @@ -203,16 +205,16 @@ const NodeDef* CompositeNodeManager::GetCurrNode() { for (auto& ops_lifo : ops_lifo_map_) { if (!ops_lifo.second.Empty()) { const auto* op = ops_lifo.second.GetCurrNode(); - candidates.emplace_back(op, node_state_->at(op).time_ready); + candidates.emplace_back(op, node_map_->at(op).time_ready); } } if (!send_manager_.Empty()) { const auto* send = send_manager_.GetCurrNode(); - candidates.emplace_back(send, node_state_->at(send).time_ready); + candidates.emplace_back(send, node_map_->at(send).time_ready); } if (!recv_manager_.Empty()) { const auto* recv = recv_manager_.GetCurrNode(); - candidates.emplace_back(recv, node_state_->at(recv).time_ready); + candidates.emplace_back(recv, node_map_->at(recv).time_ready); } CHECK(!candidates.empty()); auto first_ready = std::min_element( @@ -251,7 +253,7 @@ void CompositeNodeManager::RemoveCurrNode() { } else if (IsRecv(*node)) { recv_manager_.RemoveCurrNode(); } else { - const auto device = node_state_->at(node).device_name; + const auto device = node_map_->at(node).device_name; ops_lifo_map_[device].RemoveCurrNode(); } // Reset curr_node_ so that GetCurrNode() finds another node. @@ -300,9 +302,6 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes, } Status VirtualScheduler::Init(const GrapplerItem* item) { - grappler_item_ = item; - graph_properties_ = absl::make_unique(*item); - initialized_ = false; // Clear all internal states so that the VirtualScheduler is reusable for @@ -322,9 +321,10 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { // necessary information for emulating tensorflow op scheduling and // construct internal data structures (NodeState and DeviceState) for virtual // scheduling. - ready_nodes_->Init(GetNodeStates()); + TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates())); - // Construct graph properties. + // Constructs graph properties and performs shape inference. + graph_properties_ = absl::make_unique(*item); if (use_static_shapes_) { TF_RETURN_IF_ERROR(graph_properties_->InferStatically( true, use_aggressive_shape_inference_)); @@ -332,6 +332,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_)); } + grappler_item_ = item; const auto& graph = grappler_item_->graph; const auto& fetch_nodes = grappler_item_->fetch; std::set feed_nodes; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index afc39be4e55..47d5cc23a5a 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -132,8 +132,10 @@ class ReadyNodeManager { public: ReadyNodeManager() {} virtual ~ReadyNodeManager() {} - virtual void Init( - const std::unordered_map* node_state) {} + virtual Status Init( + const std::unordered_map* node_map) { + return Status::OK(); + } virtual void AddNode(const NodeDef* node) = 0; virtual const NodeDef* GetCurrNode() = 0; virtual void RemoveCurrNode() = 0; @@ -144,8 +146,6 @@ class FIFOManager : public ReadyNodeManager { public: FIFOManager() : ReadyNodeManager() {} ~FIFOManager() override {} - void Init(const std::unordered_map* node_state) - override {} void AddNode(const NodeDef* node) override { nodes_.push_back(node); } const NodeDef* GetCurrNode() override { CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; @@ -166,8 +166,6 @@ class LIFOManager : public ReadyNodeManager { public: LIFOManager() : ReadyNodeManager() {} ~LIFOManager() override {} - void Init(const std::unordered_map* node_state) - override {} void AddNode(const NodeDef* node) override { nodes_.push_back(node); } const NodeDef* GetCurrNode() override; void RemoveCurrNode() override; @@ -188,8 +186,8 @@ class LIFOManager : public ReadyNodeManager { class FirstReadyManager : public ReadyNodeManager { public: FirstReadyManager(); - void Init( - const std::unordered_map* node_state) override; + Status Init( + const std::unordered_map* node_map) override; ~FirstReadyManager() override {} void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); } const NodeDef* GetCurrNode() override; @@ -227,8 +225,8 @@ class CompositeNodeManager : public ReadyNodeManager { CompositeNodeManager(); ~CompositeNodeManager() override {} - void Init( - const std::unordered_map* node_state) override; + Status Init( + const std::unordered_map* node_map) override; void AddNode(const NodeDef* node) override; const NodeDef* GetCurrNode() override; void RemoveCurrNode() override; @@ -245,8 +243,8 @@ class CompositeNodeManager : public ReadyNodeManager { FirstReadyManager recv_manager_; // NodeState structure from VirtualScheduler to get time_ready of ready nodes. - // Not owned by FirstReadyManager. - const std::unordered_map* node_state_; + // Not owned by CompositeReadyManager. + const std::unordered_map* node_map_; // Cached curr node. Set back to nullptr from RemoveCurrNode(). const NodeDef* curr_node_; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index ce0b993f23d..dd7c0c2c583 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -207,14 +207,14 @@ TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) { TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) { FirstReadyManager manager; - manager.Init(&node_states_); + TF_EXPECT_OK(manager.Init(&node_states_)); manager.AddNode(&node1_); EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); } TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) { FirstReadyManager manager; - manager.Init(&node_states_); + TF_EXPECT_OK(manager.Init(&node_states_)); manager.AddNode(&node1_); manager.RemoveCurrNode(); EXPECT_TRUE(manager.Empty()); @@ -222,7 +222,7 @@ TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) { TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) { FirstReadyManager manager; - manager.Init(&node_states_); + TF_EXPECT_OK(manager.Init(&node_states_)); // Insert nodes in some random order. manager.AddNode(&node2_); manager.AddNode(&node1_); @@ -250,7 +250,7 @@ TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) { TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) { FirstReadyManager manager; - manager.Init(&node_states_); + TF_EXPECT_OK(manager.Init(&node_states_)); // Inserts nodes in some random order. manager.AddNode(&node2_); @@ -308,9 +308,9 @@ TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) { TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) { FirstReadyManager manager1; - manager1.Init(&node_states_); + TF_EXPECT_OK(manager1.Init(&node_states_)); FirstReadyManager manager2; - manager2.Init(&node_states_); + TF_EXPECT_OK(manager2.Init(&node_states_)); // 6 nodes with same time_ready. NodeDef node7; @@ -374,7 +374,7 @@ TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) { TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) { CompositeNodeManager manager; - manager.Init(&node_states_); + TF_EXPECT_OK(manager.Init(&node_states_)); manager.AddNode(&node1_); manager.RemoveCurrNode(); EXPECT_TRUE(manager.Empty()); @@ -382,7 +382,7 @@ TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) { TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) { CompositeNodeManager manager; - manager.Init(&node_states_); + TF_EXPECT_OK(manager.Init(&node_states_)); manager.AddNode(&node1_); manager.AddNode(&node2_); manager.AddNode(&node3_); @@ -412,7 +412,7 @@ TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) { TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) { CompositeNodeManager manager; - manager.Init(&node_states_); + TF_EXPECT_OK(manager.Init(&node_states_)); // Additional nodes on kCPU1. NodeDef node7; NodeDef node8; @@ -491,9 +491,9 @@ TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) { TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) { CompositeNodeManager manager; - manager.Init(&node_states_); + TF_EXPECT_OK(manager.Init(&node_states_)); CompositeNodeManager manager2; - manager2.Init(&node_states_); + TF_EXPECT_OK(manager2.Init(&node_states_)); // 6 nodes with same time_ready. NodeDef node7;