Enable input spec checking for Functional models.
PiperOrigin-RevId: 324625967 Change-Id: Ide0a8cb4d6d7614f86f22088a5ef95d72636c54e
This commit is contained in:
parent
856dc4f7b6
commit
59dc165d26
@ -25,7 +25,14 @@
|
||||
* Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient.
|
||||
* Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know.
|
||||
* Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed.
|
||||
|
||||
* Start enforcing input shape assumptions when calling Functional API Keras
|
||||
models. This may potentially break some users, in case there is a mismatch
|
||||
between the shape used when creating `Input` objects in a Functional model,
|
||||
and the shape of the data passed to that model. You can fix this mismatch by
|
||||
either calling the model with correctly-shaped data, or by relaxing `Input`
|
||||
shape assumptions (note that you can pass shapes with `None` entries for axes
|
||||
that are meant to be dynamic). You can also disable the input checking
|
||||
entirely by setting `model.input_spec = None`.
|
||||
|
||||
## Known Caveats
|
||||
|
||||
|
@ -1239,7 +1239,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
dataset = dataset.repeat(100)
|
||||
dataset = dataset.batch(10)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'incompatible with the layer'):
|
||||
with self.assertRaisesRegex(ValueError, 'is incompatible with'):
|
||||
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
|
||||
|
||||
@combinations.generate(
|
||||
|
@ -970,12 +970,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
if self._autocast:
|
||||
inputs = self._maybe_cast_inputs(inputs, input_list)
|
||||
|
||||
input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
|
||||
if eager:
|
||||
call_fn = self.call
|
||||
name_scope = self._name
|
||||
else:
|
||||
input_spec.assert_input_compatibility(self.input_spec, inputs,
|
||||
self.name)
|
||||
name_scope = self._name_scope() # Avoid autoincrementing.
|
||||
call_fn = self._autographed_call()
|
||||
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_layer as input_layer_module
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.engine import node as node_module
|
||||
from tensorflow.python.keras.engine import training as training_lib
|
||||
@ -248,6 +249,32 @@ class Functional(training_lib.Model):
|
||||
"""
|
||||
return nest.map_structure(backend.int_shape, self.input)
|
||||
|
||||
@property
|
||||
def input_spec(self):
|
||||
if hasattr(self, '_manual_input_spec'):
|
||||
return self._manual_input_spec
|
||||
if (isinstance(self._nested_inputs, (dict, list, tuple)) and
|
||||
len(self._nested_inputs) != len(self.inputs)):
|
||||
# Case where we have a nested structure.
|
||||
# In such a case we can't safely run any checks.
|
||||
return None
|
||||
if isinstance(self._nested_inputs, dict):
|
||||
# Case where `_nested_inputs` is a plain dict of Inputs.
|
||||
names = sorted(self._nested_inputs.keys())
|
||||
return [input_spec.InputSpec(
|
||||
shape=shape_with_no_batch_size(self._nested_inputs[name]),
|
||||
allow_last_axis_squeeze=True, name=name) for name in names]
|
||||
else:
|
||||
# Single input, or list / tuple of inputs.
|
||||
# The data may be passed as a dict keyed by input name.
|
||||
return [input_spec.InputSpec(
|
||||
shape=shape_with_no_batch_size(x), allow_last_axis_squeeze=True,
|
||||
name=x._keras_history.layer.name) for x in self.inputs]
|
||||
|
||||
@input_spec.setter
|
||||
def input_spec(self, value):
|
||||
self._manual_input_spec = value
|
||||
|
||||
@property
|
||||
def output(self):
|
||||
"""Retrieves the output tensor(s) of a layer.
|
||||
@ -1312,3 +1339,12 @@ def get_network_config(network, serialize_layer_fn=None):
|
||||
model_outputs = tf_utils.convert_inner_node_data(model_outputs)
|
||||
config['output_layers'] = model_outputs
|
||||
return config
|
||||
|
||||
|
||||
def shape_with_no_batch_size(x):
|
||||
if x.shape.rank is None:
|
||||
return None
|
||||
shape = x.shape.as_list()
|
||||
if shape:
|
||||
shape[0] = None
|
||||
return shape
|
||||
|
@ -1059,7 +1059,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
# Check the output dtype
|
||||
self.assertEqual(model(array_ops.ones(3, 3)).dtype, dtypes.float16)
|
||||
self.assertEqual(model(array_ops.ones((3, 10))).dtype, dtypes.float16)
|
||||
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'Double': Double})
|
||||
@ -1075,7 +1075,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
# Check the output dtype
|
||||
self.assertEqual(model(array_ops.ones(3, 3)).dtype, dtypes.float16)
|
||||
self.assertEqual(model(array_ops.ones((3, 10))).dtype, dtypes.float16)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_call_kwarg_nonserializable(self):
|
||||
@ -1793,8 +1793,8 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
||||
network = functional.Functional.from_config(network.get_config())
|
||||
|
||||
result_tensor = network({
|
||||
'x': array_ops.ones((1, 1), 'float32'),
|
||||
'y': array_ops.ones((1, 1), 'float32')
|
||||
'x1': array_ops.ones((1, 1), 'float32'),
|
||||
'x2': array_ops.ones((1, 1), 'float32')
|
||||
})
|
||||
result = self.evaluate(result_tensor)
|
||||
self.assertAllEqual(result, [[2.]])
|
||||
@ -2340,6 +2340,57 @@ class InputsOutputsErrorTest(keras_parameterized.TestCase):
|
||||
TypeError, "('Keyword argument not understood:', 'output')"):
|
||||
models.Model(inputs=inputs, output=outputs)
|
||||
|
||||
def test_input_spec(self):
|
||||
if not context.executing_eagerly():
|
||||
return
|
||||
inputs = input_layer_lib.Input((10,))
|
||||
outputs = layers.Dense(10)(inputs)
|
||||
model = models.Model(inputs, outputs)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'.*expected shape=.*'):
|
||||
model(np.zeros((3, 11)))
|
||||
|
||||
def test_input_spec_list_of_inputs(self):
|
||||
if not context.executing_eagerly():
|
||||
return
|
||||
input_1 = input_layer_lib.Input((10,), name='1')
|
||||
input_2 = input_layer_lib.Input((5,), name='2')
|
||||
x = layers.Concatenate()([input_1, input_2])
|
||||
outputs = layers.Dense(10)(x)
|
||||
model = models.Model([input_1, input_2], outputs)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'.*expects 2 input.*'):
|
||||
model(np.zeros((3, 10)))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'.*expects 2 input.*'):
|
||||
model([np.zeros((3, 10)), np.zeros((3, 5)), np.zeros((3, 10))])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'.*expected shape=.*'):
|
||||
model([np.zeros((3, 10)), np.zeros((3, 6))])
|
||||
|
||||
# Test passing data via dict keyed by input name
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Missing data for input.*'):
|
||||
model({'1': np.zeros((3, 10))})
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'.*expected shape=.*'):
|
||||
model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))})
|
||||
|
||||
def test_input_spec_dict(self):
|
||||
if not context.executing_eagerly():
|
||||
return
|
||||
input_1 = input_layer_lib.Input((10,))
|
||||
input_2 = input_layer_lib.Input((5,))
|
||||
x = layers.Concatenate()([input_1, input_2])
|
||||
outputs = layers.Dense(10)(x)
|
||||
model = models.Model({'1': input_1, '2': input_2}, outputs)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Missing data for input.*'):
|
||||
model({'1': np.zeros((3, 10))})
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'.*expected shape=.*'):
|
||||
model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -44,14 +44,32 @@ class InputSpec(object):
|
||||
a None shape is compatible with any shape.
|
||||
|
||||
Arguments:
|
||||
dtype: Expected DataType of the input.
|
||||
shape: Shape tuple, expected shape of the input
|
||||
(may include None for unchecked axes).
|
||||
ndim: Integer, expected rank of the input.
|
||||
max_ndim: Integer, maximum rank of the input.
|
||||
min_ndim: Integer, minimum rank of the input.
|
||||
axes: Dictionary mapping integer axes to
|
||||
a specific dimension value.
|
||||
dtype: Expected DataType of the input.
|
||||
shape: Shape tuple, expected shape of the input
|
||||
(may include None for unchecked axes). Includes the batch size.
|
||||
ndim: Integer, expected rank of the input.
|
||||
max_ndim: Integer, maximum rank of the input.
|
||||
min_ndim: Integer, minimum rank of the input.
|
||||
axes: Dictionary mapping integer axes to
|
||||
a specific dimension value.
|
||||
allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long
|
||||
as the last axis of the input is 1, as well as inputs of rank N-1
|
||||
as long as the last axis of the spec is 1.
|
||||
name: Expected key corresponding to this input when passing data as
|
||||
a dictionary.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
class MyLayer(Layer):
|
||||
def __init__(self):
|
||||
super(MyLayer, self).__init__()
|
||||
# The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1)
|
||||
# and raise an appropriate error message otherwise.
|
||||
self.input_spec = InputSpec(
|
||||
shape=(None, 28, 28, 1),
|
||||
allow_last_axis_squeeze=True)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -60,8 +78,15 @@ class InputSpec(object):
|
||||
ndim=None,
|
||||
max_ndim=None,
|
||||
min_ndim=None,
|
||||
axes=None):
|
||||
axes=None,
|
||||
allow_last_axis_squeeze=False,
|
||||
name=None):
|
||||
self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
|
||||
shape = tensor_shape.TensorShape(shape)
|
||||
if shape.rank is None:
|
||||
shape = None
|
||||
else:
|
||||
shape = tuple(shape.as_list())
|
||||
if shape is not None:
|
||||
self.ndim = len(shape)
|
||||
self.shape = shape
|
||||
@ -70,6 +95,8 @@ class InputSpec(object):
|
||||
self.shape = None
|
||||
self.max_ndim = max_ndim
|
||||
self.min_ndim = min_ndim
|
||||
self.name = name
|
||||
self.allow_last_axis_squeeze = allow_last_axis_squeeze
|
||||
try:
|
||||
axes = axes or {}
|
||||
self.axes = {int(k): axes[k] for k in axes}
|
||||
@ -149,6 +176,21 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
||||
if not input_spec:
|
||||
return
|
||||
|
||||
input_spec = nest.flatten(input_spec)
|
||||
if isinstance(inputs, dict):
|
||||
# Flatten `inputs` by reference order if input spec names are provided
|
||||
names = [spec.name for spec in input_spec]
|
||||
if all(names):
|
||||
list_inputs = []
|
||||
for name in names:
|
||||
if name not in inputs:
|
||||
raise ValueError('Missing data for input "%s". '
|
||||
'You passed a data dictionary with keys %s. '
|
||||
'Expected the following keys: %s' %
|
||||
(name, list(inputs.keys()), names))
|
||||
list_inputs.append(inputs[name])
|
||||
inputs = list_inputs
|
||||
|
||||
inputs = nest.flatten(inputs)
|
||||
for x in inputs:
|
||||
# Having a shape/dtype is the only commonality of the various tensor-like
|
||||
@ -157,81 +199,83 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
||||
# have a `shape` attribute.
|
||||
if not hasattr(x, 'shape'):
|
||||
raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,))
|
||||
input_spec = nest.flatten(input_spec)
|
||||
|
||||
if len(inputs) != len(input_spec):
|
||||
raise ValueError('Layer ' + layer_name + ' expects ' +
|
||||
str(len(input_spec)) + ' inputs, '
|
||||
str(len(input_spec)) + ' input(s), '
|
||||
'but it received ' + str(len(inputs)) +
|
||||
' input tensors. Inputs received: ' + str(inputs))
|
||||
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
|
||||
if spec is None:
|
||||
continue
|
||||
|
||||
if (spec.ndim is not None or
|
||||
spec.min_ndim is not None or
|
||||
spec.max_ndim is not None):
|
||||
if x.shape.ndims is None:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'its rank is undefined, but the layer requires a '
|
||||
'defined rank.')
|
||||
|
||||
shape = tensor_shape.TensorShape(x.shape)
|
||||
if shape.rank is None:
|
||||
return
|
||||
# Check ndim.
|
||||
if spec.ndim is not None:
|
||||
ndim = x.shape.ndims
|
||||
if spec.ndim is not None and not spec.allow_last_axis_squeeze:
|
||||
ndim = shape.rank
|
||||
if ndim != spec.ndim:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'expected ndim=' + str(spec.ndim) + ', found ndim=' +
|
||||
str(ndim) + '. Full shape received: ' +
|
||||
str(x.shape.as_list()))
|
||||
str(tuple(shape)))
|
||||
if spec.max_ndim is not None:
|
||||
ndim = x.shape.ndims
|
||||
ndim = x.shape.rank
|
||||
if ndim is not None and ndim > spec.max_ndim:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'expected max_ndim=' + str(spec.max_ndim) +
|
||||
', found ndim=' + str(ndim))
|
||||
if spec.min_ndim is not None:
|
||||
ndim = x.shape.ndims
|
||||
ndim = x.shape.rank
|
||||
if ndim is not None and ndim < spec.min_ndim:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
': expected min_ndim=' + str(spec.min_ndim) +
|
||||
', found ndim=' + str(ndim) +
|
||||
'. Full shape received: ' +
|
||||
str(x.shape.as_list()))
|
||||
str(tuple(shape)))
|
||||
# Check dtype.
|
||||
if spec.dtype is not None:
|
||||
if x.dtype != spec.dtype:
|
||||
if x.dtype.name != spec.dtype:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'expected dtype=' + str(spec.dtype) +
|
||||
', found dtype=' + str(x.dtype))
|
||||
|
||||
# Check specific shape axes.
|
||||
shape_as_list = shape.as_list()
|
||||
if spec.axes:
|
||||
shape = x.shape.as_list()
|
||||
if shape is not None:
|
||||
for axis, value in spec.axes.items():
|
||||
if hasattr(value, 'value'):
|
||||
value = value.value
|
||||
if value is not None and shape[int(axis)] not in {value, None}:
|
||||
raise ValueError(
|
||||
'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
|
||||
' incompatible with the layer: expected axis ' + str(axis) +
|
||||
' of input shape to have value ' + str(value) +
|
||||
' but received input with shape ' + str(shape))
|
||||
for axis, value in spec.axes.items():
|
||||
if hasattr(value, 'value'):
|
||||
value = value.value
|
||||
if value is not None and shape_as_list[int(axis)] not in {value, None}:
|
||||
raise ValueError(
|
||||
'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
|
||||
' incompatible with the layer: expected axis ' + str(axis) +
|
||||
' of input shape to have value ' + str(value) +
|
||||
' but received input with shape ' + display_shape(x.shape))
|
||||
# Check shape.
|
||||
if spec.shape is not None:
|
||||
shape = x.shape.as_list()
|
||||
if shape is not None:
|
||||
for spec_dim, dim in zip(spec.shape, shape):
|
||||
if spec_dim is not None and dim is not None:
|
||||
if spec_dim != dim:
|
||||
raise ValueError('Input ' + str(input_index) +
|
||||
' is incompatible with layer ' + layer_name +
|
||||
': expected shape=' + str(spec.shape) +
|
||||
', found shape=' + str(shape))
|
||||
if spec.shape is not None and shape.rank is not None:
|
||||
spec_shape = spec.shape
|
||||
if spec.allow_last_axis_squeeze:
|
||||
if shape_as_list and shape_as_list[-1] == 1:
|
||||
shape_as_list = shape_as_list[:-1]
|
||||
if spec_shape and spec_shape[-1] == 1:
|
||||
spec_shape = spec_shape[:-1]
|
||||
for spec_dim, dim in zip(spec_shape, shape_as_list):
|
||||
if spec_dim is not None and dim is not None:
|
||||
if spec_dim != dim:
|
||||
raise ValueError('Input ' + str(input_index) +
|
||||
' is incompatible with layer ' + layer_name +
|
||||
': expected shape=' + str(spec.shape) +
|
||||
', found shape=' + display_shape(x.shape))
|
||||
|
||||
|
||||
def display_shape(shape):
|
||||
return str(tuple(shape.as_list()))
|
||||
|
||||
|
||||
def to_tensor_spec(input_spec, default_dtype=None):
|
||||
|
@ -495,10 +495,16 @@ class Sequential(functional.Functional):
|
||||
|
||||
@property
|
||||
def input_spec(self):
|
||||
if hasattr(self, '_manual_input_spec'):
|
||||
return self._manual_input_spec
|
||||
if self.layers and hasattr(self.layers[0], 'input_spec'):
|
||||
return self.layers[0].input_spec
|
||||
return None
|
||||
|
||||
@input_spec.setter
|
||||
def input_spec(self, value):
|
||||
self._manual_input_spec = value
|
||||
|
||||
@property
|
||||
def _trackable_saved_model_saver(self):
|
||||
return model_serialization.SequentialSavedModelSaver(self)
|
||||
|
@ -324,9 +324,9 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
np_inputs = nest.map_structure(
|
||||
lambda x: np.ones((10,) + tuple(x.shape[1:]), 'float32'), model.inputs)
|
||||
lambda x: np.ones((2,) + tuple(x.shape[1:]), 'float32'), model.inputs)
|
||||
np_outputs = nest.map_structure(
|
||||
lambda x: np.ones((10,) + tuple(x.shape[1:]), 'float32'), model.outputs)
|
||||
lambda x: np.ones((2,) + tuple(x.shape[1:]), 'float32'), model.outputs)
|
||||
model.fit(np_inputs, np_outputs, batch_size=2)
|
||||
model(np_inputs) # Test calling the model directly on inputs.
|
||||
|
||||
@ -402,7 +402,7 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
def test_getitem_slice_with_step_only(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('Complex slicing like this fails in v1')
|
||||
inp = keras.Input(shape=(4, 3, 8))
|
||||
inp = keras.Input(shape=(8,))
|
||||
slice_step = keras.Input(shape=(), dtype='int32')
|
||||
|
||||
out = inp[..., ::slice_step[0]]
|
||||
@ -508,7 +508,7 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
def test_getitem_slice_with_stop_only(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('Complex slicing like this fails in v1')
|
||||
inp = keras.Input(shape=(4, 3, 8))
|
||||
inp = keras.Input(shape=(8,))
|
||||
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||
|
||||
out = inp[:slice_stop[0]]
|
||||
@ -544,7 +544,7 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
def test_getitem_slice_with_stop_and_ellipsis_only(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('Complex slicing like this fails in v1')
|
||||
inp = keras.Input(shape=(4, 3, 8))
|
||||
inp = keras.Input(shape=(8,))
|
||||
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||
|
||||
out = inp[..., :slice_stop[0]]
|
||||
@ -646,14 +646,14 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_numerical_correctness_with_attrs(self):
|
||||
x = ops.convert_to_tensor_v2([[1.5, 1.5], [2.5, 3.5]])
|
||||
inputs = keras.Input(shape=(10,))
|
||||
inputs = keras.Input(shape=(2,))
|
||||
outputs = math_ops.reduce_mean(inputs, axis=1)
|
||||
model = keras.Model(inputs, outputs)
|
||||
y = self.evaluate(model(x))
|
||||
self.assertAllClose(y, [1.5, 3.])
|
||||
|
||||
def test_numerical_correctness_serialization(self):
|
||||
x = ops.convert_to_tensor_v2([-1., 0., -2., 1.])
|
||||
x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]])
|
||||
inputs = keras.Input(shape=(4,))
|
||||
outputs = gen_nn_ops.relu(inputs)
|
||||
model1 = keras.Model(inputs, outputs)
|
||||
|
@ -277,11 +277,6 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
|
||||
def call(self, inputs):
|
||||
return inputs
|
||||
|
||||
if not context.executing_eagerly():
|
||||
layer = CustomerLayer()
|
||||
with self.assertRaisesRegex(ValueError, r'requires a defined rank'):
|
||||
layer.apply(array_ops.placeholder('int32'))
|
||||
|
||||
layer = CustomerLayer()
|
||||
with self.assertRaisesRegex(ValueError, r'expected ndim=2'):
|
||||
layer.apply(constant_op.constant([1]))
|
||||
@ -295,29 +290,24 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testInputSpecMinNdimCheck(self):
|
||||
|
||||
class CustomerLayer(base_layers.Layer):
|
||||
class CustomLayer(base_layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(CustomerLayer, self).__init__()
|
||||
super(CustomLayer, self).__init__()
|
||||
self.input_spec = input_spec.InputSpec(min_ndim=2)
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs
|
||||
|
||||
if not context.executing_eagerly():
|
||||
layer = CustomerLayer()
|
||||
with self.assertRaisesRegex(ValueError, r'requires a defined rank'):
|
||||
layer.apply(array_ops.placeholder('int32'))
|
||||
|
||||
layer = CustomerLayer()
|
||||
layer = CustomLayer()
|
||||
with self.assertRaisesRegex(ValueError, r'expected min_ndim=2'):
|
||||
layer.apply(constant_op.constant([1]))
|
||||
|
||||
# Works
|
||||
layer = CustomerLayer()
|
||||
layer = CustomLayer()
|
||||
layer.apply(constant_op.constant([[1], [2]]))
|
||||
|
||||
layer = CustomerLayer()
|
||||
layer = CustomLayer()
|
||||
layer.apply(constant_op.constant([[[1], [2]]]))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
@ -332,11 +322,6 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
|
||||
def call(self, inputs):
|
||||
return inputs
|
||||
|
||||
if not context.executing_eagerly():
|
||||
layer = CustomerLayer()
|
||||
with self.assertRaisesRegex(ValueError, r'requires a defined rank'):
|
||||
layer.apply(array_ops.placeholder('int32'))
|
||||
|
||||
layer = CustomerLayer()
|
||||
with self.assertRaisesRegex(ValueError, r'expected max_ndim=2'):
|
||||
layer.apply(constant_op.constant([[[1], [2]]]))
|
||||
|
@ -4,7 +4,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
|
@ -4,7 +4,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
|
@ -4,7 +4,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
|
Loading…
Reference in New Issue
Block a user