Shorten grappler per-node report by default with previous behavior behind --verbose.

PiperOrigin-RevId: 186369380
This commit is contained in:
A. Unique TensorFlower 2018-02-20 15:08:09 -08:00 committed by TensorFlower Gardener
parent af4b7d75c4
commit 776fa148b4
6 changed files with 95 additions and 16 deletions

View File

@ -30,11 +30,12 @@ CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
analytical_estimator_(cluster, false),
suffix_(suffix) {}
Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report) {
Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report,
bool verbose) {
GatherCosts();
PreprocessCosts();
AnalyzeCosts();
PrintAnalysis(os, per_node_report);
PrintAnalysis(os, per_node_report, verbose);
return Status::OK();
}
@ -158,7 +159,8 @@ void CostAnalyzer::AnalyzeCosts() {
}
}
void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report) const {
void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report,
bool verbose) const {
os << std::endl;
os << std::left << std::setw(50)
<< "Total time measured in ns (serialized): " << std::right
@ -227,10 +229,55 @@ void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report) const {
os << std::endl;
if (per_node_report) {
os << "Below is the per-node report:" << std::endl;
os << op_perf_.DebugString();
if (verbose) {
os << "Below is the full per-node report:" << std::endl;
os << op_perf_.DebugString();
} else {
os << "Below is the per-node report summary:" << std::endl;
int width = 35;
int width_narrow = 15;
int width_wide = 20;
os << std::setw(width + 1) << "Op,";
os << std::setw(width_wide + 1) << "Measured time (ns),";
os << std::setw(width_wide + 1) << "Compute time (ns),";
os << std::setw(width_wide + 1) << "Memory time (ns),";
os << std::setw(width_narrow + 2) << "Compute eff,";
os << std::setw(width_narrow + 2) << "Memory eff,";
os << " Inputs" << std::endl;
for (int i = 0; i < op_perf_.op_performance_size(); i++) {
const auto& perf = op_perf_.op_performance(i);
string op_name = perf.op().op();
os << std::setw(width) << op_name << ",";
os << std::setw(width_wide) << perf.compute_cost() << ",";
os << std::setw(width_wide) << perf.compute_time() << ",";
os << std::setw(width_wide) << perf.memory_time() << ",";
os << std::setw(width_narrow) << std::setprecision(2)
<< perf.compute_efficiency() * 100 << "%,";
os << std::setw(width_narrow) << std::setprecision(2)
<< perf.memory_efficiency() * 100 << "%,";
os << " [";
for (int j = 0; j < perf.op().inputs_size(); j++) {
const auto& shape = perf.op().inputs(j).shape();
if (shape.dim_size() > 0) {
os << "(";
std::vector<int> dims;
for (int k = 0; k < shape.dim_size(); k++) {
os << shape.dim(k).size();
if (k < shape.dim_size() - 1) {
os << ", ";
}
}
os << ")";
if (j < perf.op().inputs_size() - 1) {
os << ", ";
}
}
}
os << "]" << std::endl;
}
os << std::endl;
}
}
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <iostream>
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
#include "tensorflow/core/grappler/costs/cost_estimator.h"
@ -50,7 +51,7 @@ class CostAnalyzer {
public:
explicit CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
const string& suffix);
Status GenerateReport(std::ostream& os, bool per_node_report);
Status GenerateReport(std::ostream& os, bool per_node_report, bool verbose);
private:
void PredictCosts(CostEstimator* cost_estimator, CostGraphDef* cost_graph,
@ -59,7 +60,8 @@ class CostAnalyzer {
void PreprocessCosts();
void AnalyzeCosts();
void SortOpsByTime(std::map<string, OpPerfSummary> ops);
void PrintAnalysis(std::ostream& os, bool per_node_report) const;
void PrintAnalysis(std::ostream& os, bool per_node_report,
bool verbose) const;
const GrapplerItem* item_;
MeasuringCostEstimator measure_estimator_;

View File

@ -44,7 +44,7 @@ limitations under the License.
%{
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
GCluster cluster) {
bool verbose, GCluster cluster) {
tensorflow::grappler::ItemConfig cfg;
cfg.apply_optimizations = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
@ -57,11 +57,11 @@ string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_no
tensorflow::grappler::CostAnalyzer analyzer(*item, cluster.get(), suffix);
std::stringstream os;
analyzer.GenerateReport(os, per_node_report);
analyzer.GenerateReport(os, per_node_report, verbose);
return os.str();
}
%}
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
GCluster cluster);
bool verbose, GCluster cluster);

View File

@ -24,7 +24,10 @@ from tensorflow.python.grappler import cluster as gcluster
from tensorflow.python.grappler import item as gitem
def GenerateCostReport(metagraph, per_node_report=False, cluster=None):
def GenerateCostReport(metagraph,
per_node_report=False,
verbose=False,
cluster=None):
"""Analyze the cost of each TensorFlow op and node in the provided metagraph.
Args:
@ -32,6 +35,7 @@ def GenerateCostReport(metagraph, per_node_report=False, cluster=None):
per_node_report: by default the report contains stats aggregated on a per op
type basis, setting per_node_report to True adds results for each
individual node to the report.
verbose: Prints out the entire operation proto instead of a summary table.
cluster: Analyze the costs using the specified cluster, or the local machine
if no cluster was specified.
@ -42,8 +46,9 @@ def GenerateCostReport(metagraph, per_node_report=False, cluster=None):
cluster = gcluster.Cluster(disable_detailed_stats=False)
with errors.raise_exception_on_not_ok_status():
ret_from_swig = tf_wrap.GenerateCostReport(
metagraph.SerializeToString(), per_node_report, cluster.tf_cluster)
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
per_node_report, verbose,
cluster.tf_cluster)
return ret_from_swig

View File

@ -48,7 +48,7 @@ class CostAnalysisTest(test.TestCase):
train_op.append(d)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateCostReport(mg)
report = cost_analyzer.GenerateCostReport(mg, per_node_report=True)
# Check the report headers
self.assertTrue(b"Total time measured in ns (serialized):" in report)
@ -57,6 +57,26 @@ class CostAnalysisTest(test.TestCase):
self.assertTrue(b"Total time analytical in ns (lower bound):" in report)
self.assertTrue(b"Overall efficiency (analytical upper/actual):" in report)
self.assertTrue(b"Overall efficiency (analytical lower/actual):" in report)
self.assertTrue(b"Below is the per-node report summary:" in report)
# Also print the report to make it easier to debug
print("{}".format(report))
def testVerbose(self):
"""Make sure the full report is generated with verbose=True."""
a = constant_op.constant(10, name="a")
b = constant_op.constant(20, name="b")
c = math_ops.add_n([a, b], name="c")
d = math_ops.add_n([b, c], name="d")
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateCostReport(
mg, per_node_report=True, verbose=True)
# Check the report headers
self.assertTrue(b"Below is the full per-node report:" in report)
# Also print the report to make it easier to debug
print("{}".format(report))

View File

@ -74,7 +74,8 @@ def main(_):
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
metagraph.graph_def.CopyFrom(optimized_graph)
report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report,
FLAGS.verbose)
print(report)
if FLAGS.memory_report:
report = cost_analyzer.GenerateMemoryReport(metagraph)
@ -117,5 +118,9 @@ if __name__ == "__main__":
"--memory_report",
action="store_true",
help="Generate memory usage report.")
parser.add_argument(
"--verbose",
action="store_true",
help="Generate verbose reports. By default, succinct reports are used.")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)