Merge pull request #37955 from ngc92:pathlib

PiperOrigin-RevId: 305264725
Change-Id: Ia73b4865e0886cc01d52e772745476ddd89b21f7
This commit is contained in:
TensorFlower Gardener 2020-04-07 08:45:38 -07:00
commit 35b03590f6
8 changed files with 99 additions and 27 deletions

View File

@ -44,6 +44,7 @@ from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.data_utils import Sequence
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
@ -1044,12 +1045,12 @@ class ModelCheckpoint(Callback):
```
Arguments:
filepath: string, path to save the model file. `filepath` can contain
named formatting options, which will be filled the value of `epoch` and
keys in `logs` (passed in `on_epoch_end`). For example: if `filepath` is
`weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints
will be saved with the epoch number and the validation loss in the
filename.
filepath: string or `PathLike`, path to save the model file. `filepath`
can contain named formatting options, which will be filled the value of
`epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if
`filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model
checkpoints will be saved with the epoch number and the validation loss
in the filename.
monitor: quantity to monitor.
verbose: verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`, the latest best model according
@ -1090,7 +1091,7 @@ class ModelCheckpoint(Callback):
self._supports_tf_logs = True
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.filepath = path_to_string(filepath)
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.save_freq = save_freq
@ -1780,7 +1781,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
self._supports_tf_logs = True
self._validate_kwargs(kwargs)
self.log_dir = log_dir
self.log_dir = path_to_string(log_dir)
self.histogram_freq = histogram_freq
self.write_graph = write_graph
self.write_images = write_images
@ -2280,7 +2281,7 @@ class CSVLogger(Callback):
def __init__(self, filename, separator=',', append=False):
self.sep = separator
self.filename = filename
self.filename = path_to_string(filename)
self.append = append
self.writer = None
self.keys = None

View File

@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
@ -1030,7 +1031,8 @@ class Network(base_layer.Layer):
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
Arguments:
filepath: String, path to SavedModel or H5 file to save the model.
filepath: String, PathLike, path to SavedModel or H5 file to save the
model.
overwrite: Whether to silently overwrite any existing file at the
target location, or provide the user with a manual prompt.
include_optimizer: If True, save optimizer's state together.
@ -1103,10 +1105,10 @@ class Network(base_layer.Layer):
on the TensorFlow format.
Arguments:
filepath: String, path to the file to save the weights to. When saving
in TensorFlow format, this is the prefix used for checkpoint files
(multiple files are generated). Note that the '.h5' suffix causes
weights to be saved in HDF5 format.
filepath: String or PathLike, path to the file to save the weights to.
When saving in TensorFlow format, this is the prefix used for
checkpoint files (multiple files are generated). Note that the '.h5'
suffix causes weights to be saved in HDF5 format.
overwrite: Whether to silently overwrite any existing file at the
target location, or provide the user with a manual prompt.
save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
@ -1119,6 +1121,7 @@ class Network(base_layer.Layer):
ValueError: For invalid/unknown format arguments.
"""
self._assert_weights_created()
filepath = path_to_string(filepath)
filepath_is_h5 = _is_hdf5_filepath(filepath)
if save_format is None:
if filepath_is_h5:
@ -1201,9 +1204,9 @@ class Network(base_layer.Layer):
which layers are assigned in the `Model`'s constructor.
Arguments:
filepath: String, path to the weights file to load. For weight files in
TensorFlow format, this is the file prefix (the same as was passed
to `save_weights`).
filepath: String or PathLike, path to the weights file to load. For
weight files in TensorFlow format, this is the file prefix (the
same as was passed to `save_weights`).
by_name: Boolean, whether to load weights by name or by topological
order. Only topological loading is supported for weight files in
TensorFlow format.
@ -1232,6 +1235,7 @@ class Network(base_layer.Layer):
'When calling model.load_weights, skip_mismatch can only be set to '
'True when by_name is True.')
filepath = path_to_string(filepath)
if _is_hdf5_filepath(filepath):
save_format = 'h5'
else:

View File

@ -19,8 +19,6 @@ from __future__ import division
from __future__ import print_function
import os
import sys
import six
from tensorflow.python import tf2
@ -28,12 +26,11 @@ from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving.saved_model import load as saved_model_load
from tensorflow.python.keras.saving.saved_model import save as saved_model_save
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.util.tf_export import keras_export
# pylint: disable=g-import-not-at-top
if sys.version_info >= (3, 4):
import pathlib
try:
import h5py
except ImportError:
@ -115,8 +112,7 @@ def save_model(model,
default_format = 'tf' if tf2.enabled() else 'h5'
save_format = save_format or default_format
if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path):
filepath = str(filepath)
filepath = path_to_string(filepath)
if (save_format == 'h5' or
(h5py is not None and isinstance(filepath, h5py.File)) or
@ -183,8 +179,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path):
filepath = str(filepath)
filepath = path_to_string(filepath)
if isinstance(filepath, six.string_types):
loader_impl.parse_saved_model(filepath)
return saved_model_load.load(filepath, compile)

View File

@ -41,7 +41,7 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import loader_impl
if sys.version_info >= (3, 4):
if sys.version_info >= (3, 6):
import pathlib # pylint:disable=g-import-not-at-top
try:
import h5py # pylint:disable=g-import-not-at-top
@ -100,7 +100,7 @@ class TestSaveModel(test.TestCase, parameterized.TestCase):
@test_util.run_v2_only
def test_save_load_tf_pathlib(self):
if sys.version_info >= (3, 4):
if sys.version_info >= (3, 6):
path = pathlib.Path(self.get_temp_dir()) / 'model'
save.save_model(self.model, path, save_format='tf')
save.load_model(path)

View File

@ -46,6 +46,7 @@ from six.moves.urllib.error import URLError
from tensorflow.python.framework import ops
from six.moves.urllib.request import urlopen
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export
@ -137,6 +138,9 @@ def _extract_archive(file_path, path='.', archive_format='auto'):
if isinstance(archive_format, six.string_types):
archive_format = [archive_format]
file_path = path_to_string(file_path)
path = path_to_string(path)
for archive_type in archive_format:
if archive_type == 'tar':
open_fn = tarfile.open
@ -230,6 +234,8 @@ def get_file(fname,
datadir = os.path.join(datadir_base, cache_subdir)
_makedirs_exist_ok(datadir)
fname = path_to_string(fname)
if untar:
untar_fpath = os.path.join(datadir, fname)
fpath = untar_fpath + '.tar.gz'

View File

@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
import collections
import os
import sys
import numpy as np
import six
@ -33,6 +35,48 @@ except ImportError:
h5py = None
if sys.version_info >= (3, 6):
def _path_to_string(path):
if isinstance(path, os.PathLike):
return os.fspath(path)
return path
elif sys.version_info >= (3, 4):
def _path_to_string(path):
import pathlib
if isinstance(path, pathlib.Path):
return str(path)
return path
else:
def _path_to_string(path):
return path
def path_to_string(path):
"""Convert `PathLike` objects to their string representation.
If given a non-string typed path object, converts it to its string
representation. Depending on the python version used, this function
can handle the following arguments:
python >= 3.6: Everything supporting the fs path protocol
https://www.python.org/dev/peps/pep-0519
python >= 3.4: Only `pathlib.Path` objects
If the object passed to `path` is not among the above, then it is
returned unchanged. This allows e.g. passthrough of file objects
through this function.
Args:
path: `PathLike` object that represents a path
Returns:
A string representation of the path argument, if Python support exists.
"""
return _path_to_string(path)
@keras_export('keras.utils.HDF5Matrix')
class HDF5Matrix(object):
"""Representation of HDF5 dataset to be used instead of a Numpy array.

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import os
import shutil
import sys
import numpy as np
import six
@ -137,6 +138,25 @@ class TestIOUtils(keras_parameterized.TestCase):
self.assertFalse(
io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists'))
def test_path_to_string(self):
class PathLikeDummy(object):
def __fspath__(self):
return 'dummypath'
dummy = object()
if sys.version_info >= (3, 4):
from pathlib import Path # pylint:disable=g-import-not-at-top
# conversion of PathLike
self.assertEqual(io_utils.path_to_string(Path('path')), 'path')
if sys.version_info >= (3, 6):
self.assertEqual(io_utils.path_to_string(PathLikeDummy()), 'dummypath')
# pass-through, works for all versions of python
self.assertEqual(io_utils.path_to_string('path'), 'path')
self.assertIs(io_utils.path_to_string(dummy), dummy)
if __name__ == '__main__':
test.main()

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import os
import sys
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -299,6 +300,7 @@ def plot_model(model,
rankdir=rankdir,
expand_nested=expand_nested,
dpi=dpi)
to_file = path_to_string(to_file)
if dot is None:
return
_, extension = os.path.splitext(to_file)