[tf.data] Making it possible to override tf.data options.
PiperOrigin-RevId: 326680778 Change-Id: Ia41fb00680240d3e1488fc0165647e81e5837d6c
This commit is contained in:
parent
57e69437b4
commit
505a1599c3
@ -88,6 +88,7 @@
|
||||
dataset when it is safe to do so. The optimization can be disabled via
|
||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||
option.
|
||||
* `tf.data.Options` were previously immutable and can now be overriden.
|
||||
* `tf.image`:
|
||||
* Added deterministic `tf.image.stateless_random_*` functions for each
|
||||
`tf.image.random_*` function. Added a new op
|
||||
|
@ -51,25 +51,28 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(options, ds.options())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsTwiceDifferent(self):
|
||||
def testOptionsTwiceDifferentOptions(self):
|
||||
options1 = dataset_ops.Options()
|
||||
options1.experimental_optimization.autotune = True
|
||||
options2 = dataset_ops.Options()
|
||||
options2.experimental_deterministic = False
|
||||
ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
|
||||
options2)
|
||||
ds = dataset_ops.Dataset.range(0)
|
||||
ds = ds.with_options(options1)
|
||||
ds = ds.with_options(options2)
|
||||
self.assertTrue(ds.options().experimental_optimization.autotune)
|
||||
# Explicitly check that flag is False since assertFalse allows None
|
||||
self.assertIs(ds.options().experimental_deterministic, False)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsTwiceDifferentError(self):
|
||||
def testOptionsTwiceSameOption(self):
|
||||
options1 = dataset_ops.Options()
|
||||
options1.experimental_optimization.autotune = True
|
||||
options1.experimental_optimization.autotune = False
|
||||
options2 = dataset_ops.Options()
|
||||
options2.experimental_optimization.autotune = False
|
||||
with self.assertRaisesRegex(ValueError, "Cannot merge incompatible values"):
|
||||
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
|
||||
options2.experimental_optimization.autotune = True
|
||||
ds = dataset_ops.Dataset.range(0)
|
||||
ds = ds.with_options(options1)
|
||||
ds = ds.with_options(options2)
|
||||
self.assertTrue(ds.options().experimental_optimization.autotune)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsMergeOptionsFromMultipleInputs(self):
|
||||
@ -77,9 +80,9 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
options1.experimental_optimization.autotune = True
|
||||
options2 = dataset_ops.Options()
|
||||
options2.experimental_deterministic = True
|
||||
ds = dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.range(0).with_options(options1),
|
||||
dataset_ops.Dataset.range(0).with_options(options2)))
|
||||
ds1 = dataset_ops.Dataset.range(0).with_options(options1)
|
||||
ds2 = dataset_ops.Dataset.range(0).with_options(options2)
|
||||
ds = dataset_ops.Dataset.zip((ds1, ds2))
|
||||
self.assertTrue(ds.options().experimental_optimization.autotune)
|
||||
self.assertTrue(ds.options().experimental_deterministic)
|
||||
|
||||
@ -99,6 +102,16 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(options1.experimental_threading,
|
||||
threading_options.ThreadingOptions())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMutableOptions(self):
|
||||
ds = dataset_ops.Dataset.range(0)
|
||||
ds.options().experimental_optimization.autotune = True
|
||||
self.assertTrue(ds.options().experimental_optimization.autotune)
|
||||
options = dataset_ops.Options()
|
||||
ds = ds.with_options(options)
|
||||
ds.options().experimental_deterministic = True
|
||||
self.assertTrue(ds.options().experimental_deterministic)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testNestedDataset(self):
|
||||
ds = dataset_ops.Dataset.from_tensors(0)
|
||||
|
@ -2836,20 +2836,37 @@ def get_legacy_output_types(dataset_or_iterator):
|
||||
|
||||
@tf_export("data.Options")
|
||||
class Options(options_lib.OptionsBase):
|
||||
"""Represents options for tf.data.Dataset.
|
||||
"""Represents options for `tf.data.Dataset`.
|
||||
|
||||
An `Options` object can be, for instance, used to control which graph
|
||||
optimizations to apply or whether to use performance modeling to dynamically
|
||||
tune the parallelism of operations such as `tf.data.Dataset.map` or
|
||||
`tf.data.Dataset.interleave`.
|
||||
A `tf.data.Options` object can be, for instance, used to control which static
|
||||
optimizations to apply to the input pipeline graph or whether to use
|
||||
performance modeling to dynamically tune the parallelism of operations such as
|
||||
`tf.data.Dataset.map` or `tf.data.Dataset.interleave`.
|
||||
|
||||
After constructing an `Options` object, use `dataset.with_options(options)` to
|
||||
apply the options to a dataset.
|
||||
The options are set for the entire dataset and are carried over to datasets
|
||||
created through tf.data transformations.
|
||||
|
||||
>>> dataset = tf.data.Dataset.range(3)
|
||||
The options can be set either by mutating the object returned by
|
||||
`tf.data.Dataset.options()` or by constructing an `Options` object and using
|
||||
the `tf.data.Dataset.with_options(options)` transformation, which returns a
|
||||
dataset with the options set.
|
||||
|
||||
>>> dataset = tf.data.Dataset.range(42)
|
||||
>>> dataset.options().experimental_deterministic = False
|
||||
>>> print(dataset.options().experimental_deterministic)
|
||||
False
|
||||
|
||||
>>> dataset = tf.data.Dataset.range(42)
|
||||
>>> options = tf.data.Options()
|
||||
>>> # Set options here.
|
||||
>>> options.experimental_deterministic = False
|
||||
>>> dataset = dataset.with_options(options)
|
||||
>>> print(dataset.options().experimental_deterministic)
|
||||
False
|
||||
|
||||
Note: A known limitation of the `tf.data.Options` implementation is that the
|
||||
options are not preserved across tf.function boundaries. In particular, to
|
||||
set options for a dataset that is iterated within a tf.function, the options
|
||||
need to be set within the same tf.function.
|
||||
"""
|
||||
|
||||
experimental_deterministic = options_lib.create_option(
|
||||
@ -2968,17 +2985,15 @@ class Options(options_lib.OptionsBase):
|
||||
def merge(self, options):
|
||||
"""Merges itself with the given `tf.data.Options`.
|
||||
|
||||
The given `tf.data.Options` can be merged as long as there does not exist an
|
||||
attribute that is set to different values in `self` and `options`.
|
||||
If this object and the `options` to merge set an option differently, a
|
||||
warning is generated and this object's value is updated with the `options`
|
||||
object's value.
|
||||
|
||||
Args:
|
||||
options: a `tf.data.Options` to merge with
|
||||
|
||||
Raises:
|
||||
ValueError: if the given `tf.data.Options` cannot be merged
|
||||
|
||||
Returns:
|
||||
New `tf.data.Options()` object which is the result of merging self with
|
||||
New `tf.data.Options` object which is the result of merging self with
|
||||
the input `tf.data.Options`.
|
||||
"""
|
||||
return options_lib.merge_options(self, options)
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from absl import logging
|
||||
|
||||
|
||||
def _internal_attr_name(name):
|
||||
return "_" + name
|
||||
@ -98,23 +100,23 @@ def merge_options(*options_list):
|
||||
"""Merges the given options, returning the result as a new options object.
|
||||
|
||||
The input arguments are expected to have a matching type that derives from
|
||||
`OptionsBase` (and thus each represent a set of options). The method outputs
|
||||
an object of the same type created by merging the sets of options represented
|
||||
by the input arguments.
|
||||
`tf.data.OptionsBase` (and thus each represent a set of options). The method
|
||||
outputs an object of the same type created by merging the sets of options
|
||||
represented by the input arguments.
|
||||
|
||||
The sets of options can be merged as long as there does not exist an option
|
||||
with different non-default values.
|
||||
If an option is set to different values by different options objects, the
|
||||
result will match the setting of the options object that appears in the input
|
||||
list last.
|
||||
|
||||
If an option is an instance of `OptionsBase` itself, then this method is
|
||||
applied recursively to the set of options represented by this option.
|
||||
If an option is an instance of `tf.data.OptionsBase` itself, then this method
|
||||
is applied recursively to the set of options represented by this option.
|
||||
|
||||
Args:
|
||||
*options_list: options to merge
|
||||
|
||||
Raises:
|
||||
TypeError: if the input arguments are incompatible or not derived from
|
||||
`OptionsBase`
|
||||
ValueError: if the given options cannot be merged
|
||||
`tf.data.OptionsBase`
|
||||
|
||||
Returns:
|
||||
A new options object which is the result of merging the given options.
|
||||
@ -134,7 +136,7 @@ def merge_options(*options_list):
|
||||
default_options = result_type()
|
||||
result = result_type()
|
||||
for options in options_list:
|
||||
# Iterate over all set options and merge the into the result.
|
||||
# Iterate over all set options and merge them into the result.
|
||||
for name in options._options: # pylint: disable=protected-access
|
||||
this = getattr(result, name)
|
||||
that = getattr(options, name)
|
||||
@ -146,7 +148,7 @@ def merge_options(*options_list):
|
||||
elif isinstance(this, OptionsBase):
|
||||
setattr(result, name, merge_options(this, that))
|
||||
elif this != that:
|
||||
raise ValueError(
|
||||
"Cannot merge incompatible values (%r and %r) of option: %s" %
|
||||
(this, that, name))
|
||||
logging.warning("Changing the value of option %s from %r to %r.", name,
|
||||
this, that)
|
||||
setattr(result, name, that)
|
||||
return result
|
||||
|
Loading…
Reference in New Issue
Block a user