Add a callback to the model cloning API.

PiperOrigin-RevId: 239097271
This commit is contained in:
Yunlu Li 2019-03-18 17:37:31 -07:00 committed by TensorFlower Gardener
parent c79154ef49
commit dc0137f16a
3 changed files with 50 additions and 39 deletions

View File

@ -632,11 +632,11 @@ def _build_network_on_replica(model, mode, inputs=None, targets=None):
# We rely on the internal methods to avoid having share_weights weights in the
# public API.
if isinstance(model, sequential.Sequential):
updated_model = models._clone_sequential_model(model, input_tensors=inputs,
share_weights=True)
updated_model = models._clone_sequential_model(
model, input_tensors=inputs, layer_fn=models.share_weights)
else:
updated_model = models._clone_functional_model(model, input_tensors=inputs,
share_weights=True)
updated_model = models._clone_functional_model(
model, input_tensors=inputs, layer_fn=models.share_weights)
# Recast all low precision outputs back to float32 since we only casted
# the inputs to bfloat16 and not targets. This is done so that we can preserve

View File

@ -45,28 +45,35 @@ model_from_config = model_config.model_from_config
model_from_yaml = model_config.model_from_yaml
model_from_json = model_config.model_from_json
# Callable used to clone a layer with weights preserved.
def share_weights(layer):
return layer
def _clone_layer(layer):
return layer.__class__.from_config(layer.get_config())
def _clone_functional_model(model, input_tensors=None, share_weights=False):
def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
"""Clone a functional `Model` instance.
Model cloning is similar to calling a model on new inputs,
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
Input layers are always cloned.
Arguments:
model: Instance of `Model`.
input_tensors: optional list of input tensors
to build the model upon. If not provided,
placeholders will be created.
share_weights: flag to enable sharing of non-input layers between the
cloned and original model. Note this still clones the input layers.
This is required when we create a per-replica copy of the model with
distribution strategy; we want the weights to be shared but still
feed inputs separately so we create new input layers.
layer_fn: callable to be applied on non-input layers in the model. By
default it clones the layer. Another example is to preserve the layer
to share the weights. This is required when we create a per-replica
copy of the model with distribution strategy; we want the weights to
be shared but still feed inputs separately so we create new input
layers.
Returns:
An instance of `Model` reproducing the behavior
@ -74,7 +81,8 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
using newly instantiated weights.
Raises:
ValueError: in case of invalid `model` argument value.
ValueError: in case of invalid `model` argument value or `layer_fn`
argument value.
"""
if not isinstance(model, Model):
raise ValueError('Expected `model` argument '
@ -123,6 +131,9 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
for x, y in zip(model.inputs, input_tensors):
tensor_map[x] = y
if not callable(layer_fn):
raise ValueError('Expected `layer_fn` argument to be a callable.')
# Iterated over every node in the reference model, in depth order.
depth_keys = list(model._nodes_by_depth.keys())
depth_keys.sort(reverse=True)
@ -134,11 +145,9 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
# Get or create layer.
if layer not in layer_map:
if not share_weights:
# Clone layer.
new_layer = _clone_layer(layer)
layer_map[layer] = new_layer
layer = new_layer
new_layer = layer_fn(layer)
layer_map[layer] = new_layer
layer = new_layer
else:
# Reuse previously cloned layer.
layer = layer_map[layer]
@ -172,7 +181,7 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False):
return Model(input_tensors, output_tensors, name=model.name)
def _clone_sequential_model(model, input_tensors=None, share_weights=False):
def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
"""Clone a `Sequential` model instance.
Model cloning is similar to calling a model on new inputs,
@ -184,11 +193,12 @@ def _clone_sequential_model(model, input_tensors=None, share_weights=False):
input_tensors: optional list of input tensors
to build the model upon. If not provided,
placeholders will be created.
share_weights: flag to enable sharing of non-input layers between the
cloned and original model. Note this still clones the input layers.
This is required when we create a per-replica copy of the model with
distribution strategy; we want the weights to be shared but still
feed inputs separately so we create new input layers.
layer_fn: callable to be applied on non-input layers in the model. By
default it clones the layer. Another example is to preserve the layer
to share the weights. This is required when we create a per-replica
copy of the model with distribution strategy; we want the weights to
be shared but still feed inputs separately so we create new input
layers.
Returns:
An instance of `Sequential` reproducing the behavior
@ -196,35 +206,36 @@ def _clone_sequential_model(model, input_tensors=None, share_weights=False):
using newly instantiated weights.
Raises:
ValueError: in case of invalid `model` argument value.
ValueError: in case of invalid `model` argument value or `layer_fn`
argument value.
"""
if not isinstance(model, Sequential):
raise ValueError('Expected `model` argument '
'to be a `Sequential` model instance, '
'but got:', model)
if not callable(layer_fn):
raise ValueError('Expected `layer_fn` argument to be a callable.')
# Use model._layers to ensure that all layers are cloned. The model's layers
# property will exclude the initial InputLayer (if it exists) in the model,
# resulting in a different Sequential model structure.
if input_tensors is None:
if share_weights:
# In preserve weights case we still want the input layers to be cloned.
layers = []
for layer in model._layers:
if isinstance(layer, InputLayer):
layers.append(_clone_layer(layer))
else:
layers.append(layer)
else:
layers = [_clone_layer(layer) for layer in model._layers]
layers = []
for layer in model._layers:
if isinstance(layer, InputLayer):
layers.append(_clone_layer(layer))
else:
layers.append(layer_fn(layer))
return Sequential(layers=layers, name=model.name)
else:
# If input tensors are provided, the original model's InputLayer is
# overwritten with a different InputLayer.
layers = [
layer for layer in model._layers if not isinstance(layer, InputLayer)]
if not share_weights:
layers = [_clone_layer(layer) for layer in layers]
layer_fn(layer)
for layer in model._layers
if not isinstance(layer, InputLayer)
]
if len(generic_utils.to_list(input_tensors)) != 1:
raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor '

View File

@ -104,7 +104,7 @@ class TestModelCloning(keras_parameterized.TestCase):
if share_weights:
clone_fn = functools.partial(
keras.models._clone_sequential_model, share_weights=True)
keras.models._clone_sequential_model, layer_fn=models.share_weights)
else:
clone_fn = keras.models.clone_model
@ -151,7 +151,7 @@ class TestModelCloning(keras_parameterized.TestCase):
def test_clone_functional_model(self, share_weights):
if share_weights:
clone_fn = functools.partial(
keras.models._clone_functional_model, share_weights=True)
keras.models._clone_functional_model, layer_fn=models.share_weights)
else:
clone_fn = keras.models.clone_model
@ -212,7 +212,7 @@ class TestModelCloning(keras_parameterized.TestCase):
def test_clone_functional_with_masking(self, share_weights):
if share_weights:
clone_fn = functools.partial(
keras.models._clone_functional_model, share_weights=True)
keras.models._clone_functional_model, layer_fn=models.share_weights)
else:
clone_fn = keras.models.clone_model