Fix input mapping logic for nested dictionaries in Functional Model.
PiperOrigin-RevId: 331180009 Change-Id: I7654c7e6ee6d8a0a486aaa58f1473f14e6ae6bc9
This commit is contained in:
parent
7d35f9e0ee
commit
e6a0a9e127
@ -140,10 +140,16 @@ class Functional(training_lib.Model):
|
||||
# Models constructed with a single Tensor or list of Tensors can
|
||||
# be called with a dict, where the keys of the dict are the names
|
||||
# of the `Input` objects. Extra keys are ignored with warning.
|
||||
self._enable_dict_to_input_mapping = (
|
||||
not nest.is_nested(self._nested_inputs) or
|
||||
(isinstance(self._nested_inputs, (list, tuple, dict)) and
|
||||
not any(nest.is_nested(t) for t in self._nested_inputs)))
|
||||
if not nest.is_nested(self._nested_inputs):
|
||||
self._enable_dict_to_input_mapping = True
|
||||
elif (isinstance(self._nested_inputs, (list, tuple)) and
|
||||
not any(nest.is_nested(t) for t in self._nested_inputs)):
|
||||
self._enable_dict_to_input_mapping = True
|
||||
elif (isinstance(self._nested_inputs, dict) and
|
||||
not any(nest.is_nested(t) for t in self._nested_inputs.values())):
|
||||
self._enable_dict_to_input_mapping = True
|
||||
else:
|
||||
self._enable_dict_to_input_mapping = False
|
||||
|
||||
if not keras_tensor.keras_tensors_enabled():
|
||||
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
||||
|
@ -1972,6 +1972,28 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
||||
# Check that 'b' was used and 'a' was ignored.
|
||||
self.assertEqual(res.shape.as_list(), [1, 1])
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_nested_dict_mapping(self):
|
||||
a = input_layer_lib.Input(shape=(1,), dtype='int32', name='a')
|
||||
b = input_layer_lib.Input(shape=(1,), dtype='int32', name='b')
|
||||
c = input_layer_lib.Input(shape=(1,), dtype='int32', name='c')
|
||||
d = input_layer_lib.Input(shape=(1,), dtype='int32', name='d')
|
||||
inputs = {'a': (a, b), 'c': (c, d)}
|
||||
outputs = 1000 * a + 100 * b + 10 * c + d
|
||||
model = training_lib.Model(inputs, outputs)
|
||||
|
||||
a_val = array_ops.ones((1, 1), dtype='int32')
|
||||
b_val = 2 * array_ops.ones((1, 1), dtype='int32')
|
||||
c_val = 3 * array_ops.ones((1, 1), dtype='int32')
|
||||
d_val = 4 * array_ops.ones((1, 1), dtype='int32')
|
||||
|
||||
inputs_val = {'a': (a_val, b_val), 'c': (c_val, d_val)}
|
||||
res = model(inputs_val)
|
||||
|
||||
# Check that inputs were flattened in the correct order.
|
||||
self.assertFalse(model._enable_dict_to_input_mapping)
|
||||
self.assertEqual(self.evaluate(res), [1234])
|
||||
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
class AddLossTest(keras_parameterized.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user