Enable input spec checking for Functional models.
PiperOrigin-RevId: 324625967 Change-Id: Ide0a8cb4d6d7614f86f22088a5ef95d72636c54e
This commit is contained in:
parent
856dc4f7b6
commit
59dc165d26
@ -25,7 +25,14 @@
|
|||||||
* Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient.
|
* Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient.
|
||||||
* Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know.
|
* Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know.
|
||||||
* Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed.
|
* Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed.
|
||||||
|
* Start enforcing input shape assumptions when calling Functional API Keras
|
||||||
|
models. This may potentially break some users, in case there is a mismatch
|
||||||
|
between the shape used when creating `Input` objects in a Functional model,
|
||||||
|
and the shape of the data passed to that model. You can fix this mismatch by
|
||||||
|
either calling the model with correctly-shaped data, or by relaxing `Input`
|
||||||
|
shape assumptions (note that you can pass shapes with `None` entries for axes
|
||||||
|
that are meant to be dynamic). You can also disable the input checking
|
||||||
|
entirely by setting `model.input_spec = None`.
|
||||||
|
|
||||||
## Known Caveats
|
## Known Caveats
|
||||||
|
|
||||||
|
@ -1239,7 +1239,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
|||||||
dataset = dataset.repeat(100)
|
dataset = dataset.repeat(100)
|
||||||
dataset = dataset.batch(10)
|
dataset = dataset.batch(10)
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, 'incompatible with the layer'):
|
with self.assertRaisesRegex(ValueError, 'is incompatible with'):
|
||||||
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
|
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
|
@ -970,12 +970,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
if self._autocast:
|
if self._autocast:
|
||||||
inputs = self._maybe_cast_inputs(inputs, input_list)
|
inputs = self._maybe_cast_inputs(inputs, input_list)
|
||||||
|
|
||||||
|
input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
|
||||||
if eager:
|
if eager:
|
||||||
call_fn = self.call
|
call_fn = self.call
|
||||||
name_scope = self._name
|
name_scope = self._name
|
||||||
else:
|
else:
|
||||||
input_spec.assert_input_compatibility(self.input_spec, inputs,
|
|
||||||
self.name)
|
|
||||||
name_scope = self._name_scope() # Avoid autoincrementing.
|
name_scope = self._name_scope() # Avoid autoincrementing.
|
||||||
call_fn = self._autographed_call()
|
call_fn = self._autographed_call()
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.keras import backend
|
|||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.keras.engine import input_layer as input_layer_module
|
from tensorflow.python.keras.engine import input_layer as input_layer_module
|
||||||
|
from tensorflow.python.keras.engine import input_spec
|
||||||
from tensorflow.python.keras.engine import keras_tensor
|
from tensorflow.python.keras.engine import keras_tensor
|
||||||
from tensorflow.python.keras.engine import node as node_module
|
from tensorflow.python.keras.engine import node as node_module
|
||||||
from tensorflow.python.keras.engine import training as training_lib
|
from tensorflow.python.keras.engine import training as training_lib
|
||||||
@ -248,6 +249,32 @@ class Functional(training_lib.Model):
|
|||||||
"""
|
"""
|
||||||
return nest.map_structure(backend.int_shape, self.input)
|
return nest.map_structure(backend.int_shape, self.input)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_spec(self):
|
||||||
|
if hasattr(self, '_manual_input_spec'):
|
||||||
|
return self._manual_input_spec
|
||||||
|
if (isinstance(self._nested_inputs, (dict, list, tuple)) and
|
||||||
|
len(self._nested_inputs) != len(self.inputs)):
|
||||||
|
# Case where we have a nested structure.
|
||||||
|
# In such a case we can't safely run any checks.
|
||||||
|
return None
|
||||||
|
if isinstance(self._nested_inputs, dict):
|
||||||
|
# Case where `_nested_inputs` is a plain dict of Inputs.
|
||||||
|
names = sorted(self._nested_inputs.keys())
|
||||||
|
return [input_spec.InputSpec(
|
||||||
|
shape=shape_with_no_batch_size(self._nested_inputs[name]),
|
||||||
|
allow_last_axis_squeeze=True, name=name) for name in names]
|
||||||
|
else:
|
||||||
|
# Single input, or list / tuple of inputs.
|
||||||
|
# The data may be passed as a dict keyed by input name.
|
||||||
|
return [input_spec.InputSpec(
|
||||||
|
shape=shape_with_no_batch_size(x), allow_last_axis_squeeze=True,
|
||||||
|
name=x._keras_history.layer.name) for x in self.inputs]
|
||||||
|
|
||||||
|
@input_spec.setter
|
||||||
|
def input_spec(self, value):
|
||||||
|
self._manual_input_spec = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output(self):
|
def output(self):
|
||||||
"""Retrieves the output tensor(s) of a layer.
|
"""Retrieves the output tensor(s) of a layer.
|
||||||
@ -1312,3 +1339,12 @@ def get_network_config(network, serialize_layer_fn=None):
|
|||||||
model_outputs = tf_utils.convert_inner_node_data(model_outputs)
|
model_outputs = tf_utils.convert_inner_node_data(model_outputs)
|
||||||
config['output_layers'] = model_outputs
|
config['output_layers'] = model_outputs
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def shape_with_no_batch_size(x):
|
||||||
|
if x.shape.rank is None:
|
||||||
|
return None
|
||||||
|
shape = x.shape.as_list()
|
||||||
|
if shape:
|
||||||
|
shape[0] = None
|
||||||
|
return shape
|
||||||
|
@ -1059,7 +1059,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(history.history['loss'][0], 0.0)
|
self.assertEqual(history.history['loss'][0], 0.0)
|
||||||
|
|
||||||
# Check the output dtype
|
# Check the output dtype
|
||||||
self.assertEqual(model(array_ops.ones(3, 3)).dtype, dtypes.float16)
|
self.assertEqual(model(array_ops.ones((3, 10))).dtype, dtypes.float16)
|
||||||
|
|
||||||
model = training_lib.Model.from_config(
|
model = training_lib.Model.from_config(
|
||||||
model.get_config(), custom_objects={'Double': Double})
|
model.get_config(), custom_objects={'Double': Double})
|
||||||
@ -1075,7 +1075,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(history.history['loss'][0], 0.0)
|
self.assertEqual(history.history['loss'][0], 0.0)
|
||||||
|
|
||||||
# Check the output dtype
|
# Check the output dtype
|
||||||
self.assertEqual(model(array_ops.ones(3, 3)).dtype, dtypes.float16)
|
self.assertEqual(model(array_ops.ones((3, 10))).dtype, dtypes.float16)
|
||||||
|
|
||||||
@combinations.generate(combinations.keras_mode_combinations())
|
@combinations.generate(combinations.keras_mode_combinations())
|
||||||
def test_call_kwarg_nonserializable(self):
|
def test_call_kwarg_nonserializable(self):
|
||||||
@ -1793,8 +1793,8 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
|||||||
network = functional.Functional.from_config(network.get_config())
|
network = functional.Functional.from_config(network.get_config())
|
||||||
|
|
||||||
result_tensor = network({
|
result_tensor = network({
|
||||||
'x': array_ops.ones((1, 1), 'float32'),
|
'x1': array_ops.ones((1, 1), 'float32'),
|
||||||
'y': array_ops.ones((1, 1), 'float32')
|
'x2': array_ops.ones((1, 1), 'float32')
|
||||||
})
|
})
|
||||||
result = self.evaluate(result_tensor)
|
result = self.evaluate(result_tensor)
|
||||||
self.assertAllEqual(result, [[2.]])
|
self.assertAllEqual(result, [[2.]])
|
||||||
@ -2340,6 +2340,57 @@ class InputsOutputsErrorTest(keras_parameterized.TestCase):
|
|||||||
TypeError, "('Keyword argument not understood:', 'output')"):
|
TypeError, "('Keyword argument not understood:', 'output')"):
|
||||||
models.Model(inputs=inputs, output=outputs)
|
models.Model(inputs=inputs, output=outputs)
|
||||||
|
|
||||||
|
def test_input_spec(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
return
|
||||||
|
inputs = input_layer_lib.Input((10,))
|
||||||
|
outputs = layers.Dense(10)(inputs)
|
||||||
|
model = models.Model(inputs, outputs)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'.*expected shape=.*'):
|
||||||
|
model(np.zeros((3, 11)))
|
||||||
|
|
||||||
|
def test_input_spec_list_of_inputs(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
return
|
||||||
|
input_1 = input_layer_lib.Input((10,), name='1')
|
||||||
|
input_2 = input_layer_lib.Input((5,), name='2')
|
||||||
|
x = layers.Concatenate()([input_1, input_2])
|
||||||
|
outputs = layers.Dense(10)(x)
|
||||||
|
model = models.Model([input_1, input_2], outputs)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'.*expects 2 input.*'):
|
||||||
|
model(np.zeros((3, 10)))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'.*expects 2 input.*'):
|
||||||
|
model([np.zeros((3, 10)), np.zeros((3, 5)), np.zeros((3, 10))])
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'.*expected shape=.*'):
|
||||||
|
model([np.zeros((3, 10)), np.zeros((3, 6))])
|
||||||
|
|
||||||
|
# Test passing data via dict keyed by input name
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Missing data for input.*'):
|
||||||
|
model({'1': np.zeros((3, 10))})
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'.*expected shape=.*'):
|
||||||
|
model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))})
|
||||||
|
|
||||||
|
def test_input_spec_dict(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
return
|
||||||
|
input_1 = input_layer_lib.Input((10,))
|
||||||
|
input_2 = input_layer_lib.Input((5,))
|
||||||
|
x = layers.Concatenate()([input_1, input_2])
|
||||||
|
outputs = layers.Dense(10)(x)
|
||||||
|
model = models.Model({'1': input_1, '2': input_2}, outputs)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Missing data for input.*'):
|
||||||
|
model({'1': np.zeros((3, 10))})
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'.*expected shape=.*'):
|
||||||
|
model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -46,12 +46,30 @@ class InputSpec(object):
|
|||||||
Arguments:
|
Arguments:
|
||||||
dtype: Expected DataType of the input.
|
dtype: Expected DataType of the input.
|
||||||
shape: Shape tuple, expected shape of the input
|
shape: Shape tuple, expected shape of the input
|
||||||
(may include None for unchecked axes).
|
(may include None for unchecked axes). Includes the batch size.
|
||||||
ndim: Integer, expected rank of the input.
|
ndim: Integer, expected rank of the input.
|
||||||
max_ndim: Integer, maximum rank of the input.
|
max_ndim: Integer, maximum rank of the input.
|
||||||
min_ndim: Integer, minimum rank of the input.
|
min_ndim: Integer, minimum rank of the input.
|
||||||
axes: Dictionary mapping integer axes to
|
axes: Dictionary mapping integer axes to
|
||||||
a specific dimension value.
|
a specific dimension value.
|
||||||
|
allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long
|
||||||
|
as the last axis of the input is 1, as well as inputs of rank N-1
|
||||||
|
as long as the last axis of the spec is 1.
|
||||||
|
name: Expected key corresponding to this input when passing data as
|
||||||
|
a dictionary.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyLayer(Layer):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyLayer, self).__init__()
|
||||||
|
# The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1)
|
||||||
|
# and raise an appropriate error message otherwise.
|
||||||
|
self.input_spec = InputSpec(
|
||||||
|
shape=(None, 28, 28, 1),
|
||||||
|
allow_last_axis_squeeze=True)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -60,8 +78,15 @@ class InputSpec(object):
|
|||||||
ndim=None,
|
ndim=None,
|
||||||
max_ndim=None,
|
max_ndim=None,
|
||||||
min_ndim=None,
|
min_ndim=None,
|
||||||
axes=None):
|
axes=None,
|
||||||
|
allow_last_axis_squeeze=False,
|
||||||
|
name=None):
|
||||||
self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
|
self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
|
||||||
|
shape = tensor_shape.TensorShape(shape)
|
||||||
|
if shape.rank is None:
|
||||||
|
shape = None
|
||||||
|
else:
|
||||||
|
shape = tuple(shape.as_list())
|
||||||
if shape is not None:
|
if shape is not None:
|
||||||
self.ndim = len(shape)
|
self.ndim = len(shape)
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
@ -70,6 +95,8 @@ class InputSpec(object):
|
|||||||
self.shape = None
|
self.shape = None
|
||||||
self.max_ndim = max_ndim
|
self.max_ndim = max_ndim
|
||||||
self.min_ndim = min_ndim
|
self.min_ndim = min_ndim
|
||||||
|
self.name = name
|
||||||
|
self.allow_last_axis_squeeze = allow_last_axis_squeeze
|
||||||
try:
|
try:
|
||||||
axes = axes or {}
|
axes = axes or {}
|
||||||
self.axes = {int(k): axes[k] for k in axes}
|
self.axes = {int(k): axes[k] for k in axes}
|
||||||
@ -149,6 +176,21 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
|||||||
if not input_spec:
|
if not input_spec:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
input_spec = nest.flatten(input_spec)
|
||||||
|
if isinstance(inputs, dict):
|
||||||
|
# Flatten `inputs` by reference order if input spec names are provided
|
||||||
|
names = [spec.name for spec in input_spec]
|
||||||
|
if all(names):
|
||||||
|
list_inputs = []
|
||||||
|
for name in names:
|
||||||
|
if name not in inputs:
|
||||||
|
raise ValueError('Missing data for input "%s". '
|
||||||
|
'You passed a data dictionary with keys %s. '
|
||||||
|
'Expected the following keys: %s' %
|
||||||
|
(name, list(inputs.keys()), names))
|
||||||
|
list_inputs.append(inputs[name])
|
||||||
|
inputs = list_inputs
|
||||||
|
|
||||||
inputs = nest.flatten(inputs)
|
inputs = nest.flatten(inputs)
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
# Having a shape/dtype is the only commonality of the various tensor-like
|
# Having a shape/dtype is the only commonality of the various tensor-like
|
||||||
@ -157,81 +199,83 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
|||||||
# have a `shape` attribute.
|
# have a `shape` attribute.
|
||||||
if not hasattr(x, 'shape'):
|
if not hasattr(x, 'shape'):
|
||||||
raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,))
|
raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,))
|
||||||
input_spec = nest.flatten(input_spec)
|
|
||||||
if len(inputs) != len(input_spec):
|
if len(inputs) != len(input_spec):
|
||||||
raise ValueError('Layer ' + layer_name + ' expects ' +
|
raise ValueError('Layer ' + layer_name + ' expects ' +
|
||||||
str(len(input_spec)) + ' inputs, '
|
str(len(input_spec)) + ' input(s), '
|
||||||
'but it received ' + str(len(inputs)) +
|
'but it received ' + str(len(inputs)) +
|
||||||
' input tensors. Inputs received: ' + str(inputs))
|
' input tensors. Inputs received: ' + str(inputs))
|
||||||
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
|
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
|
||||||
if spec is None:
|
if spec is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (spec.ndim is not None or
|
shape = tensor_shape.TensorShape(x.shape)
|
||||||
spec.min_ndim is not None or
|
if shape.rank is None:
|
||||||
spec.max_ndim is not None):
|
return
|
||||||
if x.shape.ndims is None:
|
|
||||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
|
||||||
layer_name + ' is incompatible with the layer: '
|
|
||||||
'its rank is undefined, but the layer requires a '
|
|
||||||
'defined rank.')
|
|
||||||
|
|
||||||
# Check ndim.
|
# Check ndim.
|
||||||
if spec.ndim is not None:
|
if spec.ndim is not None and not spec.allow_last_axis_squeeze:
|
||||||
ndim = x.shape.ndims
|
ndim = shape.rank
|
||||||
if ndim != spec.ndim:
|
if ndim != spec.ndim:
|
||||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||||
layer_name + ' is incompatible with the layer: '
|
layer_name + ' is incompatible with the layer: '
|
||||||
'expected ndim=' + str(spec.ndim) + ', found ndim=' +
|
'expected ndim=' + str(spec.ndim) + ', found ndim=' +
|
||||||
str(ndim) + '. Full shape received: ' +
|
str(ndim) + '. Full shape received: ' +
|
||||||
str(x.shape.as_list()))
|
str(tuple(shape)))
|
||||||
if spec.max_ndim is not None:
|
if spec.max_ndim is not None:
|
||||||
ndim = x.shape.ndims
|
ndim = x.shape.rank
|
||||||
if ndim is not None and ndim > spec.max_ndim:
|
if ndim is not None and ndim > spec.max_ndim:
|
||||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||||
layer_name + ' is incompatible with the layer: '
|
layer_name + ' is incompatible with the layer: '
|
||||||
'expected max_ndim=' + str(spec.max_ndim) +
|
'expected max_ndim=' + str(spec.max_ndim) +
|
||||||
', found ndim=' + str(ndim))
|
', found ndim=' + str(ndim))
|
||||||
if spec.min_ndim is not None:
|
if spec.min_ndim is not None:
|
||||||
ndim = x.shape.ndims
|
ndim = x.shape.rank
|
||||||
if ndim is not None and ndim < spec.min_ndim:
|
if ndim is not None and ndim < spec.min_ndim:
|
||||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||||
layer_name + ' is incompatible with the layer: '
|
layer_name + ' is incompatible with the layer: '
|
||||||
': expected min_ndim=' + str(spec.min_ndim) +
|
': expected min_ndim=' + str(spec.min_ndim) +
|
||||||
', found ndim=' + str(ndim) +
|
', found ndim=' + str(ndim) +
|
||||||
'. Full shape received: ' +
|
'. Full shape received: ' +
|
||||||
str(x.shape.as_list()))
|
str(tuple(shape)))
|
||||||
# Check dtype.
|
# Check dtype.
|
||||||
if spec.dtype is not None:
|
if spec.dtype is not None:
|
||||||
if x.dtype != spec.dtype:
|
if x.dtype.name != spec.dtype:
|
||||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||||
layer_name + ' is incompatible with the layer: '
|
layer_name + ' is incompatible with the layer: '
|
||||||
'expected dtype=' + str(spec.dtype) +
|
'expected dtype=' + str(spec.dtype) +
|
||||||
', found dtype=' + str(x.dtype))
|
', found dtype=' + str(x.dtype))
|
||||||
|
|
||||||
# Check specific shape axes.
|
# Check specific shape axes.
|
||||||
|
shape_as_list = shape.as_list()
|
||||||
if spec.axes:
|
if spec.axes:
|
||||||
shape = x.shape.as_list()
|
|
||||||
if shape is not None:
|
|
||||||
for axis, value in spec.axes.items():
|
for axis, value in spec.axes.items():
|
||||||
if hasattr(value, 'value'):
|
if hasattr(value, 'value'):
|
||||||
value = value.value
|
value = value.value
|
||||||
if value is not None and shape[int(axis)] not in {value, None}:
|
if value is not None and shape_as_list[int(axis)] not in {value, None}:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
|
'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
|
||||||
' incompatible with the layer: expected axis ' + str(axis) +
|
' incompatible with the layer: expected axis ' + str(axis) +
|
||||||
' of input shape to have value ' + str(value) +
|
' of input shape to have value ' + str(value) +
|
||||||
' but received input with shape ' + str(shape))
|
' but received input with shape ' + display_shape(x.shape))
|
||||||
# Check shape.
|
# Check shape.
|
||||||
if spec.shape is not None:
|
if spec.shape is not None and shape.rank is not None:
|
||||||
shape = x.shape.as_list()
|
spec_shape = spec.shape
|
||||||
if shape is not None:
|
if spec.allow_last_axis_squeeze:
|
||||||
for spec_dim, dim in zip(spec.shape, shape):
|
if shape_as_list and shape_as_list[-1] == 1:
|
||||||
|
shape_as_list = shape_as_list[:-1]
|
||||||
|
if spec_shape and spec_shape[-1] == 1:
|
||||||
|
spec_shape = spec_shape[:-1]
|
||||||
|
for spec_dim, dim in zip(spec_shape, shape_as_list):
|
||||||
if spec_dim is not None and dim is not None:
|
if spec_dim is not None and dim is not None:
|
||||||
if spec_dim != dim:
|
if spec_dim != dim:
|
||||||
raise ValueError('Input ' + str(input_index) +
|
raise ValueError('Input ' + str(input_index) +
|
||||||
' is incompatible with layer ' + layer_name +
|
' is incompatible with layer ' + layer_name +
|
||||||
': expected shape=' + str(spec.shape) +
|
': expected shape=' + str(spec.shape) +
|
||||||
', found shape=' + str(shape))
|
', found shape=' + display_shape(x.shape))
|
||||||
|
|
||||||
|
|
||||||
|
def display_shape(shape):
|
||||||
|
return str(tuple(shape.as_list()))
|
||||||
|
|
||||||
|
|
||||||
def to_tensor_spec(input_spec, default_dtype=None):
|
def to_tensor_spec(input_spec, default_dtype=None):
|
||||||
|
@ -495,10 +495,16 @@ class Sequential(functional.Functional):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def input_spec(self):
|
def input_spec(self):
|
||||||
|
if hasattr(self, '_manual_input_spec'):
|
||||||
|
return self._manual_input_spec
|
||||||
if self.layers and hasattr(self.layers[0], 'input_spec'):
|
if self.layers and hasattr(self.layers[0], 'input_spec'):
|
||||||
return self.layers[0].input_spec
|
return self.layers[0].input_spec
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@input_spec.setter
|
||||||
|
def input_spec(self, value):
|
||||||
|
self._manual_input_spec = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _trackable_saved_model_saver(self):
|
def _trackable_saved_model_saver(self):
|
||||||
return model_serialization.SequentialSavedModelSaver(self)
|
return model_serialization.SequentialSavedModelSaver(self)
|
||||||
|
@ -324,9 +324,9 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
|||||||
run_eagerly=testing_utils.should_run_eagerly())
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
np_inputs = nest.map_structure(
|
np_inputs = nest.map_structure(
|
||||||
lambda x: np.ones((10,) + tuple(x.shape[1:]), 'float32'), model.inputs)
|
lambda x: np.ones((2,) + tuple(x.shape[1:]), 'float32'), model.inputs)
|
||||||
np_outputs = nest.map_structure(
|
np_outputs = nest.map_structure(
|
||||||
lambda x: np.ones((10,) + tuple(x.shape[1:]), 'float32'), model.outputs)
|
lambda x: np.ones((2,) + tuple(x.shape[1:]), 'float32'), model.outputs)
|
||||||
model.fit(np_inputs, np_outputs, batch_size=2)
|
model.fit(np_inputs, np_outputs, batch_size=2)
|
||||||
model(np_inputs) # Test calling the model directly on inputs.
|
model(np_inputs) # Test calling the model directly on inputs.
|
||||||
|
|
||||||
@ -402,7 +402,7 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
|||||||
def test_getitem_slice_with_step_only(self):
|
def test_getitem_slice_with_step_only(self):
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
self.skipTest('Complex slicing like this fails in v1')
|
self.skipTest('Complex slicing like this fails in v1')
|
||||||
inp = keras.Input(shape=(4, 3, 8))
|
inp = keras.Input(shape=(8,))
|
||||||
slice_step = keras.Input(shape=(), dtype='int32')
|
slice_step = keras.Input(shape=(), dtype='int32')
|
||||||
|
|
||||||
out = inp[..., ::slice_step[0]]
|
out = inp[..., ::slice_step[0]]
|
||||||
@ -508,7 +508,7 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
|||||||
def test_getitem_slice_with_stop_only(self):
|
def test_getitem_slice_with_stop_only(self):
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
self.skipTest('Complex slicing like this fails in v1')
|
self.skipTest('Complex slicing like this fails in v1')
|
||||||
inp = keras.Input(shape=(4, 3, 8))
|
inp = keras.Input(shape=(8,))
|
||||||
slice_stop = keras.Input(shape=(), dtype='int32')
|
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||||
|
|
||||||
out = inp[:slice_stop[0]]
|
out = inp[:slice_stop[0]]
|
||||||
@ -544,7 +544,7 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
|||||||
def test_getitem_slice_with_stop_and_ellipsis_only(self):
|
def test_getitem_slice_with_stop_and_ellipsis_only(self):
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
self.skipTest('Complex slicing like this fails in v1')
|
self.skipTest('Complex slicing like this fails in v1')
|
||||||
inp = keras.Input(shape=(4, 3, 8))
|
inp = keras.Input(shape=(8,))
|
||||||
slice_stop = keras.Input(shape=(), dtype='int32')
|
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||||
|
|
||||||
out = inp[..., :slice_stop[0]]
|
out = inp[..., :slice_stop[0]]
|
||||||
@ -646,14 +646,14 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def test_numerical_correctness_with_attrs(self):
|
def test_numerical_correctness_with_attrs(self):
|
||||||
x = ops.convert_to_tensor_v2([[1.5, 1.5], [2.5, 3.5]])
|
x = ops.convert_to_tensor_v2([[1.5, 1.5], [2.5, 3.5]])
|
||||||
inputs = keras.Input(shape=(10,))
|
inputs = keras.Input(shape=(2,))
|
||||||
outputs = math_ops.reduce_mean(inputs, axis=1)
|
outputs = math_ops.reduce_mean(inputs, axis=1)
|
||||||
model = keras.Model(inputs, outputs)
|
model = keras.Model(inputs, outputs)
|
||||||
y = self.evaluate(model(x))
|
y = self.evaluate(model(x))
|
||||||
self.assertAllClose(y, [1.5, 3.])
|
self.assertAllClose(y, [1.5, 3.])
|
||||||
|
|
||||||
def test_numerical_correctness_serialization(self):
|
def test_numerical_correctness_serialization(self):
|
||||||
x = ops.convert_to_tensor_v2([-1., 0., -2., 1.])
|
x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]])
|
||||||
inputs = keras.Input(shape=(4,))
|
inputs = keras.Input(shape=(4,))
|
||||||
outputs = gen_nn_ops.relu(inputs)
|
outputs = gen_nn_ops.relu(inputs)
|
||||||
model1 = keras.Model(inputs, outputs)
|
model1 = keras.Model(inputs, outputs)
|
||||||
|
@ -277,11 +277,6 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
|
|||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
layer = CustomerLayer()
|
|
||||||
with self.assertRaisesRegex(ValueError, r'requires a defined rank'):
|
|
||||||
layer.apply(array_ops.placeholder('int32'))
|
|
||||||
|
|
||||||
layer = CustomerLayer()
|
layer = CustomerLayer()
|
||||||
with self.assertRaisesRegex(ValueError, r'expected ndim=2'):
|
with self.assertRaisesRegex(ValueError, r'expected ndim=2'):
|
||||||
layer.apply(constant_op.constant([1]))
|
layer.apply(constant_op.constant([1]))
|
||||||
@ -295,29 +290,24 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
|
|||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def testInputSpecMinNdimCheck(self):
|
def testInputSpecMinNdimCheck(self):
|
||||||
|
|
||||||
class CustomerLayer(base_layers.Layer):
|
class CustomLayer(base_layers.Layer):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CustomerLayer, self).__init__()
|
super(CustomLayer, self).__init__()
|
||||||
self.input_spec = input_spec.InputSpec(min_ndim=2)
|
self.input_spec = input_spec.InputSpec(min_ndim=2)
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
layer = CustomLayer()
|
||||||
layer = CustomerLayer()
|
|
||||||
with self.assertRaisesRegex(ValueError, r'requires a defined rank'):
|
|
||||||
layer.apply(array_ops.placeholder('int32'))
|
|
||||||
|
|
||||||
layer = CustomerLayer()
|
|
||||||
with self.assertRaisesRegex(ValueError, r'expected min_ndim=2'):
|
with self.assertRaisesRegex(ValueError, r'expected min_ndim=2'):
|
||||||
layer.apply(constant_op.constant([1]))
|
layer.apply(constant_op.constant([1]))
|
||||||
|
|
||||||
# Works
|
# Works
|
||||||
layer = CustomerLayer()
|
layer = CustomLayer()
|
||||||
layer.apply(constant_op.constant([[1], [2]]))
|
layer.apply(constant_op.constant([[1], [2]]))
|
||||||
|
|
||||||
layer = CustomerLayer()
|
layer = CustomLayer()
|
||||||
layer.apply(constant_op.constant([[[1], [2]]]))
|
layer.apply(constant_op.constant([[[1], [2]]]))
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
@ -332,11 +322,6 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
|
|||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
layer = CustomerLayer()
|
|
||||||
with self.assertRaisesRegex(ValueError, r'requires a defined rank'):
|
|
||||||
layer.apply(array_ops.placeholder('int32'))
|
|
||||||
|
|
||||||
layer = CustomerLayer()
|
layer = CustomerLayer()
|
||||||
with self.assertRaisesRegex(ValueError, r'expected max_ndim=2'):
|
with self.assertRaisesRegex(ValueError, r'expected max_ndim=2'):
|
||||||
layer.apply(constant_op.constant([[[1], [2]]]))
|
layer.apply(constant_op.constant([[[1], [2]]]))
|
||||||
|
@ -4,7 +4,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "from_config"
|
name: "from_config"
|
||||||
|
@ -4,7 +4,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "from_config"
|
name: "from_config"
|
||||||
|
@ -4,7 +4,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "from_config"
|
name: "from_config"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user