From 2c47fddbb807233ee491eea981365839effede00 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Apr 2020 17:57:27 -0700 Subject: [PATCH] [tf.data] Solve the stack overflow problem of getting all options on Python side. PiperOrigin-RevId: 305592326 Change-Id: I77857cdcdf25e7e66cde5dd1ba9449d0df34e587 --- .../python/data/kernel_tests/options_test.py | 12 +++++++++ tensorflow/python/data/ops/dataset_ops.py | 26 ++++++++++--------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py index b38d008b833..dea217367dc 100644 --- a/tensorflow/python/data/kernel_tests/options_test.py +++ b/tensorflow/python/data/kernel_tests/options_test.py @@ -100,6 +100,18 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(options1.experimental_threading, threading_options.ThreadingOptions()) + @combinations.generate(test_base.eager_only_combinations()) + def testNestedDataset(self): + ds = dataset_ops.Dataset.from_tensors(0) + result = ds + + for _ in range(999): + result = result.concatenate(ds) + options = dataset_ops.Options() + options.experimental_optimization.autotune = False + result = result.with_options(options) + self.assertDatasetProduces(result, [0]*1000) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index b23df3672c9..119d557a1b8 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -194,6 +194,13 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): name="_variant_tracker") self._graph_attr = ops.get_default_graph() + # Initialize the options for this dataset and its inputs. + self._options_attr = Options() + for input_dataset in self._inputs(): + input_options = input_dataset.options() + if input_options is not None: + self._options_attr = self._options_attr.merge(input_options) + @property def _variant_tensor(self): return self._variant_tensor_attr @@ -332,12 +339,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): Returns: A `tf.data.Options` object representing the dataset options. """ - options = Options() - for input_dataset in self._inputs(): - input_options = input_dataset.options() - if input_options is not None: - options = options.merge(input_options) - return options + return self._options_attr def _apply_options(self): """Apply options, such as optimization configuration, to the dataset.""" @@ -4413,16 +4415,16 @@ class _OptionsDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset, options): self._input_dataset = input_dataset - self._options = input_dataset.options() - if self._options: - self._options = self._options.merge(options) - else: - self._options = options variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) + if self._options_attr: + self._options_attr = self._options_attr.merge(options) + else: + self._options_attr = options + def options(self): - return self._options + return self._options_attr class _ModelDataset(UnaryUnchangedStructureDataset):