From 1597edbeede9d7376626f0026cd94bd8eb7e50b3 Mon Sep 17 00:00:00 2001 From: Lifeng Nai Date: Tue, 2 Apr 2019 21:40:20 -0700 Subject: [PATCH] Update VirtualScheduler constructor. PiperOrigin-RevId: 241662817 --- .../grappler/costs/analytical_cost_estimator.cc | 17 ++++++++++++++++- .../grappler/costs/analytical_cost_estimator.h | 6 ++++++ .../core/grappler/costs/virtual_scheduler.cc | 12 +++++++----- .../core/grappler/costs/virtual_scheduler.h | 6 ++++-- .../grappler/costs/virtual_scheduler_test.cc | 6 ++++-- 5 files changed, 37 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index 856ed414f29..81afe7aaf74 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -123,7 +123,22 @@ AnalyticalCostEstimator::AnalyticalCostEstimator( use_aggressive_shape_inference_(use_aggressive_shape_inference) { scheduler_ = absl::make_unique( use_static_shapes_, use_aggressive_shape_inference_, cluster, - node_manager_.get()); + node_manager_.get(), + absl::make_unique(cluster->GetDevices())); +} + +AnalyticalCostEstimator::AnalyticalCostEstimator( + Cluster* cluster, std::unique_ptr node_estimator, + std::unique_ptr node_manager, + std::unique_ptr placer, bool use_static_shapes, + bool use_aggressive_shape_inference) + : node_estimator_(std::move(node_estimator)), + node_manager_(std::move(node_manager)), + use_static_shapes_(use_static_shapes), + use_aggressive_shape_inference_(use_aggressive_shape_inference) { + scheduler_ = absl::make_unique( + use_static_shapes_, use_aggressive_shape_inference_, cluster, + node_manager_.get(), std::move(placer)); } Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) { diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.h b/tensorflow/core/grappler/costs/analytical_cost_estimator.h index 26ea964f528..914f5839ad5 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.h +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.h @@ -47,6 +47,12 @@ class AnalyticalCostEstimator : public CostEstimator { std::unique_ptr node_manager, bool use_static_shapes, bool use_aggressive_shape_inference); + AnalyticalCostEstimator(Cluster* cluster, + std::unique_ptr node_estimator, + std::unique_ptr node_manager, + std::unique_ptr placer, + bool use_static_shapes, + bool use_aggressive_shape_inference); ~AnalyticalCostEstimator() override {} // Initializes the estimator for the specified grappler item. diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 881f817da26..4e55f021fb7 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -259,13 +259,15 @@ std::unique_ptr ReadyNodeManagerFactory( VirtualScheduler::VirtualScheduler(const bool use_static_shapes, const bool use_aggressive_shape_inference, Cluster* cluster, - ReadyNodeManager* ready_nodes) + ReadyNodeManager* ready_nodes, + std::unique_ptr placer) : ready_nodes_(ready_nodes), graph_costs_(Costs::ZeroCosts()), cluster_(cluster), use_static_shapes_(use_static_shapes), use_aggressive_shape_inference_(use_aggressive_shape_inference), - placer_(cluster->GetDevices()) { + placer_(std::move(placer)) { + DCHECK(placer_); // check if the pointer is valid. graph_costs_.num_ops_total = 0; initialized_ = false; track_mem_usage_snapshot_ = VLOG_IS_ON(1); @@ -524,13 +526,13 @@ bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const { } string VirtualScheduler::DeviceName(const NodeDef* node) const { - return placer_.get_canonical_device_name(*node); + return placer_->get_canonical_device_name(*node); } string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const { // Replace the ":" characters that may be present in the device name with "_". // This makes it possible to then use the resulting string in a node name. - return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":", + return str_util::StringReplace(placer_->get_canonical_device_name(*node), ":", "_", true); } @@ -620,7 +622,7 @@ OpContext VirtualScheduler::GetCurrNode() const { // Get the device from the placer. DeviceProperties device; - device = placer_.get_device(*node); + device = placer_->get_device(*node); // Special case for _Send op. if (IsSend(*node)) { diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index e8e16229633..9a67fa9effb 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -263,7 +263,9 @@ class VirtualScheduler { // Does not take ownership of cluster or ready_nodes. VirtualScheduler(const bool use_static_shapes, const bool use_aggressive_shape_inference, Cluster* cluster, - ReadyNodeManager* ready_nodes); + ReadyNodeManager* ready_nodes, + std::unique_ptr placer); + // Initializes the scheduler for the specific grappler item. // Should be called immediately after the c'tor or when the scheduler will be // reused for a new grappler item. All internal states of the scheduler @@ -356,7 +358,7 @@ class VirtualScheduler { bool track_mem_usage_snapshot_; const bool use_aggressive_shape_inference_; - VirtualPlacer placer_; // owned. + std::unique_ptr placer_; }; } // namespace grappler diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 38fd380a660..97b65860dd0 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -33,8 +33,10 @@ class TestVirtualScheduler : public VirtualScheduler { TestVirtualScheduler(const bool use_static_shapes, const bool use_aggressive_shape_inference, Cluster* cluster) - : VirtualScheduler(use_static_shapes, use_aggressive_shape_inference, - cluster, &ready_node_manager_) { + : VirtualScheduler( + use_static_shapes, use_aggressive_shape_inference, cluster, + &ready_node_manager_, + absl::make_unique(cluster->GetDevices())) { enable_mem_usage_tracking(); }