From 7273a08672c29739cee9f9aa91fb4d92ec1e2682 Mon Sep 17 00:00:00 2001 From: Peter Ma Date: Sun, 10 Mar 2019 17:14:06 -0700 Subject: [PATCH] - Added input argument for aggressive shape inference mode in AnalyticalCostEstimator. - Unified the logic on VirtualCluster in AnalyticalCostEstimator and VirtualCluster. PiperOrigin-RevId: 237718010 --- tensorflow/core/grappler/clusters/BUILD | 1 + .../core/grappler/clusters/virtual_cluster.cc | 72 +++++-------------- .../core/grappler/clusters/virtual_cluster.h | 6 +- .../costs/analytical_cost_estimator.cc | 38 ++++++---- .../costs/analytical_cost_estimator.h | 14 ++-- .../costs/analytical_cost_estimator_test.cc | 3 +- tensorflow/python/grappler/cluster.i | 2 +- tensorflow/python/grappler/cluster_test.py | 4 +- tensorflow/python/grappler/cost_analyzer.cc | 4 +- 9 files changed, 61 insertions(+), 83 deletions(-) diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index b9ffb1726df..de27cf4ba2a 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -81,6 +81,7 @@ cc_library( "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/costs:analytical_cost_estimator", "//tensorflow/core/grappler/costs:op_level_cost_estimator", "//tensorflow/core/grappler/costs:virtual_scheduler", ], diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index 118f74e8b01..2839d33c552 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -14,32 +14,33 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/clusters/virtual_cluster.h" + #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" -#include "tensorflow/core/grappler/costs/virtual_scheduler.h" namespace tensorflow { namespace grappler { VirtualCluster::VirtualCluster( const std::unordered_map& devices) - : Cluster(0), - node_estimator_(new OpLevelCostEstimator()), - node_manager_(new FirstReadyManager()) { - devices_ = devices; -} + : VirtualCluster(devices, absl::make_unique(), + ReadyNodeManagerFactory("FirstReady")) {} VirtualCluster::VirtualCluster( const std::unordered_map& devices, std::unique_ptr node_estimator, std::unique_ptr node_manager) - : Cluster(0), - node_estimator_(std::move(node_estimator)), - node_manager_(std::move(node_manager)) { + : Cluster(0) { devices_ = devices; + + // Note that we do not use aggressive shape inference to preserve unknown + // shapes from the input graph. + estimator_ = absl::make_unique( + this, std::move(node_estimator), std::move(node_manager), + /*use_static_shapes=*/true, /*use_aggressive_shape_inference=*/false); } VirtualCluster::VirtualCluster(const DeviceSet* device_set) @@ -66,19 +67,13 @@ Status VirtualCluster::Run(const GraphDef& graph, const std::vector>& feed, const std::vector& fetch, RunMetadata* metadata) { - // Initialize a virtual scheduler to process the graph. Make sure to use - // static shape inference to prevent the scheduler from calling the Run - // method on the cluster and creating an infinite loop. + // Initializes an analytical cost estimator to estimate the graph cost. Makes + // sure to use static shape inference to prevent the virtual scheduler from + // calling the Run method on the cluster and creating an infinite loop. GrapplerItem item; item.graph = graph; item.feed = feed; item.fetch = fetch; - // Note that we do not use aggressive shape inference to preserve unknown - // shapes from the input graph. - VirtualScheduler scheduler(/*use_static_shapes=*/true, - /*use_aggressive_shape_inference=*/false, this, - node_manager_.get()); - TF_RETURN_IF_ERROR(scheduler.Init(&item)); if (metadata) { metadata->clear_step_stats(); @@ -86,45 +81,14 @@ Status VirtualCluster::Run(const GraphDef& graph, metadata->clear_partition_graphs(); } - Costs node_costs; - int node_id = 0; - do { - OpContext op_context = scheduler.GetCurrNode(); - node_costs = node_estimator_->PredictCosts(op_context); - if (metadata) { - CostGraphDef::Node* cost_node = - metadata->mutable_cost_graph()->add_node(); - const string& op_name = op_context.name; - cost_node->set_id(node_id++); - cost_node->set_name(op_name); - cost_node->set_device(op_context.device_name); - cost_node->set_compute_cost( - node_costs.execution_time.asMicroSeconds().count()); - cost_node->set_compute_time( - node_costs.compute_time.asMicroSeconds().count()); - cost_node->set_memory_time( - node_costs.memory_time.asMicroSeconds().count()); - for (const auto& output : op_context.op_info.outputs()) { - auto output_info = cost_node->add_output_info(); - output_info->set_dtype(output.dtype()); - *output_info->mutable_shape() = output.shape(); - - int64 size = DataTypeSize(output.dtype()); - for (const auto& dim : output.shape().dim()) { - size *= std::max(1, dim.size()); - } - output_info->set_size(size); - } - } - } while (scheduler.MarkCurrNodeExecuted(node_costs)); - - if (metadata) { - scheduler.Summary(metadata); - } + TF_RETURN_IF_ERROR(estimator_->Initialize(item)); + Costs ignored_costs; + TF_RETURN_IF_ERROR( + estimator_->PredictCosts(item.graph, metadata, &ignored_costs)); const std::unordered_map& device = GetDevices(); std::unordered_map peak_mem_usage = - scheduler.GetPeakMemoryUsage(); + estimator_->GetScheduler()->GetPeakMemoryUsage(); for (const auto& mem_usage : peak_mem_usage) { const string& device_name = mem_usage.first; auto it = device.find(device_name); diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h index d19e39cd292..94446a998a6 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.h +++ b/tensorflow/core/grappler/clusters/virtual_cluster.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" #include "tensorflow/core/grappler/costs/virtual_scheduler.h" #include "tensorflow/core/protobuf/device_properties.pb.h" @@ -50,9 +51,8 @@ class VirtualCluster : public Cluster { const DeviceSet* GetDeviceSet() const override { return device_set_; } private: - std::unique_ptr node_estimator_; - std::unique_ptr node_manager_; - const DeviceSet* device_set_ = nullptr; // Not owned + std::unique_ptr estimator_; + const DeviceSet* device_set_ = nullptr; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index 5baf306f6fe..cb8b9964436 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -56,7 +56,7 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context, (*name_to_id)[node->name()] = node->id(); } // For nodes we have seen before (e.g. Merge nodes are executed twice by - // VirtualScheduler), the following fields will be overwritten/updated + // VirtualScheduler), the following fields will be overwritten/updated. node->set_device(op_context.device_name); node->set_compute_cost(node_costs.execution_time.asMicroSeconds().count()); node->set_compute_time(node_costs.compute_time.asMicroSeconds().count()); @@ -67,7 +67,7 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context, int input_port; string input_name = ParseNodeName(input, &input_port); - // All inputs should have been seen already unless this is a Merge node + // All inputs should have been seen already unless this is a Merge node. if (name_to_id->find(input_name) == name_to_id->end()) { if (!IsMerge(*node_manager->GetCurrNode())) LOG(ERROR) << "input: " << input @@ -76,7 +76,7 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context, // For Merge node, some of inputs may not be seen before // For example, for a typical while loop in tensorflow, Merge node // will be executed twice by VirtualScheduler (one for Enter, the - // other for NextIteration), so eventually both inputs will be added + // other for NextIteration), so eventually both inputs will be added. continue; } @@ -93,30 +93,38 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context, auto output_info = node->add_output_info(); output_info->set_alias_input_port(-1); output_info->set_dtype(output.dtype()); - auto shape = output_info->mutable_shape(); - *shape = output.shape(); + *output_info->mutable_shape() = output.shape(); + + int64 size = DataTypeSize(output.dtype()); + for (const auto& dim : output.shape().dim()) { + size *= std::max(1, dim.size()); + } + output_info->set_size(size); } } } // namespace -AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster, - bool use_static_shapes) +AnalyticalCostEstimator::AnalyticalCostEstimator( + Cluster* cluster, bool use_static_shapes, + bool use_aggressive_shape_inference) : AnalyticalCostEstimator( cluster, absl::make_unique(), - ReadyNodeManagerFactory("FirstReady"), use_static_shapes) {} + ReadyNodeManagerFactory("FirstReady"), use_static_shapes, + use_aggressive_shape_inference) {} AnalyticalCostEstimator::AnalyticalCostEstimator( Cluster* cluster, std::unique_ptr node_estimator, - std::unique_ptr node_manager, bool use_static_shapes) + std::unique_ptr node_manager, bool use_static_shapes, + bool use_aggressive_shape_inference) : cluster_(cluster), node_estimator_(std::move(node_estimator)), node_manager_(std::move(node_manager)), - use_static_shapes_(use_static_shapes) { - // Use aggressive static shape inference to minimize unknown shapes. + use_static_shapes_(use_static_shapes), + use_aggressive_shape_inference_(use_aggressive_shape_inference) { scheduler_ = absl::make_unique( - use_static_shapes_, - /*use_aggressive_shape_inference=*/true, cluster_, node_manager_.get()); + use_static_shapes_, use_aggressive_shape_inference_, cluster_, + node_manager_.get()); } Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) { @@ -142,7 +150,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, cost_graph = run_metadata->mutable_cost_graph(); // TODO(pcma): Clear nodes in cost_graph after we make sure we always pass // in an empty cost_graph (a non-empty but incomplete cost_graph will cause - // problems, e.g., no node_id in cost_graph) + // problems, e.g., no node_id in cost_graph). for (auto& node : *cost_graph->mutable_node()) { name_to_cost_node[node.name()] = &node; } @@ -165,7 +173,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, << node_costs.num_ops_with_unknown_shapes << " unknown shapes"; } - // TODO(pcma): Add unit tests for generating CostGraphDef + // TODO(pcma): Add unit tests for generating CostGraphDef. if (cost_graph) { AddCostNode(node_manager_.get(), op_context, node_id++, node_costs, &name_to_cost_node, &name_to_id, cost_graph); diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.h b/tensorflow/core/grappler/costs/analytical_cost_estimator.h index d058ba41152..c9028efe0db 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.h +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.h @@ -35,15 +35,19 @@ struct GrapplerItem; // Estimate the cost of running a Grappler item based on the theoretical // performance of the hardware that will run the model. Note that this -// internally uses aggressive shape inference with static shape inference. +// internally uses static shape inference. An option for aggressive shape +// inference is provided to minimize unknown shapes, and this is only applicable +// with static shape inference. class AnalyticalCostEstimator : public CostEstimator { public: // Does not take ownership of cluster. - AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes); + AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes, + bool use_aggressive_shape_inference); AnalyticalCostEstimator(Cluster* cluster, std::unique_ptr node_estimator, std::unique_ptr node_manager, - bool use_static_shapes); + bool use_static_shapes, + bool use_aggressive_shape_inference); ~AnalyticalCostEstimator() override {} // Initializes the estimator for the specified grappler item. @@ -63,8 +67,10 @@ class AnalyticalCostEstimator : public CostEstimator { GrapplerItem item_; std::unique_ptr node_estimator_; std::unique_ptr node_manager_; - bool use_static_shapes_; std::unique_ptr scheduler_; + + bool use_static_shapes_; + bool use_aggressive_shape_inference_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc index fdc6b79c829..e558558d00a 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc @@ -94,7 +94,8 @@ class AnalyticalCostEstimatorTest : public ::testing::Test { TEST_F(AnalyticalCostEstimatorTest, SimpleTest) { GrapplerItem item = CreateMiniGraph(); - AnalyticalCostEstimator estimator(cluster_.get(), true); + AnalyticalCostEstimator estimator(cluster_.get(), /*use_static_shapes=*/true, + /*use_aggressive_shape_inference=*/true); TF_ASSERT_OK(estimator.Initialize(item)); RunMetadata run_metadata; diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i index af9276c508b..2a53da734a6 100644 --- a/tensorflow/python/grappler/cluster.i +++ b/tensorflow/python/grappler/cluster.i @@ -153,7 +153,7 @@ static GCluster TF_NewVirtualCluster( for (const auto& named_device : named_devices) { devices[named_device.name()]= named_device.properties(); } - tensorflow::grappler::Cluster*cluster_ = + tensorflow::grappler::Cluster* cluster_ = new tensorflow::grappler::VirtualCluster(devices); PyGILState_STATE gstate = PyGILState_Ensure(); tensorflow::Status status = cluster_->Provision(); diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py index 541747867fa..2014c0dde3f 100644 --- a/tensorflow/python/grappler/cluster_test.py +++ b/tensorflow/python/grappler/cluster_test.py @@ -99,9 +99,7 @@ class ClusterTest(test.TestCase): type='GPU', frequency=1000, num_cores=60, - environment={ - 'architecture': '7' - }) + environment={'architecture': '7'}) named_device = device_properties_pb2.NamedDevice( properties=device_properties, name='/device:GPU:0') grappler_cluster = cluster.Cluster( diff --git a/tensorflow/python/grappler/cost_analyzer.cc b/tensorflow/python/grappler/cost_analyzer.cc index 9aa5fbca383..de4b82c84dc 100644 --- a/tensorflow/python/grappler/cost_analyzer.cc +++ b/tensorflow/python/grappler/cost_analyzer.cc @@ -27,7 +27,8 @@ CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster, const string& suffix) : item_(&item), measure_estimator_(cluster, 10, 0), - analytical_estimator_(cluster, false), + analytical_estimator_(cluster, /*use_static_shapes=*/false, + /*use_aggressive_shape_inference=*/true), suffix_(suffix) {} Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report, @@ -125,7 +126,6 @@ void CostAnalyzer::PreprocessCosts() { } } - void CostAnalyzer::SortOpsByTime(std::map ops) { for (const auto& op : ops) { ops_.push_back(op.second);