Update thread utilization async in map_and_batch dataset.

PiperOrigin-RevId: 291070968
Change-Id: Ibc416c6a60fdbf6c72e5beeb09a79966e1e07382
This commit is contained in:
Anna R 2020-01-22 17:54:39 -08:00 committed by TensorFlower Gardener
parent 3ab94ef0f3
commit 8c6713c89c
4 changed files with 12 additions and 55 deletions
tensorflow
core/kernels/data/experimental
python/data/experimental

View File

@ -71,9 +71,6 @@ constexpr char kStatus[] = "status";
constexpr char kCode[] = "code";
constexpr char kMessage[] = "msg";
// Period between reporting dataset statistics.
constexpr int kStatsReportingPeriodMillis = 1000;
class MapAndBatchDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
@ -221,7 +218,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
std::shared_ptr<BatchResult> result;
{
mutex_lock l(*mu_);
EnsureThreadsStarted(ctx);
EnsureRunnerThreadStarted(ctx);
while (!cancelled_ && (batch_results_.empty() ||
batch_results_.front()->num_calls > 0)) {
++waiting_;
@ -450,18 +447,13 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
void EnsureThreadsStarted(IteratorContext* ctx)
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
kTFDataMapAndBatch,
std::bind(&Iterator::RunnerThread, this, ctx_copy));
if (ctx->stats_aggregator()) {
stats_thread_ = ctx->StartThread(
"tf_data_map_and_batch_stats",
std::bind(&Iterator::StatsThread, this, ctx_copy));
}
}
}
@ -588,6 +580,15 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
num_calls_++;
}
}
const auto& stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
mutex_lock l(*mu_);
stats_aggregator->AddScalar(
stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
static_cast<float>(num_calls_) /
static_cast<float>(num_parallel_calls_->value),
num_elements());
}
for (const auto& call : new_calls) {
CallFunction(ctx, call.first, call.second);
}
@ -595,34 +596,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
}
}
void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
for (int64 step = 0;; ++step) {
int num_calls;
int num_parallel_calls;
{
mutex_lock l(*mu_);
if (step != 0 && !cancelled_) {
cond_var_->wait_for(
l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
}
if (cancelled_) {
return;
}
num_calls = num_calls_;
num_parallel_calls = num_parallel_calls_->value;
}
if (num_parallel_calls == 0) {
// Avoid division by zero.
num_parallel_calls = 1;
}
ctx->stats_aggregator()->AddScalar(
stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
static_cast<float>(num_calls) /
static_cast<float>(num_parallel_calls),
step);
}
}
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
batch_results_.push_back(
@ -762,7 +735,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
// Background thread used for coordinating input processing.
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
std::unique_ptr<Thread> stats_thread_ GUARDED_BY(*mu_);
// Determines whether the transformation has been cancelled.
bool cancelled_ GUARDED_BY(*mu_) = false;
// Identifies the number of callers currently waiting for a batch result.

View File

@ -80,7 +80,6 @@ tf_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:session",
"//tensorflow/python/data/benchmarks:benchmark_base",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",

View File

@ -25,9 +25,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -39,7 +37,7 @@ from tensorflow.python.platform import test
_NUMPY_RANDOM_SEED = 42
class MapAndBatchBenchmark(benchmark_base.DatasetBenchmarkBase):
class MapAndBatchBenchmark(test.Benchmark):
"""Benchmarks for `tf.data.experimental.map_and_batch()`."""
def benchmark_map_and_batch(self):
@ -202,16 +200,6 @@ class MapAndBatchBenchmark(benchmark_base.DatasetBenchmarkBase):
benchmark("Transformation parallelism evaluation", par_num_calls_series)
benchmark("Threadpool size evaluation", par_inter_op_series)
def benchmark_stats(self):
dataset = dataset_ops.Dataset.range(1).repeat()
dataset = dataset.apply(
batching.map_and_batch(lambda x: x + 1, 1), num_parallel_calls=32)
aggregator = stats_aggregator.StatsAggregator()
options = dataset_ops.Options()
options.experimental_stats.aggregator = aggregator
dataset = dataset.with_options(options)
self.run_and_report_benchmark(dataset, num_elements=1000, name="stats")
if __name__ == "__main__":
test.main()

View File

@ -380,8 +380,6 @@ class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase,
@combinations.generate(test_base.eager_only_combinations())
def testMapAndBatchAutoTuneBufferUtilization(self):
self.skipTest("b/147897892: This test is flaky because thread utilization "
"is recorded asynchronously")
def dataset_fn():
return dataset_ops.Dataset.range(100).apply(