diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py index 4be3aa0bbda..a64df37aeca 100644 --- a/tensorflow/python/keras/saving/save.py +++ b/tensorflow/python/keras/saving/save.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import sys import six @@ -30,6 +31,8 @@ from tensorflow.python.saved_model import loader_impl from tensorflow.python.util.tf_export import keras_export # pylint: disable=g-import-not-at-top +if sys.version >= '3.4': + import pathlib try: import h5py except ImportError: @@ -73,7 +76,7 @@ def save_model(model, Arguments: model: Keras model instance to be saved. filepath: One of the following: - - String, path where to save the model + - String or `pathlib.Path` object, path where to save the model - `h5py.File` object where to save the model overwrite: Whether we should overwrite any existing model at the target location, or instead ask the user with a manual prompt. @@ -95,6 +98,9 @@ def save_model(model, default_format = 'tf' if tf2.enabled() else 'h5' save_format = save_format or default_format + if sys.version >= '3.4' and isinstance(filepath, pathlib.Path): + filepath = str(filepath) + if (save_format == 'h5' or (h5py is not None and isinstance(filepath, h5py.File)) or os.path.splitext(filepath)[1] in _HDF5_EXTENSIONS): @@ -121,7 +127,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable= Arguments: filepath: One of the following: - - String, path to the saved model + - String or `pathlib.Path` object, path to the saved model - `h5py.File` object from which to load the model custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be @@ -145,6 +151,8 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable= isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))): return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile) + if sys.version >= '3.4' and isinstance(filepath, pathlib.Path): + filepath = str(filepath) if isinstance(filepath, six.string_types): loader_impl.parse_saved_model(filepath) return saved_model_load.load(filepath, compile) diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py index 9f9edf50176..f5fe8041857 100644 --- a/tensorflow/python/keras/saving/save_test.py +++ b/tensorflow/python/keras/saving/save_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import sys import numpy as np @@ -34,6 +35,8 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import loader_impl +if sys.version >= '3.4': + import pathlib # pylint:disable=g-import-not-at-top try: import h5py # pylint:disable=g-import-not-at-top except ImportError: @@ -83,6 +86,19 @@ class TestSaveModel(test.TestCase): save.save_model(self.subclassed_model, path, save_format='tf') self.assert_saved_model(path) + @test_util.run_v2_only + def test_save_load_tf_string(self): + path = os.path.join(self.get_temp_dir(), 'model') + save.save_model(self.model, path, save_format='tf') + save.load_model(path) + + @test_util.run_v2_only + def test_save_load_tf_pathlib(self): + if sys.version >= '3.4': + path = pathlib.Path(self.get_temp_dir()) / 'model' + save.save_model(self.model, path, save_format='tf') + save.load_model(path) + @test_util.run_in_graph_and_eager_modes def test_saving_with_dense_features(self): cols = [