- Changed to return Status for Init() function of ReadyNodeManager.
- Changed node_state to node_map for CompositeNodeManager to be consistent with VirtualScheduler. PiperOrigin-RevId: 246942431
This commit is contained in:
parent
cc91dc6aaa
commit
2f9cc84ba3
@ -116,7 +116,7 @@ FirstReadyManager::FirstReadyManager() : ReadyNodeManager() {
|
||||
std::make_heap(nodes_.begin(), nodes_.end());
|
||||
}
|
||||
|
||||
void FirstReadyManager::Init(
|
||||
Status FirstReadyManager::Init(
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_map) {
|
||||
// Reset the node state since different instances of the scheduler can reuse
|
||||
// the same node_manager.
|
||||
@ -133,6 +133,7 @@ void FirstReadyManager::Init(
|
||||
return node_map_->at(a).time_ready > node_map_->at(b).time_ready;
|
||||
}
|
||||
};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const NodeDef* FirstReadyManager::GetCurrNode() {
|
||||
@ -172,12 +173,13 @@ void FirstReadyManager::DrainWaitingQueue() {
|
||||
CompositeNodeManager::CompositeNodeManager()
|
||||
: ReadyNodeManager(), send_manager_(), recv_manager_() {}
|
||||
|
||||
void CompositeNodeManager::Init(
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_state) {
|
||||
node_state_ = node_state;
|
||||
send_manager_.Init(node_state);
|
||||
recv_manager_.Init(node_state);
|
||||
Status CompositeNodeManager::Init(
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_map) {
|
||||
node_map_ = node_map;
|
||||
TF_RETURN_IF_ERROR(send_manager_.Init(node_map));
|
||||
TF_RETURN_IF_ERROR(recv_manager_.Init(node_map));
|
||||
curr_node_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CompositeNodeManager::AddNode(const NodeDef* node) {
|
||||
@ -186,7 +188,7 @@ void CompositeNodeManager::AddNode(const NodeDef* node) {
|
||||
} else if (IsRecv(*node)) {
|
||||
recv_manager_.AddNode(node);
|
||||
} else {
|
||||
const auto& device = node_state_->at(node).device_name;
|
||||
const auto& device = node_map_->at(node).device_name;
|
||||
ops_lifo_map_[device].AddNode(node);
|
||||
}
|
||||
}
|
||||
@ -203,16 +205,16 @@ const NodeDef* CompositeNodeManager::GetCurrNode() {
|
||||
for (auto& ops_lifo : ops_lifo_map_) {
|
||||
if (!ops_lifo.second.Empty()) {
|
||||
const auto* op = ops_lifo.second.GetCurrNode();
|
||||
candidates.emplace_back(op, node_state_->at(op).time_ready);
|
||||
candidates.emplace_back(op, node_map_->at(op).time_ready);
|
||||
}
|
||||
}
|
||||
if (!send_manager_.Empty()) {
|
||||
const auto* send = send_manager_.GetCurrNode();
|
||||
candidates.emplace_back(send, node_state_->at(send).time_ready);
|
||||
candidates.emplace_back(send, node_map_->at(send).time_ready);
|
||||
}
|
||||
if (!recv_manager_.Empty()) {
|
||||
const auto* recv = recv_manager_.GetCurrNode();
|
||||
candidates.emplace_back(recv, node_state_->at(recv).time_ready);
|
||||
candidates.emplace_back(recv, node_map_->at(recv).time_ready);
|
||||
}
|
||||
CHECK(!candidates.empty());
|
||||
auto first_ready = std::min_element(
|
||||
@ -251,7 +253,7 @@ void CompositeNodeManager::RemoveCurrNode() {
|
||||
} else if (IsRecv(*node)) {
|
||||
recv_manager_.RemoveCurrNode();
|
||||
} else {
|
||||
const auto device = node_state_->at(node).device_name;
|
||||
const auto device = node_map_->at(node).device_name;
|
||||
ops_lifo_map_[device].RemoveCurrNode();
|
||||
}
|
||||
// Reset curr_node_ so that GetCurrNode() finds another node.
|
||||
@ -300,9 +302,6 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||
}
|
||||
|
||||
Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||
grappler_item_ = item;
|
||||
graph_properties_ = absl::make_unique<GraphProperties>(*item);
|
||||
|
||||
initialized_ = false;
|
||||
|
||||
// Clear all internal states so that the VirtualScheduler is reusable for
|
||||
@ -322,9 +321,10 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||
// necessary information for emulating tensorflow op scheduling and
|
||||
// construct internal data structures (NodeState and DeviceState) for virtual
|
||||
// scheduling.
|
||||
ready_nodes_->Init(GetNodeStates());
|
||||
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
|
||||
|
||||
// Construct graph properties.
|
||||
// Constructs graph properties and performs shape inference.
|
||||
graph_properties_ = absl::make_unique<GraphProperties>(*item);
|
||||
if (use_static_shapes_) {
|
||||
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
|
||||
true, use_aggressive_shape_inference_));
|
||||
@ -332,6 +332,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||
TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
|
||||
}
|
||||
|
||||
grappler_item_ = item;
|
||||
const auto& graph = grappler_item_->graph;
|
||||
const auto& fetch_nodes = grappler_item_->fetch;
|
||||
std::set<string> feed_nodes;
|
||||
|
@ -132,8 +132,10 @@ class ReadyNodeManager {
|
||||
public:
|
||||
ReadyNodeManager() {}
|
||||
virtual ~ReadyNodeManager() {}
|
||||
virtual void Init(
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_state) {}
|
||||
virtual Status Init(
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_map) {
|
||||
return Status::OK();
|
||||
}
|
||||
virtual void AddNode(const NodeDef* node) = 0;
|
||||
virtual const NodeDef* GetCurrNode() = 0;
|
||||
virtual void RemoveCurrNode() = 0;
|
||||
@ -144,8 +146,6 @@ class FIFOManager : public ReadyNodeManager {
|
||||
public:
|
||||
FIFOManager() : ReadyNodeManager() {}
|
||||
~FIFOManager() override {}
|
||||
void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
|
||||
override {}
|
||||
void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
|
||||
const NodeDef* GetCurrNode() override {
|
||||
CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
|
||||
@ -166,8 +166,6 @@ class LIFOManager : public ReadyNodeManager {
|
||||
public:
|
||||
LIFOManager() : ReadyNodeManager() {}
|
||||
~LIFOManager() override {}
|
||||
void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
|
||||
override {}
|
||||
void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
|
||||
const NodeDef* GetCurrNode() override;
|
||||
void RemoveCurrNode() override;
|
||||
@ -188,8 +186,8 @@ class LIFOManager : public ReadyNodeManager {
|
||||
class FirstReadyManager : public ReadyNodeManager {
|
||||
public:
|
||||
FirstReadyManager();
|
||||
void Init(
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
|
||||
Status Init(
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_map) override;
|
||||
~FirstReadyManager() override {}
|
||||
void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); }
|
||||
const NodeDef* GetCurrNode() override;
|
||||
@ -227,8 +225,8 @@ class CompositeNodeManager : public ReadyNodeManager {
|
||||
CompositeNodeManager();
|
||||
~CompositeNodeManager() override {}
|
||||
|
||||
void Init(
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
|
||||
Status Init(
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_map) override;
|
||||
void AddNode(const NodeDef* node) override;
|
||||
const NodeDef* GetCurrNode() override;
|
||||
void RemoveCurrNode() override;
|
||||
@ -245,8 +243,8 @@ class CompositeNodeManager : public ReadyNodeManager {
|
||||
FirstReadyManager recv_manager_;
|
||||
|
||||
// NodeState structure from VirtualScheduler to get time_ready of ready nodes.
|
||||
// Not owned by FirstReadyManager.
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_state_;
|
||||
// Not owned by CompositeReadyManager.
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_map_;
|
||||
|
||||
// Cached curr node. Set back to nullptr from RemoveCurrNode().
|
||||
const NodeDef* curr_node_;
|
||||
|
@ -207,14 +207,14 @@ TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) {
|
||||
|
||||
TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) {
|
||||
FirstReadyManager manager;
|
||||
manager.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager.Init(&node_states_));
|
||||
manager.AddNode(&node1_);
|
||||
EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
|
||||
}
|
||||
|
||||
TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
|
||||
FirstReadyManager manager;
|
||||
manager.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager.Init(&node_states_));
|
||||
manager.AddNode(&node1_);
|
||||
manager.RemoveCurrNode();
|
||||
EXPECT_TRUE(manager.Empty());
|
||||
@ -222,7 +222,7 @@ TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
|
||||
|
||||
TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
|
||||
FirstReadyManager manager;
|
||||
manager.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager.Init(&node_states_));
|
||||
// Insert nodes in some random order.
|
||||
manager.AddNode(&node2_);
|
||||
manager.AddNode(&node1_);
|
||||
@ -250,7 +250,7 @@ TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
|
||||
|
||||
TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
|
||||
FirstReadyManager manager;
|
||||
manager.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager.Init(&node_states_));
|
||||
|
||||
// Inserts nodes in some random order.
|
||||
manager.AddNode(&node2_);
|
||||
@ -308,9 +308,9 @@ TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
|
||||
|
||||
TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
|
||||
FirstReadyManager manager1;
|
||||
manager1.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager1.Init(&node_states_));
|
||||
FirstReadyManager manager2;
|
||||
manager2.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager2.Init(&node_states_));
|
||||
|
||||
// 6 nodes with same time_ready.
|
||||
NodeDef node7;
|
||||
@ -374,7 +374,7 @@ TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
|
||||
|
||||
TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
|
||||
CompositeNodeManager manager;
|
||||
manager.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager.Init(&node_states_));
|
||||
manager.AddNode(&node1_);
|
||||
manager.RemoveCurrNode();
|
||||
EXPECT_TRUE(manager.Empty());
|
||||
@ -382,7 +382,7 @@ TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
|
||||
|
||||
TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) {
|
||||
CompositeNodeManager manager;
|
||||
manager.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager.Init(&node_states_));
|
||||
manager.AddNode(&node1_);
|
||||
manager.AddNode(&node2_);
|
||||
manager.AddNode(&node3_);
|
||||
@ -412,7 +412,7 @@ TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) {
|
||||
|
||||
TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) {
|
||||
CompositeNodeManager manager;
|
||||
manager.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager.Init(&node_states_));
|
||||
// Additional nodes on kCPU1.
|
||||
NodeDef node7;
|
||||
NodeDef node8;
|
||||
@ -491,9 +491,9 @@ TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) {
|
||||
|
||||
TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) {
|
||||
CompositeNodeManager manager;
|
||||
manager.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager.Init(&node_states_));
|
||||
CompositeNodeManager manager2;
|
||||
manager2.Init(&node_states_);
|
||||
TF_EXPECT_OK(manager2.Init(&node_states_));
|
||||
|
||||
// 6 nodes with same time_ready.
|
||||
NodeDef node7;
|
||||
|
Loading…
Reference in New Issue
Block a user