Reduce retracing warnings.
Previously the code below would complain six times, thus spamming the users: ``` @tf.function def my_func(x): return x for i in range(10): my_func(i) ``` Limit the number of retracing warnings for each group of traced functions with a unique key - \_\_code\_\_ (or Python function instance if not available). For the example above the number of warnings will be reduced to two. PiperOrigin-RevId: 346474548 Change-Id: I0dfa54c4b72c5b9e29d98da3c6950ffb0f5c3fca
This commit is contained in:
parent
93dfb9b68f
commit
0923a15fa7
@ -51,84 +51,94 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
|
|
||||||
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
|
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
|
||||||
FREQUENT_TRACING_WARNING_THRESHOLD = 5
|
FREQUENT_TRACING_WARNING_THRESHOLD = 5
|
||||||
|
FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2
|
||||||
|
|
||||||
|
|
||||||
class _CallCounter(object):
|
class _FrequentTracingDetector(object):
|
||||||
"""Class keeping track of how many recent calls triggered tracing."""
|
"""Class keeping track of how many recent calls triggered tracing."""
|
||||||
|
|
||||||
__slots__ = ["_max_call_history", "_calls_per_tracings", "call_count"]
|
__slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"]
|
||||||
|
|
||||||
def __init__(self, max_call_history):
|
def __init__(self):
|
||||||
self._max_call_history = max_call_history
|
|
||||||
self._calls_per_tracings = []
|
self._calls_per_tracings = []
|
||||||
self.call_count = 0
|
self._total_warning_count = 0
|
||||||
|
self._call_count = 0
|
||||||
|
|
||||||
def called_with_tracing(self):
|
def called_with_tracing(self, function_name, omit_warning):
|
||||||
self.call_count += 1
|
"""Updates the list of most recent calls' tracing information.
|
||||||
|
|
||||||
|
Warns the user when recent calls caused retracing too often.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function_name: the python function being traced.
|
||||||
|
omit_warning: If 'True', this call will not warn the user even if
|
||||||
|
retracing happens too often.
|
||||||
|
"""
|
||||||
|
self._call_count += 1
|
||||||
self._calls_per_tracings.append(1)
|
self._calls_per_tracings.append(1)
|
||||||
|
|
||||||
while self._calls_per_tracings:
|
while self._calls_per_tracings:
|
||||||
if self.call_count - self._calls_per_tracings[0] > self._max_call_history:
|
if (self._call_count - self._calls_per_tracings[0] >
|
||||||
self.call_count -= self._calls_per_tracings.pop(0)
|
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY):
|
||||||
|
self._call_count -= self._calls_per_tracings.pop(0)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if (omit_warning or self._total_warning_count >=
|
||||||
|
FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR):
|
||||||
|
return
|
||||||
|
if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD:
|
||||||
|
self._total_warning_count += 1
|
||||||
|
logging.warning(
|
||||||
|
"{} out of the last {} calls to {} triggered tf.function "
|
||||||
|
"retracing. Tracing is expensive and the excessive number of "
|
||||||
|
"tracings could be due to (1) creating @tf.function repeatedly in "
|
||||||
|
"a loop, (2) passing tensors with different shapes, (3) passing "
|
||||||
|
"Python objects instead of tensors. For (1), please define your "
|
||||||
|
"@tf.function outside of the loop. For (2), @tf.function has "
|
||||||
|
"experimental_relax_shapes=True option that relaxes argument "
|
||||||
|
"shapes that can avoid unnecessary retracing. For (3), please "
|
||||||
|
"refer to "
|
||||||
|
"https://www.tensorflow.org/guide/function#controlling_retracing"
|
||||||
|
" and https://www.tensorflow.org/api_docs/python/tf/function for "
|
||||||
|
" more details.".format(
|
||||||
|
len(self._calls_per_tracings), self._call_count, function_name))
|
||||||
|
|
||||||
def called_without_tracing(self):
|
def called_without_tracing(self):
|
||||||
# We don't count tracing when users load a concrete function directly or
|
# We don't count tracing when users load a concrete function directly or
|
||||||
# call get_concrete_function, so the first call can be not a tracing call.
|
# call get_concrete_function, so the first call can be not a tracing call.
|
||||||
if not self._calls_per_tracings:
|
if not self._calls_per_tracings:
|
||||||
self._calls_per_tracings = [0]
|
self._calls_per_tracings = [0]
|
||||||
self._calls_per_tracings[-1] += 1
|
self._calls_per_tracings[-1] += 1
|
||||||
self.call_count += 1
|
self._call_count += 1
|
||||||
|
|
||||||
def get_tracing_count(self):
|
|
||||||
return len(self._calls_per_tracings)
|
|
||||||
|
|
||||||
|
|
||||||
class _FrequentTracingDetector(object):
|
class _FrequentTracingDetectorManager(object):
|
||||||
"""Class for frequent retracing detection and warning."""
|
"""Class for the management of all _FrequentTracingDetector objects."""
|
||||||
|
|
||||||
__slots__ = ["_counters", "_lock"]
|
__slots__ = ["_detectors", "_lock"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._counters = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
|
self._detectors = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def _get_counter(self, key):
|
def _get_detector(self, key):
|
||||||
if key not in self._counters:
|
if key not in self._detectors:
|
||||||
self._counters[key] = _CallCounter(
|
self._detectors[key] = _FrequentTracingDetector()
|
||||||
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
|
return self._detectors[key]
|
||||||
return self._counters[key]
|
|
||||||
|
|
||||||
def called_without_tracing(self, key):
|
def called_without_tracing(self, key):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
counter = self._get_counter(key)
|
detector = self._get_detector(key)
|
||||||
counter.called_without_tracing()
|
detector.called_without_tracing()
|
||||||
|
|
||||||
def called_with_tracing(self, key, function_name, omit_warning):
|
def called_with_tracing(self, key, function_name, omit_warning):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
counter = self._get_counter(key)
|
detector = self._get_detector(key)
|
||||||
counter.called_with_tracing()
|
detector.called_with_tracing(function_name, omit_warning)
|
||||||
if omit_warning:
|
|
||||||
return
|
|
||||||
if counter.get_tracing_count() >= FREQUENT_TRACING_WARNING_THRESHOLD:
|
|
||||||
logging.warning(
|
|
||||||
"{} out of the last {} calls to {} triggered tf.function "
|
|
||||||
"retracing. Tracing is expensive and the excessive number of "
|
|
||||||
"tracings could be due to (1) creating @tf.function repeatedly in "
|
|
||||||
"a loop, (2) passing tensors with different shapes, (3) passing "
|
|
||||||
"Python objects instead of tensors. For (1), please define your "
|
|
||||||
"@tf.function outside of the loop. For (2), @tf.function has "
|
|
||||||
"experimental_relax_shapes=True option that relaxes argument "
|
|
||||||
"shapes that can avoid unnecessary retracing. For (3), please "
|
|
||||||
"refer to "
|
|
||||||
"https://www.tensorflow.org/guide/function#controlling_retracing"
|
|
||||||
" and https://www.tensorflow.org/api_docs/python/tf/function for "
|
|
||||||
" more details.".format(counter.get_tracing_count(),
|
|
||||||
counter.call_count, function_name))
|
|
||||||
|
|
||||||
|
|
||||||
_frequent_tracing_detector = _FrequentTracingDetector()
|
_frequent_tracing_detector_manager = _FrequentTracingDetectorManager()
|
||||||
|
|
||||||
|
|
||||||
class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
|
class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
|
||||||
@ -794,10 +804,10 @@ class Function(object):
|
|||||||
|
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
if without_tracing:
|
if without_tracing:
|
||||||
_frequent_tracing_detector.called_without_tracing(
|
_frequent_tracing_detector_manager.called_without_tracing(
|
||||||
self._key_for_call_stats)
|
self._key_for_call_stats)
|
||||||
else:
|
else:
|
||||||
_frequent_tracing_detector.called_with_tracing(
|
_frequent_tracing_detector_manager.called_with_tracing(
|
||||||
self._key_for_call_stats, self._python_function,
|
self._key_for_call_stats, self._python_function,
|
||||||
self._omit_frequent_tracing_warning)
|
self._omit_frequent_tracing_warning)
|
||||||
|
|
||||||
|
@ -956,6 +956,18 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertLen(logs.output, 1)
|
self.assertLen(logs.output, 1)
|
||||||
self.assertIn('Tracing is expensive', logs.output[0])
|
self.assertIn('Tracing is expensive', logs.output[0])
|
||||||
|
|
||||||
|
def test_retracing_warning_limits(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def my_func(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
with self.assertLogs(level='WARN') as logs:
|
||||||
|
for i in range(10):
|
||||||
|
my_func(i)
|
||||||
|
|
||||||
|
self.assertLen(logs.output, 2)
|
||||||
|
|
||||||
def test_experimental_get_tracing_count_function(self):
|
def test_experimental_get_tracing_count_function(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
|
Loading…
x
Reference in New Issue
Block a user