diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 21df53ed59d..912c782deef 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -239,17 +239,12 @@ class Network(base_layer.Layer): outputs = outputs[0] self._nested_outputs = outputs self._nested_inputs = inputs + self._nested_inputs_are_flat_list = ( + isinstance(self._nested_inputs, (list, tuple)) and + not any(nest.is_sequence(t) for t in self._nested_inputs)) self.inputs = nest.flatten(inputs) self.outputs = nest.flatten(outputs) - # Models constructed with a single Tensor or list of Tensors can - # be called with a dict, where the keys of the dict are the names - # of the `Input` objects. Extra keys are ignored. - self._enable_dict_to_input_mapping = ( - not nest.is_sequence(self._nested_inputs) or - (isinstance(self._nested_inputs, (list, tuple)) and - not any(nest.is_sequence(t) for t in self._nested_inputs))) - if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): base_layer_utils.create_keras_history(self._nested_outputs) @@ -917,16 +912,16 @@ class Network(base_layer.Layer): def _flatten_to_reference_inputs(self, tensors): """Maps `tensors` to their respective `keras.Input`.""" - if self._enable_dict_to_input_mapping and isinstance(tensors, dict): - ref_inputs = self._nested_inputs - if not nest.is_sequence(ref_inputs): - ref_inputs = [self._nested_inputs] - - # Flatten in the order the `Input`s were passed during Model construction. - return [tensors[inp._keras_history.layer.name] for inp in ref_inputs] - - # Otherwise both self.inputs and tensors will already be in same order. - return nest.flatten(tensors) + if self._nested_inputs_are_flat_list and isinstance(tensors, dict): + # Backwards compat: Allows passing a dict to a Model constructed with a + # list. Matches dict keys to input names. + tensors = [ + tensors[inp._keras_history.layer.name] for inp in self._nested_inputs + ] + else: + # Otherwise both self.inputs and tensors will be flattened in same order. + tensors = nest.flatten(tensors) + return tensors def _conform_to_reference_input(self, tensor, ref_input): """Set shape and dtype based on `keras.Input`s.""" diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py index e227d08f595..46f4d951407 100644 --- a/tensorflow/python/keras/engine/network_test.py +++ b/tensorflow/python/keras/engine/network_test.py @@ -1576,20 +1576,6 @@ class NestedNetworkTest(keras_parameterized.TestCase): res = reversed_model({'a': a_val, 'b': b_val}) self.assertAllClose(self.evaluate(res), self.evaluate(b_val)) - def test_dict_mapping_single_input(self): - b = input_layer_lib.Input(shape=(1,), name='b') - outputs = b * 2 - model = training_lib.Model(b, outputs) - - b_val = array_ops.ones((1, 1)) - extra_val = array_ops.ones((1, 10)) - - inputs = {'a': extra_val, 'b': b_val} - res = model(inputs) - - # Check that 'b' was used and 'a' was ignored. - self.assertEqual(res.shape.as_list(), [1, 1]) - @combinations.generate(combinations.keras_mode_combinations()) class AddLossTest(keras_parameterized.TestCase):