From 2e5eba38fd410faf11bce6ba618cd75533f8a55f Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 4 Mar 2021 19:05:21 -0800
Subject: [PATCH] [tf.data] Make the result of `tf.data.Dataset.options()`
 immutable. This change is in preparation of making `tf.data.Options()`
 persistent across `tf.function` boundaries.

PiperOrigin-RevId: 361053121
Change-Id: I9b4ab3592f914e2311381d41b1a7bd11c45830aa
---
 RELEASE.md                                    |  1 +
 .../experimental/ops/optimization_options.py  |  6 +++++
 .../python/data/kernel_tests/options_test.py  | 18 ++++++++-----
 tensorflow/python/data/ops/dataset_ops.py     | 27 ++++++++++---------
 tensorflow/python/data/util/options.py        |  8 ++++++
 5 files changed, 41 insertions(+), 19 deletions(-)

diff --git a/RELEASE.md b/RELEASE.md
index d853fa284b7..353bbf81437 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -38,6 +38,7 @@
         multiple input batches should be computed in parallel. With
         `num_parallel_calls` set, `deterministic` is used to indicate that
         outputs can be obtained in the non-deterministic order.
+    *   Options returned by `tf.data.Dataset.options()` are no longer mutable.
 
 ## Bug Fixes and Other Changes
 
diff --git a/tensorflow/python/data/experimental/ops/optimization_options.py b/tensorflow/python/data/experimental/ops/optimization_options.py
index 5e399fde217..70fdc573325 100644
--- a/tensorflow/python/data/experimental/ops/optimization_options.py
+++ b/tensorflow/python/data/experimental/ops/optimization_options.py
@@ -418,3 +418,9 @@ class OptimizationOptions(options.OptionsBase):
       self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops
     if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
       self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
+
+  def _set_mutable(self, mutable):
+    """Change the mutability value to `mutable` on this options and children."""
+    # pylint: disable=protected-access
+    object.__setattr__(self, "_mutable", mutable)
+    self.map_vectorization._set_mutable(mutable)
diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py
index efd3f598a1f..a0572dd7004 100644
--- a/tensorflow/python/data/kernel_tests/options_test.py
+++ b/tensorflow/python/data/kernel_tests/options_test.py
@@ -111,14 +111,18 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
                      threading_options.ThreadingOptions())
 
   @combinations.generate(test_base.default_test_combinations())
-  def testMutableOptions(self):
+  def testMutatingOptionsRaiseValueError(self):
     ds = dataset_ops.Dataset.range(0)
-    ds.options().experimental_optimization.autotune = True
-    self.assertTrue(ds.options().experimental_optimization.autotune)
-    options = dataset_ops.Options()
-    ds = ds.with_options(options)
-    ds.options().experimental_deterministic = True
-    self.assertTrue(ds.options().experimental_deterministic)
+    options1 = dataset_ops.Options()
+    options1.experimental_slack = True
+    options2 = dataset_ops.Options()
+    options2.experimental_optimization.autotune = True
+    ds = ds.with_options(options1)
+    ds = ds.map(lambda x: 2 * x)
+    ds = ds.with_options(options2)
+    with self.assertRaises(ValueError):
+      dataset_options = ds.options()
+      dataset_options.experimental_deterministic = True
 
   @combinations.generate(test_base.eager_only_combinations())
   def testNestedDataset(self):
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index a26f8a2cdf9..fcfb55c2112 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -218,6 +218,7 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
       input_options = input_dataset.options()
       if input_options is not None:
         self._options_attr = self._options_attr.merge(input_options)
+    self._options_attr._set_mutable(False)  # pylint: disable=protected-access
 
   @property
   def _variant_tensor(self):
@@ -2990,16 +2991,10 @@ class Options(options_lib.OptionsBase):
   The options are set for the entire dataset and are carried over to datasets
   created through tf.data transformations.
 
-  The options can be set either by mutating the object returned by
-  `tf.data.Dataset.options()` or by constructing an `Options` object and using
-  the `tf.data.Dataset.with_options(options)` transformation, which returns a
+  The options can be set by constructing an `Options` object and using the
+  `tf.data.Dataset.with_options(options)` transformation, which returns a
   dataset with the options set.
 
-  >>> dataset = tf.data.Dataset.range(42)
-  >>> dataset.options().experimental_deterministic = False
-  >>> print(dataset.options().experimental_deterministic)
-  False
-
   >>> dataset = tf.data.Dataset.range(42)
   >>> options = tf.data.Options()
   >>> options.experimental_deterministic = False
@@ -3099,6 +3094,14 @@ class Options(options_lib.OptionsBase):
       self.experimental_slack = pb.slack
     self.experimental_threading._from_proto(pb.threading_options)  # pylint: disable=protected-access
 
+  def _set_mutable(self, mutable):
+    """Change the mutability value to `mutable` on this options and children."""
+    # pylint: disable=protected-access
+    object.__setattr__(self, "_mutable", mutable)
+    self.experimental_distribute._set_mutable(mutable)
+    self.experimental_optimization._set_mutable(mutable)
+    self.experimental_threading._set_mutable(mutable)
+
   def _graph_rewrites(self):
     """Produces lists of enabled, disabled, default static graph rewrites.
 
@@ -4665,17 +4668,17 @@ class _OptionsDataset(UnaryUnchangedStructureDataset):
   """An identity `Dataset` that stores options."""
 
   def __init__(self, input_dataset, options):
+    # pylint: disable=protected-access
     self._input_dataset = input_dataset
-    variant_tensor = input_dataset._variant_tensor  # pylint: disable=protected-access
+    variant_tensor = input_dataset._variant_tensor
     super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
 
     if self._options_attr:
+      self._options_attr._set_mutable(True)
       self._options_attr = self._options_attr.merge(options)
     else:
       self._options_attr = options
-
-  def options(self):
-    return self._options_attr
+    self._options_attr._set_mutable(False)
 
 
 class _ModelDataset(UnaryUnchangedStructureDataset):
diff --git a/tensorflow/python/data/util/options.py b/tensorflow/python/data/util/options.py
index 3df6f000bb6..20e4c625ba2 100644
--- a/tensorflow/python/data/util/options.py
+++ b/tensorflow/python/data/util/options.py
@@ -37,6 +37,7 @@ class OptionsBase(object):
   def __init__(self):
     # NOTE: Cannot use `self._options` here as we override `__setattr__`
     object.__setattr__(self, "_options", {})
+    object.__setattr__(self, "_mutable", True)
 
   def __eq__(self, other):
     if not isinstance(other, self.__class__):
@@ -53,12 +54,19 @@ class OptionsBase(object):
       return NotImplemented
 
   def __setattr__(self, name, value):
+    if not self._mutable:
+      raise ValueError("Mutating `tf.data.Options()` returned by "
+                       "`tf.data.Dataset.options()` has no effect.")
     if hasattr(self, name):
       object.__setattr__(self, name, value)
     else:
       raise AttributeError(
           "Cannot set the property %s on %s." % (name, type(self).__name__))
 
+  def _set_mutable(self, mutable):
+    """Change the mutability property to `mutable`."""
+    object.__setattr__(self, "_mutable", mutable)
+
   def _to_proto(self):
     """Convert options to protocol buffer."""
     raise NotImplementedError("%s._to_proto()" % type(self).__name__)