From 356e90472b519bbf70e0cae57ba74d1a22f4c77c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Feb 2020 13:57:49 -0800 Subject: [PATCH] Add flops/bytes of non XLA GPU models to tensorflow stats. PiperOrigin-RevId: 294749986 Change-Id: I1142617e7cafc259be8ea7ffdc1140918d4d5326 --- tensorflow/core/profiler/convert/BUILD | 1 + .../core/profiler/convert/xplane_to_op_metrics_db.cc | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index fe2a34b3b18..05d802ed87b 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -16,6 +16,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:cost_utils", "//tensorflow/core/profiler/utils:event_span", "//tensorflow/core/profiler/utils:op_utils", "//tensorflow/core/profiler/utils:tf_op_utils", diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index 55bf81c552a..b5181b1edd3 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -24,7 +24,9 @@ limitations under the License. #include "tensorflow/core/profiler/convert/op_stack.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/cost_utils.h" #include "tensorflow/core/profiler/utils/op_utils.h" +#include "tensorflow/core/profiler/utils/tf_op_utils.h" #include "tensorflow/core/profiler/utils/timespan.h" #include "tensorflow/core/profiler/utils/trace_utils.h" @@ -198,6 +200,7 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( int64 first_op_offset_ps = kint64max; int64 last_op_offset_ps = 0; + TfOpRoofLineCostEstimator op_level_cost_estimator; XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_trace); plane.ForEachLine([&](const XLineVisitor& line) { if (IsDerivedThreadId(line.Id())) return; @@ -210,11 +213,14 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( absl::string_view tf_op_fullname = stat->str_value(); if (tf_op_fullname.empty()) return; TfOp tf_op = ParseTfOpFullname(tf_op_fullname); + TfOpRoofLineCostEstimator::OpRoofLineStats costs; + if (tf_op.type != kUnknownOp) { + costs = op_level_cost_estimator.Predict(event); + } device_op_metrics_db_builder.EnterOp( /*program_id=*/0, tf_op.name, tf_op.type, tf_op_fullname, /*occurrences=*/1, event.DurationPs(), - /*children_time_ps=*/0, /*flops=*/0, - /*bytes_accessed=*/0); + /*children_time_ps=*/0, costs.flops, costs.bytes_accessed); }); }); result.set_total_time_ps(last_op_offset_ps - first_op_offset_ps);