If keras_model_path is google storage url, provide util to download model
remotely. PiperOrigin-RevId: 215295504
This commit is contained in:
parent
28a5ce4cf8
commit
6509437545
@ -368,6 +368,44 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
|
||||
return latest_path
|
||||
|
||||
|
||||
def _get_file_from_google_storage(keras_model_path, model_dir):
|
||||
"""Get file from google storage and download to local file.
|
||||
|
||||
Args:
|
||||
keras_model_path: a google storage path for compiled keras model.
|
||||
model_dir: the directory from estimator config.
|
||||
|
||||
Returns:
|
||||
The path where keras model is saved.
|
||||
|
||||
Raises:
|
||||
ValueError: if storage object name does not end with .h5.
|
||||
"""
|
||||
try:
|
||||
from google.cloud import storage # pylint:disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
raise TypeError('Could not save model to Google cloud storage; please '
|
||||
'install `google-cloud-storage` via '
|
||||
'`pip install google-cloud-storage`.')
|
||||
storage_client = storage.Client()
|
||||
path, blob_name = os.path.split(keras_model_path)
|
||||
_, bucket_name = os.path.split(path)
|
||||
keras_model_dir = os.path.join(model_dir, 'keras')
|
||||
if not gfile.Exists(keras_model_dir):
|
||||
gfile.MakeDirs(keras_model_dir)
|
||||
file_name = os.path.join(keras_model_dir, 'keras_model.h5')
|
||||
try:
|
||||
blob = storage_client.get_bucket(bucket_name).blob(blob_name)
|
||||
blob.download_to_filename(file_name)
|
||||
except:
|
||||
raise ValueError('Failed to download keras model, please check '
|
||||
'environment variable GOOGLE_APPLICATION_CREDENTIALS '
|
||||
'and model path storage.googleapis.com/{bucket}/{object}.')
|
||||
logging.info('Saving model to {}'.format(file_name))
|
||||
del storage_client
|
||||
return file_name
|
||||
|
||||
|
||||
def model_to_estimator(keras_model=None,
|
||||
keras_model_path=None,
|
||||
custom_objects=None,
|
||||
@ -407,12 +445,13 @@ def model_to_estimator(keras_model=None,
|
||||
'Please specity either `keras_model` or `keras_model_path`, '
|
||||
'but not both.')
|
||||
|
||||
config = estimator_lib.maybe_overwrite_model_dir_and_session_config(
|
||||
config, model_dir)
|
||||
if not keras_model:
|
||||
if keras_model_path.startswith(
|
||||
'gs://') or 'storage.googleapis.com' in keras_model_path:
|
||||
raise ValueError(
|
||||
'%s is not a local path. Please copy the model locally first.' %
|
||||
keras_model_path)
|
||||
keras_model_path = _get_file_from_google_storage(keras_model_path,
|
||||
config.model_dir)
|
||||
logging.info('Loading models from %s', keras_model_path)
|
||||
keras_model = models.load_model(keras_model_path)
|
||||
else:
|
||||
@ -425,9 +464,6 @@ def model_to_estimator(keras_model=None,
|
||||
'Please compile the model with `model.compile()` '
|
||||
'before calling `model_to_estimator()`.')
|
||||
|
||||
config = estimator_lib.maybe_overwrite_model_dir_and_session_config(config,
|
||||
model_dir)
|
||||
|
||||
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
|
||||
if _any_weight_initialized(keras_model):
|
||||
# Warn if config passed to estimator tries to update GPUOptions. If a
|
||||
|
@ -581,12 +581,6 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'compiled'):
|
||||
keras_lib.model_to_estimator(keras_model=keras_model)
|
||||
|
||||
with self.cached_session():
|
||||
keras_model = simple_sequential_model()
|
||||
with self.assertRaisesRegexp(ValueError, 'not a local path'):
|
||||
keras_lib.model_to_estimator(
|
||||
keras_model_path='gs://bucket/object')
|
||||
|
||||
def test_invalid_ionames_error(self):
|
||||
(x_train, y_train), (_, _) = testing_utils.get_test_data(
|
||||
train_samples=_TRAIN_SIZE,
|
||||
|
Loading…
Reference in New Issue
Block a user