Reduce retracing warnings.
Previously the code below would complain six times, thus spamming the users: ``` @tf.function def my_func(x): return x for i in range(10): my_func(i) ``` Limit the number of retracing warnings for each group of traced functions with a unique key - \_\_code\_\_ (or Python function instance if not available). For the example above the number of warnings will be reduced to two. PiperOrigin-RevId: 346474548 Change-Id: I0dfa54c4b72c5b9e29d98da3c6950ffb0f5c3fca
This commit is contained in:
parent
93dfb9b68f
commit
0923a15fa7
@ -51,84 +51,94 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
|
||||
FREQUENT_TRACING_WARNING_THRESHOLD = 5
|
||||
FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2
|
||||
|
||||
|
||||
class _CallCounter(object):
|
||||
class _FrequentTracingDetector(object):
|
||||
"""Class keeping track of how many recent calls triggered tracing."""
|
||||
|
||||
__slots__ = ["_max_call_history", "_calls_per_tracings", "call_count"]
|
||||
__slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"]
|
||||
|
||||
def __init__(self, max_call_history):
|
||||
self._max_call_history = max_call_history
|
||||
def __init__(self):
|
||||
self._calls_per_tracings = []
|
||||
self.call_count = 0
|
||||
self._total_warning_count = 0
|
||||
self._call_count = 0
|
||||
|
||||
def called_with_tracing(self):
|
||||
self.call_count += 1
|
||||
def called_with_tracing(self, function_name, omit_warning):
|
||||
"""Updates the list of most recent calls' tracing information.
|
||||
|
||||
Warns the user when recent calls caused retracing too often.
|
||||
|
||||
Args:
|
||||
function_name: the python function being traced.
|
||||
omit_warning: If 'True', this call will not warn the user even if
|
||||
retracing happens too often.
|
||||
"""
|
||||
self._call_count += 1
|
||||
self._calls_per_tracings.append(1)
|
||||
|
||||
while self._calls_per_tracings:
|
||||
if self.call_count - self._calls_per_tracings[0] > self._max_call_history:
|
||||
self.call_count -= self._calls_per_tracings.pop(0)
|
||||
if (self._call_count - self._calls_per_tracings[0] >
|
||||
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY):
|
||||
self._call_count -= self._calls_per_tracings.pop(0)
|
||||
else:
|
||||
break
|
||||
|
||||
if (omit_warning or self._total_warning_count >=
|
||||
FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR):
|
||||
return
|
||||
if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD:
|
||||
self._total_warning_count += 1
|
||||
logging.warning(
|
||||
"{} out of the last {} calls to {} triggered tf.function "
|
||||
"retracing. Tracing is expensive and the excessive number of "
|
||||
"tracings could be due to (1) creating @tf.function repeatedly in "
|
||||
"a loop, (2) passing tensors with different shapes, (3) passing "
|
||||
"Python objects instead of tensors. For (1), please define your "
|
||||
"@tf.function outside of the loop. For (2), @tf.function has "
|
||||
"experimental_relax_shapes=True option that relaxes argument "
|
||||
"shapes that can avoid unnecessary retracing. For (3), please "
|
||||
"refer to "
|
||||
"https://www.tensorflow.org/guide/function#controlling_retracing"
|
||||
" and https://www.tensorflow.org/api_docs/python/tf/function for "
|
||||
" more details.".format(
|
||||
len(self._calls_per_tracings), self._call_count, function_name))
|
||||
|
||||
def called_without_tracing(self):
|
||||
# We don't count tracing when users load a concrete function directly or
|
||||
# call get_concrete_function, so the first call can be not a tracing call.
|
||||
if not self._calls_per_tracings:
|
||||
self._calls_per_tracings = [0]
|
||||
self._calls_per_tracings[-1] += 1
|
||||
self.call_count += 1
|
||||
|
||||
def get_tracing_count(self):
|
||||
return len(self._calls_per_tracings)
|
||||
self._call_count += 1
|
||||
|
||||
|
||||
class _FrequentTracingDetector(object):
|
||||
"""Class for frequent retracing detection and warning."""
|
||||
class _FrequentTracingDetectorManager(object):
|
||||
"""Class for the management of all _FrequentTracingDetector objects."""
|
||||
|
||||
__slots__ = ["_counters", "_lock"]
|
||||
__slots__ = ["_detectors", "_lock"]
|
||||
|
||||
def __init__(self):
|
||||
self._counters = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
|
||||
self._detectors = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _get_counter(self, key):
|
||||
if key not in self._counters:
|
||||
self._counters[key] = _CallCounter(
|
||||
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
|
||||
return self._counters[key]
|
||||
def _get_detector(self, key):
|
||||
if key not in self._detectors:
|
||||
self._detectors[key] = _FrequentTracingDetector()
|
||||
return self._detectors[key]
|
||||
|
||||
def called_without_tracing(self, key):
|
||||
with self._lock:
|
||||
counter = self._get_counter(key)
|
||||
counter.called_without_tracing()
|
||||
detector = self._get_detector(key)
|
||||
detector.called_without_tracing()
|
||||
|
||||
def called_with_tracing(self, key, function_name, omit_warning):
|
||||
with self._lock:
|
||||
counter = self._get_counter(key)
|
||||
counter.called_with_tracing()
|
||||
if omit_warning:
|
||||
return
|
||||
if counter.get_tracing_count() >= FREQUENT_TRACING_WARNING_THRESHOLD:
|
||||
logging.warning(
|
||||
"{} out of the last {} calls to {} triggered tf.function "
|
||||
"retracing. Tracing is expensive and the excessive number of "
|
||||
"tracings could be due to (1) creating @tf.function repeatedly in "
|
||||
"a loop, (2) passing tensors with different shapes, (3) passing "
|
||||
"Python objects instead of tensors. For (1), please define your "
|
||||
"@tf.function outside of the loop. For (2), @tf.function has "
|
||||
"experimental_relax_shapes=True option that relaxes argument "
|
||||
"shapes that can avoid unnecessary retracing. For (3), please "
|
||||
"refer to "
|
||||
"https://www.tensorflow.org/guide/function#controlling_retracing"
|
||||
" and https://www.tensorflow.org/api_docs/python/tf/function for "
|
||||
" more details.".format(counter.get_tracing_count(),
|
||||
counter.call_count, function_name))
|
||||
detector = self._get_detector(key)
|
||||
detector.called_with_tracing(function_name, omit_warning)
|
||||
|
||||
|
||||
_frequent_tracing_detector = _FrequentTracingDetector()
|
||||
_frequent_tracing_detector_manager = _FrequentTracingDetectorManager()
|
||||
|
||||
|
||||
class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
|
||||
@ -794,10 +804,10 @@ class Function(object):
|
||||
|
||||
if context.executing_eagerly():
|
||||
if without_tracing:
|
||||
_frequent_tracing_detector.called_without_tracing(
|
||||
_frequent_tracing_detector_manager.called_without_tracing(
|
||||
self._key_for_call_stats)
|
||||
else:
|
||||
_frequent_tracing_detector.called_with_tracing(
|
||||
_frequent_tracing_detector_manager.called_with_tracing(
|
||||
self._key_for_call_stats, self._python_function,
|
||||
self._omit_frequent_tracing_warning)
|
||||
|
||||
|
@ -956,6 +956,18 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertLen(logs.output, 1)
|
||||
self.assertIn('Tracing is expensive', logs.output[0])
|
||||
|
||||
def test_retracing_warning_limits(self):
|
||||
|
||||
@def_function.function
|
||||
def my_func(x):
|
||||
return x
|
||||
|
||||
with self.assertLogs(level='WARN') as logs:
|
||||
for i in range(10):
|
||||
my_func(i)
|
||||
|
||||
self.assertLen(logs.output, 2)
|
||||
|
||||
def test_experimental_get_tracing_count_function(self):
|
||||
|
||||
@def_function.function
|
||||
|
Loading…
Reference in New Issue
Block a user