From 8156c2272776fce26d7096dc83bb272d2556593f Mon Sep 17 00:00:00 2001 From: Jay Shi Date: Tue, 2 Mar 2021 10:29:51 -0800 Subject: [PATCH] [tf.data] Collecting TraceMe metadata for `PrivateThreadPoolDataset` and `MaxIntraOpParallelismDataset`. PiperOrigin-RevId: 360460622 Change-Id: Icb9d9790981e2d69ff56b0d817bd646d86740684 --- .../experimental/threadpool_dataset_op.cc | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index 5a9c0c32b2f..6d2b558329a 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/util/work_sharder.h" @@ -279,7 +280,11 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel { int64 max_intra_op_parallelism) : DatasetBase(DatasetContext(ctx)), input_(input), - max_intra_op_parallelism_(max_intra_op_parallelism) { + max_intra_op_parallelism_(max_intra_op_parallelism), + traceme_metadata_( + {{"parallelism", + strings::Printf("%lld", static_cast( + max_intra_op_parallelism_))}}) { input_->Ref(); } @@ -371,12 +376,17 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + TraceMeMetadata GetTraceMeMetadata() const override { + return dataset()->traceme_metadata_; + } + private: std::unique_ptr input_impl_; }; const DatasetBase* const input_; const int64 max_intra_op_parallelism_; + const TraceMeMetadata traceme_metadata_; }; }; @@ -401,8 +411,10 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel { Dataset(OpKernelContext* ctx, const DatasetBase* input, int num_threads) : DatasetBase(DatasetContext(ctx)), input_(input), - num_threads_(num_threads == 0 ? port::MaxParallelism() - : num_threads) { + num_threads_(num_threads == 0 ? port::MaxParallelism() : num_threads), + traceme_metadata_( + {{"num_threads", strings::Printf("%lld", static_cast( + num_threads_))}}) { thread_pool_ = absl::make_unique( ctx->env(), ThreadOptions{}, "data_private_threadpool", num_threads_, /*low_latency_hint=*/false); @@ -498,12 +510,17 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + TraceMeMetadata GetTraceMeMetadata() const override { + return dataset()->traceme_metadata_; + } + private: std::unique_ptr input_impl_; }; const DatasetBase* const input_; const int64 num_threads_; + const TraceMeMetadata traceme_metadata_; std::unique_ptr thread_pool_; }; };