Open sourced the cost analyzer
PiperOrigin-RevId: 157178951
This commit is contained in:
parent
3e767e9db0
commit
2251633a50
@ -258,6 +258,7 @@ endif()
|
||||
# We include tf_cc_ops first, because tf_c depends on tf_cc.
|
||||
include(tf_cc_ops.cmake)
|
||||
include(tf_c.cmake)
|
||||
include(tf_grappler.cmake)
|
||||
if(tensorflow_BUILD_CC_EXAMPLE)
|
||||
include(tf_tutorials.cmake)
|
||||
include(tf_label_image_example.cmake)
|
||||
|
27
tensorflow/contrib/cmake/tf_grappler.cmake
Normal file
27
tensorflow/contrib/cmake/tf_grappler.cmake
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
########################################################
|
||||
# tf_grappler library
|
||||
########################################################
|
||||
file(GLOB tf_grappler_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/grappler/cost_analyzer.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/grappler/cost_analyzer.h"
|
||||
)
|
||||
|
||||
add_library(tf_grappler OBJECT ${tf_grappler_srcs})
|
||||
|
||||
add_dependencies(tf_grappler tf_core_cpu)
|
@ -721,6 +721,7 @@ if(WIN32)
|
||||
$<TARGET_OBJECTS:tf_cc_ops>
|
||||
$<TARGET_OBJECTS:tf_core_ops>
|
||||
$<TARGET_OBJECTS:tf_core_direct_session>
|
||||
$<TARGET_OBJECTS:tf_grappler>
|
||||
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
|
||||
$<$<BOOL:${tensorflow_ENABLE_GRPC_SUPPORT}>:$<TARGET_OBJECTS:tf_core_distributed_runtime>>
|
||||
$<TARGET_OBJECTS:tf_core_kernels>
|
||||
@ -767,6 +768,7 @@ add_library(pywrap_tensorflow_internal SHARED
|
||||
$<TARGET_OBJECTS:tf_cc_ops>
|
||||
$<TARGET_OBJECTS:tf_core_ops>
|
||||
$<TARGET_OBJECTS:tf_core_direct_session>
|
||||
$<TARGET_OBJECTS:tf_grappler>
|
||||
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
|
||||
$<$<BOOL:${tensorflow_ENABLE_GRPC_SUPPORT}>:$<TARGET_OBJECTS:tf_core_distributed_runtime>>
|
||||
$<TARGET_OBJECTS:tf_core_kernels>
|
||||
|
@ -230,5 +230,60 @@ string GetOpDescription(const OpInfo& op_info) {
|
||||
return description;
|
||||
}
|
||||
|
||||
OpPerformanceList CostGraphToOpPerformanceData(const CostGraphDef& cost_graph,
|
||||
const GraphDef& graph) {
|
||||
OpPerformanceList ret;
|
||||
std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
|
||||
std::unordered_map<string, const NodeDef*> name_to_node;
|
||||
for (auto& node : cost_graph.node()) {
|
||||
name_to_cost[node.name()] = &node;
|
||||
}
|
||||
for (auto& node : graph.node()) {
|
||||
name_to_node[node.name()] = &node;
|
||||
}
|
||||
|
||||
for (const auto& node : graph.node()) {
|
||||
// Skip the nodes that are not in the cost graph: these are nodes that
|
||||
// aren't run, because they aren't in the intersection of transitive
|
||||
// fan-in of a fetch node and the transitive fan-out of an input, or nodes
|
||||
// that were optimized away by the optimizer. Since they don't contribute
|
||||
// to the execution time we simply discard them.
|
||||
auto it = name_to_cost.find(node.name());
|
||||
if (it == name_to_cost.end()) {
|
||||
continue;
|
||||
}
|
||||
const CostGraphDef::Node* cost_node = it->second;
|
||||
|
||||
OpPerformance* perf = ret.add_op_performance();
|
||||
perf->set_node(node.name());
|
||||
|
||||
std::vector<OpInfo::TensorProperties> inputs =
|
||||
FindInputFeatures(node, name_to_cost, name_to_node);
|
||||
(*perf->mutable_op()) =
|
||||
BuildOpInfo(node, cost_node->device(), name_to_node, inputs);
|
||||
|
||||
perf->set_temporary_memory_size(cost_node->temporary_memory_size());
|
||||
// Note that CostGraphDef::Node::compute_cost is microseconds, while
|
||||
// OpPerformance.compute_cost is nanoseconds.
|
||||
perf->set_compute_cost(cost_node->compute_cost() * 1000);
|
||||
perf->set_compute_time(cost_node->compute_time() * 1000);
|
||||
perf->set_memory_time(cost_node->memory_time() * 1000);
|
||||
|
||||
for (const auto& output_info : cost_node->output_info()) {
|
||||
perf->mutable_op_memory()->add_output_memory(output_info.size());
|
||||
}
|
||||
|
||||
perf->mutable_op_memory()->set_host_temp_memory(
|
||||
cost_node->host_temp_memory_size());
|
||||
perf->mutable_op_memory()->set_device_temp_memory(
|
||||
cost_node->device_temp_memory_size());
|
||||
perf->mutable_op_memory()->set_host_persistent_memory(
|
||||
cost_node->host_persistent_memory_size());
|
||||
perf->mutable_op_memory()->set_device_persistent_memory(
|
||||
cost_node->device_persistent_memory_size());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/graph/types.h"
|
||||
@ -56,6 +57,10 @@ OpInfo BuildOpInfo(
|
||||
const std::unordered_map<string, const NodeDef*>& name_to_node,
|
||||
const std::vector<OpInfo::TensorProperties>& inputs);
|
||||
|
||||
// Gather performance data from a cost graph.
|
||||
OpPerformanceList CostGraphToOpPerformanceData(const CostGraphDef& cost_graph,
|
||||
const GraphDef& graph);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -155,6 +155,22 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cost_analyzer_lib",
|
||||
srcs = ["grappler/cost_analyzer.cc"],
|
||||
hdrs = ["grappler/cost_analyzer.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/clusters:single_machine",
|
||||
"//tensorflow/core/grappler/costs:analytical_cost_estimator",
|
||||
"//tensorflow/core/grappler/costs:measuring_cost_estimator",
|
||||
"//tensorflow/core/grappler/costs:op_performance_data_cc",
|
||||
"//tensorflow/core/grappler/costs:utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "numpy_lib",
|
||||
srcs = ["lib/core/numpy.cc"],
|
||||
@ -2644,6 +2660,7 @@ tf_py_wrap_cc(
|
||||
"client/tf_session.i",
|
||||
"framework/cpp_shape_inference.i",
|
||||
"framework/python_op_gen.i",
|
||||
"grappler/cost_analyzer.i",
|
||||
"grappler/tf_optimizer.i",
|
||||
"lib/core/py_func.i",
|
||||
"lib/core/strings.i",
|
||||
@ -2660,6 +2677,7 @@ tf_py_wrap_cc(
|
||||
"util/transform_graph.i",
|
||||
],
|
||||
deps = [
|
||||
":cost_analyzer_lib",
|
||||
":cpp_shape_inference",
|
||||
":kernel_registry",
|
||||
":numpy_lib",
|
||||
@ -3673,3 +3691,28 @@ cuda_py_test(
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cost_analyzer",
|
||||
srcs = [
|
||||
"grappler/cost_analyzer.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":pywrap_tensorflow_internal"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cost_analyzer_test",
|
||||
size = "small",
|
||||
srcs = ["grappler/cost_analyzer_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":client_testlib",
|
||||
":cost_analyzer",
|
||||
":framework_for_generated_wrappers",
|
||||
":math_ops",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
224
tensorflow/python/grappler/cost_analyzer.cc
Normal file
224
tensorflow/python/grappler/cost_analyzer.cc
Normal file
@ -0,0 +1,224 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/python/grappler/cost_analyzer.h"
|
||||
|
||||
#include <iomanip>
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
|
||||
const string& suffix)
|
||||
: item_(&item),
|
||||
measure_estimator_(cluster, 10, 0),
|
||||
analytical_estimator_(cluster, false),
|
||||
suffix_(suffix) {}
|
||||
|
||||
Status CostAnalyzer::GenerateReport(std::ostream& os) {
|
||||
GatherCosts();
|
||||
PreprocessCosts();
|
||||
AnalyzeCosts();
|
||||
PrintAnalysis(os);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator,
|
||||
CostGraphDef* cost_graph, int64* total_time) {
|
||||
TF_CHECK_OK(cost_estimator->Initialize(*item_));
|
||||
Costs costs;
|
||||
const Status status =
|
||||
cost_estimator->PredictCosts(item_->graph, cost_graph, &costs);
|
||||
*total_time = costs.execution_time.count();
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Could not estimate the cost for item " << item_->id << ": "
|
||||
<< status.error_message();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void CostAnalyzer::GatherCosts() {
|
||||
CostGraphDef cost_graph_measured;
|
||||
PredictCosts(&measure_estimator_, &cost_graph_measured,
|
||||
&total_time_measured_);
|
||||
VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size();
|
||||
op_perf_ = CostGraphToOpPerformanceData(cost_graph_measured, item_->graph);
|
||||
|
||||
CostGraphDef cost_graph_analytical;
|
||||
PredictCosts(&analytical_estimator_, &cost_graph_analytical,
|
||||
&total_time_analytical_);
|
||||
VLOG(1) << "cost_graph_analytical size: "
|
||||
<< cost_graph_analytical.node_size();
|
||||
|
||||
CostGraphDef cost_graph_analytical_filtered;
|
||||
std::set<string> cost_nodes;
|
||||
for (auto& node : cost_graph_measured.node()) {
|
||||
cost_nodes.insert(node.name());
|
||||
}
|
||||
for (const auto& node : cost_graph_analytical.node()) {
|
||||
auto it = cost_nodes.find(node.name());
|
||||
// Filter the nodes that are not the cost nodes returned by
|
||||
// MeasuringCostEstimator.
|
||||
if (it == cost_nodes.end()) {
|
||||
continue;
|
||||
}
|
||||
auto added_node = cost_graph_analytical_filtered.add_node();
|
||||
*added_node = node;
|
||||
}
|
||||
VLOG(1) << "cost_graph_analytical_filtered size: "
|
||||
<< cost_graph_analytical_filtered.node_size();
|
||||
|
||||
op_perf_analytical_ = CostGraphToOpPerformanceData(
|
||||
cost_graph_analytical_filtered, item_->graph);
|
||||
}
|
||||
|
||||
void CostAnalyzer::PreprocessCosts() {
|
||||
for (int i = 0; i < op_perf_.op_performance_size(); i++) {
|
||||
OpPerformance* perf = op_perf_.mutable_op_performance(i);
|
||||
const OpPerformance& analytical = op_perf_analytical_.op_performance(i);
|
||||
perf->set_compute_time(analytical.compute_time());
|
||||
perf->set_memory_time(analytical.memory_time());
|
||||
double measured_cost = perf->compute_cost();
|
||||
|
||||
double analytical_compute_cost = analytical.compute_time();
|
||||
if (analytical_compute_cost == 0) {
|
||||
// Negative infinity indidates unavailable data.
|
||||
perf->set_compute_efficiency(-INFINITY);
|
||||
} else {
|
||||
perf->set_compute_efficiency(analytical_compute_cost / measured_cost);
|
||||
}
|
||||
|
||||
double analytical_memory_cost = analytical.memory_time();
|
||||
if (analytical_memory_cost == 0) {
|
||||
// Negative infinity indidates unavailable data.
|
||||
perf->set_memory_efficiency(-INFINITY);
|
||||
} else {
|
||||
perf->set_memory_efficiency(analytical_memory_cost / measured_cost);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void CostAnalyzer::SortOpsByTime(std::map<string, OpPerfSummary> ops) {
|
||||
for (const auto& op : ops) {
|
||||
ops_.push_back(op.second);
|
||||
}
|
||||
struct CompareByTime {
|
||||
bool operator()(const OpPerfSummary& a, const OpPerfSummary& b) const {
|
||||
return a.time > b.time;
|
||||
}
|
||||
};
|
||||
std::stable_sort(ops_.begin(), ops_.end(), CompareByTime());
|
||||
}
|
||||
|
||||
void CostAnalyzer::AnalyzeCosts() {
|
||||
std::map<string, OpPerfSummary> ops;
|
||||
for (const auto& op_perf : op_perf_.op_performance()) {
|
||||
string op_name = op_perf.op().op();
|
||||
ops[op_name].count++;
|
||||
ops[op_name].time += op_perf.compute_cost();
|
||||
ops[op_name].compute_time += op_perf.compute_time();
|
||||
ops[op_name].memory_time += op_perf.memory_time();
|
||||
ops[op_name].time_upper += op_perf.compute_time() + op_perf.memory_time();
|
||||
ops[op_name].time_lower +=
|
||||
std::max(op_perf.compute_time(), op_perf.memory_time());
|
||||
ops[op_name].name = op_name;
|
||||
}
|
||||
SortOpsByTime(ops);
|
||||
|
||||
total_time_measured_serialized_ = 0;
|
||||
total_time_analytical_upper_ = 0;
|
||||
total_time_analytical_lower_ = 0;
|
||||
for (const auto& op : ops_) {
|
||||
total_time_measured_serialized_ += op.time;
|
||||
total_time_analytical_upper_ += op.time_upper;
|
||||
total_time_analytical_lower_ += op.time_lower;
|
||||
}
|
||||
}
|
||||
|
||||
void CostAnalyzer::PrintAnalysis(std::ostream& os) const {
|
||||
os << std::endl;
|
||||
os << std::left << std::setw(50)
|
||||
<< "Total time measured in ns (serialized): " << std::right
|
||||
<< std::setw(20) << total_time_measured_serialized_ << std::endl;
|
||||
os << std::left << std::setw(50)
|
||||
<< "Total time measured in ns (actual): " << std::right << std::setw(20)
|
||||
<< total_time_measured_ << std::endl;
|
||||
os << std::left << std::setw(50)
|
||||
<< "Total time analytical in ns (upper bound): " << std::right
|
||||
<< std::setw(20) << total_time_analytical_upper_ << std::endl;
|
||||
os << std::left << std::setw(50)
|
||||
<< "Total time analytical in ns (lower bound): " << std::right
|
||||
<< std::setw(20) << total_time_analytical_lower_ << std::endl;
|
||||
double efficiency_upper = static_cast<double>(total_time_analytical_upper_) /
|
||||
static_cast<double>(total_time_measured_);
|
||||
os << std::left << std::setw(50)
|
||||
<< "Overall efficiency (analytical upper/actual): " << std::right
|
||||
<< std::setw(20) << efficiency_upper << std::endl;
|
||||
double efficiency_lower = static_cast<double>(total_time_analytical_lower_) /
|
||||
static_cast<double>(total_time_measured_);
|
||||
os << std::left << std::setw(50)
|
||||
<< "Overall efficiency (analytical lower/actual): " << std::right
|
||||
<< std::setw(20) << efficiency_lower << std::endl;
|
||||
os << std::endl;
|
||||
|
||||
int width = 35;
|
||||
int width_narrow = 15;
|
||||
int width_wide = 20;
|
||||
os << std::setw(width + 1) << "Op,";
|
||||
os << std::setw(width_narrow + 1) << "Count,";
|
||||
os << std::setw(width_wide + 1) << "Measured time (ns),";
|
||||
os << std::setw(width_narrow + 2) << "Time percent,";
|
||||
os << std::setw(width_narrow + 2) << "Acc percent,";
|
||||
os << std::setw(width_wide + 1) << "Analytical upper,";
|
||||
os << std::setw(width_wide + 1) << "Analytical lower,";
|
||||
os << std::setw(width_narrow + 2) << "Overall eff";
|
||||
os << std::setw(width_narrow + 2) << "Compute eff";
|
||||
os << std::setw(width_narrow + 2) << "Memory eff" << std::endl;
|
||||
float acc_percent = 0;
|
||||
for (const auto& op : ops_) {
|
||||
double percent = static_cast<double>(op.time) /
|
||||
static_cast<double>(total_time_measured_serialized_);
|
||||
double eff =
|
||||
static_cast<double>(op.time_upper) / static_cast<double>(op.time);
|
||||
double compute_eff =
|
||||
static_cast<double>(op.compute_time) / static_cast<double>(op.time);
|
||||
double memory_eff =
|
||||
static_cast<double>(op.memory_time) / static_cast<double>(op.time);
|
||||
os << std::setw(width) << op.name << ",";
|
||||
os << std::setw(width_narrow) << op.count << ",";
|
||||
os << std::setw(width_wide) << op.time << ",";
|
||||
os << std::setw(width_narrow) << std::setprecision(2) << percent * 100
|
||||
<< "%,";
|
||||
acc_percent += percent;
|
||||
os << std::setw(width_narrow) << std::setprecision(2) << acc_percent * 100
|
||||
<< "%,";
|
||||
os << std::setw(width_wide) << op.time_upper << ",";
|
||||
os << std::setw(width_wide) << op.time_lower << ",";
|
||||
os << std::setw(width_narrow) << std::setprecision(2) << eff * 100 << "%,";
|
||||
os << std::setw(width_narrow) << std::setprecision(2) << compute_eff * 100
|
||||
<< "%,";
|
||||
os << std::setw(width_narrow) << std::setprecision(2) << memory_eff * 100
|
||||
<< "%,";
|
||||
os << std::endl;
|
||||
}
|
||||
os << std::endl;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
81
tensorflow/python/grappler/cost_analyzer.h
Normal file
81
tensorflow/python/grappler/cost_analyzer.h
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_
|
||||
|
||||
#include <iostream>
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class GraphDef;
|
||||
class CostGraphDef;
|
||||
|
||||
namespace grappler {
|
||||
struct GrapplerItem;
|
||||
|
||||
// Aggregated perf summary for ops of the same type in a graph.
|
||||
struct OpPerfSummary {
|
||||
string name;
|
||||
int64 count;
|
||||
int64 time;
|
||||
int64 compute_time;
|
||||
int64 memory_time;
|
||||
// Upper and lower bound for estimated time.
|
||||
int64 time_upper;
|
||||
int64 time_lower;
|
||||
};
|
||||
|
||||
// Generate op-level performance insights on compute/memory
|
||||
// efficiency, as well as graph-level aggregated performance statistics.
|
||||
class CostAnalyzer {
|
||||
public:
|
||||
explicit CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
|
||||
const string& suffix);
|
||||
Status GenerateReport(std::ostream& os);
|
||||
|
||||
private:
|
||||
void PredictCosts(CostEstimator* cost_estimator, CostGraphDef* cost_graph,
|
||||
int64* total_time);
|
||||
void GatherCosts();
|
||||
void PreprocessCosts();
|
||||
void AnalyzeCosts();
|
||||
void SortOpsByTime(std::map<string, OpPerfSummary> ops);
|
||||
void PrintAnalysis(std::ostream& os) const;
|
||||
|
||||
const GrapplerItem* item_;
|
||||
MeasuringCostEstimator measure_estimator_;
|
||||
AnalyticalCostEstimator analytical_estimator_;
|
||||
OpPerformanceList op_perf_;
|
||||
OpPerformanceList op_perf_analytical_;
|
||||
int64 total_time_measured_;
|
||||
int64 total_time_analytical_;
|
||||
std::vector<OpPerfSummary> ops_;
|
||||
int64 total_time_measured_serialized_;
|
||||
int64 total_time_analytical_upper_;
|
||||
int64 total_time_analytical_lower_;
|
||||
string suffix_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_
|
67
tensorflow/python/grappler/cost_analyzer.i
Normal file
67
tensorflow/python/grappler/cost_analyzer.i
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The MetaGraphDef could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||
#include "tensorflow/core/grappler/devices.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/python/grappler/cost_analyzer.h"
|
||||
%}
|
||||
|
||||
%{
|
||||
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph) {
|
||||
tensorflow::grappler::ItemConfig cfg;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg);
|
||||
|
||||
// TODO(bsteiner): we should wrap the tf session instead to properly handle the case of a
|
||||
// distributed setup.
|
||||
const int timeout_s = 3600;
|
||||
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
|
||||
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
|
||||
tensorflow::grappler::SingleMachine cluster(timeout_s, num_cpu_cores, num_gpus);
|
||||
|
||||
string suffix;
|
||||
tensorflow::grappler::CostAnalyzer analyzer(*item, &cluster, suffix);
|
||||
|
||||
std::stringstream os;
|
||||
analyzer.GenerateReport(os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph);
|
29
tensorflow/python/grappler/cost_analyzer.py
Normal file
29
tensorflow/python/grappler/cost_analyzer.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Provides a proper python API for the symbols exported through swig."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as tf_wrap
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
|
||||
def GenerateCostReport(metagraph):
|
||||
"""Analyze the cost of each TensorFlow operation in the provided metagraph."""
|
||||
with errors.raise_exception_on_not_ok_status():
|
||||
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString())
|
||||
return ret_from_swig
|
56
tensorflow/python/grappler/cost_analyzer_test.py
Normal file
56
tensorflow/python/grappler/cost_analyzer_test.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the cost analyzer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.grappler import cost_analyzer
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class PyWrapOptimizeGraphTest(test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
"""Make sure arguments can be passed correctly."""
|
||||
a = constant_op.constant(10, name="a")
|
||||
b = constant_op.constant(20, name="b")
|
||||
c = math_ops.add_n([a, b], name="c")
|
||||
d = math_ops.add_n([b, c], name="d")
|
||||
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
||||
train_op.append(d)
|
||||
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
|
||||
|
||||
report = cost_analyzer.GenerateCostReport(mg)
|
||||
|
||||
# Check the report headers
|
||||
self.assertTrue(b"Total time measured in ns (serialized):" in report)
|
||||
self.assertTrue(b"Total time measured in ns (actual):" in report)
|
||||
self.assertTrue(b"Total time analytical in ns (upper bound):" in report)
|
||||
self.assertTrue(b"Total time analytical in ns (lower bound):" in report)
|
||||
self.assertTrue(b"Overall efficiency (analytical upper/actual):" in report)
|
||||
self.assertTrue(b"Overall efficiency (analytical lower/actual):" in report)
|
||||
|
||||
# Also print the report to make it easier to debug
|
||||
print("{}".format(report))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -42,3 +42,4 @@ limitations under the License.
|
||||
%include "tensorflow/python/util/transform_graph.i"
|
||||
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
%include "tensorflow/python/grappler/cost_analyzer.i"
|
||||
|
Loading…
Reference in New Issue
Block a user