Use TraceMeOp and TraceMeEncode to hide TraceString encoding details
PiperOrigin-RevId: 323025848 Change-Id: I6f491180763d5d87a2e53296b460d0e41e4b0ac7
This commit is contained in:
parent
3edbef6647
commit
a63c4d60cf
@ -285,10 +285,11 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
||||
// starve executor threads.
|
||||
remote_access_->RunClosure([col_impl, col_ctx, done_safe, ctx]() {
|
||||
profiler::TraceMe activity(
|
||||
[&] {
|
||||
return strings::StrCat(ctx->op_kernel().name_view(), ":",
|
||||
ctx->op_kernel().type_string_view(),
|
||||
"#id=", ctx->step_id(), "#");
|
||||
[ctx] {
|
||||
string op = profiler::TraceMeOp(ctx->op_kernel().name_view(),
|
||||
ctx->op_kernel().type_string_view());
|
||||
return profiler::TraceMeEncode(std::move(op),
|
||||
{{"id", ctx->step_id()}});
|
||||
},
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
|
||||
|
@ -296,7 +296,7 @@ Status KernelAndDeviceOp::Run(
|
||||
// 'AnnotatedTraceMe' will trace both scheduling time on host and execution
|
||||
// time on device of the OpKernel.
|
||||
profiler::AnnotatedTraceMe activity(
|
||||
[&] { return kernel_->TraceString(&context, /*verbose=*/false); },
|
||||
[&] { return kernel_->TraceString(context, /*verbose=*/false); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
device_->Compute(kernel_.get(), &context);
|
||||
}
|
||||
|
@ -530,9 +530,9 @@ Status ExecutorState<PropagatorStateType>::ProcessSync(
|
||||
tracing::ScopedRegion region(tracing::EventCategory::kCompute,
|
||||
op_kernel->name_view());
|
||||
profiler::AnnotatedTraceMe activity(
|
||||
[&] {
|
||||
[op_kernel, &ctx] {
|
||||
return op_kernel->TraceString(
|
||||
&ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
|
||||
ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
|
||||
},
|
||||
profiler::GetTFTraceMeLevel(is_expensive));
|
||||
device->Compute(op_kernel, &ctx);
|
||||
@ -597,9 +597,9 @@ void ExecutorState<PropagatorStateType>::ProcessAsync(
|
||||
nodestats::SetOpStart(stats);
|
||||
{
|
||||
profiler::AnnotatedTraceMe activity(
|
||||
[&] {
|
||||
[async_kernel, state] {
|
||||
return async_kernel->TraceString(
|
||||
&state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
|
||||
state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
|
||||
},
|
||||
profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item)));
|
||||
immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx,
|
||||
|
@ -524,8 +524,9 @@ void DatasetOpKernel::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
string DatasetOpKernel::TraceString(OpKernelContext* ctx, bool verbose) {
|
||||
return strings::StrCat(name_view(), ":", type_string_view());
|
||||
string DatasetOpKernel::TraceString(const OpKernelContext& ctx,
|
||||
bool verbose) const {
|
||||
return profiler::TraceMeOp(name_view(), type_string_view());
|
||||
}
|
||||
|
||||
// static
|
||||
|
@ -1063,7 +1063,7 @@ class DatasetOpKernel : public OpKernel {
|
||||
// the `DatasetOpKernel` class.
|
||||
static bool IsDatasetOp(const OpDef* op_def);
|
||||
|
||||
string TraceString(OpKernelContext* ctx, bool verbose) override;
|
||||
string TraceString(const OpKernelContext& ctx, bool verbose) const override;
|
||||
|
||||
protected:
|
||||
// Subclasses should implement this method. It will be called during Compute
|
||||
|
@ -53,6 +53,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/platform_strings.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -172,34 +173,38 @@ Status OpKernel::OutputRange(StringPiece output_name, int* start,
|
||||
}
|
||||
}
|
||||
|
||||
string OpKernel::GetTraceArgument(OpKernelContext* ctx) {
|
||||
int num_inputs = ctx->num_inputs();
|
||||
string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const {
|
||||
int num_inputs = ctx.num_inputs();
|
||||
if (num_inputs == 0) return "";
|
||||
std::vector<string> tensor_shapes;
|
||||
tensor_shapes.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
if (!ctx->has_input(i)) {
|
||||
if (!ctx.has_input(i)) {
|
||||
tensor_shapes.emplace_back(); // Placeholder
|
||||
continue;
|
||||
}
|
||||
DataType input_dtype = ctx->input_dtype(i);
|
||||
DataType input_dtype = ctx.input_dtype(i);
|
||||
if (input_dtype == DataType::DT_RESOURCE ||
|
||||
input_dtype == DataType::DT_VARIANT || IsRefType(input_dtype)) {
|
||||
tensor_shapes.emplace_back(); // Placeholder
|
||||
continue;
|
||||
}
|
||||
tensor_shapes.emplace_back(strings::StrCat(
|
||||
DataTypeString(input_dtype), ctx->input(i).shape().DebugString()));
|
||||
DataTypeString(input_dtype), ctx.input(i).shape().DebugString()));
|
||||
}
|
||||
return strings::StrCat("shape=(", absl::StrJoin(tensor_shapes, ";"), ")");
|
||||
return strings::StrCat("(", absl::StrJoin(tensor_shapes, ";"), ")");
|
||||
}
|
||||
|
||||
string OpKernel::TraceString(OpKernelContext* ctx, bool verbose) {
|
||||
string trace_string = strings::StrCat(name_view(), ":", type_string_view());
|
||||
if (!verbose) return trace_string;
|
||||
string trace_args = GetTraceArgument(ctx);
|
||||
if (trace_args.empty()) return trace_string;
|
||||
return strings::StrCat(trace_string, "#", trace_args, "#");
|
||||
string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const {
|
||||
string trace_string = profiler::TraceMeOp(name_view(), type_string_view());
|
||||
if (verbose) {
|
||||
string shape = ShapeTraceString(ctx);
|
||||
if (!shape.empty()) {
|
||||
trace_string =
|
||||
profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
|
||||
}
|
||||
}
|
||||
return trace_string;
|
||||
}
|
||||
|
||||
void AsyncOpKernel::Compute(OpKernelContext* context) {
|
||||
@ -413,7 +418,7 @@ Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const Tensor& OpKernelContext::input(int index) {
|
||||
const Tensor& OpKernelContext::input(int index) const {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
|
||||
CHECK(!input_is_ref(index));
|
||||
|
@ -177,12 +177,10 @@ class OpKernel {
|
||||
// Returns a trace string for current computation, op name/type and input
|
||||
// tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
|
||||
// should use the default implementation.
|
||||
// Override this function to add OpKernel specific attributes that are
|
||||
// necessary for cost analysis.
|
||||
virtual string TraceString(OpKernelContext* ctx, bool verbose);
|
||||
virtual string TraceString(const OpKernelContext& ctx, bool verbose) const;
|
||||
|
||||
protected:
|
||||
string GetTraceArgument(OpKernelContext* ctx);
|
||||
string ShapeTraceString(const OpKernelContext& ctx) const;
|
||||
|
||||
private:
|
||||
const std::shared_ptr<const NodeProperties> props_;
|
||||
@ -734,7 +732,7 @@ class OpKernelContext {
|
||||
// inputs. For Ref inputs use mutable_input below.
|
||||
// REQUIRES: !IsRefType(input_dtype(index))
|
||||
// TODO(mrry): Convert this to return Status.
|
||||
const Tensor& input(int index);
|
||||
const Tensor& input(int index) const;
|
||||
|
||||
// Returns the named immutable input tensor in "tensor", as defined
|
||||
// in the OpDef. May only be used for non-Ref inputs. For Ref inputs
|
||||
|
@ -1105,7 +1105,7 @@ void BM_TraceString(const int iters, const int verbose) {
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
auto trace = op->TraceString(ctx.get(), verbose);
|
||||
auto trace = op->TraceString(*ctx, verbose);
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
@ -3824,6 +3824,7 @@ tf_kernel_library(
|
||||
":transpose_functor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -5341,7 +5342,9 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "sendrecv_ops",
|
||||
prefix = "sendrecv_ops",
|
||||
deps = REQUIRED_DEPS,
|
||||
deps = REQUIRED_DEPS + [
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/util/einsum_op_util.h"
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
@ -715,15 +716,17 @@ class EinsumOp : public OpKernel {
|
||||
ctx->set_output(0, output);
|
||||
}
|
||||
|
||||
string TraceString(OpKernelContext* ctx, bool verbose) override {
|
||||
if (!verbose) {
|
||||
return strings::StrCat(name_view(), ":", type_string_view(),
|
||||
"#equation=(", equation_, ")#");
|
||||
} else {
|
||||
string trace_args = GetTraceArgument(ctx);
|
||||
return strings::StrCat(name_view(), ":", type_string_view(),
|
||||
"#equation=(", equation_, "),", trace_args, "#");
|
||||
string TraceString(const OpKernelContext& ctx, bool verbose) const override {
|
||||
string op = profiler::TraceMeOp(name_view(), type_string_view());
|
||||
string equation = strings::StrCat("(", equation_, ")");
|
||||
if (verbose) {
|
||||
string shape = ShapeTraceString(ctx);
|
||||
if (!shape.empty()) {
|
||||
return profiler::TraceMeEncode(
|
||||
std::move(op), {{"equation", equation}, {"shape", shape}});
|
||||
}
|
||||
}
|
||||
return profiler::TraceMeEncode(std::move(op), {{"equation", equation}});
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -441,13 +441,18 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
});
|
||||
}
|
||||
|
||||
string RemoteCallOp::TraceString(OpKernelContext* ctx, bool verbose) {
|
||||
string trace_string =
|
||||
strings::StrCat(name_view(), "__", func_.name(), ":", type_string_view());
|
||||
if (!verbose) return trace_string;
|
||||
string trace_args = GetTraceArgument(ctx);
|
||||
if (trace_args.empty()) return trace_string;
|
||||
return strings::StrCat(trace_string, "#", trace_args, "#");
|
||||
string RemoteCallOp::TraceString(const OpKernelContext& ctx,
|
||||
bool verbose) const {
|
||||
string trace_string = profiler::TraceMeOp(
|
||||
strings::StrCat(name_view(), "__", func_.name()), type_string_view());
|
||||
if (verbose) {
|
||||
string shape = ShapeTraceString(ctx);
|
||||
if (!shape.empty()) {
|
||||
trace_string =
|
||||
profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
|
||||
}
|
||||
}
|
||||
return trace_string;
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
|
@ -64,7 +64,7 @@ class RemoteCallOp : public AsyncOpKernel {
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||
|
||||
string TraceString(OpKernelContext* ctx, bool verbose) override;
|
||||
string TraceString(const OpKernelContext& ctx, bool verbose) const override;
|
||||
|
||||
private:
|
||||
NameAttrList func_;
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -111,14 +112,14 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
string SendOp::TraceString(OpKernelContext* ctx, bool verbose) {
|
||||
string SendOp::TraceString(const OpKernelContext& ctx, bool verbose) const {
|
||||
const auto& attr = def().attr();
|
||||
auto src_it = attr.find("_src");
|
||||
auto dst_it = attr.find("_dst");
|
||||
const string& src = src_it != attr.end() ? src_it->second.s() : "";
|
||||
const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
|
||||
return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src,
|
||||
",to=", dst, "#");
|
||||
string op = profiler::TraceMeOp(name_view(), type_string_view());
|
||||
return profiler::TraceMeEncode(std::move(op), {{"from", src}, {"to", dst}});
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
|
||||
@ -155,14 +156,14 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
string RecvOp::TraceString(OpKernelContext* ctx, bool verbose) {
|
||||
string RecvOp::TraceString(const OpKernelContext& ctx, bool verbose) const {
|
||||
const auto& attr = def().attr();
|
||||
auto src_it = attr.find("_src");
|
||||
auto dst_it = attr.find("_dst");
|
||||
const string& src = src_it != attr.end() ? src_it->second.s() : "";
|
||||
const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
|
||||
return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src,
|
||||
",to=", dst, "#");
|
||||
string op = profiler::TraceMeOp(name_view(), type_string_view());
|
||||
return profiler::TraceMeEncode(std::move(op), {{"from", src}, {"to", dst}});
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -26,7 +26,7 @@ class SendOp : public OpKernel {
|
||||
explicit SendOp(OpKernelConstruction* ctx);
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
string TraceString(OpKernelContext* ctx, bool verbose) override;
|
||||
string TraceString(const OpKernelContext& ctx, bool verbose) const override;
|
||||
|
||||
private:
|
||||
string key_prefix_;
|
||||
@ -41,7 +41,7 @@ class RecvOp : public AsyncOpKernel {
|
||||
explicit RecvOp(OpKernelConstruction* ctx);
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||
|
||||
string TraceString(OpKernelContext* ctx, bool verbose) override;
|
||||
string TraceString(const OpKernelContext& ctx, bool verbose) const override;
|
||||
|
||||
private:
|
||||
string key_prefix_;
|
||||
|
@ -128,6 +128,17 @@ TF_ATTRIBUTE_ALWAYS_INLINE inline std::string TraceMeEncode(
|
||||
return traceme_internal::AppendArgs(std::string(), args);
|
||||
}
|
||||
|
||||
// Concatenates op_name and op_type.
|
||||
TF_ATTRIBUTE_ALWAYS_INLINE inline std::string TraceMeOp(
|
||||
absl::string_view op_name, absl::string_view op_type) {
|
||||
return absl::StrCat(op_name, ":", op_type);
|
||||
}
|
||||
TF_ATTRIBUTE_ALWAYS_INLINE inline std::string TraceMeOp(
|
||||
std::string&& op_name, absl::string_view op_type) {
|
||||
absl::StrAppend(&op_name, ":", op_type);
|
||||
return op_name;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user