diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index a2e0a1e7626..81fcf6d7b29 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -83,7 +83,9 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, // The map that stores the live experiment names and for how much percentage // of the Borg jobs, the experiments will be randomly turned on. // clang-format off - absl::flat_hash_map live_experiments; + absl::flat_hash_map live_experiments = { + {"enable_gradient_descent", 1} + }; // clang-format on auto hash_func = [](const string& str) { return Hash64(str); }; optimizations = SelectOptimizations( diff --git a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py index ec212a7d46e..14a0eafdd01 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import functools +import os import warnings from absl.testing import parameterized @@ -224,6 +225,38 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(autotune=False, autotune_buffers=False) + + combinations.combine(autotune=True, autotune_buffers=False) + + combinations.combine(autotune=True, autotune_buffers=True), + combinations.combine(set_env=[False, True]))) + def testOptimizationEnableGradientDescent(self, autotune, autotune_buffers, + set_env): + if set_env: + os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "enable_gradient_descent" + os.environ["TF_JOB_NAME"] = "test_job" + + dataset = dataset_ops.Dataset.range(5) + dataset = dataset.prefetch(buffer_size=-1) + dataset = dataset.map(lambda x: x + 1, num_parallel_calls=2) + dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1) + dataset = dataset.prefetch(buffer_size=3) + dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1) + dataset = dataset.prefetch(buffer_size=1) + + options = dataset_ops.Options() + options.experimental_optimization.autotune = autotune + options.experimental_optimization.autotune_buffers = autotune_buffers + dataset = dataset.with_options(options) + + self.assertDatasetProduces(dataset, expected_output=list(range(3, 8))) + + if set_env: + del os.environ["TF_DATA_EXPERIMENT_OPT_IN"] + del os.environ["TF_JOB_NAME"] + @combinations.generate( combinations.times( test_base.default_test_combinations(),