unblock eager execution for model_to_estimator
PiperOrigin-RevId: 203789461
This commit is contained in:
parent
9fea659a48
commit
f83a382e87
@ -39,7 +39,6 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics as metrics_module
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
@ -71,16 +70,22 @@ def _convert_tensor(x):
|
||||
return x
|
||||
|
||||
|
||||
def _any_variable_initialized():
|
||||
"""Check if any variable has been initialized in the Keras model.
|
||||
def _any_weight_initialized(keras_model):
|
||||
"""Check if any weights has been initialized in the Keras model.
|
||||
|
||||
Args:
|
||||
keras_model: An instance of compiled keras model.
|
||||
|
||||
Returns:
|
||||
boolean, True if at least one variable has been initialized, else False.
|
||||
boolean, True if at least one weight has been initialized, else False.
|
||||
Currently keras initialize all weights at get_session().
|
||||
"""
|
||||
variables = variables_module.global_variables()
|
||||
for v in variables:
|
||||
if getattr(v, '_keras_initialized', False):
|
||||
return True
|
||||
if keras_model is None:
|
||||
return False
|
||||
for layer in keras_model.layers:
|
||||
for weight in layer.weights:
|
||||
if hasattr(weight, '_keras_initialized'):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ -520,7 +525,7 @@ def model_to_estimator(keras_model=None,
|
||||
keras_model_fn, model_dir=model_dir, config=config)
|
||||
|
||||
# Check if we need to call get_weights:
|
||||
if _any_variable_initialized():
|
||||
if _any_weight_initialized(keras_model):
|
||||
keras_weights = keras_model.get_weights()
|
||||
# Warn if config passed to estimator tries to update GPUOptions. If a
|
||||
# session has already been created, the GPUOptions passed to the first
|
||||
|
@ -204,6 +204,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
|
||||
writer_cache.FileWriterCache.clear()
|
||||
gfile.DeleteRecursively(self._config.model_dir)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_train_with_tf_optimizer(self):
|
||||
for model_type in ['sequential', 'functional']:
|
||||
keras_model, (_, _), (
|
||||
@ -231,6 +232,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
|
||||
writer_cache.FileWriterCache.clear()
|
||||
gfile.DeleteRecursively(self._config.model_dir)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_train_with_subclassed_model(self):
|
||||
keras_model, (_, _), (
|
||||
_, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
|
||||
|
Loading…
Reference in New Issue
Block a user