diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 955204fd96b..6281005087d 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -104,25 +104,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector): We recommend that descendants of `Layer` implement the following methods: - * `__init__()`: Defines custom layer attributes, and creates layer state - variables that do not depend on input shapes, using `add_weight()`. - * `build(self, input_shape)`: This method can be used to create weights that - depend on the shape(s) of the input(s), using `add_weight()`. `__call__()` - will automatically build the layer (if it has not been built yet) by - calling `build()`. - * `call(self, *args, **kwargs)`: Called in `__call__` after making sure - `build()` has been called. `call()` performs the logic of applying the - layer to the input tensors (which should be passed in as argument). - Two reserved keyword arguments you can optionally use in `call()` are: - - `training` (boolean, whether the call is in - inference mode or training mode) - - `mask` (boolean tensor encoding masked timesteps in the input, used - in RNN layers) - * `get_config(self)`: Returns a dictionary containing the configuration used - to initialize this layer. If the keys differ from the arguments - in `__init__`, then override `from_config(self)` as well. - This method is used when saving - the layer or a model that contains this layer. + * `__init__()`: Save configuration in member variables + * `build()`: Called once from `__call__`, when we know the shapes of inputs + and `dtype`. Should have the calls to `add_weight()`, and then + call the super's `build()` (which sets `self.built = True`, which is + nice in case the user wants to call `build()` manually before the + first `__call__`). + * `call()`: Called in `__call__` after making sure `build()` has been called + once. Should actually perform the logic of applying the layer to the + input tensors (which should be passed in as the first argument). Arguments: trainable: Boolean, whether the layer's variables should be trainable. @@ -205,10 +195,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # Indicates whether `build` needs to be called upon layer call, to create # the layer's weights. self.built = False - # Record the build input shape for loading purposes. - # TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is - # submitted. - self._build_input_shape = None # Provides information about which inputs are compatible with the layer. self._input_spec = None self.supports_masking = False @@ -290,7 +276,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): `TensorShape` if the layer expects a list of inputs (one instance per input). """ - self._build_input_shape = input_shape self.built = True @doc_controls.for_subclass_implementers @@ -2215,11 +2200,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # to avoid creating symbolic Tensors that will later pollute any eager # operations. with tf_utils.maybe_init_scope(self): - self.build(input_shapes) # pylint:disable=not-callable - # We must set also ensure that the layer is marked as built, and the build - # shape is stored since user defined build functions may not be calling - # `super.build()` - Layer.build(self, input_shapes) + self.build(input_shapes) + # We must set self.built since user defined build functions are not + # constrained to set self.built. + self.built = True # Optionally load weight values specified at layer instantiation. if self._initial_weights is not None: diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index e6efd23ad1c..31193fe8b57 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -630,25 +630,6 @@ class BaseLayerTest(keras_parameterized.TestCase): out = self.evaluate(layer(x=x, y=y)) self.assertAllClose(out, 2 * np.ones((10, 1))) - def test_build_input_shape(self): - class CustomLayer(keras.layers.Layer): - - def build(self, input_shape): - self.add_weight('w', shape=input_shape[1:]) - super(CustomLayer, self).build(input_shape) - - layer = CustomLayer() - self.assertFalse(layer.built) - - layer.build([None, 1, 2, 3]) - self.assertTrue(layer.built) - self.assertEqual([None, 1, 2, 3], layer._build_input_shape) - - layer = CustomLayer() - layer(keras.Input((3,))) - self.assertTrue(layer.built) - self.assertEqual([None, 3], layer._build_input_shape.as_list()) - class SymbolicSupportTest(test.TestCase): diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index deb3bd27928..85699ff14df 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -244,7 +244,6 @@ class Network(base_layer.Layer): # A Network does not create weights of its own, thus it is already # built. self.built = True - self._build_input_shape = nest.map_structure(lambda x: x.shape, inputs) self._compute_output_and_mask_jointly = True self._is_graph_network = True # `_expects_training_arg` is True since the `training` argument is always @@ -357,7 +356,6 @@ class Network(base_layer.Layer): self.outputs = [] self.inputs = [] self.built = False - self._build_input_shape = None @property @trackable_layer_utils.cache_recursive_attribute('dynamic') @@ -621,7 +619,7 @@ class Network(base_layer.Layer): on real tensor data. """ if self._is_graph_network: - super(Network, self).build(input_shape) + self.built = True return # If subclass network @@ -686,7 +684,7 @@ class Network(base_layer.Layer): 'model, `call` your model on real tensor data (of ' 'the correct dtype).') - super(Network, self).build(input_shape) + self.built = True def call(self, inputs, training=None, mask=None): """Calls the model on new inputs. diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index 61c0949f7f9..ca9662e78e6 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -110,6 +110,7 @@ class Sequential(training.Model): """ super(Sequential, self).__init__(name=name, autocast=False) self.supports_masking = True + self._build_input_shape = None self._compute_output_and_mask_jointly = True self._layer_call_argspecs = {} @@ -262,7 +263,8 @@ class Sequential(training.Model): if input_shape is None: raise ValueError('You must provide an `input_shape` argument.') input_shape = tuple(input_shape) - super(Sequential, self).build(input_shape) + self._build_input_shape = input_shape + super(Sequential, self).build(input_shape) self.built = True def call(self, inputs, training=None, mask=None): # pylint: disable=redefined-outer-name