Open sourced the cost analyzer

PiperOrigin-RevId: 157178951
This commit is contained in:
Benoit Steiner 2017-05-25 18:29:14 -07:00 committed by TensorFlower Gardener
parent 3e767e9db0
commit 2251633a50
12 changed files with 591 additions and 0 deletions

View File

@ -258,6 +258,7 @@ endif()
# We include tf_cc_ops first, because tf_c depends on tf_cc.
include(tf_cc_ops.cmake)
include(tf_c.cmake)
include(tf_grappler.cmake)
if(tensorflow_BUILD_CC_EXAMPLE)
include(tf_tutorials.cmake)
include(tf_label_image_example.cmake)

View File

@ -0,0 +1,27 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
########################################################
# tf_grappler library
########################################################
file(GLOB tf_grappler_srcs
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.cc"
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h"
"${tensorflow_source_dir}/tensorflow/python/grappler/cost_analyzer.cc"
"${tensorflow_source_dir}/tensorflow/python/grappler/cost_analyzer.h"
)
add_library(tf_grappler OBJECT ${tf_grappler_srcs})
add_dependencies(tf_grappler tf_core_cpu)

View File

@ -721,6 +721,7 @@ if(WIN32)
$<TARGET_OBJECTS:tf_cc_ops>
$<TARGET_OBJECTS:tf_core_ops>
$<TARGET_OBJECTS:tf_core_direct_session>
$<TARGET_OBJECTS:tf_grappler>
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
$<$<BOOL:${tensorflow_ENABLE_GRPC_SUPPORT}>:$<TARGET_OBJECTS:tf_core_distributed_runtime>>
$<TARGET_OBJECTS:tf_core_kernels>
@ -767,6 +768,7 @@ add_library(pywrap_tensorflow_internal SHARED
$<TARGET_OBJECTS:tf_cc_ops>
$<TARGET_OBJECTS:tf_core_ops>
$<TARGET_OBJECTS:tf_core_direct_session>
$<TARGET_OBJECTS:tf_grappler>
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
$<$<BOOL:${tensorflow_ENABLE_GRPC_SUPPORT}>:$<TARGET_OBJECTS:tf_core_distributed_runtime>>
$<TARGET_OBJECTS:tf_core_kernels>

View File

@ -230,5 +230,60 @@ string GetOpDescription(const OpInfo& op_info) {
return description;
}
OpPerformanceList CostGraphToOpPerformanceData(const CostGraphDef& cost_graph,
const GraphDef& graph) {
OpPerformanceList ret;
std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
std::unordered_map<string, const NodeDef*> name_to_node;
for (auto& node : cost_graph.node()) {
name_to_cost[node.name()] = &node;
}
for (auto& node : graph.node()) {
name_to_node[node.name()] = &node;
}
for (const auto& node : graph.node()) {
// Skip the nodes that are not in the cost graph: these are nodes that
// aren't run, because they aren't in the intersection of transitive
// fan-in of a fetch node and the transitive fan-out of an input, or nodes
// that were optimized away by the optimizer. Since they don't contribute
// to the execution time we simply discard them.
auto it = name_to_cost.find(node.name());
if (it == name_to_cost.end()) {
continue;
}
const CostGraphDef::Node* cost_node = it->second;
OpPerformance* perf = ret.add_op_performance();
perf->set_node(node.name());
std::vector<OpInfo::TensorProperties> inputs =
FindInputFeatures(node, name_to_cost, name_to_node);
(*perf->mutable_op()) =
BuildOpInfo(node, cost_node->device(), name_to_node, inputs);
perf->set_temporary_memory_size(cost_node->temporary_memory_size());
// Note that CostGraphDef::Node::compute_cost is microseconds, while
// OpPerformance.compute_cost is nanoseconds.
perf->set_compute_cost(cost_node->compute_cost() * 1000);
perf->set_compute_time(cost_node->compute_time() * 1000);
perf->set_memory_time(cost_node->memory_time() * 1000);
for (const auto& output_info : cost_node->output_info()) {
perf->mutable_op_memory()->add_output_memory(output_info.size());
}
perf->mutable_op_memory()->set_host_temp_memory(
cost_node->host_temp_memory_size());
perf->mutable_op_memory()->set_device_temp_memory(
cost_node->device_temp_memory_size());
perf->mutable_op_memory()->set_host_persistent_memory(
cost_node->host_persistent_memory_size());
perf->mutable_op_memory()->set_device_persistent_memory(
cost_node->device_persistent_memory_size());
}
return ret;
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/graph/types.h"
@ -56,6 +57,10 @@ OpInfo BuildOpInfo(
const std::unordered_map<string, const NodeDef*>& name_to_node,
const std::vector<OpInfo::TensorProperties>& inputs);
// Gather performance data from a cost graph.
OpPerformanceList CostGraphToOpPerformanceData(const CostGraphDef& cost_graph,
const GraphDef& graph);
} // end namespace grappler
} // end namespace tensorflow

View File

@ -155,6 +155,22 @@ tf_py_test(
],
)
cc_library(
name = "cost_analyzer_lib",
srcs = ["grappler/cost_analyzer.cc"],
hdrs = ["grappler/cost_analyzer.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/clusters:single_machine",
"//tensorflow/core/grappler/costs:analytical_cost_estimator",
"//tensorflow/core/grappler/costs:measuring_cost_estimator",
"//tensorflow/core/grappler/costs:op_performance_data_cc",
"//tensorflow/core/grappler/costs:utils",
],
)
cc_library(
name = "numpy_lib",
srcs = ["lib/core/numpy.cc"],
@ -2644,6 +2660,7 @@ tf_py_wrap_cc(
"client/tf_session.i",
"framework/cpp_shape_inference.i",
"framework/python_op_gen.i",
"grappler/cost_analyzer.i",
"grappler/tf_optimizer.i",
"lib/core/py_func.i",
"lib/core/strings.i",
@ -2660,6 +2677,7 @@ tf_py_wrap_cc(
"util/transform_graph.i",
],
deps = [
":cost_analyzer_lib",
":cpp_shape_inference",
":kernel_registry",
":numpy_lib",
@ -3673,3 +3691,28 @@ cuda_py_test(
"//tensorflow/core:protos_all_py",
],
)
py_library(
name = "cost_analyzer",
srcs = [
"grappler/cost_analyzer.py",
],
srcs_version = "PY2AND3",
deps = [":pywrap_tensorflow_internal"],
)
py_test(
name = "cost_analyzer_test",
size = "small",
srcs = ["grappler/cost_analyzer_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":client_testlib",
":cost_analyzer",
":framework_for_generated_wrappers",
":math_ops",
"//tensorflow/core:protos_all_py",
"//third_party/py/numpy",
],
)

View File

@ -0,0 +1,224 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/grappler/cost_analyzer.h"
#include <iomanip>
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
const string& suffix)
: item_(&item),
measure_estimator_(cluster, 10, 0),
analytical_estimator_(cluster, false),
suffix_(suffix) {}
Status CostAnalyzer::GenerateReport(std::ostream& os) {
GatherCosts();
PreprocessCosts();
AnalyzeCosts();
PrintAnalysis(os);
return Status::OK();
}
void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator,
CostGraphDef* cost_graph, int64* total_time) {
TF_CHECK_OK(cost_estimator->Initialize(*item_));
Costs costs;
const Status status =
cost_estimator->PredictCosts(item_->graph, cost_graph, &costs);
*total_time = costs.execution_time.count();
if (!status.ok()) {
LOG(ERROR) << "Could not estimate the cost for item " << item_->id << ": "
<< status.error_message();
return;
}
}
void CostAnalyzer::GatherCosts() {
CostGraphDef cost_graph_measured;
PredictCosts(&measure_estimator_, &cost_graph_measured,
&total_time_measured_);
VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size();
op_perf_ = CostGraphToOpPerformanceData(cost_graph_measured, item_->graph);
CostGraphDef cost_graph_analytical;
PredictCosts(&analytical_estimator_, &cost_graph_analytical,
&total_time_analytical_);
VLOG(1) << "cost_graph_analytical size: "
<< cost_graph_analytical.node_size();
CostGraphDef cost_graph_analytical_filtered;
std::set<string> cost_nodes;
for (auto& node : cost_graph_measured.node()) {
cost_nodes.insert(node.name());
}
for (const auto& node : cost_graph_analytical.node()) {
auto it = cost_nodes.find(node.name());
// Filter the nodes that are not the cost nodes returned by
// MeasuringCostEstimator.
if (it == cost_nodes.end()) {
continue;
}
auto added_node = cost_graph_analytical_filtered.add_node();
*added_node = node;
}
VLOG(1) << "cost_graph_analytical_filtered size: "
<< cost_graph_analytical_filtered.node_size();
op_perf_analytical_ = CostGraphToOpPerformanceData(
cost_graph_analytical_filtered, item_->graph);
}
void CostAnalyzer::PreprocessCosts() {
for (int i = 0; i < op_perf_.op_performance_size(); i++) {
OpPerformance* perf = op_perf_.mutable_op_performance(i);
const OpPerformance& analytical = op_perf_analytical_.op_performance(i);
perf->set_compute_time(analytical.compute_time());
perf->set_memory_time(analytical.memory_time());
double measured_cost = perf->compute_cost();
double analytical_compute_cost = analytical.compute_time();
if (analytical_compute_cost == 0) {
// Negative infinity indidates unavailable data.
perf->set_compute_efficiency(-INFINITY);
} else {
perf->set_compute_efficiency(analytical_compute_cost / measured_cost);
}
double analytical_memory_cost = analytical.memory_time();
if (analytical_memory_cost == 0) {
// Negative infinity indidates unavailable data.
perf->set_memory_efficiency(-INFINITY);
} else {
perf->set_memory_efficiency(analytical_memory_cost / measured_cost);
}
}
}
void CostAnalyzer::SortOpsByTime(std::map<string, OpPerfSummary> ops) {
for (const auto& op : ops) {
ops_.push_back(op.second);
}
struct CompareByTime {
bool operator()(const OpPerfSummary& a, const OpPerfSummary& b) const {
return a.time > b.time;
}
};
std::stable_sort(ops_.begin(), ops_.end(), CompareByTime());
}
void CostAnalyzer::AnalyzeCosts() {
std::map<string, OpPerfSummary> ops;
for (const auto& op_perf : op_perf_.op_performance()) {
string op_name = op_perf.op().op();
ops[op_name].count++;
ops[op_name].time += op_perf.compute_cost();
ops[op_name].compute_time += op_perf.compute_time();
ops[op_name].memory_time += op_perf.memory_time();
ops[op_name].time_upper += op_perf.compute_time() + op_perf.memory_time();
ops[op_name].time_lower +=
std::max(op_perf.compute_time(), op_perf.memory_time());
ops[op_name].name = op_name;
}
SortOpsByTime(ops);
total_time_measured_serialized_ = 0;
total_time_analytical_upper_ = 0;
total_time_analytical_lower_ = 0;
for (const auto& op : ops_) {
total_time_measured_serialized_ += op.time;
total_time_analytical_upper_ += op.time_upper;
total_time_analytical_lower_ += op.time_lower;
}
}
void CostAnalyzer::PrintAnalysis(std::ostream& os) const {
os << std::endl;
os << std::left << std::setw(50)
<< "Total time measured in ns (serialized): " << std::right
<< std::setw(20) << total_time_measured_serialized_ << std::endl;
os << std::left << std::setw(50)
<< "Total time measured in ns (actual): " << std::right << std::setw(20)
<< total_time_measured_ << std::endl;
os << std::left << std::setw(50)
<< "Total time analytical in ns (upper bound): " << std::right
<< std::setw(20) << total_time_analytical_upper_ << std::endl;
os << std::left << std::setw(50)
<< "Total time analytical in ns (lower bound): " << std::right
<< std::setw(20) << total_time_analytical_lower_ << std::endl;
double efficiency_upper = static_cast<double>(total_time_analytical_upper_) /
static_cast<double>(total_time_measured_);
os << std::left << std::setw(50)
<< "Overall efficiency (analytical upper/actual): " << std::right
<< std::setw(20) << efficiency_upper << std::endl;
double efficiency_lower = static_cast<double>(total_time_analytical_lower_) /
static_cast<double>(total_time_measured_);
os << std::left << std::setw(50)
<< "Overall efficiency (analytical lower/actual): " << std::right
<< std::setw(20) << efficiency_lower << std::endl;
os << std::endl;
int width = 35;
int width_narrow = 15;
int width_wide = 20;
os << std::setw(width + 1) << "Op,";
os << std::setw(width_narrow + 1) << "Count,";
os << std::setw(width_wide + 1) << "Measured time (ns),";
os << std::setw(width_narrow + 2) << "Time percent,";
os << std::setw(width_narrow + 2) << "Acc percent,";
os << std::setw(width_wide + 1) << "Analytical upper,";
os << std::setw(width_wide + 1) << "Analytical lower,";
os << std::setw(width_narrow + 2) << "Overall eff";
os << std::setw(width_narrow + 2) << "Compute eff";
os << std::setw(width_narrow + 2) << "Memory eff" << std::endl;
float acc_percent = 0;
for (const auto& op : ops_) {
double percent = static_cast<double>(op.time) /
static_cast<double>(total_time_measured_serialized_);
double eff =
static_cast<double>(op.time_upper) / static_cast<double>(op.time);
double compute_eff =
static_cast<double>(op.compute_time) / static_cast<double>(op.time);
double memory_eff =
static_cast<double>(op.memory_time) / static_cast<double>(op.time);
os << std::setw(width) << op.name << ",";
os << std::setw(width_narrow) << op.count << ",";
os << std::setw(width_wide) << op.time << ",";
os << std::setw(width_narrow) << std::setprecision(2) << percent * 100
<< "%,";
acc_percent += percent;
os << std::setw(width_narrow) << std::setprecision(2) << acc_percent * 100
<< "%,";
os << std::setw(width_wide) << op.time_upper << ",";
os << std::setw(width_wide) << op.time_lower << ",";
os << std::setw(width_narrow) << std::setprecision(2) << eff * 100 << "%,";
os << std::setw(width_narrow) << std::setprecision(2) << compute_eff * 100
<< "%,";
os << std::setw(width_narrow) << std::setprecision(2) << memory_eff * 100
<< "%,";
os << std::endl;
}
os << std::endl;
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,81 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_
#define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_
#include <iostream>
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
namespace tensorflow {
class GraphDef;
class CostGraphDef;
namespace grappler {
struct GrapplerItem;
// Aggregated perf summary for ops of the same type in a graph.
struct OpPerfSummary {
string name;
int64 count;
int64 time;
int64 compute_time;
int64 memory_time;
// Upper and lower bound for estimated time.
int64 time_upper;
int64 time_lower;
};
// Generate op-level performance insights on compute/memory
// efficiency, as well as graph-level aggregated performance statistics.
class CostAnalyzer {
public:
explicit CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
const string& suffix);
Status GenerateReport(std::ostream& os);
private:
void PredictCosts(CostEstimator* cost_estimator, CostGraphDef* cost_graph,
int64* total_time);
void GatherCosts();
void PreprocessCosts();
void AnalyzeCosts();
void SortOpsByTime(std::map<string, OpPerfSummary> ops);
void PrintAnalysis(std::ostream& os) const;
const GrapplerItem* item_;
MeasuringCostEstimator measure_estimator_;
AnalyticalCostEstimator analytical_estimator_;
OpPerformanceList op_perf_;
OpPerformanceList op_perf_analytical_;
int64 total_time_measured_;
int64 total_time_analytical_;
std::vector<OpPerfSummary> ops_;
int64 total_time_measured_serialized_;
int64 total_time_analytical_upper_;
int64 total_time_analytical_lower_;
string suffix_;
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_

View File

@ -0,0 +1,67 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The MetaGraphDef could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%{
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/python/grappler/cost_analyzer.h"
%}
%{
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph) {
tensorflow::grappler::ItemConfig cfg;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg);
// TODO(bsteiner): we should wrap the tf session instead to properly handle the case of a
// distributed setup.
const int timeout_s = 3600;
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
tensorflow::grappler::SingleMachine cluster(timeout_s, num_cpu_cores, num_gpus);
string suffix;
tensorflow::grappler::CostAnalyzer analyzer(*item, &cluster, suffix);
std::stringstream os;
analyzer.GenerateReport(os);
return os.str();
}
%}
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph);

View File

@ -0,0 +1,29 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Provides a proper python API for the symbols exported through swig."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as tf_wrap
from tensorflow.python.framework import errors
def GenerateCostReport(metagraph):
"""Analyze the cost of each TensorFlow operation in the provided metagraph."""
with errors.raise_exception_on_not_ok_status():
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString())
return ret_from_swig

View File

@ -0,0 +1,56 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the cost analyzer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.grappler import cost_analyzer
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class PyWrapOptimizeGraphTest(test.TestCase):
def testBasic(self):
"""Make sure arguments can be passed correctly."""
a = constant_op.constant(10, name="a")
b = constant_op.constant(20, name="b")
c = math_ops.add_n([a, b], name="c")
d = math_ops.add_n([b, c], name="d")
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateCostReport(mg)
# Check the report headers
self.assertTrue(b"Total time measured in ns (serialized):" in report)
self.assertTrue(b"Total time measured in ns (actual):" in report)
self.assertTrue(b"Total time analytical in ns (upper bound):" in report)
self.assertTrue(b"Total time analytical in ns (lower bound):" in report)
self.assertTrue(b"Overall efficiency (analytical upper/actual):" in report)
self.assertTrue(b"Overall efficiency (analytical lower/actual):" in report)
# Also print the report to make it easier to debug
print("{}".format(report))
if __name__ == "__main__":
test.main()

View File

@ -42,3 +42,4 @@ limitations under the License.
%include "tensorflow/python/util/transform_graph.i"
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"