Add a warning message to users who attempt to use MirroredStrategy or MultiWorkerMirroredStrategy in pure eager mode.

PiperOrigin-RevId: 254138876
This commit is contained in:
Anjali Sridhar 2019-06-19 23:09:59 -07:00 committed by TensorFlower Gardener
parent 026bc91eae
commit 1df8a1708e

View File

@ -43,6 +43,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -722,6 +723,14 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
return self._get_cross_device_ops().broadcast(tensor, destinations)
def _call_for_each_replica(self, fn, args, kwargs):
if context.executing_eagerly():
logging.log_first_n(logging.WARN, "Using %s eagerly has significant "
"overhead currently. We will be working on improving "
"this in the future, but for now please wrap "
"`call_for_each_replica` or `experimental_run` or "
"`experimental_run_v2` inside a tf.function to get "
"the best performance." %
self._container_strategy().__class__.__name__, 5)
return _call_for_each_replica(self._container_strategy(), self._device_map,
fn, args, kwargs)