diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index ceb461e1196..cabff2109b1 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -589,3 +589,17 @@ tf_py_test( "@absl_py//absl/testing:parameterized", ], ) + +tf_py_test( + name = "deferred_sequential_test", + size = "small", + srcs = ["deferred_sequential_test.py"], + python_version = "PY3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/keras", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/python/keras/engine/deferred_sequential_test.py b/tensorflow/python/keras/engine/deferred_sequential_test.py new file mode 100644 index 00000000000..06f0aa33d5c --- /dev/null +++ b/tensorflow/python/keras/engine/deferred_sequential_test.py @@ -0,0 +1,216 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests specific to deferred-build `Sequential` models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.compat import v2_compat +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras import testing_utils +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + +try: + import h5py # pylint:disable=g-import-not-at-top +except ImportError: + h5py = None + + +class TestDeferredSequential(keras_parameterized.TestCase): + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_build_behavior(self): + # Test graph network creation after __call__ + model = get_model() + model(np.random.random((2, 6))) + self.assertLen(model.weights, 4) + self.assertTrue(model._is_graph_network) + self.assertLen(model.inputs, 1) + self.assertLen(model.outputs, 1) + self.assertEqual(model.inputs[0].shape.as_list(), [2, 6]) + self.assertEqual(model.outputs[0].shape.as_list(), [2, 2]) + + # Test effect of new __call__ with a different shape + model(np.random.random((3, 6))) + self.assertLen(model.inputs, 1) + self.assertLen(model.outputs, 1) + self.assertEqual(model.inputs[0].shape.as_list(), [None, 6]) + self.assertEqual(model.outputs[0].shape.as_list(), [None, 2]) + model(np.random.random((4, 6))) + self.assertLen(model.inputs, 1) + self.assertLen(model.outputs, 1) + self.assertEqual(model.inputs[0].shape.as_list(), [None, 6]) + self.assertEqual(model.outputs[0].shape.as_list(), [None, 2]) + + # Test graph network creation after build + model = get_model() + model.build((None, 6)) + self.assertLen(model.weights, 4) + self.assertTrue(model._is_graph_network) + self.assertLen(model.inputs, 1) + self.assertLen(model.outputs, 1) + self.assertEqual(model.inputs[0].shape.as_list(), [None, 6]) + self.assertEqual(model.outputs[0].shape.as_list(), [None, 2]) + + # Test graph network creation after compile/fit + model = get_model() + model.compile( + loss='mse', + optimizer='rmsprop', + metrics=[keras.metrics.CategoricalAccuracy()], + run_eagerly=testing_utils.should_run_eagerly()) + model.fit(np.zeros((2, 6)), np.zeros((2, 2))) + self.assertLen(model.weights, 4) + self.assertTrue(model._is_graph_network) + self.assertLen(model.inputs, 1) + self.assertLen(model.outputs, 1) + # Inconsistency here: with eager `fit`, the model is built with shape + # (2, 6), but with graph function `fit`, it is built with shape `(None, 6)`. + # This is likely due to our assumption "the batch size should be dynamic" + # at the level of `Model`. TODO(fchollet): investigate and resolve. + self.assertEqual(model.inputs[0].shape.as_list()[-1], 6) + self.assertEqual(model.outputs[0].shape.as_list()[-1], 2) + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_add_and_pop(self): + model = get_model() + model.build((None, 6)) + self.assertTrue(model.built) + self.assertTrue(model._is_graph_network) + self.assertLen(model.layers, 3) + self.assertLen(model.weights, 4) + model.pop() + self.assertTrue(model.built) + self.assertTrue(model._is_graph_network) + self.assertLen(model.layers, 2) + self.assertLen(model.weights, 2) + model.add(keras.layers.Dense(2)) + self.assertTrue(model.built) + self.assertTrue(model._is_graph_network) + self.assertLen(model.layers, 3) + self.assertLen(model.weights, 4) + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_feature_extraction(self): + # This tests layer connectivity reset when rebuilding + model = get_model() + model(np.random.random((3, 6))) # First build + model(np.random.random((4, 6))) # Triggers a rebuild + # Classic feature extractor pattern + extractor = keras.Model(inputs=model.inputs, + outputs=[layer.output for layer in model.layers]) + # Check that inputs and outputs are connected + _ = extractor(np.random.random((4, 6))) + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_saving_savedmodel(self): + model = get_model() + model(np.random.random((3, 6))) # Build model + + path = os.path.join(self.get_temp_dir(), 'model_path') + model.save(path) + new_model = keras.models.load_model(path) + for layer1, layer2 in zip(model._layers, new_model._layers): + self.assertEqual(layer1.name, layer2.name) + for w1, w2 in zip(layer1.weights, layer2.weights): + self.assertAllClose(w1, w2) + + @unittest.skipIf(h5py is None, 'Test requires h5py') + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_saving_h5(self): + path = os.path.join(self.get_temp_dir(), 'model_path.h5') + model = get_model() + model(np.random.random((3, 6))) # Build model + + path = os.path.join(self.get_temp_dir(), 'model_path.h5') + model.save(path) + new_model = keras.models.load_model(path) + for layer1, layer2 in zip(model._layers, new_model._layers): + self.assertEqual(layer1.name, layer2.name) + for w1, w2 in zip(layer1.weights, layer2.weights): + self.assertAllClose(w1, w2) + + @keras_parameterized.run_all_keras_modes + def test_shared_layer(self): + # This tests that preexisting layer connectivity is preserved + # when auto-building graph networks + shared_layer = keras.layers.Dense(2) + m1 = keras.Sequential([shared_layer]) + m1(np.random.random((3, 6))) + m2 = keras.Sequential([shared_layer]) + m2(np.random.random((3, 6))) + # Nesting case + shared_layer = keras.layers.Dense(2) + m1 = keras.Sequential([shared_layer]) + m2 = keras.Sequential([shared_layer, m1]) + m2(np.random.random((3, 2))) + + @keras_parameterized.run_all_keras_modes + def test_loss_layer(self): + class LossLayer(keras.layers.Layer): + + def call(self, inputs): + self.add_loss(math_ops.reduce_sum(inputs)) + return inputs + + # Test loss layer alone + model = keras.Sequential([LossLayer()]) + model.compile('rmsprop', run_eagerly=testing_utils.should_run_eagerly()) + loss = model.train_on_batch(np.ones((2, 2))) + self.assertAllClose(loss, 4.) + model(np.random.random((4, 2))) # Triggers a rebuild + loss = model.train_on_batch(np.ones((1, 2))) + self.assertAllClose(loss, 2.) + + # Test loss layer combined with another layer + model = keras.Sequential([ + keras.layers.Dense(1, kernel_initializer='ones'), + LossLayer()]) + model.compile('rmsprop', run_eagerly=testing_utils.should_run_eagerly()) + loss = model.train_on_batch(np.ones((2, 2))) + self.assertAllClose(loss, 4.) + model(np.random.random((4, 2))) # Triggers a rebuild + loss = model.train_on_batch(np.ones((1, 2))) + self.assertLess(loss, 2.) + + # Test loss layer combined with external loss + model = keras.Sequential([ + keras.layers.Dense(1, kernel_initializer='ones'), + LossLayer()]) + model.compile('rmsprop', 'mse', + run_eagerly=testing_utils.should_run_eagerly()) + loss = model.train_on_batch(np.ones((2, 2)), np.ones((2, 2))) + model(np.random.random((4, 2))) # Triggers a rebuild + loss = model.train_on_batch(np.ones((1, 2)), np.ones((1, 2))) + + +def get_model(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, name='first_layer')) + model.add(keras.layers.Dropout(0.3, name='dp')) + model.add(keras.layers.Dense(2, name='last_layer')) + return model + + +if __name__ == '__main__': + v2_compat.enable_v2_behavior() + test.main() diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index 1231ef6a55b..bcaaa5c46f6 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -273,25 +273,23 @@ def Input( # pylint: disable=invalid-name batch_input_shape = kwargs.pop('batch_input_shape', kwargs.pop('batch_shape', None)) - if shape and batch_input_shape: + if shape is not None and batch_input_shape is not None: raise ValueError('Only provide the `shape` OR `batch_input_shape` argument ' 'to Input, not both at the same time.') + if batch_input_shape is None and shape is None and tensor is None: + raise ValueError('Please provide to Input either a `shape`' + ' or a `tensor` argument. Note that ' + '`shape` does not include the batch ' + 'dimension.') + if kwargs: + raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) + if batch_input_shape: shape = batch_input_shape[1:] input_layer_config.update({'batch_input_shape': batch_input_shape}) else: input_layer_config.update( {'batch_size': batch_size, 'input_shape': shape}) - - if kwargs: - raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) - - if shape is None and tensor is None: - raise ValueError('Please provide to Input either a `shape`' - ' or a `tensor` argument. Note that ' - '`shape` does not include the batch ' - 'dimension.') - input_layer = InputLayer(**input_layer_config) # Return tensor including `_keras_history`. diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 8954e30f7ca..d4b88136531 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -207,15 +207,17 @@ class Network(base_layer.Layer): self.output_names = None self.input_names = None - self._is_compiled = False self._saved_model_inputs_spec = None # This is True for Sequential networks and Functional networks. self._compute_output_and_mask_jointly = False - if not hasattr(self, 'optimizer'): - # Don't reset optimizer if already set. - self.optimizer = None + # Don't reset compilation if already done. This may occur if calling + # `__init__` (or `_init_graph_network`) on an already-compiled model + # such as a Sequential model. Sequential models may need to rebuild + # themselves after compilation. + self._maybe_create_attribute('_is_compiled', False) + self._maybe_create_attribute('optimizer', None) self._scope = None # Never used. self._reuse = None # Never used. diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index 9fb35e21e01..1341c03be0d 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -21,6 +21,9 @@ from __future__ import print_function import copy +from tensorflow.python import tf2 +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.keras import layers as layer_module from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import input_layer @@ -60,11 +63,8 @@ class Sequential(training.Model): >>> # This is identical to the following: >>> model = tf.keras.Sequential() - >>> model.add(tf.keras.layers.Dense(8, input_dim=16)) - - >>> # And to the following: - >>> model = tf.keras.Sequential() - >>> model.add(tf.keras.layers.Dense(8, batch_input_shape=(None, 16))) + >>> model.add(tf.keras.Input(shape=(16,))) + >>> model.add(tf.keras.layers.Dense(8)) >>> # Note that you can also omit the `input_shape` argument. >>> # In that case the model doesn't have any weights until the first call @@ -94,8 +94,8 @@ class Sequential(training.Model): ```python # Note that when using the delayed-build pattern (no input shape specified), - # the model gets built the first time you call `fit` (or other training and - # evaluation methods). + # the model gets built the first time you call `fit`, `eval`, or `predict`, + # or the first time you call the model on some input data. model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(8)) model.add(tf.keras.layers.Dense(1)) @@ -117,14 +117,22 @@ class Sequential(training.Model): self.supports_masking = True self._compute_output_and_mask_jointly = True self._auto_track_sub_layers = False - + self._inferred_input_shape = None + self._has_explicit_input_shape = False + self._input_dtype = None self._layer_call_argspecs = {} + self._created_nodes = set() + + # Unfortunately some Sequential models using custom layers or FeatureColumn + # layers have multiple inputs. This is fundamentally incompatible with + # most of the Sequential API, and we have to disable a number of features + # for such models. + self._use_legacy_deferred_behavior = False # Add to the model any layers passed to the constructor. if layers: if not isinstance(layers, (list, tuple)): layers = [layers] - tf_utils.assert_no_legacy_layers(layers) for layer in layers: self.add(layer) @@ -209,6 +217,7 @@ class Sequential(training.Model): self.outputs = outputs self.inputs = layer_utils.get_source_inputs(self.outputs[0]) self.built = True + self._has_explicit_input_shape = True elif self.outputs: # If the model is being built continuously on top of an input layer: @@ -247,12 +256,90 @@ class Sequential(training.Model): self.outputs = None self.inputs = None self.built = False + self._inferred_input_shape = None + self._has_explicit_input_shape = False elif self._is_graph_network: self.layers[-1]._outbound_nodes = [] self.outputs = [self.layers[-1].output] self._init_graph_network(self.inputs, self.outputs, name=self.name) self.built = True + @trackable.no_automatic_dependency_tracking + def _build_graph_network_for_inferred_shape(self, + input_shape, + input_dtype=None): + if input_shape is None or not self.layers: + return + if not tf2.enabled() or not ops.executing_eagerly_outside_functions(): + # This behavior is disabled in V1 or when eager execution is disabled. + return + if (not self._has_explicit_input_shape and + not self._use_legacy_deferred_behavior): + # Determine whether the input shape is novel, i.e. whether the model + # should be rebuilt. + input_shape = tuple(input_shape) + if self._inferred_input_shape is None: + new_shape = input_shape + else: + new_shape = relax_input_shape(self._inferred_input_shape, input_shape) + if (new_shape is not None and new_shape != self._inferred_input_shape): + # A novel shape has been received: we need to rebuild the model. + # In case we are inside a graph function, we step out of it. + with ops.init_scope(): + inputs = input_layer.Input( + batch_shape=new_shape, + dtype=input_dtype, + name=self.layers[0].name + '_input') + layer_input = inputs + created_nodes = set() + for layer in self.layers: + # Clear nodes previously created via this method. This prevents + # node accumulation and ensures that e.g. `layer.output` is + # always connected to `model.inputs` + # (this is important e.g. for the feature extraction use case). + # We don't just do `layer._inbound_nodes = []` in order + # not to break shared layers added to Sequential models (which is + # technically illegal as per the `add()` docstring, + # but wasn't previously disabled). + clear_previously_created_nodes(layer, self._created_nodes) + try: + # Create Functional API connection by calling the current layer + layer_output = layer(layer_input) + except: # pylint:disable=bare-except + # Functional API calls may fail for a number of reasons: + # 1) The layer may be buggy. In this case it will be easier for + # the user to debug if we fail on the first call on concrete data, + # instead of our own call on a symbolic input. + # 2) The layer is dynamic (graph-incompatible) and hasn't + # overridden `compute_output_shape`. In this case, it is + # impossible to build a graph network. + # 3) The layer is otherwise incompatible with the Functional API + # (e.g. this is the case for some probabilistic layers that rely + # on hacks and that do not return tensors). + # In all these cases, we should avoid creating a graph network + # (or we simply can't). + self._use_legacy_deferred_behavior = True + return + if len(nest.flatten(layer_output)) != 1: + raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) + # Keep track of nodes just created above + track_nodes_created_by_last_call(layer, created_nodes) + layer_input = layer_output + outputs = layer_output + self._created_nodes = created_nodes + try: + # Initialize a graph Network. This call will never fail for + # a stack of valid Keras layers. + # However some users have layers that are fundamentally incompatible + # with the Functional API, which do not return tensors. In this + # case, we fall back to the legacy deferred behavior. + # TODO(fchollet): consider raising here, as we should not be + # supporting such layers. + self._init_graph_network(inputs, outputs, name=self.name) + except: # pylint:disable=bare-except + self._use_legacy_deferred_behavior = True + self._inferred_input_shape = new_shape + @generic_utils.default def build(self, input_shape=None): if self._is_graph_network: @@ -260,20 +347,35 @@ class Sequential(training.Model): else: if input_shape is None: raise ValueError('You must provide an `input_shape` argument.') - input_shape = tuple(input_shape) - self._build_input_shape = input_shape - super(Sequential, self).build(input_shape) + self._build_graph_network_for_inferred_shape(input_shape) + if not self.built: + input_shape = tuple(input_shape) + self._build_input_shape = input_shape + super(Sequential, self).build(input_shape) self.built = True def call(self, inputs, training=None, mask=None): # pylint: disable=redefined-outer-name + # If applicable, update the static input shape of the model. + if not self._has_explicit_input_shape: + if not tensor_util.is_tensor(inputs): + # This is a Sequential with mutiple inputs. This is technically an + # invalid use case of Sequential, but we tolerate it for backwards + # compatibility. + self._use_legacy_deferred_behavior = True + self._build_input_shape = nest.map_structure(_get_shape_tuple, inputs) + if tf2.enabled(): + logging.warning('Layers in a Sequential model should only have a ' + 'single input tensor, but we receive a %s input: %s' + '\nConsider rewriting this model with the Functional ' + 'API.' % (type(inputs), inputs)) + else: + self._build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype) + if self._is_graph_network: if not self.built: self._init_graph_network(self.inputs, self.outputs, name=self.name) return super(Sequential, self).call(inputs, training=training, mask=mask) - if self._build_input_shape is None: - self._build_input_shape = nest.map_structure(_get_shape_tuple, inputs) - outputs = inputs # handle the corner case where self.layers is empty for layer in self.layers: # During each iteration, `inputs` are the inputs to `layer`, and `outputs` @@ -293,7 +395,6 @@ class Sequential(training.Model): # `outputs` will be the inputs to the next layer. inputs = outputs mask = outputs._keras_mask - return outputs def compute_output_shape(self, input_shape): @@ -419,3 +520,34 @@ def _get_shape_tuple(t): return tuple(shape.as_list()) return None return None + + +def relax_input_shape(shape_1, shape_2): + if shape_1 is None or shape_2 is None: + return None + if len(shape_1) != len(shape_2): + return None + return tuple(None if d1 != d2 else d1 for d1, d2 in zip(shape_1, shape_2)) + + +def clear_previously_created_nodes(layer, created_nodes): + """Remove nodes from `created_nodes` from the layer's inbound_nodes.""" + for node in layer._inbound_nodes: + prev_layers = node.inbound_layers + for prev_layer in nest.flatten(prev_layers): + prev_layer._outbound_nodes = [ + n for n in prev_layer._outbound_nodes + if n not in created_nodes] + layer._inbound_nodes = [ + n for n in layer._inbound_nodes if n not in created_nodes] + + +def track_nodes_created_by_last_call(layer, created_nodes): + """Adds to `created_nodes` the nodes created by the last call to `layer`.""" + if not layer._inbound_nodes: + return + created_nodes.add(layer._inbound_nodes[-1]) + prev_layers = layer._inbound_nodes[-1].inbound_layers + for prev_layer in nest.flatten(prev_layers): + if prev_layer._outbound_nodes: + created_nodes.add(prev_layer._outbound_nodes[-1]) diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 440388f5453..c65ac094663 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -27,6 +27,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.ops import array_ops @@ -126,7 +127,6 @@ class TestSequential(keras_parameterized.TestCase): y = np.random.random((batch_size, num_classes)) model.fit(x, y, epochs=1) self.assertTrue(model.built) - self.assertFalse(model._is_graph_network) self.assertEqual(len(model.weights), 2 * 2) @keras_parameterized.run_all_keras_modes @@ -158,7 +158,6 @@ class TestSequential(keras_parameterized.TestCase): model.fit(dataset, epochs=1, steps_per_epoch=steps_per_epoch) self.assertTrue(model.built) self.assertEqual(len(model.weights), 2 * 2) - self.assertFalse(model._is_graph_network) # TODO(kaftan) This test fails w/ run_with_all_keras_modes. File ticket @parameterized.parameters((True,), (False,)) @@ -342,11 +341,16 @@ class TestSequential(keras_parameterized.TestCase): y = np.random.random((2, 5)) model.fit(x, y, epochs=1) - @keras_parameterized.run_all_keras_modes - def test_variable_names(self): + @test_util.run_v1_only('Behavior changed in V2.') + def test_variable_names_deferred(self): model = keras.models.Sequential([keras.layers.Dense(3)]) model.add(keras.layers.Dense(2)) model(array_ops.ones([2, 4])) + # Note that for regular sequential models (wrapping graph network), + # the layers' weights are built + # without the model name as prefix (because the Functional API __call__ + # reset the name scope). This is fixable, but it would be + # backwards incompatible. self.assertEqual( ['sequential/dense/kernel:0', 'sequential/dense/bias:0', 'sequential/dense_1/kernel:0', 'sequential/dense_1/bias:0'], @@ -404,7 +408,6 @@ class TestSequential(keras_parameterized.TestCase): self.assertTrue(model.built) model.add(keras.layers.Dense(3)) - self.assertFalse(model.built) model.compile('adam', loss='mse') model.fit(np.random.random((1, 3)), np.random.random((1, 3))) diff --git a/tensorflow/python/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py index 52aaffb8ef3..78611317972 100644 --- a/tensorflow/python/keras/layers/local_test.py +++ b/tensorflow/python/keras/layers/local_test.py @@ -319,9 +319,9 @@ class LocallyConnectedImplementationModeTest(test.TestCase, copy_model_weights(model_from=model_2, model_to=model_3) # Compare outputs at initialization. - out_1 = model_1.call(inputs) - out_2 = model_2.call(inputs) - out_3 = model_3.call(inputs) + out_1 = model_1(inputs) + out_2 = model_2(inputs) + out_3 = model_3(inputs) self.assertAllCloseAccordingToType( out_2, out_1, rtol=1e-5, atol=1e-5) @@ -351,9 +351,9 @@ class LocallyConnectedImplementationModeTest(test.TestCase, shuffle=False) # Compare outputs after a few training steps. - out_1 = model_1.call(inputs) - out_2 = model_2.call(inputs) - out_3 = model_3.call(inputs) + out_1 = model_1(inputs) + out_2 = model_2(inputs) + out_3 = model_3(inputs) self.assertAllCloseAccordingToType( out_2, out_1, atol=2e-4) diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py index 3711a3d43c0..a27be08deb2 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py @@ -519,7 +519,7 @@ class KerasModelTest(keras_parameterized.TestCase): regularizer=regularizer, input_shape=(1,)) if use_input_spec: - layer.input_spec = input_spec.InputSpec(shape=(2, 1)) + layer.input_spec = input_spec.InputSpec(shape=(None, 1)) model = testing_utils.get_model_from_layers([layer], input_shape=(1,), input_dtype=dtypes.float16) if get_config: diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index eb731e2d9f8..5ffeb0671a1 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -544,7 +544,7 @@ class KerasObjectLoader(tf_load.Loader): config = json_utils.decode( self._proto.nodes[model_id].user_object.metadata)['config'] if isinstance(model, models_lib.Sequential): - if not isinstance(layers[0], input_layer.InputLayer): + if config['layers'][0]['class_name'] != 'InputLayer': if 'batch_input_shape' in config['layers'][0]['config']: batch_input_shape = config['layers'][0]['config']['batch_input_shape'] layers.insert(0, input_layer.InputLayer( diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py index 90fcff89249..f0465146a19 100644 --- a/tensorflow/python/keras/saving/saving_utils.py +++ b/tensorflow/python/keras/saving/saving_utils.py @@ -91,8 +91,8 @@ def raise_model_input_error(model): raise ValueError( 'Model {} cannot be saved because the input shapes have not been ' 'set. Usually, input shapes are automatically determined from calling' - ' .fit() or .predict(). To manually set the shapes, call ' - 'model._set_inputs(inputs).'.format(model)) + ' `.fit()` or `.predict()`. To manually set the shapes, call ' + '`model.build(input_shape)`.'.format(model)) def trace_model_call(model, input_signature=None): diff --git a/tensorflow/python/keras/tests/integration_test.py b/tensorflow/python/keras/tests/integration_test.py index 5493859f3ad..dbb6f75f031 100644 --- a/tensorflow/python/keras/tests/integration_test.py +++ b/tensorflow/python/keras/tests/integration_test.py @@ -157,11 +157,6 @@ class SequentialIntegrationTest(KerasIntegrationTest): verbose=2) model = self._save_and_reload_model(model) - # TODO(b/134537740): model.pop doesn't update model outputs properly when - # model.outputs is already defined, so just set to `None` for now. - model.inputs = None - model.outputs = None - model.pop() model.add(keras.layers.Dense(y_train.shape[-1], activation='softmax')) diff --git a/tensorflow/python/training/tracking/util_test.py b/tensorflow/python/training/tracking/util_test.py index e63baa60003..a69a34c1038 100644 --- a/tensorflow/python/training/tracking/util_test.py +++ b/tensorflow/python/training/tracking/util_test.py @@ -1257,7 +1257,6 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase): model=deferred_sequential) status = deferred_sequential_checkpoint.restore(save_path) deferred_sequential.add(core.Dense(4)) - deferred_sequential(constant_op.constant([[1.]])) deferred_second_dense = core.Dense(5) deferred_sequential.add(deferred_second_dense) deferred_sequential(constant_op.constant([[1.]])) diff --git a/tensorflow/python/util/serialization_test.py b/tensorflow/python/util/serialization_test.py index 6df7533831b..a66dd11ba99 100644 --- a/tensorflow/python/util/serialization_test.py +++ b/tensorflow/python/util/serialization_test.py @@ -56,7 +56,9 @@ class SerializationTests(test.TestCase): sequential_round_trip = json.loads( json.dumps(model, default=serialization.get_json_type)) self.assertEqual( - 5, sequential_round_trip["config"]["layers"][1]["config"]["units"]) + # Note that `config['layers'][0]` will be an InputLayer in V2 + # (but not in V1) + 5, sequential_round_trip["config"]["layers"][-1]["config"]["units"]) @test_util.run_in_graph_and_eager_modes def test_serialize_model(self):