From 265de52331a10af793f733d21e2152123819a269 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 22 Jun 2020 18:20:43 -0700 Subject: [PATCH] Fix input mapping issue when model is constructed/tested with dict input tensor. The mapping of the dict input tensors was not correct since it was still using the tensor name, rather than the key of the tensor when build the model. This cause the issue down the stream when the inputs are provided with unknown keys. We had some backup logic, which will probably do correct things, eg just flat the dict to keep the original order, which was correct most of the case, but not very reliable. In this change, we make the behavior change: 1. When model is build with dict input tensors, the key of the tensor, instead of the name, will be used to map the tensor with input data. 2. Unknown keys in the input data will result into a warning. We didn't throw error since user might do it intentionally, eg using part of the model to test with full input data. PiperOrigin-RevId: 317776370 Change-Id: I91983443f2b770cb0b45ddb7726f52708cb91d61 --- tensorflow/python/keras/engine/functional.py | 24 ++++++++-- .../python/keras/engine/functional_test.py | 46 +++++++++++++++++++ 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 0612d70044d..fd80e7f8bb4 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -22,6 +22,7 @@ from __future__ import print_function import collections import copy import itertools +import warnings from six.moves import zip # pylint: disable=redefined-builtin @@ -131,10 +132,10 @@ class Functional(training_lib.Model): # Models constructed with a single Tensor or list of Tensors can # be called with a dict, where the keys of the dict are the names - # of the `Input` objects. Extra keys are ignored. + # of the `Input` objects. Extra keys are ignored with warning. self._enable_dict_to_input_mapping = ( not nest.is_sequence(self._nested_inputs) or - (isinstance(self._nested_inputs, (list, tuple)) and + (isinstance(self._nested_inputs, (list, tuple, dict)) and not any(nest.is_sequence(t) for t in self._nested_inputs))) if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): @@ -524,10 +525,27 @@ class Functional(training_lib.Model): ref_inputs = self._nested_inputs if not nest.is_sequence(ref_inputs): ref_inputs = [self._nested_inputs] + if isinstance(ref_inputs, dict): + # In the case that the graph is constructed with dict input tensors, + # We will use the original dict key to map with the keys in the input + # data. Note that the model.inputs is using nest.flatten to process the + # input tensors, which means the dict input tensors are ordered by their + # keys. + ref_input_names = sorted(ref_inputs.keys()) + else: + ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs] + + # Raise an warning if there are more input data comparing to input tensor + if len(tensors) > len(ref_input_names): + warnings.warn( + 'Input dict contained keys {} which did not match any model input. ' + 'They will be ignored by the model.'.format( + [n for n in tensors.keys() if n not in ref_input_names]) + ) try: # Flatten in the order `Input`s were passed during Model construction. - return [tensors[inp._keras_history.layer.name] for inp in ref_inputs] + return [tensors[n] for n in ref_input_names] except KeyError: # TODO(b/151582614) return nest.flatten(tensors) diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index 3c14411deb9..0e82d95d3de 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + import numpy as np + from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op @@ -43,6 +46,7 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import string_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test from tensorflow.python.training.tracking.util import Checkpoint @@ -1565,6 +1569,48 @@ class DefaultShapeInferenceBehaviorTest(keras_parameterized.TestCase): self.assertEqual(config['layers'][2]['inbound_nodes'], [[['in1', 0, 0, {}], ['in2', 0, 0, {}]]]) + @combinations.generate(combinations.combine(mode=['eager'])) + def test_dict_inputs_tensors(self): + # Note that this test is running with v2 eager only, since the v1 + # will behave differently wrt to dict input for training. + inputs = { + 'sentence2': input_layer_lib.Input( + shape=(), name='a', dtype=dtypes.string), + 'sentence1': input_layer_lib.Input( + shape=(), name='b', dtype=dtypes.string), + } + strlen = layers.Lambda(string_ops.string_length_v2) + diff = layers.Subtract()( + [strlen(inputs['sentence1']), strlen(inputs['sentence2'])]) + diff = math_ops.cast(diff, dtypes.float32) + model = training_lib.Model(inputs, diff) + + extra_keys = { + 'sentence1': constant_op.constant(['brown fox', 'lazy dog']), + 'sentence2': constant_op.constant(['owl', 'cheeky cat']), + 'label': constant_op.constant([0, 1]), + } + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + model(extra_keys) + self.assertIn('ignored by the model', str(w[-1].message)) + + model.compile('sgd', 'mse') + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + model.fit(extra_keys, y=constant_op.constant([0, 1]), steps_per_epoch=1) + self.assertIn('ignored by the model', str(w[-1].message)) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + model.evaluate(extra_keys, constant_op.constant([0, 1])) + self.assertIn('ignored by the model', str(w[-1].message)) + + # Make sure the model inputs are sorted with the dict keys. + self.assertEqual(model.inputs[0]._keras_history.layer.name, 'b') + self.assertEqual(model.inputs[1]._keras_history.layer.name, 'a') + class GraphUtilsTest(test.TestCase):