From e6a0a9e1275da92621c0515438144f7dec9d623d Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Fri, 11 Sep 2020 10:52:56 -0700 Subject: [PATCH] Fix input mapping logic for nested dictionaries in Functional Model. PiperOrigin-RevId: 331180009 Change-Id: I7654c7e6ee6d8a0a486aaa58f1473f14e6ae6bc9 --- tensorflow/python/keras/engine/functional.py | 14 ++++++++---- .../python/keras/engine/functional_test.py | 22 +++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index e1399ba6777..b20c1a6160c 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -140,10 +140,16 @@ class Functional(training_lib.Model): # Models constructed with a single Tensor or list of Tensors can # be called with a dict, where the keys of the dict are the names # of the `Input` objects. Extra keys are ignored with warning. - self._enable_dict_to_input_mapping = ( - not nest.is_nested(self._nested_inputs) or - (isinstance(self._nested_inputs, (list, tuple, dict)) and - not any(nest.is_nested(t) for t in self._nested_inputs))) + if not nest.is_nested(self._nested_inputs): + self._enable_dict_to_input_mapping = True + elif (isinstance(self._nested_inputs, (list, tuple)) and + not any(nest.is_nested(t) for t in self._nested_inputs)): + self._enable_dict_to_input_mapping = True + elif (isinstance(self._nested_inputs, dict) and + not any(nest.is_nested(t) for t in self._nested_inputs.values())): + self._enable_dict_to_input_mapping = True + else: + self._enable_dict_to_input_mapping = False if not keras_tensor.keras_tensors_enabled(): if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index 63e735810fc..8427517f235 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -1972,6 +1972,28 @@ class NestedNetworkTest(keras_parameterized.TestCase): # Check that 'b' was used and 'a' was ignored. self.assertEqual(res.shape.as_list(), [1, 1]) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_nested_dict_mapping(self): + a = input_layer_lib.Input(shape=(1,), dtype='int32', name='a') + b = input_layer_lib.Input(shape=(1,), dtype='int32', name='b') + c = input_layer_lib.Input(shape=(1,), dtype='int32', name='c') + d = input_layer_lib.Input(shape=(1,), dtype='int32', name='d') + inputs = {'a': (a, b), 'c': (c, d)} + outputs = 1000 * a + 100 * b + 10 * c + d + model = training_lib.Model(inputs, outputs) + + a_val = array_ops.ones((1, 1), dtype='int32') + b_val = 2 * array_ops.ones((1, 1), dtype='int32') + c_val = 3 * array_ops.ones((1, 1), dtype='int32') + d_val = 4 * array_ops.ones((1, 1), dtype='int32') + + inputs_val = {'a': (a_val, b_val), 'c': (c_val, d_val)} + res = model(inputs_val) + + # Check that inputs were flattened in the correct order. + self.assertFalse(model._enable_dict_to_input_mapping) + self.assertEqual(self.evaluate(res), [1234]) + @combinations.generate(combinations.keras_mode_combinations()) class AddLossTest(keras_parameterized.TestCase):