Allow unused keys when a dict is passed to a single-input Functional API model.
Ensure that the key mapping to the name of the Input is used during Model execution. PiperOrigin-RevId: 300864855 Change-Id: I7871b7fb4df3023a5d3307e3e513dc68df991239
This commit is contained in:
parent
8ca8ad2aec
commit
88ca767048
@ -239,17 +239,12 @@ class Network(base_layer.Layer):
|
||||
outputs = outputs[0]
|
||||
self._nested_outputs = outputs
|
||||
self._nested_inputs = inputs
|
||||
self._nested_inputs_are_flat_list = (
|
||||
isinstance(self._nested_inputs, (list, tuple)) and
|
||||
not any(nest.is_sequence(t) for t in self._nested_inputs))
|
||||
self.inputs = nest.flatten(inputs)
|
||||
self.outputs = nest.flatten(outputs)
|
||||
|
||||
# Models constructed with a single Tensor or list of Tensors can
|
||||
# be called with a dict, where the keys of the dict are the names
|
||||
# of the `Input` objects. Extra keys are ignored.
|
||||
self._enable_dict_to_input_mapping = (
|
||||
not nest.is_sequence(self._nested_inputs) or
|
||||
(isinstance(self._nested_inputs, (list, tuple)) and
|
||||
not any(nest.is_sequence(t) for t in self._nested_inputs)))
|
||||
|
||||
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
||||
base_layer_utils.create_keras_history(self._nested_outputs)
|
||||
|
||||
@ -917,16 +912,16 @@ class Network(base_layer.Layer):
|
||||
|
||||
def _flatten_to_reference_inputs(self, tensors):
|
||||
"""Maps `tensors` to their respective `keras.Input`."""
|
||||
if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
|
||||
ref_inputs = self._nested_inputs
|
||||
if not nest.is_sequence(ref_inputs):
|
||||
ref_inputs = [self._nested_inputs]
|
||||
|
||||
# Flatten in the order the `Input`s were passed during Model construction.
|
||||
return [tensors[inp._keras_history.layer.name] for inp in ref_inputs]
|
||||
|
||||
# Otherwise both self.inputs and tensors will already be in same order.
|
||||
return nest.flatten(tensors)
|
||||
if self._nested_inputs_are_flat_list and isinstance(tensors, dict):
|
||||
# Backwards compat: Allows passing a dict to a Model constructed with a
|
||||
# list. Matches dict keys to input names.
|
||||
tensors = [
|
||||
tensors[inp._keras_history.layer.name] for inp in self._nested_inputs
|
||||
]
|
||||
else:
|
||||
# Otherwise both self.inputs and tensors will be flattened in same order.
|
||||
tensors = nest.flatten(tensors)
|
||||
return tensors
|
||||
|
||||
def _conform_to_reference_input(self, tensor, ref_input):
|
||||
"""Set shape and dtype based on `keras.Input`s."""
|
||||
|
@ -1576,20 +1576,6 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
||||
res = reversed_model({'a': a_val, 'b': b_val})
|
||||
self.assertAllClose(self.evaluate(res), self.evaluate(b_val))
|
||||
|
||||
def test_dict_mapping_single_input(self):
|
||||
b = input_layer_lib.Input(shape=(1,), name='b')
|
||||
outputs = b * 2
|
||||
model = training_lib.Model(b, outputs)
|
||||
|
||||
b_val = array_ops.ones((1, 1))
|
||||
extra_val = array_ops.ones((1, 10))
|
||||
|
||||
inputs = {'a': extra_val, 'b': b_val}
|
||||
res = model(inputs)
|
||||
|
||||
# Check that 'b' was used and 'a' was ignored.
|
||||
self.assertEqual(res.shape.as_list(), [1, 1])
|
||||
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
class AddLossTest(keras_parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user