diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index 2a071e44a5c..cc4921e5781 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -204,7 +204,8 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, const Tensor* input, Tensor* output, int dev_to_dev_stream_index, StatusCallback done, bool sync_dst_compute) { - profiler::ScopedAnnotation annotation(edge_name); + profiler::ScopedAnnotation annotation( + [&] { return absl::StrCat("#edge_name=", edge_name, "#"); }); VLOG(1) << "Copy " << edge_name; const DeviceType src_device_type( diff --git a/tensorflow/core/common_runtime/memory_types.cc b/tensorflow/core/common_runtime/memory_types.cc index 4088165fac4..b37e65a7ca5 100644 --- a/tensorflow/core/common_runtime/memory_types.cc +++ b/tensorflow/core/common_runtime/memory_types.cc @@ -129,6 +129,8 @@ static Node* Send(Graph* g, const string& tensor_name, .Attr("send_device_incarnation", 0) // Do not care. .Attr("recv_device", device_name) .Attr("_hostmem_sendrecv", true) + .Attr("_src", edge->src()->name()) + .Attr("_dst", edge->dst()->name()) .Finalize(g, &ret)); return ret; } @@ -144,6 +146,8 @@ static Node* Recv(Graph* g, const string& tensor_name, .Attr("send_device_incarnation", 0) .Attr("recv_device", device_name) .Attr("_hostmem_sendrecv", true) + .Attr("_src", edge->src()->name()) + .Attr("_dst", edge->dst()->name()) .Finalize(g, &ret)); return ret; } diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 65b341fbae0..bf57e263441 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -189,6 +189,8 @@ void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge, opts.get_incarnation(edge->src()->assigned_device_name()))); builder->Attr("recv_device", edge->dst()->assigned_device_name()); builder->Attr("client_terminated", false); + builder->Attr("_src", edge->src()->name()); + builder->Attr("_dst", edge->dst()->name()); } NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f940866da5f..e42de02b979 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5175,6 +5175,7 @@ cc_library( REQUIRED_DEPS = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", ] tf_kernel_library( diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index 7e0e3496645..12456037415 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/core/kernels/sendrecv_ops.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -107,6 +109,22 @@ void SendOp::Compute(OpKernelContext* ctx) { } } +string SendOp::TraceString(OpKernelContext* ctx, bool verbose) { + const auto& attr = def().attr(); + auto src_it = attr.find("_src"); + auto dst_it = attr.find("_dst"); + const string& src = src_it != attr.end() ? src_it->second.s() : ""; + const string& dst = dst_it != attr.end() ? dst_it->second.s() : ""; + if (!verbose) { + return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src, + ",to=", dst, "#"); + } else { + string trace_args = GetTraceArgument(ctx); + return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src, + ",to=", dst, ",", trace_args, "#"); + } +} + REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp); REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_DEFAULT), SendOp); @@ -139,6 +157,22 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { } } +string RecvOp::TraceString(OpKernelContext* ctx, bool verbose) { + const auto& attr = def().attr(); + auto src_it = attr.find("_src"); + auto dst_it = attr.find("_dst"); + const string& src = src_it != attr.end() ? src_it->second.s() : ""; + const string& dst = dst_it != attr.end() ? dst_it->second.s() : ""; + if (!verbose) { + return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src, + ",to=", dst, "#"); + } else { + string trace_args = GetTraceArgument(ctx); + return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src, + ",to=", dst, ",", trace_args, "#"); + } +} + namespace { Rendezvous::DoneCallback make_recv_callback(OpKernelContext* ctx, AsyncOpKernel::DoneCallback done) { diff --git a/tensorflow/core/kernels/sendrecv_ops.h b/tensorflow/core/kernels/sendrecv_ops.h index 223854de132..06c5663bc04 100644 --- a/tensorflow/core/kernels/sendrecv_ops.h +++ b/tensorflow/core/kernels/sendrecv_ops.h @@ -26,6 +26,8 @@ class SendOp : public OpKernel { explicit SendOp(OpKernelConstruction* ctx); void Compute(OpKernelContext* ctx) override; + string TraceString(OpKernelContext* ctx, bool verbose) override; + private: string key_prefix_; Rendezvous::ParsedKey parsed_key_; @@ -39,6 +41,8 @@ class RecvOp : public AsyncOpKernel { explicit RecvOp(OpKernelConstruction* ctx); void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + string TraceString(OpKernelContext* ctx, bool verbose) override; + private: string key_prefix_; Rendezvous::ParsedKey parsed_key_;