[tf.data] Making it possible to override tf.data options.

PiperOrigin-RevId: 326680778
Change-Id: Ia41fb00680240d3e1488fc0165647e81e5837d6c
This commit is contained in:
Jiri Simsa 2020-08-14 10:23:33 -07:00 committed by TensorFlower Gardener
parent 57e69437b4
commit 505a1599c3
4 changed files with 70 additions and 39 deletions

View File

@ -88,6 +88,7 @@
dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.data.Options` were previously immutable and can now be overriden.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op

View File

@ -51,25 +51,28 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(options, ds.options())
@combinations.generate(test_base.default_test_combinations())
def testOptionsTwiceDifferent(self):
def testOptionsTwiceDifferentOptions(self):
options1 = dataset_ops.Options()
options1.experimental_optimization.autotune = True
options2 = dataset_ops.Options()
options2.experimental_deterministic = False
ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
options2)
ds = dataset_ops.Dataset.range(0)
ds = ds.with_options(options1)
ds = ds.with_options(options2)
self.assertTrue(ds.options().experimental_optimization.autotune)
# Explicitly check that flag is False since assertFalse allows None
self.assertIs(ds.options().experimental_deterministic, False)
@combinations.generate(test_base.default_test_combinations())
def testOptionsTwiceDifferentError(self):
def testOptionsTwiceSameOption(self):
options1 = dataset_ops.Options()
options1.experimental_optimization.autotune = True
options1.experimental_optimization.autotune = False
options2 = dataset_ops.Options()
options2.experimental_optimization.autotune = False
with self.assertRaisesRegex(ValueError, "Cannot merge incompatible values"):
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
options2.experimental_optimization.autotune = True
ds = dataset_ops.Dataset.range(0)
ds = ds.with_options(options1)
ds = ds.with_options(options2)
self.assertTrue(ds.options().experimental_optimization.autotune)
@combinations.generate(test_base.default_test_combinations())
def testOptionsMergeOptionsFromMultipleInputs(self):
@ -77,9 +80,9 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
options1.experimental_optimization.autotune = True
options2 = dataset_ops.Options()
options2.experimental_deterministic = True
ds = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(0).with_options(options1),
dataset_ops.Dataset.range(0).with_options(options2)))
ds1 = dataset_ops.Dataset.range(0).with_options(options1)
ds2 = dataset_ops.Dataset.range(0).with_options(options2)
ds = dataset_ops.Dataset.zip((ds1, ds2))
self.assertTrue(ds.options().experimental_optimization.autotune)
self.assertTrue(ds.options().experimental_deterministic)
@ -99,6 +102,16 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(options1.experimental_threading,
threading_options.ThreadingOptions())
@combinations.generate(test_base.default_test_combinations())
def testMutableOptions(self):
ds = dataset_ops.Dataset.range(0)
ds.options().experimental_optimization.autotune = True
self.assertTrue(ds.options().experimental_optimization.autotune)
options = dataset_ops.Options()
ds = ds.with_options(options)
ds.options().experimental_deterministic = True
self.assertTrue(ds.options().experimental_deterministic)
@combinations.generate(test_base.eager_only_combinations())
def testNestedDataset(self):
ds = dataset_ops.Dataset.from_tensors(0)

View File

@ -2836,20 +2836,37 @@ def get_legacy_output_types(dataset_or_iterator):
@tf_export("data.Options")
class Options(options_lib.OptionsBase):
"""Represents options for tf.data.Dataset.
"""Represents options for `tf.data.Dataset`.
An `Options` object can be, for instance, used to control which graph
optimizations to apply or whether to use performance modeling to dynamically
tune the parallelism of operations such as `tf.data.Dataset.map` or
`tf.data.Dataset.interleave`.
A `tf.data.Options` object can be, for instance, used to control which static
optimizations to apply to the input pipeline graph or whether to use
performance modeling to dynamically tune the parallelism of operations such as
`tf.data.Dataset.map` or `tf.data.Dataset.interleave`.
After constructing an `Options` object, use `dataset.with_options(options)` to
apply the options to a dataset.
The options are set for the entire dataset and are carried over to datasets
created through tf.data transformations.
>>> dataset = tf.data.Dataset.range(3)
The options can be set either by mutating the object returned by
`tf.data.Dataset.options()` or by constructing an `Options` object and using
the `tf.data.Dataset.with_options(options)` transformation, which returns a
dataset with the options set.
>>> dataset = tf.data.Dataset.range(42)
>>> dataset.options().experimental_deterministic = False
>>> print(dataset.options().experimental_deterministic)
False
>>> dataset = tf.data.Dataset.range(42)
>>> options = tf.data.Options()
>>> # Set options here.
>>> options.experimental_deterministic = False
>>> dataset = dataset.with_options(options)
>>> print(dataset.options().experimental_deterministic)
False
Note: A known limitation of the `tf.data.Options` implementation is that the
options are not preserved across tf.function boundaries. In particular, to
set options for a dataset that is iterated within a tf.function, the options
need to be set within the same tf.function.
"""
experimental_deterministic = options_lib.create_option(
@ -2968,17 +2985,15 @@ class Options(options_lib.OptionsBase):
def merge(self, options):
"""Merges itself with the given `tf.data.Options`.
The given `tf.data.Options` can be merged as long as there does not exist an
attribute that is set to different values in `self` and `options`.
If this object and the `options` to merge set an option differently, a
warning is generated and this object's value is updated with the `options`
object's value.
Args:
options: a `tf.data.Options` to merge with
Raises:
ValueError: if the given `tf.data.Options` cannot be merged
Returns:
New `tf.data.Options()` object which is the result of merging self with
New `tf.data.Options` object which is the result of merging self with
the input `tf.data.Options`.
"""
return options_lib.merge_options(self, options)

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import collections
from absl import logging
def _internal_attr_name(name):
return "_" + name
@ -98,23 +100,23 @@ def merge_options(*options_list):
"""Merges the given options, returning the result as a new options object.
The input arguments are expected to have a matching type that derives from
`OptionsBase` (and thus each represent a set of options). The method outputs
an object of the same type created by merging the sets of options represented
by the input arguments.
`tf.data.OptionsBase` (and thus each represent a set of options). The method
outputs an object of the same type created by merging the sets of options
represented by the input arguments.
The sets of options can be merged as long as there does not exist an option
with different non-default values.
If an option is set to different values by different options objects, the
result will match the setting of the options object that appears in the input
list last.
If an option is an instance of `OptionsBase` itself, then this method is
applied recursively to the set of options represented by this option.
If an option is an instance of `tf.data.OptionsBase` itself, then this method
is applied recursively to the set of options represented by this option.
Args:
*options_list: options to merge
Raises:
TypeError: if the input arguments are incompatible or not derived from
`OptionsBase`
ValueError: if the given options cannot be merged
`tf.data.OptionsBase`
Returns:
A new options object which is the result of merging the given options.
@ -134,7 +136,7 @@ def merge_options(*options_list):
default_options = result_type()
result = result_type()
for options in options_list:
# Iterate over all set options and merge the into the result.
# Iterate over all set options and merge them into the result.
for name in options._options: # pylint: disable=protected-access
this = getattr(result, name)
that = getattr(options, name)
@ -146,7 +148,7 @@ def merge_options(*options_list):
elif isinstance(this, OptionsBase):
setattr(result, name, merge_options(this, that))
elif this != that:
raise ValueError(
"Cannot merge incompatible values (%r and %r) of option: %s" %
(this, that, name))
logging.warning("Changing the value of option %s from %r to %r.", name,
this, that)
setattr(result, name, that)
return result