Add a callback to the model cloning API.
PiperOrigin-RevId: 239097271
This commit is contained in:
parent
c79154ef49
commit
dc0137f16a
@ -632,11 +632,11 @@ def _build_network_on_replica(model, mode, inputs=None, targets=None):
|
||||
# We rely on the internal methods to avoid having share_weights weights in the
|
||||
# public API.
|
||||
if isinstance(model, sequential.Sequential):
|
||||
updated_model = models._clone_sequential_model(model, input_tensors=inputs,
|
||||
share_weights=True)
|
||||
updated_model = models._clone_sequential_model(
|
||||
model, input_tensors=inputs, layer_fn=models.share_weights)
|
||||
else:
|
||||
updated_model = models._clone_functional_model(model, input_tensors=inputs,
|
||||
share_weights=True)
|
||||
updated_model = models._clone_functional_model(
|
||||
model, input_tensors=inputs, layer_fn=models.share_weights)
|
||||
|
||||
# Recast all low precision outputs back to float32 since we only casted
|
||||
# the inputs to bfloat16 and not targets. This is done so that we can preserve
|
||||
|
@ -45,28 +45,35 @@ model_from_config = model_config.model_from_config
|
||||
model_from_yaml = model_config.model_from_yaml
|
||||
model_from_json = model_config.model_from_json
|
||||
|
||||
# Callable used to clone a layer with weights preserved.
|
||||
def share_weights(layer):
|
||||
return layer
|
||||
|
||||
|
||||
def _clone_layer(layer):
|
||||
return layer.__class__.from_config(layer.get_config())
|
||||
|
||||
|
||||
def _clone_functional_model(model, input_tensors=None, share_weights=False):
|
||||
def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
|
||||
"""Clone a functional `Model` instance.
|
||||
|
||||
Model cloning is similar to calling a model on new inputs,
|
||||
except that it creates new layers (and thus new weights) instead
|
||||
of sharing the weights of the existing layers.
|
||||
|
||||
Input layers are always cloned.
|
||||
|
||||
Arguments:
|
||||
model: Instance of `Model`.
|
||||
input_tensors: optional list of input tensors
|
||||
to build the model upon. If not provided,
|
||||
placeholders will be created.
|
||||
share_weights: flag to enable sharing of non-input layers between the
|
||||
cloned and original model. Note this still clones the input layers.
|
||||
This is required when we create a per-replica copy of the model with
|
||||
distribution strategy; we want the weights to be shared but still
|
||||
feed inputs separately so we create new input layers.
|
||||
layer_fn: callable to be applied on non-input layers in the model. By
|
||||
default it clones the layer. Another example is to preserve the layer
|
||||
to share the weights. This is required when we create a per-replica
|
||||
copy of the model with distribution strategy; we want the weights to
|
||||
be shared but still feed inputs separately so we create new input
|
||||
layers.
|
||||
|
||||
Returns:
|
||||
An instance of `Model` reproducing the behavior
|
||||
@ -74,7 +81,8 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
|
||||
using newly instantiated weights.
|
||||
|
||||
Raises:
|
||||
ValueError: in case of invalid `model` argument value.
|
||||
ValueError: in case of invalid `model` argument value or `layer_fn`
|
||||
argument value.
|
||||
"""
|
||||
if not isinstance(model, Model):
|
||||
raise ValueError('Expected `model` argument '
|
||||
@ -123,6 +131,9 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
|
||||
for x, y in zip(model.inputs, input_tensors):
|
||||
tensor_map[x] = y
|
||||
|
||||
if not callable(layer_fn):
|
||||
raise ValueError('Expected `layer_fn` argument to be a callable.')
|
||||
|
||||
# Iterated over every node in the reference model, in depth order.
|
||||
depth_keys = list(model._nodes_by_depth.keys())
|
||||
depth_keys.sort(reverse=True)
|
||||
@ -134,11 +145,9 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
|
||||
|
||||
# Get or create layer.
|
||||
if layer not in layer_map:
|
||||
if not share_weights:
|
||||
# Clone layer.
|
||||
new_layer = _clone_layer(layer)
|
||||
layer_map[layer] = new_layer
|
||||
layer = new_layer
|
||||
new_layer = layer_fn(layer)
|
||||
layer_map[layer] = new_layer
|
||||
layer = new_layer
|
||||
else:
|
||||
# Reuse previously cloned layer.
|
||||
layer = layer_map[layer]
|
||||
@ -172,7 +181,7 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
|
||||
return Model(input_tensors, output_tensors, name=model.name)
|
||||
|
||||
|
||||
def _clone_sequential_model(model, input_tensors=None, share_weights=False):
|
||||
def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
|
||||
"""Clone a `Sequential` model instance.
|
||||
|
||||
Model cloning is similar to calling a model on new inputs,
|
||||
@ -184,11 +193,12 @@ def _clone_sequential_model(model, input_tensors=None, share_weights=False):
|
||||
input_tensors: optional list of input tensors
|
||||
to build the model upon. If not provided,
|
||||
placeholders will be created.
|
||||
share_weights: flag to enable sharing of non-input layers between the
|
||||
cloned and original model. Note this still clones the input layers.
|
||||
This is required when we create a per-replica copy of the model with
|
||||
distribution strategy; we want the weights to be shared but still
|
||||
feed inputs separately so we create new input layers.
|
||||
layer_fn: callable to be applied on non-input layers in the model. By
|
||||
default it clones the layer. Another example is to preserve the layer
|
||||
to share the weights. This is required when we create a per-replica
|
||||
copy of the model with distribution strategy; we want the weights to
|
||||
be shared but still feed inputs separately so we create new input
|
||||
layers.
|
||||
|
||||
Returns:
|
||||
An instance of `Sequential` reproducing the behavior
|
||||
@ -196,35 +206,36 @@ def _clone_sequential_model(model, input_tensors=None, share_weights=False):
|
||||
using newly instantiated weights.
|
||||
|
||||
Raises:
|
||||
ValueError: in case of invalid `model` argument value.
|
||||
ValueError: in case of invalid `model` argument value or `layer_fn`
|
||||
argument value.
|
||||
"""
|
||||
if not isinstance(model, Sequential):
|
||||
raise ValueError('Expected `model` argument '
|
||||
'to be a `Sequential` model instance, '
|
||||
'but got:', model)
|
||||
|
||||
if not callable(layer_fn):
|
||||
raise ValueError('Expected `layer_fn` argument to be a callable.')
|
||||
|
||||
# Use model._layers to ensure that all layers are cloned. The model's layers
|
||||
# property will exclude the initial InputLayer (if it exists) in the model,
|
||||
# resulting in a different Sequential model structure.
|
||||
if input_tensors is None:
|
||||
if share_weights:
|
||||
# In preserve weights case we still want the input layers to be cloned.
|
||||
layers = []
|
||||
for layer in model._layers:
|
||||
if isinstance(layer, InputLayer):
|
||||
layers.append(_clone_layer(layer))
|
||||
else:
|
||||
layers.append(layer)
|
||||
else:
|
||||
layers = [_clone_layer(layer) for layer in model._layers]
|
||||
layers = []
|
||||
for layer in model._layers:
|
||||
if isinstance(layer, InputLayer):
|
||||
layers.append(_clone_layer(layer))
|
||||
else:
|
||||
layers.append(layer_fn(layer))
|
||||
return Sequential(layers=layers, name=model.name)
|
||||
else:
|
||||
# If input tensors are provided, the original model's InputLayer is
|
||||
# overwritten with a different InputLayer.
|
||||
layers = [
|
||||
layer for layer in model._layers if not isinstance(layer, InputLayer)]
|
||||
if not share_weights:
|
||||
layers = [_clone_layer(layer) for layer in layers]
|
||||
layer_fn(layer)
|
||||
for layer in model._layers
|
||||
if not isinstance(layer, InputLayer)
|
||||
]
|
||||
if len(generic_utils.to_list(input_tensors)) != 1:
|
||||
raise ValueError('To clone a `Sequential` model, we expect '
|
||||
' at most one tensor '
|
||||
|
@ -104,7 +104,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
||||
|
||||
if share_weights:
|
||||
clone_fn = functools.partial(
|
||||
keras.models._clone_sequential_model, share_weights=True)
|
||||
keras.models._clone_sequential_model, layer_fn=models.share_weights)
|
||||
else:
|
||||
clone_fn = keras.models.clone_model
|
||||
|
||||
@ -151,7 +151,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
||||
def test_clone_functional_model(self, share_weights):
|
||||
if share_weights:
|
||||
clone_fn = functools.partial(
|
||||
keras.models._clone_functional_model, share_weights=True)
|
||||
keras.models._clone_functional_model, layer_fn=models.share_weights)
|
||||
else:
|
||||
clone_fn = keras.models.clone_model
|
||||
|
||||
@ -212,7 +212,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
||||
def test_clone_functional_with_masking(self, share_weights):
|
||||
if share_weights:
|
||||
clone_fn = functools.partial(
|
||||
keras.models._clone_functional_model, share_weights=True)
|
||||
keras.models._clone_functional_model, layer_fn=models.share_weights)
|
||||
else:
|
||||
clone_fn = keras.models.clone_model
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user