diff --git a/tensorflow/python/data/kernel_tests/fixed_length_record_dataset_test.py b/tensorflow/python/data/kernel_tests/fixed_length_record_dataset_test.py index 9caf1177ae9..ff89ff80465 100644 --- a/tensorflow/python/data/kernel_tests/fixed_length_record_dataset_test.py +++ b/tensorflow/python/data/kernel_tests/fixed_length_record_dataset_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import gzip import os +import pathlib import zlib from absl.testing import parameterized @@ -190,6 +191,24 @@ class FixedLengthRecordDatasetTest(test_base.DatasetTestBase, r"which is not an exact multiple of the record length \(4 bytes\).") ) + @combinations.generate(test_base.default_test_combinations()) + def testFixedLengthRecordDatasetPathlib(self): + test_filenames = self._createFiles() + test_filenames = [pathlib.Path(f) for f in test_filenames] + dataset = readers.FixedLengthRecordDataset( + test_filenames, + self._record_bytes, + self._header_bytes, + self._footer_bytes, + buffer_size=10, + num_parallel_reads=4) + expected_output = [] + for j in range(self._num_files): + expected_output.extend( + [self._record(j, i) for i in range(self._num_records)]) + self.assertDatasetProduces(dataset, expected_output=expected_output, + assert_items_equal=True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/text_line_dataset_test.py b/tensorflow/python/data/kernel_tests/text_line_dataset_test.py index 2a81dff0058..4e171d2f3ba 100644 --- a/tensorflow/python/data/kernel_tests/text_line_dataset_test.py +++ b/tensorflow/python/data/kernel_tests/text_line_dataset_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import gzip import os +import pathlib import zlib from absl.testing import parameterized @@ -168,6 +169,16 @@ class TextLineDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): open_files = psutil.Process().open_files() self.assertNotIn(filename, [open_file.path for open_file in open_files]) + @combinations.generate(test_base.default_test_combinations()) + def testTextLineDatasetPathlib(self): + files = self._createFiles(1, 5) + files = [pathlib.Path(f) for f in files] + + expected_output = [self._lineText(0, i) for i in range(5)] + ds = readers.TextLineDataset(files) + self.assertDatasetProduces( + ds, expected_output=expected_output, assert_items_equal=True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py b/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py index 792c4926640..a16fa334155 100644 --- a/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py +++ b/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import gzip import os +import pathlib import zlib from absl.testing import parameterized @@ -187,6 +188,15 @@ class TFRecordDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, expected_output=expected_output * 10, assert_items_equal=True) + @combinations.generate(test_base.default_test_combinations()) + def testDatasetPathlib(self): + files = [pathlib.Path(self.test_filenames[0])] + + expected_output = [self._record(0, i) for i in range(self._num_records)] + ds = readers.TFRecordDataset(files) + self.assertDatasetProduces( + ds, expected_output=expected_output, assert_items_equal=True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index 6ece7cc2f01..d69bf6a2297 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -51,11 +51,14 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_shape", - "//tensorflow/python/compat", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:tf2", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", "//tensorflow/python/data/util:convert", - "//tensorflow/python/data/util:structure", ], ) diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index dbc580ce331..4db302be75b 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.python import tf2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import convert @@ -27,11 +29,17 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB +def _normalise_fspath(path): + """Convert pathlib-like objects to str (__fspath__ compatibility, PEP 519).""" + return os.fspath(path) if isinstance(path, os.PathLike) else path + + def _create_or_validate_filenames_dataset(filenames): """Creates (or validates) a dataset of filenames. @@ -52,6 +60,7 @@ def _create_or_validate_filenames_dataset(filenames): "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` " "elements.") else: + filenames = nest.map_structure(_normalise_fspath, filenames) filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string) if filenames.dtype != dtypes.string: raise TypeError(