Support building model with dict inputs

Support already exists.
e79e90cb04/tensorflow/python/keras/engine/training.py (L377)

`valid_types` was updated to reflect that.

PiperOrigin-RevId: 334668220
Change-Id: I916ee0fcfb6871015751c47ca91a3d0102a6bb47
This commit is contained in:
Philip Pham 2020-09-30 13:35:47 -07:00 committed by TensorFlower Gardener
parent 55a21bf47e
commit b43033b589
2 changed files with 20 additions and 4 deletions

View File

@ -329,13 +329,13 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
never throw unexpected errors in an unrelated workflow).
Args:
input_shape: Single tuple, TensorShape, or list of shapes, where shapes
are tuples, integers, or TensorShapes.
input_shape: Single tuple, TensorShape, or list/dict of shapes, where
shapes are tuples, integers, or TensorShapes.
Raises:
ValueError:
1. In case of invalid user-provided data (not of type tuple,
list, or TensorShape).
list, TensorShape, or dict).
2. If the model requires call arguments that are agnostic
to the input shapes (positional or kwarg in call signature).
3. If not all layers were properly built.
@ -351,7 +351,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
if input_shape is None:
raise ValueError('Input shape must be defined when calling build on a '
'model subclass network.')
valid_types = (tuple, list, tensor_shape.TensorShape)
valid_types = (tuple, list, tensor_shape.TensorShape, dict)
if not isinstance(input_shape, valid_types):
raise ValueError('Specified input shape is not one of the valid types. '
'Please specify a batch input shape of type tuple or '

View File

@ -3707,6 +3707,22 @@ class TestBuildCustomModel(keras_parameterized.TestCase):
model.build([None, 1])
self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1])
@keras_parameterized.run_all_keras_modes
def test_build_dict_inputs(self):
class MyModel(training_module.Model):
def __init__(self):
super(MyModel, self).__init__()
self.l1 = layers_module.Dense(1)
def call(self, inputs):
return self.l1(inputs['x'])
model = MyModel()
model.build({'x': [None, 16]})
self.assertEqual(model.l1.kernel.shape.as_list(), [16, 1])
if __name__ == '__main__':
test.main()