[tf.data] Default autotuning to conservative values to avoid accidentally allocating too much memory before optimization loop picks values that respect the memory budget.

PiperOrigin-RevId: 355946537
Change-Id: I2e581f42fa5c016b0240fd69f0b16f08fc2fdbfd
This commit is contained in:
Jiri Simsa 2021-02-05 15:56:03 -08:00 committed by TensorFlower Gardener
parent 385019cd24
commit 911d9336e2
4 changed files with 38 additions and 14 deletions

View File

@ -216,7 +216,20 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
// If autotuning is enabled, we initialize the parallelism to 1 to
// avoid accidentally running the machine out of memory before the
// optimization can pick values that respect the memory budget.
//
// If autotuning is disabled but the transformation uses `AUTOTUNE`, we
// default the parallelism to the size of the threadpool used for
// executing the user-defined computation. If this causes OOM, the
// input pipeline should either enable autotuning, or replace
// `AUTOTUNE` with fixed parallelism.
if (TF_PREDICT_TRUE(ctx->model())) {
num_parallel_calls_->value = 1;
} else {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
}
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),

View File

@ -221,7 +221,20 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
// If autotuning is enabled, we initialize the parallelism to 1 to
// avoid accidentally running the machine out of memory before the
// optimization can pick values that respect the memory budget.
//
// If autotuning is disabled but the transformation uses `AUTOTUNE`, we
// default the parallelism to the size of the threadpool used for
// executing the user-defined computation. If this causes OOM, the
// input pipeline should either enable autotuning, or replace
// `AUTOTUNE` with fixed parallelism.
if (TF_PREDICT_TRUE(ctx->model())) {
num_parallel_calls_->value = 1;
} else {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
}
cancellation_manager_ =
absl::make_unique<CancellationManager>(ctx->cancellation_manager());

View File

@ -564,22 +564,19 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
# Tests that vectorization maintains the determinism setting.
expect_determinism = local_determinism or (local_determinism is None and
global_determinism)
elements = list(range(1000))
num_elements = 1000
def dataset_fn(delay_ms):
def sleep(x):
time.sleep(delay_ms / 1000)
# Inject random delay in the interval [0, delay_ms / 1000).
time.sleep(delay_ms * (np.random.randint(x + 1) / (x + 1)) / 1000)
return x
def map_function(x):
if math_ops.equal(x, 0):
return check_ops.ensure_shape(
script_ops.py_func(sleep, [x], x.dtype, stateful=False), ())
else:
return x
return check_ops.ensure_shape(
script_ops.py_func(sleep, [x], x.dtype, stateful=False), ())
dataset = dataset_ops.Dataset.from_tensor_slices(elements)
dataset = dataset_ops.Dataset.range(num_elements)
dataset = dataset.map(
map_function, num_parallel_calls=10, deterministic=local_determinism)
dataset = dataset.batch(1)
@ -595,7 +592,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
self.checkDeterminism(
dataset_fn,
expect_determinism,
expected_elements=[[element] for element in elements])
expected_elements=[[element] for element in range(num_elements)])
@combinations.generate(test_base.default_test_combinations())
def testOptimizationIgnoreStateful(self):

View File

@ -340,6 +340,7 @@ class DatasetTestBase(test.TestCase):
dataset = dataset_fn(delay_ms)
actual = self.getDatasetOutput(dataset)
self.assertCountEqual(expected_elements, actual)
if actual[0] != expected_elements[0]:
return
for i in range(len(actual)):
if actual[i] != expected_elements[i]:
return
self.fail("Failed to observe nondeterministic ordering")