[tf.data] Make the result of tf.data.Dataset.options() immutable. This change is in preparation of making tf.data.Options() persistent across tf.function boundaries.

PiperOrigin-RevId: 361053121
Change-Id: I9b4ab3592f914e2311381d41b1a7bd11c45830aa
This commit is contained in:
A. Unique TensorFlower 2021-03-04 19:05:21 -08:00 committed by TensorFlower Gardener
parent dd1ce23e93
commit 2e5eba38fd
5 changed files with 41 additions and 19 deletions

View File

@ -38,6 +38,7 @@
multiple input batches should be computed in parallel. With
`num_parallel_calls` set, `deterministic` is used to indicate that
outputs can be obtained in the non-deterministic order.
* Options returned by `tf.data.Dataset.options()` are no longer mutable.
## Bug Fixes and Other Changes

View File

@ -418,3 +418,9 @@ class OptimizationOptions(options.OptionsBase):
self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
def _set_mutable(self, mutable):
"""Change the mutability value to `mutable` on this options and children."""
# pylint: disable=protected-access
object.__setattr__(self, "_mutable", mutable)
self.map_vectorization._set_mutable(mutable)

View File

@ -111,14 +111,18 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
threading_options.ThreadingOptions())
@combinations.generate(test_base.default_test_combinations())
def testMutableOptions(self):
def testMutatingOptionsRaiseValueError(self):
ds = dataset_ops.Dataset.range(0)
ds.options().experimental_optimization.autotune = True
self.assertTrue(ds.options().experimental_optimization.autotune)
options = dataset_ops.Options()
ds = ds.with_options(options)
ds.options().experimental_deterministic = True
self.assertTrue(ds.options().experimental_deterministic)
options1 = dataset_ops.Options()
options1.experimental_slack = True
options2 = dataset_ops.Options()
options2.experimental_optimization.autotune = True
ds = ds.with_options(options1)
ds = ds.map(lambda x: 2 * x)
ds = ds.with_options(options2)
with self.assertRaises(ValueError):
dataset_options = ds.options()
dataset_options.experimental_deterministic = True
@combinations.generate(test_base.eager_only_combinations())
def testNestedDataset(self):

View File

@ -218,6 +218,7 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
input_options = input_dataset.options()
if input_options is not None:
self._options_attr = self._options_attr.merge(input_options)
self._options_attr._set_mutable(False) # pylint: disable=protected-access
@property
def _variant_tensor(self):
@ -2990,16 +2991,10 @@ class Options(options_lib.OptionsBase):
The options are set for the entire dataset and are carried over to datasets
created through tf.data transformations.
The options can be set either by mutating the object returned by
`tf.data.Dataset.options()` or by constructing an `Options` object and using
the `tf.data.Dataset.with_options(options)` transformation, which returns a
The options can be set by constructing an `Options` object and using the
`tf.data.Dataset.with_options(options)` transformation, which returns a
dataset with the options set.
>>> dataset = tf.data.Dataset.range(42)
>>> dataset.options().experimental_deterministic = False
>>> print(dataset.options().experimental_deterministic)
False
>>> dataset = tf.data.Dataset.range(42)
>>> options = tf.data.Options()
>>> options.experimental_deterministic = False
@ -3099,6 +3094,14 @@ class Options(options_lib.OptionsBase):
self.experimental_slack = pb.slack
self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access
def _set_mutable(self, mutable):
"""Change the mutability value to `mutable` on this options and children."""
# pylint: disable=protected-access
object.__setattr__(self, "_mutable", mutable)
self.experimental_distribute._set_mutable(mutable)
self.experimental_optimization._set_mutable(mutable)
self.experimental_threading._set_mutable(mutable)
def _graph_rewrites(self):
"""Produces lists of enabled, disabled, default static graph rewrites.
@ -4665,17 +4668,17 @@ class _OptionsDataset(UnaryUnchangedStructureDataset):
"""An identity `Dataset` that stores options."""
def __init__(self, input_dataset, options):
# pylint: disable=protected-access
self._input_dataset = input_dataset
variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access
variant_tensor = input_dataset._variant_tensor
super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
if self._options_attr:
self._options_attr._set_mutable(True)
self._options_attr = self._options_attr.merge(options)
else:
self._options_attr = options
def options(self):
return self._options_attr
self._options_attr._set_mutable(False)
class _ModelDataset(UnaryUnchangedStructureDataset):

View File

@ -37,6 +37,7 @@ class OptionsBase(object):
def __init__(self):
# NOTE: Cannot use `self._options` here as we override `__setattr__`
object.__setattr__(self, "_options", {})
object.__setattr__(self, "_mutable", True)
def __eq__(self, other):
if not isinstance(other, self.__class__):
@ -53,12 +54,19 @@ class OptionsBase(object):
return NotImplemented
def __setattr__(self, name, value):
if not self._mutable:
raise ValueError("Mutating `tf.data.Options()` returned by "
"`tf.data.Dataset.options()` has no effect.")
if hasattr(self, name):
object.__setattr__(self, name, value)
else:
raise AttributeError(
"Cannot set the property %s on %s." % (name, type(self).__name__))
def _set_mutable(self, mutable):
"""Change the mutability property to `mutable`."""
object.__setattr__(self, "_mutable", mutable)
def _to_proto(self):
"""Convert options to protocol buffer."""
raise NotImplementedError("%s._to_proto()" % type(self).__name__)