Fix input mapping issue when model is constructed/tested with dict input tensor.
The mapping of the dict input tensors was not correct since it was still using the tensor name, rather than the key of the tensor when build the model. This cause the issue down the stream when the inputs are provided with unknown keys. We had some backup logic, which will probably do correct things, eg just flat the dict to keep the original order, which was correct most of the case, but not very reliable. In this change, we make the behavior change: 1. When model is build with dict input tensors, the key of the tensor, instead of the name, will be used to map the tensor with input data. 2. Unknown keys in the input data will result into a warning. We didn't throw error since user might do it intentionally, eg using part of the model to test with full input data. PiperOrigin-RevId: 317776370 Change-Id: I91983443f2b770cb0b45ddb7726f52708cb91d61
This commit is contained in:
parent
e60cf08994
commit
265de52331
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
|||
import collections
|
||||
import copy
|
||||
import itertools
|
||||
import warnings
|
||||
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
|
@ -131,10 +132,10 @@ class Functional(training_lib.Model):
|
|||
|
||||
# Models constructed with a single Tensor or list of Tensors can
|
||||
# be called with a dict, where the keys of the dict are the names
|
||||
# of the `Input` objects. Extra keys are ignored.
|
||||
# of the `Input` objects. Extra keys are ignored with warning.
|
||||
self._enable_dict_to_input_mapping = (
|
||||
not nest.is_sequence(self._nested_inputs) or
|
||||
(isinstance(self._nested_inputs, (list, tuple)) and
|
||||
(isinstance(self._nested_inputs, (list, tuple, dict)) and
|
||||
not any(nest.is_sequence(t) for t in self._nested_inputs)))
|
||||
|
||||
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
||||
|
@ -524,10 +525,27 @@ class Functional(training_lib.Model):
|
|||
ref_inputs = self._nested_inputs
|
||||
if not nest.is_sequence(ref_inputs):
|
||||
ref_inputs = [self._nested_inputs]
|
||||
if isinstance(ref_inputs, dict):
|
||||
# In the case that the graph is constructed with dict input tensors,
|
||||
# We will use the original dict key to map with the keys in the input
|
||||
# data. Note that the model.inputs is using nest.flatten to process the
|
||||
# input tensors, which means the dict input tensors are ordered by their
|
||||
# keys.
|
||||
ref_input_names = sorted(ref_inputs.keys())
|
||||
else:
|
||||
ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs]
|
||||
|
||||
# Raise an warning if there are more input data comparing to input tensor
|
||||
if len(tensors) > len(ref_input_names):
|
||||
warnings.warn(
|
||||
'Input dict contained keys {} which did not match any model input. '
|
||||
'They will be ignored by the model.'.format(
|
||||
[n for n in tensors.keys() if n not in ref_input_names])
|
||||
)
|
||||
|
||||
try:
|
||||
# Flatten in the order `Input`s were passed during Model construction.
|
||||
return [tensors[inp._keras_history.layer.name] for inp in ref_inputs]
|
||||
return [tensors[n] for n in ref_input_names]
|
||||
except KeyError:
|
||||
# TODO(b/151582614)
|
||||
return nest.flatten(tensors)
|
||||
|
|
|
@ -18,8 +18,11 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
@ -43,6 +46,7 @@ from tensorflow.python.keras.utils import tf_utils
|
|||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking.util import Checkpoint
|
||||
|
@ -1565,6 +1569,48 @@ class DefaultShapeInferenceBehaviorTest(keras_parameterized.TestCase):
|
|||
self.assertEqual(config['layers'][2]['inbound_nodes'],
|
||||
[[['in1', 0, 0, {}], ['in2', 0, 0, {}]]])
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_dict_inputs_tensors(self):
|
||||
# Note that this test is running with v2 eager only, since the v1
|
||||
# will behave differently wrt to dict input for training.
|
||||
inputs = {
|
||||
'sentence2': input_layer_lib.Input(
|
||||
shape=(), name='a', dtype=dtypes.string),
|
||||
'sentence1': input_layer_lib.Input(
|
||||
shape=(), name='b', dtype=dtypes.string),
|
||||
}
|
||||
strlen = layers.Lambda(string_ops.string_length_v2)
|
||||
diff = layers.Subtract()(
|
||||
[strlen(inputs['sentence1']), strlen(inputs['sentence2'])])
|
||||
diff = math_ops.cast(diff, dtypes.float32)
|
||||
model = training_lib.Model(inputs, diff)
|
||||
|
||||
extra_keys = {
|
||||
'sentence1': constant_op.constant(['brown fox', 'lazy dog']),
|
||||
'sentence2': constant_op.constant(['owl', 'cheeky cat']),
|
||||
'label': constant_op.constant([0, 1]),
|
||||
}
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter('always')
|
||||
model(extra_keys)
|
||||
self.assertIn('ignored by the model', str(w[-1].message))
|
||||
|
||||
model.compile('sgd', 'mse')
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter('always')
|
||||
model.fit(extra_keys, y=constant_op.constant([0, 1]), steps_per_epoch=1)
|
||||
self.assertIn('ignored by the model', str(w[-1].message))
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter('always')
|
||||
model.evaluate(extra_keys, constant_op.constant([0, 1]))
|
||||
self.assertIn('ignored by the model', str(w[-1].message))
|
||||
|
||||
# Make sure the model inputs are sorted with the dict keys.
|
||||
self.assertEqual(model.inputs[0]._keras_history.layer.name, 'b')
|
||||
self.assertEqual(model.inputs[1]._keras_history.layer.name, 'a')
|
||||
|
||||
|
||||
class GraphUtilsTest(test.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue