diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 1ead2d5baaf..792260ce5ac 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -173,6 +173,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":graph_properties", + ":utils", ":virtual_placer", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 727aeb7ee6c..7b7d79fc7ed 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -219,5 +219,16 @@ OpInfo BuildOpInfo( return op_info; } +string GetOpDescription(const OpInfo& op_info) { + string description = "["; + description += "Op=" + op_info.op() + ", "; + description += "input_shapes=["; + for (auto const& input : op_info.inputs()) { + description += PartialTensorShape::DebugString(input.shape()); + } + description += "]"; + return description; +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h index bdba4e4b156..cb23ac83553 100644 --- a/tensorflow/core/grappler/costs/utils.h +++ b/tensorflow/core/grappler/costs/utils.h @@ -45,6 +45,9 @@ std::vector FindInputFeatures( DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node); DeviceProperties GetDeviceInfo(const string& device_str); +// Return a string describing a node given a nodeinfo. +string GetOpDescription(const OpInfo& op_info); + // Builds the OpInfo proto for node, given all nodes in the graph, the node's // device and its input properties which are typically built by shape inference // or calling FindInputFeatures. diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 86cf498538c..80318fe8ad5 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/clusters/utils.h" +#include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/util/device_name_utils.h" @@ -349,6 +350,13 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { const auto* node = GetCurrNode(); const auto& op_name = node->op(); + // Also keep track of op counts and times per op (with their shapes). + NodeInfo node_info = GetCurrNodeInfo(); + string node_description = GetOpDescription(node_info.op_info); + op_counts_[node_description] += 1; + op_costs_[node_description] = + node_costs.execution_time.asMicroSeconds().count(); + auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_); op_cost = CombineCosts(op_cost, node_costs); @@ -445,6 +453,13 @@ Costs VirtualScheduler::Summary() const { } } + // Also log the op description and their corresponding counts. + VLOG(1) << "Node description, counts, cost:"; + for (const auto& item : op_counts_) { + VLOG(1) << "Node: " << item.first << ", Count: " << item.second + << ", Individual Cost: " << op_costs_.at(item.first); + } + VLOG(1) << "Critical path execution time: " << critical_path_costs.execution_time.count(); return critical_path_costs; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 08550714329..83878eea0a6 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -137,6 +137,8 @@ class VirtualScheduler { bool IsRecvOp(const NodeDef* node) const; GraphProperties graph_properties_; + std::map op_counts_; // Op counts with key with input shape. + std::map op_costs_; // Individual op costs (with input shapes). Costs graph_costs_; // Graph cost. std::map op_to_cost_; // Per-op cost. std::unique_ptr ready_nodes_;