Remove the cloning argument to compile().

Keras models are distributed by cloning in graph mode and without cloning in eager mode as of the change # 258652546.

PiperOrigin-RevId: 258893048
This commit is contained in:
Igor Saprykin 2019-07-18 19:40:46 -07:00 committed by TensorFlower Gardener
parent 47e650a119
commit e5ced34f45
3 changed files with 3 additions and 16 deletions

View File

@ -638,9 +638,7 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
def is_distributing_by_cloning(model): def is_distributing_by_cloning(model):
"""Decide whether this model is going to be distributed via cloning. """Decide whether this model is going to be distributed via cloning.
We are going to distribute the model by cloning if the user has signaled We are going to distribute the model by cloning in graph mode.
that intent by setting `cloning=True` in `Model.compile()` unless we are in
graph mode.
Args: Args:
model: Keras model to distribute. model: Keras model to distribute.

View File

@ -245,7 +245,6 @@ class Model(network.Network):
if ((sample_weight_mode is not None) if ((sample_weight_mode is not None)
or (target_tensors is not None) or (target_tensors is not None)
or (weighted_metrics is not None) or (weighted_metrics is not None)
or (kwargs.get('cloning', False))
or not context.executing_eagerly()): or not context.executing_eagerly()):
# Fallback out of things that aren't supported with v2 loops # Fallback out of things that aren't supported with v2 loops
self._run_distributed = False self._run_distributed = False
@ -269,12 +268,6 @@ class Model(network.Network):
self._distribution_strategy = ( self._distribution_strategy = (
distribution_strategy_context.get_strategy()) distribution_strategy_context.get_strategy())
# Check whether the experimental feature of distributing the Model without
# cloning is requested.
# TODO(b/124517980, b/124377929): Remove this temporary undocumented way
# of enabling the feature and graduate it to the main distributed code path.
self._cloning = kwargs.pop('cloning', False)
if not self._run_distributed: if not self._run_distributed:
self._validate_compile_param_for_distribution_strategy(self.run_eagerly, self._validate_compile_param_for_distribution_strategy(self.run_eagerly,
sample_weight_mode, sample_weight_mode,
@ -479,8 +472,7 @@ class Model(network.Network):
# TODO(scottzhu): Finish getting sequences working with the v2 loops. # TODO(scottzhu): Finish getting sequences working with the v2 loops.
and not isinstance(inputs, (data_utils.Sequence)) and not isinstance(inputs, (data_utils.Sequence))
and not distributed_training_utils.is_tpu_strategy( and not distributed_training_utils.is_tpu_strategy(
self._distribution_strategy) self._distribution_strategy)):
and not getattr(self, '_cloning', False)):
return training_v2.Loop() return training_v2.Loop()
# Case 1: distribution strategy. # Case 1: distribution strategy.
@ -2417,8 +2409,7 @@ class Model(network.Network):
loss_weights=self.loss_weights, loss_weights=self.loss_weights,
target_tensors=target_tensors, target_tensors=target_tensors,
run_eagerly=self.run_eagerly, run_eagerly=self.run_eagerly,
run_distributed=self._run_distributed, run_distributed=self._run_distributed)
cloning=self._cloning)
# In graph mode, if we had just set inputs and targets as symbolic tensors # In graph mode, if we had just set inputs and targets as symbolic tensors
# by invoking build and compile on the model respectively, we do not have to # by invoking build and compile on the model respectively, we do not have to

View File

@ -283,8 +283,6 @@ def model_iteration(model,
# Get outputs. # Get outputs.
try: try:
# `ins` can be callable in tf.distribute.Strategy + eager case. # `ins` can be callable in tf.distribute.Strategy + eager case.
# TODO(b/134179782): Simplify this condition when cloning never
# happens.
if not callable(ins) or ( if not callable(ins) or (
model._distribution_strategy and model._distribution_strategy and
not distributed_training_utils.is_distributing_by_cloning(model)): not distributed_training_utils.is_distributing_by_cloning(model)):