Use automatic control dependencies in Keras in V2 mode.

PiperOrigin-RevId: 246676835
This commit is contained in:
Thomas O'Malley 2019-05-04 16:50:39 -07:00 committed by TensorFlower Gardener
parent 5dfad3cfa8
commit b14c390fc8
5 changed files with 236 additions and 236 deletions

View File

@ -335,7 +335,7 @@ class AutomaticControlDependencies(object):
# Ensure all ops which must run do run
self.ops_which_must_run.update(ops_which_must_run)
for r in self._returned_tensors:
for r in nest.flatten(list(self._returned_tensors), expand_composites=True):
if self.ops_which_must_run:
r.op._add_control_inputs( # pylint: disable=protected-access
[o for o in self.ops_which_must_run

View File

@ -27,10 +27,12 @@ from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.core.framework import node_def_pb2
from tensorflow.python import autograph
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values as distribute_values
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import function
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
@ -617,18 +619,22 @@ class Layer(module.Module):
if (self._expects_training_arg and
not base_layer_utils.training_arg_passed_to_call(
tf_inspect.getfullargspec(self.call), args, kwargs) and
getattr(graph, 'name', None) == 'keras_graph'):
base_layer_utils.is_in_keras_graph()):
learning_phase_passed_by_framework = True
kwargs['training'] = backend.learning_phase()
if not self.dynamic:
try:
with base_layer_utils.autocast_context_manager(
input_list,
self._mixed_precision_policy.should_cast_variables), (
base_layer_utils.AutoAddUpdates(self,
inputs)) as auto_updater:
outputs = call_fn(inputs, *args, **kwargs)
auto_updater.set_outputs(outputs)
self._mixed_precision_policy.should_cast_variables):
if ops.executing_eagerly_outside_functions():
with auto_control_deps.AutomaticControlDependencies() as acd:
outputs = call_fn(inputs, *args, **kwargs)
# Wrap Tensors in `outputs` in `tf.identity` to avoid
# circular dependencies.
outputs = base_layer_utils.mark_as_return(outputs, acd)
else:
outputs = call_fn(inputs, *args, **kwargs)
except TypeError as e:
exception_str = str(e)
@ -739,7 +745,25 @@ class Layer(module.Module):
def updates(self):
if not self.trainable and not self.stateful:
return []
return self._updates + self._gather_children_attribute('updates')
with backend.get_graph().as_default():
updates = []
for u in self._updates:
# Filter out updates created in a cross-replica context when in a
# replica context and vice versa.
if (getattr(u, '_in_cross_replica_context', False) !=
ds_context.in_cross_replica_context()):
continue
if callable(u):
try:
u = u()
except ValueError as e:
if 'Trying to capture a tensor from an inner function' in str(e):
base_layer_utils.check_graph_consistency(
method='add_update', force_raise=True)
raise
base_layer_utils.check_graph_consistency(u, method='add_update')
updates.append(u)
return updates + self._gather_children_attribute('updates')
@property
def losses(self):
@ -1011,14 +1035,13 @@ class Layer(module.Module):
"""
updates = generic_utils.to_list(updates)
if context.executing_eagerly():
# Don't run callable updates if currently executing inside the `call`
# of a Layer/Model with `trainable=False`.
# All updates can be run immediately in Eager or in a tf.function.
if base_layer_utils.is_in_eager_or_tf_function():
if not base_layer_utils.is_in_frozen_context():
for update in updates:
if callable(update):
update()
return # Updates already applied when in eager mode.
return
def process_update(x):
"""Standardize update ops.
@ -1030,24 +1053,29 @@ class Layer(module.Module):
An update op.
"""
if callable(x):
x = x()
if isinstance(x, ops.Operation):
update = lambda: process_update(x())
if not ops.executing_eagerly_outside_functions():
# In V1 mode, call the callable right away and process. This is needed
# for TPU strategy.
return update()
elif isinstance(x, ops.Operation):
update = x
elif hasattr(x, 'op'):
update = x.op
else:
update = ops.convert_to_tensor(x)
base_layer_utils.check_graph_consistency(update, method='add_update')
update._unconditional_update = (inputs is None)
update._in_cross_replica_context = (
ds_context.has_strategy() and ds_context.in_cross_replica_context())
return update
updates = [process_update(x) for x in updates]
# Non-callable Updates are run automatically inside `call` in V2, so
# they do not need to be tracked later.
if (ops.executing_eagerly_outside_functions() and
base_layer_utils.is_in_call_context()):
updates = [u for u in updates if callable(u)]
self._updates += updates
if inputs is None:
for u in updates:
u._unconditional_update = True # pylint: disable=protected-access
else:
for u in updates:
u._unconditional_update = False # pylint: disable=protected-access
def set_weights(self, weights):
"""Sets the weights of the layer, from Numpy arrays.

View File

@ -651,13 +651,20 @@ class NameScopingTest(keras_parameterized.TestCase):
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
def test_name_scope_sublayer(self):
class NameScopeTracker(keras.layers.Layer):
def call(self, inputs):
self.active_name_scope = ops.get_name_scope()
return inputs
x = keras.backend.placeholder(shape=(10, 10))
layer = keras.layers.Dense(
10, activation=keras.layers.ReLU(name='MyAct'), name='MyName2')
y = layer(x)
sublayer = NameScopeTracker(name='Sublayer')
layer = keras.layers.Dense(10, activation=sublayer, name='MyName2')
layer(x)
self.assertEqual(layer.bias.name, 'MyName2/bias:0')
self.assertEqual(layer.kernel.name, 'MyName2/kernel:0')
self.assertEqual(y.name, 'MyName2/MyAct/Relu:0')
self.assertEqual(sublayer.active_name_scope, 'MyName2/Sublayer')
def test_name_scope_tf_tensor(self):
x = ops.convert_to_tensor(np.ones((10, 10)))
@ -779,7 +786,8 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
def call(self, inputs, training=None):
if training:
self.add_update(self.counter.assign_add(math_ops.reduce_sum(inputs)))
z = math_ops.reduce_sum(inputs)
self.add_update(lambda: self.counter.assign_add(z))
return inputs
def compute_output_shape(self, input_shape):
@ -797,7 +805,9 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
# TODO(fchollet): support the same workflow in graph mode.
with self.assertRaisesRegexp(RuntimeError,
'`add_update` in a control flow branch'):
layer = MyLayer()(keras.Input((3,)))
layer = MyLayer()
layer(keras.Input((3,)))
_ = layer.updates
@parameterized.named_parameters(('eager', True),
('symbolic', False))

View File

@ -23,13 +23,11 @@ import enum
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import init_ops_v2
@ -351,6 +349,21 @@ def is_in_frozen_context():
return getattr(_call_context, 'frozen', False)
def is_in_keras_graph():
"""Returns if currently executing inside of a Keras graph."""
# Returns True even if in a subgraph of the Keras graph, such as those
# created by control flow ops.
return (getattr(backend.get_graph(), 'name', None) == 'keras_graph' or
getattr(_call_context, 'in_keras_graph', False))
def is_in_eager_or_tf_function():
"""Returns if in eager mode or inside of a tf.function."""
return (context.executing_eagerly() or
(ops.executing_eagerly_outside_functions() and
not is_in_keras_graph()))
def uses_keras_history(tensors):
"""Check if at least one Tensor originates from a `keras.Input`.
@ -413,7 +426,11 @@ def call_context(layer):
"""Scope that marks when we are currently inside a Layer/Model's `call`."""
was_in_call = is_in_call_context()
was_frozen = is_in_frozen_context()
was_in_keras_graph = getattr(_call_context, 'in_keras_graph', False)
_call_context.in_call = True
_call_context.in_keras_graph = (
was_in_keras_graph or
getattr(backend.get_graph(), 'name', None) == 'keras_graph')
if not layer.trainable:
_call_context.frozen = True
try:
@ -421,6 +438,7 @@ def call_context(layer):
finally:
_call_context.in_call = was_in_call
_call_context.frozen = was_frozen
_call_context.in_keras_graph = was_in_keras_graph
def training_arg_passed_to_call(argspec, args, kwargs):
@ -431,121 +449,6 @@ def training_arg_passed_to_call(argspec, args, kwargs):
return 'training' in full_args
class AutoAddUpdates(object):
"""Automatically track stateful ops with `add_update`.
This context manager is used to automatically add stateful ops to a Layer
or Model's `.updates`. This ensures that stateful ops are run in the Keras
training loop. It also allows for these stateful ops to be disabled by
setting `trainable=False`.
Example:
```
with AutoAddUpdates(layer, inputs) as auto_updates:
outputs = layer.call(inputs)
auto_updates.set_outputs(outputs)
```
Attributes:
layer: Layer or Model instance to add the updates to.
inputs: The inputs to this Layer or Model, to be used for input-conditional
updates.
outputs: The outputs of this Layer or Model.
"""
def __init__(self, layer, inputs):
self.layer = layer
self.inputs = inputs
self.outputs = []
def set_outputs(self, outputs):
if self.outputs:
raise RuntimeError('`set_outputs` should only be called once on an'
'`AutoAddUpdates` instance.')
self.outputs = outputs
def __enter__(self):
# Only run in V2 Function mode.
if (context.executing_eagerly() or
not ops.executing_eagerly_outside_functions()):
return self
self._graph = ops.get_default_graph()
self._num_operations = len(self._graph.get_operations())
return self
def __exit__(self, error_type, unused_value, unused_traceback):
if error_type:
# Allow errors that occurred inside this context manager to pass through
# normally.
return
# Only run in V2 Function mode.
if (context.executing_eagerly() or
not ops.executing_eagerly_outside_functions()):
return
if (self._graph is not ops.get_default_graph() or
self._graph.name != 'keras_graph'):
# Only auto-track updates when the Keras Graph is the only one used.
return
new_operations = self._graph.get_operations()[self._num_operations:]
new_stateful_ops = set()
# pylint: disable=protected-access
for op in new_operations:
# While loop is not supported in general for automatic control
# dependencies.
if control_flow_util.IsInWhileLoop(op):
continue
# Track stateful ops via `add_update`.
is_stateful_op = (
op.type not in self._graph._registered_ops or
auto_control_deps.op_is_stateful(
self._graph._registered_ops[op.type]))
# Ignore ReadVariableOps as they are not needed to be run separately.
# This ensures existing Layers don't get extra updates.
if is_stateful_op and op.type != 'ReadVariableOp':
new_stateful_ops.add(op)
explicit_updates = set(
[u for u in self.layer.updates if not isinstance(u, tuple)])
# pylint: enable=protected-access
# Don't add updates that will already be run by virtue of being consumed by
# other stateful ops or by the Layer's outputs. This ensures that existing
# Layers like `BatchNormalization` continue to return the same values for
# `.update` calls.
minimum_ops = set()
targets = new_stateful_ops.union(
set(nest.flatten(self.outputs)), explicit_updates)
for op in new_stateful_ops:
# Scrub any ops that are consumed by the outputs or other stateful ops.
reachable = tf_utils.get_reachable_from_inputs(op)
if not (targets - {op}).intersection(reachable):
minimum_ops.add(op)
new_stateful_ops = minimum_ops
# Don't double-track updates added via explicitly calling `add_update`.
# Also don't double-track updates already tracked in sublayers.
new_stateful_ops = new_stateful_ops - explicit_updates
# Decide whether to track as input-conditional or unconditional.
input_reachable_ops = tf_utils.get_reachable_from_inputs(
self.inputs, targets=new_stateful_ops)
unconditional_updates = new_stateful_ops - input_reachable_ops
conditional_updates = new_stateful_ops - unconditional_updates
if unconditional_updates:
self.layer.add_update(list(unconditional_updates))
if conditional_updates:
self.layer.add_update(list(conditional_updates), inputs=self.inputs)
def _get_var_read_dtype(input_list, should_cast):
"""Gets the dtype that AutoCastVariables should be read in."""
if should_cast and input_list and input_list[0].dtype.is_floating:
@ -579,7 +482,7 @@ def is_subclassed(layer):
layer.__module__.find('keras.layers') == -1)
def check_graph_consistency(tensor, method):
def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
"""Checks that tensors passed to `add_*` method match the Keras graph.
When one of the `add_*` method is called inside a V2 conditional branch,
@ -589,79 +492,101 @@ def check_graph_consistency(tensor, method):
Arguments:
tensor: Tensor to check.
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
force_raise: If an error should be raised regardless of `tensor`.
Raises:
RuntimeError: In case of an out-of-graph tensor.
"""
if ops.executing_eagerly_outside_functions() and hasattr(tensor, 'graph'):
if isinstance(tensor.graph,
(control_flow_util_v2.CondBranchFuncGraph,
control_flow_util_v2.WhileCondFuncGraph,
control_flow_util_v2.WhileBodyFuncGraph)):
if method == 'add_metric':
bad_example = """
def call(self, inputs, training=None):
if training:
metric = compute_metric(inputs)
self.add_metric(metric, name='my_metric', aggregation='mean')
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
metric = compute_metric(inputs)
else:
metric = 0.
if (force_raise or (ops.executing_eagerly_outside_functions() and
hasattr(tensor, 'graph') and
isinstance(tensor.graph,
(control_flow_util_v2.CondBranchFuncGraph,
control_flow_util_v2.WhileCondFuncGraph,
control_flow_util_v2.WhileBodyFuncGraph)))):
if method == 'add_metric':
bad_example = """
def call(self, inputs, training=None):
if training:
metric = compute_metric(inputs)
self.add_metric(metric, name='my_metric', aggregation='mean')
return inputs
"""
elif method == 'add_loss':
bad_example = """
def call(self, inputs, training=None):
if training:
loss = compute_loss(inputs)
self.add_loss(loss)
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
loss = compute_loss(inputs)
else:
loss = 0.
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
metric = compute_metric(inputs)
else:
metric = 0.
self.add_metric(metric, name='my_metric', aggregation='mean')
return inputs
"""
elif method == 'add_loss':
bad_example = """
def call(self, inputs, training=None):
if training:
loss = compute_loss(inputs)
self.add_loss(loss)
return inputs
"""
else:
bad_example = """
def call(self, inputs, training=None):
if training:
self.add_update(self.w.assign_add(1))
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
increment = 1
else:
increment = 0
self.add_update(self.w.assign_add(increment))
return inputs
"""
raise RuntimeError(
'You are using the method `{method}` in a control flow branch '
'in your layer, e.g.:\n{bad_example}\n'
'This is not currently supported. '
'You should either use static control flow (`tf.cond`) '
'or move your call to {method} out of the control flow branch, '
'e.g.:\n{correct_example}\n'
'You can also resolve this by marking your layer '
'as dynamic (eager-only) by passing '
'`dynamic=True` to the layer constructor. '
'Any kind of control flow is supported with dynamic layers. '
'Note that using `dynamic=True` requires you '
'to implement static shape inference '
'in the `compute_output_shape(input_shape)` method.'.format(
method=method,
bad_example=bad_example,
correct_example=correct_example))
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
loss = compute_loss(inputs)
else:
loss = 0.
self.add_loss(loss)
return inputs
"""
else:
bad_example = """
def call(self, inputs, training=None):
if training:
self.add_update(self.w.assign_add(1))
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
increment = 1
else:
increment = 0
self.add_update(self.w.assign_add(increment))
return inputs
"""
raise RuntimeError(
'You are using the method `{method}` in a control flow branch '
'in your layer, e.g.:\n{bad_example}\n'
'This is not currently supported. '
'You should either use static control flow (`tf.cond`) '
'or move your call to {method} out of the control flow branch, '
'e.g.:\n{correct_example}\n'
'You can also resolve this by marking your layer '
'as dynamic (eager-only) by passing '
'`dynamic=True` to the layer constructor. '
'Any kind of control flow is supported with dynamic layers. '
'Note that using `dynamic=True` requires you '
'to implement static shape inference '
'in the `compute_output_shape(input_shape)` method.'.format(
method=method,
bad_example=bad_example,
correct_example=correct_example))
def mark_as_return(outputs, acd):
"""Marks `outputs` as the return values for automatic control deps."""
def _mark_as_return(tensor):
"""Marks `tensor` as the return value for automatic control deps."""
if not tensor_util.is_tensor(tensor):
return tensor
# pylint: disable=protected-access
return_tensor = acd.mark_as_return(tensor)
if getattr(tensor, '_keras_mask', None) is not None:
return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
else:
return_tensor._keras_mask = None
return return_tensor
# pylint: enable=protected-access
return nest.map_structure(_mark_as_return, outputs)

View File

@ -2920,7 +2920,7 @@ class BareUpdateLayer(keras.layers.Layer):
return math_ops.cast(self.counter, inputs.dtype) * inputs
class AddUpdateLayer(keras.layers.Layer):
class LambdaUpdateLayer(keras.layers.Layer):
def build(self, input_shape):
self.counter = self.add_weight(
@ -2932,7 +2932,7 @@ class AddUpdateLayer(keras.layers.Layer):
def call(self, inputs):
# Make sure update isn't run twice.
self.add_update(state_ops.assign_add(self.counter, 1))
self.add_update(lambda: state_ops.assign_add(self.counter, 1))
return math_ops.cast(self.counter, inputs.dtype) * inputs
@ -2950,12 +2950,31 @@ class NestedUpdateLayer(keras.layers.Layer):
return self.layer(inputs)
class SubgraphUpdateLayer(keras.layers.Layer):
def build(self, input_shape):
self.counter = self.add_weight(
'counter',
dtype='int32',
shape=(),
initializer='zeros',
trainable=False)
def call(self, inputs, training=None):
if training is None:
training = keras.backend.learning_phase()
if training:
self.counter.assign(self.counter + 1)
return inputs
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class TestAutoUpdates(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
@parameterized.named_parameters(('bare_update', BareUpdateLayer()),
('add_update', AddUpdateLayer()),
('lambda_update', LambdaUpdateLayer()),
('nested_update', NestedUpdateLayer()))
def test_updates_in_model(self, layer):
x, y = np.ones((10, 10)), np.ones((10, 1))
@ -2963,16 +2982,34 @@ class TestAutoUpdates(keras_parameterized.TestCase):
[layer, keras.layers.Dense(1)], input_shape=(10,))
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
model.fit(x, y, batch_size=2, epochs=1)
if not testing_utils.should_run_eagerly():
# Check that `trainable=False` disables updates.
layer.trainable = False
model.compile(
'sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
model.fit(x, y, batch_size=2, epochs=1)
self.assertEqual(self.evaluate(layer.counter), 5)
@keras_parameterized.run_with_all_model_types
def test_lambda_updates_trainable_false(self):
x, y = np.ones((10, 10)), np.ones((10, 1))
layer = LambdaUpdateLayer()
model = testing_utils.get_model_from_layers(
[layer, keras.layers.Dense(1)], input_shape=(10,))
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
model.fit(x, y, batch_size=2, epochs=1)
self.assertEqual(self.evaluate(layer.counter), 5)
layer.trainable = False
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
model.fit(x, y, batch_size=2, epochs=1)
self.assertEqual(self.evaluate(layer.counter), 5)
@keras_parameterized.run_with_all_model_types
def test_subgraph_updates_in_model(self):
layer = SubgraphUpdateLayer()
x, y = np.ones((10, 10)), np.ones((10, 1))
model = testing_utils.get_model_from_layers(
[layer, keras.layers.Dense(1)], input_shape=(10,))
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
model.fit(x, y, batch_size=2, epochs=1)
self.assertEqual(self.evaluate(layer.counter), 5)
@parameterized.named_parameters(('bare_update', BareUpdateLayer()),
('add_update', AddUpdateLayer()),
('lambda_update', LambdaUpdateLayer()),
('nested_update', NestedUpdateLayer()))
def test_updates_standalone_layer(self, layer):
y = layer(np.ones((10, 10)))
@ -2980,23 +3017,23 @@ class TestAutoUpdates(keras_parameterized.TestCase):
self.evaluate(y)
self.assertEqual(self.evaluate(layer.counter), 1)
def test_trainable_false(self):
x = keras.backend.placeholder(shape=(10, 10), dtype='float32')
layer = NestedUpdateLayer()
def test_trainable_false_standalone_layer(self):
layer = LambdaUpdateLayer()
y = layer(np.ones((10, 10)))
self.evaluate(layer.counter.initializer)
self.evaluate(y)
self.assertEqual(self.evaluate(layer.counter), 1)
layer.trainable = False
y = layer(x)
func = keras.backend.function([x], [y])
x_val = np.ones((10, 10))
func(x_val)
counter = keras.backend.get_value(layer.counter)
self.assertEqual(counter, 0)
y = layer(np.ones((10, 10)))
self.evaluate(y)
self.assertEqual(self.evaluate(layer.counter), 1)
@keras_parameterized.run_with_all_model_types
def test_batchnorm_trainable_false(self):
bn = keras.layers.BatchNormalization()
bn.trainable = False
model = testing_utils.get_model_from_layers([bn, keras.layers.Dense(1)],
input_shape=(10,))
bn.trainable = False
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
x, y = np.ones((10, 10)), np.ones((10, 1))
model.fit(x, y, batch_size=2, epochs=1)