Remove the cloning
argument to compile()
.
Keras models are distributed by cloning in graph mode and without cloning in eager mode as of the change # 258652546. PiperOrigin-RevId: 258893048
This commit is contained in:
parent
47e650a119
commit
e5ced34f45
@ -638,9 +638,7 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
|||||||
def is_distributing_by_cloning(model):
|
def is_distributing_by_cloning(model):
|
||||||
"""Decide whether this model is going to be distributed via cloning.
|
"""Decide whether this model is going to be distributed via cloning.
|
||||||
|
|
||||||
We are going to distribute the model by cloning if the user has signaled
|
We are going to distribute the model by cloning in graph mode.
|
||||||
that intent by setting `cloning=True` in `Model.compile()` unless we are in
|
|
||||||
graph mode.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Keras model to distribute.
|
model: Keras model to distribute.
|
||||||
|
@ -245,7 +245,6 @@ class Model(network.Network):
|
|||||||
if ((sample_weight_mode is not None)
|
if ((sample_weight_mode is not None)
|
||||||
or (target_tensors is not None)
|
or (target_tensors is not None)
|
||||||
or (weighted_metrics is not None)
|
or (weighted_metrics is not None)
|
||||||
or (kwargs.get('cloning', False))
|
|
||||||
or not context.executing_eagerly()):
|
or not context.executing_eagerly()):
|
||||||
# Fallback out of things that aren't supported with v2 loops
|
# Fallback out of things that aren't supported with v2 loops
|
||||||
self._run_distributed = False
|
self._run_distributed = False
|
||||||
@ -269,12 +268,6 @@ class Model(network.Network):
|
|||||||
self._distribution_strategy = (
|
self._distribution_strategy = (
|
||||||
distribution_strategy_context.get_strategy())
|
distribution_strategy_context.get_strategy())
|
||||||
|
|
||||||
# Check whether the experimental feature of distributing the Model without
|
|
||||||
# cloning is requested.
|
|
||||||
# TODO(b/124517980, b/124377929): Remove this temporary undocumented way
|
|
||||||
# of enabling the feature and graduate it to the main distributed code path.
|
|
||||||
self._cloning = kwargs.pop('cloning', False)
|
|
||||||
|
|
||||||
if not self._run_distributed:
|
if not self._run_distributed:
|
||||||
self._validate_compile_param_for_distribution_strategy(self.run_eagerly,
|
self._validate_compile_param_for_distribution_strategy(self.run_eagerly,
|
||||||
sample_weight_mode,
|
sample_weight_mode,
|
||||||
@ -479,8 +472,7 @@ class Model(network.Network):
|
|||||||
# TODO(scottzhu): Finish getting sequences working with the v2 loops.
|
# TODO(scottzhu): Finish getting sequences working with the v2 loops.
|
||||||
and not isinstance(inputs, (data_utils.Sequence))
|
and not isinstance(inputs, (data_utils.Sequence))
|
||||||
and not distributed_training_utils.is_tpu_strategy(
|
and not distributed_training_utils.is_tpu_strategy(
|
||||||
self._distribution_strategy)
|
self._distribution_strategy)):
|
||||||
and not getattr(self, '_cloning', False)):
|
|
||||||
return training_v2.Loop()
|
return training_v2.Loop()
|
||||||
|
|
||||||
# Case 1: distribution strategy.
|
# Case 1: distribution strategy.
|
||||||
@ -2417,8 +2409,7 @@ class Model(network.Network):
|
|||||||
loss_weights=self.loss_weights,
|
loss_weights=self.loss_weights,
|
||||||
target_tensors=target_tensors,
|
target_tensors=target_tensors,
|
||||||
run_eagerly=self.run_eagerly,
|
run_eagerly=self.run_eagerly,
|
||||||
run_distributed=self._run_distributed,
|
run_distributed=self._run_distributed)
|
||||||
cloning=self._cloning)
|
|
||||||
|
|
||||||
# In graph mode, if we had just set inputs and targets as symbolic tensors
|
# In graph mode, if we had just set inputs and targets as symbolic tensors
|
||||||
# by invoking build and compile on the model respectively, we do not have to
|
# by invoking build and compile on the model respectively, we do not have to
|
||||||
|
@ -283,8 +283,6 @@ def model_iteration(model,
|
|||||||
# Get outputs.
|
# Get outputs.
|
||||||
try:
|
try:
|
||||||
# `ins` can be callable in tf.distribute.Strategy + eager case.
|
# `ins` can be callable in tf.distribute.Strategy + eager case.
|
||||||
# TODO(b/134179782): Simplify this condition when cloning never
|
|
||||||
# happens.
|
|
||||||
if not callable(ins) or (
|
if not callable(ins) or (
|
||||||
model._distribution_strategy and
|
model._distribution_strategy and
|
||||||
not distributed_training_utils.is_distributing_by_cloning(model)):
|
not distributed_training_utils.is_distributing_by_cloning(model)):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user