Add a warning message to users who attempt to use MirroredStrategy or MultiWorkerMirroredStrategy in pure eager mode.
PiperOrigin-RevId: 254138876
This commit is contained in:
parent
026bc91eae
commit
1df8a1708e
@ -43,6 +43,7 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -722,6 +723,14 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
return self._get_cross_device_ops().broadcast(tensor, destinations)
|
||||
|
||||
def _call_for_each_replica(self, fn, args, kwargs):
|
||||
if context.executing_eagerly():
|
||||
logging.log_first_n(logging.WARN, "Using %s eagerly has significant "
|
||||
"overhead currently. We will be working on improving "
|
||||
"this in the future, but for now please wrap "
|
||||
"`call_for_each_replica` or `experimental_run` or "
|
||||
"`experimental_run_v2` inside a tf.function to get "
|
||||
"the best performance." %
|
||||
self._container_strategy().__class__.__name__, 5)
|
||||
return _call_for_each_replica(self._container_strategy(), self._device_map,
|
||||
fn, args, kwargs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user