[tf.data] Solve the stack overflow problem of getting all options on Python side.
PiperOrigin-RevId: 305592326 Change-Id: I77857cdcdf25e7e66cde5dd1ba9449d0df34e587
This commit is contained in:
parent
a70db4d1d1
commit
2c47fddbb8
|
@ -100,6 +100,18 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
self.assertEqual(options1.experimental_threading,
|
||||
threading_options.ThreadingOptions())
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testNestedDataset(self):
|
||||
ds = dataset_ops.Dataset.from_tensors(0)
|
||||
result = ds
|
||||
|
||||
for _ in range(999):
|
||||
result = result.concatenate(ds)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.autotune = False
|
||||
result = result.with_options(options)
|
||||
self.assertDatasetProduces(result, [0]*1000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
@ -194,6 +194,13 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||
name="_variant_tracker")
|
||||
self._graph_attr = ops.get_default_graph()
|
||||
|
||||
# Initialize the options for this dataset and its inputs.
|
||||
self._options_attr = Options()
|
||||
for input_dataset in self._inputs():
|
||||
input_options = input_dataset.options()
|
||||
if input_options is not None:
|
||||
self._options_attr = self._options_attr.merge(input_options)
|
||||
|
||||
@property
|
||||
def _variant_tensor(self):
|
||||
return self._variant_tensor_attr
|
||||
|
@ -332,12 +339,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||
Returns:
|
||||
A `tf.data.Options` object representing the dataset options.
|
||||
"""
|
||||
options = Options()
|
||||
for input_dataset in self._inputs():
|
||||
input_options = input_dataset.options()
|
||||
if input_options is not None:
|
||||
options = options.merge(input_options)
|
||||
return options
|
||||
return self._options_attr
|
||||
|
||||
def _apply_options(self):
|
||||
"""Apply options, such as optimization configuration, to the dataset."""
|
||||
|
@ -4413,16 +4415,16 @@ class _OptionsDataset(UnaryUnchangedStructureDataset):
|
|||
|
||||
def __init__(self, input_dataset, options):
|
||||
self._input_dataset = input_dataset
|
||||
self._options = input_dataset.options()
|
||||
if self._options:
|
||||
self._options = self._options.merge(options)
|
||||
else:
|
||||
self._options = options
|
||||
variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access
|
||||
super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
if self._options_attr:
|
||||
self._options_attr = self._options_attr.merge(options)
|
||||
else:
|
||||
self._options_attr = options
|
||||
|
||||
def options(self):
|
||||
return self._options
|
||||
return self._options_attr
|
||||
|
||||
|
||||
class _ModelDataset(UnaryUnchangedStructureDataset):
|
||||
|
|
Loading…
Reference in New Issue