Merge pull request #32258 from kkimdev/cherrypicks_5XVFU
[r2.0 CherryPick]: @tf.function: Show a warning message when tracing happens too frequently
This commit is contained in:
commit
f67991359e
@ -38,6 +38,41 @@ from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
|
||||
FREQUENT_TRACING_WARNING_THRESHOLD = 5
|
||||
|
||||
|
||||
class _CallCounter(object):
|
||||
"""Class keeping track of how many recent calls triggered tracing."""
|
||||
|
||||
def __init__(self, max_call_history):
|
||||
self._max_call_history = max_call_history
|
||||
self._calls_per_tracings = []
|
||||
self.call_count = 0
|
||||
|
||||
def called_with_tracing(self):
|
||||
self.call_count += 1
|
||||
self._calls_per_tracings.append(1)
|
||||
|
||||
while self._calls_per_tracings:
|
||||
if self.call_count - self._calls_per_tracings[0] > self._max_call_history:
|
||||
self.call_count -= self._calls_per_tracings.pop(0)
|
||||
else:
|
||||
break
|
||||
|
||||
def called_without_tracing(self):
|
||||
# TODO(kkimlabs): This is an unnecessary defensive check. Since this is last
|
||||
# minute CL before 2.0 release, I've decided to be very defensive here to
|
||||
# avoid a potential crash. Remove once we release 2.0.
|
||||
if not self._calls_per_tracings:
|
||||
return
|
||||
|
||||
self._calls_per_tracings[-1] += 1
|
||||
self.call_count += 1
|
||||
|
||||
def get_tracing_count(self):
|
||||
return len(self._calls_per_tracings)
|
||||
|
||||
|
||||
class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
|
||||
"""Variable which does not lift its initializer out of function context.
|
||||
@ -297,6 +332,7 @@ class Function(object):
|
||||
self._stateless_fn = None # GUARDED_BY(self._lock)
|
||||
self._descriptor_cache = weakref.WeakKeyDictionary()
|
||||
self._name = name
|
||||
self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
|
||||
|
||||
def _defun_with_scope(self, scope):
|
||||
"""Creates a defun wrapped inside a variable creator scope."""
|
||||
@ -406,11 +442,41 @@ class Function(object):
|
||||
self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
|
||||
self._python_function, self.input_signature)
|
||||
|
||||
def _get_tracing_count(self):
|
||||
result = self._stateless_fn.tracing_count if self._stateless_fn else 0
|
||||
result += self._stateful_fn.tracing_count if self._stateful_fn else 0
|
||||
return result
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
"""Calls the graph function."""
|
||||
"""Calls the graph function and warn too frequent tracings."""
|
||||
context.ensure_initialized()
|
||||
if RUN_FUNCTIONS_EAGERLY:
|
||||
return self._python_function(*args, **kwds)
|
||||
|
||||
tracing_count = self._get_tracing_count()
|
||||
result = self._call(*args, **kwds)
|
||||
if tracing_count == self._get_tracing_count():
|
||||
self._call_counter.called_without_tracing()
|
||||
return result
|
||||
|
||||
self._call_counter.called_with_tracing()
|
||||
recent_tracing_count = self._call_counter.get_tracing_count()
|
||||
if recent_tracing_count >= FREQUENT_TRACING_WARNING_THRESHOLD:
|
||||
logging.warning(
|
||||
"{} out of the last {} calls to {} triggered tf.function retracing. "
|
||||
"Tracing is expensive and the excessive number of tracings is likely "
|
||||
"due to passing python objects instead of tensors. Also, tf.function "
|
||||
"has experimental_relax_shapes=True option that relaxes argument "
|
||||
"shapes that can avoid unnecessary retracing. Please refer to "
|
||||
"https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args"
|
||||
" and https://www.tensorflow.org/api_docs/python/tf/function for more "
|
||||
"details.".format(recent_tracing_count, self._call_counter.call_count,
|
||||
self._python_function))
|
||||
|
||||
return result
|
||||
|
||||
def _call(self, *args, **kwds):
|
||||
"""Calls the graph function."""
|
||||
self._lock.acquire()
|
||||
if self._created_variables:
|
||||
# Release the lock early so that multiple threads can perform the call
|
||||
|
@ -1809,6 +1809,7 @@ class Function(object):
|
||||
self._function_cache = FunctionCache()
|
||||
self._function_attributes = attributes or {}
|
||||
self._capture_by_value = capture_by_value
|
||||
self.tracing_count = 0
|
||||
|
||||
self._lock = threading.Lock()
|
||||
# _descriptor_cache is a of instance of a class to an instance-specific
|
||||
@ -2011,6 +2012,8 @@ class Function(object):
|
||||
|
||||
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
||||
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
||||
self.tracing_count += 1
|
||||
|
||||
if self.input_signature is None:
|
||||
arglen = len(args)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user