PR #34254: Allow `pathlib.Path` paths for loading models via Keras API
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/34254 This change allows the path for `tf.keras.models.load_model` to be specified as a `pathlib.Path` object. Previously, the provided path was explicitly checked to see if it was a string, triggering an `IOError` for `pathlib.Path` objects. Copybara import of the project: -- d26af8b4e308b99e9bdcb3cc0aa0f62dcbbffd88 by Matthew Petroff <matthew@mpetroff.net>: Allow `pathlib.Path` paths for loading models via Keras API. Previously, the provided path was checked to see if it was a string, triggering an IOError. -- 96a2d971a0548b644a5719bcc1a3eae7799a9eb4 by Matthew Petroff <matthew@mpetroff.net>: Add tests for loading Keras models in TF format. Both string-defined and pathlib-defined model paths are tested. -- 4f15ec68b80731f9ade1a0d3f836381445d21010 by Matthew Petroff <matthew@mpetroff.net>: Add missing import. PiperOrigin-RevId: 281332697 Change-Id: I8d99f264dc45d1a7223340d9739f2f03ef4e51f7
This commit is contained in:
parent
84f9fb043e
commit
ee9f39459d
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import six
|
||||
|
||||
|
@ -30,6 +31,8 @@ from tensorflow.python.saved_model import loader_impl
|
|||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if sys.version >= '3.4':
|
||||
import pathlib
|
||||
try:
|
||||
import h5py
|
||||
except ImportError:
|
||||
|
@ -73,7 +76,7 @@ def save_model(model,
|
|||
Arguments:
|
||||
model: Keras model instance to be saved.
|
||||
filepath: One of the following:
|
||||
- String, path where to save the model
|
||||
- String or `pathlib.Path` object, path where to save the model
|
||||
- `h5py.File` object where to save the model
|
||||
overwrite: Whether we should overwrite any existing model at the target
|
||||
location, or instead ask the user with a manual prompt.
|
||||
|
@ -95,6 +98,9 @@ def save_model(model,
|
|||
default_format = 'tf' if tf2.enabled() else 'h5'
|
||||
save_format = save_format or default_format
|
||||
|
||||
if sys.version >= '3.4' and isinstance(filepath, pathlib.Path):
|
||||
filepath = str(filepath)
|
||||
|
||||
if (save_format == 'h5' or
|
||||
(h5py is not None and isinstance(filepath, h5py.File)) or
|
||||
os.path.splitext(filepath)[1] in _HDF5_EXTENSIONS):
|
||||
|
@ -121,7 +127,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
|
|||
|
||||
Arguments:
|
||||
filepath: One of the following:
|
||||
- String, path to the saved model
|
||||
- String or `pathlib.Path` object, path to the saved model
|
||||
- `h5py.File` object from which to load the model
|
||||
custom_objects: Optional dictionary mapping names
|
||||
(strings) to custom classes or functions to be
|
||||
|
@ -145,6 +151,8 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
|
|||
isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
|
||||
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
|
||||
|
||||
if sys.version >= '3.4' and isinstance(filepath, pathlib.Path):
|
||||
filepath = str(filepath)
|
||||
if isinstance(filepath, six.string_types):
|
||||
loader_impl.parse_saved_model(filepath)
|
||||
return saved_model_load.load(filepath, compile)
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -34,6 +35,8 @@ from tensorflow.python.ops import lookup_ops
|
|||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
|
||||
if sys.version >= '3.4':
|
||||
import pathlib # pylint:disable=g-import-not-at-top
|
||||
try:
|
||||
import h5py # pylint:disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
|
@ -83,6 +86,19 @@ class TestSaveModel(test.TestCase):
|
|||
save.save_model(self.subclassed_model, path, save_format='tf')
|
||||
self.assert_saved_model(path)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_save_load_tf_string(self):
|
||||
path = os.path.join(self.get_temp_dir(), 'model')
|
||||
save.save_model(self.model, path, save_format='tf')
|
||||
save.load_model(path)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_save_load_tf_pathlib(self):
|
||||
if sys.version >= '3.4':
|
||||
path = pathlib.Path(self.get_temp_dir()) / 'model'
|
||||
save.save_model(self.model, path, save_format='tf')
|
||||
save.load_model(path)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_saving_with_dense_features(self):
|
||||
cols = [
|
||||
|
|
Loading…
Reference in New Issue