give sendop/recvop some temporary attribute for tracing. so we can annotate the memcpy device events better.
PiperOrigin-RevId: 296242264 Change-Id: Ib515bc56faf37ede6610b62c8aaab3ab66ef6830
This commit is contained in:
parent
0fa7a0b033
commit
3113c74feb
@ -204,7 +204,8 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
|
||||
const Tensor* input, Tensor* output,
|
||||
int dev_to_dev_stream_index, StatusCallback done,
|
||||
bool sync_dst_compute) {
|
||||
profiler::ScopedAnnotation annotation(edge_name);
|
||||
profiler::ScopedAnnotation annotation(
|
||||
[&] { return absl::StrCat("#edge_name=", edge_name, "#"); });
|
||||
VLOG(1) << "Copy " << edge_name;
|
||||
|
||||
const DeviceType src_device_type(
|
||||
|
@ -129,6 +129,8 @@ static Node* Send(Graph* g, const string& tensor_name,
|
||||
.Attr("send_device_incarnation", 0) // Do not care.
|
||||
.Attr("recv_device", device_name)
|
||||
.Attr("_hostmem_sendrecv", true)
|
||||
.Attr("_src", edge->src()->name())
|
||||
.Attr("_dst", edge->dst()->name())
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
@ -144,6 +146,8 @@ static Node* Recv(Graph* g, const string& tensor_name,
|
||||
.Attr("send_device_incarnation", 0)
|
||||
.Attr("recv_device", device_name)
|
||||
.Attr("_hostmem_sendrecv", true)
|
||||
.Attr("_src", edge->src()->name())
|
||||
.Attr("_dst", edge->dst()->name())
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
@ -189,6 +189,8 @@ void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
|
||||
opts.get_incarnation(edge->src()->assigned_device_name())));
|
||||
builder->Attr("recv_device", edge->dst()->assigned_device_name());
|
||||
builder->Attr("client_terminated", false);
|
||||
builder->Attr("_src", edge->src()->name());
|
||||
builder->Attr("_dst", edge->dst()->name());
|
||||
}
|
||||
|
||||
NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
|
||||
|
@ -5175,6 +5175,7 @@ cc_library(
|
||||
REQUIRED_DEPS = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/sendrecv_ops.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
@ -107,6 +109,22 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
string SendOp::TraceString(OpKernelContext* ctx, bool verbose) {
|
||||
const auto& attr = def().attr();
|
||||
auto src_it = attr.find("_src");
|
||||
auto dst_it = attr.find("_dst");
|
||||
const string& src = src_it != attr.end() ? src_it->second.s() : "";
|
||||
const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
|
||||
if (!verbose) {
|
||||
return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src,
|
||||
",to=", dst, "#");
|
||||
} else {
|
||||
string trace_args = GetTraceArgument(ctx);
|
||||
return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src,
|
||||
",to=", dst, ",", trace_args, "#");
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_DEFAULT), SendOp);
|
||||
|
||||
@ -139,6 +157,22 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
string RecvOp::TraceString(OpKernelContext* ctx, bool verbose) {
|
||||
const auto& attr = def().attr();
|
||||
auto src_it = attr.find("_src");
|
||||
auto dst_it = attr.find("_dst");
|
||||
const string& src = src_it != attr.end() ? src_it->second.s() : "";
|
||||
const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
|
||||
if (!verbose) {
|
||||
return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src,
|
||||
",to=", dst, "#");
|
||||
} else {
|
||||
string trace_args = GetTraceArgument(ctx);
|
||||
return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src,
|
||||
",to=", dst, ",", trace_args, "#");
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
Rendezvous::DoneCallback make_recv_callback(OpKernelContext* ctx,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
|
@ -26,6 +26,8 @@ class SendOp : public OpKernel {
|
||||
explicit SendOp(OpKernelConstruction* ctx);
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
string TraceString(OpKernelContext* ctx, bool verbose) override;
|
||||
|
||||
private:
|
||||
string key_prefix_;
|
||||
Rendezvous::ParsedKey parsed_key_;
|
||||
@ -39,6 +41,8 @@ class RecvOp : public AsyncOpKernel {
|
||||
explicit RecvOp(OpKernelConstruction* ctx);
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||
|
||||
string TraceString(OpKernelContext* ctx, bool verbose) override;
|
||||
|
||||
private:
|
||||
string key_prefix_;
|
||||
Rendezvous::ParsedKey parsed_key_;
|
||||
|
Loading…
x
Reference in New Issue
Block a user