diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index dd50d0577d4..e401a798d68 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -141,7 +141,9 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/core:allocator", "//tensorflow/core:lib", + "//tensorflow/core/profiler/lib:connected_traceme", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/profiler/lib:traceme_encode", "//tensorflow/stream_executor:event", "//tensorflow/stream_executor:stream", "//tensorflow/stream_executor/host:host_platform_id", diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index ccb72b7ce30..ef259cf1cfd 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -98,7 +98,9 @@ limitations under the License. #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/connected_traceme.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/profiler/lib/traceme_encode.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/event.h" @@ -1429,10 +1431,9 @@ StatusOr PjRtExecutable::EnqueueExecution( int executable_idx, const RunId& run_id, const ExecuteOptions& options, Device* device, std::vector* device_buffers) const { int device_ordinal = device->local_device_state()->device_ordinal(); - tensorflow::profiler::TraceMe traceme([&] { - return absl::StrCat("LocalExecutable::Execute#run_id=", run_id.ToInt(), - "#"); - }); + tensorflow::profiler::TraceMeConsumer activity( + "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt, + run_id.ToInt()); VLOG(3) << "Replica " << replica << ", partition " << partition << " mapped to device ordinal for execution: " << device_ordinal; @@ -1721,10 +1722,9 @@ PjRtExecutable::ExecuteOnLocalDevices( absl::Span> argument_handles, const ExecuteOptions& options) const { RunId run_id; - tensorflow::profiler::TraceMe traceme([&] { - return absl::StrCat( - "LocalExecutable::ExecuteOnLocalDevices#run_id=", run_id.ToInt(), "#"); - }); + tensorflow::profiler::TraceMeProducer activity( + "LocalExecutable::ExecuteOnLocalDevices", + tensorflow::profiler::ContextType::kPjRt, run_id.ToInt()); const int num_local_devices = local_devices_.size(); diff --git a/tensorflow/core/profiler/lib/connected_traceme.h b/tensorflow/core/profiler/lib/connected_traceme.h index ed8b4ac1ad2..b55c4407fe6 100644 --- a/tensorflow/core/profiler/lib/connected_traceme.h +++ b/tensorflow/core/profiler/lib/connected_traceme.h @@ -29,6 +29,7 @@ enum class ContextType : int { kGeneric, kTfExecutor, kSharedBatchScheduler, + kPjRt, }; /*