[tf.data] Start to roll out the optimization enable_gradient_descent. Also add the Python test.

PiperOrigin-RevId: 335747562
Change-Id: I9ce62981d540131f9a6297f61eaea1b6e0def492
This commit is contained in:
Jay Shi 2020-10-06 16:38:30 -07:00 committed by TensorFlower Gardener
parent 84db924eae
commit bd444b572f
2 changed files with 36 additions and 1 deletions

View File

@ -83,7 +83,9 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
// The map that stores the live experiment names and for how much percentage
// of the Borg jobs, the experiments will be randomly turned on.
// clang-format off
absl::flat_hash_map<string, uint64> live_experiments;
absl::flat_hash_map<string, uint64> live_experiments = {
{"enable_gradient_descent", 1}
};
// clang-format on
auto hash_func = [](const string& str) { return Hash64(str); };
optimizations = SelectOptimizations(

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import functools
import os
import warnings
from absl.testing import parameterized
@ -224,6 +225,38 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(autotune=False, autotune_buffers=False) +
combinations.combine(autotune=True, autotune_buffers=False) +
combinations.combine(autotune=True, autotune_buffers=True),
combinations.combine(set_env=[False, True])))
def testOptimizationEnableGradientDescent(self, autotune, autotune_buffers,
set_env):
if set_env:
os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "enable_gradient_descent"
os.environ["TF_JOB_NAME"] = "test_job"
dataset = dataset_ops.Dataset.range(5)
dataset = dataset.prefetch(buffer_size=-1)
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=2)
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
dataset = dataset.prefetch(buffer_size=3)
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
dataset = dataset.prefetch(buffer_size=1)
options = dataset_ops.Options()
options.experimental_optimization.autotune = autotune
options.experimental_optimization.autotune_buffers = autotune_buffers
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=list(range(3, 8)))
if set_env:
del os.environ["TF_DATA_EXPERIMENT_OPT_IN"]
del os.environ["TF_JOB_NAME"]
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),