Add a standalone cost analysis tool. Improved logging.
PiperOrigin-RevId: 158933442
This commit is contained in:
parent
838571b0a0
commit
6ffa51f1e0
@ -36,6 +36,8 @@ SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus)
|
||||
num_gpus_(num_gpus),
|
||||
expected_init_time_s_(0),
|
||||
closing_(false) {
|
||||
VLOG(1) << "Number of CPU cores: " << num_cpu_cores
|
||||
<< " Number of GPUs: " << num_gpus;
|
||||
thread_pool_.reset(new thread::ThreadPool(
|
||||
Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
|
||||
|
||||
@ -73,9 +75,12 @@ Status SingleMachine::Provision() {
|
||||
DeviceProperties attr = GetLocalCPUInfo();
|
||||
devices_["/job:localhost/replica:0/task:0/cpu:0"] = GetLocalCPUInfo();
|
||||
|
||||
VLOG(1) << "Number of GPUs: " << num_gpus_;
|
||||
for (int i = 0; i < num_gpus_; ++i) {
|
||||
devices_[strings::StrCat("/job:localhost/replica:0/task:0/gpu:", i)] =
|
||||
GetLocalGPUInfo(i);
|
||||
string device_name =
|
||||
strings::StrCat("/job:localhost/replica:0/task:0/gpu:", i);
|
||||
VLOG(1) << "Adding GPU device " << device_name;
|
||||
devices_[device_name] = GetLocalGPUInfo(i);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -101,6 +101,7 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
|
||||
}
|
||||
|
||||
// Run "measurement_steps_" and measure the time.
|
||||
VLOG(1) << "Number of measurement steps: " << measurement_steps_;
|
||||
if (measurement_threads_ > 0) {
|
||||
for (int i = 0; i < measurement_steps_; ++i) {
|
||||
thread_pool_->Schedule([i, &measurement_fn]() { measurement_fn(i); });
|
||||
|
@ -314,6 +314,8 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
|
||||
bandwidth = 100;
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Device: " << device.type() << " GFLOPS: " << gflops
|
||||
<< " Bandwidth: " << bandwidth;
|
||||
|
||||
return std::make_pair(gflops, bandwidth);
|
||||
}
|
||||
|
@ -36,6 +36,7 @@ VirtualPlacer::VirtualPlacer(const Cluster* cluster) {
|
||||
|
||||
} else {
|
||||
default_device_ = devices_.begin()->first;
|
||||
VLOG(1) << "Number of devices: " << devices_.size();
|
||||
for (const auto& device : devices_) {
|
||||
if (str_util::Lowercase(device.first).find("gpu") != string::npos) {
|
||||
default_device_ = device.first;
|
||||
@ -47,6 +48,7 @@ VirtualPlacer::VirtualPlacer(const Cluster* cluster) {
|
||||
|
||||
const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const {
|
||||
string device = get_canonical_device_name(node);
|
||||
VLOG(3) << "Device name: " << device;
|
||||
auto it = devices_.find(device);
|
||||
DCHECK(it != devices_.end());
|
||||
return it->second;
|
||||
|
@ -31,7 +31,7 @@ struct ItemConfig {
|
||||
: ignore_user_placement(true),
|
||||
ignore_colocation(true),
|
||||
placeholder_unknown_output_shape_dim(-1),
|
||||
apply_optimizations(true),
|
||||
apply_optimizations(false),
|
||||
inline_functions(true) {}
|
||||
|
||||
// If true, ignore all user specified node placement.
|
||||
|
@ -3832,6 +3832,19 @@ py_library(
|
||||
deps = [":pywrap_tensorflow_internal"],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "cost_analyzer_tool",
|
||||
srcs = [
|
||||
"grappler/cost_analyzer_tool.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cost_analyzer",
|
||||
":framework_for_generated_wrappers",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cost_analyzer_test",
|
||||
size = "small",
|
||||
|
@ -30,11 +30,11 @@ CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
|
||||
analytical_estimator_(cluster, false),
|
||||
suffix_(suffix) {}
|
||||
|
||||
Status CostAnalyzer::GenerateReport(std::ostream& os) {
|
||||
Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report) {
|
||||
GatherCosts();
|
||||
PreprocessCosts();
|
||||
AnalyzeCosts();
|
||||
PrintAnalysis(os);
|
||||
PrintAnalysis(os, per_node_report);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -158,7 +158,7 @@ void CostAnalyzer::AnalyzeCosts() {
|
||||
}
|
||||
}
|
||||
|
||||
void CostAnalyzer::PrintAnalysis(std::ostream& os) const {
|
||||
void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report) const {
|
||||
os << std::endl;
|
||||
os << std::left << std::setw(50)
|
||||
<< "Total time measured in ns (serialized): " << std::right
|
||||
@ -225,6 +225,11 @@ void CostAnalyzer::PrintAnalysis(std::ostream& os) const {
|
||||
os << std::endl;
|
||||
}
|
||||
os << std::endl;
|
||||
|
||||
if (per_node_report) {
|
||||
os << "Below is the per-node report:" << std::endl;
|
||||
os << op_perf_.DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -50,7 +50,7 @@ class CostAnalyzer {
|
||||
public:
|
||||
explicit CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
|
||||
const string& suffix);
|
||||
Status GenerateReport(std::ostream& os);
|
||||
Status GenerateReport(std::ostream& os, bool per_node_report);
|
||||
|
||||
private:
|
||||
void PredictCosts(CostEstimator* cost_estimator, CostGraphDef* cost_graph,
|
||||
@ -59,7 +59,7 @@ class CostAnalyzer {
|
||||
void PreprocessCosts();
|
||||
void AnalyzeCosts();
|
||||
void SortOpsByTime(std::map<string, OpPerfSummary> ops);
|
||||
void PrintAnalysis(std::ostream& os) const;
|
||||
void PrintAnalysis(std::ostream& os, bool per_node_report) const;
|
||||
|
||||
const GrapplerItem* item_;
|
||||
MeasuringCostEstimator measure_estimator_;
|
||||
|
@ -42,8 +42,10 @@ limitations under the License.
|
||||
%}
|
||||
|
||||
%{
|
||||
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph) {
|
||||
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool
|
||||
per_node_report) {
|
||||
tensorflow::grappler::ItemConfig cfg;
|
||||
cfg.apply_optimizations = false;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg);
|
||||
|
||||
@ -53,16 +55,20 @@ string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph) {
|
||||
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
|
||||
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
|
||||
tensorflow::grappler::SingleMachine cluster(timeout_s, num_cpu_cores, num_gpus);
|
||||
cluster.SetNumWarmupSteps(10);
|
||||
cluster.AllowSoftPlacement(true);
|
||||
cluster.DisableDetailedStats(false);
|
||||
TF_CHECK_OK(cluster.Provision());
|
||||
|
||||
string suffix;
|
||||
tensorflow::grappler::CostAnalyzer analyzer(*item, &cluster, suffix);
|
||||
|
||||
std::stringstream os;
|
||||
analyzer.GenerateReport(os);
|
||||
analyzer.GenerateReport(os, per_node_report);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph);
|
||||
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool
|
||||
per_node_report);
|
||||
|
@ -22,8 +22,19 @@ from tensorflow.python import pywrap_tensorflow as tf_wrap
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
|
||||
def GenerateCostReport(metagraph):
|
||||
"""Analyze the cost of each TensorFlow operation in the provided metagraph."""
|
||||
def GenerateCostReport(metagraph, per_node_report=False):
|
||||
"""Analyze the cost of each TensorFlow op and node in the provided metagraph.
|
||||
|
||||
Args:
|
||||
metagraph: An TensorFlow MetaGraphDef.
|
||||
per_node_report: by default the report contains stats aggregated on a per op
|
||||
type basis, setting per_node_report to True adds results for each
|
||||
individual node to the report.
|
||||
|
||||
Returns:
|
||||
A string of cost report.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status():
|
||||
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString())
|
||||
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
|
||||
per_node_report)
|
||||
return ret_from_swig
|
||||
|
49
tensorflow/python/grappler/cost_analyzer_tool.py
Normal file
49
tensorflow/python/grappler/cost_analyzer_tool.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""A tool for cost analysis."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python.grappler import cost_analyzer
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
|
||||
def main(_):
|
||||
with open(FLAGS.input) as input_file:
|
||||
metagraph = meta_graph_pb2.MetaGraphDef()
|
||||
metagraph.ParseFromString(input_file.read())
|
||||
|
||||
report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
|
||||
print(report)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input", type=str, default=None, help="Input .meta file path.")
|
||||
parser.add_argument(
|
||||
"--per_node_report",
|
||||
action="store_true",
|
||||
help="Generate per-node report. By default the report contains stats "
|
||||
"aggregated on a per op type basis, per_node_report adds results "
|
||||
"for each individual node to the report.")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
Loading…
Reference in New Issue
Block a user