[tf.data] Make the result of tf.data.Dataset.options() immutable. This change is in preparation of making tf.data.Options() persistent across tf.function boundaries.

PiperOrigin-RevId: 361053121
Change-Id: I9b4ab3592f914e2311381d41b1a7bd11c45830aa
This commit is contained in:
A. Unique TensorFlower 2021-03-04 19:05:21 -08:00 committed by TensorFlower Gardener
parent dd1ce23e93
commit 2e5eba38fd
5 changed files with 41 additions and 19 deletions

View File

@ -38,6 +38,7 @@
multiple input batches should be computed in parallel. With multiple input batches should be computed in parallel. With
`num_parallel_calls` set, `deterministic` is used to indicate that `num_parallel_calls` set, `deterministic` is used to indicate that
outputs can be obtained in the non-deterministic order. outputs can be obtained in the non-deterministic order.
* Options returned by `tf.data.Dataset.options()` are no longer mutable.
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes

View File

@ -418,3 +418,9 @@ class OptimizationOptions(options.OptionsBase):
self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None: if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
def _set_mutable(self, mutable):
"""Change the mutability value to `mutable` on this options and children."""
# pylint: disable=protected-access
object.__setattr__(self, "_mutable", mutable)
self.map_vectorization._set_mutable(mutable)

View File

@ -111,14 +111,18 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
threading_options.ThreadingOptions()) threading_options.ThreadingOptions())
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testMutableOptions(self): def testMutatingOptionsRaiseValueError(self):
ds = dataset_ops.Dataset.range(0) ds = dataset_ops.Dataset.range(0)
ds.options().experimental_optimization.autotune = True options1 = dataset_ops.Options()
self.assertTrue(ds.options().experimental_optimization.autotune) options1.experimental_slack = True
options = dataset_ops.Options() options2 = dataset_ops.Options()
ds = ds.with_options(options) options2.experimental_optimization.autotune = True
ds.options().experimental_deterministic = True ds = ds.with_options(options1)
self.assertTrue(ds.options().experimental_deterministic) ds = ds.map(lambda x: 2 * x)
ds = ds.with_options(options2)
with self.assertRaises(ValueError):
dataset_options = ds.options()
dataset_options.experimental_deterministic = True
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testNestedDataset(self): def testNestedDataset(self):

View File

@ -218,6 +218,7 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
input_options = input_dataset.options() input_options = input_dataset.options()
if input_options is not None: if input_options is not None:
self._options_attr = self._options_attr.merge(input_options) self._options_attr = self._options_attr.merge(input_options)
self._options_attr._set_mutable(False) # pylint: disable=protected-access
@property @property
def _variant_tensor(self): def _variant_tensor(self):
@ -2990,16 +2991,10 @@ class Options(options_lib.OptionsBase):
The options are set for the entire dataset and are carried over to datasets The options are set for the entire dataset and are carried over to datasets
created through tf.data transformations. created through tf.data transformations.
The options can be set either by mutating the object returned by The options can be set by constructing an `Options` object and using the
`tf.data.Dataset.options()` or by constructing an `Options` object and using `tf.data.Dataset.with_options(options)` transformation, which returns a
the `tf.data.Dataset.with_options(options)` transformation, which returns a
dataset with the options set. dataset with the options set.
>>> dataset = tf.data.Dataset.range(42)
>>> dataset.options().experimental_deterministic = False
>>> print(dataset.options().experimental_deterministic)
False
>>> dataset = tf.data.Dataset.range(42) >>> dataset = tf.data.Dataset.range(42)
>>> options = tf.data.Options() >>> options = tf.data.Options()
>>> options.experimental_deterministic = False >>> options.experimental_deterministic = False
@ -3099,6 +3094,14 @@ class Options(options_lib.OptionsBase):
self.experimental_slack = pb.slack self.experimental_slack = pb.slack
self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access
def _set_mutable(self, mutable):
"""Change the mutability value to `mutable` on this options and children."""
# pylint: disable=protected-access
object.__setattr__(self, "_mutable", mutable)
self.experimental_distribute._set_mutable(mutable)
self.experimental_optimization._set_mutable(mutable)
self.experimental_threading._set_mutable(mutable)
def _graph_rewrites(self): def _graph_rewrites(self):
"""Produces lists of enabled, disabled, default static graph rewrites. """Produces lists of enabled, disabled, default static graph rewrites.
@ -4665,17 +4668,17 @@ class _OptionsDataset(UnaryUnchangedStructureDataset):
"""An identity `Dataset` that stores options.""" """An identity `Dataset` that stores options."""
def __init__(self, input_dataset, options): def __init__(self, input_dataset, options):
# pylint: disable=protected-access
self._input_dataset = input_dataset self._input_dataset = input_dataset
variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access variant_tensor = input_dataset._variant_tensor
super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
if self._options_attr: if self._options_attr:
self._options_attr._set_mutable(True)
self._options_attr = self._options_attr.merge(options) self._options_attr = self._options_attr.merge(options)
else: else:
self._options_attr = options self._options_attr = options
self._options_attr._set_mutable(False)
def options(self):
return self._options_attr
class _ModelDataset(UnaryUnchangedStructureDataset): class _ModelDataset(UnaryUnchangedStructureDataset):

View File

@ -37,6 +37,7 @@ class OptionsBase(object):
def __init__(self): def __init__(self):
# NOTE: Cannot use `self._options` here as we override `__setattr__` # NOTE: Cannot use `self._options` here as we override `__setattr__`
object.__setattr__(self, "_options", {}) object.__setattr__(self, "_options", {})
object.__setattr__(self, "_mutable", True)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
@ -53,12 +54,19 @@ class OptionsBase(object):
return NotImplemented return NotImplemented
def __setattr__(self, name, value): def __setattr__(self, name, value):
if not self._mutable:
raise ValueError("Mutating `tf.data.Options()` returned by "
"`tf.data.Dataset.options()` has no effect.")
if hasattr(self, name): if hasattr(self, name):
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
else: else:
raise AttributeError( raise AttributeError(
"Cannot set the property %s on %s." % (name, type(self).__name__)) "Cannot set the property %s on %s." % (name, type(self).__name__))
def _set_mutable(self, mutable):
"""Change the mutability property to `mutable`."""
object.__setattr__(self, "_mutable", mutable)
def _to_proto(self): def _to_proto(self):
"""Convert options to protocol buffer.""" """Convert options to protocol buffer."""
raise NotImplementedError("%s._to_proto()" % type(self).__name__) raise NotImplementedError("%s._to_proto()" % type(self).__name__)