From ee9f39459d2e0838408010cafef05176ca9339c6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Nov 2019 11:03:26 -0800 Subject: [PATCH] PR #34254: Allow `pathlib.Path` paths for loading models via Keras API Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/34254 This change allows the path for `tf.keras.models.load_model` to be specified as a `pathlib.Path` object. Previously, the provided path was explicitly checked to see if it was a string, triggering an `IOError` for `pathlib.Path` objects. Copybara import of the project: -- d26af8b4e308b99e9bdcb3cc0aa0f62dcbbffd88 by Matthew Petroff : Allow `pathlib.Path` paths for loading models via Keras API. Previously, the provided path was checked to see if it was a string, triggering an IOError. -- 96a2d971a0548b644a5719bcc1a3eae7799a9eb4 by Matthew Petroff : Add tests for loading Keras models in TF format. Both string-defined and pathlib-defined model paths are tested. -- 4f15ec68b80731f9ade1a0d3f836381445d21010 by Matthew Petroff : Add missing import. PiperOrigin-RevId: 281332697 Change-Id: I8d99f264dc45d1a7223340d9739f2f03ef4e51f7 --- tensorflow/python/keras/saving/save.py | 12 ++++++++++-- tensorflow/python/keras/saving/save_test.py | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py index 4be3aa0bbda..a64df37aeca 100644 --- a/tensorflow/python/keras/saving/save.py +++ b/tensorflow/python/keras/saving/save.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import sys import six @@ -30,6 +31,8 @@ from tensorflow.python.saved_model import loader_impl from tensorflow.python.util.tf_export import keras_export # pylint: disable=g-import-not-at-top +if sys.version >= '3.4': + import pathlib try: import h5py except ImportError: @@ -73,7 +76,7 @@ def save_model(model, Arguments: model: Keras model instance to be saved. filepath: One of the following: - - String, path where to save the model + - String or `pathlib.Path` object, path where to save the model - `h5py.File` object where to save the model overwrite: Whether we should overwrite any existing model at the target location, or instead ask the user with a manual prompt. @@ -95,6 +98,9 @@ def save_model(model, default_format = 'tf' if tf2.enabled() else 'h5' save_format = save_format or default_format + if sys.version >= '3.4' and isinstance(filepath, pathlib.Path): + filepath = str(filepath) + if (save_format == 'h5' or (h5py is not None and isinstance(filepath, h5py.File)) or os.path.splitext(filepath)[1] in _HDF5_EXTENSIONS): @@ -121,7 +127,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable= Arguments: filepath: One of the following: - - String, path to the saved model + - String or `pathlib.Path` object, path to the saved model - `h5py.File` object from which to load the model custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be @@ -145,6 +151,8 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable= isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))): return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile) + if sys.version >= '3.4' and isinstance(filepath, pathlib.Path): + filepath = str(filepath) if isinstance(filepath, six.string_types): loader_impl.parse_saved_model(filepath) return saved_model_load.load(filepath, compile) diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py index 9f9edf50176..f5fe8041857 100644 --- a/tensorflow/python/keras/saving/save_test.py +++ b/tensorflow/python/keras/saving/save_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import sys import numpy as np @@ -34,6 +35,8 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import loader_impl +if sys.version >= '3.4': + import pathlib # pylint:disable=g-import-not-at-top try: import h5py # pylint:disable=g-import-not-at-top except ImportError: @@ -83,6 +86,19 @@ class TestSaveModel(test.TestCase): save.save_model(self.subclassed_model, path, save_format='tf') self.assert_saved_model(path) + @test_util.run_v2_only + def test_save_load_tf_string(self): + path = os.path.join(self.get_temp_dir(), 'model') + save.save_model(self.model, path, save_format='tf') + save.load_model(path) + + @test_util.run_v2_only + def test_save_load_tf_pathlib(self): + if sys.version >= '3.4': + path = pathlib.Path(self.get_temp_dir()) / 'model' + save.save_model(self.model, path, save_format='tf') + save.load_model(path) + @test_util.run_in_graph_and_eager_modes def test_saving_with_dense_features(self): cols = [