From 2e5eba38fd410faf11bce6ba618cd75533f8a55f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Thu, 4 Mar 2021 19:05:21 -0800 Subject: [PATCH] [tf.data] Make the result of `tf.data.Dataset.options()` immutable. This change is in preparation of making `tf.data.Options()` persistent across `tf.function` boundaries. PiperOrigin-RevId: 361053121 Change-Id: I9b4ab3592f914e2311381d41b1a7bd11c45830aa --- RELEASE.md | 1 + .../experimental/ops/optimization_options.py | 6 +++++ .../python/data/kernel_tests/options_test.py | 18 ++++++++----- tensorflow/python/data/ops/dataset_ops.py | 27 ++++++++++--------- tensorflow/python/data/util/options.py | 8 ++++++ 5 files changed, 41 insertions(+), 19 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index d853fa284b7..353bbf81437 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -38,6 +38,7 @@ multiple input batches should be computed in parallel. With `num_parallel_calls` set, `deterministic` is used to indicate that outputs can be obtained in the non-deterministic order. + * Options returned by `tf.data.Dataset.options()` are no longer mutable. ## Bug Fixes and Other Changes diff --git a/tensorflow/python/data/experimental/ops/optimization_options.py b/tensorflow/python/data/experimental/ops/optimization_options.py index 5e399fde217..70fdc573325 100644 --- a/tensorflow/python/data/experimental/ops/optimization_options.py +++ b/tensorflow/python/data/experimental/ops/optimization_options.py @@ -418,3 +418,9 @@ class OptimizationOptions(options.OptionsBase): self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None: self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion + + def _set_mutable(self, mutable): + """Change the mutability value to `mutable` on this options and children.""" + # pylint: disable=protected-access + object.__setattr__(self, "_mutable", mutable) + self.map_vectorization._set_mutable(mutable) diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py index efd3f598a1f..a0572dd7004 100644 --- a/tensorflow/python/data/kernel_tests/options_test.py +++ b/tensorflow/python/data/kernel_tests/options_test.py @@ -111,14 +111,18 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase): threading_options.ThreadingOptions()) @combinations.generate(test_base.default_test_combinations()) - def testMutableOptions(self): + def testMutatingOptionsRaiseValueError(self): ds = dataset_ops.Dataset.range(0) - ds.options().experimental_optimization.autotune = True - self.assertTrue(ds.options().experimental_optimization.autotune) - options = dataset_ops.Options() - ds = ds.with_options(options) - ds.options().experimental_deterministic = True - self.assertTrue(ds.options().experimental_deterministic) + options1 = dataset_ops.Options() + options1.experimental_slack = True + options2 = dataset_ops.Options() + options2.experimental_optimization.autotune = True + ds = ds.with_options(options1) + ds = ds.map(lambda x: 2 * x) + ds = ds.with_options(options2) + with self.assertRaises(ValueError): + dataset_options = ds.options() + dataset_options.experimental_deterministic = True @combinations.generate(test_base.eager_only_combinations()) def testNestedDataset(self): diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index a26f8a2cdf9..fcfb55c2112 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -218,6 +218,7 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, input_options = input_dataset.options() if input_options is not None: self._options_attr = self._options_attr.merge(input_options) + self._options_attr._set_mutable(False) # pylint: disable=protected-access @property def _variant_tensor(self): @@ -2990,16 +2991,10 @@ class Options(options_lib.OptionsBase): The options are set for the entire dataset and are carried over to datasets created through tf.data transformations. - The options can be set either by mutating the object returned by - `tf.data.Dataset.options()` or by constructing an `Options` object and using - the `tf.data.Dataset.with_options(options)` transformation, which returns a + The options can be set by constructing an `Options` object and using the + `tf.data.Dataset.with_options(options)` transformation, which returns a dataset with the options set. - >>> dataset = tf.data.Dataset.range(42) - >>> dataset.options().experimental_deterministic = False - >>> print(dataset.options().experimental_deterministic) - False - >>> dataset = tf.data.Dataset.range(42) >>> options = tf.data.Options() >>> options.experimental_deterministic = False @@ -3099,6 +3094,14 @@ class Options(options_lib.OptionsBase): self.experimental_slack = pb.slack self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access + def _set_mutable(self, mutable): + """Change the mutability value to `mutable` on this options and children.""" + # pylint: disable=protected-access + object.__setattr__(self, "_mutable", mutable) + self.experimental_distribute._set_mutable(mutable) + self.experimental_optimization._set_mutable(mutable) + self.experimental_threading._set_mutable(mutable) + def _graph_rewrites(self): """Produces lists of enabled, disabled, default static graph rewrites. @@ -4665,17 +4668,17 @@ class _OptionsDataset(UnaryUnchangedStructureDataset): """An identity `Dataset` that stores options.""" def __init__(self, input_dataset, options): + # pylint: disable=protected-access self._input_dataset = input_dataset - variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access + variant_tensor = input_dataset._variant_tensor super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) if self._options_attr: + self._options_attr._set_mutable(True) self._options_attr = self._options_attr.merge(options) else: self._options_attr = options - - def options(self): - return self._options_attr + self._options_attr._set_mutable(False) class _ModelDataset(UnaryUnchangedStructureDataset): diff --git a/tensorflow/python/data/util/options.py b/tensorflow/python/data/util/options.py index 3df6f000bb6..20e4c625ba2 100644 --- a/tensorflow/python/data/util/options.py +++ b/tensorflow/python/data/util/options.py @@ -37,6 +37,7 @@ class OptionsBase(object): def __init__(self): # NOTE: Cannot use `self._options` here as we override `__setattr__` object.__setattr__(self, "_options", {}) + object.__setattr__(self, "_mutable", True) def __eq__(self, other): if not isinstance(other, self.__class__): @@ -53,12 +54,19 @@ class OptionsBase(object): return NotImplemented def __setattr__(self, name, value): + if not self._mutable: + raise ValueError("Mutating `tf.data.Options()` returned by " + "`tf.data.Dataset.options()` has no effect.") if hasattr(self, name): object.__setattr__(self, name, value) else: raise AttributeError( "Cannot set the property %s on %s." % (name, type(self).__name__)) + def _set_mutable(self, mutable): + """Change the mutability property to `mutable`.""" + object.__setattr__(self, "_mutable", mutable) + def _to_proto(self): """Convert options to protocol buffer.""" raise NotImplementedError("%s._to_proto()" % type(self).__name__)