diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py index 79836f80179..0deeadf1ca4 100644 --- a/tensorflow/python/keras/engine/distributed_training_utils.py +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -632,11 +632,11 @@ def _build_network_on_replica(model, mode, inputs=None, targets=None): # We rely on the internal methods to avoid having share_weights weights in the # public API. if isinstance(model, sequential.Sequential): - updated_model = models._clone_sequential_model(model, input_tensors=inputs, - share_weights=True) + updated_model = models._clone_sequential_model( + model, input_tensors=inputs, layer_fn=models.share_weights) else: - updated_model = models._clone_functional_model(model, input_tensors=inputs, - share_weights=True) + updated_model = models._clone_functional_model( + model, input_tensors=inputs, layer_fn=models.share_weights) # Recast all low precision outputs back to float32 since we only casted # the inputs to bfloat16 and not targets. This is done so that we can preserve diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index e4371c2a93d..80cb17a0f99 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -45,28 +45,35 @@ model_from_config = model_config.model_from_config model_from_yaml = model_config.model_from_yaml model_from_json = model_config.model_from_json +# Callable used to clone a layer with weights preserved. +def share_weights(layer): + return layer + def _clone_layer(layer): return layer.__class__.from_config(layer.get_config()) -def _clone_functional_model(model, input_tensors=None, share_weights=False): +def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer): """Clone a functional `Model` instance. Model cloning is similar to calling a model on new inputs, except that it creates new layers (and thus new weights) instead of sharing the weights of the existing layers. + Input layers are always cloned. + Arguments: model: Instance of `Model`. input_tensors: optional list of input tensors to build the model upon. If not provided, placeholders will be created. - share_weights: flag to enable sharing of non-input layers between the - cloned and original model. Note this still clones the input layers. - This is required when we create a per-replica copy of the model with - distribution strategy; we want the weights to be shared but still - feed inputs separately so we create new input layers. + layer_fn: callable to be applied on non-input layers in the model. By + default it clones the layer. Another example is to preserve the layer + to share the weights. This is required when we create a per-replica + copy of the model with distribution strategy; we want the weights to + be shared but still feed inputs separately so we create new input + layers. Returns: An instance of `Model` reproducing the behavior @@ -74,7 +81,8 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False): using newly instantiated weights. Raises: - ValueError: in case of invalid `model` argument value. + ValueError: in case of invalid `model` argument value or `layer_fn` + argument value. """ if not isinstance(model, Model): raise ValueError('Expected `model` argument ' @@ -123,6 +131,9 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False): for x, y in zip(model.inputs, input_tensors): tensor_map[x] = y + if not callable(layer_fn): + raise ValueError('Expected `layer_fn` argument to be a callable.') + # Iterated over every node in the reference model, in depth order. depth_keys = list(model._nodes_by_depth.keys()) depth_keys.sort(reverse=True) @@ -134,11 +145,9 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False): # Get or create layer. if layer not in layer_map: - if not share_weights: - # Clone layer. - new_layer = _clone_layer(layer) - layer_map[layer] = new_layer - layer = new_layer + new_layer = layer_fn(layer) + layer_map[layer] = new_layer + layer = new_layer else: # Reuse previously cloned layer. layer = layer_map[layer] @@ -172,7 +181,7 @@ def _clone_functional_model(model, input_tensors=None, share_weights=False): return Model(input_tensors, output_tensors, name=model.name) -def _clone_sequential_model(model, input_tensors=None, share_weights=False): +def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer): """Clone a `Sequential` model instance. Model cloning is similar to calling a model on new inputs, @@ -184,11 +193,12 @@ def _clone_sequential_model(model, input_tensors=None, share_weights=False): input_tensors: optional list of input tensors to build the model upon. If not provided, placeholders will be created. - share_weights: flag to enable sharing of non-input layers between the - cloned and original model. Note this still clones the input layers. - This is required when we create a per-replica copy of the model with - distribution strategy; we want the weights to be shared but still - feed inputs separately so we create new input layers. + layer_fn: callable to be applied on non-input layers in the model. By + default it clones the layer. Another example is to preserve the layer + to share the weights. This is required when we create a per-replica + copy of the model with distribution strategy; we want the weights to + be shared but still feed inputs separately so we create new input + layers. Returns: An instance of `Sequential` reproducing the behavior @@ -196,35 +206,36 @@ def _clone_sequential_model(model, input_tensors=None, share_weights=False): using newly instantiated weights. Raises: - ValueError: in case of invalid `model` argument value. + ValueError: in case of invalid `model` argument value or `layer_fn` + argument value. """ if not isinstance(model, Sequential): raise ValueError('Expected `model` argument ' 'to be a `Sequential` model instance, ' 'but got:', model) + if not callable(layer_fn): + raise ValueError('Expected `layer_fn` argument to be a callable.') + # Use model._layers to ensure that all layers are cloned. The model's layers # property will exclude the initial InputLayer (if it exists) in the model, # resulting in a different Sequential model structure. if input_tensors is None: - if share_weights: - # In preserve weights case we still want the input layers to be cloned. - layers = [] - for layer in model._layers: - if isinstance(layer, InputLayer): - layers.append(_clone_layer(layer)) - else: - layers.append(layer) - else: - layers = [_clone_layer(layer) for layer in model._layers] + layers = [] + for layer in model._layers: + if isinstance(layer, InputLayer): + layers.append(_clone_layer(layer)) + else: + layers.append(layer_fn(layer)) return Sequential(layers=layers, name=model.name) else: # If input tensors are provided, the original model's InputLayer is # overwritten with a different InputLayer. layers = [ - layer for layer in model._layers if not isinstance(layer, InputLayer)] - if not share_weights: - layers = [_clone_layer(layer) for layer in layers] + layer_fn(layer) + for layer in model._layers + if not isinstance(layer, InputLayer) + ] if len(generic_utils.to_list(input_tensors)) != 1: raise ValueError('To clone a `Sequential` model, we expect ' ' at most one tensor ' diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py index f429aba498d..0ef7323fe5e 100644 --- a/tensorflow/python/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -104,7 +104,7 @@ class TestModelCloning(keras_parameterized.TestCase): if share_weights: clone_fn = functools.partial( - keras.models._clone_sequential_model, share_weights=True) + keras.models._clone_sequential_model, layer_fn=models.share_weights) else: clone_fn = keras.models.clone_model @@ -151,7 +151,7 @@ class TestModelCloning(keras_parameterized.TestCase): def test_clone_functional_model(self, share_weights): if share_weights: clone_fn = functools.partial( - keras.models._clone_functional_model, share_weights=True) + keras.models._clone_functional_model, layer_fn=models.share_weights) else: clone_fn = keras.models.clone_model @@ -212,7 +212,7 @@ class TestModelCloning(keras_parameterized.TestCase): def test_clone_functional_with_masking(self, share_weights): if share_weights: clone_fn = functools.partial( - keras.models._clone_functional_model, share_weights=True) + keras.models._clone_functional_model, layer_fn=models.share_weights) else: clone_fn = keras.models.clone_model