[tf.data] Collecting TraceMe metadata for PrivateThreadPoolDataset
and MaxIntraOpParallelismDataset
.
PiperOrigin-RevId: 360460622 Change-Id: Icb9d9790981e2d69ff56b0d817bd646d86740684
This commit is contained in:
parent
8c1068332b
commit
8156c22727
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/stringprintf.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
@ -279,7 +280,11 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
||||
int64 max_intra_op_parallelism)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
max_intra_op_parallelism_(max_intra_op_parallelism) {
|
||||
max_intra_op_parallelism_(max_intra_op_parallelism),
|
||||
traceme_metadata_(
|
||||
{{"parallelism",
|
||||
strings::Printf("%lld", static_cast<long long>(
|
||||
max_intra_op_parallelism_))}}) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -371,12 +376,17 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TraceMeMetadata GetTraceMeMetadata() const override {
|
||||
return dataset()->traceme_metadata_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
const int64 max_intra_op_parallelism_;
|
||||
const TraceMeMetadata traceme_metadata_;
|
||||
};
|
||||
};
|
||||
|
||||
@ -401,8 +411,10 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, int num_threads)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
num_threads_(num_threads == 0 ? port::MaxParallelism()
|
||||
: num_threads) {
|
||||
num_threads_(num_threads == 0 ? port::MaxParallelism() : num_threads),
|
||||
traceme_metadata_(
|
||||
{{"num_threads", strings::Printf("%lld", static_cast<long long>(
|
||||
num_threads_))}}) {
|
||||
thread_pool_ = absl::make_unique<thread::ThreadPool>(
|
||||
ctx->env(), ThreadOptions{}, "data_private_threadpool", num_threads_,
|
||||
/*low_latency_hint=*/false);
|
||||
@ -498,12 +510,17 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TraceMeMetadata GetTraceMeMetadata() const override {
|
||||
return dataset()->traceme_metadata_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
const int64 num_threads_;
|
||||
const TraceMeMetadata traceme_metadata_;
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
};
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user