Add a standalone cost analysis tool. Improved logging.

PiperOrigin-RevId: 158933442
This commit is contained in:
Yao Zhang 2017-06-13 19:47:59 -07:00 committed by TensorFlower Gardener
parent 838571b0a0
commit 6ffa51f1e0
11 changed files with 110 additions and 16 deletions

View File

@ -36,6 +36,8 @@ SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus)
num_gpus_(num_gpus),
expected_init_time_s_(0),
closing_(false) {
VLOG(1) << "Number of CPU cores: " << num_cpu_cores
<< " Number of GPUs: " << num_gpus;
thread_pool_.reset(new thread::ThreadPool(
Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
@ -73,9 +75,12 @@ Status SingleMachine::Provision() {
DeviceProperties attr = GetLocalCPUInfo();
devices_["/job:localhost/replica:0/task:0/cpu:0"] = GetLocalCPUInfo();
VLOG(1) << "Number of GPUs: " << num_gpus_;
for (int i = 0; i < num_gpus_; ++i) {
devices_[strings::StrCat("/job:localhost/replica:0/task:0/gpu:", i)] =
GetLocalGPUInfo(i);
string device_name =
strings::StrCat("/job:localhost/replica:0/task:0/gpu:", i);
VLOG(1) << "Adding GPU device " << device_name;
devices_[device_name] = GetLocalGPUInfo(i);
}
return Status::OK();
}

View File

@ -101,6 +101,7 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
}
// Run "measurement_steps_" and measure the time.
VLOG(1) << "Number of measurement steps: " << measurement_steps_;
if (measurement_threads_ > 0) {
for (int i = 0; i < measurement_steps_; ++i) {
thread_pool_->Schedule([i, &measurement_fn]() { measurement_fn(i); });

View File

@ -314,6 +314,8 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
bandwidth = 100;
}
}
VLOG(1) << "Device: " << device.type() << " GFLOPS: " << gflops
<< " Bandwidth: " << bandwidth;
return std::make_pair(gflops, bandwidth);
}

View File

@ -36,6 +36,7 @@ VirtualPlacer::VirtualPlacer(const Cluster* cluster) {
} else {
default_device_ = devices_.begin()->first;
VLOG(1) << "Number of devices: " << devices_.size();
for (const auto& device : devices_) {
if (str_util::Lowercase(device.first).find("gpu") != string::npos) {
default_device_ = device.first;
@ -47,6 +48,7 @@ VirtualPlacer::VirtualPlacer(const Cluster* cluster) {
const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const {
string device = get_canonical_device_name(node);
VLOG(3) << "Device name: " << device;
auto it = devices_.find(device);
DCHECK(it != devices_.end());
return it->second;

View File

@ -31,7 +31,7 @@ struct ItemConfig {
: ignore_user_placement(true),
ignore_colocation(true),
placeholder_unknown_output_shape_dim(-1),
apply_optimizations(true),
apply_optimizations(false),
inline_functions(true) {}
// If true, ignore all user specified node placement.

View File

@ -3832,6 +3832,19 @@ py_library(
deps = [":pywrap_tensorflow_internal"],
)
py_binary(
name = "cost_analyzer_tool",
srcs = [
"grappler/cost_analyzer_tool.py",
],
srcs_version = "PY2AND3",
deps = [
":cost_analyzer",
":framework_for_generated_wrappers",
"//tensorflow/core:protos_all_py",
],
)
py_test(
name = "cost_analyzer_test",
size = "small",

View File

@ -30,11 +30,11 @@ CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
analytical_estimator_(cluster, false),
suffix_(suffix) {}
Status CostAnalyzer::GenerateReport(std::ostream& os) {
Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report) {
GatherCosts();
PreprocessCosts();
AnalyzeCosts();
PrintAnalysis(os);
PrintAnalysis(os, per_node_report);
return Status::OK();
}
@ -158,7 +158,7 @@ void CostAnalyzer::AnalyzeCosts() {
}
}
void CostAnalyzer::PrintAnalysis(std::ostream& os) const {
void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report) const {
os << std::endl;
os << std::left << std::setw(50)
<< "Total time measured in ns (serialized): " << std::right
@ -225,6 +225,11 @@ void CostAnalyzer::PrintAnalysis(std::ostream& os) const {
os << std::endl;
}
os << std::endl;
if (per_node_report) {
os << "Below is the per-node report:" << std::endl;
os << op_perf_.DebugString();
}
}
} // end namespace grappler

View File

@ -50,7 +50,7 @@ class CostAnalyzer {
public:
explicit CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
const string& suffix);
Status GenerateReport(std::ostream& os);
Status GenerateReport(std::ostream& os, bool per_node_report);
private:
void PredictCosts(CostEstimator* cost_estimator, CostGraphDef* cost_graph,
@ -59,7 +59,7 @@ class CostAnalyzer {
void PreprocessCosts();
void AnalyzeCosts();
void SortOpsByTime(std::map<string, OpPerfSummary> ops);
void PrintAnalysis(std::ostream& os) const;
void PrintAnalysis(std::ostream& os, bool per_node_report) const;
const GrapplerItem* item_;
MeasuringCostEstimator measure_estimator_;

View File

@ -42,8 +42,10 @@ limitations under the License.
%}
%{
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph) {
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool
per_node_report) {
tensorflow::grappler::ItemConfig cfg;
cfg.apply_optimizations = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg);
@ -53,16 +55,20 @@ string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph) {
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
tensorflow::grappler::SingleMachine cluster(timeout_s, num_cpu_cores, num_gpus);
cluster.SetNumWarmupSteps(10);
cluster.AllowSoftPlacement(true);
cluster.DisableDetailedStats(false);
TF_CHECK_OK(cluster.Provision());
string suffix;
tensorflow::grappler::CostAnalyzer analyzer(*item, &cluster, suffix);
std::stringstream os;
analyzer.GenerateReport(os);
analyzer.GenerateReport(os, per_node_report);
return os.str();
}
%}
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph);
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool
per_node_report);

View File

@ -22,8 +22,19 @@ from tensorflow.python import pywrap_tensorflow as tf_wrap
from tensorflow.python.framework import errors
def GenerateCostReport(metagraph):
"""Analyze the cost of each TensorFlow operation in the provided metagraph."""
def GenerateCostReport(metagraph, per_node_report=False):
"""Analyze the cost of each TensorFlow op and node in the provided metagraph.
Args:
metagraph: An TensorFlow MetaGraphDef.
per_node_report: by default the report contains stats aggregated on a per op
type basis, setting per_node_report to True adds results for each
individual node to the report.
Returns:
A string of cost report.
"""
with errors.raise_exception_on_not_ok_status():
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString())
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
per_node_report)
return ret_from_swig

View File

@ -0,0 +1,49 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""A tool for cost analysis."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.grappler import cost_analyzer
from tensorflow.python.platform import app
def main(_):
with open(FLAGS.input) as input_file:
metagraph = meta_graph_pb2.MetaGraphDef()
metagraph.ParseFromString(input_file.read())
report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
print(report)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input", type=str, default=None, help="Input .meta file path.")
parser.add_argument(
"--per_node_report",
action="store_true",
help="Generate per-node report. By default the report contains stats "
"aggregated on a per op type basis, per_node_report adds results "
"for each individual node to the report.")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)