unblock eager execution for model_to_estimator

PiperOrigin-RevId: 203789461
This commit is contained in:
Zhenyu Tan 2018-07-09 10:36:58 -07:00 committed by TensorFlower Gardener
parent 9fea659a48
commit f83a382e87
2 changed files with 16 additions and 9 deletions

View File

@ -39,7 +39,6 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import distribute as distribute_lib
@ -71,16 +70,22 @@ def _convert_tensor(x):
return x
def _any_variable_initialized():
"""Check if any variable has been initialized in the Keras model.
def _any_weight_initialized(keras_model):
"""Check if any weights has been initialized in the Keras model.
Args:
keras_model: An instance of compiled keras model.
Returns:
boolean, True if at least one variable has been initialized, else False.
boolean, True if at least one weight has been initialized, else False.
Currently keras initialize all weights at get_session().
"""
variables = variables_module.global_variables()
for v in variables:
if getattr(v, '_keras_initialized', False):
return True
if keras_model is None:
return False
for layer in keras_model.layers:
for weight in layer.weights:
if hasattr(weight, '_keras_initialized'):
return True
return False
@ -520,7 +525,7 @@ def model_to_estimator(keras_model=None,
keras_model_fn, model_dir=model_dir, config=config)
# Check if we need to call get_weights:
if _any_variable_initialized():
if _any_weight_initialized(keras_model):
keras_weights = keras_model.get_weights()
# Warn if config passed to estimator tries to update GPUOptions. If a
# session has already been created, the GPUOptions passed to the first

View File

@ -204,6 +204,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
@test_util.run_in_graph_and_eager_modes
def test_train_with_tf_optimizer(self):
for model_type in ['sequential', 'functional']:
keras_model, (_, _), (
@ -231,6 +232,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
@test_util.run_in_graph_and_eager_modes
def test_train_with_subclassed_model(self):
keras_model, (_, _), (
_, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(