1) Remove autotune_algorithm from experimental_options. Instead, use HILL_CLIMB when buffer size autotuning is disabled and GRADIENT_DESCENT when buffer size autotuning is enabled.

2) Some refactoring:
  a) s/static_optimization/static_rewrite, because not all our rewrites are
     'optimizations', so to speak
  b) moved logic for determining which autotuning options to apply into
    `optimization_options.py`

PiperOrigin-RevId: 283867649
Change-Id: Ica01a469f8c0039b11db2aa2304a50c700a5ddd7
This commit is contained in:
Rachel Lim 2019-12-04 16:54:49 -08:00 committed by TensorFlower Gardener
parent 30b99dd053
commit b656428fb1
7 changed files with 217 additions and 250 deletions

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -30,149 +31,116 @@ from tensorflow.python.platform import test
class AutotuneBenchmark(test.Benchmark):
"""Benchmarks for autotuning performance knobs."""
def benchmark_map(self):
a = self._benchmark_map(autotune=False)
b = self._benchmark_map(autotune=True)
c = self._benchmark_map(
autotune=True, algorithm=dataset_ops.AutotuneAlgorithm.GRADIENT_DESCENT)
print("HillClimb vs Default speedup: %f" % (a / b))
print("GradientDescent vs Default speedup: %f" % (a / c))
def _benchmark_map(self,
autotune,
algorithm=dataset_ops.AutotuneAlgorithm.HILL_CLIMB):
k = 1024 * 1024
dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
np.random.rand(4 * k,
1))).repeat()
dataset = dataset.map(
math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE)
def _run_benchmark(self, dataset, autotune, autotune_buffers,
benchmark_iters, benchmark_label):
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.autotune = autotune
if autotune:
options.experimental_optimization.autotune_algorithm = algorithm.value
options.experimental_optimization.autotune_buffers = autotune_buffers
dataset = dataset.with_options(options)
iterator = dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next()
# Run the op directly to avoid copying the tensor to python.
get_next_op = nest.flatten(get_next)[0].op
deltas = []
with session.Session() as sess:
for _ in range(5):
sess.run(get_next.op)
for _ in range(10000):
sess.run(get_next_op)
for _ in range(benchmark_iters):
start = time.time()
sess.run(get_next.op)
sess.run(get_next_op)
end = time.time()
deltas.append(end - start)
autotune_string = "_autotune_{}".format(
"parallelism_and_buffer_sizes"
if autotune_buffers else "parallelism_only")
self.report_benchmark(
iters=10000,
iters=benchmark_iters,
wall_time=np.median(deltas),
name="map" + (("_autotune_%s" % algorithm.name) if autotune else ""))
name=benchmark_label + (autotune_string if autotune else ""))
return np.median(deltas)
def benchmark_map(self):
a = self._benchmark_map(autotune=False)
b = self._benchmark_map(autotune=True, autotune_buffers=False)
c = self._benchmark_map(autotune=True, autotune_buffers=True)
print("autotune parallelism vs no autotuning speedup: {}".format(a / b))
print("autotune parallelism and buffer sizes vs no autotuning speedup: {}"
.format(a / c))
def _benchmark_map(self, autotune, autotune_buffers=False):
k = 1024 * 1024
dataset = dataset_ops.Dataset.from_tensors(
(np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))).repeat()
dataset = dataset.map(
math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE)
return self._run_benchmark(
dataset,
autotune,
autotune_buffers,
benchmark_iters=10000,
benchmark_label="map")
def benchmark_map_and_batch(self):
a = self._benchmark_map_and_batch(autotune=False)
b = self._benchmark_map_and_batch(autotune=True)
c = self._benchmark_map_and_batch(
autotune=True, algorithm=dataset_ops.AutotuneAlgorithm.GRADIENT_DESCENT)
print("HillClimb vs Default speedup: %f" % (a / b))
print("GradientDescent vs Default speedup: %f" % (a / c))
b = self._benchmark_map_and_batch(autotune=True, autotune_buffers=False)
c = self._benchmark_map_and_batch(autotune=True, autotune_buffers=True)
print("autotune parallelism vs no autotuning speedup: {}".format(a / b))
print("autotune parallelism and buffer sizes vs no autotuning speedup: {}"
.format(a / c))
def _benchmark_map_and_batch(
self, autotune, algorithm=dataset_ops.AutotuneAlgorithm.HILL_CLIMB):
def _benchmark_map_and_batch(self, autotune, autotune_buffers=False):
batch_size = 16
k = 1024 * 1024
dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
np.random.rand(4 * k,
1))).repeat()
dataset = dataset_ops.Dataset.from_tensors(
(np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))).repeat()
dataset = dataset.map(
math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE)
dataset = dataset.batch(batch_size=batch_size)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.map_and_batch_fusion = True
options.experimental_optimization.autotune = autotune
if autotune:
options.experimental_optimization.autotune_algorithm = algorithm.value
dataset = dataset.with_options(options)
iterator = dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next()
deltas = []
with session.Session() as sess:
for _ in range(5):
sess.run(get_next.op)
for _ in range(1000):
start = time.time()
sess.run(get_next.op)
end = time.time()
deltas.append(end - start)
self.report_benchmark(
iters=1000,
wall_time=np.median(deltas),
name="map_and_batch" +
(("_autotune_%s" % algorithm.name) if autotune else ""))
return np.median(deltas)
return self._run_benchmark(
dataset,
autotune,
autotune_buffers,
benchmark_iters=1000,
benchmark_label="map_and_batch")
def benchmark_interleave(self):
a = self._benchmark_interleave(autotune=False)
b = self._benchmark_interleave(autotune=True)
c = self._benchmark_interleave(
autotune=True, algorithm=dataset_ops.AutotuneAlgorithm.GRADIENT_DESCENT)
print("HillClimb vs Default speedup: %f" % (a / b))
print("GradientDescent vs Default speedup: %f" % (a / c))
b = self._benchmark_interleave(autotune=True, autotune_buffers=False)
c = self._benchmark_interleave(autotune=True, autotune_buffers=True)
print("autotune parallelism vs no autotuning speedup: {}".format(a / b))
print("autotune parallelism and buffer sizes vs no autotuning speedup: {}"
.format(a / c))
def _benchmark_interleave(self,
autotune,
algorithm=dataset_ops.AutotuneAlgorithm.HILL_CLIMB):
def _benchmark_interleave(self, autotune, autotune_buffers=False):
k = 1024 * 1024
dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
np.random.rand(4 * k,
1))).repeat()
dataset = dataset_ops.Dataset.from_tensors(
(np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))).repeat()
dataset = dataset.map(math_ops.matmul)
dataset = dataset_ops.Dataset.range(1).repeat().interleave(
lambda _: dataset,
cycle_length=10,
num_parallel_calls=dataset_ops.AUTOTUNE)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.autotune = autotune
if autotune:
options.experimental_optimization.autotune_algorithm = algorithm.value
dataset = dataset.with_options(options)
iterator = dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next()
deltas = []
with session.Session() as sess:
for _ in range(5):
sess.run(get_next.op)
for _ in range(10000):
start = time.time()
sess.run(get_next.op)
end = time.time()
deltas.append(end - start)
self.report_benchmark(
iters=10000,
wall_time=np.median(deltas),
name="interleave" +
(("_autotune_%s" % algorithm.name) if autotune else ""))
return np.median(deltas)
return self._run_benchmark(
dataset,
autotune,
autotune_buffers,
benchmark_iters=10000,
benchmark_label="interleave")
def benchmark_map_and_interleave(self):
a = self._benchmark_map_and_interleave(autotune=False)
b = self._benchmark_map_and_interleave(autotune=True)
c = self._benchmark_map_and_interleave(
autotune=True, algorithm=dataset_ops.AutotuneAlgorithm.GRADIENT_DESCENT)
print("HillClimb vs Default speedup: %f" % (a / b))
print("GradientDescent vs Default speedup: %f" % (a / c))
b = self._benchmark_map_and_interleave(
autotune=True, autotune_buffers=False)
c = self._benchmark_map_and_interleave(autotune=True, autotune_buffers=True)
print("autotune parallelism vs no autotuning speedup: {}".format(a / b))
print("autotune parallelism and buffer sizes vs no autotuning speedup: {}"
.format(a / c))
def _benchmark_map_and_interleave(
self, autotune, algorithm=dataset_ops.AutotuneAlgorithm.HILL_CLIMB):
def _benchmark_map_and_interleave(self, autotune, autotune_buffers=False):
k = 1024 * 1024
a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))
@ -204,42 +172,26 @@ class AutotuneBenchmark(test.Benchmark):
dataset = dataset_ops.Dataset.zip((dataset, dataset_c))
dataset = dataset.map(f2, num_parallel_calls=dataset_ops.AUTOTUNE)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.autotune = autotune
if autotune:
options.experimental_optimization.autotune_algorithm = algorithm.value
dataset = dataset.with_options(options)
iterator = dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next()
deltas = []
with session.Session() as sess:
for _ in range(5):
sess.run(get_next)
for _ in range(10000):
start = time.time()
sess.run(get_next)
end = time.time()
deltas.append(end - start)
self.report_benchmark(
iters=10000,
wall_time=np.median(deltas),
name="map_and_interleave" +
(("_autotune_%s" % algorithm.name) if autotune else ""))
return np.median(deltas)
return self._run_benchmark(
dataset,
autotune,
autotune_buffers,
benchmark_iters=10000,
benchmark_label="map_and_interleave")
def benchmark_map_batch_and_interleave(self):
a = self._benchmark_map_batch_and_interleave(autotune=False)
b = self._benchmark_map_batch_and_interleave(autotune=True)
b = self._benchmark_map_batch_and_interleave(
autotune=True, autotune_buffers=False)
c = self._benchmark_map_batch_and_interleave(
autotune=True, algorithm=dataset_ops.AutotuneAlgorithm.GRADIENT_DESCENT)
print("HillClimb vs Default speedup: %f" % (a / b))
print("GradientDescent vs Default speedup: %f" % (a / c))
autotune=True, autotune_buffers=True)
print("autotune parallelism vs no autotuning speedup: {}".format(a / b))
print("autotune parallelism and buffer sizes vs no autotuning speedup: {}"
.format(a / c))
def _benchmark_map_batch_and_interleave(
self, autotune, algorithm=dataset_ops.AutotuneAlgorithm.HILL_CLIMB):
def _benchmark_map_batch_and_interleave(self,
autotune,
autotune_buffers=False):
batch_size = 16
k = 1024 * 1024
a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
@ -268,32 +220,12 @@ class AutotuneBenchmark(test.Benchmark):
math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE)
dataset_c = dataset_c.batch(batch_size=batch_size)
dataset = dataset_ops.Dataset.zip((dataset, dataset_c))
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.map_and_batch_fusion = True
options.experimental_optimization.autotune = autotune
if autotune:
options.experimental_optimization.autotune_algorithm = algorithm.value
dataset = dataset.with_options(options)
iterator = dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next()
deltas = []
with session.Session() as sess:
for _ in range(5):
sess.run(get_next)
for _ in range(1000):
start = time.time()
sess.run(get_next)
end = time.time()
deltas.append(end - start)
self.report_benchmark(
iters=1000,
wall_time=np.median(deltas),
name="map_batch_and_interleave" +
(("_autotune_%s" % algorithm.name) if autotune else ""))
return np.median(deltas)
return self._run_benchmark(
dataset,
autotune,
autotune_buffers,
benchmark_iters=1000,
benchmark_label="map_batch_and_interleave")
if __name__ == "__main__":

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.experimental.ops import threadpool
@ -215,11 +216,11 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset)
self.assertGreaterEqual(len(w), 1)
expected = ("tf.data static optimizations are not compatible with "
"tf.Variable. The following optimizations will be disabled: %s."
" To enable optimizations, use resource variables instead by "
expected = ("tf.data graph rewrites are not compatible with "
"tf.Variable. The following rewrites will be disabled: %s."
" To enable rewrites, use resource variables instead by "
"calling `tf.enable_resource_variables()` at the start of the "
"program." % (", ".join(options._static_optimizations())))
"program." % (", ".join(options._graph_rewrites())))
self.assertTrue(any([expected in str(warning) for warning in w]))
# Check that outputs are the same in the optimized and unoptimized cases,
@ -249,10 +250,10 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
"shuffle_and_repeat_fusion",
]
self.assertEqual(
set(options._static_optimizations()), set(expected_optimizations))
set(options._graph_rewrites()), set(expected_optimizations))
def testOptimizationDisableDefault(self):
"""Tests that we can disable all static optimizations enabled by default.
"""Tests that we can disable all graph optimizations enabled by default.
If the `apply_default_optimizations` optimization options flag is False,
only explicitly enabled optimizations will be applied.
@ -266,7 +267,27 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
"noop_elimination",
]
self.assertEqual(
set(options._static_optimizations()), set(expected_optimizations))
set(options._graph_rewrites()), set(expected_optimizations))
def testAutotuningDefaults(self):
options = dataset_ops.Options()
# Check defaults
autotune, algorithm, cpu_budget = options._autotune_settings()
self.assertTrue(autotune)
self.assertEqual(algorithm,
optimization_options._AutotuneAlgorithm.HILL_CLIMB)
self.assertEqual(cpu_budget, 0)
def testAutotuningBufferSizes(self):
options = dataset_ops.Options()
options.experimental_optimization.autotune_buffers = True
self.assertIn("inject_prefetch", options._graph_rewrites())
autotune, algorithm, cpu_budget = options._autotune_settings()
self.assertTrue(autotune)
self.assertEqual(algorithm,
optimization_options._AutotuneAlgorithm.GRADIENT_DESCENT)
self.assertEqual(cpu_budget, 0)
if __name__ == "__main__":

View File

@ -44,9 +44,9 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
dataset, ["/cpu:1", "/cpu:2"])
dataset = multi_device_iterator._dataset # pylint: disable=protected-access
self.assertIn("slack", dataset.options()._static_optimizations())
self.assertIn("slack", dataset.options()._graph_rewrites())
self.assertIn("slack:slack_period:2",
dataset.options()._static_optimization_configs())
dataset.options()._graph_rewrite_configs())
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config):
@ -67,9 +67,9 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
options = dataset_ops.Options()
options.experimental_slack = True
dataset = dataset.with_options(options)
self.assertIn("slack", dataset.options()._static_optimizations())
self.assertIn("slack", dataset.options()._graph_rewrites())
self.assertIn("slack:slack_period:1",
dataset.options()._static_optimization_configs())
dataset.options()._graph_rewrite_configs())
self.assertDatasetProduces(dataset, range(10))
def testWithPassthroughDataset(self):

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@ -24,6 +26,12 @@ from tensorflow.python.util.tf_export import tf_export
_ENABLE_AUTOTUNE_BUFFERS_BY_DEFAULT = False
class _AutotuneAlgorithm(enum.Enum):
"""Controls what algorithm is used in the autotune implementation."""
HILL_CLIMB = 0
GRADIENT_DESCENT = 1
@tf_export("data.experimental.MapVectorizationOptions")
class MapVectorizationOptions(options.OptionsBase):
"""Represents options for the MapVectorization optimization."""
@ -44,12 +52,14 @@ class MapVectorizationOptions(options.OptionsBase):
"original segment at runtime based on their iterations speed. If None, "
"defaults to False.")
def _static_optimizations(self):
def _graph_rewrites(self):
if self.enabled:
return ["map_vectorization"]
return []
def _static_optimization_configs(self):
def _graph_rewrite_configs(self):
if not self.enabled:
return []
if self.use_choose_fastest:
return ["map_vectorization:use_choose_fastest:true"]
else:
@ -76,7 +86,7 @@ class OptimizationOptions(options.OptionsBase):
name="apply_default_optimizations",
ty=bool,
docstring=
"Whether to apply default static optimizations. If False, only static "
"Whether to apply default graph optimizations. If False, only graph "
"optimizations that have been explicitly enabled will be applied.")
autotune = options.create_option(
@ -86,13 +96,6 @@ class OptimizationOptions(options.OptionsBase):
"Whether to automatically tune performance knobs. If None, defaults to "
"True.")
autotune_algorithm = options.create_option(
name="autotune_algorithm",
ty=int,
docstring=
"When autotuning is enabled (through `autotune`), identifies the "
"algorithm to use for the autotuning optimization.")
autotune_buffers = options.create_option(
name="autotune_buffers",
ty=bool,
@ -183,8 +186,34 @@ class OptimizationOptions(options.OptionsBase):
docstring="Whether to fuse shuffle and repeat transformations. If None, "
"defaults to True.")
def _static_optimizations(self):
"""Produces the list of enabled static optimizations."""
def _autotune_buffers(self):
if self.autotune_buffers is not None:
return self.autotune_buffers
# The default setting for autotune_buffers is based on
# _ENABLE_AUTOTUNE_BUFFERS_BY_DEFAULT
return _ENABLE_AUTOTUNE_BUFFERS_BY_DEFAULT
def _autotune_settings(self):
# Default autotune settings
autotune = True
# If autotune_buffers is enabled, we use the GRADIENT_DESCENT algorithm by
# default, which is more performant for tuning heterogeneous parameters.
algorithm = (
_AutotuneAlgorithm.GRADIENT_DESCENT
if self._autotune_buffers() else _AutotuneAlgorithm.HILL_CLIMB)
cpu_budget = 0 # Indicates that all CPU cores should be used by default.
# Set these options if they are explicitly set by the user.
if self.autotune is False: # pylint: disable=g-bool-id-comparison
autotune = False
if self.autotune_cpu_budget is not None:
cpu_budget = self.autotune_cpu_budget
return autotune, algorithm, cpu_budget
def _graph_rewrites(self):
"""Produces the list of enabled graph optimizations."""
result = set()
all_optimizations = [
"filter_fusion",
@ -215,17 +244,19 @@ class OptimizationOptions(options.OptionsBase):
result.add(optimization)
if self.map_vectorization is not None:
result.update(self.map_vectorization._static_optimizations()) # pylint: disable=protected-access
result.update(self.map_vectorization._graph_rewrites()) # pylint: disable=protected-access
# The default setting for autotune_buffers is based on
# _ENABLE_AUTOTUNE_BUFFERS_BY_DEFAULT
autotune_buffers = self.autotune_buffers or (
self.autotune_buffers is None and _ENABLE_AUTOTUNE_BUFFERS_BY_DEFAULT)
autotune_buffers = self._autotune_buffers()
if self.autotune is not False and autotune_buffers: # pylint: disable=g-bool-id-comparison
# When autotuning buffer sizes is enabled, we inject a `prefetch`
# transformation after asynchronous dataset ops. Only the buffer sizes of
# prefetch transformations will be autotuned, though this is practically
# equivalent to tuning the buffer sizes of the other asynchronous
# transformations.
result.add("inject_prefetch")
return sorted(list(result))
def _static_optimization_configs(self):
def _graph_rewrite_configs(self):
if self.map_vectorization is not None:
return self.map_vectorization._static_optimization_configs() # pylint: disable=protected-access
return self.map_vectorization._graph_rewrite_configs() # pylint: disable=protected-access
return []

View File

@ -29,7 +29,6 @@ import numpy as np
import six
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.core.framework import graph_pb2
from tensorflow.python import tf2
from tensorflow.python.compat import compat
@ -90,17 +89,11 @@ autograph = lazy_loader.LazyLoader(
ops.NotDifferentiable("ReduceDataset")
# A constant that can be used to enable auto-tuning.
AUTOTUNE = -1
tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
class AutotuneAlgorithm(enum.Enum):
HILL_CLIMB = 0
GRADIENT_DESCENT = 1
class ExternalStatePolicy(enum.Enum):
WARN = 0
IGNORE = 1
@ -227,9 +220,9 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
def. In that case, the state in these ops would be thrown away.
strip_device_assignment: If true, non-local (i.e. job and task) device
assignment is stripped from ops in the serialized graph.
external_state_policy: The ExternalStatePolicy enum that determines how
we handle input pipelines that depend on external state. By default,
its set to WARN.
external_state_policy: The ExternalStatePolicy enum that determines how we
handle input pipelines that depend on external state. By default, its
set to WARN.
Returns:
A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
@ -355,6 +348,8 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
dataset = self
options = self.options()
# (1) Apply threading options
if options.experimental_threading is not None:
t_options = options.experimental_threading
if t_options.max_intra_op_parallelism is not None:
@ -363,36 +358,31 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
if t_options.private_threadpool_size is not None:
dataset = _PrivateThreadPoolDataset(dataset,
t_options.private_threadpool_size)
# (2) Apply graph rewrite options
# pylint: disable=protected-access
static_optimizations = options._static_optimizations()
static_optimization_configs = options._static_optimization_configs()
graph_rewrites = options._graph_rewrites()
graph_rewrite_configs = options._graph_rewrite_configs()
# pylint: enable=protected-access
if static_optimizations:
if graph_rewrites:
if self._has_captured_ref():
warnings.warn(
"tf.data static optimizations are not compatible with tf.Variable. "
"The following optimizations will be disabled: %s. To enable "
"optimizations, use resource variables instead by calling "
"tf.data graph rewrites are not compatible with tf.Variable. "
"The following rewrites will be disabled: %s. To enable "
"rewrites, use resource variables instead by calling "
"`tf.enable_resource_variables()` at the start of the program." %
", ".join(static_optimizations))
", ".join(graph_rewrites))
else:
dataset = _OptimizeDataset(dataset, static_optimizations,
static_optimization_configs)
dataset = _OptimizeDataset(dataset, graph_rewrites,
graph_rewrite_configs)
autotune = True
algorithm = AutotuneAlgorithm.HILL_CLIMB
cpu_budget = 0 # Indicates that all CPU cores should be used.
if options.experimental_optimization is not None:
if options.experimental_optimization.autotune is False: # pylint: disable=g-bool-id-comparison
autotune = False
if options.experimental_optimization.autotune_algorithm is not None:
algorithm = options.experimental_optimization.autotune_algorithm
if options.experimental_optimization.autotune_cpu_budget is not None:
cpu_budget = options.experimental_optimization.autotune_cpu_budget
# (3) Apply autotune options
autotune, algorithm, cpu_budget = options._autotune_settings() # pylint: disable=protected-access
if autotune:
dataset = _ModelDataset(dataset, algorithm, cpu_budget)
# (4) Apply stats aggregator options
if options.experimental_stats and options.experimental_stats.aggregator: # pylint: disable=line-too-long
dataset = _SetStatsAggregatorDataset( # pylint: disable=protected-access
dataset, options.experimental_stats.aggregator,
@ -2600,7 +2590,7 @@ def get_legacy_output_types(dataset_or_iterator):
class Options(options_lib.OptionsBase):
"""Represents options for tf.data.Dataset.
An `Options` object can be, for instance, used to control which static
An `Options` object can be, for instance, used to control which graph
optimizations to apply or whether to use performance modeling to dynamically
tune the parallelism of operations such as `tf.data.Dataset.map` or
`tf.data.Dataset.interleave`.
@ -2675,11 +2665,15 @@ class Options(options_lib.OptionsBase):
"might be thrown away; FAIL: We fail if any state is being captured.",
default_factory=lambda: ExternalStatePolicy.WARN)
def _static_optimizations(self):
"""Produces the list of enabled static optimizations."""
def _graph_rewrites(self):
"""Produces the list of enabled static graph rewrites."""
result = []
result.extend(self.experimental_optimization._static_optimizations()) # pylint: disable=protected-access
if self.experimental_optimization is not None:
result.extend(self.experimental_optimization._graph_rewrites()) # pylint: disable=protected-access
else:
# Apply default options
result.extend(
optimization_options.OptimizationOptions()._graph_rewrites()) # pylint: disable=protected-access
if self.experimental_deterministic is False:
result.append("make_sloppy")
@ -2692,12 +2686,11 @@ class Options(options_lib.OptionsBase):
result.append("make_stateless")
return result
def _static_optimization_configs(self):
"""Produces the list of configurations for enabled static optimizations."""
def _graph_rewrite_configs(self):
"""Produces the list of configurations for enabled graph optimizations."""
result = []
if self.experimental_optimization:
result.extend(
self.experimental_optimization._static_optimization_configs()) # pylint: disable=protected-access
result.extend(self.experimental_optimization._graph_rewrite_configs()) # pylint: disable=protected-access
if self.experimental_slack:
num_devices = self.experimental_distribute.num_devices
@ -2706,6 +2699,13 @@ class Options(options_lib.OptionsBase):
result.append("slack:slack_period:%d" % num_devices)
return result
def _autotune_settings(self):
if self.experimental_optimization is not None:
return self.experimental_optimization._autotune_settings() # pylint: disable=protected-access
# Return default autotune options
return optimization_options.OptimizationOptions()._autotune_settings() # pylint: disable=protected-access
def merge(self, options):
"""Merges itself with the given `tf.data.Options`.
@ -4177,20 +4177,11 @@ class _ModelDataset(UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, algorithm, cpu_budget):
self._input_dataset = input_dataset
# TODO(jsimsa): This check is introduced for forward compatibility and can
# be removed after 7/24/2019. At that point, all servers are expected to
# recognize the `algorithm` attribute.
if algorithm != AutotuneAlgorithm.HILL_CLIMB:
variant_tensor = gen_dataset_ops.model_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
algorithm=algorithm,
cpu_budget=cpu_budget,
**self._flat_structure)
else:
variant_tensor = gen_dataset_ops.model_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
cpu_budget=cpu_budget,
**self._flat_structure)
variant_tensor = gen_dataset_ops.model_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
algorithm=algorithm.value,
cpu_budget=cpu_budget,
**self._flat_structure)
super(_ModelDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -11,10 +11,6 @@ tf_class {
name: "autotune"
mtype: "<type \'property\'>"
}
member {
name: "autotune_algorithm"
mtype: "<type \'property\'>"
}
member {
name: "autotune_buffers"
mtype: "<type \'property\'>"

View File

@ -11,10 +11,6 @@ tf_class {
name: "autotune"
mtype: "<type \'property\'>"
}
member {
name: "autotune_algorithm"
mtype: "<type \'property\'>"
}
member {
name: "autotune_buffers"
mtype: "<type \'property\'>"