diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 7e3dadedcf9..984c6d6e000 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -44,6 +44,7 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.io_utils import path_to_string from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops @@ -1044,12 +1045,12 @@ class ModelCheckpoint(Callback): ``` Arguments: - filepath: string, path to save the model file. `filepath` can contain - named formatting options, which will be filled the value of `epoch` and - keys in `logs` (passed in `on_epoch_end`). For example: if `filepath` is - `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints - will be saved with the epoch number and the validation loss in the - filename. + filepath: string or `PathLike`, path to save the model file. `filepath` + can contain named formatting options, which will be filled the value of + `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if + `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model + checkpoints will be saved with the epoch number and the validation loss + in the filename. monitor: quantity to monitor. verbose: verbosity mode, 0 or 1. save_best_only: if `save_best_only=True`, the latest best model according @@ -1090,7 +1091,7 @@ class ModelCheckpoint(Callback): self._supports_tf_logs = True self.monitor = monitor self.verbose = verbose - self.filepath = filepath + self.filepath = path_to_string(filepath) self.save_best_only = save_best_only self.save_weights_only = save_weights_only self.save_freq = save_freq @@ -1780,7 +1781,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): self._supports_tf_logs = True self._validate_kwargs(kwargs) - self.log_dir = log_dir + self.log_dir = path_to_string(log_dir) self.histogram_freq = histogram_freq self.write_graph = write_graph self.write_images = write_images @@ -2280,7 +2281,7 @@ class CSVLogger(Callback): def __init__(self, filename, separator=',', append=False): self.sep = separator - self.filename = filename + self.filename = path_to_string(filename) self.append = append self.writer = None self.keys = None diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 747e51fc4e2..8954e30f7ca 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.keras.utils.io_utils import path_to_string from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -1030,7 +1031,8 @@ class Network(base_layer.Layer): access specific variables, e.g. `model.get_layer("dense_1").kernel`. Arguments: - filepath: String, path to SavedModel or H5 file to save the model. + filepath: String, PathLike, path to SavedModel or H5 file to save the + model. overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt. include_optimizer: If True, save optimizer's state together. @@ -1103,10 +1105,10 @@ class Network(base_layer.Layer): on the TensorFlow format. Arguments: - filepath: String, path to the file to save the weights to. When saving - in TensorFlow format, this is the prefix used for checkpoint files - (multiple files are generated). Note that the '.h5' suffix causes - weights to be saved in HDF5 format. + filepath: String or PathLike, path to the file to save the weights to. + When saving in TensorFlow format, this is the prefix used for + checkpoint files (multiple files are generated). Note that the '.h5' + suffix causes weights to be saved in HDF5 format. overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt. save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or @@ -1119,6 +1121,7 @@ class Network(base_layer.Layer): ValueError: For invalid/unknown format arguments. """ self._assert_weights_created() + filepath = path_to_string(filepath) filepath_is_h5 = _is_hdf5_filepath(filepath) if save_format is None: if filepath_is_h5: @@ -1201,9 +1204,9 @@ class Network(base_layer.Layer): which layers are assigned in the `Model`'s constructor. Arguments: - filepath: String, path to the weights file to load. For weight files in - TensorFlow format, this is the file prefix (the same as was passed - to `save_weights`). + filepath: String or PathLike, path to the weights file to load. For + weight files in TensorFlow format, this is the file prefix (the + same as was passed to `save_weights`). by_name: Boolean, whether to load weights by name or by topological order. Only topological loading is supported for weight files in TensorFlow format. @@ -1232,6 +1235,7 @@ class Network(base_layer.Layer): 'When calling model.load_weights, skip_mismatch can only be set to ' 'True when by_name is True.') + filepath = path_to_string(filepath) if _is_hdf5_filepath(filepath): save_format = 'h5' else: diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py index 88c46c85c41..43c09a62ea9 100644 --- a/tensorflow/python/keras/saving/save.py +++ b/tensorflow/python/keras/saving/save.py @@ -19,8 +19,6 @@ from __future__ import division from __future__ import print_function import os -import sys - import six from tensorflow.python import tf2 @@ -28,12 +26,11 @@ from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving.saved_model import load as saved_model_load from tensorflow.python.keras.saving.saved_model import save as saved_model_save from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils.io_utils import path_to_string from tensorflow.python.saved_model import loader_impl from tensorflow.python.util.tf_export import keras_export # pylint: disable=g-import-not-at-top -if sys.version_info >= (3, 4): - import pathlib try: import h5py except ImportError: @@ -115,8 +112,7 @@ def save_model(model, default_format = 'tf' if tf2.enabled() else 'h5' save_format = save_format or default_format - if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path): - filepath = str(filepath) + filepath = path_to_string(filepath) if (save_format == 'h5' or (h5py is not None and isinstance(filepath, h5py.File)) or @@ -183,8 +179,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable= isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))): return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile) - if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path): - filepath = str(filepath) + filepath = path_to_string(filepath) if isinstance(filepath, six.string_types): loader_impl.parse_saved_model(filepath) return saved_model_load.load(filepath, compile) diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py index 7c48cca49ff..5c5846fe738 100644 --- a/tensorflow/python/keras/saving/save_test.py +++ b/tensorflow/python/keras/saving/save_test.py @@ -41,7 +41,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import loader_impl -if sys.version_info >= (3, 4): +if sys.version_info >= (3, 6): import pathlib # pylint:disable=g-import-not-at-top try: import h5py # pylint:disable=g-import-not-at-top @@ -100,7 +100,7 @@ class TestSaveModel(test.TestCase, parameterized.TestCase): @test_util.run_v2_only def test_save_load_tf_pathlib(self): - if sys.version_info >= (3, 4): + if sys.version_info >= (3, 6): path = pathlib.Path(self.get_temp_dir()) / 'model' save.save_model(self.model, path, save_format='tf') save.load_model(path) diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index 00df6b59739..1cf27f8fb65 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -46,6 +46,7 @@ from six.moves.urllib.error import URLError from tensorflow.python.framework import ops from six.moves.urllib.request import urlopen from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.io_utils import path_to_string from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import keras_export @@ -137,6 +138,9 @@ def _extract_archive(file_path, path='.', archive_format='auto'): if isinstance(archive_format, six.string_types): archive_format = [archive_format] + file_path = path_to_string(file_path) + path = path_to_string(path) + for archive_type in archive_format: if archive_type == 'tar': open_fn = tarfile.open @@ -230,6 +234,8 @@ def get_file(fname, datadir = os.path.join(datadir_base, cache_subdir) _makedirs_exist_ok(datadir) + fname = path_to_string(fname) + if untar: untar_fpath = os.path.join(datadir, fname) fpath = untar_fpath + '.tar.gz' diff --git a/tensorflow/python/keras/utils/io_utils.py b/tensorflow/python/keras/utils/io_utils.py index 99ef25a17d9..7c3395b239c 100644 --- a/tensorflow/python/keras/utils/io_utils.py +++ b/tensorflow/python/keras/utils/io_utils.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import collections +import os +import sys import numpy as np import six @@ -33,6 +35,48 @@ except ImportError: h5py = None +if sys.version_info >= (3, 6): + + def _path_to_string(path): + if isinstance(path, os.PathLike): + return os.fspath(path) + return path +elif sys.version_info >= (3, 4): + + def _path_to_string(path): + import pathlib + if isinstance(path, pathlib.Path): + return str(path) + return path +else: + + def _path_to_string(path): + return path + + +def path_to_string(path): + """Convert `PathLike` objects to their string representation. + + If given a non-string typed path object, converts it to its string + representation. Depending on the python version used, this function + can handle the following arguments: + python >= 3.6: Everything supporting the fs path protocol + https://www.python.org/dev/peps/pep-0519 + python >= 3.4: Only `pathlib.Path` objects + + If the object passed to `path` is not among the above, then it is + returned unchanged. This allows e.g. passthrough of file objects + through this function. + + Args: + path: `PathLike` object that represents a path + + Returns: + A string representation of the path argument, if Python support exists. + """ + return _path_to_string(path) + + @keras_export('keras.utils.HDF5Matrix') class HDF5Matrix(object): """Representation of HDF5 dataset to be used instead of a Numpy array. diff --git a/tensorflow/python/keras/utils/io_utils_test.py b/tensorflow/python/keras/utils/io_utils_test.py index 363fb5d7211..29328e52dbc 100644 --- a/tensorflow/python/keras/utils/io_utils_test.py +++ b/tensorflow/python/keras/utils/io_utils_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os import shutil +import sys import numpy as np import six @@ -137,6 +138,25 @@ class TestIOUtils(keras_parameterized.TestCase): self.assertFalse( io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists')) + def test_path_to_string(self): + + class PathLikeDummy(object): + + def __fspath__(self): + return 'dummypath' + + dummy = object() + if sys.version_info >= (3, 4): + from pathlib import Path # pylint:disable=g-import-not-at-top + # conversion of PathLike + self.assertEqual(io_utils.path_to_string(Path('path')), 'path') + if sys.version_info >= (3, 6): + self.assertEqual(io_utils.path_to_string(PathLikeDummy()), 'dummypath') + + # pass-through, works for all versions of python + self.assertEqual(io_utils.path_to_string('path'), 'path') + self.assertIs(io_utils.path_to_string(dummy), dummy) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/utils/vis_utils.py b/tensorflow/python/keras/utils/vis_utils.py index 9819cb831e2..87c436a5bd7 100644 --- a/tensorflow/python/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/utils/vis_utils.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import sys +from tensorflow.python.keras.utils.io_utils import path_to_string from tensorflow.python.util import nest from tensorflow.python.util.tf_export import keras_export @@ -299,6 +300,7 @@ def plot_model(model, rankdir=rankdir, expand_nested=expand_nested, dpi=dpi) + to_file = path_to_string(to_file) if dot is None: return _, extension = os.path.splitext(to_file)