Merge pull request #37955 from ngc92:pathlib
PiperOrigin-RevId: 305264725 Change-Id: Ia73b4865e0886cc01d52e772745476ddd89b21f7
This commit is contained in:
commit
35b03590f6
@ -44,6 +44,7 @@ from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.keras.utils import version_utils
|
||||
from tensorflow.python.keras.utils.data_utils import Sequence
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -1044,12 +1045,12 @@ class ModelCheckpoint(Callback):
|
||||
```
|
||||
|
||||
Arguments:
|
||||
filepath: string, path to save the model file. `filepath` can contain
|
||||
named formatting options, which will be filled the value of `epoch` and
|
||||
keys in `logs` (passed in `on_epoch_end`). For example: if `filepath` is
|
||||
`weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints
|
||||
will be saved with the epoch number and the validation loss in the
|
||||
filename.
|
||||
filepath: string or `PathLike`, path to save the model file. `filepath`
|
||||
can contain named formatting options, which will be filled the value of
|
||||
`epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if
|
||||
`filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model
|
||||
checkpoints will be saved with the epoch number and the validation loss
|
||||
in the filename.
|
||||
monitor: quantity to monitor.
|
||||
verbose: verbosity mode, 0 or 1.
|
||||
save_best_only: if `save_best_only=True`, the latest best model according
|
||||
@ -1090,7 +1091,7 @@ class ModelCheckpoint(Callback):
|
||||
self._supports_tf_logs = True
|
||||
self.monitor = monitor
|
||||
self.verbose = verbose
|
||||
self.filepath = filepath
|
||||
self.filepath = path_to_string(filepath)
|
||||
self.save_best_only = save_best_only
|
||||
self.save_weights_only = save_weights_only
|
||||
self.save_freq = save_freq
|
||||
@ -1780,7 +1781,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
self._supports_tf_logs = True
|
||||
self._validate_kwargs(kwargs)
|
||||
|
||||
self.log_dir = log_dir
|
||||
self.log_dir = path_to_string(log_dir)
|
||||
self.histogram_freq = histogram_freq
|
||||
self.write_graph = write_graph
|
||||
self.write_images = write_images
|
||||
@ -2280,7 +2281,7 @@ class CSVLogger(Callback):
|
||||
|
||||
def __init__(self, filename, separator=',', append=False):
|
||||
self.sep = separator
|
||||
self.filename = filename
|
||||
self.filename = path_to_string(filename)
|
||||
self.append = append
|
||||
self.writer = None
|
||||
self.keys = None
|
||||
|
@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -1030,7 +1031,8 @@ class Network(base_layer.Layer):
|
||||
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
|
||||
|
||||
Arguments:
|
||||
filepath: String, path to SavedModel or H5 file to save the model.
|
||||
filepath: String, PathLike, path to SavedModel or H5 file to save the
|
||||
model.
|
||||
overwrite: Whether to silently overwrite any existing file at the
|
||||
target location, or provide the user with a manual prompt.
|
||||
include_optimizer: If True, save optimizer's state together.
|
||||
@ -1103,10 +1105,10 @@ class Network(base_layer.Layer):
|
||||
on the TensorFlow format.
|
||||
|
||||
Arguments:
|
||||
filepath: String, path to the file to save the weights to. When saving
|
||||
in TensorFlow format, this is the prefix used for checkpoint files
|
||||
(multiple files are generated). Note that the '.h5' suffix causes
|
||||
weights to be saved in HDF5 format.
|
||||
filepath: String or PathLike, path to the file to save the weights to.
|
||||
When saving in TensorFlow format, this is the prefix used for
|
||||
checkpoint files (multiple files are generated). Note that the '.h5'
|
||||
suffix causes weights to be saved in HDF5 format.
|
||||
overwrite: Whether to silently overwrite any existing file at the
|
||||
target location, or provide the user with a manual prompt.
|
||||
save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
|
||||
@ -1119,6 +1121,7 @@ class Network(base_layer.Layer):
|
||||
ValueError: For invalid/unknown format arguments.
|
||||
"""
|
||||
self._assert_weights_created()
|
||||
filepath = path_to_string(filepath)
|
||||
filepath_is_h5 = _is_hdf5_filepath(filepath)
|
||||
if save_format is None:
|
||||
if filepath_is_h5:
|
||||
@ -1201,9 +1204,9 @@ class Network(base_layer.Layer):
|
||||
which layers are assigned in the `Model`'s constructor.
|
||||
|
||||
Arguments:
|
||||
filepath: String, path to the weights file to load. For weight files in
|
||||
TensorFlow format, this is the file prefix (the same as was passed
|
||||
to `save_weights`).
|
||||
filepath: String or PathLike, path to the weights file to load. For
|
||||
weight files in TensorFlow format, this is the file prefix (the
|
||||
same as was passed to `save_weights`).
|
||||
by_name: Boolean, whether to load weights by name or by topological
|
||||
order. Only topological loading is supported for weight files in
|
||||
TensorFlow format.
|
||||
@ -1232,6 +1235,7 @@ class Network(base_layer.Layer):
|
||||
'When calling model.load_weights, skip_mismatch can only be set to '
|
||||
'True when by_name is True.')
|
||||
|
||||
filepath = path_to_string(filepath)
|
||||
if _is_hdf5_filepath(filepath):
|
||||
save_format = 'h5'
|
||||
else:
|
||||
|
@ -19,8 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
@ -28,12 +26,11 @@ from tensorflow.python.keras.saving import hdf5_format
|
||||
from tensorflow.python.keras.saving.saved_model import load as saved_model_load
|
||||
from tensorflow.python.keras.saving.saved_model import save as saved_model_save
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if sys.version_info >= (3, 4):
|
||||
import pathlib
|
||||
try:
|
||||
import h5py
|
||||
except ImportError:
|
||||
@ -115,8 +112,7 @@ def save_model(model,
|
||||
default_format = 'tf' if tf2.enabled() else 'h5'
|
||||
save_format = save_format or default_format
|
||||
|
||||
if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path):
|
||||
filepath = str(filepath)
|
||||
filepath = path_to_string(filepath)
|
||||
|
||||
if (save_format == 'h5' or
|
||||
(h5py is not None and isinstance(filepath, h5py.File)) or
|
||||
@ -183,8 +179,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
|
||||
isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
|
||||
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
|
||||
|
||||
if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path):
|
||||
filepath = str(filepath)
|
||||
filepath = path_to_string(filepath)
|
||||
if isinstance(filepath, six.string_types):
|
||||
loader_impl.parse_saved_model(filepath)
|
||||
return saved_model_load.load(filepath, compile)
|
||||
|
@ -41,7 +41,7 @@ from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
|
||||
if sys.version_info >= (3, 4):
|
||||
if sys.version_info >= (3, 6):
|
||||
import pathlib # pylint:disable=g-import-not-at-top
|
||||
try:
|
||||
import h5py # pylint:disable=g-import-not-at-top
|
||||
@ -100,7 +100,7 @@ class TestSaveModel(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_save_load_tf_pathlib(self):
|
||||
if sys.version_info >= (3, 4):
|
||||
if sys.version_info >= (3, 6):
|
||||
path = pathlib.Path(self.get_temp_dir()) / 'model'
|
||||
save.save_model(self.model, path, save_format='tf')
|
||||
save.load_model(path)
|
||||
|
@ -46,6 +46,7 @@ from six.moves.urllib.error import URLError
|
||||
from tensorflow.python.framework import ops
|
||||
from six.moves.urllib.request import urlopen
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -137,6 +138,9 @@ def _extract_archive(file_path, path='.', archive_format='auto'):
|
||||
if isinstance(archive_format, six.string_types):
|
||||
archive_format = [archive_format]
|
||||
|
||||
file_path = path_to_string(file_path)
|
||||
path = path_to_string(path)
|
||||
|
||||
for archive_type in archive_format:
|
||||
if archive_type == 'tar':
|
||||
open_fn = tarfile.open
|
||||
@ -230,6 +234,8 @@ def get_file(fname,
|
||||
datadir = os.path.join(datadir_base, cache_subdir)
|
||||
_makedirs_exist_ok(datadir)
|
||||
|
||||
fname = path_to_string(fname)
|
||||
|
||||
if untar:
|
||||
untar_fpath = os.path.join(datadir, fname)
|
||||
fpath = untar_fpath + '.tar.gz'
|
||||
|
@ -19,6 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
@ -33,6 +35,48 @@ except ImportError:
|
||||
h5py = None
|
||||
|
||||
|
||||
if sys.version_info >= (3, 6):
|
||||
|
||||
def _path_to_string(path):
|
||||
if isinstance(path, os.PathLike):
|
||||
return os.fspath(path)
|
||||
return path
|
||||
elif sys.version_info >= (3, 4):
|
||||
|
||||
def _path_to_string(path):
|
||||
import pathlib
|
||||
if isinstance(path, pathlib.Path):
|
||||
return str(path)
|
||||
return path
|
||||
else:
|
||||
|
||||
def _path_to_string(path):
|
||||
return path
|
||||
|
||||
|
||||
def path_to_string(path):
|
||||
"""Convert `PathLike` objects to their string representation.
|
||||
|
||||
If given a non-string typed path object, converts it to its string
|
||||
representation. Depending on the python version used, this function
|
||||
can handle the following arguments:
|
||||
python >= 3.6: Everything supporting the fs path protocol
|
||||
https://www.python.org/dev/peps/pep-0519
|
||||
python >= 3.4: Only `pathlib.Path` objects
|
||||
|
||||
If the object passed to `path` is not among the above, then it is
|
||||
returned unchanged. This allows e.g. passthrough of file objects
|
||||
through this function.
|
||||
|
||||
Args:
|
||||
path: `PathLike` object that represents a path
|
||||
|
||||
Returns:
|
||||
A string representation of the path argument, if Python support exists.
|
||||
"""
|
||||
return _path_to_string(path)
|
||||
|
||||
|
||||
@keras_export('keras.utils.HDF5Matrix')
|
||||
class HDF5Matrix(object):
|
||||
"""Representation of HDF5 dataset to be used instead of a Numpy array.
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
@ -137,6 +138,25 @@ class TestIOUtils(keras_parameterized.TestCase):
|
||||
self.assertFalse(
|
||||
io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists'))
|
||||
|
||||
def test_path_to_string(self):
|
||||
|
||||
class PathLikeDummy(object):
|
||||
|
||||
def __fspath__(self):
|
||||
return 'dummypath'
|
||||
|
||||
dummy = object()
|
||||
if sys.version_info >= (3, 4):
|
||||
from pathlib import Path # pylint:disable=g-import-not-at-top
|
||||
# conversion of PathLike
|
||||
self.assertEqual(io_utils.path_to_string(Path('path')), 'path')
|
||||
if sys.version_info >= (3, 6):
|
||||
self.assertEqual(io_utils.path_to_string(PathLikeDummy()), 'dummypath')
|
||||
|
||||
# pass-through, works for all versions of python
|
||||
self.assertEqual(io_utils.path_to_string('path'), 'path')
|
||||
self.assertIs(io_utils.path_to_string(dummy), dummy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -299,6 +300,7 @@ def plot_model(model,
|
||||
rankdir=rankdir,
|
||||
expand_nested=expand_nested,
|
||||
dpi=dpi)
|
||||
to_file = path_to_string(to_file)
|
||||
if dot is None:
|
||||
return
|
||||
_, extension = os.path.splitext(to_file)
|
||||
|
Loading…
Reference in New Issue
Block a user