[tf.data] Solve the stack overflow problem of getting all options on Python side.

PiperOrigin-RevId: 305592326
Change-Id: I77857cdcdf25e7e66cde5dd1ba9449d0df34e587
This commit is contained in:
A. Unique TensorFlower 2020-04-08 17:57:27 -07:00 committed by TensorFlower Gardener
parent a70db4d1d1
commit 2c47fddbb8
2 changed files with 26 additions and 12 deletions

View File

@ -100,6 +100,18 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(options1.experimental_threading,
threading_options.ThreadingOptions())
@combinations.generate(test_base.eager_only_combinations())
def testNestedDataset(self):
ds = dataset_ops.Dataset.from_tensors(0)
result = ds
for _ in range(999):
result = result.concatenate(ds)
options = dataset_ops.Options()
options.experimental_optimization.autotune = False
result = result.with_options(options)
self.assertDatasetProduces(result, [0]*1000)
if __name__ == "__main__":
test.main()

View File

@ -194,6 +194,13 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
name="_variant_tracker")
self._graph_attr = ops.get_default_graph()
# Initialize the options for this dataset and its inputs.
self._options_attr = Options()
for input_dataset in self._inputs():
input_options = input_dataset.options()
if input_options is not None:
self._options_attr = self._options_attr.merge(input_options)
@property
def _variant_tensor(self):
return self._variant_tensor_attr
@ -332,12 +339,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
Returns:
A `tf.data.Options` object representing the dataset options.
"""
options = Options()
for input_dataset in self._inputs():
input_options = input_dataset.options()
if input_options is not None:
options = options.merge(input_options)
return options
return self._options_attr
def _apply_options(self):
"""Apply options, such as optimization configuration, to the dataset."""
@ -4413,16 +4415,16 @@ class _OptionsDataset(UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, options):
self._input_dataset = input_dataset
self._options = input_dataset.options()
if self._options:
self._options = self._options.merge(options)
else:
self._options = options
variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access
super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
if self._options_attr:
self._options_attr = self._options_attr.merge(options)
else:
self._options_attr = options
def options(self):
return self._options
return self._options_attr
class _ModelDataset(UnaryUnchangedStructureDataset):