Add a callback to the model cloning API.

PiperOrigin-RevId: 239097271
This commit is contained in:
Yunlu Li 2019-03-18 17:37:31 -07:00 committed by TensorFlower Gardener
parent c79154ef49
commit dc0137f16a
3 changed files with 50 additions and 39 deletions

View File

@ -632,11 +632,11 @@ def _build_network_on_replica(model, mode, inputs=None, targets=None):
# We rely on the internal methods to avoid having share_weights weights in the # We rely on the internal methods to avoid having share_weights weights in the
# public API. # public API.
if isinstance(model, sequential.Sequential): if isinstance(model, sequential.Sequential):
updated_model = models._clone_sequential_model(model, input_tensors=inputs, updated_model = models._clone_sequential_model(
share_weights=True) model, input_tensors=inputs, layer_fn=models.share_weights)
else: else:
updated_model = models._clone_functional_model(model, input_tensors=inputs, updated_model = models._clone_functional_model(
share_weights=True) model, input_tensors=inputs, layer_fn=models.share_weights)
# Recast all low precision outputs back to float32 since we only casted # Recast all low precision outputs back to float32 since we only casted
# the inputs to bfloat16 and not targets. This is done so that we can preserve # the inputs to bfloat16 and not targets. This is done so that we can preserve

View File

@ -45,28 +45,35 @@ model_from_config = model_config.model_from_config
model_from_yaml = model_config.model_from_yaml model_from_yaml = model_config.model_from_yaml
model_from_json = model_config.model_from_json model_from_json = model_config.model_from_json
# Callable used to clone a layer with weights preserved.
def share_weights(layer):
return layer
def _clone_layer(layer): def _clone_layer(layer):
return layer.__class__.from_config(layer.get_config()) return layer.__class__.from_config(layer.get_config())
def _clone_functional_model(model, input_tensors=None, share_weights=False): def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
"""Clone a functional `Model` instance. """Clone a functional `Model` instance.
Model cloning is similar to calling a model on new inputs, Model cloning is similar to calling a model on new inputs,
except that it creates new layers (and thus new weights) instead except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers. of sharing the weights of the existing layers.
Input layers are always cloned.
Arguments: Arguments:
model: Instance of `Model`. model: Instance of `Model`.
input_tensors: optional list of input tensors input_tensors: optional list of input tensors
to build the model upon. If not provided, to build the model upon. If not provided,
placeholders will be created. placeholders will be created.
share_weights: flag to enable sharing of non-input layers between the layer_fn: callable to be applied on non-input layers in the model. By
cloned and original model. Note this still clones the input layers. default it clones the layer. Another example is to preserve the layer
This is required when we create a per-replica copy of the model with to share the weights. This is required when we create a per-replica
distribution strategy; we want the weights to be shared but still copy of the model with distribution strategy; we want the weights to
feed inputs separately so we create new input layers. be shared but still feed inputs separately so we create new input
layers.
Returns: Returns:
An instance of `Model` reproducing the behavior An instance of `Model` reproducing the behavior
@ -74,7 +81,8 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
using newly instantiated weights. using newly instantiated weights.
Raises: Raises:
ValueError: in case of invalid `model` argument value. ValueError: in case of invalid `model` argument value or `layer_fn`
argument value.
""" """
if not isinstance(model, Model): if not isinstance(model, Model):
raise ValueError('Expected `model` argument ' raise ValueError('Expected `model` argument '
@ -123,6 +131,9 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
for x, y in zip(model.inputs, input_tensors): for x, y in zip(model.inputs, input_tensors):
tensor_map[x] = y tensor_map[x] = y
if not callable(layer_fn):
raise ValueError('Expected `layer_fn` argument to be a callable.')
# Iterated over every node in the reference model, in depth order. # Iterated over every node in the reference model, in depth order.
depth_keys = list(model._nodes_by_depth.keys()) depth_keys = list(model._nodes_by_depth.keys())
depth_keys.sort(reverse=True) depth_keys.sort(reverse=True)
@ -134,11 +145,9 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
# Get or create layer. # Get or create layer.
if layer not in layer_map: if layer not in layer_map:
if not share_weights: new_layer = layer_fn(layer)
# Clone layer. layer_map[layer] = new_layer
new_layer = _clone_layer(layer) layer = new_layer
layer_map[layer] = new_layer
layer = new_layer
else: else:
# Reuse previously cloned layer. # Reuse previously cloned layer.
layer = layer_map[layer] layer = layer_map[layer]
@ -172,7 +181,7 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
return Model(input_tensors, output_tensors, name=model.name) return Model(input_tensors, output_tensors, name=model.name)
def _clone_sequential_model(model, input_tensors=None, share_weights=False): def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
"""Clone a `Sequential` model instance. """Clone a `Sequential` model instance.
Model cloning is similar to calling a model on new inputs, Model cloning is similar to calling a model on new inputs,
@ -184,11 +193,12 @@ def _clone_sequential_model(model, input_tensors=None, share_weights=False):
input_tensors: optional list of input tensors input_tensors: optional list of input tensors
to build the model upon. If not provided, to build the model upon. If not provided,
placeholders will be created. placeholders will be created.
share_weights: flag to enable sharing of non-input layers between the layer_fn: callable to be applied on non-input layers in the model. By
cloned and original model. Note this still clones the input layers. default it clones the layer. Another example is to preserve the layer
This is required when we create a per-replica copy of the model with to share the weights. This is required when we create a per-replica
distribution strategy; we want the weights to be shared but still copy of the model with distribution strategy; we want the weights to
feed inputs separately so we create new input layers. be shared but still feed inputs separately so we create new input
layers.
Returns: Returns:
An instance of `Sequential` reproducing the behavior An instance of `Sequential` reproducing the behavior
@ -196,35 +206,36 @@ def _clone_sequential_model(model, input_tensors=None, share_weights=False):
using newly instantiated weights. using newly instantiated weights.
Raises: Raises:
ValueError: in case of invalid `model` argument value. ValueError: in case of invalid `model` argument value or `layer_fn`
argument value.
""" """
if not isinstance(model, Sequential): if not isinstance(model, Sequential):
raise ValueError('Expected `model` argument ' raise ValueError('Expected `model` argument '
'to be a `Sequential` model instance, ' 'to be a `Sequential` model instance, '
'but got:', model) 'but got:', model)
if not callable(layer_fn):
raise ValueError('Expected `layer_fn` argument to be a callable.')
# Use model._layers to ensure that all layers are cloned. The model's layers # Use model._layers to ensure that all layers are cloned. The model's layers
# property will exclude the initial InputLayer (if it exists) in the model, # property will exclude the initial InputLayer (if it exists) in the model,
# resulting in a different Sequential model structure. # resulting in a different Sequential model structure.
if input_tensors is None: if input_tensors is None:
if share_weights: layers = []
# In preserve weights case we still want the input layers to be cloned. for layer in model._layers:
layers = [] if isinstance(layer, InputLayer):
for layer in model._layers: layers.append(_clone_layer(layer))
if isinstance(layer, InputLayer): else:
layers.append(_clone_layer(layer)) layers.append(layer_fn(layer))
else:
layers.append(layer)
else:
layers = [_clone_layer(layer) for layer in model._layers]
return Sequential(layers=layers, name=model.name) return Sequential(layers=layers, name=model.name)
else: else:
# If input tensors are provided, the original model's InputLayer is # If input tensors are provided, the original model's InputLayer is
# overwritten with a different InputLayer. # overwritten with a different InputLayer.
layers = [ layers = [
layer for layer in model._layers if not isinstance(layer, InputLayer)] layer_fn(layer)
if not share_weights: for layer in model._layers
layers = [_clone_layer(layer) for layer in layers] if not isinstance(layer, InputLayer)
]
if len(generic_utils.to_list(input_tensors)) != 1: if len(generic_utils.to_list(input_tensors)) != 1:
raise ValueError('To clone a `Sequential` model, we expect ' raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor ' ' at most one tensor '

View File

@ -104,7 +104,7 @@ class TestModelCloning(keras_parameterized.TestCase):
if share_weights: if share_weights:
clone_fn = functools.partial( clone_fn = functools.partial(
keras.models._clone_sequential_model, share_weights=True) keras.models._clone_sequential_model, layer_fn=models.share_weights)
else: else:
clone_fn = keras.models.clone_model clone_fn = keras.models.clone_model
@ -151,7 +151,7 @@ class TestModelCloning(keras_parameterized.TestCase):
def test_clone_functional_model(self, share_weights): def test_clone_functional_model(self, share_weights):
if share_weights: if share_weights:
clone_fn = functools.partial( clone_fn = functools.partial(
keras.models._clone_functional_model, share_weights=True) keras.models._clone_functional_model, layer_fn=models.share_weights)
else: else:
clone_fn = keras.models.clone_model clone_fn = keras.models.clone_model
@ -212,7 +212,7 @@ class TestModelCloning(keras_parameterized.TestCase):
def test_clone_functional_with_masking(self, share_weights): def test_clone_functional_with_masking(self, share_weights):
if share_weights: if share_weights:
clone_fn = functools.partial( clone_fn = functools.partial(
keras.models._clone_functional_model, share_weights=True) keras.models._clone_functional_model, layer_fn=models.share_weights)
else: else:
clone_fn = keras.models.clone_model clone_fn = keras.models.clone_model