Support building model with dict inputs
Support already exists.
e79e90cb04/tensorflow/python/keras/engine/training.py (L377)
`valid_types` was updated to reflect that.
PiperOrigin-RevId: 334668220
Change-Id: I916ee0fcfb6871015751c47ca91a3d0102a6bb47
This commit is contained in:
parent
55a21bf47e
commit
b43033b589
@ -329,13 +329,13 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
never throw unexpected errors in an unrelated workflow).
|
||||
|
||||
Args:
|
||||
input_shape: Single tuple, TensorShape, or list of shapes, where shapes
|
||||
are tuples, integers, or TensorShapes.
|
||||
input_shape: Single tuple, TensorShape, or list/dict of shapes, where
|
||||
shapes are tuples, integers, or TensorShapes.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
1. In case of invalid user-provided data (not of type tuple,
|
||||
list, or TensorShape).
|
||||
list, TensorShape, or dict).
|
||||
2. If the model requires call arguments that are agnostic
|
||||
to the input shapes (positional or kwarg in call signature).
|
||||
3. If not all layers were properly built.
|
||||
@ -351,7 +351,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
if input_shape is None:
|
||||
raise ValueError('Input shape must be defined when calling build on a '
|
||||
'model subclass network.')
|
||||
valid_types = (tuple, list, tensor_shape.TensorShape)
|
||||
valid_types = (tuple, list, tensor_shape.TensorShape, dict)
|
||||
if not isinstance(input_shape, valid_types):
|
||||
raise ValueError('Specified input shape is not one of the valid types. '
|
||||
'Please specify a batch input shape of type tuple or '
|
||||
|
@ -3707,6 +3707,22 @@ class TestBuildCustomModel(keras_parameterized.TestCase):
|
||||
model.build([None, 1])
|
||||
self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_build_dict_inputs(self):
|
||||
|
||||
class MyModel(training_module.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self.l1 = layers_module.Dense(1)
|
||||
|
||||
def call(self, inputs):
|
||||
return self.l1(inputs['x'])
|
||||
|
||||
model = MyModel()
|
||||
model.build({'x': [None, 16]})
|
||||
self.assertEqual(model.l1.kernel.shape.as_list(), [16, 1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user