Fixed a bug with Functional model serialization when a layer that produces kwarg args of another layer had already been used to define a different functional model.
PiperOrigin-RevId: 324934074 Change-Id: Ibcee3958120e50bc72eb3a3c95410f8e5e1135ea
This commit is contained in:
parent
deefe8cafb
commit
e95a955af8
@ -1129,7 +1129,18 @@ def reconstruct_from_config(config, custom_objects=None, created_layers=None):
|
||||
tensor_index = t[2]
|
||||
|
||||
layer = layer_map[layer_name]
|
||||
node = layer._inbound_nodes[get_node_index(layer, node_index)]
|
||||
new_node_index = get_node_index(layer, node_index)
|
||||
if new_node_index is None:
|
||||
# The inbound node may not have been processed yet,
|
||||
# (This can happen e.g. if it depends on a different set
|
||||
# of inputs than those that have been processed already).
|
||||
# raise an IndexError so that the current node puts itself
|
||||
# back on the unprocessed queue.
|
||||
# Caution: This may lead to infinite loops for malformed
|
||||
# network configurations! (or when there is a bug in
|
||||
# the network config loading code).
|
||||
raise IndexError
|
||||
node = layer._inbound_nodes[new_node_index]
|
||||
return nest.flatten(node.outputs)[tensor_index]
|
||||
return t
|
||||
|
||||
|
@ -998,8 +998,11 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_call_kwarg_derived_from_keras_layer(self):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.keras_mode_combinations(),
|
||||
combinations.combine(share_already_used_layer=[True, False])))
|
||||
def test_call_kwarg_derived_from_keras_layer(self, share_already_used_layer):
|
||||
|
||||
class MaybeAdd(layers.Layer):
|
||||
|
||||
@ -1008,9 +1011,26 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
return x1 + x2
|
||||
return x1
|
||||
|
||||
class IdentityLayer(layers.Layer):
|
||||
|
||||
def call(self, x):
|
||||
return x
|
||||
|
||||
input1 = input_layer_lib.Input(10)
|
||||
input2 = input_layer_lib.Input(10)
|
||||
outputs = MaybeAdd()(input1, x2=input2)
|
||||
identity_layer = IdentityLayer()
|
||||
|
||||
if share_already_used_layer:
|
||||
# We have had model serialization/deserialization break in the past:
|
||||
# when a layer was previously used to construct other functional models
|
||||
# and had a non-empty list of inbound nodes before being used to define
|
||||
# the model being serialized/deserialized.
|
||||
# (The serialization/deserialization was not correctly adjusting
|
||||
# the node_index serialization/deserialization).
|
||||
# So, we explicitly test this case.
|
||||
training_lib.Model([input1], identity_layer(input1))
|
||||
|
||||
outputs = MaybeAdd()(input1, x2=identity_layer(input2))
|
||||
model = training_lib.Model([input1, input2], outputs)
|
||||
model.compile(
|
||||
'sgd',
|
||||
@ -1024,7 +1044,11 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
|
||||
model.get_config(),
|
||||
custom_objects={
|
||||
'MaybeAdd': MaybeAdd,
|
||||
'IdentityLayer': IdentityLayer
|
||||
})
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
@ -1107,10 +1131,18 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
TypeError, 'Layer double was passed non-JSON-serializable arguments.'):
|
||||
model.get_config()
|
||||
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_mode_combinations(),
|
||||
combinations.keras_tensor_combinations()))
|
||||
def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(self):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.keras_mode_combinations(),
|
||||
combinations.keras_tensor_combinations(),
|
||||
combinations.combine(share_already_used_layer=[True, False])))
|
||||
def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(
|
||||
self, share_already_used_layer):
|
||||
|
||||
class IdentityLayer(layers.Layer):
|
||||
|
||||
def call(self, x):
|
||||
return x
|
||||
|
||||
class MaybeAdd(layers.Layer):
|
||||
|
||||
@ -1120,7 +1152,18 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
return x1
|
||||
|
||||
input2 = input_layer_lib.Input(10)
|
||||
outputs = MaybeAdd()(3., x2=input2)
|
||||
identity_layer = IdentityLayer()
|
||||
if share_already_used_layer:
|
||||
# We have had model serialization/deserialization break in the past:
|
||||
# when a layer was previously used to construct other functional models
|
||||
# and had a non-empty list of inbound nodes before being used to define
|
||||
# the model being serialized/deserialized.
|
||||
# (The serialization/deserialization was not correctly adjusting
|
||||
# the node_index serialization/deserialization).
|
||||
# So, we explicitly test this case.
|
||||
training_lib.Model([input2], identity_layer(input2))
|
||||
|
||||
outputs = MaybeAdd()(3., x2=identity_layer(input2))
|
||||
model = training_lib.Model([input2], outputs)
|
||||
model.compile(
|
||||
'sgd',
|
||||
@ -1134,7 +1177,11 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
|
||||
model.get_config(),
|
||||
custom_objects={
|
||||
'MaybeAdd': MaybeAdd,
|
||||
'IdentityLayer': IdentityLayer
|
||||
})
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
|
@ -169,6 +169,23 @@ class Node(object):
|
||||
arguments.update(kwargs)
|
||||
kwargs = arguments
|
||||
|
||||
def _serialize_keras_tensor(t):
|
||||
"""Serializes a single Tensor passed to `call`."""
|
||||
if hasattr(t, '_keras_history'):
|
||||
kh = t._keras_history
|
||||
node_index = kh.node_index
|
||||
node_key = make_node_key(kh.layer.name, node_index)
|
||||
new_node_index = node_conversion_map.get(node_key, 0)
|
||||
return [kh.layer.name, new_node_index, kh.tensor_index]
|
||||
|
||||
if isinstance(t, np.ndarray):
|
||||
return t.tolist()
|
||||
|
||||
if isinstance(t, ops.Tensor):
|
||||
return backend.get_value(t).tolist()
|
||||
|
||||
return t
|
||||
|
||||
kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
|
||||
try:
|
||||
json.dumps(kwargs, default=json_utils.get_json_type)
|
||||
@ -273,18 +290,3 @@ class KerasHistory(
|
||||
|
||||
def is_keras_tensor(obj):
|
||||
return hasattr(obj, '_keras_history')
|
||||
|
||||
|
||||
def _serialize_keras_tensor(t):
|
||||
"""Serializes a single Tensor passed to `call`."""
|
||||
if hasattr(t, '_keras_history'):
|
||||
kh = t._keras_history
|
||||
return [kh.layer.name, kh.node_index, kh.tensor_index]
|
||||
|
||||
if isinstance(t, np.ndarray):
|
||||
return t.tolist()
|
||||
|
||||
if isinstance(t, ops.Tensor):
|
||||
return backend.get_value(t).tolist()
|
||||
|
||||
return t
|
||||
|
@ -206,6 +206,8 @@ def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
|
||||
ancillary_layers = [
|
||||
layer for layer in created_layers.values() if layer not in model.layers
|
||||
]
|
||||
# TODO(b/162887610): This may need to adjust the inbound node index if the
|
||||
# created layers had already been used to define other models.
|
||||
if ancillary_layers:
|
||||
new_nodes = nest.flatten([
|
||||
layer.inbound_nodes[1:]
|
||||
|
Loading…
Reference in New Issue
Block a user