Merge pull request #37955 from ngc92:pathlib
PiperOrigin-RevId: 305264725 Change-Id: Ia73b4865e0886cc01d52e772745476ddd89b21f7
This commit is contained in:
commit
35b03590f6
@ -44,6 +44,7 @@ from tensorflow.python.keras.utils import tf_utils
|
|||||||
from tensorflow.python.keras.utils import version_utils
|
from tensorflow.python.keras.utils import version_utils
|
||||||
from tensorflow.python.keras.utils.data_utils import Sequence
|
from tensorflow.python.keras.utils.data_utils import Sequence
|
||||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||||
|
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -1044,12 +1045,12 @@ class ModelCheckpoint(Callback):
|
|||||||
```
|
```
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
filepath: string, path to save the model file. `filepath` can contain
|
filepath: string or `PathLike`, path to save the model file. `filepath`
|
||||||
named formatting options, which will be filled the value of `epoch` and
|
can contain named formatting options, which will be filled the value of
|
||||||
keys in `logs` (passed in `on_epoch_end`). For example: if `filepath` is
|
`epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if
|
||||||
`weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints
|
`filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model
|
||||||
will be saved with the epoch number and the validation loss in the
|
checkpoints will be saved with the epoch number and the validation loss
|
||||||
filename.
|
in the filename.
|
||||||
monitor: quantity to monitor.
|
monitor: quantity to monitor.
|
||||||
verbose: verbosity mode, 0 or 1.
|
verbose: verbosity mode, 0 or 1.
|
||||||
save_best_only: if `save_best_only=True`, the latest best model according
|
save_best_only: if `save_best_only=True`, the latest best model according
|
||||||
@ -1090,7 +1091,7 @@ class ModelCheckpoint(Callback):
|
|||||||
self._supports_tf_logs = True
|
self._supports_tf_logs = True
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.filepath = filepath
|
self.filepath = path_to_string(filepath)
|
||||||
self.save_best_only = save_best_only
|
self.save_best_only = save_best_only
|
||||||
self.save_weights_only = save_weights_only
|
self.save_weights_only = save_weights_only
|
||||||
self.save_freq = save_freq
|
self.save_freq = save_freq
|
||||||
@ -1780,7 +1781,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
|||||||
self._supports_tf_logs = True
|
self._supports_tf_logs = True
|
||||||
self._validate_kwargs(kwargs)
|
self._validate_kwargs(kwargs)
|
||||||
|
|
||||||
self.log_dir = log_dir
|
self.log_dir = path_to_string(log_dir)
|
||||||
self.histogram_freq = histogram_freq
|
self.histogram_freq = histogram_freq
|
||||||
self.write_graph = write_graph
|
self.write_graph = write_graph
|
||||||
self.write_images = write_images
|
self.write_images = write_images
|
||||||
@ -2280,7 +2281,7 @@ class CSVLogger(Callback):
|
|||||||
|
|
||||||
def __init__(self, filename, separator=',', append=False):
|
def __init__(self, filename, separator=',', append=False):
|
||||||
self.sep = separator
|
self.sep = separator
|
||||||
self.filename = filename
|
self.filename = path_to_string(filename)
|
||||||
self.append = append
|
self.append = append
|
||||||
self.writer = None
|
self.writer = None
|
||||||
self.keys = None
|
self.keys = None
|
||||||
|
@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils
|
|||||||
from tensorflow.python.keras.utils import layer_utils
|
from tensorflow.python.keras.utils import layer_utils
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||||
|
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
@ -1030,7 +1031,8 @@ class Network(base_layer.Layer):
|
|||||||
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
|
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
filepath: String, path to SavedModel or H5 file to save the model.
|
filepath: String, PathLike, path to SavedModel or H5 file to save the
|
||||||
|
model.
|
||||||
overwrite: Whether to silently overwrite any existing file at the
|
overwrite: Whether to silently overwrite any existing file at the
|
||||||
target location, or provide the user with a manual prompt.
|
target location, or provide the user with a manual prompt.
|
||||||
include_optimizer: If True, save optimizer's state together.
|
include_optimizer: If True, save optimizer's state together.
|
||||||
@ -1103,10 +1105,10 @@ class Network(base_layer.Layer):
|
|||||||
on the TensorFlow format.
|
on the TensorFlow format.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
filepath: String, path to the file to save the weights to. When saving
|
filepath: String or PathLike, path to the file to save the weights to.
|
||||||
in TensorFlow format, this is the prefix used for checkpoint files
|
When saving in TensorFlow format, this is the prefix used for
|
||||||
(multiple files are generated). Note that the '.h5' suffix causes
|
checkpoint files (multiple files are generated). Note that the '.h5'
|
||||||
weights to be saved in HDF5 format.
|
suffix causes weights to be saved in HDF5 format.
|
||||||
overwrite: Whether to silently overwrite any existing file at the
|
overwrite: Whether to silently overwrite any existing file at the
|
||||||
target location, or provide the user with a manual prompt.
|
target location, or provide the user with a manual prompt.
|
||||||
save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
|
save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
|
||||||
@ -1119,6 +1121,7 @@ class Network(base_layer.Layer):
|
|||||||
ValueError: For invalid/unknown format arguments.
|
ValueError: For invalid/unknown format arguments.
|
||||||
"""
|
"""
|
||||||
self._assert_weights_created()
|
self._assert_weights_created()
|
||||||
|
filepath = path_to_string(filepath)
|
||||||
filepath_is_h5 = _is_hdf5_filepath(filepath)
|
filepath_is_h5 = _is_hdf5_filepath(filepath)
|
||||||
if save_format is None:
|
if save_format is None:
|
||||||
if filepath_is_h5:
|
if filepath_is_h5:
|
||||||
@ -1201,9 +1204,9 @@ class Network(base_layer.Layer):
|
|||||||
which layers are assigned in the `Model`'s constructor.
|
which layers are assigned in the `Model`'s constructor.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
filepath: String, path to the weights file to load. For weight files in
|
filepath: String or PathLike, path to the weights file to load. For
|
||||||
TensorFlow format, this is the file prefix (the same as was passed
|
weight files in TensorFlow format, this is the file prefix (the
|
||||||
to `save_weights`).
|
same as was passed to `save_weights`).
|
||||||
by_name: Boolean, whether to load weights by name or by topological
|
by_name: Boolean, whether to load weights by name or by topological
|
||||||
order. Only topological loading is supported for weight files in
|
order. Only topological loading is supported for weight files in
|
||||||
TensorFlow format.
|
TensorFlow format.
|
||||||
@ -1232,6 +1235,7 @@ class Network(base_layer.Layer):
|
|||||||
'When calling model.load_weights, skip_mismatch can only be set to '
|
'When calling model.load_weights, skip_mismatch can only be set to '
|
||||||
'True when by_name is True.')
|
'True when by_name is True.')
|
||||||
|
|
||||||
|
filepath = path_to_string(filepath)
|
||||||
if _is_hdf5_filepath(filepath):
|
if _is_hdf5_filepath(filepath):
|
||||||
save_format = 'h5'
|
save_format = 'h5'
|
||||||
else:
|
else:
|
||||||
|
@ -19,8 +19,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
@ -28,12 +26,11 @@ from tensorflow.python.keras.saving import hdf5_format
|
|||||||
from tensorflow.python.keras.saving.saved_model import load as saved_model_load
|
from tensorflow.python.keras.saving.saved_model import load as saved_model_load
|
||||||
from tensorflow.python.keras.saving.saved_model import save as saved_model_save
|
from tensorflow.python.keras.saving.saved_model import save as saved_model_save
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
|
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||||
from tensorflow.python.saved_model import loader_impl
|
from tensorflow.python.saved_model import loader_impl
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
# pylint: disable=g-import-not-at-top
|
# pylint: disable=g-import-not-at-top
|
||||||
if sys.version_info >= (3, 4):
|
|
||||||
import pathlib
|
|
||||||
try:
|
try:
|
||||||
import h5py
|
import h5py
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -115,8 +112,7 @@ def save_model(model,
|
|||||||
default_format = 'tf' if tf2.enabled() else 'h5'
|
default_format = 'tf' if tf2.enabled() else 'h5'
|
||||||
save_format = save_format or default_format
|
save_format = save_format or default_format
|
||||||
|
|
||||||
if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path):
|
filepath = path_to_string(filepath)
|
||||||
filepath = str(filepath)
|
|
||||||
|
|
||||||
if (save_format == 'h5' or
|
if (save_format == 'h5' or
|
||||||
(h5py is not None and isinstance(filepath, h5py.File)) or
|
(h5py is not None and isinstance(filepath, h5py.File)) or
|
||||||
@ -183,8 +179,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
|
|||||||
isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
|
isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
|
||||||
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
|
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
|
||||||
|
|
||||||
if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path):
|
filepath = path_to_string(filepath)
|
||||||
filepath = str(filepath)
|
|
||||||
if isinstance(filepath, six.string_types):
|
if isinstance(filepath, six.string_types):
|
||||||
loader_impl.parse_saved_model(filepath)
|
loader_impl.parse_saved_model(filepath)
|
||||||
return saved_model_load.load(filepath, compile)
|
return saved_model_load.load(filepath, compile)
|
||||||
|
@ -41,7 +41,7 @@ from tensorflow.python.ops import lookup_ops
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model import loader_impl
|
from tensorflow.python.saved_model import loader_impl
|
||||||
|
|
||||||
if sys.version_info >= (3, 4):
|
if sys.version_info >= (3, 6):
|
||||||
import pathlib # pylint:disable=g-import-not-at-top
|
import pathlib # pylint:disable=g-import-not-at-top
|
||||||
try:
|
try:
|
||||||
import h5py # pylint:disable=g-import-not-at-top
|
import h5py # pylint:disable=g-import-not-at-top
|
||||||
@ -100,7 +100,7 @@ class TestSaveModel(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def test_save_load_tf_pathlib(self):
|
def test_save_load_tf_pathlib(self):
|
||||||
if sys.version_info >= (3, 4):
|
if sys.version_info >= (3, 6):
|
||||||
path = pathlib.Path(self.get_temp_dir()) / 'model'
|
path = pathlib.Path(self.get_temp_dir()) / 'model'
|
||||||
save.save_model(self.model, path, save_format='tf')
|
save.save_model(self.model, path, save_format='tf')
|
||||||
save.load_model(path)
|
save.load_model(path)
|
||||||
|
@ -46,6 +46,7 @@ from six.moves.urllib.error import URLError
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from six.moves.urllib.request import urlopen
|
from six.moves.urllib.request import urlopen
|
||||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||||
|
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
@ -137,6 +138,9 @@ def _extract_archive(file_path, path='.', archive_format='auto'):
|
|||||||
if isinstance(archive_format, six.string_types):
|
if isinstance(archive_format, six.string_types):
|
||||||
archive_format = [archive_format]
|
archive_format = [archive_format]
|
||||||
|
|
||||||
|
file_path = path_to_string(file_path)
|
||||||
|
path = path_to_string(path)
|
||||||
|
|
||||||
for archive_type in archive_format:
|
for archive_type in archive_format:
|
||||||
if archive_type == 'tar':
|
if archive_type == 'tar':
|
||||||
open_fn = tarfile.open
|
open_fn = tarfile.open
|
||||||
@ -230,6 +234,8 @@ def get_file(fname,
|
|||||||
datadir = os.path.join(datadir_base, cache_subdir)
|
datadir = os.path.join(datadir_base, cache_subdir)
|
||||||
_makedirs_exist_ok(datadir)
|
_makedirs_exist_ok(datadir)
|
||||||
|
|
||||||
|
fname = path_to_string(fname)
|
||||||
|
|
||||||
if untar:
|
if untar:
|
||||||
untar_fpath = os.path.join(datadir, fname)
|
untar_fpath = os.path.join(datadir, fname)
|
||||||
fpath = untar_fpath + '.tar.gz'
|
fpath = untar_fpath + '.tar.gz'
|
||||||
|
@ -19,6 +19,8 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
@ -33,6 +35,48 @@ except ImportError:
|
|||||||
h5py = None
|
h5py = None
|
||||||
|
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 6):
|
||||||
|
|
||||||
|
def _path_to_string(path):
|
||||||
|
if isinstance(path, os.PathLike):
|
||||||
|
return os.fspath(path)
|
||||||
|
return path
|
||||||
|
elif sys.version_info >= (3, 4):
|
||||||
|
|
||||||
|
def _path_to_string(path):
|
||||||
|
import pathlib
|
||||||
|
if isinstance(path, pathlib.Path):
|
||||||
|
return str(path)
|
||||||
|
return path
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _path_to_string(path):
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def path_to_string(path):
|
||||||
|
"""Convert `PathLike` objects to their string representation.
|
||||||
|
|
||||||
|
If given a non-string typed path object, converts it to its string
|
||||||
|
representation. Depending on the python version used, this function
|
||||||
|
can handle the following arguments:
|
||||||
|
python >= 3.6: Everything supporting the fs path protocol
|
||||||
|
https://www.python.org/dev/peps/pep-0519
|
||||||
|
python >= 3.4: Only `pathlib.Path` objects
|
||||||
|
|
||||||
|
If the object passed to `path` is not among the above, then it is
|
||||||
|
returned unchanged. This allows e.g. passthrough of file objects
|
||||||
|
through this function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: `PathLike` object that represents a path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string representation of the path argument, if Python support exists.
|
||||||
|
"""
|
||||||
|
return _path_to_string(path)
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.utils.HDF5Matrix')
|
@keras_export('keras.utils.HDF5Matrix')
|
||||||
class HDF5Matrix(object):
|
class HDF5Matrix(object):
|
||||||
"""Representation of HDF5 dataset to be used instead of a Numpy array.
|
"""Representation of HDF5 dataset to be used instead of a Numpy array.
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
@ -137,6 +138,25 @@ class TestIOUtils(keras_parameterized.TestCase):
|
|||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists'))
|
io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists'))
|
||||||
|
|
||||||
|
def test_path_to_string(self):
|
||||||
|
|
||||||
|
class PathLikeDummy(object):
|
||||||
|
|
||||||
|
def __fspath__(self):
|
||||||
|
return 'dummypath'
|
||||||
|
|
||||||
|
dummy = object()
|
||||||
|
if sys.version_info >= (3, 4):
|
||||||
|
from pathlib import Path # pylint:disable=g-import-not-at-top
|
||||||
|
# conversion of PathLike
|
||||||
|
self.assertEqual(io_utils.path_to_string(Path('path')), 'path')
|
||||||
|
if sys.version_info >= (3, 6):
|
||||||
|
self.assertEqual(io_utils.path_to_string(PathLikeDummy()), 'dummypath')
|
||||||
|
|
||||||
|
# pass-through, works for all versions of python
|
||||||
|
self.assertEqual(io_utils.path_to_string('path'), 'path')
|
||||||
|
self.assertIs(io_utils.path_to_string(dummy), dummy)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
@ -299,6 +300,7 @@ def plot_model(model,
|
|||||||
rankdir=rankdir,
|
rankdir=rankdir,
|
||||||
expand_nested=expand_nested,
|
expand_nested=expand_nested,
|
||||||
dpi=dpi)
|
dpi=dpi)
|
||||||
|
to_file = path_to_string(to_file)
|
||||||
if dot is None:
|
if dot is None:
|
||||||
return
|
return
|
||||||
_, extension = os.path.splitext(to_file)
|
_, extension = os.path.splitext(to_file)
|
||||||
|
Loading…
Reference in New Issue
Block a user