Allow unused keys when a dict is passed to a single-input Functional API model.

Ensure that the key mapping to the name of the Input is used during Model
execution.

PiperOrigin-RevId: 300864855
Change-Id: I7871b7fb4df3023a5d3307e3e513dc68df991239
This commit is contained in:
A. Unique TensorFlower 2020-03-13 18:35:36 -07:00 committed by TensorFlower Gardener
parent 8ca8ad2aec
commit 88ca767048
2 changed files with 13 additions and 32 deletions

View File

@ -239,17 +239,12 @@ class Network(base_layer.Layer):
outputs = outputs[0]
self._nested_outputs = outputs
self._nested_inputs = inputs
self._nested_inputs_are_flat_list = (
isinstance(self._nested_inputs, (list, tuple)) and
not any(nest.is_sequence(t) for t in self._nested_inputs))
self.inputs = nest.flatten(inputs)
self.outputs = nest.flatten(outputs)
# Models constructed with a single Tensor or list of Tensors can
# be called with a dict, where the keys of the dict are the names
# of the `Input` objects. Extra keys are ignored.
self._enable_dict_to_input_mapping = (
not nest.is_sequence(self._nested_inputs) or
(isinstance(self._nested_inputs, (list, tuple)) and
not any(nest.is_sequence(t) for t in self._nested_inputs)))
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
base_layer_utils.create_keras_history(self._nested_outputs)
@ -917,16 +912,16 @@ class Network(base_layer.Layer):
def _flatten_to_reference_inputs(self, tensors):
"""Maps `tensors` to their respective `keras.Input`."""
if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
ref_inputs = self._nested_inputs
if not nest.is_sequence(ref_inputs):
ref_inputs = [self._nested_inputs]
# Flatten in the order the `Input`s were passed during Model construction.
return [tensors[inp._keras_history.layer.name] for inp in ref_inputs]
# Otherwise both self.inputs and tensors will already be in same order.
return nest.flatten(tensors)
if self._nested_inputs_are_flat_list and isinstance(tensors, dict):
# Backwards compat: Allows passing a dict to a Model constructed with a
# list. Matches dict keys to input names.
tensors = [
tensors[inp._keras_history.layer.name] for inp in self._nested_inputs
]
else:
# Otherwise both self.inputs and tensors will be flattened in same order.
tensors = nest.flatten(tensors)
return tensors
def _conform_to_reference_input(self, tensor, ref_input):
"""Set shape and dtype based on `keras.Input`s."""

View File

@ -1576,20 +1576,6 @@ class NestedNetworkTest(keras_parameterized.TestCase):
res = reversed_model({'a': a_val, 'b': b_val})
self.assertAllClose(self.evaluate(res), self.evaluate(b_val))
def test_dict_mapping_single_input(self):
b = input_layer_lib.Input(shape=(1,), name='b')
outputs = b * 2
model = training_lib.Model(b, outputs)
b_val = array_ops.ones((1, 1))
extra_val = array_ops.ones((1, 10))
inputs = {'a': extra_val, 'b': b_val}
res = model(inputs)
# Check that 'b' was used and 'a' was ignored.
self.assertEqual(res.shape.as_list(), [1, 1])
@combinations.generate(combinations.keras_mode_combinations())
class AddLossTest(keras_parameterized.TestCase):