- Changed to return Status for Init() function of ReadyNodeManager.

- Changed node_state to node_map for CompositeNodeManager to be consistent with VirtualScheduler.

PiperOrigin-RevId: 246942431
This commit is contained in:
Peter Ma 2019-05-06 19:15:04 -07:00 committed by TensorFlower Gardener
parent cc91dc6aaa
commit 2f9cc84ba3
3 changed files with 38 additions and 39 deletions

View File

@ -116,7 +116,7 @@ FirstReadyManager::FirstReadyManager() : ReadyNodeManager() {
std::make_heap(nodes_.begin(), nodes_.end());
}
void FirstReadyManager::Init(
Status FirstReadyManager::Init(
const std::unordered_map<const NodeDef*, NodeState>* node_map) {
// Reset the node state since different instances of the scheduler can reuse
// the same node_manager.
@ -133,6 +133,7 @@ void FirstReadyManager::Init(
return node_map_->at(a).time_ready > node_map_->at(b).time_ready;
}
};
return Status::OK();
}
const NodeDef* FirstReadyManager::GetCurrNode() {
@ -172,12 +173,13 @@ void FirstReadyManager::DrainWaitingQueue() {
CompositeNodeManager::CompositeNodeManager()
: ReadyNodeManager(), send_manager_(), recv_manager_() {}
void CompositeNodeManager::Init(
const std::unordered_map<const NodeDef*, NodeState>* node_state) {
node_state_ = node_state;
send_manager_.Init(node_state);
recv_manager_.Init(node_state);
Status CompositeNodeManager::Init(
const std::unordered_map<const NodeDef*, NodeState>* node_map) {
node_map_ = node_map;
TF_RETURN_IF_ERROR(send_manager_.Init(node_map));
TF_RETURN_IF_ERROR(recv_manager_.Init(node_map));
curr_node_ = nullptr;
return Status::OK();
}
void CompositeNodeManager::AddNode(const NodeDef* node) {
@ -186,7 +188,7 @@ void CompositeNodeManager::AddNode(const NodeDef* node) {
} else if (IsRecv(*node)) {
recv_manager_.AddNode(node);
} else {
const auto& device = node_state_->at(node).device_name;
const auto& device = node_map_->at(node).device_name;
ops_lifo_map_[device].AddNode(node);
}
}
@ -203,16 +205,16 @@ const NodeDef* CompositeNodeManager::GetCurrNode() {
for (auto& ops_lifo : ops_lifo_map_) {
if (!ops_lifo.second.Empty()) {
const auto* op = ops_lifo.second.GetCurrNode();
candidates.emplace_back(op, node_state_->at(op).time_ready);
candidates.emplace_back(op, node_map_->at(op).time_ready);
}
}
if (!send_manager_.Empty()) {
const auto* send = send_manager_.GetCurrNode();
candidates.emplace_back(send, node_state_->at(send).time_ready);
candidates.emplace_back(send, node_map_->at(send).time_ready);
}
if (!recv_manager_.Empty()) {
const auto* recv = recv_manager_.GetCurrNode();
candidates.emplace_back(recv, node_state_->at(recv).time_ready);
candidates.emplace_back(recv, node_map_->at(recv).time_ready);
}
CHECK(!candidates.empty());
auto first_ready = std::min_element(
@ -251,7 +253,7 @@ void CompositeNodeManager::RemoveCurrNode() {
} else if (IsRecv(*node)) {
recv_manager_.RemoveCurrNode();
} else {
const auto device = node_state_->at(node).device_name;
const auto device = node_map_->at(node).device_name;
ops_lifo_map_[device].RemoveCurrNode();
}
// Reset curr_node_ so that GetCurrNode() finds another node.
@ -300,9 +302,6 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
}
Status VirtualScheduler::Init(const GrapplerItem* item) {
grappler_item_ = item;
graph_properties_ = absl::make_unique<GraphProperties>(*item);
initialized_ = false;
// Clear all internal states so that the VirtualScheduler is reusable for
@ -322,9 +321,10 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
// necessary information for emulating tensorflow op scheduling and
// construct internal data structures (NodeState and DeviceState) for virtual
// scheduling.
ready_nodes_->Init(GetNodeStates());
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
// Construct graph properties.
// Constructs graph properties and performs shape inference.
graph_properties_ = absl::make_unique<GraphProperties>(*item);
if (use_static_shapes_) {
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
true, use_aggressive_shape_inference_));
@ -332,6 +332,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
}
grappler_item_ = item;
const auto& graph = grappler_item_->graph;
const auto& fetch_nodes = grappler_item_->fetch;
std::set<string> feed_nodes;

View File

@ -132,8 +132,10 @@ class ReadyNodeManager {
public:
ReadyNodeManager() {}
virtual ~ReadyNodeManager() {}
virtual void Init(
const std::unordered_map<const NodeDef*, NodeState>* node_state) {}
virtual Status Init(
const std::unordered_map<const NodeDef*, NodeState>* node_map) {
return Status::OK();
}
virtual void AddNode(const NodeDef* node) = 0;
virtual const NodeDef* GetCurrNode() = 0;
virtual void RemoveCurrNode() = 0;
@ -144,8 +146,6 @@ class FIFOManager : public ReadyNodeManager {
public:
FIFOManager() : ReadyNodeManager() {}
~FIFOManager() override {}
void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
override {}
void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
const NodeDef* GetCurrNode() override {
CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
@ -166,8 +166,6 @@ class LIFOManager : public ReadyNodeManager {
public:
LIFOManager() : ReadyNodeManager() {}
~LIFOManager() override {}
void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
override {}
void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
const NodeDef* GetCurrNode() override;
void RemoveCurrNode() override;
@ -188,8 +186,8 @@ class LIFOManager : public ReadyNodeManager {
class FirstReadyManager : public ReadyNodeManager {
public:
FirstReadyManager();
void Init(
const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
Status Init(
const std::unordered_map<const NodeDef*, NodeState>* node_map) override;
~FirstReadyManager() override {}
void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); }
const NodeDef* GetCurrNode() override;
@ -227,8 +225,8 @@ class CompositeNodeManager : public ReadyNodeManager {
CompositeNodeManager();
~CompositeNodeManager() override {}
void Init(
const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
Status Init(
const std::unordered_map<const NodeDef*, NodeState>* node_map) override;
void AddNode(const NodeDef* node) override;
const NodeDef* GetCurrNode() override;
void RemoveCurrNode() override;
@ -245,8 +243,8 @@ class CompositeNodeManager : public ReadyNodeManager {
FirstReadyManager recv_manager_;
// NodeState structure from VirtualScheduler to get time_ready of ready nodes.
// Not owned by FirstReadyManager.
const std::unordered_map<const NodeDef*, NodeState>* node_state_;
// Not owned by CompositeReadyManager.
const std::unordered_map<const NodeDef*, NodeState>* node_map_;
// Cached curr node. Set back to nullptr from RemoveCurrNode().
const NodeDef* curr_node_;

View File

@ -207,14 +207,14 @@ TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) {
TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) {
FirstReadyManager manager;
manager.Init(&node_states_);
TF_EXPECT_OK(manager.Init(&node_states_));
manager.AddNode(&node1_);
EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
}
TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
FirstReadyManager manager;
manager.Init(&node_states_);
TF_EXPECT_OK(manager.Init(&node_states_));
manager.AddNode(&node1_);
manager.RemoveCurrNode();
EXPECT_TRUE(manager.Empty());
@ -222,7 +222,7 @@ TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
FirstReadyManager manager;
manager.Init(&node_states_);
TF_EXPECT_OK(manager.Init(&node_states_));
// Insert nodes in some random order.
manager.AddNode(&node2_);
manager.AddNode(&node1_);
@ -250,7 +250,7 @@ TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
FirstReadyManager manager;
manager.Init(&node_states_);
TF_EXPECT_OK(manager.Init(&node_states_));
// Inserts nodes in some random order.
manager.AddNode(&node2_);
@ -308,9 +308,9 @@ TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
FirstReadyManager manager1;
manager1.Init(&node_states_);
TF_EXPECT_OK(manager1.Init(&node_states_));
FirstReadyManager manager2;
manager2.Init(&node_states_);
TF_EXPECT_OK(manager2.Init(&node_states_));
// 6 nodes with same time_ready.
NodeDef node7;
@ -374,7 +374,7 @@ TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
CompositeNodeManager manager;
manager.Init(&node_states_);
TF_EXPECT_OK(manager.Init(&node_states_));
manager.AddNode(&node1_);
manager.RemoveCurrNode();
EXPECT_TRUE(manager.Empty());
@ -382,7 +382,7 @@ TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) {
CompositeNodeManager manager;
manager.Init(&node_states_);
TF_EXPECT_OK(manager.Init(&node_states_));
manager.AddNode(&node1_);
manager.AddNode(&node2_);
manager.AddNode(&node3_);
@ -412,7 +412,7 @@ TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) {
TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) {
CompositeNodeManager manager;
manager.Init(&node_states_);
TF_EXPECT_OK(manager.Init(&node_states_));
// Additional nodes on kCPU1.
NodeDef node7;
NodeDef node8;
@ -491,9 +491,9 @@ TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) {
TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) {
CompositeNodeManager manager;
manager.Init(&node_states_);
TF_EXPECT_OK(manager.Init(&node_states_));
CompositeNodeManager manager2;
manager2.Init(&node_states_);
TF_EXPECT_OK(manager2.Init(&node_states_));
// 6 nodes with same time_ready.
NodeDef node7;