Add a callback to the model cloning API.
PiperOrigin-RevId: 239097271
This commit is contained in:
parent
c79154ef49
commit
dc0137f16a
@ -632,11 +632,11 @@ def _build_network_on_replica(model, mode, inputs=None, targets=None):
|
|||||||
# We rely on the internal methods to avoid having share_weights weights in the
|
# We rely on the internal methods to avoid having share_weights weights in the
|
||||||
# public API.
|
# public API.
|
||||||
if isinstance(model, sequential.Sequential):
|
if isinstance(model, sequential.Sequential):
|
||||||
updated_model = models._clone_sequential_model(model, input_tensors=inputs,
|
updated_model = models._clone_sequential_model(
|
||||||
share_weights=True)
|
model, input_tensors=inputs, layer_fn=models.share_weights)
|
||||||
else:
|
else:
|
||||||
updated_model = models._clone_functional_model(model, input_tensors=inputs,
|
updated_model = models._clone_functional_model(
|
||||||
share_weights=True)
|
model, input_tensors=inputs, layer_fn=models.share_weights)
|
||||||
|
|
||||||
# Recast all low precision outputs back to float32 since we only casted
|
# Recast all low precision outputs back to float32 since we only casted
|
||||||
# the inputs to bfloat16 and not targets. This is done so that we can preserve
|
# the inputs to bfloat16 and not targets. This is done so that we can preserve
|
||||||
|
@ -45,28 +45,35 @@ model_from_config = model_config.model_from_config
|
|||||||
model_from_yaml = model_config.model_from_yaml
|
model_from_yaml = model_config.model_from_yaml
|
||||||
model_from_json = model_config.model_from_json
|
model_from_json = model_config.model_from_json
|
||||||
|
|
||||||
|
# Callable used to clone a layer with weights preserved.
|
||||||
|
def share_weights(layer):
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
def _clone_layer(layer):
|
def _clone_layer(layer):
|
||||||
return layer.__class__.from_config(layer.get_config())
|
return layer.__class__.from_config(layer.get_config())
|
||||||
|
|
||||||
|
|
||||||
def _clone_functional_model(model, input_tensors=None, share_weights=False):
|
def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
|
||||||
"""Clone a functional `Model` instance.
|
"""Clone a functional `Model` instance.
|
||||||
|
|
||||||
Model cloning is similar to calling a model on new inputs,
|
Model cloning is similar to calling a model on new inputs,
|
||||||
except that it creates new layers (and thus new weights) instead
|
except that it creates new layers (and thus new weights) instead
|
||||||
of sharing the weights of the existing layers.
|
of sharing the weights of the existing layers.
|
||||||
|
|
||||||
|
Input layers are always cloned.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
model: Instance of `Model`.
|
model: Instance of `Model`.
|
||||||
input_tensors: optional list of input tensors
|
input_tensors: optional list of input tensors
|
||||||
to build the model upon. If not provided,
|
to build the model upon. If not provided,
|
||||||
placeholders will be created.
|
placeholders will be created.
|
||||||
share_weights: flag to enable sharing of non-input layers between the
|
layer_fn: callable to be applied on non-input layers in the model. By
|
||||||
cloned and original model. Note this still clones the input layers.
|
default it clones the layer. Another example is to preserve the layer
|
||||||
This is required when we create a per-replica copy of the model with
|
to share the weights. This is required when we create a per-replica
|
||||||
distribution strategy; we want the weights to be shared but still
|
copy of the model with distribution strategy; we want the weights to
|
||||||
feed inputs separately so we create new input layers.
|
be shared but still feed inputs separately so we create new input
|
||||||
|
layers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of `Model` reproducing the behavior
|
An instance of `Model` reproducing the behavior
|
||||||
@ -74,7 +81,8 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
|
|||||||
using newly instantiated weights.
|
using newly instantiated weights.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: in case of invalid `model` argument value.
|
ValueError: in case of invalid `model` argument value or `layer_fn`
|
||||||
|
argument value.
|
||||||
"""
|
"""
|
||||||
if not isinstance(model, Model):
|
if not isinstance(model, Model):
|
||||||
raise ValueError('Expected `model` argument '
|
raise ValueError('Expected `model` argument '
|
||||||
@ -123,6 +131,9 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
|
|||||||
for x, y in zip(model.inputs, input_tensors):
|
for x, y in zip(model.inputs, input_tensors):
|
||||||
tensor_map[x] = y
|
tensor_map[x] = y
|
||||||
|
|
||||||
|
if not callable(layer_fn):
|
||||||
|
raise ValueError('Expected `layer_fn` argument to be a callable.')
|
||||||
|
|
||||||
# Iterated over every node in the reference model, in depth order.
|
# Iterated over every node in the reference model, in depth order.
|
||||||
depth_keys = list(model._nodes_by_depth.keys())
|
depth_keys = list(model._nodes_by_depth.keys())
|
||||||
depth_keys.sort(reverse=True)
|
depth_keys.sort(reverse=True)
|
||||||
@ -134,11 +145,9 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
|
|||||||
|
|
||||||
# Get or create layer.
|
# Get or create layer.
|
||||||
if layer not in layer_map:
|
if layer not in layer_map:
|
||||||
if not share_weights:
|
new_layer = layer_fn(layer)
|
||||||
# Clone layer.
|
layer_map[layer] = new_layer
|
||||||
new_layer = _clone_layer(layer)
|
layer = new_layer
|
||||||
layer_map[layer] = new_layer
|
|
||||||
layer = new_layer
|
|
||||||
else:
|
else:
|
||||||
# Reuse previously cloned layer.
|
# Reuse previously cloned layer.
|
||||||
layer = layer_map[layer]
|
layer = layer_map[layer]
|
||||||
@ -172,7 +181,7 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
|
|||||||
return Model(input_tensors, output_tensors, name=model.name)
|
return Model(input_tensors, output_tensors, name=model.name)
|
||||||
|
|
||||||
|
|
||||||
def _clone_sequential_model(model, input_tensors=None, share_weights=False):
|
def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
|
||||||
"""Clone a `Sequential` model instance.
|
"""Clone a `Sequential` model instance.
|
||||||
|
|
||||||
Model cloning is similar to calling a model on new inputs,
|
Model cloning is similar to calling a model on new inputs,
|
||||||
@ -184,11 +193,12 @@ def _clone_sequential_model(model, input_tensors=None, share_weights=False):
|
|||||||
input_tensors: optional list of input tensors
|
input_tensors: optional list of input tensors
|
||||||
to build the model upon. If not provided,
|
to build the model upon. If not provided,
|
||||||
placeholders will be created.
|
placeholders will be created.
|
||||||
share_weights: flag to enable sharing of non-input layers between the
|
layer_fn: callable to be applied on non-input layers in the model. By
|
||||||
cloned and original model. Note this still clones the input layers.
|
default it clones the layer. Another example is to preserve the layer
|
||||||
This is required when we create a per-replica copy of the model with
|
to share the weights. This is required when we create a per-replica
|
||||||
distribution strategy; we want the weights to be shared but still
|
copy of the model with distribution strategy; we want the weights to
|
||||||
feed inputs separately so we create new input layers.
|
be shared but still feed inputs separately so we create new input
|
||||||
|
layers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of `Sequential` reproducing the behavior
|
An instance of `Sequential` reproducing the behavior
|
||||||
@ -196,35 +206,36 @@ def _clone_sequential_model(model, input_tensors=None, share_weights=False):
|
|||||||
using newly instantiated weights.
|
using newly instantiated weights.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: in case of invalid `model` argument value.
|
ValueError: in case of invalid `model` argument value or `layer_fn`
|
||||||
|
argument value.
|
||||||
"""
|
"""
|
||||||
if not isinstance(model, Sequential):
|
if not isinstance(model, Sequential):
|
||||||
raise ValueError('Expected `model` argument '
|
raise ValueError('Expected `model` argument '
|
||||||
'to be a `Sequential` model instance, '
|
'to be a `Sequential` model instance, '
|
||||||
'but got:', model)
|
'but got:', model)
|
||||||
|
|
||||||
|
if not callable(layer_fn):
|
||||||
|
raise ValueError('Expected `layer_fn` argument to be a callable.')
|
||||||
|
|
||||||
# Use model._layers to ensure that all layers are cloned. The model's layers
|
# Use model._layers to ensure that all layers are cloned. The model's layers
|
||||||
# property will exclude the initial InputLayer (if it exists) in the model,
|
# property will exclude the initial InputLayer (if it exists) in the model,
|
||||||
# resulting in a different Sequential model structure.
|
# resulting in a different Sequential model structure.
|
||||||
if input_tensors is None:
|
if input_tensors is None:
|
||||||
if share_weights:
|
layers = []
|
||||||
# In preserve weights case we still want the input layers to be cloned.
|
for layer in model._layers:
|
||||||
layers = []
|
if isinstance(layer, InputLayer):
|
||||||
for layer in model._layers:
|
layers.append(_clone_layer(layer))
|
||||||
if isinstance(layer, InputLayer):
|
else:
|
||||||
layers.append(_clone_layer(layer))
|
layers.append(layer_fn(layer))
|
||||||
else:
|
|
||||||
layers.append(layer)
|
|
||||||
else:
|
|
||||||
layers = [_clone_layer(layer) for layer in model._layers]
|
|
||||||
return Sequential(layers=layers, name=model.name)
|
return Sequential(layers=layers, name=model.name)
|
||||||
else:
|
else:
|
||||||
# If input tensors are provided, the original model's InputLayer is
|
# If input tensors are provided, the original model's InputLayer is
|
||||||
# overwritten with a different InputLayer.
|
# overwritten with a different InputLayer.
|
||||||
layers = [
|
layers = [
|
||||||
layer for layer in model._layers if not isinstance(layer, InputLayer)]
|
layer_fn(layer)
|
||||||
if not share_weights:
|
for layer in model._layers
|
||||||
layers = [_clone_layer(layer) for layer in layers]
|
if not isinstance(layer, InputLayer)
|
||||||
|
]
|
||||||
if len(generic_utils.to_list(input_tensors)) != 1:
|
if len(generic_utils.to_list(input_tensors)) != 1:
|
||||||
raise ValueError('To clone a `Sequential` model, we expect '
|
raise ValueError('To clone a `Sequential` model, we expect '
|
||||||
' at most one tensor '
|
' at most one tensor '
|
||||||
|
@ -104,7 +104,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
if share_weights:
|
if share_weights:
|
||||||
clone_fn = functools.partial(
|
clone_fn = functools.partial(
|
||||||
keras.models._clone_sequential_model, share_weights=True)
|
keras.models._clone_sequential_model, layer_fn=models.share_weights)
|
||||||
else:
|
else:
|
||||||
clone_fn = keras.models.clone_model
|
clone_fn = keras.models.clone_model
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
|||||||
def test_clone_functional_model(self, share_weights):
|
def test_clone_functional_model(self, share_weights):
|
||||||
if share_weights:
|
if share_weights:
|
||||||
clone_fn = functools.partial(
|
clone_fn = functools.partial(
|
||||||
keras.models._clone_functional_model, share_weights=True)
|
keras.models._clone_functional_model, layer_fn=models.share_weights)
|
||||||
else:
|
else:
|
||||||
clone_fn = keras.models.clone_model
|
clone_fn = keras.models.clone_model
|
||||||
|
|
||||||
@ -212,7 +212,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
|||||||
def test_clone_functional_with_masking(self, share_weights):
|
def test_clone_functional_with_masking(self, share_weights):
|
||||||
if share_weights:
|
if share_weights:
|
||||||
clone_fn = functools.partial(
|
clone_fn = functools.partial(
|
||||||
keras.models._clone_functional_model, share_weights=True)
|
keras.models._clone_functional_model, layer_fn=models.share_weights)
|
||||||
else:
|
else:
|
||||||
clone_fn = keras.models.clone_model
|
clone_fn = keras.models.clone_model
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user