diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 5103c1129ee..3c765229ca8 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -8,6 +8,7 @@ cc_library( srcs = ["host_threads_xplane_to_tf_metrics_db.cc"], hdrs = ["host_threads_xplane_to_tf_metrics_db.h"], deps = [ + ":op_metrics_db_combiner", ":op_stack", "//tensorflow/core/platform:types", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", diff --git a/tensorflow/core/profiler/convert/host_threads_xplane_to_tf_metrics_db.cc b/tensorflow/core/profiler/convert/host_threads_xplane_to_tf_metrics_db.cc index 60bf256d902..9ff1a9bfef5 100644 --- a/tensorflow/core/profiler/convert/host_threads_xplane_to_tf_metrics_db.cc +++ b/tensorflow/core/profiler/convert/host_threads_xplane_to_tf_metrics_db.cc @@ -20,6 +20,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "tensorflow/core/profiler/convert/op_stack.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/op_utils.h" #include "tensorflow/core/profiler/utils/timespan.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" @@ -137,6 +139,23 @@ void CollectTfActivities(const XLineVisitor& line, } // namespace +absl::flat_hash_map CollectTfOpsFromHostThreadsXPlane( + const XPlane& host_trace) { + absl::flat_hash_map tf_ops; + for (const auto& id_metadata : host_trace.event_metadata()) { + const XEventMetadata& metadata = id_metadata.second; + // On the host, we have added some user-specified TraceMe's in addition to + // the TraceMe's added to every TensorFlow op by the system. These + // user-inserted TraceMe's have "unknown" type. We don't count them in + // Tf-stats. + TfOp tf_op = ParseTfOpFullname(metadata.name()); + if (!IsUnknownOp(tf_op.type)) { + tf_ops.try_emplace(metadata.id(), tf_op); + } + } + return tf_ops; +} + TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData( const XLineVisitor& line, const absl::flat_hash_map& tf_ops) { TfMetricsDbData tf_metrics_db_data; @@ -148,5 +167,24 @@ TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData( return tf_metrics_db_data; } +void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst) { + AddIdleOp(&src.tf_metrics_db); + dst->Combine(src.tf_metrics_db); + src.tf_metrics_db.Clear(); +} + +OpMetricsDb ConvertHostThreadsXPlaneToTfMetricsDb(const XPlane& host_trace) { + absl::flat_hash_map tf_ops = + CollectTfOpsFromHostThreadsXPlane(host_trace); + OpMetricsDb result; + OpMetricsDbCombiner combiner(&result); + XPlaneVisitor plane(&host_trace); + plane.ForEachLine([&tf_ops, &combiner](const XLineVisitor& line) { + ConsumeTfMetricsDbData( + ConvertHostThreadsXLineToTfMetricsDbData(line, tf_ops), &combiner); + }); + return result; +} + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/host_threads_xplane_to_tf_metrics_db.h b/tensorflow/core/profiler/convert/host_threads_xplane_to_tf_metrics_db.h index e6f3c82f052..c8c6e10c2ef 100644 --- a/tensorflow/core/profiler/convert/host_threads_xplane_to_tf_metrics_db.h +++ b/tensorflow/core/profiler/convert/host_threads_xplane_to_tf_metrics_db.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/event_span.h" @@ -40,9 +41,16 @@ struct TfMetricsDbData { HostOpMetricsDbBuilder tf_metrics_db_builder{&tf_metrics_db}; }; +absl::flat_hash_map CollectTfOpsFromHostThreadsXPlane( + const XPlane& host_trace); + TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData( const XLineVisitor& line, const absl::flat_hash_map& tf_ops); +void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst); + +OpMetricsDb ConvertHostThreadsXPlaneToTfMetricsDb(const XPlane& host_trace); + } // namespace profiler } // namespace tensorflow