diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index fbc8923e050..b77163cb97a 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -619,6 +619,11 @@ class StrategyBase(object): if not hasattr(extended, "_retrace_functions_for_each_device"): # pylint: disable=protected-access + # `extended._retrace_functions_for_each_device` dictates + # 1) whether all the ops created inside function will have devices + # inherited from outer stack, and + # 2) whether the same function will be retraced when it is called on + # different devices. try: extended._retrace_functions_for_each_device = ( len(extended.worker_devices) > 1)