Fixed a bug with Functional model serialization when a layer that produces kwarg args of another layer had already been used to define a different functional model.

PiperOrigin-RevId: 324934074
Change-Id: Ibcee3958120e50bc72eb3a3c95410f8e5e1135ea
This commit is contained in:
Tomer Kaftan 2020-08-04 18:35:44 -07:00 committed by TensorFlower Gardener
parent deefe8cafb
commit e95a955af8
4 changed files with 88 additions and 26 deletions

View File

@ -1129,7 +1129,18 @@ def reconstruct_from_config(config, custom_objects=None, created_layers=None):
tensor_index = t[2]
layer = layer_map[layer_name]
node = layer._inbound_nodes[get_node_index(layer, node_index)]
new_node_index = get_node_index(layer, node_index)
if new_node_index is None:
# The inbound node may not have been processed yet,
# (This can happen e.g. if it depends on a different set
# of inputs than those that have been processed already).
# raise an IndexError so that the current node puts itself
# back on the unprocessed queue.
# Caution: This may lead to infinite loops for malformed
# network configurations! (or when there is a bug in
# the network config loading code).
raise IndexError
node = layer._inbound_nodes[new_node_index]
return nest.flatten(node.outputs)[tensor_index]
return t

View File

@ -998,8 +998,11 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations())
def test_call_kwarg_derived_from_keras_layer(self):
@combinations.generate(
combinations.times(
combinations.keras_mode_combinations(),
combinations.combine(share_already_used_layer=[True, False])))
def test_call_kwarg_derived_from_keras_layer(self, share_already_used_layer):
class MaybeAdd(layers.Layer):
@ -1008,9 +1011,26 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
return x1 + x2
return x1
class IdentityLayer(layers.Layer):
def call(self, x):
return x
input1 = input_layer_lib.Input(10)
input2 = input_layer_lib.Input(10)
outputs = MaybeAdd()(input1, x2=input2)
identity_layer = IdentityLayer()
if share_already_used_layer:
# We have had model serialization/deserialization break in the past:
# when a layer was previously used to construct other functional models
# and had a non-empty list of inbound nodes before being used to define
# the model being serialized/deserialized.
# (The serialization/deserialization was not correctly adjusting
# the node_index serialization/deserialization).
# So, we explicitly test this case.
training_lib.Model([input1], identity_layer(input1))
outputs = MaybeAdd()(input1, x2=identity_layer(input2))
model = training_lib.Model([input1, input2], outputs)
model.compile(
'sgd',
@ -1024,7 +1044,11 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
self.assertEqual(history.history['loss'][0], 0.0)
model = training_lib.Model.from_config(
model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
model.get_config(),
custom_objects={
'MaybeAdd': MaybeAdd,
'IdentityLayer': IdentityLayer
})
model.compile(
'sgd',
'mse',
@ -1107,10 +1131,18 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
TypeError, 'Layer double was passed non-JSON-serializable arguments.'):
model.get_config()
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(),
combinations.keras_tensor_combinations()))
def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(self):
@combinations.generate(
combinations.times(
combinations.keras_mode_combinations(),
combinations.keras_tensor_combinations(),
combinations.combine(share_already_used_layer=[True, False])))
def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(
self, share_already_used_layer):
class IdentityLayer(layers.Layer):
def call(self, x):
return x
class MaybeAdd(layers.Layer):
@ -1120,7 +1152,18 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
return x1
input2 = input_layer_lib.Input(10)
outputs = MaybeAdd()(3., x2=input2)
identity_layer = IdentityLayer()
if share_already_used_layer:
# We have had model serialization/deserialization break in the past:
# when a layer was previously used to construct other functional models
# and had a non-empty list of inbound nodes before being used to define
# the model being serialized/deserialized.
# (The serialization/deserialization was not correctly adjusting
# the node_index serialization/deserialization).
# So, we explicitly test this case.
training_lib.Model([input2], identity_layer(input2))
outputs = MaybeAdd()(3., x2=identity_layer(input2))
model = training_lib.Model([input2], outputs)
model.compile(
'sgd',
@ -1134,7 +1177,11 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
self.assertEqual(history.history['loss'][0], 0.0)
model = training_lib.Model.from_config(
model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
model.get_config(),
custom_objects={
'MaybeAdd': MaybeAdd,
'IdentityLayer': IdentityLayer
})
model.compile(
'sgd',
'mse',

View File

@ -169,6 +169,23 @@ class Node(object):
arguments.update(kwargs)
kwargs = arguments
def _serialize_keras_tensor(t):
"""Serializes a single Tensor passed to `call`."""
if hasattr(t, '_keras_history'):
kh = t._keras_history
node_index = kh.node_index
node_key = make_node_key(kh.layer.name, node_index)
new_node_index = node_conversion_map.get(node_key, 0)
return [kh.layer.name, new_node_index, kh.tensor_index]
if isinstance(t, np.ndarray):
return t.tolist()
if isinstance(t, ops.Tensor):
return backend.get_value(t).tolist()
return t
kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
try:
json.dumps(kwargs, default=json_utils.get_json_type)
@ -273,18 +290,3 @@ class KerasHistory(
def is_keras_tensor(obj):
return hasattr(obj, '_keras_history')
def _serialize_keras_tensor(t):
"""Serializes a single Tensor passed to `call`."""
if hasattr(t, '_keras_history'):
kh = t._keras_history
return [kh.layer.name, kh.node_index, kh.tensor_index]
if isinstance(t, np.ndarray):
return t.tolist()
if isinstance(t, ops.Tensor):
return backend.get_value(t).tolist()
return t

View File

@ -206,6 +206,8 @@ def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
ancillary_layers = [
layer for layer in created_layers.values() if layer not in model.layers
]
# TODO(b/162887610): This may need to adjust the inbound node index if the
# created layers had already been used to define other models.
if ancillary_layers:
new_nodes = nest.flatten([
layer.inbound_nodes[1:]