From fcacb40d4c5e2874f176b27ca75e7a1ce31fd87c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 12 Sep 2017 12:18:54 -0700 Subject: [PATCH] FirstReadyManager for scheduling nodes in VirtualScheduler. The current FIFOManager may yield inefficient scheduling; _Recv pushed to the FIFO blocks other nodes that can run before _Recv due to the node order in FIFO. FirstReadyManager picks a node with the earliest time_ready in the queue, avoiding this problem. Also, fixed VirtualPlacer to properly set device when Node's device name does not include job name and to set GPU:0 as default device. PiperOrigin-RevId: 168418455 --- tensorflow/core/grappler/costs/BUILD | 1 + .../core/grappler/costs/virtual_placer.cc | 59 +++++- .../core/grappler/costs/virtual_placer.h | 1 + .../grappler/costs/virtual_placer_test.cc | 139 +++++++++++- .../core/grappler/costs/virtual_scheduler.cc | 19 +- .../core/grappler/costs/virtual_scheduler.h | 84 +++++++- .../grappler/costs/virtual_scheduler_test.cc | 197 ++++++++++++++++-- 7 files changed, 460 insertions(+), 40 deletions(-) diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 0940cca1aac..3a029e2d940 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -190,6 +190,7 @@ cc_test( deps = [ ":virtual_placer", "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/grappler/costs/virtual_placer.cc b/tensorflow/core/grappler/costs/virtual_placer.cc index 64e2778fc9d..24c45235ff2 100644 --- a/tensorflow/core/grappler/costs/virtual_placer.cc +++ b/tensorflow/core/grappler/costs/virtual_placer.cc @@ -34,15 +34,62 @@ VirtualPlacer::VirtualPlacer(const Cluster* cluster) { DeviceProperties& prop = devices_["UNKNOWN"]; prop.set_type("UNKNOWN"); - } else { + } else if (devices_.size() == 1) { + // If there is only one device in the cluster, use it as default device, + // whatever it is. default_device_ = devices_.begin()->first; - VLOG(1) << "Number of devices: " << devices_.size(); + } else { + // Default device is set from the devices in the cluster in the following + // priority: /gpu:0, /cpu:0, or any device. + // TODO(dyoon): This logic assumes single machine with CPU and GPU devices. + // Make it more general to support multiple machines, job types, and devices + // other than CPU and GPU. + std::map cpu_devices; // CPU device map: id -> device name. + std::map gpu_devices; // GPU device map: id -> device name. for (const auto& device : devices_) { - if (str_util::Lowercase(device.first).find("gpu") != string::npos) { - default_device_ = device.first; - break; + DeviceNameUtils::ParsedName parsed_name; + bool parsed = DeviceNameUtils::ParseFullName(device.first, &parsed_name); + if (parsed) { + // Parsed devices are stored to cpu_devices or gpu_devices map, + // addressed (and orderd) by device id. + if (str_util::Lowercase(parsed_name.type) == "gpu") { + gpu_devices[parsed_name.id] = device.first; + } else if (str_util::Lowercase(parsed_name.type) == "cpu") { + cpu_devices[parsed_name.id] = device.first; + } } } + if (!gpu_devices.empty()) { + // GPU:0 (or GPU with smallest device id). + default_device_ = gpu_devices.begin()->second; + } else if (!cpu_devices.empty()) { + // CPU:0 (or CPU with smallest device id). + default_device_ = cpu_devices.begin()->second; + } else { + default_device_ = devices_.begin()->first; // Any device. + } + } + + // Default job name for canonical device name. + default_job_name_ = "localhost"; + // Scan the device names from the cluster, and if there is one job name used, + // use it for canonical device name. + std::unordered_set job_names_from_cluster; + for (const auto& device : devices_) { + const auto& device_name = device.first; + DeviceNameUtils::ParsedName parsed_name; + bool parsed = DeviceNameUtils::ParseFullName(device_name, &parsed_name); + if (parsed && !parsed_name.job.empty()) { + job_names_from_cluster.insert(parsed_name.job); + } + } + // If there is only type of job name in all the devices in the cluster, use + // that one as default job name; otherwise, use localhost. + // TODO(dyoon): this should be improved, especially when the cluster is + // composed of multiple worker, PS, and other types of jobs. + if (job_names_from_cluster.size() == 1) { + auto it = job_names_from_cluster.begin(); + default_job_name_ = *it; } } @@ -78,7 +125,7 @@ string VirtualPlacer::get_canonical_device_name(const NodeDef& node) const { return get_default_device_name(); } else { if (parsed_name.job.empty()) { - parsed_name.job = "localhost"; + parsed_name.job = default_job_name_; } device = strings::StrCat( "/job:", parsed_name.job, "/replica:", parsed_name.replica, diff --git a/tensorflow/core/grappler/costs/virtual_placer.h b/tensorflow/core/grappler/costs/virtual_placer.h index 6d814c95273..75ee496329d 100644 --- a/tensorflow/core/grappler/costs/virtual_placer.h +++ b/tensorflow/core/grappler/costs/virtual_placer.h @@ -41,6 +41,7 @@ class VirtualPlacer { private: std::unordered_map devices_; string default_device_; + string default_job_name_; const string& get_default_device_name() const; }; diff --git a/tensorflow/core/grappler/costs/virtual_placer_test.cc b/tensorflow/core/grappler/costs/virtual_placer_test.cc index a16455cb703..3a0510c44ae 100644 --- a/tensorflow/core/grappler/costs/virtual_placer_test.cc +++ b/tensorflow/core/grappler/costs/virtual_placer_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/virtual_placer.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/device_properties.pb.h" @@ -36,6 +37,7 @@ TEST(VirtualPlacerTest, LocalDevices) { NodeDef node; node.set_op("Conv2D"); + // node.device() is empty, but GPU is default device if there is. EXPECT_EQ("GPU", placer.get_device(node).type()); EXPECT_EQ("/job:localhost/replica:0/task:0/device:GPU:0", placer.get_canonical_device_name(node)); @@ -51,16 +53,42 @@ TEST(VirtualPlacerTest, LocalDevices) { placer.get_canonical_device_name(node)); } -TEST(VirtualPlacerTest, EmptyJobBecomesLocalhost) { - // Virtual placer should use "localhost" if device is empty. - // First create a cluster with only localhost devices. +TEST(VirtualPlacerTest, EmptyJobName) { + // Virtual placer choose job name from the devices in cluster if a device name + // of an op is empty. In case there are more than one kind of job name + // or job names are missin in the devices in cluster, we use local_host. + for (const string& job_name : {"localhost", "worker", "worker_train"}) { + std::unordered_map devices; + DeviceProperties cpu_device; + cpu_device.set_type("CPU"); + devices[strings::StrCat("/job:", job_name, "/replica:0/task:0/cpu:0")] = + cpu_device; + DeviceProperties gpu_device; + gpu_device.set_type("GPU"); + devices[strings::StrCat("/job:", job_name, + "/replica:0/task:0/device:GPU:0")] = gpu_device; + VirtualCluster cluster(devices); + VirtualPlacer placer(&cluster); + + NodeDef node; + node.set_op("Conv2D"); + node.set_device("/device:CPU:0"); + EXPECT_EQ(strings::StrCat("/job:", job_name, "/replica:0/task:0/cpu:0"), + placer.get_canonical_device_name(node)); + node.set_device("/device:GPU:0"); + EXPECT_EQ( + strings::StrCat("/job:", job_name, "/replica:0/task:0/device:GPU:0"), + placer.get_canonical_device_name(node)); + } + + // When more than one job names are used, we use default "localhost" + // This may be improved later. std::unordered_map devices; DeviceProperties cpu_device; cpu_device.set_type("CPU"); devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; - DeviceProperties gpu_device; - gpu_device.set_type("GPU"); - devices["/job:localhost/replica:0/task:0/device:GPU:0"] = gpu_device; + devices["/job:ps/replica:0/task:0/cpu:0"] = cpu_device; + devices["/job:worker/replica:0/task:0/cpu:0"] = cpu_device; VirtualCluster cluster(devices); VirtualPlacer placer(&cluster); @@ -69,9 +97,102 @@ TEST(VirtualPlacerTest, EmptyJobBecomesLocalhost) { node.set_device("/device:CPU:0"); EXPECT_EQ("/job:localhost/replica:0/task:0/cpu:0", placer.get_canonical_device_name(node)); - node.set_device("/device:GPU:0"); - EXPECT_EQ("/job:localhost/replica:0/task:0/device:GPU:0", - placer.get_canonical_device_name(node)); +} + +string GetDefaultDeviceName( + const std::unordered_map& devices) { + VirtualCluster cluster(devices); + VirtualPlacer placer(&cluster); + NodeDef node; + node.set_op("Conv2D"); + // Device is not set to the node, so get_canonical_device_name() will return + // the default_device_. + return placer.get_canonical_device_name(node); +} + +TEST(VirtualPlacerTest, DefaultDevice) { + std::unordered_map devices; + DeviceProperties cpu_device; + cpu_device.set_type("CPU"); + devices["/job:worker/replica:0/task:0/cpu:0"] = cpu_device; + + // CPU is default when there is only CPU. + EXPECT_EQ("/job:worker/replica:0/task:0/cpu:0", + GetDefaultDeviceName(devices)); + + DeviceProperties gpu_device; + gpu_device.set_type("GPU"); + + // If there is any GPU, then gpu:0 is default device. + for (int i = 0; i < 8; i++) { + devices[strings::StrCat("/job:worker/replica:0/task:0/gpu:", i)] = + gpu_device; + EXPECT_EQ("/job:worker/replica:0/task:0/gpu:0", + GetDefaultDeviceName(devices)); + } +} + +TEST(VirtualPlacerTest, MultiReplica) { + // Create a cluster with 8 workers, each with 8 GPUs. + std::unordered_map devices; + DeviceProperties cpu_device; + cpu_device.set_type("CPU"); + DeviceProperties gpu_device; + gpu_device.set_type("GPU"); + for (int i = 0; i < 8; i++) { + devices[strings::StrCat("/job:worker/replica:", i, "/task:0/cpu:0")] = + cpu_device; + for (int j = 0; j < 8; j++) { + devices[strings::StrCat("/job:worker/replica:", i, "/task:0/gpu:", j)] = + gpu_device; + } + } + + std::unique_ptr cluster(new VirtualCluster(devices)); + std::unique_ptr placer(new VirtualPlacer(cluster.get())); + + auto get_device_name = [&placer](const string& device) -> string { + NodeDef node; + node.set_op("Conv2D"); + node.set_device(device); + return placer->get_canonical_device_name(node); + }; + + // Validate device name is correct when we pass only replica ID and device + // name. + EXPECT_EQ("/job:worker/replica:0/task:0/cpu:0", + get_device_name("/replica:0/cpu:0")); + EXPECT_EQ("/job:worker/replica:2/task:0/cpu:0", + get_device_name("/replica:2/cpu:0")); + EXPECT_EQ("/job:worker/replica:7/task:0/cpu:0", + get_device_name("/replica:7/cpu:0")); + EXPECT_EQ("/job:worker/replica:3/task:0/gpu:0", + get_device_name("/replica:3/gpu:0")); + EXPECT_EQ("/job:worker/replica:5/task:0/gpu:3", + get_device_name("/replica:5/gpu:3")); + EXPECT_EQ("/job:worker/replica:4/task:0/gpu:7", + get_device_name("/replica:4/gpu:7")); + + // Now add PS replicas; with multiple job names present in the cluster, + // device names in nodes should specify job names correctly. + for (int i = 0; i < 4; i++) { + devices[strings::StrCat("/job:ps/replica:", i, "/task:0/cpu:0")] = + cpu_device; + } + cluster.reset(new VirtualCluster(devices)); + placer.reset(new VirtualPlacer(cluster.get())); + EXPECT_EQ("/job:worker/replica:0/task:0/cpu:0", + get_device_name("/job:worker/replica:0/cpu:0")); + EXPECT_EQ("/job:worker/replica:7/task:0/gpu:3", + get_device_name("/job:worker/replica:7/gpu:3")); + EXPECT_EQ("/job:ps/replica:0/task:0/cpu:0", + get_device_name("/job:ps/replica:0/cpu:0")); + EXPECT_EQ("/job:ps/replica:1/task:0/cpu:0", + get_device_name("/job:ps/replica:1/cpu:0")); + EXPECT_EQ("/job:ps/replica:2/task:0/cpu:0", + get_device_name("/job:ps/replica:2/cpu:0")); + EXPECT_EQ("/job:ps/replica:3/task:0/cpu:0", + get_device_name("/job:ps/replica:3/cpu:0")); } TEST(VirtualPlacerTest, FallBackUnknown) { diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index ea5fead4b9c..16c434b0ad1 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -88,10 +88,7 @@ struct RecvNodeDescriptorEqual { VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item, const bool use_static_shapes, Cluster* cluster) - : // Allow LIFO as well as FIFO. LIFO allows an output node of an node to - // follow it in execution, saving addition memory time from having to - // write and read. For default cases, use FIFO for performance. - ready_nodes_(new FIFOManager()), + : ready_nodes_(ReadyNodeManagerFactory("FirstReady")), graph_costs_(Costs::ZeroCosts()), graph_properties_(*grappler_item), cluster_(cluster), @@ -101,6 +98,18 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item, initialized_ = false; } +ReadyNodeManager* VirtualScheduler::ReadyNodeManagerFactory( + const string& ready_node_manager) { + if (ready_node_manager == "FIFO") { + return new FIFOManager(); + } else if (ready_node_manager == "LIFO") { + return new LIFOManager(); + } else if (ready_node_manager == "FirstReady") { + return new FirstReadyManager(GetNodeStates()); + } + CHECK(false) << "Not a valid ready node manager: " << ready_node_manager; +} + Status VirtualScheduler::Init() { // Init() preprocesses the input grappler_item and graph_properties to extract // necessary information for emulating tensorflow op scheduling and @@ -210,7 +219,7 @@ Status VirtualScheduler::Init() { if (given_as_feed || has_no_inputs) { curr_node_state.time_ready = Costs::Duration(); ready_nodes_->AddNode(curr_node); - VLOG(1) << "Added ready node: " << curr_node->name(); + VLOG(3) << "Added ready node: " << curr_node->name(); } feed_nodes.erase(curr_node->name()); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index e9abecb1223..0bbd2fd2eb9 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -179,6 +179,77 @@ class LIFOManager : public ReadyNodeManager { std::list::iterator curr_pos_ = nodes_.end(); }; +// FirstReadyManager picks a node with the minimum time_ready value. +// Behavior is unknown if there are more than one nodes with the minimum +// time_ready value (it depends on C++ STL push_heap and pop_heap). +class FirstReadyManager : public ReadyNodeManager { + public: + FirstReadyManager( + const std::unordered_map* node_state) + : ReadyNodeManager(), node_state_(node_state) { + std::make_heap(nodes_.begin(), nodes_.end()); + greater_ = [this](const NodeDef* a, const NodeDef* b) -> bool { + // Note: we need a node with minimum time_ready, not + // maximum; hence, using a > b for comparison function. + return node_state_->at(a).time_ready > node_state_->at(b).time_ready; + }; + } + ~FirstReadyManager() override {} + + void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); } + + const NodeDef* GetCurrNode() override { + if (nodes_.empty()) { + // Nothing in the node_; probably, the very first call. Move + // waiting_queue_ to node_. + _DrainWaitingQueue(); + CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; + } + return nodes_.front(); + } + + void RemoveCurrNode() override { + if (nodes_.empty()) { + // Make sure that there is a node to be removed at the front of nodes_. + GetCurrNode(); + } + std::pop_heap(nodes_.begin(), nodes_.end(), greater_); + nodes_.pop_back(); + _DrainWaitingQueue(); + } + + bool Empty() const override { + return nodes_.empty() && waiting_queue_.empty(); + } + + private: + // Move all the nodes in the waiting_queue_ to nodes_. + void _DrainWaitingQueue() { + for (const auto* node : waiting_queue_) { + // push_heap in AddNode() and pop_heap in RemoveCurrNode() guarantees that + // the first element is the node with minimum time_ready. + nodes_.push_back(node); + std::push_heap(nodes_.begin(), nodes_.end(), greater_); + } + waiting_queue_.clear(); + } + + // nodes_ is the main queue, where we construct heap, and the front is the + // current node. + std::vector nodes_; + // Newly added nodes are added to waiting_queue_. That way, GetCurrNode(), + // wihch returns the front of the nodes_, always returns the same node, + // even if any of new nodes has time_ready smaller than the current node's. + std::vector waiting_queue_; + // Comparator functor for heap; stl heap is max heap, so we use "greater than" + // functor for keeping the smallest time_ready node at the front of heap. + std::function greater_; + + // NodeState structure from VirtualScheduler to get time_ready of ready nodes. + // Not owned by FirstReadyManager. + const std::unordered_map* node_state_; +}; + // A wrapper struct to OpInfo proto. // TODO(dyoon): once we extend OpInfo or implement a better interface, and then // delete this wrapper struct. @@ -211,13 +282,11 @@ class VirtualScheduler { Costs Summary(RunMetadata* metadata); protected: - // GetDeviceStates and GetNodeStates are currently for testing purpuse only. - // Retrieves detailed scheduling results. - const std::unordered_map& GetDeviceStates() const { - return device_; + const std::unordered_map* GetDeviceStates() const { + return &device_; } - const std::unordered_map& GetNodeStates() const { - return node_map_; + const std::unordered_map* GetNodeStates() const { + return &node_map_; } // Returns the size of output at port_num (unit: bytes). A special case is @@ -233,6 +302,9 @@ class VirtualScheduler { const string kAttrDstDevice = "dst_device_"; const string kChannelDevice = "Channel"; + // Methods called from constructor. + ReadyNodeManager* ReadyNodeManagerFactory(const string& ready_node_manager); + // Methods called from Init(). Fails if initialize_ is set. void MaybeUpdateInputOutput(const NodeDef* node); NodeState& GetNodeStateOrCreateIt(const NodeDef* node); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index a4f72b0f97c..cea00b04f26 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -42,6 +42,7 @@ class TestVirtualScheduler : public VirtualScheduler { class VirtualSchedulerTest : public ::testing::Test { protected: NodeDef node1_, node2_, node3_, node4_, node5_, node6_; + std::unordered_map node_states_; const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0"; const string kCPU1 = "/job:localhost/replica:0/task:0/cpu:1"; @@ -67,6 +68,21 @@ class VirtualSchedulerTest : public ::testing::Test { node5_.set_name("Node5"); node6_.set_name("Node6"); + // Initialize node_states, with time_ready in reverse order. + node_states_[&node1_] = NodeState(); + node_states_[&node2_] = NodeState(); + node_states_[&node3_] = NodeState(); + node_states_[&node4_] = NodeState(); + node_states_[&node5_] = NodeState(); + node_states_[&node6_] = NodeState(); + + node_states_[&node6_].time_ready = 1000; + node_states_[&node5_].time_ready = 2000; + node_states_[&node4_].time_ready = 3000; + node_states_[&node3_].time_ready = 4000; + node_states_[&node2_].time_ready = 5000; + node_states_[&node1_].time_ready = 6000; + // Initializes cluster_ and placer_. std::unordered_map devices; @@ -984,6 +1000,112 @@ TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) { EXPECT_TRUE(manager.Empty()); } +TEST_F(VirtualSchedulerTest, GetSingleNodeFirstReadyManager) { + FirstReadyManager manager = FirstReadyManager(&node_states_); + + manager.AddNode(&node1_); + EXPECT_EQ("Node1", manager.GetCurrNode()->name()); +} + +TEST_F(VirtualSchedulerTest, RemoveSingleNodeFirstReadyManager) { + FirstReadyManager manager = FirstReadyManager(&node_states_); + + manager.AddNode(&node1_); + manager.RemoveCurrNode(); + EXPECT_TRUE(manager.Empty()); +} + +TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) { + FirstReadyManager manager = FirstReadyManager(&node_states_); + + // Insert nodes in some random order. + manager.AddNode(&node2_); + manager.AddNode(&node1_); + manager.AddNode(&node4_); + manager.AddNode(&node5_); + manager.AddNode(&node3_); + manager.AddNode(&node6_); + + // In whatever order we insert nodes, we get the same order based on nodes' + // time_ready. + EXPECT_EQ("Node6", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node5", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node4", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node3", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node2", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node1", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_TRUE(manager.Empty()); +} + +TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) { + FirstReadyManager manager = FirstReadyManager(&node_states_); + + // Insert nodes in some random order. + manager.AddNode(&node2_); + manager.AddNode(&node1_); + manager.AddNode(&node4_); + manager.AddNode(&node5_); + manager.AddNode(&node3_); + manager.AddNode(&node6_); + + // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode() + // should return it. + EXPECT_EQ("Node6", manager.GetCurrNode()->name()); + // Now insret a few other nodes, but their time_ready's are even smaller than + // that of Node6. Befor calling RemoveCurrNode(), GetCurrNode() should return + // the same node, Node6, in this case. + + NodeDef node7; + NodeDef node8; + NodeDef node9; + node7.set_name("Node7"); + node8.set_name("Node8"); + node9.set_name("Node9"); + node_states_[&node7] = NodeState(); + node_states_[&node8] = NodeState(); + node_states_[&node9] = NodeState(); + node_states_[&node7].time_ready = 5; + node_states_[&node8].time_ready = 4; + node_states_[&node9].time_ready = 3; + + manager.AddNode(&node7); + EXPECT_EQ("Node6", manager.GetCurrNode()->name()); + + manager.AddNode(&node8); + EXPECT_EQ("Node6", manager.GetCurrNode()->name()); + + manager.RemoveCurrNode(); + // Now Node6 is removed, and GetCurrNode() will return Node8. + EXPECT_EQ("Node8", manager.GetCurrNode()->name()); + + // Again, AddNode shouldn't change GetCurrNode(). + manager.AddNode(&node9); + EXPECT_EQ("Node8", manager.GetCurrNode()->name()); + + manager.RemoveCurrNode(); + EXPECT_EQ("Node9", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node7", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node5", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node4", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node3", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node2", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node1", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_TRUE(manager.Empty()); +} + // Create small graph, run predict costs on it, make sure the costs from the // summary match the hand-calculated costs. TEST_F(VirtualSchedulerTest, SummaryCostTest) { @@ -1129,8 +1251,8 @@ TEST_F(VirtualSchedulerTest, MemoryUsage) { // Run the scheduler. RunScheduler(""); - const auto& device_states = scheduler_->GetDeviceStates(); - const auto& cpu_state = device_states.at(kCPU0); + const auto* device_states = scheduler_->GetDeviceStates(); + const auto& cpu_state = device_states->at(kCPU0); // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage // is 4 x the input tensor size while executing the out node. @@ -1150,8 +1272,8 @@ TEST_F(VirtualSchedulerTest, ControlDependency) { // Run the scheduler. RunScheduler(""); - const auto& device_states = scheduler_->GetDeviceStates(); - const auto& cpu_state = device_states.at(kCPU0); + const auto* device_states = scheduler_->GetDeviceStates(); + const auto& cpu_state = device_states->at(kCPU0); // The graph has a NoOp that takes control dependency from 7 NoOps. The peak // memory usage is when executing the final NoOp. @@ -1173,7 +1295,7 @@ TEST_F(VirtualSchedulerTest, ComplexDependency) { RunScheduler("bn"); const auto& device_states = scheduler_->GetDeviceStates(); - const auto& cpu_state = device_states.at(kCPU0); + const auto& cpu_state = device_states->at(kCPU0); // The graph is // bn = FusedBatchNorm(x, scale, offset, mean, var) @@ -1202,10 +1324,10 @@ TEST_F(VirtualSchedulerTest, ComplexDependency) { }; ExpectSetEq(expected, nodes_in_memory); - const auto& node_states = scheduler_->GetNodeStates(); + const auto* node_states = scheduler_->GetNodeStates(); const NodeState* bn_node = nullptr; const NodeState* x_node = nullptr; - for (const auto& nodedef_node_state : node_states) { + for (const auto& nodedef_node_state : *node_states) { const NodeDef* node = nodedef_node_state.first; const NodeState& node_state = nodedef_node_state.second; if (node->name() == "bn") { @@ -1233,8 +1355,8 @@ TEST_F(VirtualSchedulerTest, Variable) { // Run the scheduler. RunScheduler(""); - const auto& device_states = scheduler_->GetDeviceStates(); - const auto& cpu_state = device_states.at(kCPU0); + const auto* device_states = scheduler_->GetDeviceStates(); + const auto& cpu_state = device_states->at(kCPU0); // There is one Conv2D that takes x and f, but f is variable, so it should be // in persistent nodes. @@ -1258,25 +1380,67 @@ TEST_F(VirtualSchedulerTest, WhileLoop) { RunMetadata metadata; scheduler_->Summary(&metadata); + // Nodes in topological order (each node takes 1 usec) and possible start + // time usec: + // * const, ones: 0, 1 usec + // * while/Enter, while/Enter_1: 2, 3 usec + // * while/Merge, while/Merge_1: 4, 5 usec + // * while/Less/y: 6 usec + // * while/Less: 7 usec + // * while/LoopCond: 8 usec + // * while/Switch, while/Switch_1: 9, 10 usec + // * while/Identity, while/Identity_1, while/Exit, while/Exit_1: 11 - 14 usec + // * while/add/y, while/concat/Axis: 15, 16 usec + // * while/add, while/concat: 17, 18 usec + // * while/NextIteration, while/NextIteration_1: 19, 20 usec + int num_next_iteration = 0; int num_next_iteration_1 = 0; int num_exit = 0; int num_exit_1 = 0; + int64 next_iter_start_micro; + int64 next_iter_1_start_micro; + int64 exit_start_micro; + int64 exit_1_start_micro; for (const auto& device_step_stats : metadata.step_stats().dev_stats()) { for (const auto& stats : device_step_stats.node_stats()) { std::cout << stats.DebugString() << std::endl; - if (stats.node_name() == "while/NextIteration") { + // Start micro for while/Less/y, while/Less, and while/LoopCond are fixed + // regardless of scheduling method. + if (stats.node_name() == "while/Less/y") { + EXPECT_EQ(6, stats.all_start_micros()); + } else if (stats.node_name() == "while/Less") { + EXPECT_EQ(7, stats.all_start_micros()); + } else if (stats.node_name() == "while/LoopCond") { + EXPECT_EQ(8, stats.all_start_micros()); + } else if (stats.node_name() == "while/NextIteration") { ++num_next_iteration; - EXPECT_EQ(19, stats.all_start_micros()); + // Start time can be either 19 or 20 depending on how the scheduler + // picks a node among ready nodes. + next_iter_start_micro = stats.all_start_micros(); + EXPECT_LE(19, next_iter_start_micro); + EXPECT_GE(20, next_iter_start_micro); } else if (stats.node_name() == "while/NextIteration_1") { ++num_next_iteration_1; - EXPECT_EQ(20, stats.all_start_micros()); + // Start time can be either 19 or 20 depending on how the scheduler + // picks a node among ready nodes. + next_iter_1_start_micro = stats.all_start_micros(); + EXPECT_LE(19, next_iter_1_start_micro); + EXPECT_GE(20, next_iter_1_start_micro); } else if (stats.node_name() == "while/Exit") { ++num_exit; - EXPECT_EQ(14, stats.all_start_micros()); + // Start time can be between 11 and 14 (inclusive) depending on how + // the scheduler picks a node among ready nodes. + exit_start_micro = stats.all_start_micros(); + EXPECT_LE(11, exit_start_micro); + EXPECT_GE(14, exit_start_micro); } else if (stats.node_name() == "while/Exit_1") { ++num_exit_1; - EXPECT_EQ(12, stats.all_start_micros()); + // Start time can be between 11 and 14 (inclusive) depending on how + // the scheduler picks a node among ready nodes. + exit_1_start_micro = stats.all_start_micros(); + EXPECT_LE(11, exit_1_start_micro); + EXPECT_GE(14, exit_1_start_micro); } } } @@ -1287,6 +1451,11 @@ TEST_F(VirtualSchedulerTest, WhileLoop) { EXPECT_EQ(1, num_next_iteration_1); EXPECT_EQ(1, num_exit); EXPECT_EQ(1, num_exit_1); + + // Start times of while/NextIteration and while/NextIteration_1 should be + // different, so should be those of while/Exit and while/Exit_1. + EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro); + EXPECT_NE(exit_start_micro, exit_1_start_micro); } TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {