- Changed to return Status for Init() function of ReadyNodeManager.

- Changed node_state to node_map for CompositeNodeManager to be consistent with VirtualScheduler.

PiperOrigin-RevId: 246942431
This commit is contained in:
Peter Ma 2019-05-06 19:15:04 -07:00 committed by TensorFlower Gardener
parent cc91dc6aaa
commit 2f9cc84ba3
3 changed files with 38 additions and 39 deletions

View File

@ -116,7 +116,7 @@ FirstReadyManager::FirstReadyManager() : ReadyNodeManager() {
std::make_heap(nodes_.begin(), nodes_.end()); std::make_heap(nodes_.begin(), nodes_.end());
} }
void FirstReadyManager::Init( Status FirstReadyManager::Init(
const std::unordered_map<const NodeDef*, NodeState>* node_map) { const std::unordered_map<const NodeDef*, NodeState>* node_map) {
// Reset the node state since different instances of the scheduler can reuse // Reset the node state since different instances of the scheduler can reuse
// the same node_manager. // the same node_manager.
@ -133,6 +133,7 @@ void FirstReadyManager::Init(
return node_map_->at(a).time_ready > node_map_->at(b).time_ready; return node_map_->at(a).time_ready > node_map_->at(b).time_ready;
} }
}; };
return Status::OK();
} }
const NodeDef* FirstReadyManager::GetCurrNode() { const NodeDef* FirstReadyManager::GetCurrNode() {
@ -172,12 +173,13 @@ void FirstReadyManager::DrainWaitingQueue() {
CompositeNodeManager::CompositeNodeManager() CompositeNodeManager::CompositeNodeManager()
: ReadyNodeManager(), send_manager_(), recv_manager_() {} : ReadyNodeManager(), send_manager_(), recv_manager_() {}
void CompositeNodeManager::Init( Status CompositeNodeManager::Init(
const std::unordered_map<const NodeDef*, NodeState>* node_state) { const std::unordered_map<const NodeDef*, NodeState>* node_map) {
node_state_ = node_state; node_map_ = node_map;
send_manager_.Init(node_state); TF_RETURN_IF_ERROR(send_manager_.Init(node_map));
recv_manager_.Init(node_state); TF_RETURN_IF_ERROR(recv_manager_.Init(node_map));
curr_node_ = nullptr; curr_node_ = nullptr;
return Status::OK();
} }
void CompositeNodeManager::AddNode(const NodeDef* node) { void CompositeNodeManager::AddNode(const NodeDef* node) {
@ -186,7 +188,7 @@ void CompositeNodeManager::AddNode(const NodeDef* node) {
} else if (IsRecv(*node)) { } else if (IsRecv(*node)) {
recv_manager_.AddNode(node); recv_manager_.AddNode(node);
} else { } else {
const auto& device = node_state_->at(node).device_name; const auto& device = node_map_->at(node).device_name;
ops_lifo_map_[device].AddNode(node); ops_lifo_map_[device].AddNode(node);
} }
} }
@ -203,16 +205,16 @@ const NodeDef* CompositeNodeManager::GetCurrNode() {
for (auto& ops_lifo : ops_lifo_map_) { for (auto& ops_lifo : ops_lifo_map_) {
if (!ops_lifo.second.Empty()) { if (!ops_lifo.second.Empty()) {
const auto* op = ops_lifo.second.GetCurrNode(); const auto* op = ops_lifo.second.GetCurrNode();
candidates.emplace_back(op, node_state_->at(op).time_ready); candidates.emplace_back(op, node_map_->at(op).time_ready);
} }
} }
if (!send_manager_.Empty()) { if (!send_manager_.Empty()) {
const auto* send = send_manager_.GetCurrNode(); const auto* send = send_manager_.GetCurrNode();
candidates.emplace_back(send, node_state_->at(send).time_ready); candidates.emplace_back(send, node_map_->at(send).time_ready);
} }
if (!recv_manager_.Empty()) { if (!recv_manager_.Empty()) {
const auto* recv = recv_manager_.GetCurrNode(); const auto* recv = recv_manager_.GetCurrNode();
candidates.emplace_back(recv, node_state_->at(recv).time_ready); candidates.emplace_back(recv, node_map_->at(recv).time_ready);
} }
CHECK(!candidates.empty()); CHECK(!candidates.empty());
auto first_ready = std::min_element( auto first_ready = std::min_element(
@ -251,7 +253,7 @@ void CompositeNodeManager::RemoveCurrNode() {
} else if (IsRecv(*node)) { } else if (IsRecv(*node)) {
recv_manager_.RemoveCurrNode(); recv_manager_.RemoveCurrNode();
} else { } else {
const auto device = node_state_->at(node).device_name; const auto device = node_map_->at(node).device_name;
ops_lifo_map_[device].RemoveCurrNode(); ops_lifo_map_[device].RemoveCurrNode();
} }
// Reset curr_node_ so that GetCurrNode() finds another node. // Reset curr_node_ so that GetCurrNode() finds another node.
@ -300,9 +302,6 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
} }
Status VirtualScheduler::Init(const GrapplerItem* item) { Status VirtualScheduler::Init(const GrapplerItem* item) {
grappler_item_ = item;
graph_properties_ = absl::make_unique<GraphProperties>(*item);
initialized_ = false; initialized_ = false;
// Clear all internal states so that the VirtualScheduler is reusable for // Clear all internal states so that the VirtualScheduler is reusable for
@ -322,9 +321,10 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
// necessary information for emulating tensorflow op scheduling and // necessary information for emulating tensorflow op scheduling and
// construct internal data structures (NodeState and DeviceState) for virtual // construct internal data structures (NodeState and DeviceState) for virtual
// scheduling. // scheduling.
ready_nodes_->Init(GetNodeStates()); TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
// Construct graph properties. // Constructs graph properties and performs shape inference.
graph_properties_ = absl::make_unique<GraphProperties>(*item);
if (use_static_shapes_) { if (use_static_shapes_) {
TF_RETURN_IF_ERROR(graph_properties_->InferStatically( TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
true, use_aggressive_shape_inference_)); true, use_aggressive_shape_inference_));
@ -332,6 +332,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_)); TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
} }
grappler_item_ = item;
const auto& graph = grappler_item_->graph; const auto& graph = grappler_item_->graph;
const auto& fetch_nodes = grappler_item_->fetch; const auto& fetch_nodes = grappler_item_->fetch;
std::set<string> feed_nodes; std::set<string> feed_nodes;

View File

@ -132,8 +132,10 @@ class ReadyNodeManager {
public: public:
ReadyNodeManager() {} ReadyNodeManager() {}
virtual ~ReadyNodeManager() {} virtual ~ReadyNodeManager() {}
virtual void Init( virtual Status Init(
const std::unordered_map<const NodeDef*, NodeState>* node_state) {} const std::unordered_map<const NodeDef*, NodeState>* node_map) {
return Status::OK();
}
virtual void AddNode(const NodeDef* node) = 0; virtual void AddNode(const NodeDef* node) = 0;
virtual const NodeDef* GetCurrNode() = 0; virtual const NodeDef* GetCurrNode() = 0;
virtual void RemoveCurrNode() = 0; virtual void RemoveCurrNode() = 0;
@ -144,8 +146,6 @@ class FIFOManager : public ReadyNodeManager {
public: public:
FIFOManager() : ReadyNodeManager() {} FIFOManager() : ReadyNodeManager() {}
~FIFOManager() override {} ~FIFOManager() override {}
void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
override {}
void AddNode(const NodeDef* node) override { nodes_.push_back(node); } void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
const NodeDef* GetCurrNode() override { const NodeDef* GetCurrNode() override {
CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
@ -166,8 +166,6 @@ class LIFOManager : public ReadyNodeManager {
public: public:
LIFOManager() : ReadyNodeManager() {} LIFOManager() : ReadyNodeManager() {}
~LIFOManager() override {} ~LIFOManager() override {}
void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
override {}
void AddNode(const NodeDef* node) override { nodes_.push_back(node); } void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
const NodeDef* GetCurrNode() override; const NodeDef* GetCurrNode() override;
void RemoveCurrNode() override; void RemoveCurrNode() override;
@ -188,8 +186,8 @@ class LIFOManager : public ReadyNodeManager {
class FirstReadyManager : public ReadyNodeManager { class FirstReadyManager : public ReadyNodeManager {
public: public:
FirstReadyManager(); FirstReadyManager();
void Init( Status Init(
const std::unordered_map<const NodeDef*, NodeState>* node_state) override; const std::unordered_map<const NodeDef*, NodeState>* node_map) override;
~FirstReadyManager() override {} ~FirstReadyManager() override {}
void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); } void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); }
const NodeDef* GetCurrNode() override; const NodeDef* GetCurrNode() override;
@ -227,8 +225,8 @@ class CompositeNodeManager : public ReadyNodeManager {
CompositeNodeManager(); CompositeNodeManager();
~CompositeNodeManager() override {} ~CompositeNodeManager() override {}
void Init( Status Init(
const std::unordered_map<const NodeDef*, NodeState>* node_state) override; const std::unordered_map<const NodeDef*, NodeState>* node_map) override;
void AddNode(const NodeDef* node) override; void AddNode(const NodeDef* node) override;
const NodeDef* GetCurrNode() override; const NodeDef* GetCurrNode() override;
void RemoveCurrNode() override; void RemoveCurrNode() override;
@ -245,8 +243,8 @@ class CompositeNodeManager : public ReadyNodeManager {
FirstReadyManager recv_manager_; FirstReadyManager recv_manager_;
// NodeState structure from VirtualScheduler to get time_ready of ready nodes. // NodeState structure from VirtualScheduler to get time_ready of ready nodes.
// Not owned by FirstReadyManager. // Not owned by CompositeReadyManager.
const std::unordered_map<const NodeDef*, NodeState>* node_state_; const std::unordered_map<const NodeDef*, NodeState>* node_map_;
// Cached curr node. Set back to nullptr from RemoveCurrNode(). // Cached curr node. Set back to nullptr from RemoveCurrNode().
const NodeDef* curr_node_; const NodeDef* curr_node_;

View File

@ -207,14 +207,14 @@ TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) {
TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) { TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) {
FirstReadyManager manager; FirstReadyManager manager;
manager.Init(&node_states_); TF_EXPECT_OK(manager.Init(&node_states_));
manager.AddNode(&node1_); manager.AddNode(&node1_);
EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
} }
TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) { TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
FirstReadyManager manager; FirstReadyManager manager;
manager.Init(&node_states_); TF_EXPECT_OK(manager.Init(&node_states_));
manager.AddNode(&node1_); manager.AddNode(&node1_);
manager.RemoveCurrNode(); manager.RemoveCurrNode();
EXPECT_TRUE(manager.Empty()); EXPECT_TRUE(manager.Empty());
@ -222,7 +222,7 @@ TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) { TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
FirstReadyManager manager; FirstReadyManager manager;
manager.Init(&node_states_); TF_EXPECT_OK(manager.Init(&node_states_));
// Insert nodes in some random order. // Insert nodes in some random order.
manager.AddNode(&node2_); manager.AddNode(&node2_);
manager.AddNode(&node1_); manager.AddNode(&node1_);
@ -250,7 +250,7 @@ TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) { TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
FirstReadyManager manager; FirstReadyManager manager;
manager.Init(&node_states_); TF_EXPECT_OK(manager.Init(&node_states_));
// Inserts nodes in some random order. // Inserts nodes in some random order.
manager.AddNode(&node2_); manager.AddNode(&node2_);
@ -308,9 +308,9 @@ TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) { TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
FirstReadyManager manager1; FirstReadyManager manager1;
manager1.Init(&node_states_); TF_EXPECT_OK(manager1.Init(&node_states_));
FirstReadyManager manager2; FirstReadyManager manager2;
manager2.Init(&node_states_); TF_EXPECT_OK(manager2.Init(&node_states_));
// 6 nodes with same time_ready. // 6 nodes with same time_ready.
NodeDef node7; NodeDef node7;
@ -374,7 +374,7 @@ TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) { TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
CompositeNodeManager manager; CompositeNodeManager manager;
manager.Init(&node_states_); TF_EXPECT_OK(manager.Init(&node_states_));
manager.AddNode(&node1_); manager.AddNode(&node1_);
manager.RemoveCurrNode(); manager.RemoveCurrNode();
EXPECT_TRUE(manager.Empty()); EXPECT_TRUE(manager.Empty());
@ -382,7 +382,7 @@ TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) { TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) {
CompositeNodeManager manager; CompositeNodeManager manager;
manager.Init(&node_states_); TF_EXPECT_OK(manager.Init(&node_states_));
manager.AddNode(&node1_); manager.AddNode(&node1_);
manager.AddNode(&node2_); manager.AddNode(&node2_);
manager.AddNode(&node3_); manager.AddNode(&node3_);
@ -412,7 +412,7 @@ TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) {
TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) { TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) {
CompositeNodeManager manager; CompositeNodeManager manager;
manager.Init(&node_states_); TF_EXPECT_OK(manager.Init(&node_states_));
// Additional nodes on kCPU1. // Additional nodes on kCPU1.
NodeDef node7; NodeDef node7;
NodeDef node8; NodeDef node8;
@ -491,9 +491,9 @@ TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) {
TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) { TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) {
CompositeNodeManager manager; CompositeNodeManager manager;
manager.Init(&node_states_); TF_EXPECT_OK(manager.Init(&node_states_));
CompositeNodeManager manager2; CompositeNodeManager manager2;
manager2.Init(&node_states_); TF_EXPECT_OK(manager2.Init(&node_states_));
// 6 nodes with same time_ready. // 6 nodes with same time_ready.
NodeDef node7; NodeDef node7;