Add __fspath__ support (PEP 519) to files datasets

PiperOrigin-RevId: 340292396
Change-Id: I876a8f2c0ad3e72663ae13ab6456da4e76165726
This commit is contained in:
Etienne Pot 2020-11-02 12:30:57 -08:00 committed by TensorFlower Gardener
parent c37d199bd6
commit 794aa738cd
5 changed files with 54 additions and 2 deletions

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import gzip
import os
import pathlib
import zlib
from absl.testing import parameterized
@ -190,6 +191,24 @@ class FixedLengthRecordDatasetTest(test_base.DatasetTestBase,
r"which is not an exact multiple of the record length \(4 bytes\).")
)
@combinations.generate(test_base.default_test_combinations())
def testFixedLengthRecordDatasetPathlib(self):
test_filenames = self._createFiles()
test_filenames = [pathlib.Path(f) for f in test_filenames]
dataset = readers.FixedLengthRecordDataset(
test_filenames,
self._record_bytes,
self._header_bytes,
self._footer_bytes,
buffer_size=10,
num_parallel_reads=4)
expected_output = []
for j in range(self._num_files):
expected_output.extend(
[self._record(j, i) for i in range(self._num_records)])
self.assertDatasetProduces(dataset, expected_output=expected_output,
assert_items_equal=True)
if __name__ == "__main__":
test.main()

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import gzip
import os
import pathlib
import zlib
from absl.testing import parameterized
@ -168,6 +169,16 @@ class TextLineDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
open_files = psutil.Process().open_files()
self.assertNotIn(filename, [open_file.path for open_file in open_files])
@combinations.generate(test_base.default_test_combinations())
def testTextLineDatasetPathlib(self):
files = self._createFiles(1, 5)
files = [pathlib.Path(f) for f in files]
expected_output = [self._lineText(0, i) for i in range(5)]
ds = readers.TextLineDataset(files)
self.assertDatasetProduces(
ds, expected_output=expected_output, assert_items_equal=True)
if __name__ == "__main__":
test.main()

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import gzip
import os
import pathlib
import zlib
from absl.testing import parameterized
@ -187,6 +188,15 @@ class TFRecordDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, expected_output=expected_output * 10, assert_items_equal=True)
@combinations.generate(test_base.default_test_combinations())
def testDatasetPathlib(self):
files = [pathlib.Path(self.test_filenames[0])]
expected_output = [self._record(0, i) for i in range(self._num_records)]
ds = readers.TFRecordDataset(files)
self.assertDatasetProduces(
ds, expected_output=expected_output, assert_items_equal=True)
if __name__ == "__main__":
test.main()

View File

@ -51,11 +51,14 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/compat",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tf2",
"//tensorflow/python:tf_export",
"//tensorflow/python:util",
"//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:structure",
],
)

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert
@ -27,11 +29,17 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB
def _normalise_fspath(path):
"""Convert pathlib-like objects to str (__fspath__ compatibility, PEP 519)."""
return os.fspath(path) if isinstance(path, os.PathLike) else path
def _create_or_validate_filenames_dataset(filenames):
"""Creates (or validates) a dataset of filenames.
@ -52,6 +60,7 @@ def _create_or_validate_filenames_dataset(filenames):
"`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
"elements.")
else:
filenames = nest.map_structure(_normalise_fspath, filenames)
filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string)
if filenames.dtype != dtypes.string:
raise TypeError(