Add __fspath__
support (PEP 519) to files datasets
PiperOrigin-RevId: 340292396 Change-Id: I876a8f2c0ad3e72663ae13ab6456da4e76165726
This commit is contained in:
parent
c37d199bd6
commit
794aa738cd
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import pathlib
|
||||
import zlib
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -190,6 +191,24 @@ class FixedLengthRecordDatasetTest(test_base.DatasetTestBase,
|
||||
r"which is not an exact multiple of the record length \(4 bytes\).")
|
||||
)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFixedLengthRecordDatasetPathlib(self):
|
||||
test_filenames = self._createFiles()
|
||||
test_filenames = [pathlib.Path(f) for f in test_filenames]
|
||||
dataset = readers.FixedLengthRecordDataset(
|
||||
test_filenames,
|
||||
self._record_bytes,
|
||||
self._header_bytes,
|
||||
self._footer_bytes,
|
||||
buffer_size=10,
|
||||
num_parallel_reads=4)
|
||||
expected_output = []
|
||||
for j in range(self._num_files):
|
||||
expected_output.extend(
|
||||
[self._record(j, i) for i in range(self._num_records)])
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output,
|
||||
assert_items_equal=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import pathlib
|
||||
import zlib
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -168,6 +169,16 @@ class TextLineDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
open_files = psutil.Process().open_files()
|
||||
self.assertNotIn(filename, [open_file.path for open_file in open_files])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTextLineDatasetPathlib(self):
|
||||
files = self._createFiles(1, 5)
|
||||
files = [pathlib.Path(f) for f in files]
|
||||
|
||||
expected_output = [self._lineText(0, i) for i in range(5)]
|
||||
ds = readers.TextLineDataset(files)
|
||||
self.assertDatasetProduces(
|
||||
ds, expected_output=expected_output, assert_items_equal=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import pathlib
|
||||
import zlib
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -187,6 +188,15 @@ class TFRecordDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=expected_output * 10, assert_items_equal=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDatasetPathlib(self):
|
||||
files = [pathlib.Path(self.test_filenames[0])]
|
||||
|
||||
expected_output = [self._record(0, i) for i in range(self._num_records)]
|
||||
ds = readers.TFRecordDataset(files)
|
||||
self.assertDatasetProduces(
|
||||
ds, expected_output=expected_output, assert_items_equal=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -51,11 +51,14 @@ py_library(
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:experimental_dataset_ops_gen",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/compat",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/util:convert",
|
||||
"//tensorflow/python/data/util:structure",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
@ -27,11 +29,17 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB
|
||||
|
||||
|
||||
def _normalise_fspath(path):
|
||||
"""Convert pathlib-like objects to str (__fspath__ compatibility, PEP 519)."""
|
||||
return os.fspath(path) if isinstance(path, os.PathLike) else path
|
||||
|
||||
|
||||
def _create_or_validate_filenames_dataset(filenames):
|
||||
"""Creates (or validates) a dataset of filenames.
|
||||
|
||||
@ -52,6 +60,7 @@ def _create_or_validate_filenames_dataset(filenames):
|
||||
"`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
|
||||
"elements.")
|
||||
else:
|
||||
filenames = nest.map_structure(_normalise_fspath, filenames)
|
||||
filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string)
|
||||
if filenames.dtype != dtypes.string:
|
||||
raise TypeError(
|
||||
|
Loading…
Reference in New Issue
Block a user