If keras_model_path is google storage url, provide util to download model

remotely.

PiperOrigin-RevId: 215295504
This commit is contained in:
Zhenyu Tan 2018-10-01 15:52:16 -07:00 committed by TensorFlower Gardener
parent 28a5ce4cf8
commit 6509437545
2 changed files with 42 additions and 12 deletions

View File

@ -368,6 +368,44 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
return latest_path
def _get_file_from_google_storage(keras_model_path, model_dir):
"""Get file from google storage and download to local file.
Args:
keras_model_path: a google storage path for compiled keras model.
model_dir: the directory from estimator config.
Returns:
The path where keras model is saved.
Raises:
ValueError: if storage object name does not end with .h5.
"""
try:
from google.cloud import storage # pylint:disable=g-import-not-at-top
except ImportError:
raise TypeError('Could not save model to Google cloud storage; please '
'install `google-cloud-storage` via '
'`pip install google-cloud-storage`.')
storage_client = storage.Client()
path, blob_name = os.path.split(keras_model_path)
_, bucket_name = os.path.split(path)
keras_model_dir = os.path.join(model_dir, 'keras')
if not gfile.Exists(keras_model_dir):
gfile.MakeDirs(keras_model_dir)
file_name = os.path.join(keras_model_dir, 'keras_model.h5')
try:
blob = storage_client.get_bucket(bucket_name).blob(blob_name)
blob.download_to_filename(file_name)
except:
raise ValueError('Failed to download keras model, please check '
'environment variable GOOGLE_APPLICATION_CREDENTIALS '
'and model path storage.googleapis.com/{bucket}/{object}.')
logging.info('Saving model to {}'.format(file_name))
del storage_client
return file_name
def model_to_estimator(keras_model=None,
keras_model_path=None,
custom_objects=None,
@ -407,12 +445,13 @@ def model_to_estimator(keras_model=None,
'Please specity either `keras_model` or `keras_model_path`, '
'but not both.')
config = estimator_lib.maybe_overwrite_model_dir_and_session_config(
config, model_dir)
if not keras_model:
if keras_model_path.startswith(
'gs://') or 'storage.googleapis.com' in keras_model_path:
raise ValueError(
'%s is not a local path. Please copy the model locally first.' %
keras_model_path)
keras_model_path = _get_file_from_google_storage(keras_model_path,
config.model_dir)
logging.info('Loading models from %s', keras_model_path)
keras_model = models.load_model(keras_model_path)
else:
@ -425,9 +464,6 @@ def model_to_estimator(keras_model=None,
'Please compile the model with `model.compile()` '
'before calling `model_to_estimator()`.')
config = estimator_lib.maybe_overwrite_model_dir_and_session_config(config,
model_dir)
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
if _any_weight_initialized(keras_model):
# Warn if config passed to estimator tries to update GPUOptions. If a

View File

@ -581,12 +581,6 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, 'compiled'):
keras_lib.model_to_estimator(keras_model=keras_model)
with self.cached_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'not a local path'):
keras_lib.model_to_estimator(
keras_model_path='gs://bucket/object')
def test_invalid_ionames_error(self):
(x_train, y_train), (_, _) = testing_utils.get_test_data(
train_samples=_TRAIN_SIZE,