[tf.data] Default autotuning to conservative values to avoid accidentally allocating too much memory before optimization loop picks values that respect the memory budget.

PiperOrigin-RevId: 355946537
Change-Id: I2e581f42fa5c016b0240fd69f0b16f08fc2fdbfd
This commit is contained in:
Jiri Simsa 2021-02-05 15:56:03 -08:00 committed by TensorFlower Gardener
parent 385019cd24
commit 911d9336e2
4 changed files with 38 additions and 14 deletions

View File

@ -216,7 +216,20 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_); mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) { if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size(); // If autotuning is enabled, we initialize the parallelism to 1 to
// avoid accidentally running the machine out of memory before the
// optimization can pick values that respect the memory budget.
//
// If autotuning is disabled but the transformation uses `AUTOTUNE`, we
// default the parallelism to the size of the threadpool used for
// executing the user-defined computation. If this causes OOM, the
// input pipeline should either enable autotuning, or replace
// `AUTOTUNE` with fixed parallelism.
if (TF_PREDICT_TRUE(ctx->model())) {
num_parallel_calls_->value = 1;
} else {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
} }
TF_RETURN_IF_ERROR(RegisterCancellationCallback( TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(), ctx->cancellation_manager(),

View File

@ -221,7 +221,20 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_); mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) { if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size(); // If autotuning is enabled, we initialize the parallelism to 1 to
// avoid accidentally running the machine out of memory before the
// optimization can pick values that respect the memory budget.
//
// If autotuning is disabled but the transformation uses `AUTOTUNE`, we
// default the parallelism to the size of the threadpool used for
// executing the user-defined computation. If this causes OOM, the
// input pipeline should either enable autotuning, or replace
// `AUTOTUNE` with fixed parallelism.
if (TF_PREDICT_TRUE(ctx->model())) {
num_parallel_calls_->value = 1;
} else {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
} }
cancellation_manager_ = cancellation_manager_ =
absl::make_unique<CancellationManager>(ctx->cancellation_manager()); absl::make_unique<CancellationManager>(ctx->cancellation_manager());

View File

@ -564,22 +564,19 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
# Tests that vectorization maintains the determinism setting. # Tests that vectorization maintains the determinism setting.
expect_determinism = local_determinism or (local_determinism is None and expect_determinism = local_determinism or (local_determinism is None and
global_determinism) global_determinism)
elements = list(range(1000)) num_elements = 1000
def dataset_fn(delay_ms): def dataset_fn(delay_ms):
def sleep(x): def sleep(x):
time.sleep(delay_ms / 1000) # Inject random delay in the interval [0, delay_ms / 1000).
time.sleep(delay_ms * (np.random.randint(x + 1) / (x + 1)) / 1000)
return x return x
def map_function(x): def map_function(x):
if math_ops.equal(x, 0): return check_ops.ensure_shape(
return check_ops.ensure_shape( script_ops.py_func(sleep, [x], x.dtype, stateful=False), ())
script_ops.py_func(sleep, [x], x.dtype, stateful=False), ())
else:
return x
dataset = dataset_ops.Dataset.from_tensor_slices(elements) dataset = dataset_ops.Dataset.range(num_elements)
dataset = dataset.map( dataset = dataset.map(
map_function, num_parallel_calls=10, deterministic=local_determinism) map_function, num_parallel_calls=10, deterministic=local_determinism)
dataset = dataset.batch(1) dataset = dataset.batch(1)
@ -595,7 +592,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
self.checkDeterminism( self.checkDeterminism(
dataset_fn, dataset_fn,
expect_determinism, expect_determinism,
expected_elements=[[element] for element in elements]) expected_elements=[[element] for element in range(num_elements)])
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testOptimizationIgnoreStateful(self): def testOptimizationIgnoreStateful(self):

View File

@ -340,6 +340,7 @@ class DatasetTestBase(test.TestCase):
dataset = dataset_fn(delay_ms) dataset = dataset_fn(delay_ms)
actual = self.getDatasetOutput(dataset) actual = self.getDatasetOutput(dataset)
self.assertCountEqual(expected_elements, actual) self.assertCountEqual(expected_elements, actual)
if actual[0] != expected_elements[0]: for i in range(len(actual)):
return if actual[i] != expected_elements[i]:
return
self.fail("Failed to observe nondeterministic ordering") self.fail("Failed to observe nondeterministic ordering")