Update initialization of variables in Keras.

- Removes the use of global variable collections for initialization. This means that variables created externally (such as input feed variables) will have to be initialized explicitly.

PiperOrigin-RevId: 209692235
This commit is contained in:
Pavithra Vijay 2018-08-21 17:49:43 -07:00 committed by TensorFlower Gardener
parent 274bebe3a6
commit d432bd4203
8 changed files with 85 additions and 20 deletions

View File

@ -186,3 +186,7 @@ class CrossShardOptimizer(optimizer.Optimizer):
A list of strings.
"""
return self._opt.get_slot_names(*args, **kwargs)
def variables(self):
"""Forwarding the variables from the underlying optimizer."""
return self._opt.variables()

View File

@ -36,6 +36,7 @@ from tensorflow.python.keras import optimizers
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
@ -314,7 +315,15 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
if not model.train_function:
# pylint: disable=protected-access
model._make_train_function()
K._initialize_variables(sess)
# We are using global variables collection here because:
# estimator runs eager mode under context.graph_mode() context manager
# When we try to get all the TF optimizer variables using
# optimizer.variables() we try to return variables that belong to the
# current graph. This check (variable.op.graph is current_graph) will
# error as the context is graph mode but variables are eager.
# TODO(psv): investigate this and see if we can remove the usage of
# collection here.
K._initialize_variables(sess, variables_module.global_variables())
# pylint: enable=protected-access
saver = saver_lib.Saver()
latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')

View File

@ -94,6 +94,14 @@ _IMAGE_DATA_FORMAT = 'channels_last'
# We assume our devices don't change henceforth.
_LOCAL_DEVICES = None
# This dictionary holds a mapping between a graph and variables to initialize
# in the graph.
_GRAPH_VARIABLES = {}
# This dictionary holds a mapping between a graph and TF optimizers created in
# the graph.
_GRAPH_TF_OPTIMIZERS = {}
@tf_export('keras.backend.backend')
def backend():
@ -309,6 +317,8 @@ def clear_session():
"""
global _SESSION
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
global _GRAPH_VARIABLES # pylint: disable=global-variable-not-assigned
global _GRAPH_TF_OPTIMIZERS # pylint: disable=global-variable-not-assigned
ops.reset_default_graph()
reset_uids()
_SESSION = None
@ -316,6 +326,8 @@ def clear_session():
False, shape=(), name='keras_learning_phase')
_GRAPH_LEARNING_PHASES = {}
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
_GRAPH_VARIABLES.pop(ops.get_default_graph(), None)
_GRAPH_TF_OPTIMIZERS.pop(ops.get_default_graph(), None)
@tf_export('keras.backend.manual_variable_initialization')
@ -651,12 +663,43 @@ def variable(value, dtype=None, name=None, constraint=None):
elif hasattr(value, 'shape'):
v._keras_shape = int_shape(value)
v._uses_learning_phase = False
track_variable(v)
return v
def _initialize_variables(session):
def track_tf_optimizer(tf_optimizer):
"""Tracks the given TF optimizer for initialization of its variables."""
if context.executing_eagerly():
return
graph = ops.get_default_graph()
if graph not in _GRAPH_TF_OPTIMIZERS:
_GRAPH_TF_OPTIMIZERS[graph] = set()
_GRAPH_TF_OPTIMIZERS[graph].add(tf_optimizer)
def track_variable(v):
"""Tracks the given variable for initialization."""
if context.executing_eagerly():
return
graph = v.graph if hasattr(v, 'graph') else ops.get_default_graph()
if graph not in _GRAPH_VARIABLES:
_GRAPH_VARIABLES[graph] = set()
_GRAPH_VARIABLES[graph].add(v)
def _get_variables(graph=None):
"""Returns variables corresponding to the given graph for initialization."""
assert not context.executing_eagerly()
variables = _GRAPH_VARIABLES.get(graph, set())
for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
variables.update(opt.optimizer.variables())
return variables
def _initialize_variables(session, variables=None):
"""Utility to initialize uninitialized variables on the fly."""
variables = variables_module.global_variables()
if variables is None:
variables = _get_variables(ops.get_default_graph())
candidate_vars = []
for v in variables:
if not getattr(v, '_keras_initialized', False):
@ -974,6 +1017,7 @@ def zeros(shape, dtype=None, name=None):
v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
if py_all(v.shape.as_list()):
return variable(v, dtype=dtype, name=name)
track_variable(v)
return v
@ -1008,6 +1052,7 @@ def ones(shape, dtype=None, name=None):
v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
if py_all(v.shape.as_list()):
return variable(v, dtype=dtype, name=name)
track_variable(v)
return v

View File

@ -570,6 +570,7 @@ class Layer(checkpointable.CheckpointableBase):
use_resource=use_resource,
synchronization=synchronization,
aggregation=aggregation)
backend.track_variable(variable)
if regularizer is not None:
# TODO(fchollet): in the future, this should be handled at the

View File

@ -37,6 +37,7 @@ from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.rmsprop import RMSPropOptimizer
@ -1488,9 +1489,10 @@ class TestTrainingWithDataTensors(test.TestCase):
output_a_np = np.random.random((10, 4))
output_b_np = np.random.random((10, 3))
a = keras.Input(
tensor=keras.backend.variables_module.Variable(input_a_np,
dtype='float32'))
input_v = keras.backend.variables_module.Variable(
input_a_np, dtype='float32')
self.evaluate(variables_lib.variables_initializer([input_v]))
a = keras.Input(tensor=input_v)
b = keras.Input(shape=(3,), name='input_b')
a_2 = keras.layers.Dense(4, name='dense_1')(a)
@ -1535,9 +1537,8 @@ class TestTrainingWithDataTensors(test.TestCase):
# Now test a model with a single input
# i.e. we don't pass any data to fit the model.
a = keras.Input(
tensor=keras.backend.variables_module.Variable(input_a_np,
dtype='float32'))
self.evaluate(variables_lib.variables_initializer([input_v]))
a = keras.Input(tensor=input_v)
a_2 = keras.layers.Dense(4, name='dense_1')(a)
a_2 = keras.layers.Dropout(0.5, name='dropout')(a_2)
model = keras.models.Model(a, a_2)
@ -1575,9 +1576,8 @@ class TestTrainingWithDataTensors(test.TestCase):
# Same, without learning phase
# i.e. we don't pass any data to fit the model.
a = keras.Input(
tensor=keras.backend.variables_module.Variable(input_a_np,
dtype='float32'))
self.evaluate(variables_lib.variables_initializer([input_v]))
a = keras.Input(tensor=input_v)
a_2 = keras.layers.Dense(4, name='dense_1')(a)
model = keras.models.Model(a, a_2)
model.summary()
@ -1700,9 +1700,10 @@ class TestTrainingWithDataTensors(test.TestCase):
out = model.evaluate(input_a_np, None)
# Test model with no external data at all.
a = keras.Input(
tensor=keras.backend.variables_module.Variable(input_a_np,
dtype='float32'))
input_v = keras.backend.variables_module.Variable(
input_a_np, dtype='float32')
self.evaluate(variables_lib.variables_initializer([input_v]))
a = keras.Input(tensor=input_v)
a_2 = keras.layers.Dense(4, name='dense_1')(a)
a_2 = keras.layers.Dropout(0.5, name='dropout')(a_2)
model = keras.models.Model(a, a_2)
@ -1743,9 +1744,8 @@ class TestTrainingWithDataTensors(test.TestCase):
self.assertEqual(out.shape, (10 * 3, 4))
# Test multi-output model with no external data at all.
a = keras.Input(
tensor=keras.backend.variables_module.Variable(input_a_np,
dtype='float32'))
self.evaluate(variables_lib.variables_initializer([input_v]))
a = keras.Input(tensor=input_v)
a_1 = keras.layers.Dense(4, name='dense_1')(a)
a_2 = keras.layers.Dropout(0.5, name='dropout')(a_1)
model = keras.models.Model(a, [a_1, a_2])

View File

@ -450,7 +450,9 @@ def clone_and_build_model(
else:
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
optimizer.iterations = training_util.get_or_create_global_step()
global_step = training_util.get_or_create_global_step()
K.track_variable(global_step)
optimizer.iterations = global_step
clone.compile(
optimizer,

View File

@ -813,7 +813,9 @@ def get(identifier):
"""
# Wrap TF optimizer instances
if isinstance(identifier, tf_optimizer_module.Optimizer):
return TFOptimizer(identifier)
opt = TFOptimizer(identifier)
K.track_tf_optimizer(opt)
return opt
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, six.string_types):

View File

@ -142,6 +142,7 @@ class KerasOptimizersTest(test.TestCase):
2, input_shape=(3,), kernel_constraint=keras.constraints.MaxNorm(1)))
# This is possible
model.compile(loss='mean_squared_error', optimizer=optimizer)
keras.backend.track_tf_optimizer(optimizer)
model.fit(np.random.random((5, 3)),
np.random.random((5, 2)),
epochs=1,
@ -163,6 +164,7 @@ class KerasOptimizersTest(test.TestCase):
model.add(keras.layers.Dense(
2, input_shape=(3,), kernel_constraint=keras.constraints.MaxNorm(1)))
model.compile(loss='mean_squared_error', optimizer=optimizer)
keras.backend.track_tf_optimizer(optimizer)
self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 0)
model.fit(np.random.random((55, 3)),