Use TraceMeOp and TraceMeEncode to hide TraceString encoding details

PiperOrigin-RevId: 323025848
Change-Id: I6f491180763d5d87a2e53296b460d0e41e4b0ac7
This commit is contained in:
Jose Baiocchi 2020-07-24 10:47:12 -07:00 committed by TensorFlower Gardener
parent 3edbef6647
commit a63c4d60cf
15 changed files with 84 additions and 56 deletions

View File

@ -285,10 +285,11 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
// starve executor threads.
remote_access_->RunClosure([col_impl, col_ctx, done_safe, ctx]() {
profiler::TraceMe activity(
[&] {
return strings::StrCat(ctx->op_kernel().name_view(), ":",
ctx->op_kernel().type_string_view(),
"#id=", ctx->step_id(), "#");
[ctx] {
string op = profiler::TraceMeOp(ctx->op_kernel().name_view(),
ctx->op_kernel().type_string_view());
return profiler::TraceMeEncode(std::move(op),
{{"id", ctx->step_id()}});
},
profiler::TraceMeLevel::kInfo);
col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {

View File

@ -296,7 +296,7 @@ Status KernelAndDeviceOp::Run(
// 'AnnotatedTraceMe' will trace both scheduling time on host and execution
// time on device of the OpKernel.
profiler::AnnotatedTraceMe activity(
[&] { return kernel_->TraceString(&context, /*verbose=*/false); },
[&] { return kernel_->TraceString(context, /*verbose=*/false); },
profiler::TraceMeLevel::kInfo);
device_->Compute(kernel_.get(), &context);
}

View File

@ -530,9 +530,9 @@ Status ExecutorState<PropagatorStateType>::ProcessSync(
tracing::ScopedRegion region(tracing::EventCategory::kCompute,
op_kernel->name_view());
profiler::AnnotatedTraceMe activity(
[&] {
[op_kernel, &ctx] {
return op_kernel->TraceString(
&ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
},
profiler::GetTFTraceMeLevel(is_expensive));
device->Compute(op_kernel, &ctx);
@ -597,9 +597,9 @@ void ExecutorState<PropagatorStateType>::ProcessAsync(
nodestats::SetOpStart(stats);
{
profiler::AnnotatedTraceMe activity(
[&] {
[async_kernel, state] {
return async_kernel->TraceString(
&state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
},
profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item)));
immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx,

View File

@ -524,8 +524,9 @@ void DatasetOpKernel::Compute(OpKernelContext* ctx) {
}
}
string DatasetOpKernel::TraceString(OpKernelContext* ctx, bool verbose) {
return strings::StrCat(name_view(), ":", type_string_view());
string DatasetOpKernel::TraceString(const OpKernelContext& ctx,
bool verbose) const {
return profiler::TraceMeOp(name_view(), type_string_view());
}
// static

View File

@ -1063,7 +1063,7 @@ class DatasetOpKernel : public OpKernel {
// the `DatasetOpKernel` class.
static bool IsDatasetOp(const OpDef* op_def);
string TraceString(OpKernelContext* ctx, bool verbose) override;
string TraceString(const OpKernelContext& ctx, bool verbose) const override;
protected:
// Subclasses should implement this method. It will be called during Compute

View File

@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/platform_strings.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@ -172,34 +173,38 @@ Status OpKernel::OutputRange(StringPiece output_name, int* start,
}
}
string OpKernel::GetTraceArgument(OpKernelContext* ctx) {
int num_inputs = ctx->num_inputs();
string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const {
int num_inputs = ctx.num_inputs();
if (num_inputs == 0) return "";
std::vector<string> tensor_shapes;
tensor_shapes.reserve(num_inputs);
for (int i = 0; i < num_inputs; i++) {
if (!ctx->has_input(i)) {
if (!ctx.has_input(i)) {
tensor_shapes.emplace_back(); // Placeholder
continue;
}
DataType input_dtype = ctx->input_dtype(i);
DataType input_dtype = ctx.input_dtype(i);
if (input_dtype == DataType::DT_RESOURCE ||
input_dtype == DataType::DT_VARIANT || IsRefType(input_dtype)) {
tensor_shapes.emplace_back(); // Placeholder
continue;
}
tensor_shapes.emplace_back(strings::StrCat(
DataTypeString(input_dtype), ctx->input(i).shape().DebugString()));
DataTypeString(input_dtype), ctx.input(i).shape().DebugString()));
}
return strings::StrCat("shape=(", absl::StrJoin(tensor_shapes, ";"), ")");
return strings::StrCat("(", absl::StrJoin(tensor_shapes, ";"), ")");
}
string OpKernel::TraceString(OpKernelContext* ctx, bool verbose) {
string trace_string = strings::StrCat(name_view(), ":", type_string_view());
if (!verbose) return trace_string;
string trace_args = GetTraceArgument(ctx);
if (trace_args.empty()) return trace_string;
return strings::StrCat(trace_string, "#", trace_args, "#");
string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const {
string trace_string = profiler::TraceMeOp(name_view(), type_string_view());
if (verbose) {
string shape = ShapeTraceString(ctx);
if (!shape.empty()) {
trace_string =
profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
}
}
return trace_string;
}
void AsyncOpKernel::Compute(OpKernelContext* context) {
@ -413,7 +418,7 @@ Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
return Status::OK();
}
const Tensor& OpKernelContext::input(int index) {
const Tensor& OpKernelContext::input(int index) const {
CHECK_GE(index, 0);
CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
CHECK(!input_is_ref(index));

View File

@ -177,12 +177,10 @@ class OpKernel {
// Returns a trace string for current computation, op name/type and input
// tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
// should use the default implementation.
// Override this function to add OpKernel specific attributes that are
// necessary for cost analysis.
virtual string TraceString(OpKernelContext* ctx, bool verbose);
virtual string TraceString(const OpKernelContext& ctx, bool verbose) const;
protected:
string GetTraceArgument(OpKernelContext* ctx);
string ShapeTraceString(const OpKernelContext& ctx) const;
private:
const std::shared_ptr<const NodeProperties> props_;
@ -734,7 +732,7 @@ class OpKernelContext {
// inputs. For Ref inputs use mutable_input below.
// REQUIRES: !IsRefType(input_dtype(index))
// TODO(mrry): Convert this to return Status.
const Tensor& input(int index);
const Tensor& input(int index) const;
// Returns the named immutable input tensor in "tensor", as defined
// in the OpDef. May only be used for non-Ref inputs. For Ref inputs

View File

@ -1105,7 +1105,7 @@ void BM_TraceString(const int iters, const int verbose) {
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
auto trace = op->TraceString(ctx.get(), verbose);
auto trace = op->TraceString(*ctx, verbose);
}
testing::StopTiming();
}

View File

@ -3824,6 +3824,7 @@ tf_kernel_library(
":transpose_functor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"//third_party/eigen3",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
@ -5341,7 +5342,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sendrecv_ops",
prefix = "sendrecv_ops",
deps = REQUIRED_DEPS,
deps = REQUIRED_DEPS + [
"//tensorflow/core/profiler/lib:traceme",
],
)
tf_cc_test(

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/util/einsum_op_util.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -715,15 +716,17 @@ class EinsumOp : public OpKernel {
ctx->set_output(0, output);
}
string TraceString(OpKernelContext* ctx, bool verbose) override {
if (!verbose) {
return strings::StrCat(name_view(), ":", type_string_view(),
"#equation=(", equation_, ")#");
} else {
string trace_args = GetTraceArgument(ctx);
return strings::StrCat(name_view(), ":", type_string_view(),
"#equation=(", equation_, "),", trace_args, "#");
string TraceString(const OpKernelContext& ctx, bool verbose) const override {
string op = profiler::TraceMeOp(name_view(), type_string_view());
string equation = strings::StrCat("(", equation_, ")");
if (verbose) {
string shape = ShapeTraceString(ctx);
if (!shape.empty()) {
return profiler::TraceMeEncode(
std::move(op), {{"equation", equation}, {"shape", shape}});
}
}
return profiler::TraceMeEncode(std::move(op), {{"equation", equation}});
}
private:

View File

@ -441,13 +441,18 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
});
}
string RemoteCallOp::TraceString(OpKernelContext* ctx, bool verbose) {
string trace_string =
strings::StrCat(name_view(), "__", func_.name(), ":", type_string_view());
if (!verbose) return trace_string;
string trace_args = GetTraceArgument(ctx);
if (trace_args.empty()) return trace_string;
return strings::StrCat(trace_string, "#", trace_args, "#");
string RemoteCallOp::TraceString(const OpKernelContext& ctx,
bool verbose) const {
string trace_string = profiler::TraceMeOp(
strings::StrCat(name_view(), "__", func_.name()), type_string_view());
if (verbose) {
string shape = ShapeTraceString(ctx);
if (!shape.empty()) {
trace_string =
profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
}
}
return trace_string;
}
REGISTER_KERNEL_BUILDER(

View File

@ -64,7 +64,7 @@ class RemoteCallOp : public AsyncOpKernel {
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
string TraceString(OpKernelContext* ctx, bool verbose) override;
string TraceString(const OpKernelContext& ctx, bool verbose) const override;
private:
NameAttrList func_;

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
@ -111,14 +112,14 @@ void SendOp::Compute(OpKernelContext* ctx) {
}
}
string SendOp::TraceString(OpKernelContext* ctx, bool verbose) {
string SendOp::TraceString(const OpKernelContext& ctx, bool verbose) const {
const auto& attr = def().attr();
auto src_it = attr.find("_src");
auto dst_it = attr.find("_dst");
const string& src = src_it != attr.end() ? src_it->second.s() : "";
const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src,
",to=", dst, "#");
string op = profiler::TraceMeOp(name_view(), type_string_view());
return profiler::TraceMeEncode(std::move(op), {{"from", src}, {"to", dst}});
}
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
@ -155,14 +156,14 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
}
}
string RecvOp::TraceString(OpKernelContext* ctx, bool verbose) {
string RecvOp::TraceString(const OpKernelContext& ctx, bool verbose) const {
const auto& attr = def().attr();
auto src_it = attr.find("_src");
auto dst_it = attr.find("_dst");
const string& src = src_it != attr.end() ? src_it->second.s() : "";
const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
return strings::StrCat(name_view(), ":", type_string_view(), "#from=", src,
",to=", dst, "#");
string op = profiler::TraceMeOp(name_view(), type_string_view());
return profiler::TraceMeEncode(std::move(op), {{"from", src}, {"to", dst}});
}
namespace {

View File

@ -26,7 +26,7 @@ class SendOp : public OpKernel {
explicit SendOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
string TraceString(OpKernelContext* ctx, bool verbose) override;
string TraceString(const OpKernelContext& ctx, bool verbose) const override;
private:
string key_prefix_;
@ -41,7 +41,7 @@ class RecvOp : public AsyncOpKernel {
explicit RecvOp(OpKernelConstruction* ctx);
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
string TraceString(OpKernelContext* ctx, bool verbose) override;
string TraceString(const OpKernelContext& ctx, bool verbose) const override;
private:
string key_prefix_;

View File

@ -128,6 +128,17 @@ TF_ATTRIBUTE_ALWAYS_INLINE inline std::string TraceMeEncode(
return traceme_internal::AppendArgs(std::string(), args);
}
// Concatenates op_name and op_type.
TF_ATTRIBUTE_ALWAYS_INLINE inline std::string TraceMeOp(
absl::string_view op_name, absl::string_view op_type) {
return absl::StrCat(op_name, ":", op_type);
}
TF_ATTRIBUTE_ALWAYS_INLINE inline std::string TraceMeOp(
std::string&& op_name, absl::string_view op_type) {
absl::StrAppend(&op_name, ":", op_type);
return op_name;
}
} // namespace profiler
} // namespace tensorflow