[tf.data] Apply gradient descent method as default algorithm for autotuning optimization.
PiperOrigin-RevId: 341499875 Change-Id: Ie2eab5ed5e85e0c9afac1fb5b612057e51bd0e12
This commit is contained in:
parent
db1293e895
commit
b3d45cd17c
@ -84,6 +84,7 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
|||||||
// of the Borg jobs, the experiments will be randomly turned on.
|
// of the Borg jobs, the experiments will be randomly turned on.
|
||||||
// clang-format off
|
// clang-format off
|
||||||
absl::flat_hash_map<string, uint64> live_experiments = {
|
absl::flat_hash_map<string, uint64> live_experiments = {
|
||||||
|
{"enable_gradient_descent", 100},
|
||||||
{"map_parallelization", 20}
|
{"map_parallelization", 20}
|
||||||
};
|
};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
@ -110,9 +111,6 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
|||||||
|
|
||||||
// The vector stores the graduated experiment names which will be turned on
|
// The vector stores the graduated experiment names which will be turned on
|
||||||
// for all input pipelines.
|
// for all input pipelines.
|
||||||
//
|
|
||||||
// Note some of the graduated experiments may be hard coded, so not listed
|
|
||||||
// below.
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
std::vector<string> graduated_experiments = {"disable_intra_op_parallelism"};
|
std::vector<string> graduated_experiments = {"disable_intra_op_parallelism"};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
@ -245,6 +245,38 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(
|
||||||
|
test_base.default_test_combinations(),
|
||||||
|
combinations.combine(autotune=False, autotune_buffers=False) +
|
||||||
|
combinations.combine(autotune=True, autotune_buffers=False) +
|
||||||
|
combinations.combine(autotune=True, autotune_buffers=True),
|
||||||
|
combinations.combine(set_env=[False, True])))
|
||||||
|
def testOptimizationEnableGradientDescent(self, autotune, autotune_buffers,
|
||||||
|
set_env):
|
||||||
|
if set_env:
|
||||||
|
os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "enable_gradient_descent"
|
||||||
|
os.environ["TF_JOB_NAME"] = "test_job"
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(5)
|
||||||
|
dataset = dataset.prefetch(buffer_size=-1)
|
||||||
|
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=2)
|
||||||
|
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
|
||||||
|
dataset = dataset.prefetch(buffer_size=3)
|
||||||
|
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
|
||||||
|
dataset = dataset.prefetch(buffer_size=1)
|
||||||
|
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_optimization.autotune = autotune
|
||||||
|
options.experimental_optimization.autotune_buffers = autotune_buffers
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
|
||||||
|
self.assertDatasetProduces(dataset, expected_output=list(range(3, 8)))
|
||||||
|
|
||||||
|
if set_env:
|
||||||
|
del os.environ["TF_DATA_EXPERIMENT_OPT_IN"]
|
||||||
|
del os.environ["TF_JOB_NAME"]
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(
|
combinations.times(
|
||||||
test_base.default_test_combinations(),
|
test_base.default_test_combinations(),
|
||||||
@ -543,16 +575,14 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
if autotune_buffers is True: # pylint: disable=g-bool-id-comparison
|
if autotune_buffers is True: # pylint: disable=g-bool-id-comparison
|
||||||
self.assertIn("autotune_buffer_sizes", graph_rewrites.enabled)
|
self.assertIn("autotune_buffer_sizes", graph_rewrites.enabled)
|
||||||
self.assertIn("disable_prefetch_legacy_autotune", graph_rewrites.enabled)
|
self.assertIn("disable_prefetch_legacy_autotune", graph_rewrites.enabled)
|
||||||
|
self.assertEqual(algorithm,
|
||||||
|
optimization_options._AutotuneAlgorithm.GRADIENT_DESCENT)
|
||||||
else:
|
else:
|
||||||
self.assertNotIn("autotune_buffer_sizes", graph_rewrites.enabled)
|
self.assertNotIn("autotune_buffer_sizes", graph_rewrites.enabled)
|
||||||
self.assertNotIn("disable_prefetch_legacy_autotune",
|
self.assertNotIn("disable_prefetch_legacy_autotune",
|
||||||
graph_rewrites.enabled)
|
graph_rewrites.enabled)
|
||||||
if autotune_buffers is False: # pylint: disable=g-bool-id-comparison
|
|
||||||
self.assertEqual(algorithm,
|
self.assertEqual(algorithm,
|
||||||
optimization_options._AutotuneAlgorithm.HILL_CLIMB)
|
optimization_options._AutotuneAlgorithm.HILL_CLIMB)
|
||||||
else:
|
|
||||||
self.assertEqual(algorithm,
|
|
||||||
optimization_options._AutotuneAlgorithm.GRADIENT_DESCENT)
|
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(
|
combinations.times(
|
||||||
|
@ -228,8 +228,8 @@ class OptimizationOptions(options.OptionsBase):
|
|||||||
# If autotune_buffers is enabled, we use the GRADIENT_DESCENT algorithm by
|
# If autotune_buffers is enabled, we use the GRADIENT_DESCENT algorithm by
|
||||||
# default, which is more performant for tuning heterogeneous parameters.
|
# default, which is more performant for tuning heterogeneous parameters.
|
||||||
algorithm = (
|
algorithm = (
|
||||||
_AutotuneAlgorithm.HILL_CLIMB if self.autotune_buffers is False # pylint: disable=g-bool-id-comparison
|
_AutotuneAlgorithm.GRADIENT_DESCENT
|
||||||
else _AutotuneAlgorithm.GRADIENT_DESCENT)
|
if self._autotune_buffers() else _AutotuneAlgorithm.HILL_CLIMB)
|
||||||
cpu_budget = 0 # Indicates that all CPU cores should be used by default.
|
cpu_budget = 0 # Indicates that all CPU cores should be used by default.
|
||||||
ram_budget = 0 # Indicates that default value of RAM budget should be used.
|
ram_budget = 0 # Indicates that default value of RAM budget should be used.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user