diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index bade45e96a3..269ac86fd67 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -258,6 +258,7 @@ endif() # We include tf_cc_ops first, because tf_c depends on tf_cc. include(tf_cc_ops.cmake) include(tf_c.cmake) +include(tf_grappler.cmake) if(tensorflow_BUILD_CC_EXAMPLE) include(tf_tutorials.cmake) include(tf_label_image_example.cmake) diff --git a/tensorflow/contrib/cmake/tf_grappler.cmake b/tensorflow/contrib/cmake/tf_grappler.cmake new file mode 100644 index 00000000000..4811c8cce9c --- /dev/null +++ b/tensorflow/contrib/cmake/tf_grappler.cmake @@ -0,0 +1,27 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +######################################################## +# tf_grappler library +######################################################## +file(GLOB tf_grappler_srcs + "${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.cc" + "${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h" + "${tensorflow_source_dir}/tensorflow/python/grappler/cost_analyzer.cc" + "${tensorflow_source_dir}/tensorflow/python/grappler/cost_analyzer.h" + ) + +add_library(tf_grappler OBJECT ${tf_grappler_srcs}) + +add_dependencies(tf_grappler tf_core_cpu) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 6d129e70638..a2aa2a8102f 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -721,6 +721,7 @@ if(WIN32) $ $ $ + $ $ $<$:$> $ @@ -767,6 +768,7 @@ add_library(pywrap_tensorflow_internal SHARED $ $ $ + $ $ $<$:$> $ diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 7b7d79fc7ed..3cc92b56d20 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -230,5 +230,60 @@ string GetOpDescription(const OpInfo& op_info) { return description; } +OpPerformanceList CostGraphToOpPerformanceData(const CostGraphDef& cost_graph, + const GraphDef& graph) { + OpPerformanceList ret; + std::unordered_map name_to_cost; + std::unordered_map name_to_node; + for (auto& node : cost_graph.node()) { + name_to_cost[node.name()] = &node; + } + for (auto& node : graph.node()) { + name_to_node[node.name()] = &node; + } + + for (const auto& node : graph.node()) { + // Skip the nodes that are not in the cost graph: these are nodes that + // aren't run, because they aren't in the intersection of transitive + // fan-in of a fetch node and the transitive fan-out of an input, or nodes + // that were optimized away by the optimizer. Since they don't contribute + // to the execution time we simply discard them. + auto it = name_to_cost.find(node.name()); + if (it == name_to_cost.end()) { + continue; + } + const CostGraphDef::Node* cost_node = it->second; + + OpPerformance* perf = ret.add_op_performance(); + perf->set_node(node.name()); + + std::vector inputs = + FindInputFeatures(node, name_to_cost, name_to_node); + (*perf->mutable_op()) = + BuildOpInfo(node, cost_node->device(), name_to_node, inputs); + + perf->set_temporary_memory_size(cost_node->temporary_memory_size()); + // Note that CostGraphDef::Node::compute_cost is microseconds, while + // OpPerformance.compute_cost is nanoseconds. + perf->set_compute_cost(cost_node->compute_cost() * 1000); + perf->set_compute_time(cost_node->compute_time() * 1000); + perf->set_memory_time(cost_node->memory_time() * 1000); + + for (const auto& output_info : cost_node->output_info()) { + perf->mutable_op_memory()->add_output_memory(output_info.size()); + } + + perf->mutable_op_memory()->set_host_temp_memory( + cost_node->host_temp_memory_size()); + perf->mutable_op_memory()->set_device_temp_memory( + cost_node->device_temp_memory_size()); + perf->mutable_op_memory()->set_host_persistent_memory( + cost_node->host_persistent_memory_size()); + perf->mutable_op_memory()->set_device_persistent_memory( + cost_node->device_persistent_memory_size()); + } + return ret; +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h index cb23ac83553..17214a55cb6 100644 --- a/tensorflow/core/grappler/costs/utils.h +++ b/tensorflow/core/grappler/costs/utils.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/graph/types.h" @@ -56,6 +57,10 @@ OpInfo BuildOpInfo( const std::unordered_map& name_to_node, const std::vector& inputs); +// Gather performance data from a cost graph. +OpPerformanceList CostGraphToOpPerformanceData(const CostGraphDef& cost_graph, + const GraphDef& graph); + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 6ff414b8eca..aa1c7774f62 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -155,6 +155,22 @@ tf_py_test( ], ) +cc_library( + name = "cost_analyzer_lib", + srcs = ["grappler/cost_analyzer.cc"], + hdrs = ["grappler/cost_analyzer.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/clusters:single_machine", + "//tensorflow/core/grappler/costs:analytical_cost_estimator", + "//tensorflow/core/grappler/costs:measuring_cost_estimator", + "//tensorflow/core/grappler/costs:op_performance_data_cc", + "//tensorflow/core/grappler/costs:utils", + ], +) + cc_library( name = "numpy_lib", srcs = ["lib/core/numpy.cc"], @@ -2644,6 +2660,7 @@ tf_py_wrap_cc( "client/tf_session.i", "framework/cpp_shape_inference.i", "framework/python_op_gen.i", + "grappler/cost_analyzer.i", "grappler/tf_optimizer.i", "lib/core/py_func.i", "lib/core/strings.i", @@ -2660,6 +2677,7 @@ tf_py_wrap_cc( "util/transform_graph.i", ], deps = [ + ":cost_analyzer_lib", ":cpp_shape_inference", ":kernel_registry", ":numpy_lib", @@ -3673,3 +3691,28 @@ cuda_py_test( "//tensorflow/core:protos_all_py", ], ) + +py_library( + name = "cost_analyzer", + srcs = [ + "grappler/cost_analyzer.py", + ], + srcs_version = "PY2AND3", + deps = [":pywrap_tensorflow_internal"], +) + +py_test( + name = "cost_analyzer_test", + size = "small", + srcs = ["grappler/cost_analyzer_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":client_testlib", + ":cost_analyzer", + ":framework_for_generated_wrappers", + ":math_ops", + "//tensorflow/core:protos_all_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/python/grappler/cost_analyzer.cc b/tensorflow/python/grappler/cost_analyzer.cc new file mode 100644 index 00000000000..273a74dd286 --- /dev/null +++ b/tensorflow/python/grappler/cost_analyzer.cc @@ -0,0 +1,224 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/python/grappler/cost_analyzer.h" + +#include +#include "tensorflow/core/grappler/costs/utils.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster, + const string& suffix) + : item_(&item), + measure_estimator_(cluster, 10, 0), + analytical_estimator_(cluster, false), + suffix_(suffix) {} + +Status CostAnalyzer::GenerateReport(std::ostream& os) { + GatherCosts(); + PreprocessCosts(); + AnalyzeCosts(); + PrintAnalysis(os); + return Status::OK(); +} + +void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator, + CostGraphDef* cost_graph, int64* total_time) { + TF_CHECK_OK(cost_estimator->Initialize(*item_)); + Costs costs; + const Status status = + cost_estimator->PredictCosts(item_->graph, cost_graph, &costs); + *total_time = costs.execution_time.count(); + if (!status.ok()) { + LOG(ERROR) << "Could not estimate the cost for item " << item_->id << ": " + << status.error_message(); + return; + } +} + +void CostAnalyzer::GatherCosts() { + CostGraphDef cost_graph_measured; + PredictCosts(&measure_estimator_, &cost_graph_measured, + &total_time_measured_); + VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size(); + op_perf_ = CostGraphToOpPerformanceData(cost_graph_measured, item_->graph); + + CostGraphDef cost_graph_analytical; + PredictCosts(&analytical_estimator_, &cost_graph_analytical, + &total_time_analytical_); + VLOG(1) << "cost_graph_analytical size: " + << cost_graph_analytical.node_size(); + + CostGraphDef cost_graph_analytical_filtered; + std::set cost_nodes; + for (auto& node : cost_graph_measured.node()) { + cost_nodes.insert(node.name()); + } + for (const auto& node : cost_graph_analytical.node()) { + auto it = cost_nodes.find(node.name()); + // Filter the nodes that are not the cost nodes returned by + // MeasuringCostEstimator. + if (it == cost_nodes.end()) { + continue; + } + auto added_node = cost_graph_analytical_filtered.add_node(); + *added_node = node; + } + VLOG(1) << "cost_graph_analytical_filtered size: " + << cost_graph_analytical_filtered.node_size(); + + op_perf_analytical_ = CostGraphToOpPerformanceData( + cost_graph_analytical_filtered, item_->graph); +} + +void CostAnalyzer::PreprocessCosts() { + for (int i = 0; i < op_perf_.op_performance_size(); i++) { + OpPerformance* perf = op_perf_.mutable_op_performance(i); + const OpPerformance& analytical = op_perf_analytical_.op_performance(i); + perf->set_compute_time(analytical.compute_time()); + perf->set_memory_time(analytical.memory_time()); + double measured_cost = perf->compute_cost(); + + double analytical_compute_cost = analytical.compute_time(); + if (analytical_compute_cost == 0) { + // Negative infinity indidates unavailable data. + perf->set_compute_efficiency(-INFINITY); + } else { + perf->set_compute_efficiency(analytical_compute_cost / measured_cost); + } + + double analytical_memory_cost = analytical.memory_time(); + if (analytical_memory_cost == 0) { + // Negative infinity indidates unavailable data. + perf->set_memory_efficiency(-INFINITY); + } else { + perf->set_memory_efficiency(analytical_memory_cost / measured_cost); + } + } +} + + +void CostAnalyzer::SortOpsByTime(std::map ops) { + for (const auto& op : ops) { + ops_.push_back(op.second); + } + struct CompareByTime { + bool operator()(const OpPerfSummary& a, const OpPerfSummary& b) const { + return a.time > b.time; + } + }; + std::stable_sort(ops_.begin(), ops_.end(), CompareByTime()); +} + +void CostAnalyzer::AnalyzeCosts() { + std::map ops; + for (const auto& op_perf : op_perf_.op_performance()) { + string op_name = op_perf.op().op(); + ops[op_name].count++; + ops[op_name].time += op_perf.compute_cost(); + ops[op_name].compute_time += op_perf.compute_time(); + ops[op_name].memory_time += op_perf.memory_time(); + ops[op_name].time_upper += op_perf.compute_time() + op_perf.memory_time(); + ops[op_name].time_lower += + std::max(op_perf.compute_time(), op_perf.memory_time()); + ops[op_name].name = op_name; + } + SortOpsByTime(ops); + + total_time_measured_serialized_ = 0; + total_time_analytical_upper_ = 0; + total_time_analytical_lower_ = 0; + for (const auto& op : ops_) { + total_time_measured_serialized_ += op.time; + total_time_analytical_upper_ += op.time_upper; + total_time_analytical_lower_ += op.time_lower; + } +} + +void CostAnalyzer::PrintAnalysis(std::ostream& os) const { + os << std::endl; + os << std::left << std::setw(50) + << "Total time measured in ns (serialized): " << std::right + << std::setw(20) << total_time_measured_serialized_ << std::endl; + os << std::left << std::setw(50) + << "Total time measured in ns (actual): " << std::right << std::setw(20) + << total_time_measured_ << std::endl; + os << std::left << std::setw(50) + << "Total time analytical in ns (upper bound): " << std::right + << std::setw(20) << total_time_analytical_upper_ << std::endl; + os << std::left << std::setw(50) + << "Total time analytical in ns (lower bound): " << std::right + << std::setw(20) << total_time_analytical_lower_ << std::endl; + double efficiency_upper = static_cast(total_time_analytical_upper_) / + static_cast(total_time_measured_); + os << std::left << std::setw(50) + << "Overall efficiency (analytical upper/actual): " << std::right + << std::setw(20) << efficiency_upper << std::endl; + double efficiency_lower = static_cast(total_time_analytical_lower_) / + static_cast(total_time_measured_); + os << std::left << std::setw(50) + << "Overall efficiency (analytical lower/actual): " << std::right + << std::setw(20) << efficiency_lower << std::endl; + os << std::endl; + + int width = 35; + int width_narrow = 15; + int width_wide = 20; + os << std::setw(width + 1) << "Op,"; + os << std::setw(width_narrow + 1) << "Count,"; + os << std::setw(width_wide + 1) << "Measured time (ns),"; + os << std::setw(width_narrow + 2) << "Time percent,"; + os << std::setw(width_narrow + 2) << "Acc percent,"; + os << std::setw(width_wide + 1) << "Analytical upper,"; + os << std::setw(width_wide + 1) << "Analytical lower,"; + os << std::setw(width_narrow + 2) << "Overall eff"; + os << std::setw(width_narrow + 2) << "Compute eff"; + os << std::setw(width_narrow + 2) << "Memory eff" << std::endl; + float acc_percent = 0; + for (const auto& op : ops_) { + double percent = static_cast(op.time) / + static_cast(total_time_measured_serialized_); + double eff = + static_cast(op.time_upper) / static_cast(op.time); + double compute_eff = + static_cast(op.compute_time) / static_cast(op.time); + double memory_eff = + static_cast(op.memory_time) / static_cast(op.time); + os << std::setw(width) << op.name << ","; + os << std::setw(width_narrow) << op.count << ","; + os << std::setw(width_wide) << op.time << ","; + os << std::setw(width_narrow) << std::setprecision(2) << percent * 100 + << "%,"; + acc_percent += percent; + os << std::setw(width_narrow) << std::setprecision(2) << acc_percent * 100 + << "%,"; + os << std::setw(width_wide) << op.time_upper << ","; + os << std::setw(width_wide) << op.time_lower << ","; + os << std::setw(width_narrow) << std::setprecision(2) << eff * 100 << "%,"; + os << std::setw(width_narrow) << std::setprecision(2) << compute_eff * 100 + << "%,"; + os << std::setw(width_narrow) << std::setprecision(2) << memory_eff * 100 + << "%,"; + os << std::endl; + } + os << std::endl; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/python/grappler/cost_analyzer.h b/tensorflow/python/grappler/cost_analyzer.h new file mode 100644 index 00000000000..3700bf5fb37 --- /dev/null +++ b/tensorflow/python/grappler/cost_analyzer.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_ + +#include +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" + +namespace tensorflow { +class GraphDef; +class CostGraphDef; + +namespace grappler { +struct GrapplerItem; + +// Aggregated perf summary for ops of the same type in a graph. +struct OpPerfSummary { + string name; + int64 count; + int64 time; + int64 compute_time; + int64 memory_time; + // Upper and lower bound for estimated time. + int64 time_upper; + int64 time_lower; +}; + +// Generate op-level performance insights on compute/memory +// efficiency, as well as graph-level aggregated performance statistics. +class CostAnalyzer { + public: + explicit CostAnalyzer(const GrapplerItem& item, Cluster* cluster, + const string& suffix); + Status GenerateReport(std::ostream& os); + + private: + void PredictCosts(CostEstimator* cost_estimator, CostGraphDef* cost_graph, + int64* total_time); + void GatherCosts(); + void PreprocessCosts(); + void AnalyzeCosts(); + void SortOpsByTime(std::map ops); + void PrintAnalysis(std::ostream& os) const; + + const GrapplerItem* item_; + MeasuringCostEstimator measure_estimator_; + AnalyticalCostEstimator analytical_estimator_; + OpPerformanceList op_perf_; + OpPerformanceList op_perf_analytical_; + int64 total_time_measured_; + int64 total_time_analytical_; + std::vector ops_; + int64 total_time_measured_serialized_; + int64 total_time_analytical_upper_; + int64 total_time_analytical_lower_; + string suffix_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_ diff --git a/tensorflow/python/grappler/cost_analyzer.i b/tensorflow/python/grappler/cost_analyzer.i new file mode 100644 index 00000000000..782ef35fad5 --- /dev/null +++ b/tensorflow/python/grappler/cost_analyzer.i @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "tensorflow/python/lib/core/strings.i" +%include "tensorflow/python/platform/base.i" + +%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { + char* c_string; + Py_ssize_t py_size; + if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + SWIG_fail; + } + + if (!temp.ParseFromString(string(c_string, py_size))) { + PyErr_SetString( + PyExc_TypeError, + "The MetaGraphDef could not be parsed as a valid protocol buffer"); + SWIG_fail; + } + $1 = &temp; +} + +%{ +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/clusters/single_machine.h" +#include "tensorflow/core/grappler/devices.h" +#include "tensorflow/core/grappler/grappler_item_builder.h" +#include "tensorflow/python/grappler/cost_analyzer.h" +%} + +%{ +string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph) { + tensorflow::grappler::ItemConfig cfg; + std::unique_ptr item = + tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg); + + // TODO(bsteiner): we should wrap the tf session instead to properly handle the case of a + // distributed setup. + const int timeout_s = 3600; + int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); + int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); + tensorflow::grappler::SingleMachine cluster(timeout_s, num_cpu_cores, num_gpus); + + string suffix; + tensorflow::grappler::CostAnalyzer analyzer(*item, &cluster, suffix); + + std::stringstream os; + analyzer.GenerateReport(os); + return os.str(); +} + +%} + +string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph); diff --git a/tensorflow/python/grappler/cost_analyzer.py b/tensorflow/python/grappler/cost_analyzer.py new file mode 100644 index 00000000000..d16614c7c75 --- /dev/null +++ b/tensorflow/python/grappler/cost_analyzer.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Provides a proper python API for the symbols exported through swig.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow as tf_wrap +from tensorflow.python.framework import errors + + +def GenerateCostReport(metagraph): + """Analyze the cost of each TensorFlow operation in the provided metagraph.""" + with errors.raise_exception_on_not_ok_status(): + ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString()) + return ret_from_swig diff --git a/tensorflow/python/grappler/cost_analyzer_test.py b/tensorflow/python/grappler/cost_analyzer_test.py new file mode 100644 index 00000000000..19d3c9695bf --- /dev/null +++ b/tensorflow/python/grappler/cost_analyzer_test.py @@ -0,0 +1,56 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the cost analyzer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import ops +from tensorflow.python.grappler import cost_analyzer +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class PyWrapOptimizeGraphTest(test.TestCase): + + def testBasic(self): + """Make sure arguments can be passed correctly.""" + a = constant_op.constant(10, name="a") + b = constant_op.constant(20, name="b") + c = math_ops.add_n([a, b], name="c") + d = math_ops.add_n([b, c], name="d") + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + train_op.append(d) + mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) + + report = cost_analyzer.GenerateCostReport(mg) + + # Check the report headers + self.assertTrue(b"Total time measured in ns (serialized):" in report) + self.assertTrue(b"Total time measured in ns (actual):" in report) + self.assertTrue(b"Total time analytical in ns (upper bound):" in report) + self.assertTrue(b"Total time analytical in ns (lower bound):" in report) + self.assertTrue(b"Overall efficiency (analytical upper/actual):" in report) + self.assertTrue(b"Overall efficiency (analytical lower/actual):" in report) + + # Also print the report to make it easier to debug + print("{}".format(report)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 5c2ad417e2f..a9a0b7fffa8 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -42,3 +42,4 @@ limitations under the License. %include "tensorflow/python/util/transform_graph.i" %include "tensorflow/python/grappler/tf_optimizer.i" +%include "tensorflow/python/grappler/cost_analyzer.i"