From b14c390fc869b63fd2b1a4e6f4477ce410b9383e Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Sat, 4 May 2019 16:50:39 -0700 Subject: [PATCH] Use automatic control dependencies in Keras in V2 mode. PiperOrigin-RevId: 246676835 --- .../python/framework/auto_control_deps.py | 2 +- tensorflow/python/keras/engine/base_layer.py | 68 ++-- .../python/keras/engine/base_layer_test.py | 22 +- .../python/keras/engine/base_layer_utils.py | 303 +++++++----------- .../python/keras/engine/training_test.py | 77 +++-- 5 files changed, 236 insertions(+), 236 deletions(-) diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index a8ba4ea50d1..9aae8594dcc 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -335,7 +335,7 @@ class AutomaticControlDependencies(object): # Ensure all ops which must run do run self.ops_which_must_run.update(ops_which_must_run) - for r in self._returned_tensors: + for r in nest.flatten(list(self._returned_tensors), expand_composites=True): if self.ops_which_must_run: r.op._add_control_inputs( # pylint: disable=protected-access [o for o in self.ops_which_must_run diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 70ef87f6bac..b98f9344b4f 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -27,10 +27,12 @@ from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.core.framework import node_def_pb2 from tensorflow.python import autograph +from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import values as distribute_values from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import function +from tensorflow.python.framework import auto_control_deps from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops @@ -617,18 +619,22 @@ class Layer(module.Module): if (self._expects_training_arg and not base_layer_utils.training_arg_passed_to_call( tf_inspect.getfullargspec(self.call), args, kwargs) and - getattr(graph, 'name', None) == 'keras_graph'): + base_layer_utils.is_in_keras_graph()): learning_phase_passed_by_framework = True kwargs['training'] = backend.learning_phase() if not self.dynamic: try: with base_layer_utils.autocast_context_manager( input_list, - self._mixed_precision_policy.should_cast_variables), ( - base_layer_utils.AutoAddUpdates(self, - inputs)) as auto_updater: - outputs = call_fn(inputs, *args, **kwargs) - auto_updater.set_outputs(outputs) + self._mixed_precision_policy.should_cast_variables): + if ops.executing_eagerly_outside_functions(): + with auto_control_deps.AutomaticControlDependencies() as acd: + outputs = call_fn(inputs, *args, **kwargs) + # Wrap Tensors in `outputs` in `tf.identity` to avoid + # circular dependencies. + outputs = base_layer_utils.mark_as_return(outputs, acd) + else: + outputs = call_fn(inputs, *args, **kwargs) except TypeError as e: exception_str = str(e) @@ -739,7 +745,25 @@ class Layer(module.Module): def updates(self): if not self.trainable and not self.stateful: return [] - return self._updates + self._gather_children_attribute('updates') + with backend.get_graph().as_default(): + updates = [] + for u in self._updates: + # Filter out updates created in a cross-replica context when in a + # replica context and vice versa. + if (getattr(u, '_in_cross_replica_context', False) != + ds_context.in_cross_replica_context()): + continue + if callable(u): + try: + u = u() + except ValueError as e: + if 'Trying to capture a tensor from an inner function' in str(e): + base_layer_utils.check_graph_consistency( + method='add_update', force_raise=True) + raise + base_layer_utils.check_graph_consistency(u, method='add_update') + updates.append(u) + return updates + self._gather_children_attribute('updates') @property def losses(self): @@ -1011,14 +1035,13 @@ class Layer(module.Module): """ updates = generic_utils.to_list(updates) - if context.executing_eagerly(): - # Don't run callable updates if currently executing inside the `call` - # of a Layer/Model with `trainable=False`. + # All updates can be run immediately in Eager or in a tf.function. + if base_layer_utils.is_in_eager_or_tf_function(): if not base_layer_utils.is_in_frozen_context(): for update in updates: if callable(update): update() - return # Updates already applied when in eager mode. + return def process_update(x): """Standardize update ops. @@ -1030,24 +1053,29 @@ class Layer(module.Module): An update op. """ if callable(x): - x = x() - if isinstance(x, ops.Operation): + update = lambda: process_update(x()) + if not ops.executing_eagerly_outside_functions(): + # In V1 mode, call the callable right away and process. This is needed + # for TPU strategy. + return update() + elif isinstance(x, ops.Operation): update = x elif hasattr(x, 'op'): update = x.op else: update = ops.convert_to_tensor(x) - base_layer_utils.check_graph_consistency(update, method='add_update') + update._unconditional_update = (inputs is None) + update._in_cross_replica_context = ( + ds_context.has_strategy() and ds_context.in_cross_replica_context()) return update updates = [process_update(x) for x in updates] + # Non-callable Updates are run automatically inside `call` in V2, so + # they do not need to be tracked later. + if (ops.executing_eagerly_outside_functions() and + base_layer_utils.is_in_call_context()): + updates = [u for u in updates if callable(u)] self._updates += updates - if inputs is None: - for u in updates: - u._unconditional_update = True # pylint: disable=protected-access - else: - for u in updates: - u._unconditional_update = False # pylint: disable=protected-access def set_weights(self, weights): """Sets the weights of the layer, from Numpy arrays. diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 6835e02ff2d..b8a3f9a5eed 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -651,13 +651,20 @@ class NameScopingTest(keras_parameterized.TestCase): self.assertEqual(layer.kernel.name, 'MyName/kernel:0') def test_name_scope_sublayer(self): + + class NameScopeTracker(keras.layers.Layer): + + def call(self, inputs): + self.active_name_scope = ops.get_name_scope() + return inputs + x = keras.backend.placeholder(shape=(10, 10)) - layer = keras.layers.Dense( - 10, activation=keras.layers.ReLU(name='MyAct'), name='MyName2') - y = layer(x) + sublayer = NameScopeTracker(name='Sublayer') + layer = keras.layers.Dense(10, activation=sublayer, name='MyName2') + layer(x) self.assertEqual(layer.bias.name, 'MyName2/bias:0') self.assertEqual(layer.kernel.name, 'MyName2/kernel:0') - self.assertEqual(y.name, 'MyName2/MyAct/Relu:0') + self.assertEqual(sublayer.active_name_scope, 'MyName2/Sublayer') def test_name_scope_tf_tensor(self): x = ops.convert_to_tensor(np.ones((10, 10))) @@ -779,7 +786,8 @@ class AutographControlFlowTest(keras_parameterized.TestCase): def call(self, inputs, training=None): if training: - self.add_update(self.counter.assign_add(math_ops.reduce_sum(inputs))) + z = math_ops.reduce_sum(inputs) + self.add_update(lambda: self.counter.assign_add(z)) return inputs def compute_output_shape(self, input_shape): @@ -797,7 +805,9 @@ class AutographControlFlowTest(keras_parameterized.TestCase): # TODO(fchollet): support the same workflow in graph mode. with self.assertRaisesRegexp(RuntimeError, '`add_update` in a control flow branch'): - layer = MyLayer()(keras.Input((3,))) + layer = MyLayer() + layer(keras.Input((3,))) + _ = layer.updates @parameterized.named_parameters(('eager', True), ('symbolic', False)) diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index b97326eea6a..c9a27f7e82b 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -23,13 +23,11 @@ import enum from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context -from tensorflow.python.framework import auto_control_deps from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend -from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops_v2 @@ -351,6 +349,21 @@ def is_in_frozen_context(): return getattr(_call_context, 'frozen', False) +def is_in_keras_graph(): + """Returns if currently executing inside of a Keras graph.""" + # Returns True even if in a subgraph of the Keras graph, such as those + # created by control flow ops. + return (getattr(backend.get_graph(), 'name', None) == 'keras_graph' or + getattr(_call_context, 'in_keras_graph', False)) + + +def is_in_eager_or_tf_function(): + """Returns if in eager mode or inside of a tf.function.""" + return (context.executing_eagerly() or + (ops.executing_eagerly_outside_functions() and + not is_in_keras_graph())) + + def uses_keras_history(tensors): """Check if at least one Tensor originates from a `keras.Input`. @@ -413,7 +426,11 @@ def call_context(layer): """Scope that marks when we are currently inside a Layer/Model's `call`.""" was_in_call = is_in_call_context() was_frozen = is_in_frozen_context() + was_in_keras_graph = getattr(_call_context, 'in_keras_graph', False) _call_context.in_call = True + _call_context.in_keras_graph = ( + was_in_keras_graph or + getattr(backend.get_graph(), 'name', None) == 'keras_graph') if not layer.trainable: _call_context.frozen = True try: @@ -421,6 +438,7 @@ def call_context(layer): finally: _call_context.in_call = was_in_call _call_context.frozen = was_frozen + _call_context.in_keras_graph = was_in_keras_graph def training_arg_passed_to_call(argspec, args, kwargs): @@ -431,121 +449,6 @@ def training_arg_passed_to_call(argspec, args, kwargs): return 'training' in full_args -class AutoAddUpdates(object): - """Automatically track stateful ops with `add_update`. - - This context manager is used to automatically add stateful ops to a Layer - or Model's `.updates`. This ensures that stateful ops are run in the Keras - training loop. It also allows for these stateful ops to be disabled by - setting `trainable=False`. - - Example: - - ``` - with AutoAddUpdates(layer, inputs) as auto_updates: - outputs = layer.call(inputs) - auto_updates.set_outputs(outputs) - ``` - - Attributes: - layer: Layer or Model instance to add the updates to. - inputs: The inputs to this Layer or Model, to be used for input-conditional - updates. - outputs: The outputs of this Layer or Model. - """ - - def __init__(self, layer, inputs): - self.layer = layer - self.inputs = inputs - self.outputs = [] - - def set_outputs(self, outputs): - if self.outputs: - raise RuntimeError('`set_outputs` should only be called once on an' - '`AutoAddUpdates` instance.') - self.outputs = outputs - - def __enter__(self): - # Only run in V2 Function mode. - if (context.executing_eagerly() or - not ops.executing_eagerly_outside_functions()): - return self - - self._graph = ops.get_default_graph() - self._num_operations = len(self._graph.get_operations()) - return self - - def __exit__(self, error_type, unused_value, unused_traceback): - if error_type: - # Allow errors that occurred inside this context manager to pass through - # normally. - return - - # Only run in V2 Function mode. - if (context.executing_eagerly() or - not ops.executing_eagerly_outside_functions()): - return - - if (self._graph is not ops.get_default_graph() or - self._graph.name != 'keras_graph'): - # Only auto-track updates when the Keras Graph is the only one used. - return - - new_operations = self._graph.get_operations()[self._num_operations:] - new_stateful_ops = set() - - # pylint: disable=protected-access - for op in new_operations: - # While loop is not supported in general for automatic control - # dependencies. - if control_flow_util.IsInWhileLoop(op): - continue - - # Track stateful ops via `add_update`. - is_stateful_op = ( - op.type not in self._graph._registered_ops or - auto_control_deps.op_is_stateful( - self._graph._registered_ops[op.type])) - - # Ignore ReadVariableOps as they are not needed to be run separately. - # This ensures existing Layers don't get extra updates. - if is_stateful_op and op.type != 'ReadVariableOp': - new_stateful_ops.add(op) - - explicit_updates = set( - [u for u in self.layer.updates if not isinstance(u, tuple)]) - # pylint: enable=protected-access - - # Don't add updates that will already be run by virtue of being consumed by - # other stateful ops or by the Layer's outputs. This ensures that existing - # Layers like `BatchNormalization` continue to return the same values for - # `.update` calls. - minimum_ops = set() - targets = new_stateful_ops.union( - set(nest.flatten(self.outputs)), explicit_updates) - for op in new_stateful_ops: - # Scrub any ops that are consumed by the outputs or other stateful ops. - reachable = tf_utils.get_reachable_from_inputs(op) - if not (targets - {op}).intersection(reachable): - minimum_ops.add(op) - new_stateful_ops = minimum_ops - - # Don't double-track updates added via explicitly calling `add_update`. - # Also don't double-track updates already tracked in sublayers. - new_stateful_ops = new_stateful_ops - explicit_updates - - # Decide whether to track as input-conditional or unconditional. - input_reachable_ops = tf_utils.get_reachable_from_inputs( - self.inputs, targets=new_stateful_ops) - unconditional_updates = new_stateful_ops - input_reachable_ops - conditional_updates = new_stateful_ops - unconditional_updates - - if unconditional_updates: - self.layer.add_update(list(unconditional_updates)) - if conditional_updates: - self.layer.add_update(list(conditional_updates), inputs=self.inputs) - - def _get_var_read_dtype(input_list, should_cast): """Gets the dtype that AutoCastVariables should be read in.""" if should_cast and input_list and input_list[0].dtype.is_floating: @@ -579,7 +482,7 @@ def is_subclassed(layer): layer.__module__.find('keras.layers') == -1) -def check_graph_consistency(tensor, method): +def check_graph_consistency(tensor=None, method='add_loss', force_raise=False): """Checks that tensors passed to `add_*` method match the Keras graph. When one of the `add_*` method is called inside a V2 conditional branch, @@ -589,79 +492,101 @@ def check_graph_consistency(tensor, method): Arguments: tensor: Tensor to check. method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}. + force_raise: If an error should be raised regardless of `tensor`. Raises: RuntimeError: In case of an out-of-graph tensor. """ - if ops.executing_eagerly_outside_functions() and hasattr(tensor, 'graph'): - if isinstance(tensor.graph, - (control_flow_util_v2.CondBranchFuncGraph, - control_flow_util_v2.WhileCondFuncGraph, - control_flow_util_v2.WhileBodyFuncGraph)): - if method == 'add_metric': - bad_example = """ - def call(self, inputs, training=None): - if training: - metric = compute_metric(inputs) - self.add_metric(metric, name='my_metric', aggregation='mean') - return inputs - """ - correct_example = """ - def call(self, inputs, training=None): - if training: - metric = compute_metric(inputs) - else: - metric = 0. + if (force_raise or (ops.executing_eagerly_outside_functions() and + hasattr(tensor, 'graph') and + isinstance(tensor.graph, + (control_flow_util_v2.CondBranchFuncGraph, + control_flow_util_v2.WhileCondFuncGraph, + control_flow_util_v2.WhileBodyFuncGraph)))): + if method == 'add_metric': + bad_example = """ + def call(self, inputs, training=None): + if training: + metric = compute_metric(inputs) self.add_metric(metric, name='my_metric', aggregation='mean') - return inputs - """ - elif method == 'add_loss': - bad_example = """ - def call(self, inputs, training=None): - if training: - loss = compute_loss(inputs) - self.add_loss(loss) - return inputs - """ - correct_example = """ - def call(self, inputs, training=None): - if training: - loss = compute_loss(inputs) - else: - loss = 0. + return inputs + """ + correct_example = """ + def call(self, inputs, training=None): + if training: + metric = compute_metric(inputs) + else: + metric = 0. + self.add_metric(metric, name='my_metric', aggregation='mean') + return inputs + """ + elif method == 'add_loss': + bad_example = """ + def call(self, inputs, training=None): + if training: + loss = compute_loss(inputs) self.add_loss(loss) - return inputs - """ - else: - bad_example = """ - def call(self, inputs, training=None): - if training: - self.add_update(self.w.assign_add(1)) - return inputs - """ - correct_example = """ - def call(self, inputs, training=None): - if training: - increment = 1 - else: - increment = 0 - self.add_update(self.w.assign_add(increment)) - return inputs - """ - raise RuntimeError( - 'You are using the method `{method}` in a control flow branch ' - 'in your layer, e.g.:\n{bad_example}\n' - 'This is not currently supported. ' - 'You should either use static control flow (`tf.cond`) ' - 'or move your call to {method} out of the control flow branch, ' - 'e.g.:\n{correct_example}\n' - 'You can also resolve this by marking your layer ' - 'as dynamic (eager-only) by passing ' - '`dynamic=True` to the layer constructor. ' - 'Any kind of control flow is supported with dynamic layers. ' - 'Note that using `dynamic=True` requires you ' - 'to implement static shape inference ' - 'in the `compute_output_shape(input_shape)` method.'.format( - method=method, - bad_example=bad_example, - correct_example=correct_example)) + return inputs + """ + correct_example = """ + def call(self, inputs, training=None): + if training: + loss = compute_loss(inputs) + else: + loss = 0. + self.add_loss(loss) + return inputs + """ + else: + bad_example = """ + def call(self, inputs, training=None): + if training: + self.add_update(self.w.assign_add(1)) + return inputs + """ + correct_example = """ + def call(self, inputs, training=None): + if training: + increment = 1 + else: + increment = 0 + self.add_update(self.w.assign_add(increment)) + return inputs + """ + raise RuntimeError( + 'You are using the method `{method}` in a control flow branch ' + 'in your layer, e.g.:\n{bad_example}\n' + 'This is not currently supported. ' + 'You should either use static control flow (`tf.cond`) ' + 'or move your call to {method} out of the control flow branch, ' + 'e.g.:\n{correct_example}\n' + 'You can also resolve this by marking your layer ' + 'as dynamic (eager-only) by passing ' + '`dynamic=True` to the layer constructor. ' + 'Any kind of control flow is supported with dynamic layers. ' + 'Note that using `dynamic=True` requires you ' + 'to implement static shape inference ' + 'in the `compute_output_shape(input_shape)` method.'.format( + method=method, + bad_example=bad_example, + correct_example=correct_example)) + + +def mark_as_return(outputs, acd): + """Marks `outputs` as the return values for automatic control deps.""" + + def _mark_as_return(tensor): + """Marks `tensor` as the return value for automatic control deps.""" + if not tensor_util.is_tensor(tensor): + return tensor + + # pylint: disable=protected-access + return_tensor = acd.mark_as_return(tensor) + if getattr(tensor, '_keras_mask', None) is not None: + return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask) + else: + return_tensor._keras_mask = None + return return_tensor + # pylint: enable=protected-access + + return nest.map_structure(_mark_as_return, outputs) diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index de751520091..732270c7c4d 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -2920,7 +2920,7 @@ class BareUpdateLayer(keras.layers.Layer): return math_ops.cast(self.counter, inputs.dtype) * inputs -class AddUpdateLayer(keras.layers.Layer): +class LambdaUpdateLayer(keras.layers.Layer): def build(self, input_shape): self.counter = self.add_weight( @@ -2932,7 +2932,7 @@ class AddUpdateLayer(keras.layers.Layer): def call(self, inputs): # Make sure update isn't run twice. - self.add_update(state_ops.assign_add(self.counter, 1)) + self.add_update(lambda: state_ops.assign_add(self.counter, 1)) return math_ops.cast(self.counter, inputs.dtype) * inputs @@ -2950,12 +2950,31 @@ class NestedUpdateLayer(keras.layers.Layer): return self.layer(inputs) +class SubgraphUpdateLayer(keras.layers.Layer): + + def build(self, input_shape): + self.counter = self.add_weight( + 'counter', + dtype='int32', + shape=(), + initializer='zeros', + trainable=False) + + def call(self, inputs, training=None): + if training is None: + training = keras.backend.learning_phase() + + if training: + self.counter.assign(self.counter + 1) + return inputs + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) class TestAutoUpdates(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types @parameterized.named_parameters(('bare_update', BareUpdateLayer()), - ('add_update', AddUpdateLayer()), + ('lambda_update', LambdaUpdateLayer()), ('nested_update', NestedUpdateLayer())) def test_updates_in_model(self, layer): x, y = np.ones((10, 10)), np.ones((10, 1)) @@ -2963,16 +2982,34 @@ class TestAutoUpdates(keras_parameterized.TestCase): [layer, keras.layers.Dense(1)], input_shape=(10,)) model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) model.fit(x, y, batch_size=2, epochs=1) - if not testing_utils.should_run_eagerly(): - # Check that `trainable=False` disables updates. - layer.trainable = False - model.compile( - 'sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) - model.fit(x, y, batch_size=2, epochs=1) + self.assertEqual(self.evaluate(layer.counter), 5) + + @keras_parameterized.run_with_all_model_types + def test_lambda_updates_trainable_false(self): + x, y = np.ones((10, 10)), np.ones((10, 1)) + layer = LambdaUpdateLayer() + model = testing_utils.get_model_from_layers( + [layer, keras.layers.Dense(1)], input_shape=(10,)) + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + model.fit(x, y, batch_size=2, epochs=1) + self.assertEqual(self.evaluate(layer.counter), 5) + layer.trainable = False + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + model.fit(x, y, batch_size=2, epochs=1) + self.assertEqual(self.evaluate(layer.counter), 5) + + @keras_parameterized.run_with_all_model_types + def test_subgraph_updates_in_model(self): + layer = SubgraphUpdateLayer() + x, y = np.ones((10, 10)), np.ones((10, 1)) + model = testing_utils.get_model_from_layers( + [layer, keras.layers.Dense(1)], input_shape=(10,)) + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + model.fit(x, y, batch_size=2, epochs=1) self.assertEqual(self.evaluate(layer.counter), 5) @parameterized.named_parameters(('bare_update', BareUpdateLayer()), - ('add_update', AddUpdateLayer()), + ('lambda_update', LambdaUpdateLayer()), ('nested_update', NestedUpdateLayer())) def test_updates_standalone_layer(self, layer): y = layer(np.ones((10, 10))) @@ -2980,23 +3017,23 @@ class TestAutoUpdates(keras_parameterized.TestCase): self.evaluate(y) self.assertEqual(self.evaluate(layer.counter), 1) - def test_trainable_false(self): - x = keras.backend.placeholder(shape=(10, 10), dtype='float32') - layer = NestedUpdateLayer() + def test_trainable_false_standalone_layer(self): + layer = LambdaUpdateLayer() + y = layer(np.ones((10, 10))) + self.evaluate(layer.counter.initializer) + self.evaluate(y) + self.assertEqual(self.evaluate(layer.counter), 1) layer.trainable = False - y = layer(x) - func = keras.backend.function([x], [y]) - x_val = np.ones((10, 10)) - func(x_val) - counter = keras.backend.get_value(layer.counter) - self.assertEqual(counter, 0) + y = layer(np.ones((10, 10))) + self.evaluate(y) + self.assertEqual(self.evaluate(layer.counter), 1) @keras_parameterized.run_with_all_model_types def test_batchnorm_trainable_false(self): bn = keras.layers.BatchNormalization() - bn.trainable = False model = testing_utils.get_model_from_layers([bn, keras.layers.Dense(1)], input_shape=(10,)) + bn.trainable = False model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) x, y = np.ones((10, 10)), np.ones((10, 1)) model.fit(x, y, batch_size=2, epochs=1)