Unifies the behaviors of "regular Sequential" (starts with an Input, wraps a Functional model) and "deferred Sequential" models (only gets built when it sees its input data for the first time).

Before, the following two models had the following behavior differences:

```
regular_sequential = Sequential([
  Dense(2, activation='relu', input_shape=(3,)),
  Dense(4),
]
deferred_sequential = Sequential([
  Dense(2, activation='relu'),
  Dense(4),
])
deferred_sequential(tf.zeros((1, 3)))  # Builds the deferred sequential
```

1) The regular sequential is inspectable: its `.summary()` displays intermediate output shapes. The deferred sequential does not display output shapes.
2) The regular sequential has `inputs` and `outputs` attributes, and its intermediate layers can be used to do feature extraction (see example below), etc. The deferred sequential can't do this.

Feature extraction example:

```
model = Sequential(...)
extractor = keras.Model(inputs=model.inputs,
                        outputs=[layer.output for layer in model.layers])
features = extractor(data)
```

After this CL, the two models behave exactly the same once the deferred sequential has been built (whether by `__call__`ing it, by calling `fit`/`evaluate`/`predict`, or by calling `build` directly). The input shape used is the most restrictive shape compatible with all shapes previously seen by the model (i.e. the set of invariants among all shapes).

The behavior unification is not applied for TF V1, since we don't want to disrupt legacy behaviors and don't want to add new features in V1.

Note that the deferred Sequential remain different in the following cases:
- When a deferred Sequential is called with inputs of different ranks. This is impossible to express in the Functional API (of which "regular Sequential" is a wrapper). However, this is almost certainly not something that anyone is doing.
- When a deferred Sequential starts with a layer that takes multiple inputs. At this time this is something that regular Sequential models do not support. This is an invalid use case of Sequential (which should be single-input and single-output), which unfortunately some users have come to rely on. We may choose to enable it in the future (since it can be expressed with the Functional API). When we enable it we could unify the deferred Sequential behavior in this case.
- When a deferred Sequential contains a non-autographable layer that isn't marked as dynamic, or that is marked as dynamic but does not support shape inference. In that case no Functional model can be built.

PiperOrigin-RevId: 306151882
Change-Id: I4aa881af254ee845f771e375933deae664c80354
This commit is contained in:
Francois Chollet 2020-04-12 16:00:50 -07:00 committed by TensorFlower Gardener
parent dd519b931a
commit 1227a9446d
13 changed files with 414 additions and 53 deletions

View File

@ -589,3 +589,17 @@ tf_py_test(
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "deferred_sequential_test",
size = "small",
srcs = ["deferred_sequential_test.py"],
python_version = "PY3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/keras",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -0,0 +1,216 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests specific to deferred-build `Sequential` models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import numpy as np
from tensorflow.python import keras
from tensorflow.python.compat import v2_compat
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
try:
import h5py # pylint:disable=g-import-not-at-top
except ImportError:
h5py = None
class TestDeferredSequential(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_build_behavior(self):
# Test graph network creation after __call__
model = get_model()
model(np.random.random((2, 6)))
self.assertLen(model.weights, 4)
self.assertTrue(model._is_graph_network)
self.assertLen(model.inputs, 1)
self.assertLen(model.outputs, 1)
self.assertEqual(model.inputs[0].shape.as_list(), [2, 6])
self.assertEqual(model.outputs[0].shape.as_list(), [2, 2])
# Test effect of new __call__ with a different shape
model(np.random.random((3, 6)))
self.assertLen(model.inputs, 1)
self.assertLen(model.outputs, 1)
self.assertEqual(model.inputs[0].shape.as_list(), [None, 6])
self.assertEqual(model.outputs[0].shape.as_list(), [None, 2])
model(np.random.random((4, 6)))
self.assertLen(model.inputs, 1)
self.assertLen(model.outputs, 1)
self.assertEqual(model.inputs[0].shape.as_list(), [None, 6])
self.assertEqual(model.outputs[0].shape.as_list(), [None, 2])
# Test graph network creation after build
model = get_model()
model.build((None, 6))
self.assertLen(model.weights, 4)
self.assertTrue(model._is_graph_network)
self.assertLen(model.inputs, 1)
self.assertLen(model.outputs, 1)
self.assertEqual(model.inputs[0].shape.as_list(), [None, 6])
self.assertEqual(model.outputs[0].shape.as_list(), [None, 2])
# Test graph network creation after compile/fit
model = get_model()
model.compile(
loss='mse',
optimizer='rmsprop',
metrics=[keras.metrics.CategoricalAccuracy()],
run_eagerly=testing_utils.should_run_eagerly())
model.fit(np.zeros((2, 6)), np.zeros((2, 2)))
self.assertLen(model.weights, 4)
self.assertTrue(model._is_graph_network)
self.assertLen(model.inputs, 1)
self.assertLen(model.outputs, 1)
# Inconsistency here: with eager `fit`, the model is built with shape
# (2, 6), but with graph function `fit`, it is built with shape `(None, 6)`.
# This is likely due to our assumption "the batch size should be dynamic"
# at the level of `Model`. TODO(fchollet): investigate and resolve.
self.assertEqual(model.inputs[0].shape.as_list()[-1], 6)
self.assertEqual(model.outputs[0].shape.as_list()[-1], 2)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_add_and_pop(self):
model = get_model()
model.build((None, 6))
self.assertTrue(model.built)
self.assertTrue(model._is_graph_network)
self.assertLen(model.layers, 3)
self.assertLen(model.weights, 4)
model.pop()
self.assertTrue(model.built)
self.assertTrue(model._is_graph_network)
self.assertLen(model.layers, 2)
self.assertLen(model.weights, 2)
model.add(keras.layers.Dense(2))
self.assertTrue(model.built)
self.assertTrue(model._is_graph_network)
self.assertLen(model.layers, 3)
self.assertLen(model.weights, 4)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_feature_extraction(self):
# This tests layer connectivity reset when rebuilding
model = get_model()
model(np.random.random((3, 6))) # First build
model(np.random.random((4, 6))) # Triggers a rebuild
# Classic feature extractor pattern
extractor = keras.Model(inputs=model.inputs,
outputs=[layer.output for layer in model.layers])
# Check that inputs and outputs are connected
_ = extractor(np.random.random((4, 6)))
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_saving_savedmodel(self):
model = get_model()
model(np.random.random((3, 6))) # Build model
path = os.path.join(self.get_temp_dir(), 'model_path')
model.save(path)
new_model = keras.models.load_model(path)
for layer1, layer2 in zip(model._layers, new_model._layers):
self.assertEqual(layer1.name, layer2.name)
for w1, w2 in zip(layer1.weights, layer2.weights):
self.assertAllClose(w1, w2)
@unittest.skipIf(h5py is None, 'Test requires h5py')
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_saving_h5(self):
path = os.path.join(self.get_temp_dir(), 'model_path.h5')
model = get_model()
model(np.random.random((3, 6))) # Build model
path = os.path.join(self.get_temp_dir(), 'model_path.h5')
model.save(path)
new_model = keras.models.load_model(path)
for layer1, layer2 in zip(model._layers, new_model._layers):
self.assertEqual(layer1.name, layer2.name)
for w1, w2 in zip(layer1.weights, layer2.weights):
self.assertAllClose(w1, w2)
@keras_parameterized.run_all_keras_modes
def test_shared_layer(self):
# This tests that preexisting layer connectivity is preserved
# when auto-building graph networks
shared_layer = keras.layers.Dense(2)
m1 = keras.Sequential([shared_layer])
m1(np.random.random((3, 6)))
m2 = keras.Sequential([shared_layer])
m2(np.random.random((3, 6)))
# Nesting case
shared_layer = keras.layers.Dense(2)
m1 = keras.Sequential([shared_layer])
m2 = keras.Sequential([shared_layer, m1])
m2(np.random.random((3, 2)))
@keras_parameterized.run_all_keras_modes
def test_loss_layer(self):
class LossLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(math_ops.reduce_sum(inputs))
return inputs
# Test loss layer alone
model = keras.Sequential([LossLayer()])
model.compile('rmsprop', run_eagerly=testing_utils.should_run_eagerly())
loss = model.train_on_batch(np.ones((2, 2)))
self.assertAllClose(loss, 4.)
model(np.random.random((4, 2))) # Triggers a rebuild
loss = model.train_on_batch(np.ones((1, 2)))
self.assertAllClose(loss, 2.)
# Test loss layer combined with another layer
model = keras.Sequential([
keras.layers.Dense(1, kernel_initializer='ones'),
LossLayer()])
model.compile('rmsprop', run_eagerly=testing_utils.should_run_eagerly())
loss = model.train_on_batch(np.ones((2, 2)))
self.assertAllClose(loss, 4.)
model(np.random.random((4, 2))) # Triggers a rebuild
loss = model.train_on_batch(np.ones((1, 2)))
self.assertLess(loss, 2.)
# Test loss layer combined with external loss
model = keras.Sequential([
keras.layers.Dense(1, kernel_initializer='ones'),
LossLayer()])
model.compile('rmsprop', 'mse',
run_eagerly=testing_utils.should_run_eagerly())
loss = model.train_on_batch(np.ones((2, 2)), np.ones((2, 2)))
model(np.random.random((4, 2))) # Triggers a rebuild
loss = model.train_on_batch(np.ones((1, 2)), np.ones((1, 2)))
def get_model():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, name='first_layer'))
model.add(keras.layers.Dropout(0.3, name='dp'))
model.add(keras.layers.Dense(2, name='last_layer'))
return model
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()

View File

@ -273,25 +273,23 @@ def Input( # pylint: disable=invalid-name
batch_input_shape = kwargs.pop('batch_input_shape',
kwargs.pop('batch_shape', None))
if shape and batch_input_shape:
if shape is not None and batch_input_shape is not None:
raise ValueError('Only provide the `shape` OR `batch_input_shape` argument '
'to Input, not both at the same time.')
if batch_input_shape is None and shape is None and tensor is None:
raise ValueError('Please provide to Input either a `shape`'
' or a `tensor` argument. Note that '
'`shape` does not include the batch '
'dimension.')
if kwargs:
raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
if batch_input_shape:
shape = batch_input_shape[1:]
input_layer_config.update({'batch_input_shape': batch_input_shape})
else:
input_layer_config.update(
{'batch_size': batch_size, 'input_shape': shape})
if kwargs:
raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
if shape is None and tensor is None:
raise ValueError('Please provide to Input either a `shape`'
' or a `tensor` argument. Note that '
'`shape` does not include the batch '
'dimension.')
input_layer = InputLayer(**input_layer_config)
# Return tensor including `_keras_history`.

View File

@ -207,15 +207,17 @@ class Network(base_layer.Layer):
self.output_names = None
self.input_names = None
self._is_compiled = False
self._saved_model_inputs_spec = None
# This is True for Sequential networks and Functional networks.
self._compute_output_and_mask_jointly = False
if not hasattr(self, 'optimizer'):
# Don't reset optimizer if already set.
self.optimizer = None
# Don't reset compilation if already done. This may occur if calling
# `__init__` (or `_init_graph_network`) on an already-compiled model
# such as a Sequential model. Sequential models may need to rebuild
# themselves after compilation.
self._maybe_create_attribute('_is_compiled', False)
self._maybe_create_attribute('optimizer', None)
self._scope = None # Never used.
self._reuse = None # Never used.

View File

@ -21,6 +21,9 @@ from __future__ import print_function
import copy
from tensorflow.python import tf2
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import layers as layer_module
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import input_layer
@ -60,11 +63,8 @@ class Sequential(training.Model):
>>> # This is identical to the following:
>>> model = tf.keras.Sequential()
>>> model.add(tf.keras.layers.Dense(8, input_dim=16))
>>> # And to the following:
>>> model = tf.keras.Sequential()
>>> model.add(tf.keras.layers.Dense(8, batch_input_shape=(None, 16)))
>>> model.add(tf.keras.Input(shape=(16,)))
>>> model.add(tf.keras.layers.Dense(8))
>>> # Note that you can also omit the `input_shape` argument.
>>> # In that case the model doesn't have any weights until the first call
@ -94,8 +94,8 @@ class Sequential(training.Model):
```python
# Note that when using the delayed-build pattern (no input shape specified),
# the model gets built the first time you call `fit` (or other training and
# evaluation methods).
# the model gets built the first time you call `fit`, `eval`, or `predict`,
# or the first time you call the model on some input data.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(8))
model.add(tf.keras.layers.Dense(1))
@ -117,14 +117,22 @@ class Sequential(training.Model):
self.supports_masking = True
self._compute_output_and_mask_jointly = True
self._auto_track_sub_layers = False
self._inferred_input_shape = None
self._has_explicit_input_shape = False
self._input_dtype = None
self._layer_call_argspecs = {}
self._created_nodes = set()
# Unfortunately some Sequential models using custom layers or FeatureColumn
# layers have multiple inputs. This is fundamentally incompatible with
# most of the Sequential API, and we have to disable a number of features
# for such models.
self._use_legacy_deferred_behavior = False
# Add to the model any layers passed to the constructor.
if layers:
if not isinstance(layers, (list, tuple)):
layers = [layers]
tf_utils.assert_no_legacy_layers(layers)
for layer in layers:
self.add(layer)
@ -209,6 +217,7 @@ class Sequential(training.Model):
self.outputs = outputs
self.inputs = layer_utils.get_source_inputs(self.outputs[0])
self.built = True
self._has_explicit_input_shape = True
elif self.outputs:
# If the model is being built continuously on top of an input layer:
@ -247,12 +256,90 @@ class Sequential(training.Model):
self.outputs = None
self.inputs = None
self.built = False
self._inferred_input_shape = None
self._has_explicit_input_shape = False
elif self._is_graph_network:
self.layers[-1]._outbound_nodes = []
self.outputs = [self.layers[-1].output]
self._init_graph_network(self.inputs, self.outputs, name=self.name)
self.built = True
@trackable.no_automatic_dependency_tracking
def _build_graph_network_for_inferred_shape(self,
input_shape,
input_dtype=None):
if input_shape is None or not self.layers:
return
if not tf2.enabled() or not ops.executing_eagerly_outside_functions():
# This behavior is disabled in V1 or when eager execution is disabled.
return
if (not self._has_explicit_input_shape and
not self._use_legacy_deferred_behavior):
# Determine whether the input shape is novel, i.e. whether the model
# should be rebuilt.
input_shape = tuple(input_shape)
if self._inferred_input_shape is None:
new_shape = input_shape
else:
new_shape = relax_input_shape(self._inferred_input_shape, input_shape)
if (new_shape is not None and new_shape != self._inferred_input_shape):
# A novel shape has been received: we need to rebuild the model.
# In case we are inside a graph function, we step out of it.
with ops.init_scope():
inputs = input_layer.Input(
batch_shape=new_shape,
dtype=input_dtype,
name=self.layers[0].name + '_input')
layer_input = inputs
created_nodes = set()
for layer in self.layers:
# Clear nodes previously created via this method. This prevents
# node accumulation and ensures that e.g. `layer.output` is
# always connected to `model.inputs`
# (this is important e.g. for the feature extraction use case).
# We don't just do `layer._inbound_nodes = []` in order
# not to break shared layers added to Sequential models (which is
# technically illegal as per the `add()` docstring,
# but wasn't previously disabled).
clear_previously_created_nodes(layer, self._created_nodes)
try:
# Create Functional API connection by calling the current layer
layer_output = layer(layer_input)
except: # pylint:disable=bare-except
# Functional API calls may fail for a number of reasons:
# 1) The layer may be buggy. In this case it will be easier for
# the user to debug if we fail on the first call on concrete data,
# instead of our own call on a symbolic input.
# 2) The layer is dynamic (graph-incompatible) and hasn't
# overridden `compute_output_shape`. In this case, it is
# impossible to build a graph network.
# 3) The layer is otherwise incompatible with the Functional API
# (e.g. this is the case for some probabilistic layers that rely
# on hacks and that do not return tensors).
# In all these cases, we should avoid creating a graph network
# (or we simply can't).
self._use_legacy_deferred_behavior = True
return
if len(nest.flatten(layer_output)) != 1:
raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
# Keep track of nodes just created above
track_nodes_created_by_last_call(layer, created_nodes)
layer_input = layer_output
outputs = layer_output
self._created_nodes = created_nodes
try:
# Initialize a graph Network. This call will never fail for
# a stack of valid Keras layers.
# However some users have layers that are fundamentally incompatible
# with the Functional API, which do not return tensors. In this
# case, we fall back to the legacy deferred behavior.
# TODO(fchollet): consider raising here, as we should not be
# supporting such layers.
self._init_graph_network(inputs, outputs, name=self.name)
except: # pylint:disable=bare-except
self._use_legacy_deferred_behavior = True
self._inferred_input_shape = new_shape
@generic_utils.default
def build(self, input_shape=None):
if self._is_graph_network:
@ -260,20 +347,35 @@ class Sequential(training.Model):
else:
if input_shape is None:
raise ValueError('You must provide an `input_shape` argument.')
input_shape = tuple(input_shape)
self._build_input_shape = input_shape
super(Sequential, self).build(input_shape)
self._build_graph_network_for_inferred_shape(input_shape)
if not self.built:
input_shape = tuple(input_shape)
self._build_input_shape = input_shape
super(Sequential, self).build(input_shape)
self.built = True
def call(self, inputs, training=None, mask=None): # pylint: disable=redefined-outer-name
# If applicable, update the static input shape of the model.
if not self._has_explicit_input_shape:
if not tensor_util.is_tensor(inputs):
# This is a Sequential with mutiple inputs. This is technically an
# invalid use case of Sequential, but we tolerate it for backwards
# compatibility.
self._use_legacy_deferred_behavior = True
self._build_input_shape = nest.map_structure(_get_shape_tuple, inputs)
if tf2.enabled():
logging.warning('Layers in a Sequential model should only have a '
'single input tensor, but we receive a %s input: %s'
'\nConsider rewriting this model with the Functional '
'API.' % (type(inputs), inputs))
else:
self._build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype)
if self._is_graph_network:
if not self.built:
self._init_graph_network(self.inputs, self.outputs, name=self.name)
return super(Sequential, self).call(inputs, training=training, mask=mask)
if self._build_input_shape is None:
self._build_input_shape = nest.map_structure(_get_shape_tuple, inputs)
outputs = inputs # handle the corner case where self.layers is empty
for layer in self.layers:
# During each iteration, `inputs` are the inputs to `layer`, and `outputs`
@ -293,7 +395,6 @@ class Sequential(training.Model):
# `outputs` will be the inputs to the next layer.
inputs = outputs
mask = outputs._keras_mask
return outputs
def compute_output_shape(self, input_shape):
@ -419,3 +520,34 @@ def _get_shape_tuple(t):
return tuple(shape.as_list())
return None
return None
def relax_input_shape(shape_1, shape_2):
if shape_1 is None or shape_2 is None:
return None
if len(shape_1) != len(shape_2):
return None
return tuple(None if d1 != d2 else d1 for d1, d2 in zip(shape_1, shape_2))
def clear_previously_created_nodes(layer, created_nodes):
"""Remove nodes from `created_nodes` from the layer's inbound_nodes."""
for node in layer._inbound_nodes:
prev_layers = node.inbound_layers
for prev_layer in nest.flatten(prev_layers):
prev_layer._outbound_nodes = [
n for n in prev_layer._outbound_nodes
if n not in created_nodes]
layer._inbound_nodes = [
n for n in layer._inbound_nodes if n not in created_nodes]
def track_nodes_created_by_last_call(layer, created_nodes):
"""Adds to `created_nodes` the nodes created by the last call to `layer`."""
if not layer._inbound_nodes:
return
created_nodes.add(layer._inbound_nodes[-1])
prev_layers = layer._inbound_nodes[-1].inbound_layers
for prev_layer in nest.flatten(prev_layers):
if prev_layer._outbound_nodes:
created_nodes.add(prev_layer._outbound_nodes[-1])

View File

@ -27,6 +27,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import array_ops
@ -126,7 +127,6 @@ class TestSequential(keras_parameterized.TestCase):
y = np.random.random((batch_size, num_classes))
model.fit(x, y, epochs=1)
self.assertTrue(model.built)
self.assertFalse(model._is_graph_network)
self.assertEqual(len(model.weights), 2 * 2)
@keras_parameterized.run_all_keras_modes
@ -158,7 +158,6 @@ class TestSequential(keras_parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=steps_per_epoch)
self.assertTrue(model.built)
self.assertEqual(len(model.weights), 2 * 2)
self.assertFalse(model._is_graph_network)
# TODO(kaftan) This test fails w/ run_with_all_keras_modes. File ticket
@parameterized.parameters((True,), (False,))
@ -342,11 +341,16 @@ class TestSequential(keras_parameterized.TestCase):
y = np.random.random((2, 5))
model.fit(x, y, epochs=1)
@keras_parameterized.run_all_keras_modes
def test_variable_names(self):
@test_util.run_v1_only('Behavior changed in V2.')
def test_variable_names_deferred(self):
model = keras.models.Sequential([keras.layers.Dense(3)])
model.add(keras.layers.Dense(2))
model(array_ops.ones([2, 4]))
# Note that for regular sequential models (wrapping graph network),
# the layers' weights are built
# without the model name as prefix (because the Functional API __call__
# reset the name scope). This is fixable, but it would be
# backwards incompatible.
self.assertEqual(
['sequential/dense/kernel:0', 'sequential/dense/bias:0',
'sequential/dense_1/kernel:0', 'sequential/dense_1/bias:0'],
@ -404,7 +408,6 @@ class TestSequential(keras_parameterized.TestCase):
self.assertTrue(model.built)
model.add(keras.layers.Dense(3))
self.assertFalse(model.built)
model.compile('adam', loss='mse')
model.fit(np.random.random((1, 3)), np.random.random((1, 3)))

View File

@ -319,9 +319,9 @@ class LocallyConnectedImplementationModeTest(test.TestCase,
copy_model_weights(model_from=model_2, model_to=model_3)
# Compare outputs at initialization.
out_1 = model_1.call(inputs)
out_2 = model_2.call(inputs)
out_3 = model_3.call(inputs)
out_1 = model_1(inputs)
out_2 = model_2(inputs)
out_3 = model_3(inputs)
self.assertAllCloseAccordingToType(
out_2, out_1, rtol=1e-5, atol=1e-5)
@ -351,9 +351,9 @@ class LocallyConnectedImplementationModeTest(test.TestCase,
shuffle=False)
# Compare outputs after a few training steps.
out_1 = model_1.call(inputs)
out_2 = model_2.call(inputs)
out_3 = model_3.call(inputs)
out_1 = model_1(inputs)
out_2 = model_2(inputs)
out_3 = model_3(inputs)
self.assertAllCloseAccordingToType(
out_2, out_1, atol=2e-4)

View File

@ -519,7 +519,7 @@ class KerasModelTest(keras_parameterized.TestCase):
regularizer=regularizer,
input_shape=(1,))
if use_input_spec:
layer.input_spec = input_spec.InputSpec(shape=(2, 1))
layer.input_spec = input_spec.InputSpec(shape=(None, 1))
model = testing_utils.get_model_from_layers([layer], input_shape=(1,),
input_dtype=dtypes.float16)
if get_config:

View File

@ -544,7 +544,7 @@ class KerasObjectLoader(tf_load.Loader):
config = json_utils.decode(
self._proto.nodes[model_id].user_object.metadata)['config']
if isinstance(model, models_lib.Sequential):
if not isinstance(layers[0], input_layer.InputLayer):
if config['layers'][0]['class_name'] != 'InputLayer':
if 'batch_input_shape' in config['layers'][0]['config']:
batch_input_shape = config['layers'][0]['config']['batch_input_shape']
layers.insert(0, input_layer.InputLayer(

View File

@ -91,8 +91,8 @@ def raise_model_input_error(model):
raise ValueError(
'Model {} cannot be saved because the input shapes have not been '
'set. Usually, input shapes are automatically determined from calling'
' .fit() or .predict(). To manually set the shapes, call '
'model._set_inputs(inputs).'.format(model))
' `.fit()` or `.predict()`. To manually set the shapes, call '
'`model.build(input_shape)`.'.format(model))
def trace_model_call(model, input_signature=None):

View File

@ -157,11 +157,6 @@ class SequentialIntegrationTest(KerasIntegrationTest):
verbose=2)
model = self._save_and_reload_model(model)
# TODO(b/134537740): model.pop doesn't update model outputs properly when
# model.outputs is already defined, so just set to `None` for now.
model.inputs = None
model.outputs = None
model.pop()
model.add(keras.layers.Dense(y_train.shape[-1], activation='softmax'))

View File

@ -1257,7 +1257,6 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
model=deferred_sequential)
status = deferred_sequential_checkpoint.restore(save_path)
deferred_sequential.add(core.Dense(4))
deferred_sequential(constant_op.constant([[1.]]))
deferred_second_dense = core.Dense(5)
deferred_sequential.add(deferred_second_dense)
deferred_sequential(constant_op.constant([[1.]]))

View File

@ -56,7 +56,9 @@ class SerializationTests(test.TestCase):
sequential_round_trip = json.loads(
json.dumps(model, default=serialization.get_json_type))
self.assertEqual(
5, sequential_round_trip["config"]["layers"][1]["config"]["units"])
# Note that `config['layers'][0]` will be an InputLayer in V2
# (but not in V1)
5, sequential_round_trip["config"]["layers"][-1]["config"]["units"])
@test_util.run_in_graph_and_eager_modes
def test_serialize_model(self):