unblock eager execution for model_to_estimator
PiperOrigin-RevId: 203789461
This commit is contained in:
parent
9fea659a48
commit
f83a382e87
@ -39,7 +39,6 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
|||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import metrics as metrics_module
|
from tensorflow.python.ops import metrics as metrics_module
|
||||||
from tensorflow.python.ops import variables as variables_module
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.training import distribute as distribute_lib
|
from tensorflow.python.training import distribute as distribute_lib
|
||||||
@ -71,16 +70,22 @@ def _convert_tensor(x):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _any_variable_initialized():
|
def _any_weight_initialized(keras_model):
|
||||||
"""Check if any variable has been initialized in the Keras model.
|
"""Check if any weights has been initialized in the Keras model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keras_model: An instance of compiled keras model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
boolean, True if at least one variable has been initialized, else False.
|
boolean, True if at least one weight has been initialized, else False.
|
||||||
|
Currently keras initialize all weights at get_session().
|
||||||
"""
|
"""
|
||||||
variables = variables_module.global_variables()
|
if keras_model is None:
|
||||||
for v in variables:
|
return False
|
||||||
if getattr(v, '_keras_initialized', False):
|
for layer in keras_model.layers:
|
||||||
return True
|
for weight in layer.weights:
|
||||||
|
if hasattr(weight, '_keras_initialized'):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -520,7 +525,7 @@ def model_to_estimator(keras_model=None,
|
|||||||
keras_model_fn, model_dir=model_dir, config=config)
|
keras_model_fn, model_dir=model_dir, config=config)
|
||||||
|
|
||||||
# Check if we need to call get_weights:
|
# Check if we need to call get_weights:
|
||||||
if _any_variable_initialized():
|
if _any_weight_initialized(keras_model):
|
||||||
keras_weights = keras_model.get_weights()
|
keras_weights = keras_model.get_weights()
|
||||||
# Warn if config passed to estimator tries to update GPUOptions. If a
|
# Warn if config passed to estimator tries to update GPUOptions. If a
|
||||||
# session has already been created, the GPUOptions passed to the first
|
# session has already been created, the GPUOptions passed to the first
|
||||||
|
@ -204,6 +204,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
|
|||||||
writer_cache.FileWriterCache.clear()
|
writer_cache.FileWriterCache.clear()
|
||||||
gfile.DeleteRecursively(self._config.model_dir)
|
gfile.DeleteRecursively(self._config.model_dir)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_train_with_tf_optimizer(self):
|
def test_train_with_tf_optimizer(self):
|
||||||
for model_type in ['sequential', 'functional']:
|
for model_type in ['sequential', 'functional']:
|
||||||
keras_model, (_, _), (
|
keras_model, (_, _), (
|
||||||
@ -231,6 +232,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
|
|||||||
writer_cache.FileWriterCache.clear()
|
writer_cache.FileWriterCache.clear()
|
||||||
gfile.DeleteRecursively(self._config.model_dir)
|
gfile.DeleteRecursively(self._config.model_dir)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_train_with_subclassed_model(self):
|
def test_train_with_subclassed_model(self):
|
||||||
keras_model, (_, _), (
|
keras_model, (_, _), (
|
||||||
_, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
|
_, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user