Use automatic control dependencies in Keras in V2 mode.
PiperOrigin-RevId: 246676835
This commit is contained in:
parent
5dfad3cfa8
commit
b14c390fc8
@ -335,7 +335,7 @@ class AutomaticControlDependencies(object):
|
||||
|
||||
# Ensure all ops which must run do run
|
||||
self.ops_which_must_run.update(ops_which_must_run)
|
||||
for r in self._returned_tensors:
|
||||
for r in nest.flatten(list(self._returned_tensors), expand_composites=True):
|
||||
if self.ops_which_must_run:
|
||||
r.op._add_control_inputs( # pylint: disable=protected-access
|
||||
[o for o in self.ops_which_must_run
|
||||
|
@ -27,10 +27,12 @@ from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.python import autograph
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import values as distribute_values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
@ -617,18 +619,22 @@ class Layer(module.Module):
|
||||
if (self._expects_training_arg and
|
||||
not base_layer_utils.training_arg_passed_to_call(
|
||||
tf_inspect.getfullargspec(self.call), args, kwargs) and
|
||||
getattr(graph, 'name', None) == 'keras_graph'):
|
||||
base_layer_utils.is_in_keras_graph()):
|
||||
learning_phase_passed_by_framework = True
|
||||
kwargs['training'] = backend.learning_phase()
|
||||
if not self.dynamic:
|
||||
try:
|
||||
with base_layer_utils.autocast_context_manager(
|
||||
input_list,
|
||||
self._mixed_precision_policy.should_cast_variables), (
|
||||
base_layer_utils.AutoAddUpdates(self,
|
||||
inputs)) as auto_updater:
|
||||
outputs = call_fn(inputs, *args, **kwargs)
|
||||
auto_updater.set_outputs(outputs)
|
||||
self._mixed_precision_policy.should_cast_variables):
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
with auto_control_deps.AutomaticControlDependencies() as acd:
|
||||
outputs = call_fn(inputs, *args, **kwargs)
|
||||
# Wrap Tensors in `outputs` in `tf.identity` to avoid
|
||||
# circular dependencies.
|
||||
outputs = base_layer_utils.mark_as_return(outputs, acd)
|
||||
else:
|
||||
outputs = call_fn(inputs, *args, **kwargs)
|
||||
|
||||
except TypeError as e:
|
||||
exception_str = str(e)
|
||||
@ -739,7 +745,25 @@ class Layer(module.Module):
|
||||
def updates(self):
|
||||
if not self.trainable and not self.stateful:
|
||||
return []
|
||||
return self._updates + self._gather_children_attribute('updates')
|
||||
with backend.get_graph().as_default():
|
||||
updates = []
|
||||
for u in self._updates:
|
||||
# Filter out updates created in a cross-replica context when in a
|
||||
# replica context and vice versa.
|
||||
if (getattr(u, '_in_cross_replica_context', False) !=
|
||||
ds_context.in_cross_replica_context()):
|
||||
continue
|
||||
if callable(u):
|
||||
try:
|
||||
u = u()
|
||||
except ValueError as e:
|
||||
if 'Trying to capture a tensor from an inner function' in str(e):
|
||||
base_layer_utils.check_graph_consistency(
|
||||
method='add_update', force_raise=True)
|
||||
raise
|
||||
base_layer_utils.check_graph_consistency(u, method='add_update')
|
||||
updates.append(u)
|
||||
return updates + self._gather_children_attribute('updates')
|
||||
|
||||
@property
|
||||
def losses(self):
|
||||
@ -1011,14 +1035,13 @@ class Layer(module.Module):
|
||||
"""
|
||||
updates = generic_utils.to_list(updates)
|
||||
|
||||
if context.executing_eagerly():
|
||||
# Don't run callable updates if currently executing inside the `call`
|
||||
# of a Layer/Model with `trainable=False`.
|
||||
# All updates can be run immediately in Eager or in a tf.function.
|
||||
if base_layer_utils.is_in_eager_or_tf_function():
|
||||
if not base_layer_utils.is_in_frozen_context():
|
||||
for update in updates:
|
||||
if callable(update):
|
||||
update()
|
||||
return # Updates already applied when in eager mode.
|
||||
return
|
||||
|
||||
def process_update(x):
|
||||
"""Standardize update ops.
|
||||
@ -1030,24 +1053,29 @@ class Layer(module.Module):
|
||||
An update op.
|
||||
"""
|
||||
if callable(x):
|
||||
x = x()
|
||||
if isinstance(x, ops.Operation):
|
||||
update = lambda: process_update(x())
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
# In V1 mode, call the callable right away and process. This is needed
|
||||
# for TPU strategy.
|
||||
return update()
|
||||
elif isinstance(x, ops.Operation):
|
||||
update = x
|
||||
elif hasattr(x, 'op'):
|
||||
update = x.op
|
||||
else:
|
||||
update = ops.convert_to_tensor(x)
|
||||
base_layer_utils.check_graph_consistency(update, method='add_update')
|
||||
update._unconditional_update = (inputs is None)
|
||||
update._in_cross_replica_context = (
|
||||
ds_context.has_strategy() and ds_context.in_cross_replica_context())
|
||||
return update
|
||||
|
||||
updates = [process_update(x) for x in updates]
|
||||
# Non-callable Updates are run automatically inside `call` in V2, so
|
||||
# they do not need to be tracked later.
|
||||
if (ops.executing_eagerly_outside_functions() and
|
||||
base_layer_utils.is_in_call_context()):
|
||||
updates = [u for u in updates if callable(u)]
|
||||
self._updates += updates
|
||||
if inputs is None:
|
||||
for u in updates:
|
||||
u._unconditional_update = True # pylint: disable=protected-access
|
||||
else:
|
||||
for u in updates:
|
||||
u._unconditional_update = False # pylint: disable=protected-access
|
||||
|
||||
def set_weights(self, weights):
|
||||
"""Sets the weights of the layer, from Numpy arrays.
|
||||
|
@ -651,13 +651,20 @@ class NameScopingTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
|
||||
|
||||
def test_name_scope_sublayer(self):
|
||||
|
||||
class NameScopeTracker(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
self.active_name_scope = ops.get_name_scope()
|
||||
return inputs
|
||||
|
||||
x = keras.backend.placeholder(shape=(10, 10))
|
||||
layer = keras.layers.Dense(
|
||||
10, activation=keras.layers.ReLU(name='MyAct'), name='MyName2')
|
||||
y = layer(x)
|
||||
sublayer = NameScopeTracker(name='Sublayer')
|
||||
layer = keras.layers.Dense(10, activation=sublayer, name='MyName2')
|
||||
layer(x)
|
||||
self.assertEqual(layer.bias.name, 'MyName2/bias:0')
|
||||
self.assertEqual(layer.kernel.name, 'MyName2/kernel:0')
|
||||
self.assertEqual(y.name, 'MyName2/MyAct/Relu:0')
|
||||
self.assertEqual(sublayer.active_name_scope, 'MyName2/Sublayer')
|
||||
|
||||
def test_name_scope_tf_tensor(self):
|
||||
x = ops.convert_to_tensor(np.ones((10, 10)))
|
||||
@ -779,7 +786,8 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
self.add_update(self.counter.assign_add(math_ops.reduce_sum(inputs)))
|
||||
z = math_ops.reduce_sum(inputs)
|
||||
self.add_update(lambda: self.counter.assign_add(z))
|
||||
return inputs
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
@ -797,7 +805,9 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
||||
# TODO(fchollet): support the same workflow in graph mode.
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
'`add_update` in a control flow branch'):
|
||||
layer = MyLayer()(keras.Input((3,)))
|
||||
layer = MyLayer()
|
||||
layer(keras.Input((3,)))
|
||||
_ = layer.updates
|
||||
|
||||
@parameterized.named_parameters(('eager', True),
|
||||
('symbolic', False))
|
||||
|
@ -23,13 +23,11 @@ import enum
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
@ -351,6 +349,21 @@ def is_in_frozen_context():
|
||||
return getattr(_call_context, 'frozen', False)
|
||||
|
||||
|
||||
def is_in_keras_graph():
|
||||
"""Returns if currently executing inside of a Keras graph."""
|
||||
# Returns True even if in a subgraph of the Keras graph, such as those
|
||||
# created by control flow ops.
|
||||
return (getattr(backend.get_graph(), 'name', None) == 'keras_graph' or
|
||||
getattr(_call_context, 'in_keras_graph', False))
|
||||
|
||||
|
||||
def is_in_eager_or_tf_function():
|
||||
"""Returns if in eager mode or inside of a tf.function."""
|
||||
return (context.executing_eagerly() or
|
||||
(ops.executing_eagerly_outside_functions() and
|
||||
not is_in_keras_graph()))
|
||||
|
||||
|
||||
def uses_keras_history(tensors):
|
||||
"""Check if at least one Tensor originates from a `keras.Input`.
|
||||
|
||||
@ -413,7 +426,11 @@ def call_context(layer):
|
||||
"""Scope that marks when we are currently inside a Layer/Model's `call`."""
|
||||
was_in_call = is_in_call_context()
|
||||
was_frozen = is_in_frozen_context()
|
||||
was_in_keras_graph = getattr(_call_context, 'in_keras_graph', False)
|
||||
_call_context.in_call = True
|
||||
_call_context.in_keras_graph = (
|
||||
was_in_keras_graph or
|
||||
getattr(backend.get_graph(), 'name', None) == 'keras_graph')
|
||||
if not layer.trainable:
|
||||
_call_context.frozen = True
|
||||
try:
|
||||
@ -421,6 +438,7 @@ def call_context(layer):
|
||||
finally:
|
||||
_call_context.in_call = was_in_call
|
||||
_call_context.frozen = was_frozen
|
||||
_call_context.in_keras_graph = was_in_keras_graph
|
||||
|
||||
|
||||
def training_arg_passed_to_call(argspec, args, kwargs):
|
||||
@ -431,121 +449,6 @@ def training_arg_passed_to_call(argspec, args, kwargs):
|
||||
return 'training' in full_args
|
||||
|
||||
|
||||
class AutoAddUpdates(object):
|
||||
"""Automatically track stateful ops with `add_update`.
|
||||
|
||||
This context manager is used to automatically add stateful ops to a Layer
|
||||
or Model's `.updates`. This ensures that stateful ops are run in the Keras
|
||||
training loop. It also allows for these stateful ops to be disabled by
|
||||
setting `trainable=False`.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
with AutoAddUpdates(layer, inputs) as auto_updates:
|
||||
outputs = layer.call(inputs)
|
||||
auto_updates.set_outputs(outputs)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
layer: Layer or Model instance to add the updates to.
|
||||
inputs: The inputs to this Layer or Model, to be used for input-conditional
|
||||
updates.
|
||||
outputs: The outputs of this Layer or Model.
|
||||
"""
|
||||
|
||||
def __init__(self, layer, inputs):
|
||||
self.layer = layer
|
||||
self.inputs = inputs
|
||||
self.outputs = []
|
||||
|
||||
def set_outputs(self, outputs):
|
||||
if self.outputs:
|
||||
raise RuntimeError('`set_outputs` should only be called once on an'
|
||||
'`AutoAddUpdates` instance.')
|
||||
self.outputs = outputs
|
||||
|
||||
def __enter__(self):
|
||||
# Only run in V2 Function mode.
|
||||
if (context.executing_eagerly() or
|
||||
not ops.executing_eagerly_outside_functions()):
|
||||
return self
|
||||
|
||||
self._graph = ops.get_default_graph()
|
||||
self._num_operations = len(self._graph.get_operations())
|
||||
return self
|
||||
|
||||
def __exit__(self, error_type, unused_value, unused_traceback):
|
||||
if error_type:
|
||||
# Allow errors that occurred inside this context manager to pass through
|
||||
# normally.
|
||||
return
|
||||
|
||||
# Only run in V2 Function mode.
|
||||
if (context.executing_eagerly() or
|
||||
not ops.executing_eagerly_outside_functions()):
|
||||
return
|
||||
|
||||
if (self._graph is not ops.get_default_graph() or
|
||||
self._graph.name != 'keras_graph'):
|
||||
# Only auto-track updates when the Keras Graph is the only one used.
|
||||
return
|
||||
|
||||
new_operations = self._graph.get_operations()[self._num_operations:]
|
||||
new_stateful_ops = set()
|
||||
|
||||
# pylint: disable=protected-access
|
||||
for op in new_operations:
|
||||
# While loop is not supported in general for automatic control
|
||||
# dependencies.
|
||||
if control_flow_util.IsInWhileLoop(op):
|
||||
continue
|
||||
|
||||
# Track stateful ops via `add_update`.
|
||||
is_stateful_op = (
|
||||
op.type not in self._graph._registered_ops or
|
||||
auto_control_deps.op_is_stateful(
|
||||
self._graph._registered_ops[op.type]))
|
||||
|
||||
# Ignore ReadVariableOps as they are not needed to be run separately.
|
||||
# This ensures existing Layers don't get extra updates.
|
||||
if is_stateful_op and op.type != 'ReadVariableOp':
|
||||
new_stateful_ops.add(op)
|
||||
|
||||
explicit_updates = set(
|
||||
[u for u in self.layer.updates if not isinstance(u, tuple)])
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Don't add updates that will already be run by virtue of being consumed by
|
||||
# other stateful ops or by the Layer's outputs. This ensures that existing
|
||||
# Layers like `BatchNormalization` continue to return the same values for
|
||||
# `.update` calls.
|
||||
minimum_ops = set()
|
||||
targets = new_stateful_ops.union(
|
||||
set(nest.flatten(self.outputs)), explicit_updates)
|
||||
for op in new_stateful_ops:
|
||||
# Scrub any ops that are consumed by the outputs or other stateful ops.
|
||||
reachable = tf_utils.get_reachable_from_inputs(op)
|
||||
if not (targets - {op}).intersection(reachable):
|
||||
minimum_ops.add(op)
|
||||
new_stateful_ops = minimum_ops
|
||||
|
||||
# Don't double-track updates added via explicitly calling `add_update`.
|
||||
# Also don't double-track updates already tracked in sublayers.
|
||||
new_stateful_ops = new_stateful_ops - explicit_updates
|
||||
|
||||
# Decide whether to track as input-conditional or unconditional.
|
||||
input_reachable_ops = tf_utils.get_reachable_from_inputs(
|
||||
self.inputs, targets=new_stateful_ops)
|
||||
unconditional_updates = new_stateful_ops - input_reachable_ops
|
||||
conditional_updates = new_stateful_ops - unconditional_updates
|
||||
|
||||
if unconditional_updates:
|
||||
self.layer.add_update(list(unconditional_updates))
|
||||
if conditional_updates:
|
||||
self.layer.add_update(list(conditional_updates), inputs=self.inputs)
|
||||
|
||||
|
||||
def _get_var_read_dtype(input_list, should_cast):
|
||||
"""Gets the dtype that AutoCastVariables should be read in."""
|
||||
if should_cast and input_list and input_list[0].dtype.is_floating:
|
||||
@ -579,7 +482,7 @@ def is_subclassed(layer):
|
||||
layer.__module__.find('keras.layers') == -1)
|
||||
|
||||
|
||||
def check_graph_consistency(tensor, method):
|
||||
def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
|
||||
"""Checks that tensors passed to `add_*` method match the Keras graph.
|
||||
|
||||
When one of the `add_*` method is called inside a V2 conditional branch,
|
||||
@ -589,79 +492,101 @@ def check_graph_consistency(tensor, method):
|
||||
Arguments:
|
||||
tensor: Tensor to check.
|
||||
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
|
||||
force_raise: If an error should be raised regardless of `tensor`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: In case of an out-of-graph tensor.
|
||||
"""
|
||||
if ops.executing_eagerly_outside_functions() and hasattr(tensor, 'graph'):
|
||||
if isinstance(tensor.graph,
|
||||
(control_flow_util_v2.CondBranchFuncGraph,
|
||||
control_flow_util_v2.WhileCondFuncGraph,
|
||||
control_flow_util_v2.WhileBodyFuncGraph)):
|
||||
if method == 'add_metric':
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = compute_metric(inputs)
|
||||
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = compute_metric(inputs)
|
||||
else:
|
||||
metric = 0.
|
||||
if (force_raise or (ops.executing_eagerly_outside_functions() and
|
||||
hasattr(tensor, 'graph') and
|
||||
isinstance(tensor.graph,
|
||||
(control_flow_util_v2.CondBranchFuncGraph,
|
||||
control_flow_util_v2.WhileCondFuncGraph,
|
||||
control_flow_util_v2.WhileBodyFuncGraph)))):
|
||||
if method == 'add_metric':
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = compute_metric(inputs)
|
||||
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||
return inputs
|
||||
"""
|
||||
elif method == 'add_loss':
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = compute_loss(inputs)
|
||||
self.add_loss(loss)
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = compute_loss(inputs)
|
||||
else:
|
||||
loss = 0.
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = compute_metric(inputs)
|
||||
else:
|
||||
metric = 0.
|
||||
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||
return inputs
|
||||
"""
|
||||
elif method == 'add_loss':
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = compute_loss(inputs)
|
||||
self.add_loss(loss)
|
||||
return inputs
|
||||
"""
|
||||
else:
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
self.add_update(self.w.assign_add(1))
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
increment = 1
|
||||
else:
|
||||
increment = 0
|
||||
self.add_update(self.w.assign_add(increment))
|
||||
return inputs
|
||||
"""
|
||||
raise RuntimeError(
|
||||
'You are using the method `{method}` in a control flow branch '
|
||||
'in your layer, e.g.:\n{bad_example}\n'
|
||||
'This is not currently supported. '
|
||||
'You should either use static control flow (`tf.cond`) '
|
||||
'or move your call to {method} out of the control flow branch, '
|
||||
'e.g.:\n{correct_example}\n'
|
||||
'You can also resolve this by marking your layer '
|
||||
'as dynamic (eager-only) by passing '
|
||||
'`dynamic=True` to the layer constructor. '
|
||||
'Any kind of control flow is supported with dynamic layers. '
|
||||
'Note that using `dynamic=True` requires you '
|
||||
'to implement static shape inference '
|
||||
'in the `compute_output_shape(input_shape)` method.'.format(
|
||||
method=method,
|
||||
bad_example=bad_example,
|
||||
correct_example=correct_example))
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = compute_loss(inputs)
|
||||
else:
|
||||
loss = 0.
|
||||
self.add_loss(loss)
|
||||
return inputs
|
||||
"""
|
||||
else:
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
self.add_update(self.w.assign_add(1))
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
increment = 1
|
||||
else:
|
||||
increment = 0
|
||||
self.add_update(self.w.assign_add(increment))
|
||||
return inputs
|
||||
"""
|
||||
raise RuntimeError(
|
||||
'You are using the method `{method}` in a control flow branch '
|
||||
'in your layer, e.g.:\n{bad_example}\n'
|
||||
'This is not currently supported. '
|
||||
'You should either use static control flow (`tf.cond`) '
|
||||
'or move your call to {method} out of the control flow branch, '
|
||||
'e.g.:\n{correct_example}\n'
|
||||
'You can also resolve this by marking your layer '
|
||||
'as dynamic (eager-only) by passing '
|
||||
'`dynamic=True` to the layer constructor. '
|
||||
'Any kind of control flow is supported with dynamic layers. '
|
||||
'Note that using `dynamic=True` requires you '
|
||||
'to implement static shape inference '
|
||||
'in the `compute_output_shape(input_shape)` method.'.format(
|
||||
method=method,
|
||||
bad_example=bad_example,
|
||||
correct_example=correct_example))
|
||||
|
||||
|
||||
def mark_as_return(outputs, acd):
|
||||
"""Marks `outputs` as the return values for automatic control deps."""
|
||||
|
||||
def _mark_as_return(tensor):
|
||||
"""Marks `tensor` as the return value for automatic control deps."""
|
||||
if not tensor_util.is_tensor(tensor):
|
||||
return tensor
|
||||
|
||||
# pylint: disable=protected-access
|
||||
return_tensor = acd.mark_as_return(tensor)
|
||||
if getattr(tensor, '_keras_mask', None) is not None:
|
||||
return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
|
||||
else:
|
||||
return_tensor._keras_mask = None
|
||||
return return_tensor
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return nest.map_structure(_mark_as_return, outputs)
|
||||
|
@ -2920,7 +2920,7 @@ class BareUpdateLayer(keras.layers.Layer):
|
||||
return math_ops.cast(self.counter, inputs.dtype) * inputs
|
||||
|
||||
|
||||
class AddUpdateLayer(keras.layers.Layer):
|
||||
class LambdaUpdateLayer(keras.layers.Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
self.counter = self.add_weight(
|
||||
@ -2932,7 +2932,7 @@ class AddUpdateLayer(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
# Make sure update isn't run twice.
|
||||
self.add_update(state_ops.assign_add(self.counter, 1))
|
||||
self.add_update(lambda: state_ops.assign_add(self.counter, 1))
|
||||
return math_ops.cast(self.counter, inputs.dtype) * inputs
|
||||
|
||||
|
||||
@ -2950,12 +2950,31 @@ class NestedUpdateLayer(keras.layers.Layer):
|
||||
return self.layer(inputs)
|
||||
|
||||
|
||||
class SubgraphUpdateLayer(keras.layers.Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
self.counter = self.add_weight(
|
||||
'counter',
|
||||
dtype='int32',
|
||||
shape=(),
|
||||
initializer='zeros',
|
||||
trainable=False)
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training is None:
|
||||
training = keras.backend.learning_phase()
|
||||
|
||||
if training:
|
||||
self.counter.assign(self.counter + 1)
|
||||
return inputs
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
class TestAutoUpdates(keras_parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@parameterized.named_parameters(('bare_update', BareUpdateLayer()),
|
||||
('add_update', AddUpdateLayer()),
|
||||
('lambda_update', LambdaUpdateLayer()),
|
||||
('nested_update', NestedUpdateLayer()))
|
||||
def test_updates_in_model(self, layer):
|
||||
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||
@ -2963,16 +2982,34 @@ class TestAutoUpdates(keras_parameterized.TestCase):
|
||||
[layer, keras.layers.Dense(1)], input_shape=(10,))
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
model.fit(x, y, batch_size=2, epochs=1)
|
||||
if not testing_utils.should_run_eagerly():
|
||||
# Check that `trainable=False` disables updates.
|
||||
layer.trainable = False
|
||||
model.compile(
|
||||
'sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
model.fit(x, y, batch_size=2, epochs=1)
|
||||
self.assertEqual(self.evaluate(layer.counter), 5)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_lambda_updates_trainable_false(self):
|
||||
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||
layer = LambdaUpdateLayer()
|
||||
model = testing_utils.get_model_from_layers(
|
||||
[layer, keras.layers.Dense(1)], input_shape=(10,))
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
model.fit(x, y, batch_size=2, epochs=1)
|
||||
self.assertEqual(self.evaluate(layer.counter), 5)
|
||||
layer.trainable = False
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
model.fit(x, y, batch_size=2, epochs=1)
|
||||
self.assertEqual(self.evaluate(layer.counter), 5)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_subgraph_updates_in_model(self):
|
||||
layer = SubgraphUpdateLayer()
|
||||
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||
model = testing_utils.get_model_from_layers(
|
||||
[layer, keras.layers.Dense(1)], input_shape=(10,))
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
model.fit(x, y, batch_size=2, epochs=1)
|
||||
self.assertEqual(self.evaluate(layer.counter), 5)
|
||||
|
||||
@parameterized.named_parameters(('bare_update', BareUpdateLayer()),
|
||||
('add_update', AddUpdateLayer()),
|
||||
('lambda_update', LambdaUpdateLayer()),
|
||||
('nested_update', NestedUpdateLayer()))
|
||||
def test_updates_standalone_layer(self, layer):
|
||||
y = layer(np.ones((10, 10)))
|
||||
@ -2980,23 +3017,23 @@ class TestAutoUpdates(keras_parameterized.TestCase):
|
||||
self.evaluate(y)
|
||||
self.assertEqual(self.evaluate(layer.counter), 1)
|
||||
|
||||
def test_trainable_false(self):
|
||||
x = keras.backend.placeholder(shape=(10, 10), dtype='float32')
|
||||
layer = NestedUpdateLayer()
|
||||
def test_trainable_false_standalone_layer(self):
|
||||
layer = LambdaUpdateLayer()
|
||||
y = layer(np.ones((10, 10)))
|
||||
self.evaluate(layer.counter.initializer)
|
||||
self.evaluate(y)
|
||||
self.assertEqual(self.evaluate(layer.counter), 1)
|
||||
layer.trainable = False
|
||||
y = layer(x)
|
||||
func = keras.backend.function([x], [y])
|
||||
x_val = np.ones((10, 10))
|
||||
func(x_val)
|
||||
counter = keras.backend.get_value(layer.counter)
|
||||
self.assertEqual(counter, 0)
|
||||
y = layer(np.ones((10, 10)))
|
||||
self.evaluate(y)
|
||||
self.assertEqual(self.evaluate(layer.counter), 1)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_batchnorm_trainable_false(self):
|
||||
bn = keras.layers.BatchNormalization()
|
||||
bn.trainable = False
|
||||
model = testing_utils.get_model_from_layers([bn, keras.layers.Dense(1)],
|
||||
input_shape=(10,))
|
||||
bn.trainable = False
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||
model.fit(x, y, batch_size=2, epochs=1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user