[tf.data] Default autotuning to conservative values to avoid accidentally allocating too much memory before optimization loop picks values that respect the memory budget.
PiperOrigin-RevId: 355946537 Change-Id: I2e581f42fa5c016b0240fd69f0b16f08fc2fdbfd
This commit is contained in:
parent
385019cd24
commit
911d9336e2
@ -216,7 +216,20 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
if (num_parallel_calls_->value == model::kAutotune) {
|
||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||
// If autotuning is enabled, we initialize the parallelism to 1 to
|
||||
// avoid accidentally running the machine out of memory before the
|
||||
// optimization can pick values that respect the memory budget.
|
||||
//
|
||||
// If autotuning is disabled but the transformation uses `AUTOTUNE`, we
|
||||
// default the parallelism to the size of the threadpool used for
|
||||
// executing the user-defined computation. If this causes OOM, the
|
||||
// input pipeline should either enable autotuning, or replace
|
||||
// `AUTOTUNE` with fixed parallelism.
|
||||
if (TF_PREDICT_TRUE(ctx->model())) {
|
||||
num_parallel_calls_->value = 1;
|
||||
} else {
|
||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
|
@ -221,7 +221,20 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
if (num_parallel_calls_->value == model::kAutotune) {
|
||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||
// If autotuning is enabled, we initialize the parallelism to 1 to
|
||||
// avoid accidentally running the machine out of memory before the
|
||||
// optimization can pick values that respect the memory budget.
|
||||
//
|
||||
// If autotuning is disabled but the transformation uses `AUTOTUNE`, we
|
||||
// default the parallelism to the size of the threadpool used for
|
||||
// executing the user-defined computation. If this causes OOM, the
|
||||
// input pipeline should either enable autotuning, or replace
|
||||
// `AUTOTUNE` with fixed parallelism.
|
||||
if (TF_PREDICT_TRUE(ctx->model())) {
|
||||
num_parallel_calls_->value = 1;
|
||||
} else {
|
||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||
}
|
||||
}
|
||||
cancellation_manager_ =
|
||||
absl::make_unique<CancellationManager>(ctx->cancellation_manager());
|
||||
|
@ -564,22 +564,19 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# Tests that vectorization maintains the determinism setting.
|
||||
expect_determinism = local_determinism or (local_determinism is None and
|
||||
global_determinism)
|
||||
elements = list(range(1000))
|
||||
|
||||
num_elements = 1000
|
||||
def dataset_fn(delay_ms):
|
||||
|
||||
def sleep(x):
|
||||
time.sleep(delay_ms / 1000)
|
||||
# Inject random delay in the interval [0, delay_ms / 1000).
|
||||
time.sleep(delay_ms * (np.random.randint(x + 1) / (x + 1)) / 1000)
|
||||
return x
|
||||
|
||||
def map_function(x):
|
||||
if math_ops.equal(x, 0):
|
||||
return check_ops.ensure_shape(
|
||||
script_ops.py_func(sleep, [x], x.dtype, stateful=False), ())
|
||||
else:
|
||||
return x
|
||||
return check_ops.ensure_shape(
|
||||
script_ops.py_func(sleep, [x], x.dtype, stateful=False), ())
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||
dataset = dataset_ops.Dataset.range(num_elements)
|
||||
dataset = dataset.map(
|
||||
map_function, num_parallel_calls=10, deterministic=local_determinism)
|
||||
dataset = dataset.batch(1)
|
||||
@ -595,7 +592,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.checkDeterminism(
|
||||
dataset_fn,
|
||||
expect_determinism,
|
||||
expected_elements=[[element] for element in elements])
|
||||
expected_elements=[[element] for element in range(num_elements)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationIgnoreStateful(self):
|
||||
|
@ -340,6 +340,7 @@ class DatasetTestBase(test.TestCase):
|
||||
dataset = dataset_fn(delay_ms)
|
||||
actual = self.getDatasetOutput(dataset)
|
||||
self.assertCountEqual(expected_elements, actual)
|
||||
if actual[0] != expected_elements[0]:
|
||||
return
|
||||
for i in range(len(actual)):
|
||||
if actual[i] != expected_elements[i]:
|
||||
return
|
||||
self.fail("Failed to observe nondeterministic ordering")
|
||||
|
Loading…
Reference in New Issue
Block a user