From f089b3180ec7ddecd503d6d499cd6977c22b8f11 Mon Sep 17 00:00:00 2001 From: Lifeng Nai Date: Thu, 28 Mar 2019 22:54:35 -0700 Subject: [PATCH] update VirtualPlacer constructor interface. PiperOrigin-RevId: 240923787 --- .../core/grappler/costs/virtual_placer.cc | 14 +++++------- .../core/grappler/costs/virtual_placer.h | 2 +- .../grappler/costs/virtual_placer_test.cc | 22 +++++++++---------- .../core/grappler/costs/virtual_scheduler.cc | 2 +- .../grappler/optimizers/layout_optimizer.cc | 2 +- .../optimizers/layout_optimizer_test.cc | 2 +- .../grappler/optimizers/static_schedule.cc | 4 ++-- 7 files changed, 23 insertions(+), 25 deletions(-) diff --git a/tensorflow/core/grappler/costs/virtual_placer.cc b/tensorflow/core/grappler/costs/virtual_placer.cc index 146eecf5bcb..b492bed7a77 100644 --- a/tensorflow/core/grappler/costs/virtual_placer.cc +++ b/tensorflow/core/grappler/costs/virtual_placer.cc @@ -23,14 +23,12 @@ limitations under the License. namespace tensorflow { namespace grappler { -VirtualPlacer::VirtualPlacer(const Cluster* cluster) { - CHECK(cluster); - - // Default job name for canonical device name. Needs to be set before the - // first call to to_lfqn_or_empty() - default_job_name_lowercase_ = "localhost"; - - devices_ = cluster->GetDevices(); +VirtualPlacer::VirtualPlacer( + const std::unordered_map& devices) + : devices_(devices), + // Default job name for canonical device name. Needs to be set before the + // first call to to_lfqn_or_empty() + default_job_name_lowercase_("localhost") { lfqn_map_.reserve(devices_.size()); for (const auto& kv : devices_) { const auto lfqn = to_lfqn_or_empty(kv.first); diff --git a/tensorflow/core/grappler/costs/virtual_placer.h b/tensorflow/core/grappler/costs/virtual_placer.h index e17ece7c1a8..8761f40f707 100644 --- a/tensorflow/core/grappler/costs/virtual_placer.h +++ b/tensorflow/core/grappler/costs/virtual_placer.h @@ -28,7 +28,7 @@ class Cluster; // The virtual placer emulates the behavior of the TF placer. class VirtualPlacer { public: - VirtualPlacer(const Cluster* cluster); + VirtualPlacer(const std::unordered_map& devices); const DeviceProperties& get_device(const NodeDef& node) const; diff --git a/tensorflow/core/grappler/costs/virtual_placer_test.cc b/tensorflow/core/grappler/costs/virtual_placer_test.cc index d1f9cd2176b..1f6beca7bb5 100644 --- a/tensorflow/core/grappler/costs/virtual_placer_test.cc +++ b/tensorflow/core/grappler/costs/virtual_placer_test.cc @@ -33,7 +33,7 @@ TEST(VirtualPlacerTest, LocalDevices) { gpu_device.set_type("GPU"); devices["/job:localhost/replica:0/task:0/device:GPU:0"] = gpu_device; VirtualCluster cluster(devices); - VirtualPlacer placer(&cluster); + VirtualPlacer placer(devices); NodeDef node; node.set_op("Conv2D"); @@ -63,7 +63,7 @@ TEST(VirtualPlacerTest, ShortNames) { gpu_device.set_type("GPU"); devices["/GPU:0"] = gpu_device; VirtualCluster cluster(devices); - VirtualPlacer placer(&cluster); + VirtualPlacer placer(devices); NodeDef node; node.set_op("Conv2D"); @@ -93,7 +93,7 @@ TEST(VirtualPlacerTest, PlacementOnNonDefaultDevice) { tpu_device.set_type("TPU"); devices["/job:localhost/replica:0/task:0/device:TPU:0"] = tpu_device; VirtualCluster cluster(devices); - VirtualPlacer placer(&cluster); + VirtualPlacer placer(devices); NodeDef node; node.set_op("Conv2D"); @@ -123,7 +123,7 @@ TEST(VirtualPlacerTest, EmptyJobName) { devices[strings::StrCat("/job:", job_name, "/replica:0/task:0/device:GPU:0")] = gpu_device; VirtualCluster cluster(devices); - VirtualPlacer placer(&cluster); + VirtualPlacer placer(devices); NodeDef node; node.set_op("Conv2D"); @@ -145,7 +145,7 @@ TEST(VirtualPlacerTest, EmptyJobName) { devices["/job:ps/replica:0/task:0/cpu:0"] = cpu_device; devices["/job:worker/replica:0/task:0/cpu:0"] = cpu_device; VirtualCluster cluster(devices); - VirtualPlacer placer(&cluster); + VirtualPlacer placer(devices); NodeDef node; node.set_op("Conv2D"); @@ -157,7 +157,7 @@ TEST(VirtualPlacerTest, EmptyJobName) { string GetDefaultDeviceName( const std::unordered_map& devices) { VirtualCluster cluster(devices); - VirtualPlacer placer(&cluster); + VirtualPlacer placer(devices); NodeDef node; node.set_op("Conv2D"); // Device is not set to the node, so get_canonical_device_name() will return @@ -204,7 +204,7 @@ TEST(VirtualPlacerTest, MultiReplica) { } std::unique_ptr cluster(new VirtualCluster(devices)); - std::unique_ptr placer(new VirtualPlacer(cluster.get())); + std::unique_ptr placer(new VirtualPlacer(devices)); auto get_device_name = [&placer](const string& device) -> string { NodeDef node; @@ -235,7 +235,7 @@ TEST(VirtualPlacerTest, MultiReplica) { cpu_device; } cluster.reset(new VirtualCluster(devices)); - placer.reset(new VirtualPlacer(cluster.get())); + placer.reset(new VirtualPlacer(cluster->GetDevices())); EXPECT_EQ("/job:worker/replica:0/task:0/cpu:0", get_device_name("/job:worker/replica:0/cpu:0")); EXPECT_EQ("/job:worker/replica:7/task:0/gpu:3", @@ -255,7 +255,7 @@ TEST(VirtualPlacerTest, FallBackUnknown) { // cluster. std::unordered_map devices; VirtualCluster cluster(devices); - VirtualPlacer placer(&cluster); + VirtualPlacer placer(devices); NodeDef node; node.set_op("Conv2D"); @@ -271,7 +271,7 @@ TEST(VirtualPlacerTest, FallBackCPU) { cpu_device.set_type("CPU"); devices["/job:my_job/replica:0/task:0/cpu:0"] = cpu_device; VirtualCluster cluster(devices); - VirtualPlacer placer(&cluster); + VirtualPlacer placer(devices); NodeDef node; node.set_op("Conv2D"); @@ -291,7 +291,7 @@ TEST(VirtualPlacerTest, RemoteDevices) { gpu_device.set_type("GPU"); devices["/job:my_job/replica:0/task:0/device:GPU:0"] = gpu_device; VirtualCluster cluster(devices); - VirtualPlacer placer(&cluster); + VirtualPlacer placer(devices); NodeDef node; node.set_op("Conv2D"); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 52c8f6f97db..881f817da26 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -265,7 +265,7 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes, cluster_(cluster), use_static_shapes_(use_static_shapes), use_aggressive_shape_inference_(use_aggressive_shape_inference), - placer_(cluster) { + placer_(cluster->GetDevices()) { graph_costs_.num_ops_total = 0; initialized_ = false; track_mem_usage_snapshot_ = VLOG_IS_ON(1); diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index e9d622afbf4..cf1e42dfa44 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -2207,7 +2207,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, return Status::OK(); } - virtual_placer_.reset(new VirtualPlacer(cluster)); + virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices())); nodes_to_preserve_ = item.NodesToPreserve(); GraphProperties graph_properties(item); auto status = graph_properties.InferStatically(false); diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index eb2a8e87dde..3dc1bdf4763 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -1223,7 +1223,7 @@ TEST_F(LayoutOptimizerTest, DevicePlacement) { auto i = ops::Identity(s.WithOpName("i"), shape); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - VirtualPlacer virtual_placer(virtual_cluster_.get()); + VirtualPlacer virtual_placer(virtual_cluster_->GetDevices()); for (auto& node : *item.graph.mutable_node()) { string device = virtual_placer.get_canonical_device_name(node); node.set_device(device); diff --git a/tensorflow/core/grappler/optimizers/static_schedule.cc b/tensorflow/core/grappler/optimizers/static_schedule.cc index 5206e9957dc..9950db063d6 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule.cc +++ b/tensorflow/core/grappler/optimizers/static_schedule.cc @@ -94,7 +94,7 @@ Status EstimateEarliestExecutionTimes( GraphProperties properties(item); TF_RETURN_IF_ERROR(properties.InferStatically(true)); OpLevelCostEstimator estimator; - VirtualPlacer placer(cluster); + VirtualPlacer placer(cluster->GetDevices()); while (!ready_nodes.empty()) { const NodeDef* node = ready_nodes.front(); @@ -162,7 +162,7 @@ Status EstimateRequiredTimes( GraphProperties properties(item); TF_RETURN_IF_ERROR(properties.InferStatically(true)); OpLevelCostEstimator estimator; - VirtualPlacer placer(cluster); + VirtualPlacer placer(cluster->GetDevices()); while (!ready_nodes.empty()) { const NodeDef* node = ready_nodes.front();