[tf.data] Solve the stack overflow problem of getting all options on Python side.
PiperOrigin-RevId: 305592326 Change-Id: I77857cdcdf25e7e66cde5dd1ba9449d0df34e587
This commit is contained in:
parent
a70db4d1d1
commit
2c47fddbb8
|
@ -100,6 +100,18 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
self.assertEqual(options1.experimental_threading,
|
self.assertEqual(options1.experimental_threading,
|
||||||
threading_options.ThreadingOptions())
|
threading_options.ThreadingOptions())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testNestedDataset(self):
|
||||||
|
ds = dataset_ops.Dataset.from_tensors(0)
|
||||||
|
result = ds
|
||||||
|
|
||||||
|
for _ in range(999):
|
||||||
|
result = result.concatenate(ds)
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_optimization.autotune = False
|
||||||
|
result = result.with_options(options)
|
||||||
|
self.assertDatasetProduces(result, [0]*1000)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
@ -194,6 +194,13 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||||
name="_variant_tracker")
|
name="_variant_tracker")
|
||||||
self._graph_attr = ops.get_default_graph()
|
self._graph_attr = ops.get_default_graph()
|
||||||
|
|
||||||
|
# Initialize the options for this dataset and its inputs.
|
||||||
|
self._options_attr = Options()
|
||||||
|
for input_dataset in self._inputs():
|
||||||
|
input_options = input_dataset.options()
|
||||||
|
if input_options is not None:
|
||||||
|
self._options_attr = self._options_attr.merge(input_options)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _variant_tensor(self):
|
def _variant_tensor(self):
|
||||||
return self._variant_tensor_attr
|
return self._variant_tensor_attr
|
||||||
|
@ -332,12 +339,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||||
Returns:
|
Returns:
|
||||||
A `tf.data.Options` object representing the dataset options.
|
A `tf.data.Options` object representing the dataset options.
|
||||||
"""
|
"""
|
||||||
options = Options()
|
return self._options_attr
|
||||||
for input_dataset in self._inputs():
|
|
||||||
input_options = input_dataset.options()
|
|
||||||
if input_options is not None:
|
|
||||||
options = options.merge(input_options)
|
|
||||||
return options
|
|
||||||
|
|
||||||
def _apply_options(self):
|
def _apply_options(self):
|
||||||
"""Apply options, such as optimization configuration, to the dataset."""
|
"""Apply options, such as optimization configuration, to the dataset."""
|
||||||
|
@ -4413,16 +4415,16 @@ class _OptionsDataset(UnaryUnchangedStructureDataset):
|
||||||
|
|
||||||
def __init__(self, input_dataset, options):
|
def __init__(self, input_dataset, options):
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
self._options = input_dataset.options()
|
|
||||||
if self._options:
|
|
||||||
self._options = self._options.merge(options)
|
|
||||||
else:
|
|
||||||
self._options = options
|
|
||||||
variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access
|
variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access
|
||||||
super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
|
super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
|
if self._options_attr:
|
||||||
|
self._options_attr = self._options_attr.merge(options)
|
||||||
|
else:
|
||||||
|
self._options_attr = options
|
||||||
|
|
||||||
def options(self):
|
def options(self):
|
||||||
return self._options
|
return self._options_attr
|
||||||
|
|
||||||
|
|
||||||
class _ModelDataset(UnaryUnchangedStructureDataset):
|
class _ModelDataset(UnaryUnchangedStructureDataset):
|
||||||
|
|
Loading…
Reference in New Issue