Merge pull request #37955 from ngc92:pathlib

PiperOrigin-RevId: 305264725
Change-Id: Ia73b4865e0886cc01d52e772745476ddd89b21f7
This commit is contained in:
TensorFlower Gardener 2020-04-07 08:45:38 -07:00
commit 35b03590f6
8 changed files with 99 additions and 27 deletions

View File

@ -44,6 +44,7 @@ from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.keras.utils.data_utils import Sequence
from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -1044,12 +1045,12 @@ class ModelCheckpoint(Callback):
``` ```
Arguments: Arguments:
filepath: string, path to save the model file. `filepath` can contain filepath: string or `PathLike`, path to save the model file. `filepath`
named formatting options, which will be filled the value of `epoch` and can contain named formatting options, which will be filled the value of
keys in `logs` (passed in `on_epoch_end`). For example: if `filepath` is `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if
`weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model
will be saved with the epoch number and the validation loss in the checkpoints will be saved with the epoch number and the validation loss
filename. in the filename.
monitor: quantity to monitor. monitor: quantity to monitor.
verbose: verbosity mode, 0 or 1. verbose: verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`, the latest best model according save_best_only: if `save_best_only=True`, the latest best model according
@ -1090,7 +1091,7 @@ class ModelCheckpoint(Callback):
self._supports_tf_logs = True self._supports_tf_logs = True
self.monitor = monitor self.monitor = monitor
self.verbose = verbose self.verbose = verbose
self.filepath = filepath self.filepath = path_to_string(filepath)
self.save_best_only = save_best_only self.save_best_only = save_best_only
self.save_weights_only = save_weights_only self.save_weights_only = save_weights_only
self.save_freq = save_freq self.save_freq = save_freq
@ -1780,7 +1781,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
self._supports_tf_logs = True self._supports_tf_logs = True
self._validate_kwargs(kwargs) self._validate_kwargs(kwargs)
self.log_dir = log_dir self.log_dir = path_to_string(log_dir)
self.histogram_freq = histogram_freq self.histogram_freq = histogram_freq
self.write_graph = write_graph self.write_graph = write_graph
self.write_images = write_images self.write_images = write_images
@ -2280,7 +2281,7 @@ class CSVLogger(Callback):
def __init__(self, filename, separator=',', append=False): def __init__(self, filename, separator=',', append=False):
self.sep = separator self.sep = separator
self.filename = filename self.filename = path_to_string(filename)
self.append = append self.append = append
self.writer = None self.writer = None
self.keys = None self.keys = None

View File

@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
@ -1030,7 +1031,8 @@ class Network(base_layer.Layer):
access specific variables, e.g. `model.get_layer("dense_1").kernel`. access specific variables, e.g. `model.get_layer("dense_1").kernel`.
Arguments: Arguments:
filepath: String, path to SavedModel or H5 file to save the model. filepath: String, PathLike, path to SavedModel or H5 file to save the
model.
overwrite: Whether to silently overwrite any existing file at the overwrite: Whether to silently overwrite any existing file at the
target location, or provide the user with a manual prompt. target location, or provide the user with a manual prompt.
include_optimizer: If True, save optimizer's state together. include_optimizer: If True, save optimizer's state together.
@ -1103,10 +1105,10 @@ class Network(base_layer.Layer):
on the TensorFlow format. on the TensorFlow format.
Arguments: Arguments:
filepath: String, path to the file to save the weights to. When saving filepath: String or PathLike, path to the file to save the weights to.
in TensorFlow format, this is the prefix used for checkpoint files When saving in TensorFlow format, this is the prefix used for
(multiple files are generated). Note that the '.h5' suffix causes checkpoint files (multiple files are generated). Note that the '.h5'
weights to be saved in HDF5 format. suffix causes weights to be saved in HDF5 format.
overwrite: Whether to silently overwrite any existing file at the overwrite: Whether to silently overwrite any existing file at the
target location, or provide the user with a manual prompt. target location, or provide the user with a manual prompt.
save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
@ -1119,6 +1121,7 @@ class Network(base_layer.Layer):
ValueError: For invalid/unknown format arguments. ValueError: For invalid/unknown format arguments.
""" """
self._assert_weights_created() self._assert_weights_created()
filepath = path_to_string(filepath)
filepath_is_h5 = _is_hdf5_filepath(filepath) filepath_is_h5 = _is_hdf5_filepath(filepath)
if save_format is None: if save_format is None:
if filepath_is_h5: if filepath_is_h5:
@ -1201,9 +1204,9 @@ class Network(base_layer.Layer):
which layers are assigned in the `Model`'s constructor. which layers are assigned in the `Model`'s constructor.
Arguments: Arguments:
filepath: String, path to the weights file to load. For weight files in filepath: String or PathLike, path to the weights file to load. For
TensorFlow format, this is the file prefix (the same as was passed weight files in TensorFlow format, this is the file prefix (the
to `save_weights`). same as was passed to `save_weights`).
by_name: Boolean, whether to load weights by name or by topological by_name: Boolean, whether to load weights by name or by topological
order. Only topological loading is supported for weight files in order. Only topological loading is supported for weight files in
TensorFlow format. TensorFlow format.
@ -1232,6 +1235,7 @@ class Network(base_layer.Layer):
'When calling model.load_weights, skip_mismatch can only be set to ' 'When calling model.load_weights, skip_mismatch can only be set to '
'True when by_name is True.') 'True when by_name is True.')
filepath = path_to_string(filepath)
if _is_hdf5_filepath(filepath): if _is_hdf5_filepath(filepath):
save_format = 'h5' save_format = 'h5'
else: else:

View File

@ -19,8 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import six import six
from tensorflow.python import tf2 from tensorflow.python import tf2
@ -28,12 +26,11 @@ from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving.saved_model import load as saved_model_load from tensorflow.python.keras.saving.saved_model import load as saved_model_load
from tensorflow.python.keras.saving.saved_model import save as saved_model_save from tensorflow.python.keras.saving.saved_model import save as saved_model_save
from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import loader_impl
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
if sys.version_info >= (3, 4):
import pathlib
try: try:
import h5py import h5py
except ImportError: except ImportError:
@ -115,8 +112,7 @@ def save_model(model,
default_format = 'tf' if tf2.enabled() else 'h5' default_format = 'tf' if tf2.enabled() else 'h5'
save_format = save_format or default_format save_format = save_format or default_format
if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path): filepath = path_to_string(filepath)
filepath = str(filepath)
if (save_format == 'h5' or if (save_format == 'h5' or
(h5py is not None and isinstance(filepath, h5py.File)) or (h5py is not None and isinstance(filepath, h5py.File)) or
@ -183,8 +179,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))): isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile) return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path): filepath = path_to_string(filepath)
filepath = str(filepath)
if isinstance(filepath, six.string_types): if isinstance(filepath, six.string_types):
loader_impl.parse_saved_model(filepath) loader_impl.parse_saved_model(filepath)
return saved_model_load.load(filepath, compile) return saved_model_load.load(filepath, compile)

View File

@ -41,7 +41,7 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import loader_impl
if sys.version_info >= (3, 4): if sys.version_info >= (3, 6):
import pathlib # pylint:disable=g-import-not-at-top import pathlib # pylint:disable=g-import-not-at-top
try: try:
import h5py # pylint:disable=g-import-not-at-top import h5py # pylint:disable=g-import-not-at-top
@ -100,7 +100,7 @@ class TestSaveModel(test.TestCase, parameterized.TestCase):
@test_util.run_v2_only @test_util.run_v2_only
def test_save_load_tf_pathlib(self): def test_save_load_tf_pathlib(self):
if sys.version_info >= (3, 4): if sys.version_info >= (3, 6):
path = pathlib.Path(self.get_temp_dir()) / 'model' path = pathlib.Path(self.get_temp_dir()) / 'model'
save.save_model(self.model, path, save_format='tf') save.save_model(self.model, path, save_format='tf')
save.load_model(path) save.load_model(path)

View File

@ -46,6 +46,7 @@ from six.moves.urllib.error import URLError
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from six.moves.urllib.request import urlopen from six.moves.urllib.request import urlopen
from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -137,6 +138,9 @@ def _extract_archive(file_path, path='.', archive_format='auto'):
if isinstance(archive_format, six.string_types): if isinstance(archive_format, six.string_types):
archive_format = [archive_format] archive_format = [archive_format]
file_path = path_to_string(file_path)
path = path_to_string(path)
for archive_type in archive_format: for archive_type in archive_format:
if archive_type == 'tar': if archive_type == 'tar':
open_fn = tarfile.open open_fn = tarfile.open
@ -230,6 +234,8 @@ def get_file(fname,
datadir = os.path.join(datadir_base, cache_subdir) datadir = os.path.join(datadir_base, cache_subdir)
_makedirs_exist_ok(datadir) _makedirs_exist_ok(datadir)
fname = path_to_string(fname)
if untar: if untar:
untar_fpath = os.path.join(datadir, fname) untar_fpath = os.path.join(datadir, fname)
fpath = untar_fpath + '.tar.gz' fpath = untar_fpath + '.tar.gz'

View File

@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import os
import sys
import numpy as np import numpy as np
import six import six
@ -33,6 +35,48 @@ except ImportError:
h5py = None h5py = None
if sys.version_info >= (3, 6):
def _path_to_string(path):
if isinstance(path, os.PathLike):
return os.fspath(path)
return path
elif sys.version_info >= (3, 4):
def _path_to_string(path):
import pathlib
if isinstance(path, pathlib.Path):
return str(path)
return path
else:
def _path_to_string(path):
return path
def path_to_string(path):
"""Convert `PathLike` objects to their string representation.
If given a non-string typed path object, converts it to its string
representation. Depending on the python version used, this function
can handle the following arguments:
python >= 3.6: Everything supporting the fs path protocol
https://www.python.org/dev/peps/pep-0519
python >= 3.4: Only `pathlib.Path` objects
If the object passed to `path` is not among the above, then it is
returned unchanged. This allows e.g. passthrough of file objects
through this function.
Args:
path: `PathLike` object that represents a path
Returns:
A string representation of the path argument, if Python support exists.
"""
return _path_to_string(path)
@keras_export('keras.utils.HDF5Matrix') @keras_export('keras.utils.HDF5Matrix')
class HDF5Matrix(object): class HDF5Matrix(object):
"""Representation of HDF5 dataset to be used instead of a Numpy array. """Representation of HDF5 dataset to be used instead of a Numpy array.

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import os import os
import shutil import shutil
import sys
import numpy as np import numpy as np
import six import six
@ -137,6 +138,25 @@ class TestIOUtils(keras_parameterized.TestCase):
self.assertFalse( self.assertFalse(
io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists')) io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists'))
def test_path_to_string(self):
class PathLikeDummy(object):
def __fspath__(self):
return 'dummypath'
dummy = object()
if sys.version_info >= (3, 4):
from pathlib import Path # pylint:disable=g-import-not-at-top
# conversion of PathLike
self.assertEqual(io_utils.path_to_string(Path('path')), 'path')
if sys.version_info >= (3, 6):
self.assertEqual(io_utils.path_to_string(PathLikeDummy()), 'dummypath')
# pass-through, works for all versions of python
self.assertEqual(io_utils.path_to_string('path'), 'path')
self.assertIs(io_utils.path_to_string(dummy), dummy)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
import sys import sys
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -299,6 +300,7 @@ def plot_model(model,
rankdir=rankdir, rankdir=rankdir,
expand_nested=expand_nested, expand_nested=expand_nested,
dpi=dpi) dpi=dpi)
to_file = path_to_string(to_file)
if dot is None: if dot is None:
return return
_, extension = os.path.splitext(to_file) _, extension = os.path.splitext(to_file)