Fix input mapping logic for nested dictionaries in Functional Model.

PiperOrigin-RevId: 331180009
Change-Id: I7654c7e6ee6d8a0a486aaa58f1473f14e6ae6bc9
This commit is contained in:
Thomas O'Malley 2020-09-11 10:52:56 -07:00 committed by TensorFlower Gardener
parent 7d35f9e0ee
commit e6a0a9e127
2 changed files with 32 additions and 4 deletions

View File

@ -140,10 +140,16 @@ class Functional(training_lib.Model):
# Models constructed with a single Tensor or list of Tensors can
# be called with a dict, where the keys of the dict are the names
# of the `Input` objects. Extra keys are ignored with warning.
self._enable_dict_to_input_mapping = (
not nest.is_nested(self._nested_inputs) or
(isinstance(self._nested_inputs, (list, tuple, dict)) and
not any(nest.is_nested(t) for t in self._nested_inputs)))
if not nest.is_nested(self._nested_inputs):
self._enable_dict_to_input_mapping = True
elif (isinstance(self._nested_inputs, (list, tuple)) and
not any(nest.is_nested(t) for t in self._nested_inputs)):
self._enable_dict_to_input_mapping = True
elif (isinstance(self._nested_inputs, dict) and
not any(nest.is_nested(t) for t in self._nested_inputs.values())):
self._enable_dict_to_input_mapping = True
else:
self._enable_dict_to_input_mapping = False
if not keras_tensor.keras_tensors_enabled():
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):

View File

@ -1972,6 +1972,28 @@ class NestedNetworkTest(keras_parameterized.TestCase):
# Check that 'b' was used and 'a' was ignored.
self.assertEqual(res.shape.as_list(), [1, 1])
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_nested_dict_mapping(self):
a = input_layer_lib.Input(shape=(1,), dtype='int32', name='a')
b = input_layer_lib.Input(shape=(1,), dtype='int32', name='b')
c = input_layer_lib.Input(shape=(1,), dtype='int32', name='c')
d = input_layer_lib.Input(shape=(1,), dtype='int32', name='d')
inputs = {'a': (a, b), 'c': (c, d)}
outputs = 1000 * a + 100 * b + 10 * c + d
model = training_lib.Model(inputs, outputs)
a_val = array_ops.ones((1, 1), dtype='int32')
b_val = 2 * array_ops.ones((1, 1), dtype='int32')
c_val = 3 * array_ops.ones((1, 1), dtype='int32')
d_val = 4 * array_ops.ones((1, 1), dtype='int32')
inputs_val = {'a': (a_val, b_val), 'c': (c_val, d_val)}
res = model(inputs_val)
# Check that inputs were flattened in the correct order.
self.assertFalse(model._enable_dict_to_input_mapping)
self.assertEqual(self.evaluate(res), [1234])
@combinations.generate(combinations.keras_mode_combinations())
class AddLossTest(keras_parameterized.TestCase):