diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 0612d70044d..fd80e7f8bb4 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -22,6 +22,7 @@ from __future__ import print_function import collections import copy import itertools +import warnings from six.moves import zip # pylint: disable=redefined-builtin @@ -131,10 +132,10 @@ class Functional(training_lib.Model): # Models constructed with a single Tensor or list of Tensors can # be called with a dict, where the keys of the dict are the names - # of the `Input` objects. Extra keys are ignored. + # of the `Input` objects. Extra keys are ignored with warning. self._enable_dict_to_input_mapping = ( not nest.is_sequence(self._nested_inputs) or - (isinstance(self._nested_inputs, (list, tuple)) and + (isinstance(self._nested_inputs, (list, tuple, dict)) and not any(nest.is_sequence(t) for t in self._nested_inputs))) if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): @@ -524,10 +525,27 @@ class Functional(training_lib.Model): ref_inputs = self._nested_inputs if not nest.is_sequence(ref_inputs): ref_inputs = [self._nested_inputs] + if isinstance(ref_inputs, dict): + # In the case that the graph is constructed with dict input tensors, + # We will use the original dict key to map with the keys in the input + # data. Note that the model.inputs is using nest.flatten to process the + # input tensors, which means the dict input tensors are ordered by their + # keys. + ref_input_names = sorted(ref_inputs.keys()) + else: + ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs] + + # Raise an warning if there are more input data comparing to input tensor + if len(tensors) > len(ref_input_names): + warnings.warn( + 'Input dict contained keys {} which did not match any model input. ' + 'They will be ignored by the model.'.format( + [n for n in tensors.keys() if n not in ref_input_names]) + ) try: # Flatten in the order `Input`s were passed during Model construction. - return [tensors[inp._keras_history.layer.name] for inp in ref_inputs] + return [tensors[n] for n in ref_input_names] except KeyError: # TODO(b/151582614) return nest.flatten(tensors) diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index 3c14411deb9..0e82d95d3de 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + import numpy as np + from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op @@ -43,6 +46,7 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import string_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test from tensorflow.python.training.tracking.util import Checkpoint @@ -1565,6 +1569,48 @@ class DefaultShapeInferenceBehaviorTest(keras_parameterized.TestCase): self.assertEqual(config['layers'][2]['inbound_nodes'], [[['in1', 0, 0, {}], ['in2', 0, 0, {}]]]) + @combinations.generate(combinations.combine(mode=['eager'])) + def test_dict_inputs_tensors(self): + # Note that this test is running with v2 eager only, since the v1 + # will behave differently wrt to dict input for training. + inputs = { + 'sentence2': input_layer_lib.Input( + shape=(), name='a', dtype=dtypes.string), + 'sentence1': input_layer_lib.Input( + shape=(), name='b', dtype=dtypes.string), + } + strlen = layers.Lambda(string_ops.string_length_v2) + diff = layers.Subtract()( + [strlen(inputs['sentence1']), strlen(inputs['sentence2'])]) + diff = math_ops.cast(diff, dtypes.float32) + model = training_lib.Model(inputs, diff) + + extra_keys = { + 'sentence1': constant_op.constant(['brown fox', 'lazy dog']), + 'sentence2': constant_op.constant(['owl', 'cheeky cat']), + 'label': constant_op.constant([0, 1]), + } + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + model(extra_keys) + self.assertIn('ignored by the model', str(w[-1].message)) + + model.compile('sgd', 'mse') + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + model.fit(extra_keys, y=constant_op.constant([0, 1]), steps_per_epoch=1) + self.assertIn('ignored by the model', str(w[-1].message)) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + model.evaluate(extra_keys, constant_op.constant([0, 1])) + self.assertIn('ignored by the model', str(w[-1].message)) + + # Make sure the model inputs are sorted with the dict keys. + self.assertEqual(model.inputs[0]._keras_history.layer.name, 'b') + self.assertEqual(model.inputs[1]._keras_history.layer.name, 'a') + class GraphUtilsTest(test.TestCase):