diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 49265445659..1bdccf5f0ea 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -409,6 +409,21 @@ void StepStatsCollector::Save(const string& device, } } +void StepStatsCollector::SaveThreadName(const string& device, + const uint32 thread_id, + const string& thread_name) { + VLOG(1) << "Save dev " << device << " thread id " << thread_id << " name " + << thread_name; + { + mutex_lock l(mu_); + if (finalized_) { + LOG(WARNING) << "thread_name saved after finalize will not be collected."; + } + auto& thread_names_map = thread_names_[device]; + thread_names_map[thread_id] = thread_name; + } +} + NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats( const Node* node) { // Only collect statistics for non-transfer nodes. @@ -531,5 +546,15 @@ void StepStatsCollector::FinalizeInternal() { stats->stats()->Swap(dss->add_node_stats()); } } + for (const auto& device_thread : thread_names_) { + if (dev_stats_pb.find(device_thread.first) == dev_stats_pb.end()) { + // skip device without DeviceStepStats. + continue; + } + DeviceStepStats* dss = dev_stats_pb.at(device_thread.first); + for (const auto& thread_name : device_thread.second) { + (*dss->mutable_thread_names())[thread_name.first] = thread_name.second; + } + } } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index 7d34383ce82..dfcc51ff4c7 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -175,6 +175,10 @@ class StepStatsCollector : public StepStatsCollectorInterface { void Save(const string& device, NodeExecStats* node_stats_pb); void Save(const string& device, NodeExecStatsWrapper* node_stats); + // Saves thread name. + void SaveThreadName(const string& device, const uint32 thread_id, + const string& thread_name); + NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override; string ReportAllocsOnResourceExhausted(const string& err) override; @@ -191,12 +195,14 @@ class StepStatsCollector : public StepStatsCollectorInterface { static const uint64 kMaxCollectedNodes = 1 << 20; typedef std::vector> NodeStatsVector; + typedef std::unordered_map ThreadNamesMap; void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_); mutex mu_; bool finalized_ GUARDED_BY(mu_); std::unordered_map dev_stats_ GUARDED_BY(mu_); + std::unordered_map thread_names_ GUARDED_BY(mu_); StepStats* step_stats_ GUARDED_BY(mu_); uint64 collected_nodes_ GUARDED_BY(mu_) = 0; }; diff --git a/tensorflow/core/framework/step_stats.proto b/tensorflow/core/framework/step_stats.proto index 67cc9e38459..f8cab135aba 100644 --- a/tensorflow/core/framework/step_stats.proto +++ b/tensorflow/core/framework/step_stats.proto @@ -77,6 +77,8 @@ message NodeExecStats { message DeviceStepStats { string device = 1; repeated NodeExecStats node_stats = 2; + // Its key is thread id. + map thread_names = 3; } message StepStats { diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.cc b/tensorflow/core/profiler/internal/cpu/host_tracer.cc index 1efc50695ac..3fb29664688 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer.cc @@ -79,6 +79,8 @@ Status HostTracer::CollectDataToCollector( const string cpu_name = "/host:CPU"; for (auto& thread : events_) { + step_stats_collector->SaveThreadName(cpu_name, thread.thread.tid, + thread.thread.name); for (auto& event : thread.events) { if (!event.end_time) { auto it = end_times.find(event.activity_id); diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc index 4641ae2a5f6..1eb9ed6a990 100644 --- a/tensorflow/core/profiler/lib/profiler_session.cc +++ b/tensorflow/core/profiler/lib/profiler_session.cc @@ -47,6 +47,12 @@ void ConvertRunMetadataToTraceEvent(RunMetadata* run_metadata, resource.set_name("0"); resource.set_resource_id(0); (*device.mutable_resources())[0] = resource; + for (const auto& thread_name : device_stats->thread_names()) { + tensorflow::tpu::Resource resource; + resource.set_resource_id(thread_name.first); + resource.set_name(thread_name.second); + (*device.mutable_resources())[thread_name.first] = resource; + } (*trace_devices)[device_id] = device; // Emit events.