[tf.data] Making it possible to override tf.data options.

PiperOrigin-RevId: 326680778
Change-Id: Ia41fb00680240d3e1488fc0165647e81e5837d6c
This commit is contained in:
Jiri Simsa 2020-08-14 10:23:33 -07:00 committed by TensorFlower Gardener
parent 57e69437b4
commit 505a1599c3
4 changed files with 70 additions and 39 deletions

View File

@ -88,6 +88,7 @@
dataset when it is safe to do so. The optimization can be disabled via dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset the `experimental_optimization.reorder_data_discarding_ops` dataset
option. option.
* `tf.data.Options` were previously immutable and can now be overriden.
* `tf.image`: * `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each * Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op `tf.image.random_*` function. Added a new op

View File

@ -51,25 +51,28 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(options, ds.options()) self.assertEqual(options, ds.options())
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testOptionsTwiceDifferent(self): def testOptionsTwiceDifferentOptions(self):
options1 = dataset_ops.Options() options1 = dataset_ops.Options()
options1.experimental_optimization.autotune = True options1.experimental_optimization.autotune = True
options2 = dataset_ops.Options() options2 = dataset_ops.Options()
options2.experimental_deterministic = False options2.experimental_deterministic = False
ds = dataset_ops.Dataset.range(0).with_options(options1).with_options( ds = dataset_ops.Dataset.range(0)
options2) ds = ds.with_options(options1)
ds = ds.with_options(options2)
self.assertTrue(ds.options().experimental_optimization.autotune) self.assertTrue(ds.options().experimental_optimization.autotune)
# Explicitly check that flag is False since assertFalse allows None # Explicitly check that flag is False since assertFalse allows None
self.assertIs(ds.options().experimental_deterministic, False) self.assertIs(ds.options().experimental_deterministic, False)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testOptionsTwiceDifferentError(self): def testOptionsTwiceSameOption(self):
options1 = dataset_ops.Options() options1 = dataset_ops.Options()
options1.experimental_optimization.autotune = True options1.experimental_optimization.autotune = False
options2 = dataset_ops.Options() options2 = dataset_ops.Options()
options2.experimental_optimization.autotune = False options2.experimental_optimization.autotune = True
with self.assertRaisesRegex(ValueError, "Cannot merge incompatible values"): ds = dataset_ops.Dataset.range(0)
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2) ds = ds.with_options(options1)
ds = ds.with_options(options2)
self.assertTrue(ds.options().experimental_optimization.autotune)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testOptionsMergeOptionsFromMultipleInputs(self): def testOptionsMergeOptionsFromMultipleInputs(self):
@ -77,9 +80,9 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
options1.experimental_optimization.autotune = True options1.experimental_optimization.autotune = True
options2 = dataset_ops.Options() options2 = dataset_ops.Options()
options2.experimental_deterministic = True options2.experimental_deterministic = True
ds = dataset_ops.Dataset.zip( ds1 = dataset_ops.Dataset.range(0).with_options(options1)
(dataset_ops.Dataset.range(0).with_options(options1), ds2 = dataset_ops.Dataset.range(0).with_options(options2)
dataset_ops.Dataset.range(0).with_options(options2))) ds = dataset_ops.Dataset.zip((ds1, ds2))
self.assertTrue(ds.options().experimental_optimization.autotune) self.assertTrue(ds.options().experimental_optimization.autotune)
self.assertTrue(ds.options().experimental_deterministic) self.assertTrue(ds.options().experimental_deterministic)
@ -99,6 +102,16 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(options1.experimental_threading, self.assertEqual(options1.experimental_threading,
threading_options.ThreadingOptions()) threading_options.ThreadingOptions())
@combinations.generate(test_base.default_test_combinations())
def testMutableOptions(self):
ds = dataset_ops.Dataset.range(0)
ds.options().experimental_optimization.autotune = True
self.assertTrue(ds.options().experimental_optimization.autotune)
options = dataset_ops.Options()
ds = ds.with_options(options)
ds.options().experimental_deterministic = True
self.assertTrue(ds.options().experimental_deterministic)
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testNestedDataset(self): def testNestedDataset(self):
ds = dataset_ops.Dataset.from_tensors(0) ds = dataset_ops.Dataset.from_tensors(0)

View File

@ -2836,20 +2836,37 @@ def get_legacy_output_types(dataset_or_iterator):
@tf_export("data.Options") @tf_export("data.Options")
class Options(options_lib.OptionsBase): class Options(options_lib.OptionsBase):
"""Represents options for tf.data.Dataset. """Represents options for `tf.data.Dataset`.
An `Options` object can be, for instance, used to control which graph A `tf.data.Options` object can be, for instance, used to control which static
optimizations to apply or whether to use performance modeling to dynamically optimizations to apply to the input pipeline graph or whether to use
tune the parallelism of operations such as `tf.data.Dataset.map` or performance modeling to dynamically tune the parallelism of operations such as
`tf.data.Dataset.interleave`. `tf.data.Dataset.map` or `tf.data.Dataset.interleave`.
After constructing an `Options` object, use `dataset.with_options(options)` to The options are set for the entire dataset and are carried over to datasets
apply the options to a dataset. created through tf.data transformations.
>>> dataset = tf.data.Dataset.range(3) The options can be set either by mutating the object returned by
`tf.data.Dataset.options()` or by constructing an `Options` object and using
the `tf.data.Dataset.with_options(options)` transformation, which returns a
dataset with the options set.
>>> dataset = tf.data.Dataset.range(42)
>>> dataset.options().experimental_deterministic = False
>>> print(dataset.options().experimental_deterministic)
False
>>> dataset = tf.data.Dataset.range(42)
>>> options = tf.data.Options() >>> options = tf.data.Options()
>>> # Set options here. >>> options.experimental_deterministic = False
>>> dataset = dataset.with_options(options) >>> dataset = dataset.with_options(options)
>>> print(dataset.options().experimental_deterministic)
False
Note: A known limitation of the `tf.data.Options` implementation is that the
options are not preserved across tf.function boundaries. In particular, to
set options for a dataset that is iterated within a tf.function, the options
need to be set within the same tf.function.
""" """
experimental_deterministic = options_lib.create_option( experimental_deterministic = options_lib.create_option(
@ -2968,17 +2985,15 @@ class Options(options_lib.OptionsBase):
def merge(self, options): def merge(self, options):
"""Merges itself with the given `tf.data.Options`. """Merges itself with the given `tf.data.Options`.
The given `tf.data.Options` can be merged as long as there does not exist an If this object and the `options` to merge set an option differently, a
attribute that is set to different values in `self` and `options`. warning is generated and this object's value is updated with the `options`
object's value.
Args: Args:
options: a `tf.data.Options` to merge with options: a `tf.data.Options` to merge with
Raises:
ValueError: if the given `tf.data.Options` cannot be merged
Returns: Returns:
New `tf.data.Options()` object which is the result of merging self with New `tf.data.Options` object which is the result of merging self with
the input `tf.data.Options`. the input `tf.data.Options`.
""" """
return options_lib.merge_options(self, options) return options_lib.merge_options(self, options)

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import collections import collections
from absl import logging
def _internal_attr_name(name): def _internal_attr_name(name):
return "_" + name return "_" + name
@ -98,23 +100,23 @@ def merge_options(*options_list):
"""Merges the given options, returning the result as a new options object. """Merges the given options, returning the result as a new options object.
The input arguments are expected to have a matching type that derives from The input arguments are expected to have a matching type that derives from
`OptionsBase` (and thus each represent a set of options). The method outputs `tf.data.OptionsBase` (and thus each represent a set of options). The method
an object of the same type created by merging the sets of options represented outputs an object of the same type created by merging the sets of options
by the input arguments. represented by the input arguments.
The sets of options can be merged as long as there does not exist an option If an option is set to different values by different options objects, the
with different non-default values. result will match the setting of the options object that appears in the input
list last.
If an option is an instance of `OptionsBase` itself, then this method is If an option is an instance of `tf.data.OptionsBase` itself, then this method
applied recursively to the set of options represented by this option. is applied recursively to the set of options represented by this option.
Args: Args:
*options_list: options to merge *options_list: options to merge
Raises: Raises:
TypeError: if the input arguments are incompatible or not derived from TypeError: if the input arguments are incompatible or not derived from
`OptionsBase` `tf.data.OptionsBase`
ValueError: if the given options cannot be merged
Returns: Returns:
A new options object which is the result of merging the given options. A new options object which is the result of merging the given options.
@ -134,7 +136,7 @@ def merge_options(*options_list):
default_options = result_type() default_options = result_type()
result = result_type() result = result_type()
for options in options_list: for options in options_list:
# Iterate over all set options and merge the into the result. # Iterate over all set options and merge them into the result.
for name in options._options: # pylint: disable=protected-access for name in options._options: # pylint: disable=protected-access
this = getattr(result, name) this = getattr(result, name)
that = getattr(options, name) that = getattr(options, name)
@ -146,7 +148,7 @@ def merge_options(*options_list):
elif isinstance(this, OptionsBase): elif isinstance(this, OptionsBase):
setattr(result, name, merge_options(this, that)) setattr(result, name, merge_options(this, that))
elif this != that: elif this != that:
raise ValueError( logging.warning("Changing the value of option %s from %r to %r.", name,
"Cannot merge incompatible values (%r and %r) of option: %s" % this, that)
(this, that, name)) setattr(result, name, that)
return result return result