Reduce retracing warnings.

Previously the code below would complain six times, thus spamming the users:

```
@tf.function
def my_func(x):
  return x

for i in range(10):
  my_func(i)
```

Limit the number of retracing warnings for each group of traced functions with a unique key - \_\_code\_\_ (or Python function instance if not available). For the example above the number of warnings will be reduced to two.

PiperOrigin-RevId: 346474548
Change-Id: I0dfa54c4b72c5b9e29d98da3c6950ffb0f5c3fca
This commit is contained in:
A. Unique TensorFlower 2020-12-08 21:32:06 -08:00 committed by TensorFlower Gardener
parent 93dfb9b68f
commit 0923a15fa7
2 changed files with 68 additions and 46 deletions

View File

@ -51,84 +51,94 @@ from tensorflow.python.util.tf_export import tf_export
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
FREQUENT_TRACING_WARNING_THRESHOLD = 5
FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2
class _CallCounter(object):
class _FrequentTracingDetector(object):
"""Class keeping track of how many recent calls triggered tracing."""
__slots__ = ["_max_call_history", "_calls_per_tracings", "call_count"]
__slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"]
def __init__(self, max_call_history):
self._max_call_history = max_call_history
def __init__(self):
self._calls_per_tracings = []
self.call_count = 0
self._total_warning_count = 0
self._call_count = 0
def called_with_tracing(self):
self.call_count += 1
def called_with_tracing(self, function_name, omit_warning):
"""Updates the list of most recent calls' tracing information.
Warns the user when recent calls caused retracing too often.
Args:
function_name: the python function being traced.
omit_warning: If 'True', this call will not warn the user even if
retracing happens too often.
"""
self._call_count += 1
self._calls_per_tracings.append(1)
while self._calls_per_tracings:
if self.call_count - self._calls_per_tracings[0] > self._max_call_history:
self.call_count -= self._calls_per_tracings.pop(0)
if (self._call_count - self._calls_per_tracings[0] >
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY):
self._call_count -= self._calls_per_tracings.pop(0)
else:
break
if (omit_warning or self._total_warning_count >=
FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR):
return
if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD:
self._total_warning_count += 1
logging.warning(
"{} out of the last {} calls to {} triggered tf.function "
"retracing. Tracing is expensive and the excessive number of "
"tracings could be due to (1) creating @tf.function repeatedly in "
"a loop, (2) passing tensors with different shapes, (3) passing "
"Python objects instead of tensors. For (1), please define your "
"@tf.function outside of the loop. For (2), @tf.function has "
"experimental_relax_shapes=True option that relaxes argument "
"shapes that can avoid unnecessary retracing. For (3), please "
"refer to "
"https://www.tensorflow.org/guide/function#controlling_retracing"
" and https://www.tensorflow.org/api_docs/python/tf/function for "
" more details.".format(
len(self._calls_per_tracings), self._call_count, function_name))
def called_without_tracing(self):
# We don't count tracing when users load a concrete function directly or
# call get_concrete_function, so the first call can be not a tracing call.
if not self._calls_per_tracings:
self._calls_per_tracings = [0]
self._calls_per_tracings[-1] += 1
self.call_count += 1
def get_tracing_count(self):
return len(self._calls_per_tracings)
self._call_count += 1
class _FrequentTracingDetector(object):
"""Class for frequent retracing detection and warning."""
class _FrequentTracingDetectorManager(object):
"""Class for the management of all _FrequentTracingDetector objects."""
__slots__ = ["_counters", "_lock"]
__slots__ = ["_detectors", "_lock"]
def __init__(self):
self._counters = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
self._detectors = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
self._lock = threading.Lock()
def _get_counter(self, key):
if key not in self._counters:
self._counters[key] = _CallCounter(
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
return self._counters[key]
def _get_detector(self, key):
if key not in self._detectors:
self._detectors[key] = _FrequentTracingDetector()
return self._detectors[key]
def called_without_tracing(self, key):
with self._lock:
counter = self._get_counter(key)
counter.called_without_tracing()
detector = self._get_detector(key)
detector.called_without_tracing()
def called_with_tracing(self, key, function_name, omit_warning):
with self._lock:
counter = self._get_counter(key)
counter.called_with_tracing()
if omit_warning:
return
if counter.get_tracing_count() >= FREQUENT_TRACING_WARNING_THRESHOLD:
logging.warning(
"{} out of the last {} calls to {} triggered tf.function "
"retracing. Tracing is expensive and the excessive number of "
"tracings could be due to (1) creating @tf.function repeatedly in "
"a loop, (2) passing tensors with different shapes, (3) passing "
"Python objects instead of tensors. For (1), please define your "
"@tf.function outside of the loop. For (2), @tf.function has "
"experimental_relax_shapes=True option that relaxes argument "
"shapes that can avoid unnecessary retracing. For (3), please "
"refer to "
"https://www.tensorflow.org/guide/function#controlling_retracing"
" and https://www.tensorflow.org/api_docs/python/tf/function for "
" more details.".format(counter.get_tracing_count(),
counter.call_count, function_name))
detector = self._get_detector(key)
detector.called_with_tracing(function_name, omit_warning)
_frequent_tracing_detector = _FrequentTracingDetector()
_frequent_tracing_detector_manager = _FrequentTracingDetectorManager()
class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
@ -794,10 +804,10 @@ class Function(object):
if context.executing_eagerly():
if without_tracing:
_frequent_tracing_detector.called_without_tracing(
_frequent_tracing_detector_manager.called_without_tracing(
self._key_for_call_stats)
else:
_frequent_tracing_detector.called_with_tracing(
_frequent_tracing_detector_manager.called_with_tracing(
self._key_for_call_stats, self._python_function,
self._omit_frequent_tracing_warning)

View File

@ -956,6 +956,18 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
def test_retracing_warning_limits(self):
@def_function.function
def my_func(x):
return x
with self.assertLogs(level='WARN') as logs:
for i in range(10):
my_func(i)
self.assertLen(logs.output, 2)
def test_experimental_get_tracing_count_function(self):
@def_function.function