Update thread utilization async in map_and_batch dataset.
PiperOrigin-RevId: 291070968 Change-Id: Ibc416c6a60fdbf6c72e5beeb09a79966e1e07382
This commit is contained in:
parent
3ab94ef0f3
commit
8c6713c89c
tensorflow
core/kernels/data/experimental
python/data/experimental
@ -71,9 +71,6 @@ constexpr char kStatus[] = "status";
|
||||
constexpr char kCode[] = "code";
|
||||
constexpr char kMessage[] = "msg";
|
||||
|
||||
// Period between reporting dataset statistics.
|
||||
constexpr int kStatsReportingPeriodMillis = 1000;
|
||||
|
||||
class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
|
||||
@ -221,7 +218,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
std::shared_ptr<BatchResult> result;
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
EnsureThreadsStarted(ctx);
|
||||
EnsureRunnerThreadStarted(ctx);
|
||||
while (!cancelled_ && (batch_results_.empty() ||
|
||||
batch_results_.front()->num_calls > 0)) {
|
||||
++waiting_;
|
||||
@ -450,18 +447,13 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!runner_thread_) {
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ = ctx->StartThread(
|
||||
kTFDataMapAndBatch,
|
||||
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
||||
if (ctx->stats_aggregator()) {
|
||||
stats_thread_ = ctx->StartThread(
|
||||
"tf_data_map_and_batch_stats",
|
||||
std::bind(&Iterator::StatsThread, this, ctx_copy));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -588,6 +580,15 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
num_calls_++;
|
||||
}
|
||||
}
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator) {
|
||||
mutex_lock l(*mu_);
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
|
||||
static_cast<float>(num_calls_) /
|
||||
static_cast<float>(num_parallel_calls_->value),
|
||||
num_elements());
|
||||
}
|
||||
for (const auto& call : new_calls) {
|
||||
CallFunction(ctx, call.first, call.second);
|
||||
}
|
||||
@ -595,34 +596,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
|
||||
for (int64 step = 0;; ++step) {
|
||||
int num_calls;
|
||||
int num_parallel_calls;
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
if (step != 0 && !cancelled_) {
|
||||
cond_var_->wait_for(
|
||||
l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
|
||||
}
|
||||
if (cancelled_) {
|
||||
return;
|
||||
}
|
||||
num_calls = num_calls_;
|
||||
num_parallel_calls = num_parallel_calls_->value;
|
||||
}
|
||||
if (num_parallel_calls == 0) {
|
||||
// Avoid division by zero.
|
||||
num_parallel_calls = 1;
|
||||
}
|
||||
ctx->stats_aggregator()->AddScalar(
|
||||
stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
|
||||
static_cast<float>(num_calls) /
|
||||
static_cast<float>(num_parallel_calls),
|
||||
step);
|
||||
}
|
||||
}
|
||||
|
||||
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
|
||||
size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
batch_results_.push_back(
|
||||
@ -762,7 +735,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
|
||||
// Background thread used for coordinating input processing.
|
||||
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
|
||||
std::unique_ptr<Thread> stats_thread_ GUARDED_BY(*mu_);
|
||||
// Determines whether the transformation has been cancelled.
|
||||
bool cancelled_ GUARDED_BY(*mu_) = false;
|
||||
// Identifies the number of callers currently waiting for a batch result.
|
||||
|
@ -80,7 +80,6 @@ tf_py_test(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python/data/benchmarks:benchmark_base",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
|
@ -25,9 +25,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import stats_aggregator
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -39,7 +37,7 @@ from tensorflow.python.platform import test
|
||||
_NUMPY_RANDOM_SEED = 42
|
||||
|
||||
|
||||
class MapAndBatchBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
class MapAndBatchBenchmark(test.Benchmark):
|
||||
"""Benchmarks for `tf.data.experimental.map_and_batch()`."""
|
||||
|
||||
def benchmark_map_and_batch(self):
|
||||
@ -202,16 +200,6 @@ class MapAndBatchBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
benchmark("Transformation parallelism evaluation", par_num_calls_series)
|
||||
benchmark("Threadpool size evaluation", par_inter_op_series)
|
||||
|
||||
def benchmark_stats(self):
|
||||
dataset = dataset_ops.Dataset.range(1).repeat()
|
||||
dataset = dataset.apply(
|
||||
batching.map_and_batch(lambda x: x + 1, 1), num_parallel_calls=32)
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_stats.aggregator = aggregator
|
||||
dataset = dataset.with_options(options)
|
||||
self.run_and_report_benchmark(dataset, num_elements=1000, name="stats")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -380,8 +380,6 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase,
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testMapAndBatchAutoTuneBufferUtilization(self):
|
||||
self.skipTest("b/147897892: This test is flaky because thread utilization "
|
||||
"is recorded asynchronously")
|
||||
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.range(100).apply(
|
||||
|
Loading…
Reference in New Issue
Block a user