[tf.data] Make the result of tf.data.Dataset.options()
immutable. This change is in preparation of making tf.data.Options()
persistent across tf.function
boundaries.
PiperOrigin-RevId: 361053121 Change-Id: I9b4ab3592f914e2311381d41b1a7bd11c45830aa
This commit is contained in:
parent
dd1ce23e93
commit
2e5eba38fd
@ -38,6 +38,7 @@
|
||||
multiple input batches should be computed in parallel. With
|
||||
`num_parallel_calls` set, `deterministic` is used to indicate that
|
||||
outputs can be obtained in the non-deterministic order.
|
||||
* Options returned by `tf.data.Dataset.options()` are no longer mutable.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
|
@ -418,3 +418,9 @@ class OptimizationOptions(options.OptionsBase):
|
||||
self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops
|
||||
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
|
||||
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
|
||||
|
||||
def _set_mutable(self, mutable):
|
||||
"""Change the mutability value to `mutable` on this options and children."""
|
||||
# pylint: disable=protected-access
|
||||
object.__setattr__(self, "_mutable", mutable)
|
||||
self.map_vectorization._set_mutable(mutable)
|
||||
|
@ -111,14 +111,18 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
threading_options.ThreadingOptions())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testMutableOptions(self):
|
||||
def testMutatingOptionsRaiseValueError(self):
|
||||
ds = dataset_ops.Dataset.range(0)
|
||||
ds.options().experimental_optimization.autotune = True
|
||||
self.assertTrue(ds.options().experimental_optimization.autotune)
|
||||
options = dataset_ops.Options()
|
||||
ds = ds.with_options(options)
|
||||
ds.options().experimental_deterministic = True
|
||||
self.assertTrue(ds.options().experimental_deterministic)
|
||||
options1 = dataset_ops.Options()
|
||||
options1.experimental_slack = True
|
||||
options2 = dataset_ops.Options()
|
||||
options2.experimental_optimization.autotune = True
|
||||
ds = ds.with_options(options1)
|
||||
ds = ds.map(lambda x: 2 * x)
|
||||
ds = ds.with_options(options2)
|
||||
with self.assertRaises(ValueError):
|
||||
dataset_options = ds.options()
|
||||
dataset_options.experimental_deterministic = True
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testNestedDataset(self):
|
||||
|
@ -218,6 +218,7 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
|
||||
input_options = input_dataset.options()
|
||||
if input_options is not None:
|
||||
self._options_attr = self._options_attr.merge(input_options)
|
||||
self._options_attr._set_mutable(False) # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _variant_tensor(self):
|
||||
@ -2990,16 +2991,10 @@ class Options(options_lib.OptionsBase):
|
||||
The options are set for the entire dataset and are carried over to datasets
|
||||
created through tf.data transformations.
|
||||
|
||||
The options can be set either by mutating the object returned by
|
||||
`tf.data.Dataset.options()` or by constructing an `Options` object and using
|
||||
the `tf.data.Dataset.with_options(options)` transformation, which returns a
|
||||
The options can be set by constructing an `Options` object and using the
|
||||
`tf.data.Dataset.with_options(options)` transformation, which returns a
|
||||
dataset with the options set.
|
||||
|
||||
>>> dataset = tf.data.Dataset.range(42)
|
||||
>>> dataset.options().experimental_deterministic = False
|
||||
>>> print(dataset.options().experimental_deterministic)
|
||||
False
|
||||
|
||||
>>> dataset = tf.data.Dataset.range(42)
|
||||
>>> options = tf.data.Options()
|
||||
>>> options.experimental_deterministic = False
|
||||
@ -3099,6 +3094,14 @@ class Options(options_lib.OptionsBase):
|
||||
self.experimental_slack = pb.slack
|
||||
self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access
|
||||
|
||||
def _set_mutable(self, mutable):
|
||||
"""Change the mutability value to `mutable` on this options and children."""
|
||||
# pylint: disable=protected-access
|
||||
object.__setattr__(self, "_mutable", mutable)
|
||||
self.experimental_distribute._set_mutable(mutable)
|
||||
self.experimental_optimization._set_mutable(mutable)
|
||||
self.experimental_threading._set_mutable(mutable)
|
||||
|
||||
def _graph_rewrites(self):
|
||||
"""Produces lists of enabled, disabled, default static graph rewrites.
|
||||
|
||||
@ -4665,17 +4668,17 @@ class _OptionsDataset(UnaryUnchangedStructureDataset):
|
||||
"""An identity `Dataset` that stores options."""
|
||||
|
||||
def __init__(self, input_dataset, options):
|
||||
# pylint: disable=protected-access
|
||||
self._input_dataset = input_dataset
|
||||
variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access
|
||||
variant_tensor = input_dataset._variant_tensor
|
||||
super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
if self._options_attr:
|
||||
self._options_attr._set_mutable(True)
|
||||
self._options_attr = self._options_attr.merge(options)
|
||||
else:
|
||||
self._options_attr = options
|
||||
|
||||
def options(self):
|
||||
return self._options_attr
|
||||
self._options_attr._set_mutable(False)
|
||||
|
||||
|
||||
class _ModelDataset(UnaryUnchangedStructureDataset):
|
||||
|
@ -37,6 +37,7 @@ class OptionsBase(object):
|
||||
def __init__(self):
|
||||
# NOTE: Cannot use `self._options` here as we override `__setattr__`
|
||||
object.__setattr__(self, "_options", {})
|
||||
object.__setattr__(self, "_mutable", True)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
@ -53,12 +54,19 @@ class OptionsBase(object):
|
||||
return NotImplemented
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if not self._mutable:
|
||||
raise ValueError("Mutating `tf.data.Options()` returned by "
|
||||
"`tf.data.Dataset.options()` has no effect.")
|
||||
if hasattr(self, name):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
raise AttributeError(
|
||||
"Cannot set the property %s on %s." % (name, type(self).__name__))
|
||||
|
||||
def _set_mutable(self, mutable):
|
||||
"""Change the mutability property to `mutable`."""
|
||||
object.__setattr__(self, "_mutable", mutable)
|
||||
|
||||
def _to_proto(self):
|
||||
"""Convert options to protocol buffer."""
|
||||
raise NotImplementedError("%s._to_proto()" % type(self).__name__)
|
||||
|
Loading…
Reference in New Issue
Block a user