From b57c7d71eff5914a503d15130cb90a240b3bcf40 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Fri, 15 Mar 2019 12:18:44 -0700 Subject: [PATCH] Switch `_set_inputs` to use full `__call__`. PiperOrigin-RevId: 238688012 --- tensorflow/python/keras/callbacks_test.py | 14 ++++-- tensorflow/python/keras/engine/base_layer.py | 22 ++++----- .../python/keras/engine/base_layer_utils.py | 7 +-- tensorflow/python/keras/engine/network.py | 11 ++--- tensorflow/python/keras/engine/training.py | 48 ++++++++----------- .../python/keras/model_subclassing_test.py | 24 +++++++++- 6 files changed, 72 insertions(+), 54 deletions(-) diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 690409f2522..d25ef6360f0 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -1164,6 +1164,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase): model = self._get_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) tb_cbk = keras.callbacks.TensorBoard(self.logdir, histogram_freq=1) + model_type = testing_utils.get_model_type() model.fit( x, @@ -1182,7 +1183,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase): }, ) self.assertEqual( - self._strip_layer_names(summary_file.histograms), + self._strip_layer_names(summary_file.histograms, model_type), { _ObservedSummary(logdir=self.train_dir, tag='bias_0'), _ObservedSummary(logdir=self.train_dir, tag='kernel_0'), @@ -1194,6 +1195,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase): x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) tb_cbk = keras.callbacks.TensorBoard( self.logdir, histogram_freq=1, write_images=True) + model_type = testing_utils.get_model_type() model.fit( x, @@ -1212,14 +1214,14 @@ class TestTensorBoardV2(keras_parameterized.TestCase): }, ) self.assertEqual( - self._strip_layer_names(summary_file.histograms), + self._strip_layer_names(summary_file.histograms, model_type), { _ObservedSummary(logdir=self.train_dir, tag='bias_0'), _ObservedSummary(logdir=self.train_dir, tag='kernel_0'), }, ) self.assertEqual( - self._strip_layer_names(summary_file.images), + self._strip_layer_names(summary_file.images, model_type), { _ObservedSummary(logdir=self.train_dir, tag='bias_0/image/0'), _ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/0'), @@ -1228,7 +1230,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase): }, ) - def _strip_layer_names(self, summaries): + def _strip_layer_names(self, summaries, model_type): """Deduplicate summary names modulo layer prefix. This removes the first slash-component of each tag name: for @@ -1236,6 +1238,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase): Args: summaries: A `set` of `_ObservedSummary` values. + model_type: The model type currently being tested. Returns: A new `set` of `_ObservedSummary` values with layer prefixes @@ -1245,7 +1248,8 @@ class TestTensorBoardV2(keras_parameterized.TestCase): for summary in summaries: if '/' not in summary.tag: raise ValueError('tag has no layer name: %r' % summary.tag) - new_tag = summary.tag.split('/', 1)[1] + start_from = 2 if 'subclass' in model_type else 1 + new_tag = '/'.join(summary.tag.split('/')[start_from:]) result.add(summary._replace(tag=new_tag)) return result diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index fc59f3c81f7..cf70c28e75e 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -714,9 +714,7 @@ class Layer(trackable.Trackable): @property def updates(self): - if not self.trainable and not self.stateful: - return [] - return self._updates + self._gather_children_attribute('updates') + return self._get_unfiltered_updates(check_trainable=True) @property def losses(self): @@ -967,13 +965,15 @@ class Layer(trackable.Trackable): if inputs is None: # Requesting unconditional updates. - return [x for x in self._unfiltered_updates if x._unconditional_update] # pylint: disable=protected-access + return [ + x for x in self._get_unfiltered_updates() if x._unconditional_update # pylint: disable=protected-access + ] # Requesting input-conditional updates. inputs = nest.flatten(inputs) - reachable = tf_utils.get_reachable_from_inputs(inputs, - self._unfiltered_updates) - return [u for u in self._unfiltered_updates if u in reachable] # pylint: disable=protected-access + reachable = tf_utils.get_reachable_from_inputs( + inputs, self._get_unfiltered_updates()) + return [u for u in self._get_unfiltered_updates() if u in reachable] # pylint: disable=protected-access def get_losses_for(self, inputs): """Retrieves losses relevant to a specific set of inputs. @@ -1847,10 +1847,10 @@ class Layer(trackable.Trackable): def _is_layer(self): return True - @property - def _unfiltered_updates(self): - # Overridden in `Network`. - return self.updates + def _get_unfiltered_updates(self, check_trainable=True): + if check_trainable and not self.trainable and not self.stateful: + return [] + return self._updates + self._gather_children_attribute('updates') class Node(object): diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index 95f709fd425..4697f8d1f9e 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -482,8 +482,10 @@ class AutoAddUpdates(object): if is_stateful_op and op.type != 'ReadVariableOp': new_stateful_ops.add(op) - explicit_updates = set( - [u for u in self.layer._unfiltered_updates if not isinstance(u, tuple)]) + explicit_updates = set([ + u for u in self.layer._get_unfiltered_updates(check_trainable=False) + if not isinstance(u, tuple) + ]) # pylint: enable=protected-access # Don't add updates that will already be run by virtue of being consumed by @@ -542,4 +544,3 @@ def autocast_context_manager(input_list, should_cast): var_read_dtype = _get_var_read_dtype(input_list, should_cast) return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access var_read_dtype) - diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 99003294b27..fdda1141fd7 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -521,11 +521,12 @@ class Network(base_layer.Layer): return layer raise ValueError('No such layer: ' + name) - @property - def _unfiltered_updates(self): + def _get_unfiltered_updates(self, check_trainable=True): + if check_trainable and not self.trainable and not self.stateful: + return [] updates = [] for layer in self.layers: - updates += layer._unfiltered_updates + updates += layer._get_unfiltered_updates(check_trainable=check_trainable) updates += list(self._updates) return updates @@ -605,10 +606,8 @@ class Network(base_layer.Layer): Returns: A list of update ops. """ - if not self.trainable and not self.stateful: - return [] - updates = self._unfiltered_updates + updates = self._get_unfiltered_updates(check_trainable=True) # `updates` might contain irrelevant updates, so it needs to be filtered # with respect to inputs the model has been called on. diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 820a99b4463..56c787b97aa 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -2612,7 +2612,7 @@ class Model(network.Network): 'However we received `validation_data=%s`' % validation_data) return val_x, val_y, val_sample_weight - @trackable.no_automatic_dependency_tracking + # TODO(omalleyt): Consider changing to a more descriptive function name. def _set_inputs(self, inputs, outputs=None, training=None): """Set model's input and output specs based on the input data received. @@ -2639,6 +2639,22 @@ class Model(network.Network): ValueError: If dict inputs are passed to a Sequential Model where the first layer isn't FeatureLayer. """ + inputs = self._set_input_attrs(inputs) + + if outputs is None: + kwargs = {'training': training} if self._expects_training_arg else {} + try: + outputs = self(inputs, **kwargs) + except NotImplementedError: + # This Model or a submodel is dynamic and hasn't overridden + # `compute_output_shape`. + outputs = None + + self._set_output_attrs(outputs) + + @trackable.no_automatic_dependency_tracking + def _set_input_attrs(self, inputs): + """Sets attributes related to the inputs of the Model.""" if self.inputs: raise ValueError('Model inputs are already set.') @@ -2675,33 +2691,11 @@ class Model(network.Network): self._feed_inputs.append(v) self._feed_input_shapes.append(K.int_shape(v)) - # TODO(fchollet): consider calling `_maybe_build` before calling the model. - if outputs is None: - if not self._dynamic: - # The network may include dynamic layers but its `call` - # itself isn't dynamic. - # Obtain symbolic outputs by calling the model. - with K.get_graph().as_default(): - contains_symbolic_tensors = getattr( - self, '_contains_symbolic_tensors', False) - if self._expects_training_arg: - outputs = self.call(inputs, training=training) - else: - outputs = self.call(inputs) - # Reset to the previously saved value. If `call()` had `add_metric` - # or `add_loss`, then `_contains_symbolic_tensors` will have been set - # to True since we are not in `__call__` context. Hence we are - # resetting to the old value here. - self._contains_symbolic_tensors = contains_symbolic_tensors - else: - # Case: network's `call` is dynamic. - try: - outputs = self._symbolic_call(inputs) - except NotImplementedError: - # Static shape inference was not implemented for this dynamic net. - # Do not specify symbolic outputs. - outputs = None + return inputs + @trackable.no_automatic_dependency_tracking + def _set_output_attrs(self, outputs): + """Sets attributes related to the outputs of the Model.""" outputs = nest.flatten(outputs) self.outputs = outputs self.output_names = training_utils.generic_output_names(outputs) diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index 5220f4e28f4..768b8e4dd3b 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -187,8 +187,8 @@ def get_nested_model_3(input_dim, num_classes): return keras.Model(inputs, outputs, name='nested_model_3') -@test_util.run_all_in_graph_and_eager_modes -class ModelSubclassingTest(test.TestCase): +@keras_parameterized.run_all_keras_modes +class ModelSubclassingTest(keras_parameterized.TestCase): def test_custom_build(self): class DummyModel(keras.Model): @@ -210,6 +210,26 @@ class ModelSubclassingTest(test.TestCase): self.assertTrue(test_model.uses_custom_build, 'Model should use user ' 'defined build when called.') + def test_custom_build_with_fit(self): + + class DummyModel(keras.Model): + + def __init__(self): + super(DummyModel, self).__init__() + self.layer1 = keras.layers.Dense(10, activation='relu') + + def build(self, input_shape): + self.layer2 = keras.layers.Dense(1, activation='relu') + + def call(self, inputs): + return self.layer2(self.layer1(inputs)) + + model = DummyModel() + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2, epochs=2) + self.assertLen(model.layers, 2) + self.assertLen(model.trainable_variables, 4) + def test_invalid_input_shape_build(self): num_classes = 2 input_dim = 50