PR #34254: Allow `pathlib.Path` paths for loading models via Keras API

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/34254

This change allows the path for `tf.keras.models.load_model` to be specified as a `pathlib.Path` object. Previously, the provided path was explicitly checked to see if it was a string, triggering an `IOError` for `pathlib.Path` objects.
Copybara import of the project:

--
d26af8b4e308b99e9bdcb3cc0aa0f62dcbbffd88 by Matthew Petroff <matthew@mpetroff.net>:

Allow `pathlib.Path` paths for loading models via Keras API.

Previously, the provided path was checked to see if it was a string,
triggering an IOError.

--
96a2d971a0548b644a5719bcc1a3eae7799a9eb4 by Matthew Petroff <matthew@mpetroff.net>:

Add tests for loading Keras models in TF format.

Both string-defined and pathlib-defined model paths are tested.

--
4f15ec68b80731f9ade1a0d3f836381445d21010 by Matthew Petroff <matthew@mpetroff.net>:

Add missing import.

PiperOrigin-RevId: 281332697
Change-Id: I8d99f264dc45d1a7223340d9739f2f03ef4e51f7
This commit is contained in:
A. Unique TensorFlower 2019-11-19 11:03:26 -08:00 committed by TensorFlower Gardener
parent 84f9fb043e
commit ee9f39459d
2 changed files with 26 additions and 2 deletions

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
import sys
import six
@ -30,6 +31,8 @@ from tensorflow.python.saved_model import loader_impl
from tensorflow.python.util.tf_export import keras_export
# pylint: disable=g-import-not-at-top
if sys.version >= '3.4':
import pathlib
try:
import h5py
except ImportError:
@ -73,7 +76,7 @@ def save_model(model,
Arguments:
model: Keras model instance to be saved.
filepath: One of the following:
- String, path where to save the model
- String or `pathlib.Path` object, path where to save the model
- `h5py.File` object where to save the model
overwrite: Whether we should overwrite any existing model at the target
location, or instead ask the user with a manual prompt.
@ -95,6 +98,9 @@ def save_model(model,
default_format = 'tf' if tf2.enabled() else 'h5'
save_format = save_format or default_format
if sys.version >= '3.4' and isinstance(filepath, pathlib.Path):
filepath = str(filepath)
if (save_format == 'h5' or
(h5py is not None and isinstance(filepath, h5py.File)) or
os.path.splitext(filepath)[1] in _HDF5_EXTENSIONS):
@ -121,7 +127,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
Arguments:
filepath: One of the following:
- String, path to the saved model
- String or `pathlib.Path` object, path to the saved model
- `h5py.File` object from which to load the model
custom_objects: Optional dictionary mapping names
(strings) to custom classes or functions to be
@ -145,6 +151,8 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
if sys.version >= '3.4' and isinstance(filepath, pathlib.Path):
filepath = str(filepath)
if isinstance(filepath, six.string_types):
loader_impl.parse_saved_model(filepath)
return saved_model_load.load(filepath, compile)

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
@ -34,6 +35,8 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import loader_impl
if sys.version >= '3.4':
import pathlib # pylint:disable=g-import-not-at-top
try:
import h5py # pylint:disable=g-import-not-at-top
except ImportError:
@ -83,6 +86,19 @@ class TestSaveModel(test.TestCase):
save.save_model(self.subclassed_model, path, save_format='tf')
self.assert_saved_model(path)
@test_util.run_v2_only
def test_save_load_tf_string(self):
path = os.path.join(self.get_temp_dir(), 'model')
save.save_model(self.model, path, save_format='tf')
save.load_model(path)
@test_util.run_v2_only
def test_save_load_tf_pathlib(self):
if sys.version >= '3.4':
path = pathlib.Path(self.get_temp_dir()) / 'model'
save.save_model(self.model, path, save_format='tf')
save.load_model(path)
@test_util.run_in_graph_and_eager_modes
def test_saving_with_dense_features(self):
cols = [