Allow unused keys when a dict is passed to a single-input Functional API model.
Ensure that the key mapping to the name of the Input is used during Model execution. PiperOrigin-RevId: 300864855 Change-Id: I7871b7fb4df3023a5d3307e3e513dc68df991239
This commit is contained in:
parent
8ca8ad2aec
commit
88ca767048
@ -239,17 +239,12 @@ class Network(base_layer.Layer):
|
|||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
self._nested_outputs = outputs
|
self._nested_outputs = outputs
|
||||||
self._nested_inputs = inputs
|
self._nested_inputs = inputs
|
||||||
|
self._nested_inputs_are_flat_list = (
|
||||||
|
isinstance(self._nested_inputs, (list, tuple)) and
|
||||||
|
not any(nest.is_sequence(t) for t in self._nested_inputs))
|
||||||
self.inputs = nest.flatten(inputs)
|
self.inputs = nest.flatten(inputs)
|
||||||
self.outputs = nest.flatten(outputs)
|
self.outputs = nest.flatten(outputs)
|
||||||
|
|
||||||
# Models constructed with a single Tensor or list of Tensors can
|
|
||||||
# be called with a dict, where the keys of the dict are the names
|
|
||||||
# of the `Input` objects. Extra keys are ignored.
|
|
||||||
self._enable_dict_to_input_mapping = (
|
|
||||||
not nest.is_sequence(self._nested_inputs) or
|
|
||||||
(isinstance(self._nested_inputs, (list, tuple)) and
|
|
||||||
not any(nest.is_sequence(t) for t in self._nested_inputs)))
|
|
||||||
|
|
||||||
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
||||||
base_layer_utils.create_keras_history(self._nested_outputs)
|
base_layer_utils.create_keras_history(self._nested_outputs)
|
||||||
|
|
||||||
@ -917,16 +912,16 @@ class Network(base_layer.Layer):
|
|||||||
|
|
||||||
def _flatten_to_reference_inputs(self, tensors):
|
def _flatten_to_reference_inputs(self, tensors):
|
||||||
"""Maps `tensors` to their respective `keras.Input`."""
|
"""Maps `tensors` to their respective `keras.Input`."""
|
||||||
if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
|
if self._nested_inputs_are_flat_list and isinstance(tensors, dict):
|
||||||
ref_inputs = self._nested_inputs
|
# Backwards compat: Allows passing a dict to a Model constructed with a
|
||||||
if not nest.is_sequence(ref_inputs):
|
# list. Matches dict keys to input names.
|
||||||
ref_inputs = [self._nested_inputs]
|
tensors = [
|
||||||
|
tensors[inp._keras_history.layer.name] for inp in self._nested_inputs
|
||||||
# Flatten in the order the `Input`s were passed during Model construction.
|
]
|
||||||
return [tensors[inp._keras_history.layer.name] for inp in ref_inputs]
|
else:
|
||||||
|
# Otherwise both self.inputs and tensors will be flattened in same order.
|
||||||
# Otherwise both self.inputs and tensors will already be in same order.
|
tensors = nest.flatten(tensors)
|
||||||
return nest.flatten(tensors)
|
return tensors
|
||||||
|
|
||||||
def _conform_to_reference_input(self, tensor, ref_input):
|
def _conform_to_reference_input(self, tensor, ref_input):
|
||||||
"""Set shape and dtype based on `keras.Input`s."""
|
"""Set shape and dtype based on `keras.Input`s."""
|
||||||
|
@ -1576,20 +1576,6 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
|||||||
res = reversed_model({'a': a_val, 'b': b_val})
|
res = reversed_model({'a': a_val, 'b': b_val})
|
||||||
self.assertAllClose(self.evaluate(res), self.evaluate(b_val))
|
self.assertAllClose(self.evaluate(res), self.evaluate(b_val))
|
||||||
|
|
||||||
def test_dict_mapping_single_input(self):
|
|
||||||
b = input_layer_lib.Input(shape=(1,), name='b')
|
|
||||||
outputs = b * 2
|
|
||||||
model = training_lib.Model(b, outputs)
|
|
||||||
|
|
||||||
b_val = array_ops.ones((1, 1))
|
|
||||||
extra_val = array_ops.ones((1, 10))
|
|
||||||
|
|
||||||
inputs = {'a': extra_val, 'b': b_val}
|
|
||||||
res = model(inputs)
|
|
||||||
|
|
||||||
# Check that 'b' was used and 'a' was ignored.
|
|
||||||
self.assertEqual(res.shape.as_list(), [1, 1])
|
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(combinations.keras_mode_combinations())
|
@combinations.generate(combinations.keras_mode_combinations())
|
||||||
class AddLossTest(keras_parameterized.TestCase):
|
class AddLossTest(keras_parameterized.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user