[tf.data] Start to roll out the optimization enable_gradient_descent
. Also add the Python test.
PiperOrigin-RevId: 335747562 Change-Id: I9ce62981d540131f9a6297f61eaea1b6e0def492
This commit is contained in:
parent
84db924eae
commit
bd444b572f
@ -83,7 +83,9 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
// The map that stores the live experiment names and for how much percentage
|
||||
// of the Borg jobs, the experiments will be randomly turned on.
|
||||
// clang-format off
|
||||
absl::flat_hash_map<string, uint64> live_experiments;
|
||||
absl::flat_hash_map<string, uint64> live_experiments = {
|
||||
{"enable_gradient_descent", 1}
|
||||
};
|
||||
// clang-format on
|
||||
auto hash_func = [](const string& str) { return Hash64(str); };
|
||||
optimizations = SelectOptimizations(
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -224,6 +225,38 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(autotune=False, autotune_buffers=False) +
|
||||
combinations.combine(autotune=True, autotune_buffers=False) +
|
||||
combinations.combine(autotune=True, autotune_buffers=True),
|
||||
combinations.combine(set_env=[False, True])))
|
||||
def testOptimizationEnableGradientDescent(self, autotune, autotune_buffers,
|
||||
set_env):
|
||||
if set_env:
|
||||
os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "enable_gradient_descent"
|
||||
os.environ["TF_JOB_NAME"] = "test_job"
|
||||
|
||||
dataset = dataset_ops.Dataset.range(5)
|
||||
dataset = dataset.prefetch(buffer_size=-1)
|
||||
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=2)
|
||||
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
|
||||
dataset = dataset.prefetch(buffer_size=3)
|
||||
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
|
||||
dataset = dataset.prefetch(buffer_size=1)
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.autotune = autotune
|
||||
options.experimental_optimization.autotune_buffers = autotune_buffers
|
||||
dataset = dataset.with_options(options)
|
||||
|
||||
self.assertDatasetProduces(dataset, expected_output=list(range(3, 8)))
|
||||
|
||||
if set_env:
|
||||
del os.environ["TF_DATA_EXPERIMENT_OPT_IN"]
|
||||
del os.environ["TF_JOB_NAME"]
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
|
Loading…
Reference in New Issue
Block a user