Switch to using variable._in_graph_mode instead of context.executing_eagerly() in optimizer.variables()

Remove global collection usage from the Keras model to estimator flow.

PiperOrigin-RevId: 209837803
This commit is contained in:
Pavithra Vijay 2018-08-22 15:09:08 -07:00 committed by TensorFlower Gardener
parent cd199a89db
commit d97e525c1f
4 changed files with 7 additions and 17 deletions

View File

@ -36,7 +36,6 @@ from tensorflow.python.keras import optimizers
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
@ -315,15 +314,7 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
if not model.train_function: if not model.train_function:
# pylint: disable=protected-access # pylint: disable=protected-access
model._make_train_function() model._make_train_function()
# We are using global variables collection here because: K._initialize_variables(sess)
# estimator runs eager mode under context.graph_mode() context manager
# When we try to get all the TF optimizer variables using
# optimizer.variables() we try to return variables that belong to the
# current graph. This check (variable.op.graph is current_graph) will
# error as the context is graph mode but variables are eager.
# TODO(psv): investigate this and see if we can remove the usage of
# collection here.
K._initialize_variables(sess, variables_module.global_variables())
# pylint: enable=protected-access # pylint: enable=protected-access
saver = saver_lib.Saver() saver = saver_lib.Saver()
latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt') latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')

View File

@ -696,9 +696,8 @@ def _get_variables(graph=None):
return variables return variables
def _initialize_variables(session, variables=None): def _initialize_variables(session):
"""Utility to initialize uninitialized variables on the fly.""" """Utility to initialize uninitialized variables on the fly."""
if variables is None:
variables = _get_variables(ops.get_default_graph()) variables = _get_variables(ops.get_default_graph())
candidate_vars = [] candidate_vars = []
for v in variables: for v in variables:

View File

@ -447,6 +447,7 @@ def clone_and_build_model(
elif model.optimizer: elif model.optimizer:
if isinstance(model.optimizer, optimizers.TFOptimizer): if isinstance(model.optimizer, optimizers.TFOptimizer):
optimizer = model.optimizer optimizer = model.optimizer
K.track_tf_optimizer(optimizer)
else: else:
optimizer_config = model.optimizer.get_config() optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config) optimizer = model.optimizer.__class__.from_config(optimizer_config)

View File

@ -772,16 +772,15 @@ class Optimizer(
Returns: Returns:
A list of variables. A list of variables.
""" """
executing_eagerly = context.executing_eagerly()
current_graph = ops.get_default_graph() current_graph = ops.get_default_graph()
def _from_current_graph(variable): def _from_current_graph(variable):
if executing_eagerly: if variable._in_graph_mode: # pylint: disable=protected-access
return variable.op.graph is current_graph
else:
# No variable.op in eager mode. We don't expect lots of eager graphs, # No variable.op in eager mode. We don't expect lots of eager graphs,
# but behavior should be consistent with graph mode. # but behavior should be consistent with graph mode.
return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access
else:
return variable.op.graph is current_graph
optimizer_variables = [v for v in self._non_slot_variables() optimizer_variables = [v for v in self._non_slot_variables()
if _from_current_graph(v)] if _from_current_graph(v)]