Use automatic control dependencies in Keras in V2 mode.
PiperOrigin-RevId: 246676835
This commit is contained in:
parent
5dfad3cfa8
commit
b14c390fc8
@ -335,7 +335,7 @@ class AutomaticControlDependencies(object):
|
|||||||
|
|
||||||
# Ensure all ops which must run do run
|
# Ensure all ops which must run do run
|
||||||
self.ops_which_must_run.update(ops_which_must_run)
|
self.ops_which_must_run.update(ops_which_must_run)
|
||||||
for r in self._returned_tensors:
|
for r in nest.flatten(list(self._returned_tensors), expand_composites=True):
|
||||||
if self.ops_which_must_run:
|
if self.ops_which_must_run:
|
||||||
r.op._add_control_inputs( # pylint: disable=protected-access
|
r.op._add_control_inputs( # pylint: disable=protected-access
|
||||||
[o for o in self.ops_which_must_run
|
[o for o in self.ops_which_must_run
|
||||||
|
|||||||
@ -27,10 +27,12 @@ from six.moves import zip # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
from tensorflow.core.framework import node_def_pb2
|
from tensorflow.core.framework import node_def_pb2
|
||||||
from tensorflow.python import autograph
|
from tensorflow.python import autograph
|
||||||
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.distribute import values as distribute_values
|
from tensorflow.python.distribute import values as distribute_values
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import execute
|
from tensorflow.python.eager import execute
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
|
from tensorflow.python.framework import auto_control_deps
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import func_graph
|
from tensorflow.python.framework import func_graph
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -617,18 +619,22 @@ class Layer(module.Module):
|
|||||||
if (self._expects_training_arg and
|
if (self._expects_training_arg and
|
||||||
not base_layer_utils.training_arg_passed_to_call(
|
not base_layer_utils.training_arg_passed_to_call(
|
||||||
tf_inspect.getfullargspec(self.call), args, kwargs) and
|
tf_inspect.getfullargspec(self.call), args, kwargs) and
|
||||||
getattr(graph, 'name', None) == 'keras_graph'):
|
base_layer_utils.is_in_keras_graph()):
|
||||||
learning_phase_passed_by_framework = True
|
learning_phase_passed_by_framework = True
|
||||||
kwargs['training'] = backend.learning_phase()
|
kwargs['training'] = backend.learning_phase()
|
||||||
if not self.dynamic:
|
if not self.dynamic:
|
||||||
try:
|
try:
|
||||||
with base_layer_utils.autocast_context_manager(
|
with base_layer_utils.autocast_context_manager(
|
||||||
input_list,
|
input_list,
|
||||||
self._mixed_precision_policy.should_cast_variables), (
|
self._mixed_precision_policy.should_cast_variables):
|
||||||
base_layer_utils.AutoAddUpdates(self,
|
if ops.executing_eagerly_outside_functions():
|
||||||
inputs)) as auto_updater:
|
with auto_control_deps.AutomaticControlDependencies() as acd:
|
||||||
|
outputs = call_fn(inputs, *args, **kwargs)
|
||||||
|
# Wrap Tensors in `outputs` in `tf.identity` to avoid
|
||||||
|
# circular dependencies.
|
||||||
|
outputs = base_layer_utils.mark_as_return(outputs, acd)
|
||||||
|
else:
|
||||||
outputs = call_fn(inputs, *args, **kwargs)
|
outputs = call_fn(inputs, *args, **kwargs)
|
||||||
auto_updater.set_outputs(outputs)
|
|
||||||
|
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
exception_str = str(e)
|
exception_str = str(e)
|
||||||
@ -739,7 +745,25 @@ class Layer(module.Module):
|
|||||||
def updates(self):
|
def updates(self):
|
||||||
if not self.trainable and not self.stateful:
|
if not self.trainable and not self.stateful:
|
||||||
return []
|
return []
|
||||||
return self._updates + self._gather_children_attribute('updates')
|
with backend.get_graph().as_default():
|
||||||
|
updates = []
|
||||||
|
for u in self._updates:
|
||||||
|
# Filter out updates created in a cross-replica context when in a
|
||||||
|
# replica context and vice versa.
|
||||||
|
if (getattr(u, '_in_cross_replica_context', False) !=
|
||||||
|
ds_context.in_cross_replica_context()):
|
||||||
|
continue
|
||||||
|
if callable(u):
|
||||||
|
try:
|
||||||
|
u = u()
|
||||||
|
except ValueError as e:
|
||||||
|
if 'Trying to capture a tensor from an inner function' in str(e):
|
||||||
|
base_layer_utils.check_graph_consistency(
|
||||||
|
method='add_update', force_raise=True)
|
||||||
|
raise
|
||||||
|
base_layer_utils.check_graph_consistency(u, method='add_update')
|
||||||
|
updates.append(u)
|
||||||
|
return updates + self._gather_children_attribute('updates')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def losses(self):
|
def losses(self):
|
||||||
@ -1011,14 +1035,13 @@ class Layer(module.Module):
|
|||||||
"""
|
"""
|
||||||
updates = generic_utils.to_list(updates)
|
updates = generic_utils.to_list(updates)
|
||||||
|
|
||||||
if context.executing_eagerly():
|
# All updates can be run immediately in Eager or in a tf.function.
|
||||||
# Don't run callable updates if currently executing inside the `call`
|
if base_layer_utils.is_in_eager_or_tf_function():
|
||||||
# of a Layer/Model with `trainable=False`.
|
|
||||||
if not base_layer_utils.is_in_frozen_context():
|
if not base_layer_utils.is_in_frozen_context():
|
||||||
for update in updates:
|
for update in updates:
|
||||||
if callable(update):
|
if callable(update):
|
||||||
update()
|
update()
|
||||||
return # Updates already applied when in eager mode.
|
return
|
||||||
|
|
||||||
def process_update(x):
|
def process_update(x):
|
||||||
"""Standardize update ops.
|
"""Standardize update ops.
|
||||||
@ -1030,24 +1053,29 @@ class Layer(module.Module):
|
|||||||
An update op.
|
An update op.
|
||||||
"""
|
"""
|
||||||
if callable(x):
|
if callable(x):
|
||||||
x = x()
|
update = lambda: process_update(x())
|
||||||
if isinstance(x, ops.Operation):
|
if not ops.executing_eagerly_outside_functions():
|
||||||
|
# In V1 mode, call the callable right away and process. This is needed
|
||||||
|
# for TPU strategy.
|
||||||
|
return update()
|
||||||
|
elif isinstance(x, ops.Operation):
|
||||||
update = x
|
update = x
|
||||||
elif hasattr(x, 'op'):
|
elif hasattr(x, 'op'):
|
||||||
update = x.op
|
update = x.op
|
||||||
else:
|
else:
|
||||||
update = ops.convert_to_tensor(x)
|
update = ops.convert_to_tensor(x)
|
||||||
base_layer_utils.check_graph_consistency(update, method='add_update')
|
update._unconditional_update = (inputs is None)
|
||||||
|
update._in_cross_replica_context = (
|
||||||
|
ds_context.has_strategy() and ds_context.in_cross_replica_context())
|
||||||
return update
|
return update
|
||||||
|
|
||||||
updates = [process_update(x) for x in updates]
|
updates = [process_update(x) for x in updates]
|
||||||
|
# Non-callable Updates are run automatically inside `call` in V2, so
|
||||||
|
# they do not need to be tracked later.
|
||||||
|
if (ops.executing_eagerly_outside_functions() and
|
||||||
|
base_layer_utils.is_in_call_context()):
|
||||||
|
updates = [u for u in updates if callable(u)]
|
||||||
self._updates += updates
|
self._updates += updates
|
||||||
if inputs is None:
|
|
||||||
for u in updates:
|
|
||||||
u._unconditional_update = True # pylint: disable=protected-access
|
|
||||||
else:
|
|
||||||
for u in updates:
|
|
||||||
u._unconditional_update = False # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def set_weights(self, weights):
|
def set_weights(self, weights):
|
||||||
"""Sets the weights of the layer, from Numpy arrays.
|
"""Sets the weights of the layer, from Numpy arrays.
|
||||||
|
|||||||
@ -651,13 +651,20 @@ class NameScopingTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
|
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
|
||||||
|
|
||||||
def test_name_scope_sublayer(self):
|
def test_name_scope_sublayer(self):
|
||||||
|
|
||||||
|
class NameScopeTracker(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
self.active_name_scope = ops.get_name_scope()
|
||||||
|
return inputs
|
||||||
|
|
||||||
x = keras.backend.placeholder(shape=(10, 10))
|
x = keras.backend.placeholder(shape=(10, 10))
|
||||||
layer = keras.layers.Dense(
|
sublayer = NameScopeTracker(name='Sublayer')
|
||||||
10, activation=keras.layers.ReLU(name='MyAct'), name='MyName2')
|
layer = keras.layers.Dense(10, activation=sublayer, name='MyName2')
|
||||||
y = layer(x)
|
layer(x)
|
||||||
self.assertEqual(layer.bias.name, 'MyName2/bias:0')
|
self.assertEqual(layer.bias.name, 'MyName2/bias:0')
|
||||||
self.assertEqual(layer.kernel.name, 'MyName2/kernel:0')
|
self.assertEqual(layer.kernel.name, 'MyName2/kernel:0')
|
||||||
self.assertEqual(y.name, 'MyName2/MyAct/Relu:0')
|
self.assertEqual(sublayer.active_name_scope, 'MyName2/Sublayer')
|
||||||
|
|
||||||
def test_name_scope_tf_tensor(self):
|
def test_name_scope_tf_tensor(self):
|
||||||
x = ops.convert_to_tensor(np.ones((10, 10)))
|
x = ops.convert_to_tensor(np.ones((10, 10)))
|
||||||
@ -779,7 +786,8 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def call(self, inputs, training=None):
|
def call(self, inputs, training=None):
|
||||||
if training:
|
if training:
|
||||||
self.add_update(self.counter.assign_add(math_ops.reduce_sum(inputs)))
|
z = math_ops.reduce_sum(inputs)
|
||||||
|
self.add_update(lambda: self.counter.assign_add(z))
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
@ -797,7 +805,9 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
# TODO(fchollet): support the same workflow in graph mode.
|
# TODO(fchollet): support the same workflow in graph mode.
|
||||||
with self.assertRaisesRegexp(RuntimeError,
|
with self.assertRaisesRegexp(RuntimeError,
|
||||||
'`add_update` in a control flow branch'):
|
'`add_update` in a control flow branch'):
|
||||||
layer = MyLayer()(keras.Input((3,)))
|
layer = MyLayer()
|
||||||
|
layer(keras.Input((3,)))
|
||||||
|
_ = layer.updates
|
||||||
|
|
||||||
@parameterized.named_parameters(('eager', True),
|
@parameterized.named_parameters(('eager', True),
|
||||||
('symbolic', False))
|
('symbolic', False))
|
||||||
|
|||||||
@ -23,13 +23,11 @@ import enum
|
|||||||
|
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import auto_control_deps
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_util
|
|
||||||
from tensorflow.python.ops import control_flow_util_v2
|
from tensorflow.python.ops import control_flow_util_v2
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import init_ops_v2
|
from tensorflow.python.ops import init_ops_v2
|
||||||
@ -351,6 +349,21 @@ def is_in_frozen_context():
|
|||||||
return getattr(_call_context, 'frozen', False)
|
return getattr(_call_context, 'frozen', False)
|
||||||
|
|
||||||
|
|
||||||
|
def is_in_keras_graph():
|
||||||
|
"""Returns if currently executing inside of a Keras graph."""
|
||||||
|
# Returns True even if in a subgraph of the Keras graph, such as those
|
||||||
|
# created by control flow ops.
|
||||||
|
return (getattr(backend.get_graph(), 'name', None) == 'keras_graph' or
|
||||||
|
getattr(_call_context, 'in_keras_graph', False))
|
||||||
|
|
||||||
|
|
||||||
|
def is_in_eager_or_tf_function():
|
||||||
|
"""Returns if in eager mode or inside of a tf.function."""
|
||||||
|
return (context.executing_eagerly() or
|
||||||
|
(ops.executing_eagerly_outside_functions() and
|
||||||
|
not is_in_keras_graph()))
|
||||||
|
|
||||||
|
|
||||||
def uses_keras_history(tensors):
|
def uses_keras_history(tensors):
|
||||||
"""Check if at least one Tensor originates from a `keras.Input`.
|
"""Check if at least one Tensor originates from a `keras.Input`.
|
||||||
|
|
||||||
@ -413,7 +426,11 @@ def call_context(layer):
|
|||||||
"""Scope that marks when we are currently inside a Layer/Model's `call`."""
|
"""Scope that marks when we are currently inside a Layer/Model's `call`."""
|
||||||
was_in_call = is_in_call_context()
|
was_in_call = is_in_call_context()
|
||||||
was_frozen = is_in_frozen_context()
|
was_frozen = is_in_frozen_context()
|
||||||
|
was_in_keras_graph = getattr(_call_context, 'in_keras_graph', False)
|
||||||
_call_context.in_call = True
|
_call_context.in_call = True
|
||||||
|
_call_context.in_keras_graph = (
|
||||||
|
was_in_keras_graph or
|
||||||
|
getattr(backend.get_graph(), 'name', None) == 'keras_graph')
|
||||||
if not layer.trainable:
|
if not layer.trainable:
|
||||||
_call_context.frozen = True
|
_call_context.frozen = True
|
||||||
try:
|
try:
|
||||||
@ -421,6 +438,7 @@ def call_context(layer):
|
|||||||
finally:
|
finally:
|
||||||
_call_context.in_call = was_in_call
|
_call_context.in_call = was_in_call
|
||||||
_call_context.frozen = was_frozen
|
_call_context.frozen = was_frozen
|
||||||
|
_call_context.in_keras_graph = was_in_keras_graph
|
||||||
|
|
||||||
|
|
||||||
def training_arg_passed_to_call(argspec, args, kwargs):
|
def training_arg_passed_to_call(argspec, args, kwargs):
|
||||||
@ -431,121 +449,6 @@ def training_arg_passed_to_call(argspec, args, kwargs):
|
|||||||
return 'training' in full_args
|
return 'training' in full_args
|
||||||
|
|
||||||
|
|
||||||
class AutoAddUpdates(object):
|
|
||||||
"""Automatically track stateful ops with `add_update`.
|
|
||||||
|
|
||||||
This context manager is used to automatically add stateful ops to a Layer
|
|
||||||
or Model's `.updates`. This ensures that stateful ops are run in the Keras
|
|
||||||
training loop. It also allows for these stateful ops to be disabled by
|
|
||||||
setting `trainable=False`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```
|
|
||||||
with AutoAddUpdates(layer, inputs) as auto_updates:
|
|
||||||
outputs = layer.call(inputs)
|
|
||||||
auto_updates.set_outputs(outputs)
|
|
||||||
```
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
layer: Layer or Model instance to add the updates to.
|
|
||||||
inputs: The inputs to this Layer or Model, to be used for input-conditional
|
|
||||||
updates.
|
|
||||||
outputs: The outputs of this Layer or Model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, layer, inputs):
|
|
||||||
self.layer = layer
|
|
||||||
self.inputs = inputs
|
|
||||||
self.outputs = []
|
|
||||||
|
|
||||||
def set_outputs(self, outputs):
|
|
||||||
if self.outputs:
|
|
||||||
raise RuntimeError('`set_outputs` should only be called once on an'
|
|
||||||
'`AutoAddUpdates` instance.')
|
|
||||||
self.outputs = outputs
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
# Only run in V2 Function mode.
|
|
||||||
if (context.executing_eagerly() or
|
|
||||||
not ops.executing_eagerly_outside_functions()):
|
|
||||||
return self
|
|
||||||
|
|
||||||
self._graph = ops.get_default_graph()
|
|
||||||
self._num_operations = len(self._graph.get_operations())
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, error_type, unused_value, unused_traceback):
|
|
||||||
if error_type:
|
|
||||||
# Allow errors that occurred inside this context manager to pass through
|
|
||||||
# normally.
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only run in V2 Function mode.
|
|
||||||
if (context.executing_eagerly() or
|
|
||||||
not ops.executing_eagerly_outside_functions()):
|
|
||||||
return
|
|
||||||
|
|
||||||
if (self._graph is not ops.get_default_graph() or
|
|
||||||
self._graph.name != 'keras_graph'):
|
|
||||||
# Only auto-track updates when the Keras Graph is the only one used.
|
|
||||||
return
|
|
||||||
|
|
||||||
new_operations = self._graph.get_operations()[self._num_operations:]
|
|
||||||
new_stateful_ops = set()
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
for op in new_operations:
|
|
||||||
# While loop is not supported in general for automatic control
|
|
||||||
# dependencies.
|
|
||||||
if control_flow_util.IsInWhileLoop(op):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Track stateful ops via `add_update`.
|
|
||||||
is_stateful_op = (
|
|
||||||
op.type not in self._graph._registered_ops or
|
|
||||||
auto_control_deps.op_is_stateful(
|
|
||||||
self._graph._registered_ops[op.type]))
|
|
||||||
|
|
||||||
# Ignore ReadVariableOps as they are not needed to be run separately.
|
|
||||||
# This ensures existing Layers don't get extra updates.
|
|
||||||
if is_stateful_op and op.type != 'ReadVariableOp':
|
|
||||||
new_stateful_ops.add(op)
|
|
||||||
|
|
||||||
explicit_updates = set(
|
|
||||||
[u for u in self.layer.updates if not isinstance(u, tuple)])
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
# Don't add updates that will already be run by virtue of being consumed by
|
|
||||||
# other stateful ops or by the Layer's outputs. This ensures that existing
|
|
||||||
# Layers like `BatchNormalization` continue to return the same values for
|
|
||||||
# `.update` calls.
|
|
||||||
minimum_ops = set()
|
|
||||||
targets = new_stateful_ops.union(
|
|
||||||
set(nest.flatten(self.outputs)), explicit_updates)
|
|
||||||
for op in new_stateful_ops:
|
|
||||||
# Scrub any ops that are consumed by the outputs or other stateful ops.
|
|
||||||
reachable = tf_utils.get_reachable_from_inputs(op)
|
|
||||||
if not (targets - {op}).intersection(reachable):
|
|
||||||
minimum_ops.add(op)
|
|
||||||
new_stateful_ops = minimum_ops
|
|
||||||
|
|
||||||
# Don't double-track updates added via explicitly calling `add_update`.
|
|
||||||
# Also don't double-track updates already tracked in sublayers.
|
|
||||||
new_stateful_ops = new_stateful_ops - explicit_updates
|
|
||||||
|
|
||||||
# Decide whether to track as input-conditional or unconditional.
|
|
||||||
input_reachable_ops = tf_utils.get_reachable_from_inputs(
|
|
||||||
self.inputs, targets=new_stateful_ops)
|
|
||||||
unconditional_updates = new_stateful_ops - input_reachable_ops
|
|
||||||
conditional_updates = new_stateful_ops - unconditional_updates
|
|
||||||
|
|
||||||
if unconditional_updates:
|
|
||||||
self.layer.add_update(list(unconditional_updates))
|
|
||||||
if conditional_updates:
|
|
||||||
self.layer.add_update(list(conditional_updates), inputs=self.inputs)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_var_read_dtype(input_list, should_cast):
|
def _get_var_read_dtype(input_list, should_cast):
|
||||||
"""Gets the dtype that AutoCastVariables should be read in."""
|
"""Gets the dtype that AutoCastVariables should be read in."""
|
||||||
if should_cast and input_list and input_list[0].dtype.is_floating:
|
if should_cast and input_list and input_list[0].dtype.is_floating:
|
||||||
@ -579,7 +482,7 @@ def is_subclassed(layer):
|
|||||||
layer.__module__.find('keras.layers') == -1)
|
layer.__module__.find('keras.layers') == -1)
|
||||||
|
|
||||||
|
|
||||||
def check_graph_consistency(tensor, method):
|
def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
|
||||||
"""Checks that tensors passed to `add_*` method match the Keras graph.
|
"""Checks that tensors passed to `add_*` method match the Keras graph.
|
||||||
|
|
||||||
When one of the `add_*` method is called inside a V2 conditional branch,
|
When one of the `add_*` method is called inside a V2 conditional branch,
|
||||||
@ -589,15 +492,17 @@ def check_graph_consistency(tensor, method):
|
|||||||
Arguments:
|
Arguments:
|
||||||
tensor: Tensor to check.
|
tensor: Tensor to check.
|
||||||
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
|
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
|
||||||
|
force_raise: If an error should be raised regardless of `tensor`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: In case of an out-of-graph tensor.
|
RuntimeError: In case of an out-of-graph tensor.
|
||||||
"""
|
"""
|
||||||
if ops.executing_eagerly_outside_functions() and hasattr(tensor, 'graph'):
|
if (force_raise or (ops.executing_eagerly_outside_functions() and
|
||||||
if isinstance(tensor.graph,
|
hasattr(tensor, 'graph') and
|
||||||
|
isinstance(tensor.graph,
|
||||||
(control_flow_util_v2.CondBranchFuncGraph,
|
(control_flow_util_v2.CondBranchFuncGraph,
|
||||||
control_flow_util_v2.WhileCondFuncGraph,
|
control_flow_util_v2.WhileCondFuncGraph,
|
||||||
control_flow_util_v2.WhileBodyFuncGraph)):
|
control_flow_util_v2.WhileBodyFuncGraph)))):
|
||||||
if method == 'add_metric':
|
if method == 'add_metric':
|
||||||
bad_example = """
|
bad_example = """
|
||||||
def call(self, inputs, training=None):
|
def call(self, inputs, training=None):
|
||||||
@ -665,3 +570,23 @@ def check_graph_consistency(tensor, method):
|
|||||||
method=method,
|
method=method,
|
||||||
bad_example=bad_example,
|
bad_example=bad_example,
|
||||||
correct_example=correct_example))
|
correct_example=correct_example))
|
||||||
|
|
||||||
|
|
||||||
|
def mark_as_return(outputs, acd):
|
||||||
|
"""Marks `outputs` as the return values for automatic control deps."""
|
||||||
|
|
||||||
|
def _mark_as_return(tensor):
|
||||||
|
"""Marks `tensor` as the return value for automatic control deps."""
|
||||||
|
if not tensor_util.is_tensor(tensor):
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
return_tensor = acd.mark_as_return(tensor)
|
||||||
|
if getattr(tensor, '_keras_mask', None) is not None:
|
||||||
|
return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
|
||||||
|
else:
|
||||||
|
return_tensor._keras_mask = None
|
||||||
|
return return_tensor
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
return nest.map_structure(_mark_as_return, outputs)
|
||||||
|
|||||||
@ -2920,7 +2920,7 @@ class BareUpdateLayer(keras.layers.Layer):
|
|||||||
return math_ops.cast(self.counter, inputs.dtype) * inputs
|
return math_ops.cast(self.counter, inputs.dtype) * inputs
|
||||||
|
|
||||||
|
|
||||||
class AddUpdateLayer(keras.layers.Layer):
|
class LambdaUpdateLayer(keras.layers.Layer):
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
self.counter = self.add_weight(
|
self.counter = self.add_weight(
|
||||||
@ -2932,7 +2932,7 @@ class AddUpdateLayer(keras.layers.Layer):
|
|||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
# Make sure update isn't run twice.
|
# Make sure update isn't run twice.
|
||||||
self.add_update(state_ops.assign_add(self.counter, 1))
|
self.add_update(lambda: state_ops.assign_add(self.counter, 1))
|
||||||
return math_ops.cast(self.counter, inputs.dtype) * inputs
|
return math_ops.cast(self.counter, inputs.dtype) * inputs
|
||||||
|
|
||||||
|
|
||||||
@ -2950,12 +2950,31 @@ class NestedUpdateLayer(keras.layers.Layer):
|
|||||||
return self.layer(inputs)
|
return self.layer(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class SubgraphUpdateLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.counter = self.add_weight(
|
||||||
|
'counter',
|
||||||
|
dtype='int32',
|
||||||
|
shape=(),
|
||||||
|
initializer='zeros',
|
||||||
|
trainable=False)
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training is None:
|
||||||
|
training = keras.backend.learning_phase()
|
||||||
|
|
||||||
|
if training:
|
||||||
|
self.counter.assign(self.counter + 1)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
class TestAutoUpdates(keras_parameterized.TestCase):
|
class TestAutoUpdates(keras_parameterized.TestCase):
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@parameterized.named_parameters(('bare_update', BareUpdateLayer()),
|
@parameterized.named_parameters(('bare_update', BareUpdateLayer()),
|
||||||
('add_update', AddUpdateLayer()),
|
('lambda_update', LambdaUpdateLayer()),
|
||||||
('nested_update', NestedUpdateLayer()))
|
('nested_update', NestedUpdateLayer()))
|
||||||
def test_updates_in_model(self, layer):
|
def test_updates_in_model(self, layer):
|
||||||
x, y = np.ones((10, 10)), np.ones((10, 1))
|
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||||
@ -2963,16 +2982,34 @@ class TestAutoUpdates(keras_parameterized.TestCase):
|
|||||||
[layer, keras.layers.Dense(1)], input_shape=(10,))
|
[layer, keras.layers.Dense(1)], input_shape=(10,))
|
||||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
model.fit(x, y, batch_size=2, epochs=1)
|
model.fit(x, y, batch_size=2, epochs=1)
|
||||||
if not testing_utils.should_run_eagerly():
|
self.assertEqual(self.evaluate(layer.counter), 5)
|
||||||
# Check that `trainable=False` disables updates.
|
|
||||||
|
@keras_parameterized.run_with_all_model_types
|
||||||
|
def test_lambda_updates_trainable_false(self):
|
||||||
|
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||||
|
layer = LambdaUpdateLayer()
|
||||||
|
model = testing_utils.get_model_from_layers(
|
||||||
|
[layer, keras.layers.Dense(1)], input_shape=(10,))
|
||||||
|
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
model.fit(x, y, batch_size=2, epochs=1)
|
||||||
|
self.assertEqual(self.evaluate(layer.counter), 5)
|
||||||
layer.trainable = False
|
layer.trainable = False
|
||||||
model.compile(
|
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
'sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
model.fit(x, y, batch_size=2, epochs=1)
|
||||||
|
self.assertEqual(self.evaluate(layer.counter), 5)
|
||||||
|
|
||||||
|
@keras_parameterized.run_with_all_model_types
|
||||||
|
def test_subgraph_updates_in_model(self):
|
||||||
|
layer = SubgraphUpdateLayer()
|
||||||
|
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||||
|
model = testing_utils.get_model_from_layers(
|
||||||
|
[layer, keras.layers.Dense(1)], input_shape=(10,))
|
||||||
|
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
model.fit(x, y, batch_size=2, epochs=1)
|
model.fit(x, y, batch_size=2, epochs=1)
|
||||||
self.assertEqual(self.evaluate(layer.counter), 5)
|
self.assertEqual(self.evaluate(layer.counter), 5)
|
||||||
|
|
||||||
@parameterized.named_parameters(('bare_update', BareUpdateLayer()),
|
@parameterized.named_parameters(('bare_update', BareUpdateLayer()),
|
||||||
('add_update', AddUpdateLayer()),
|
('lambda_update', LambdaUpdateLayer()),
|
||||||
('nested_update', NestedUpdateLayer()))
|
('nested_update', NestedUpdateLayer()))
|
||||||
def test_updates_standalone_layer(self, layer):
|
def test_updates_standalone_layer(self, layer):
|
||||||
y = layer(np.ones((10, 10)))
|
y = layer(np.ones((10, 10)))
|
||||||
@ -2980,23 +3017,23 @@ class TestAutoUpdates(keras_parameterized.TestCase):
|
|||||||
self.evaluate(y)
|
self.evaluate(y)
|
||||||
self.assertEqual(self.evaluate(layer.counter), 1)
|
self.assertEqual(self.evaluate(layer.counter), 1)
|
||||||
|
|
||||||
def test_trainable_false(self):
|
def test_trainable_false_standalone_layer(self):
|
||||||
x = keras.backend.placeholder(shape=(10, 10), dtype='float32')
|
layer = LambdaUpdateLayer()
|
||||||
layer = NestedUpdateLayer()
|
y = layer(np.ones((10, 10)))
|
||||||
|
self.evaluate(layer.counter.initializer)
|
||||||
|
self.evaluate(y)
|
||||||
|
self.assertEqual(self.evaluate(layer.counter), 1)
|
||||||
layer.trainable = False
|
layer.trainable = False
|
||||||
y = layer(x)
|
y = layer(np.ones((10, 10)))
|
||||||
func = keras.backend.function([x], [y])
|
self.evaluate(y)
|
||||||
x_val = np.ones((10, 10))
|
self.assertEqual(self.evaluate(layer.counter), 1)
|
||||||
func(x_val)
|
|
||||||
counter = keras.backend.get_value(layer.counter)
|
|
||||||
self.assertEqual(counter, 0)
|
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
def test_batchnorm_trainable_false(self):
|
def test_batchnorm_trainable_false(self):
|
||||||
bn = keras.layers.BatchNormalization()
|
bn = keras.layers.BatchNormalization()
|
||||||
bn.trainable = False
|
|
||||||
model = testing_utils.get_model_from_layers([bn, keras.layers.Dense(1)],
|
model = testing_utils.get_model_from_layers([bn, keras.layers.Dense(1)],
|
||||||
input_shape=(10,))
|
input_shape=(10,))
|
||||||
|
bn.trainable = False
|
||||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
x, y = np.ones((10, 10)), np.ones((10, 1))
|
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||||
model.fit(x, y, batch_size=2, epochs=1)
|
model.fit(x, y, batch_size=2, epochs=1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user