diff --git a/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py b/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py index a6ee0d7dec7..9123aff4df9 100644 --- a/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py +++ b/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -30,149 +31,116 @@ from tensorflow.python.platform import test class AutotuneBenchmark(test.Benchmark): """Benchmarks for autotuning performance knobs.""" - def benchmark_map(self): - a = self._benchmark_map(autotune=False) - b = self._benchmark_map(autotune=True) - c = self._benchmark_map( - autotune=True, algorithm=dataset_ops.AutotuneAlgorithm.GRADIENT_DESCENT) - print("HillClimb vs Default speedup: %f" % (a / b)) - print("GradientDescent vs Default speedup: %f" % (a / c)) - - def _benchmark_map(self, - autotune, - algorithm=dataset_ops.AutotuneAlgorithm.HILL_CLIMB): - k = 1024 * 1024 - dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), - np.random.rand(4 * k, - 1))).repeat() - dataset = dataset.map( - math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE) + def _run_benchmark(self, dataset, autotune, autotune_buffers, + benchmark_iters, benchmark_label): options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.autotune = autotune - if autotune: - options.experimental_optimization.autotune_algorithm = algorithm.value + options.experimental_optimization.autotune_buffers = autotune_buffers dataset = dataset.with_options(options) iterator = dataset_ops.make_one_shot_iterator(dataset) get_next = iterator.get_next() + # Run the op directly to avoid copying the tensor to python. + get_next_op = nest.flatten(get_next)[0].op deltas = [] with session.Session() as sess: for _ in range(5): - sess.run(get_next.op) - for _ in range(10000): + sess.run(get_next_op) + for _ in range(benchmark_iters): start = time.time() - sess.run(get_next.op) + sess.run(get_next_op) end = time.time() deltas.append(end - start) + autotune_string = "_autotune_{}".format( + "parallelism_and_buffer_sizes" + if autotune_buffers else "parallelism_only") + self.report_benchmark( - iters=10000, + iters=benchmark_iters, wall_time=np.median(deltas), - name="map" + (("_autotune_%s" % algorithm.name) if autotune else "")) + name=benchmark_label + (autotune_string if autotune else "")) return np.median(deltas) + def benchmark_map(self): + a = self._benchmark_map(autotune=False) + b = self._benchmark_map(autotune=True, autotune_buffers=False) + c = self._benchmark_map(autotune=True, autotune_buffers=True) + print("autotune parallelism vs no autotuning speedup: {}".format(a / b)) + print("autotune parallelism and buffer sizes vs no autotuning speedup: {}" + .format(a / c)) + + def _benchmark_map(self, autotune, autotune_buffers=False): + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors( + (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))).repeat() + dataset = dataset.map( + math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE) + return self._run_benchmark( + dataset, + autotune, + autotune_buffers, + benchmark_iters=10000, + benchmark_label="map") + def benchmark_map_and_batch(self): a = self._benchmark_map_and_batch(autotune=False) - b = self._benchmark_map_and_batch(autotune=True) - c = self._benchmark_map_and_batch( - autotune=True, algorithm=dataset_ops.AutotuneAlgorithm.GRADIENT_DESCENT) - print("HillClimb vs Default speedup: %f" % (a / b)) - print("GradientDescent vs Default speedup: %f" % (a / c)) + b = self._benchmark_map_and_batch(autotune=True, autotune_buffers=False) + c = self._benchmark_map_and_batch(autotune=True, autotune_buffers=True) + print("autotune parallelism vs no autotuning speedup: {}".format(a / b)) + print("autotune parallelism and buffer sizes vs no autotuning speedup: {}" + .format(a / c)) - def _benchmark_map_and_batch( - self, autotune, algorithm=dataset_ops.AutotuneAlgorithm.HILL_CLIMB): + def _benchmark_map_and_batch(self, autotune, autotune_buffers=False): batch_size = 16 k = 1024 * 1024 - dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), - np.random.rand(4 * k, - 1))).repeat() + dataset = dataset_ops.Dataset.from_tensors( + (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))).repeat() dataset = dataset.map( math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset.batch(batch_size=batch_size) - options = dataset_ops.Options() - options.experimental_optimization.apply_default_optimizations = False - options.experimental_optimization.map_and_batch_fusion = True - options.experimental_optimization.autotune = autotune - if autotune: - options.experimental_optimization.autotune_algorithm = algorithm.value - dataset = dataset.with_options(options) - iterator = dataset_ops.make_one_shot_iterator(dataset) - get_next = iterator.get_next() - - deltas = [] - with session.Session() as sess: - for _ in range(5): - sess.run(get_next.op) - for _ in range(1000): - start = time.time() - sess.run(get_next.op) - end = time.time() - deltas.append(end - start) - - self.report_benchmark( - iters=1000, - wall_time=np.median(deltas), - name="map_and_batch" + - (("_autotune_%s" % algorithm.name) if autotune else "")) - return np.median(deltas) + return self._run_benchmark( + dataset, + autotune, + autotune_buffers, + benchmark_iters=1000, + benchmark_label="map_and_batch") def benchmark_interleave(self): a = self._benchmark_interleave(autotune=False) - b = self._benchmark_interleave(autotune=True) - c = self._benchmark_interleave( - autotune=True, algorithm=dataset_ops.AutotuneAlgorithm.GRADIENT_DESCENT) - print("HillClimb vs Default speedup: %f" % (a / b)) - print("GradientDescent vs Default speedup: %f" % (a / c)) + b = self._benchmark_interleave(autotune=True, autotune_buffers=False) + c = self._benchmark_interleave(autotune=True, autotune_buffers=True) + print("autotune parallelism vs no autotuning speedup: {}".format(a / b)) + print("autotune parallelism and buffer sizes vs no autotuning speedup: {}" + .format(a / c)) - def _benchmark_interleave(self, - autotune, - algorithm=dataset_ops.AutotuneAlgorithm.HILL_CLIMB): + def _benchmark_interleave(self, autotune, autotune_buffers=False): k = 1024 * 1024 - dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), - np.random.rand(4 * k, - 1))).repeat() + dataset = dataset_ops.Dataset.from_tensors( + (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))).repeat() dataset = dataset.map(math_ops.matmul) dataset = dataset_ops.Dataset.range(1).repeat().interleave( lambda _: dataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) - options = dataset_ops.Options() - options.experimental_optimization.apply_default_optimizations = False - options.experimental_optimization.autotune = autotune - if autotune: - options.experimental_optimization.autotune_algorithm = algorithm.value - dataset = dataset.with_options(options) - iterator = dataset_ops.make_one_shot_iterator(dataset) - get_next = iterator.get_next() - - deltas = [] - with session.Session() as sess: - for _ in range(5): - sess.run(get_next.op) - for _ in range(10000): - start = time.time() - sess.run(get_next.op) - end = time.time() - deltas.append(end - start) - - self.report_benchmark( - iters=10000, - wall_time=np.median(deltas), - name="interleave" + - (("_autotune_%s" % algorithm.name) if autotune else "")) - return np.median(deltas) + return self._run_benchmark( + dataset, + autotune, + autotune_buffers, + benchmark_iters=10000, + benchmark_label="interleave") def benchmark_map_and_interleave(self): a = self._benchmark_map_and_interleave(autotune=False) - b = self._benchmark_map_and_interleave(autotune=True) - c = self._benchmark_map_and_interleave( - autotune=True, algorithm=dataset_ops.AutotuneAlgorithm.GRADIENT_DESCENT) - print("HillClimb vs Default speedup: %f" % (a / b)) - print("GradientDescent vs Default speedup: %f" % (a / c)) + b = self._benchmark_map_and_interleave( + autotune=True, autotune_buffers=False) + c = self._benchmark_map_and_interleave(autotune=True, autotune_buffers=True) + print("autotune parallelism vs no autotuning speedup: {}".format(a / b)) + print("autotune parallelism and buffer sizes vs no autotuning speedup: {}" + .format(a / c)) - def _benchmark_map_and_interleave( - self, autotune, algorithm=dataset_ops.AutotuneAlgorithm.HILL_CLIMB): + def _benchmark_map_and_interleave(self, autotune, autotune_buffers=False): k = 1024 * 1024 a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1)) b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1)) @@ -204,42 +172,26 @@ class AutotuneBenchmark(test.Benchmark): dataset = dataset_ops.Dataset.zip((dataset, dataset_c)) dataset = dataset.map(f2, num_parallel_calls=dataset_ops.AUTOTUNE) - options = dataset_ops.Options() - options.experimental_optimization.apply_default_optimizations = False - options.experimental_optimization.autotune = autotune - if autotune: - options.experimental_optimization.autotune_algorithm = algorithm.value - dataset = dataset.with_options(options) - iterator = dataset_ops.make_one_shot_iterator(dataset) - get_next = iterator.get_next() - - deltas = [] - with session.Session() as sess: - for _ in range(5): - sess.run(get_next) - for _ in range(10000): - start = time.time() - sess.run(get_next) - end = time.time() - deltas.append(end - start) - - self.report_benchmark( - iters=10000, - wall_time=np.median(deltas), - name="map_and_interleave" + - (("_autotune_%s" % algorithm.name) if autotune else "")) - return np.median(deltas) + return self._run_benchmark( + dataset, + autotune, + autotune_buffers, + benchmark_iters=10000, + benchmark_label="map_and_interleave") def benchmark_map_batch_and_interleave(self): a = self._benchmark_map_batch_and_interleave(autotune=False) - b = self._benchmark_map_batch_and_interleave(autotune=True) + b = self._benchmark_map_batch_and_interleave( + autotune=True, autotune_buffers=False) c = self._benchmark_map_batch_and_interleave( - autotune=True, algorithm=dataset_ops.AutotuneAlgorithm.GRADIENT_DESCENT) - print("HillClimb vs Default speedup: %f" % (a / b)) - print("GradientDescent vs Default speedup: %f" % (a / c)) + autotune=True, autotune_buffers=True) + print("autotune parallelism vs no autotuning speedup: {}".format(a / b)) + print("autotune parallelism and buffer sizes vs no autotuning speedup: {}" + .format(a / c)) - def _benchmark_map_batch_and_interleave( - self, autotune, algorithm=dataset_ops.AutotuneAlgorithm.HILL_CLIMB): + def _benchmark_map_batch_and_interleave(self, + autotune, + autotune_buffers=False): batch_size = 16 k = 1024 * 1024 a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1)) @@ -268,32 +220,12 @@ class AutotuneBenchmark(test.Benchmark): math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE) dataset_c = dataset_c.batch(batch_size=batch_size) dataset = dataset_ops.Dataset.zip((dataset, dataset_c)) - options = dataset_ops.Options() - options.experimental_optimization.apply_default_optimizations = False - options.experimental_optimization.map_and_batch_fusion = True - options.experimental_optimization.autotune = autotune - if autotune: - options.experimental_optimization.autotune_algorithm = algorithm.value - dataset = dataset.with_options(options) - iterator = dataset_ops.make_one_shot_iterator(dataset) - get_next = iterator.get_next() - - deltas = [] - with session.Session() as sess: - for _ in range(5): - sess.run(get_next) - for _ in range(1000): - start = time.time() - sess.run(get_next) - end = time.time() - deltas.append(end - start) - - self.report_benchmark( - iters=1000, - wall_time=np.median(deltas), - name="map_batch_and_interleave" + - (("_autotune_%s" % algorithm.name) if autotune else "")) - return np.median(deltas) + return self._run_benchmark( + dataset, + autotune, + autotune_buffers, + benchmark_iters=1000, + benchmark_label="map_batch_and_interleave") if __name__ == "__main__": diff --git a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py index 1bd7e320466..397703e1c40 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import grouping +from tensorflow.python.data.experimental.ops import optimization_options from tensorflow.python.data.experimental.ops import scan_ops from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import threadpool @@ -215,11 +216,11 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset) self.assertGreaterEqual(len(w), 1) - expected = ("tf.data static optimizations are not compatible with " - "tf.Variable. The following optimizations will be disabled: %s." - " To enable optimizations, use resource variables instead by " + expected = ("tf.data graph rewrites are not compatible with " + "tf.Variable. The following rewrites will be disabled: %s." + " To enable rewrites, use resource variables instead by " "calling `tf.enable_resource_variables()` at the start of the " - "program." % (", ".join(options._static_optimizations()))) + "program." % (", ".join(options._graph_rewrites()))) self.assertTrue(any([expected in str(warning) for warning in w])) # Check that outputs are the same in the optimized and unoptimized cases, @@ -249,10 +250,10 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): "shuffle_and_repeat_fusion", ] self.assertEqual( - set(options._static_optimizations()), set(expected_optimizations)) + set(options._graph_rewrites()), set(expected_optimizations)) def testOptimizationDisableDefault(self): - """Tests that we can disable all static optimizations enabled by default. + """Tests that we can disable all graph optimizations enabled by default. If the `apply_default_optimizations` optimization options flag is False, only explicitly enabled optimizations will be applied. @@ -266,7 +267,27 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): "noop_elimination", ] self.assertEqual( - set(options._static_optimizations()), set(expected_optimizations)) + set(options._graph_rewrites()), set(expected_optimizations)) + + def testAutotuningDefaults(self): + options = dataset_ops.Options() + + # Check defaults + autotune, algorithm, cpu_budget = options._autotune_settings() + self.assertTrue(autotune) + self.assertEqual(algorithm, + optimization_options._AutotuneAlgorithm.HILL_CLIMB) + self.assertEqual(cpu_budget, 0) + + def testAutotuningBufferSizes(self): + options = dataset_ops.Options() + options.experimental_optimization.autotune_buffers = True + self.assertIn("inject_prefetch", options._graph_rewrites()) + autotune, algorithm, cpu_budget = options._autotune_settings() + self.assertTrue(autotune) + self.assertEqual(algorithm, + optimization_options._AutotuneAlgorithm.GRADIENT_DESCENT) + self.assertEqual(cpu_budget, 0) if __name__ == "__main__": diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py index 5de98189322..abc9eb5f0ad 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py @@ -44,9 +44,9 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"]) dataset = multi_device_iterator._dataset # pylint: disable=protected-access - self.assertIn("slack", dataset.options()._static_optimizations()) + self.assertIn("slack", dataset.options()._graph_rewrites()) self.assertIn("slack:slack_period:2", - dataset.options()._static_optimization_configs()) + dataset.options()._graph_rewrite_configs()) config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config): @@ -67,9 +67,9 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): options = dataset_ops.Options() options.experimental_slack = True dataset = dataset.with_options(options) - self.assertIn("slack", dataset.options()._static_optimizations()) + self.assertIn("slack", dataset.options()._graph_rewrites()) self.assertIn("slack:slack_period:1", - dataset.options()._static_optimization_configs()) + dataset.options()._graph_rewrite_configs()) self.assertDatasetProduces(dataset, range(10)) def testWithPassthroughDataset(self): diff --git a/tensorflow/python/data/experimental/ops/optimization_options.py b/tensorflow/python/data/experimental/ops/optimization_options.py index 57cee3d0e5f..5db4db91c17 100644 --- a/tensorflow/python/data/experimental/ops/optimization_options.py +++ b/tensorflow/python/data/experimental/ops/optimization_options.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import enum + from tensorflow.python.data.util import options from tensorflow.python.util.tf_export import tf_export @@ -24,6 +26,12 @@ from tensorflow.python.util.tf_export import tf_export _ENABLE_AUTOTUNE_BUFFERS_BY_DEFAULT = False +class _AutotuneAlgorithm(enum.Enum): + """Controls what algorithm is used in the autotune implementation.""" + HILL_CLIMB = 0 + GRADIENT_DESCENT = 1 + + @tf_export("data.experimental.MapVectorizationOptions") class MapVectorizationOptions(options.OptionsBase): """Represents options for the MapVectorization optimization.""" @@ -44,12 +52,14 @@ class MapVectorizationOptions(options.OptionsBase): "original segment at runtime based on their iterations speed. If None, " "defaults to False.") - def _static_optimizations(self): + def _graph_rewrites(self): if self.enabled: return ["map_vectorization"] return [] - def _static_optimization_configs(self): + def _graph_rewrite_configs(self): + if not self.enabled: + return [] if self.use_choose_fastest: return ["map_vectorization:use_choose_fastest:true"] else: @@ -76,7 +86,7 @@ class OptimizationOptions(options.OptionsBase): name="apply_default_optimizations", ty=bool, docstring= - "Whether to apply default static optimizations. If False, only static " + "Whether to apply default graph optimizations. If False, only graph " "optimizations that have been explicitly enabled will be applied.") autotune = options.create_option( @@ -86,13 +96,6 @@ class OptimizationOptions(options.OptionsBase): "Whether to automatically tune performance knobs. If None, defaults to " "True.") - autotune_algorithm = options.create_option( - name="autotune_algorithm", - ty=int, - docstring= - "When autotuning is enabled (through `autotune`), identifies the " - "algorithm to use for the autotuning optimization.") - autotune_buffers = options.create_option( name="autotune_buffers", ty=bool, @@ -183,8 +186,34 @@ class OptimizationOptions(options.OptionsBase): docstring="Whether to fuse shuffle and repeat transformations. If None, " "defaults to True.") - def _static_optimizations(self): - """Produces the list of enabled static optimizations.""" + def _autotune_buffers(self): + if self.autotune_buffers is not None: + return self.autotune_buffers + # The default setting for autotune_buffers is based on + # _ENABLE_AUTOTUNE_BUFFERS_BY_DEFAULT + return _ENABLE_AUTOTUNE_BUFFERS_BY_DEFAULT + + def _autotune_settings(self): + # Default autotune settings + autotune = True + + # If autotune_buffers is enabled, we use the GRADIENT_DESCENT algorithm by + # default, which is more performant for tuning heterogeneous parameters. + algorithm = ( + _AutotuneAlgorithm.GRADIENT_DESCENT + if self._autotune_buffers() else _AutotuneAlgorithm.HILL_CLIMB) + cpu_budget = 0 # Indicates that all CPU cores should be used by default. + + # Set these options if they are explicitly set by the user. + if self.autotune is False: # pylint: disable=g-bool-id-comparison + autotune = False + if self.autotune_cpu_budget is not None: + cpu_budget = self.autotune_cpu_budget + + return autotune, algorithm, cpu_budget + + def _graph_rewrites(self): + """Produces the list of enabled graph optimizations.""" result = set() all_optimizations = [ "filter_fusion", @@ -215,17 +244,19 @@ class OptimizationOptions(options.OptionsBase): result.add(optimization) if self.map_vectorization is not None: - result.update(self.map_vectorization._static_optimizations()) # pylint: disable=protected-access + result.update(self.map_vectorization._graph_rewrites()) # pylint: disable=protected-access - # The default setting for autotune_buffers is based on - # _ENABLE_AUTOTUNE_BUFFERS_BY_DEFAULT - autotune_buffers = self.autotune_buffers or ( - self.autotune_buffers is None and _ENABLE_AUTOTUNE_BUFFERS_BY_DEFAULT) + autotune_buffers = self._autotune_buffers() if self.autotune is not False and autotune_buffers: # pylint: disable=g-bool-id-comparison + # When autotuning buffer sizes is enabled, we inject a `prefetch` + # transformation after asynchronous dataset ops. Only the buffer sizes of + # prefetch transformations will be autotuned, though this is practically + # equivalent to tuning the buffer sizes of the other asynchronous + # transformations. result.add("inject_prefetch") return sorted(list(result)) - def _static_optimization_configs(self): + def _graph_rewrite_configs(self): if self.map_vectorization is not None: - return self.map_vectorization._static_optimization_configs() # pylint: disable=protected-access + return self.map_vectorization._graph_rewrite_configs() # pylint: disable=protected-access return [] diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index f3367023a7b..06bdfd03eb8 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -29,7 +29,6 @@ import numpy as np import six from six.moves import queue as Queue # pylint: disable=redefined-builtin - from tensorflow.core.framework import graph_pb2 from tensorflow.python import tf2 from tensorflow.python.compat import compat @@ -90,17 +89,11 @@ autograph = lazy_loader.LazyLoader( ops.NotDifferentiable("ReduceDataset") - # A constant that can be used to enable auto-tuning. AUTOTUNE = -1 tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE") -class AutotuneAlgorithm(enum.Enum): - HILL_CLIMB = 0 - GRADIENT_DESCENT = 1 - - class ExternalStatePolicy(enum.Enum): WARN = 0 IGNORE = 1 @@ -227,9 +220,9 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): def. In that case, the state in these ops would be thrown away. strip_device_assignment: If true, non-local (i.e. job and task) device assignment is stripped from ops in the serialized graph. - external_state_policy: The ExternalStatePolicy enum that determines how - we handle input pipelines that depend on external state. By default, - its set to WARN. + external_state_policy: The ExternalStatePolicy enum that determines how we + handle input pipelines that depend on external state. By default, its + set to WARN. Returns: A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a @@ -355,6 +348,8 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): dataset = self options = self.options() + + # (1) Apply threading options if options.experimental_threading is not None: t_options = options.experimental_threading if t_options.max_intra_op_parallelism is not None: @@ -363,36 +358,31 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): if t_options.private_threadpool_size is not None: dataset = _PrivateThreadPoolDataset(dataset, t_options.private_threadpool_size) + + # (2) Apply graph rewrite options # pylint: disable=protected-access - static_optimizations = options._static_optimizations() - static_optimization_configs = options._static_optimization_configs() + graph_rewrites = options._graph_rewrites() + graph_rewrite_configs = options._graph_rewrite_configs() # pylint: enable=protected-access - if static_optimizations: + if graph_rewrites: if self._has_captured_ref(): warnings.warn( - "tf.data static optimizations are not compatible with tf.Variable. " - "The following optimizations will be disabled: %s. To enable " - "optimizations, use resource variables instead by calling " + "tf.data graph rewrites are not compatible with tf.Variable. " + "The following rewrites will be disabled: %s. To enable " + "rewrites, use resource variables instead by calling " "`tf.enable_resource_variables()` at the start of the program." % - ", ".join(static_optimizations)) + ", ".join(graph_rewrites)) else: - dataset = _OptimizeDataset(dataset, static_optimizations, - static_optimization_configs) + dataset = _OptimizeDataset(dataset, graph_rewrites, + graph_rewrite_configs) - autotune = True - algorithm = AutotuneAlgorithm.HILL_CLIMB - cpu_budget = 0 # Indicates that all CPU cores should be used. - if options.experimental_optimization is not None: - if options.experimental_optimization.autotune is False: # pylint: disable=g-bool-id-comparison - autotune = False - if options.experimental_optimization.autotune_algorithm is not None: - algorithm = options.experimental_optimization.autotune_algorithm - if options.experimental_optimization.autotune_cpu_budget is not None: - cpu_budget = options.experimental_optimization.autotune_cpu_budget + # (3) Apply autotune options + autotune, algorithm, cpu_budget = options._autotune_settings() # pylint: disable=protected-access if autotune: dataset = _ModelDataset(dataset, algorithm, cpu_budget) + # (4) Apply stats aggregator options if options.experimental_stats and options.experimental_stats.aggregator: # pylint: disable=line-too-long dataset = _SetStatsAggregatorDataset( # pylint: disable=protected-access dataset, options.experimental_stats.aggregator, @@ -2600,7 +2590,7 @@ def get_legacy_output_types(dataset_or_iterator): class Options(options_lib.OptionsBase): """Represents options for tf.data.Dataset. - An `Options` object can be, for instance, used to control which static + An `Options` object can be, for instance, used to control which graph optimizations to apply or whether to use performance modeling to dynamically tune the parallelism of operations such as `tf.data.Dataset.map` or `tf.data.Dataset.interleave`. @@ -2675,11 +2665,15 @@ class Options(options_lib.OptionsBase): "might be thrown away; FAIL: We fail if any state is being captured.", default_factory=lambda: ExternalStatePolicy.WARN) - def _static_optimizations(self): - """Produces the list of enabled static optimizations.""" - + def _graph_rewrites(self): + """Produces the list of enabled static graph rewrites.""" result = [] - result.extend(self.experimental_optimization._static_optimizations()) # pylint: disable=protected-access + if self.experimental_optimization is not None: + result.extend(self.experimental_optimization._graph_rewrites()) # pylint: disable=protected-access + else: + # Apply default options + result.extend( + optimization_options.OptimizationOptions()._graph_rewrites()) # pylint: disable=protected-access if self.experimental_deterministic is False: result.append("make_sloppy") @@ -2692,12 +2686,11 @@ class Options(options_lib.OptionsBase): result.append("make_stateless") return result - def _static_optimization_configs(self): - """Produces the list of configurations for enabled static optimizations.""" + def _graph_rewrite_configs(self): + """Produces the list of configurations for enabled graph optimizations.""" result = [] if self.experimental_optimization: - result.extend( - self.experimental_optimization._static_optimization_configs()) # pylint: disable=protected-access + result.extend(self.experimental_optimization._graph_rewrite_configs()) # pylint: disable=protected-access if self.experimental_slack: num_devices = self.experimental_distribute.num_devices @@ -2706,6 +2699,13 @@ class Options(options_lib.OptionsBase): result.append("slack:slack_period:%d" % num_devices) return result + def _autotune_settings(self): + if self.experimental_optimization is not None: + return self.experimental_optimization._autotune_settings() # pylint: disable=protected-access + + # Return default autotune options + return optimization_options.OptimizationOptions()._autotune_settings() # pylint: disable=protected-access + def merge(self, options): """Merges itself with the given `tf.data.Options`. @@ -4177,20 +4177,11 @@ class _ModelDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset, algorithm, cpu_budget): self._input_dataset = input_dataset - # TODO(jsimsa): This check is introduced for forward compatibility and can - # be removed after 7/24/2019. At that point, all servers are expected to - # recognize the `algorithm` attribute. - if algorithm != AutotuneAlgorithm.HILL_CLIMB: - variant_tensor = gen_dataset_ops.model_dataset( - input_dataset._variant_tensor, # pylint: disable=protected-access - algorithm=algorithm, - cpu_budget=cpu_budget, - **self._flat_structure) - else: - variant_tensor = gen_dataset_ops.model_dataset( - input_dataset._variant_tensor, # pylint: disable=protected-access - cpu_budget=cpu_budget, - **self._flat_structure) + variant_tensor = gen_dataset_ops.model_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access + algorithm=algorithm.value, + cpu_budget=cpu_budget, + **self._flat_structure) super(_ModelDataset, self).__init__(input_dataset, variant_tensor) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt index f7301ff180c..a79d205cf0b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt @@ -11,10 +11,6 @@ tf_class { name: "autotune" mtype: "" } - member { - name: "autotune_algorithm" - mtype: "" - } member { name: "autotune_buffers" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt index f7301ff180c..a79d205cf0b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt @@ -11,10 +11,6 @@ tf_class { name: "autotune" mtype: "" } - member { - name: "autotune_algorithm" - mtype: "" - } member { name: "autotune_buffers" mtype: ""