[tf.data] Collecting TraceMe metadata for PrivateThreadPoolDataset and MaxIntraOpParallelismDataset.

PiperOrigin-RevId: 360460622
Change-Id: Icb9d9790981e2d69ff56b0d817bd646d86740684
This commit is contained in:
Jay Shi 2021-03-02 10:29:51 -08:00 committed by TensorFlower Gardener
parent 8c1068332b
commit 8156c22727

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/stringprintf.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/work_sharder.h"
@ -279,7 +280,11 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
int64 max_intra_op_parallelism)
: DatasetBase(DatasetContext(ctx)),
input_(input),
max_intra_op_parallelism_(max_intra_op_parallelism) {
max_intra_op_parallelism_(max_intra_op_parallelism),
traceme_metadata_(
{{"parallelism",
strings::Printf("%lld", static_cast<long long>(
max_intra_op_parallelism_))}}) {
input_->Ref();
}
@ -371,12 +376,17 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
TraceMeMetadata GetTraceMeMetadata() const override {
return dataset()->traceme_metadata_;
}
private:
std::unique_ptr<IteratorBase> input_impl_;
};
const DatasetBase* const input_;
const int64 max_intra_op_parallelism_;
const TraceMeMetadata traceme_metadata_;
};
};
@ -401,8 +411,10 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
Dataset(OpKernelContext* ctx, const DatasetBase* input, int num_threads)
: DatasetBase(DatasetContext(ctx)),
input_(input),
num_threads_(num_threads == 0 ? port::MaxParallelism()
: num_threads) {
num_threads_(num_threads == 0 ? port::MaxParallelism() : num_threads),
traceme_metadata_(
{{"num_threads", strings::Printf("%lld", static_cast<long long>(
num_threads_))}}) {
thread_pool_ = absl::make_unique<thread::ThreadPool>(
ctx->env(), ThreadOptions{}, "data_private_threadpool", num_threads_,
/*low_latency_hint=*/false);
@ -498,12 +510,17 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
TraceMeMetadata GetTraceMeMetadata() const override {
return dataset()->traceme_metadata_;
}
private:
std::unique_ptr<IteratorBase> input_impl_;
};
const DatasetBase* const input_;
const int64 num_threads_;
const TraceMeMetadata traceme_metadata_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
};
};