diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 8e6a31a98b5..1109c0d5ed8 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -329,13 +329,13 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): never throw unexpected errors in an unrelated workflow). Args: - input_shape: Single tuple, TensorShape, or list of shapes, where shapes - are tuples, integers, or TensorShapes. + input_shape: Single tuple, TensorShape, or list/dict of shapes, where + shapes are tuples, integers, or TensorShapes. Raises: ValueError: 1. In case of invalid user-provided data (not of type tuple, - list, or TensorShape). + list, TensorShape, or dict). 2. If the model requires call arguments that are agnostic to the input shapes (positional or kwarg in call signature). 3. If not all layers were properly built. @@ -351,7 +351,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): if input_shape is None: raise ValueError('Input shape must be defined when calling build on a ' 'model subclass network.') - valid_types = (tuple, list, tensor_shape.TensorShape) + valid_types = (tuple, list, tensor_shape.TensorShape, dict) if not isinstance(input_shape, valid_types): raise ValueError('Specified input shape is not one of the valid types. ' 'Please specify a batch input shape of type tuple or ' diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 2d7d57559a6..6a833560cff 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -3707,6 +3707,22 @@ class TestBuildCustomModel(keras_parameterized.TestCase): model.build([None, 1]) self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1]) + @keras_parameterized.run_all_keras_modes + def test_build_dict_inputs(self): + + class MyModel(training_module.Model): + + def __init__(self): + super(MyModel, self).__init__() + self.l1 = layers_module.Dense(1) + + def call(self, inputs): + return self.l1(inputs['x']) + + model = MyModel() + model.build({'x': [None, 16]}) + self.assertEqual(model.l1.kernel.shape.as_list(), [16, 1]) + if __name__ == '__main__': test.main()