FirstReadyManager for scheduling nodes in VirtualScheduler.

The current FIFOManager may yield inefficient scheduling; _Recv pushed to the
FIFO blocks other nodes that can run before _Recv due to the node order in FIFO.
FirstReadyManager picks a node with the earliest time_ready in the queue,
avoiding this problem.

Also, fixed VirtualPlacer to properly set device when Node's device name does not
include job name and to set GPU:0 as default device.

PiperOrigin-RevId: 168418455
This commit is contained in:
A. Unique TensorFlower 2017-09-12 12:18:54 -07:00 committed by TensorFlower Gardener
parent 7e47624f5f
commit fcacb40d4c
7 changed files with 460 additions and 40 deletions

View File

@ -190,6 +190,7 @@ cc_test(
deps = [
":virtual_placer",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -34,15 +34,62 @@ VirtualPlacer::VirtualPlacer(const Cluster* cluster) {
DeviceProperties& prop = devices_["UNKNOWN"];
prop.set_type("UNKNOWN");
} else {
} else if (devices_.size() == 1) {
// If there is only one device in the cluster, use it as default device,
// whatever it is.
default_device_ = devices_.begin()->first;
VLOG(1) << "Number of devices: " << devices_.size();
} else {
// Default device is set from the devices in the cluster in the following
// priority: /gpu:0, /cpu:0, or any device.
// TODO(dyoon): This logic assumes single machine with CPU and GPU devices.
// Make it more general to support multiple machines, job types, and devices
// other than CPU and GPU.
std::map<int, string> cpu_devices; // CPU device map: id -> device name.
std::map<int, string> gpu_devices; // GPU device map: id -> device name.
for (const auto& device : devices_) {
if (str_util::Lowercase(device.first).find("gpu") != string::npos) {
default_device_ = device.first;
break;
DeviceNameUtils::ParsedName parsed_name;
bool parsed = DeviceNameUtils::ParseFullName(device.first, &parsed_name);
if (parsed) {
// Parsed devices are stored to cpu_devices or gpu_devices map,
// addressed (and orderd) by device id.
if (str_util::Lowercase(parsed_name.type) == "gpu") {
gpu_devices[parsed_name.id] = device.first;
} else if (str_util::Lowercase(parsed_name.type) == "cpu") {
cpu_devices[parsed_name.id] = device.first;
}
}
}
if (!gpu_devices.empty()) {
// GPU:0 (or GPU with smallest device id).
default_device_ = gpu_devices.begin()->second;
} else if (!cpu_devices.empty()) {
// CPU:0 (or CPU with smallest device id).
default_device_ = cpu_devices.begin()->second;
} else {
default_device_ = devices_.begin()->first; // Any device.
}
}
// Default job name for canonical device name.
default_job_name_ = "localhost";
// Scan the device names from the cluster, and if there is one job name used,
// use it for canonical device name.
std::unordered_set<string> job_names_from_cluster;
for (const auto& device : devices_) {
const auto& device_name = device.first;
DeviceNameUtils::ParsedName parsed_name;
bool parsed = DeviceNameUtils::ParseFullName(device_name, &parsed_name);
if (parsed && !parsed_name.job.empty()) {
job_names_from_cluster.insert(parsed_name.job);
}
}
// If there is only type of job name in all the devices in the cluster, use
// that one as default job name; otherwise, use localhost.
// TODO(dyoon): this should be improved, especially when the cluster is
// composed of multiple worker, PS, and other types of jobs.
if (job_names_from_cluster.size() == 1) {
auto it = job_names_from_cluster.begin();
default_job_name_ = *it;
}
}
@ -78,7 +125,7 @@ string VirtualPlacer::get_canonical_device_name(const NodeDef& node) const {
return get_default_device_name();
} else {
if (parsed_name.job.empty()) {
parsed_name.job = "localhost";
parsed_name.job = default_job_name_;
}
device = strings::StrCat(
"/job:", parsed_name.job, "/replica:", parsed_name.replica,

View File

@ -41,6 +41,7 @@ class VirtualPlacer {
private:
std::unordered_map<string, DeviceProperties> devices_;
string default_device_;
string default_job_name_;
const string& get_default_device_name() const;
};

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/virtual_placer.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
@ -36,6 +37,7 @@ TEST(VirtualPlacerTest, LocalDevices) {
NodeDef node;
node.set_op("Conv2D");
// node.device() is empty, but GPU is default device if there is.
EXPECT_EQ("GPU", placer.get_device(node).type());
EXPECT_EQ("/job:localhost/replica:0/task:0/device:GPU:0",
placer.get_canonical_device_name(node));
@ -51,16 +53,42 @@ TEST(VirtualPlacerTest, LocalDevices) {
placer.get_canonical_device_name(node));
}
TEST(VirtualPlacerTest, EmptyJobBecomesLocalhost) {
// Virtual placer should use "localhost" if device is empty.
// First create a cluster with only localhost devices.
TEST(VirtualPlacerTest, EmptyJobName) {
// Virtual placer choose job name from the devices in cluster if a device name
// of an op is empty. In case there are more than one kind of job name
// or job names are missin in the devices in cluster, we use local_host.
for (const string& job_name : {"localhost", "worker", "worker_train"}) {
std::unordered_map<string, DeviceProperties> devices;
DeviceProperties cpu_device;
cpu_device.set_type("CPU");
devices[strings::StrCat("/job:", job_name, "/replica:0/task:0/cpu:0")] =
cpu_device;
DeviceProperties gpu_device;
gpu_device.set_type("GPU");
devices[strings::StrCat("/job:", job_name,
"/replica:0/task:0/device:GPU:0")] = gpu_device;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
NodeDef node;
node.set_op("Conv2D");
node.set_device("/device:CPU:0");
EXPECT_EQ(strings::StrCat("/job:", job_name, "/replica:0/task:0/cpu:0"),
placer.get_canonical_device_name(node));
node.set_device("/device:GPU:0");
EXPECT_EQ(
strings::StrCat("/job:", job_name, "/replica:0/task:0/device:GPU:0"),
placer.get_canonical_device_name(node));
}
// When more than one job names are used, we use default "localhost"
// This may be improved later.
std::unordered_map<string, DeviceProperties> devices;
DeviceProperties cpu_device;
cpu_device.set_type("CPU");
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
DeviceProperties gpu_device;
gpu_device.set_type("GPU");
devices["/job:localhost/replica:0/task:0/device:GPU:0"] = gpu_device;
devices["/job:ps/replica:0/task:0/cpu:0"] = cpu_device;
devices["/job:worker/replica:0/task:0/cpu:0"] = cpu_device;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
@ -69,9 +97,102 @@ TEST(VirtualPlacerTest, EmptyJobBecomesLocalhost) {
node.set_device("/device:CPU:0");
EXPECT_EQ("/job:localhost/replica:0/task:0/cpu:0",
placer.get_canonical_device_name(node));
node.set_device("/device:GPU:0");
EXPECT_EQ("/job:localhost/replica:0/task:0/device:GPU:0",
placer.get_canonical_device_name(node));
}
string GetDefaultDeviceName(
const std::unordered_map<string, DeviceProperties>& devices) {
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);
NodeDef node;
node.set_op("Conv2D");
// Device is not set to the node, so get_canonical_device_name() will return
// the default_device_.
return placer.get_canonical_device_name(node);
}
TEST(VirtualPlacerTest, DefaultDevice) {
std::unordered_map<string, DeviceProperties> devices;
DeviceProperties cpu_device;
cpu_device.set_type("CPU");
devices["/job:worker/replica:0/task:0/cpu:0"] = cpu_device;
// CPU is default when there is only CPU.
EXPECT_EQ("/job:worker/replica:0/task:0/cpu:0",
GetDefaultDeviceName(devices));
DeviceProperties gpu_device;
gpu_device.set_type("GPU");
// If there is any GPU, then gpu:0 is default device.
for (int i = 0; i < 8; i++) {
devices[strings::StrCat("/job:worker/replica:0/task:0/gpu:", i)] =
gpu_device;
EXPECT_EQ("/job:worker/replica:0/task:0/gpu:0",
GetDefaultDeviceName(devices));
}
}
TEST(VirtualPlacerTest, MultiReplica) {
// Create a cluster with 8 workers, each with 8 GPUs.
std::unordered_map<string, DeviceProperties> devices;
DeviceProperties cpu_device;
cpu_device.set_type("CPU");
DeviceProperties gpu_device;
gpu_device.set_type("GPU");
for (int i = 0; i < 8; i++) {
devices[strings::StrCat("/job:worker/replica:", i, "/task:0/cpu:0")] =
cpu_device;
for (int j = 0; j < 8; j++) {
devices[strings::StrCat("/job:worker/replica:", i, "/task:0/gpu:", j)] =
gpu_device;
}
}
std::unique_ptr<VirtualCluster> cluster(new VirtualCluster(devices));
std::unique_ptr<VirtualPlacer> placer(new VirtualPlacer(cluster.get()));
auto get_device_name = [&placer](const string& device) -> string {
NodeDef node;
node.set_op("Conv2D");
node.set_device(device);
return placer->get_canonical_device_name(node);
};
// Validate device name is correct when we pass only replica ID and device
// name.
EXPECT_EQ("/job:worker/replica:0/task:0/cpu:0",
get_device_name("/replica:0/cpu:0"));
EXPECT_EQ("/job:worker/replica:2/task:0/cpu:0",
get_device_name("/replica:2/cpu:0"));
EXPECT_EQ("/job:worker/replica:7/task:0/cpu:0",
get_device_name("/replica:7/cpu:0"));
EXPECT_EQ("/job:worker/replica:3/task:0/gpu:0",
get_device_name("/replica:3/gpu:0"));
EXPECT_EQ("/job:worker/replica:5/task:0/gpu:3",
get_device_name("/replica:5/gpu:3"));
EXPECT_EQ("/job:worker/replica:4/task:0/gpu:7",
get_device_name("/replica:4/gpu:7"));
// Now add PS replicas; with multiple job names present in the cluster,
// device names in nodes should specify job names correctly.
for (int i = 0; i < 4; i++) {
devices[strings::StrCat("/job:ps/replica:", i, "/task:0/cpu:0")] =
cpu_device;
}
cluster.reset(new VirtualCluster(devices));
placer.reset(new VirtualPlacer(cluster.get()));
EXPECT_EQ("/job:worker/replica:0/task:0/cpu:0",
get_device_name("/job:worker/replica:0/cpu:0"));
EXPECT_EQ("/job:worker/replica:7/task:0/gpu:3",
get_device_name("/job:worker/replica:7/gpu:3"));
EXPECT_EQ("/job:ps/replica:0/task:0/cpu:0",
get_device_name("/job:ps/replica:0/cpu:0"));
EXPECT_EQ("/job:ps/replica:1/task:0/cpu:0",
get_device_name("/job:ps/replica:1/cpu:0"));
EXPECT_EQ("/job:ps/replica:2/task:0/cpu:0",
get_device_name("/job:ps/replica:2/cpu:0"));
EXPECT_EQ("/job:ps/replica:3/task:0/cpu:0",
get_device_name("/job:ps/replica:3/cpu:0"));
}
TEST(VirtualPlacerTest, FallBackUnknown) {

View File

@ -88,10 +88,7 @@ struct RecvNodeDescriptorEqual {
VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
const bool use_static_shapes,
Cluster* cluster)
: // Allow LIFO as well as FIFO. LIFO allows an output node of an node to
// follow it in execution, saving addition memory time from having to
// write and read. For default cases, use FIFO for performance.
ready_nodes_(new FIFOManager()),
: ready_nodes_(ReadyNodeManagerFactory("FirstReady")),
graph_costs_(Costs::ZeroCosts()),
graph_properties_(*grappler_item),
cluster_(cluster),
@ -101,6 +98,18 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
initialized_ = false;
}
ReadyNodeManager* VirtualScheduler::ReadyNodeManagerFactory(
const string& ready_node_manager) {
if (ready_node_manager == "FIFO") {
return new FIFOManager();
} else if (ready_node_manager == "LIFO") {
return new LIFOManager();
} else if (ready_node_manager == "FirstReady") {
return new FirstReadyManager(GetNodeStates());
}
CHECK(false) << "Not a valid ready node manager: " << ready_node_manager;
}
Status VirtualScheduler::Init() {
// Init() preprocesses the input grappler_item and graph_properties to extract
// necessary information for emulating tensorflow op scheduling and
@ -210,7 +219,7 @@ Status VirtualScheduler::Init() {
if (given_as_feed || has_no_inputs) {
curr_node_state.time_ready = Costs::Duration();
ready_nodes_->AddNode(curr_node);
VLOG(1) << "Added ready node: " << curr_node->name();
VLOG(3) << "Added ready node: " << curr_node->name();
}
feed_nodes.erase(curr_node->name());

View File

@ -179,6 +179,77 @@ class LIFOManager : public ReadyNodeManager {
std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end();
};
// FirstReadyManager picks a node with the minimum time_ready value.
// Behavior is unknown if there are more than one nodes with the minimum
// time_ready value (it depends on C++ STL push_heap and pop_heap).
class FirstReadyManager : public ReadyNodeManager {
public:
FirstReadyManager(
const std::unordered_map<const NodeDef*, NodeState>* node_state)
: ReadyNodeManager(), node_state_(node_state) {
std::make_heap(nodes_.begin(), nodes_.end());
greater_ = [this](const NodeDef* a, const NodeDef* b) -> bool {
// Note: we need a node with minimum time_ready, not
// maximum; hence, using a > b for comparison function.
return node_state_->at(a).time_ready > node_state_->at(b).time_ready;
};
}
~FirstReadyManager() override {}
void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); }
const NodeDef* GetCurrNode() override {
if (nodes_.empty()) {
// Nothing in the node_; probably, the very first call. Move
// waiting_queue_ to node_.
_DrainWaitingQueue();
CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
}
return nodes_.front();
}
void RemoveCurrNode() override {
if (nodes_.empty()) {
// Make sure that there is a node to be removed at the front of nodes_.
GetCurrNode();
}
std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
nodes_.pop_back();
_DrainWaitingQueue();
}
bool Empty() const override {
return nodes_.empty() && waiting_queue_.empty();
}
private:
// Move all the nodes in the waiting_queue_ to nodes_.
void _DrainWaitingQueue() {
for (const auto* node : waiting_queue_) {
// push_heap in AddNode() and pop_heap in RemoveCurrNode() guarantees that
// the first element is the node with minimum time_ready.
nodes_.push_back(node);
std::push_heap(nodes_.begin(), nodes_.end(), greater_);
}
waiting_queue_.clear();
}
// nodes_ is the main queue, where we construct heap, and the front is the
// current node.
std::vector<const NodeDef*> nodes_;
// Newly added nodes are added to waiting_queue_. That way, GetCurrNode(),
// wihch returns the front of the nodes_, always returns the same node,
// even if any of new nodes has time_ready smaller than the current node's.
std::vector<const NodeDef*> waiting_queue_;
// Comparator functor for heap; stl heap is max heap, so we use "greater than"
// functor for keeping the smallest time_ready node at the front of heap.
std::function<bool(const NodeDef*, const NodeDef*)> greater_;
// NodeState structure from VirtualScheduler to get time_ready of ready nodes.
// Not owned by FirstReadyManager.
const std::unordered_map<const NodeDef*, NodeState>* node_state_;
};
// A wrapper struct to OpInfo proto.
// TODO(dyoon): once we extend OpInfo or implement a better interface, and then
// delete this wrapper struct.
@ -211,13 +282,11 @@ class VirtualScheduler {
Costs Summary(RunMetadata* metadata);
protected:
// GetDeviceStates and GetNodeStates are currently for testing purpuse only.
// Retrieves detailed scheduling results.
const std::unordered_map<string, DeviceState>& GetDeviceStates() const {
return device_;
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
return &device_;
}
const std::unordered_map<const NodeDef*, NodeState>& GetNodeStates() const {
return node_map_;
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
return &node_map_;
}
// Returns the size of output at port_num (unit: bytes). A special case is
@ -233,6 +302,9 @@ class VirtualScheduler {
const string kAttrDstDevice = "dst_device_";
const string kChannelDevice = "Channel";
// Methods called from constructor.
ReadyNodeManager* ReadyNodeManagerFactory(const string& ready_node_manager);
// Methods called from Init(). Fails if initialize_ is set.
void MaybeUpdateInputOutput(const NodeDef* node);
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);

View File

@ -42,6 +42,7 @@ class TestVirtualScheduler : public VirtualScheduler {
class VirtualSchedulerTest : public ::testing::Test {
protected:
NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
std::unordered_map<const NodeDef*, NodeState> node_states_;
const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
const string kCPU1 = "/job:localhost/replica:0/task:0/cpu:1";
@ -67,6 +68,21 @@ class VirtualSchedulerTest : public ::testing::Test {
node5_.set_name("Node5");
node6_.set_name("Node6");
// Initialize node_states, with time_ready in reverse order.
node_states_[&node1_] = NodeState();
node_states_[&node2_] = NodeState();
node_states_[&node3_] = NodeState();
node_states_[&node4_] = NodeState();
node_states_[&node5_] = NodeState();
node_states_[&node6_] = NodeState();
node_states_[&node6_].time_ready = 1000;
node_states_[&node5_].time_ready = 2000;
node_states_[&node4_].time_ready = 3000;
node_states_[&node3_].time_ready = 4000;
node_states_[&node2_].time_ready = 5000;
node_states_[&node1_].time_ready = 6000;
// Initializes cluster_ and placer_.
std::unordered_map<string, DeviceProperties> devices;
@ -984,6 +1000,112 @@ TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) {
EXPECT_TRUE(manager.Empty());
}
TEST_F(VirtualSchedulerTest, GetSingleNodeFirstReadyManager) {
FirstReadyManager manager = FirstReadyManager(&node_states_);
manager.AddNode(&node1_);
EXPECT_EQ("Node1", manager.GetCurrNode()->name());
}
TEST_F(VirtualSchedulerTest, RemoveSingleNodeFirstReadyManager) {
FirstReadyManager manager = FirstReadyManager(&node_states_);
manager.AddNode(&node1_);
manager.RemoveCurrNode();
EXPECT_TRUE(manager.Empty());
}
TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) {
FirstReadyManager manager = FirstReadyManager(&node_states_);
// Insert nodes in some random order.
manager.AddNode(&node2_);
manager.AddNode(&node1_);
manager.AddNode(&node4_);
manager.AddNode(&node5_);
manager.AddNode(&node3_);
manager.AddNode(&node6_);
// In whatever order we insert nodes, we get the same order based on nodes'
// time_ready.
EXPECT_EQ("Node6", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node5", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node4", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node3", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node2", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node1", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_TRUE(manager.Empty());
}
TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) {
FirstReadyManager manager = FirstReadyManager(&node_states_);
// Insert nodes in some random order.
manager.AddNode(&node2_);
manager.AddNode(&node1_);
manager.AddNode(&node4_);
manager.AddNode(&node5_);
manager.AddNode(&node3_);
manager.AddNode(&node6_);
// Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
// should return it.
EXPECT_EQ("Node6", manager.GetCurrNode()->name());
// Now insret a few other nodes, but their time_ready's are even smaller than
// that of Node6. Befor calling RemoveCurrNode(), GetCurrNode() should return
// the same node, Node6, in this case.
NodeDef node7;
NodeDef node8;
NodeDef node9;
node7.set_name("Node7");
node8.set_name("Node8");
node9.set_name("Node9");
node_states_[&node7] = NodeState();
node_states_[&node8] = NodeState();
node_states_[&node9] = NodeState();
node_states_[&node7].time_ready = 5;
node_states_[&node8].time_ready = 4;
node_states_[&node9].time_ready = 3;
manager.AddNode(&node7);
EXPECT_EQ("Node6", manager.GetCurrNode()->name());
manager.AddNode(&node8);
EXPECT_EQ("Node6", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
// Now Node6 is removed, and GetCurrNode() will return Node8.
EXPECT_EQ("Node8", manager.GetCurrNode()->name());
// Again, AddNode shouldn't change GetCurrNode().
manager.AddNode(&node9);
EXPECT_EQ("Node8", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node9", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node7", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node5", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node4", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node3", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node2", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_EQ("Node1", manager.GetCurrNode()->name());
manager.RemoveCurrNode();
EXPECT_TRUE(manager.Empty());
}
// Create small graph, run predict costs on it, make sure the costs from the
// summary match the hand-calculated costs.
TEST_F(VirtualSchedulerTest, SummaryCostTest) {
@ -1129,8 +1251,8 @@ TEST_F(VirtualSchedulerTest, MemoryUsage) {
// Run the scheduler.
RunScheduler("");
const auto& device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states.at(kCPU0);
const auto* device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states->at(kCPU0);
// out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
// is 4 x the input tensor size while executing the out node.
@ -1150,8 +1272,8 @@ TEST_F(VirtualSchedulerTest, ControlDependency) {
// Run the scheduler.
RunScheduler("");
const auto& device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states.at(kCPU0);
const auto* device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states->at(kCPU0);
// The graph has a NoOp that takes control dependency from 7 NoOps. The peak
// memory usage is when executing the final NoOp.
@ -1173,7 +1295,7 @@ TEST_F(VirtualSchedulerTest, ComplexDependency) {
RunScheduler("bn");
const auto& device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states.at(kCPU0);
const auto& cpu_state = device_states->at(kCPU0);
// The graph is
// bn = FusedBatchNorm(x, scale, offset, mean, var)
@ -1202,10 +1324,10 @@ TEST_F(VirtualSchedulerTest, ComplexDependency) {
};
ExpectSetEq(expected, nodes_in_memory);
const auto& node_states = scheduler_->GetNodeStates();
const auto* node_states = scheduler_->GetNodeStates();
const NodeState* bn_node = nullptr;
const NodeState* x_node = nullptr;
for (const auto& nodedef_node_state : node_states) {
for (const auto& nodedef_node_state : *node_states) {
const NodeDef* node = nodedef_node_state.first;
const NodeState& node_state = nodedef_node_state.second;
if (node->name() == "bn") {
@ -1233,8 +1355,8 @@ TEST_F(VirtualSchedulerTest, Variable) {
// Run the scheduler.
RunScheduler("");
const auto& device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states.at(kCPU0);
const auto* device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states->at(kCPU0);
// There is one Conv2D that takes x and f, but f is variable, so it should be
// in persistent nodes.
@ -1258,25 +1380,67 @@ TEST_F(VirtualSchedulerTest, WhileLoop) {
RunMetadata metadata;
scheduler_->Summary(&metadata);
// Nodes in topological order (each node takes 1 usec) and possible start
// time usec:
// * const, ones: 0, 1 usec
// * while/Enter, while/Enter_1: 2, 3 usec
// * while/Merge, while/Merge_1: 4, 5 usec
// * while/Less/y: 6 usec
// * while/Less: 7 usec
// * while/LoopCond: 8 usec
// * while/Switch, while/Switch_1: 9, 10 usec
// * while/Identity, while/Identity_1, while/Exit, while/Exit_1: 11 - 14 usec
// * while/add/y, while/concat/Axis: 15, 16 usec
// * while/add, while/concat: 17, 18 usec
// * while/NextIteration, while/NextIteration_1: 19, 20 usec
int num_next_iteration = 0;
int num_next_iteration_1 = 0;
int num_exit = 0;
int num_exit_1 = 0;
int64 next_iter_start_micro;
int64 next_iter_1_start_micro;
int64 exit_start_micro;
int64 exit_1_start_micro;
for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
for (const auto& stats : device_step_stats.node_stats()) {
std::cout << stats.DebugString() << std::endl;
if (stats.node_name() == "while/NextIteration") {
// Start micro for while/Less/y, while/Less, and while/LoopCond are fixed
// regardless of scheduling method.
if (stats.node_name() == "while/Less/y") {
EXPECT_EQ(6, stats.all_start_micros());
} else if (stats.node_name() == "while/Less") {
EXPECT_EQ(7, stats.all_start_micros());
} else if (stats.node_name() == "while/LoopCond") {
EXPECT_EQ(8, stats.all_start_micros());
} else if (stats.node_name() == "while/NextIteration") {
++num_next_iteration;
EXPECT_EQ(19, stats.all_start_micros());
// Start time can be either 19 or 20 depending on how the scheduler
// picks a node among ready nodes.
next_iter_start_micro = stats.all_start_micros();
EXPECT_LE(19, next_iter_start_micro);
EXPECT_GE(20, next_iter_start_micro);
} else if (stats.node_name() == "while/NextIteration_1") {
++num_next_iteration_1;
EXPECT_EQ(20, stats.all_start_micros());
// Start time can be either 19 or 20 depending on how the scheduler
// picks a node among ready nodes.
next_iter_1_start_micro = stats.all_start_micros();
EXPECT_LE(19, next_iter_1_start_micro);
EXPECT_GE(20, next_iter_1_start_micro);
} else if (stats.node_name() == "while/Exit") {
++num_exit;
EXPECT_EQ(14, stats.all_start_micros());
// Start time can be between 11 and 14 (inclusive) depending on how
// the scheduler picks a node among ready nodes.
exit_start_micro = stats.all_start_micros();
EXPECT_LE(11, exit_start_micro);
EXPECT_GE(14, exit_start_micro);
} else if (stats.node_name() == "while/Exit_1") {
++num_exit_1;
EXPECT_EQ(12, stats.all_start_micros());
// Start time can be between 11 and 14 (inclusive) depending on how
// the scheduler picks a node among ready nodes.
exit_1_start_micro = stats.all_start_micros();
EXPECT_LE(11, exit_1_start_micro);
EXPECT_GE(14, exit_1_start_micro);
}
}
}
@ -1287,6 +1451,11 @@ TEST_F(VirtualSchedulerTest, WhileLoop) {
EXPECT_EQ(1, num_next_iteration_1);
EXPECT_EQ(1, num_exit);
EXPECT_EQ(1, num_exit_1);
// Start times of while/NextIteration and while/NextIteration_1 should be
// different, so should be those of while/Exit and while/Exit_1.
EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
EXPECT_NE(exit_start_micro, exit_1_start_micro);
}
TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {