Merge pull request #32258 from kkimdev/cherrypicks_5XVFU
[r2.0 CherryPick]: @tf.function: Show a warning message when tracing happens too frequently
This commit is contained in:
commit
f67991359e
@ -38,6 +38,41 @@ from tensorflow.python.util import object_identity
|
|||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
|
||||||
|
FREQUENT_TRACING_WARNING_THRESHOLD = 5
|
||||||
|
|
||||||
|
|
||||||
|
class _CallCounter(object):
|
||||||
|
"""Class keeping track of how many recent calls triggered tracing."""
|
||||||
|
|
||||||
|
def __init__(self, max_call_history):
|
||||||
|
self._max_call_history = max_call_history
|
||||||
|
self._calls_per_tracings = []
|
||||||
|
self.call_count = 0
|
||||||
|
|
||||||
|
def called_with_tracing(self):
|
||||||
|
self.call_count += 1
|
||||||
|
self._calls_per_tracings.append(1)
|
||||||
|
|
||||||
|
while self._calls_per_tracings:
|
||||||
|
if self.call_count - self._calls_per_tracings[0] > self._max_call_history:
|
||||||
|
self.call_count -= self._calls_per_tracings.pop(0)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def called_without_tracing(self):
|
||||||
|
# TODO(kkimlabs): This is an unnecessary defensive check. Since this is last
|
||||||
|
# minute CL before 2.0 release, I've decided to be very defensive here to
|
||||||
|
# avoid a potential crash. Remove once we release 2.0.
|
||||||
|
if not self._calls_per_tracings:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._calls_per_tracings[-1] += 1
|
||||||
|
self.call_count += 1
|
||||||
|
|
||||||
|
def get_tracing_count(self):
|
||||||
|
return len(self._calls_per_tracings)
|
||||||
|
|
||||||
|
|
||||||
class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
|
class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
|
||||||
"""Variable which does not lift its initializer out of function context.
|
"""Variable which does not lift its initializer out of function context.
|
||||||
@ -297,6 +332,7 @@ class Function(object):
|
|||||||
self._stateless_fn = None # GUARDED_BY(self._lock)
|
self._stateless_fn = None # GUARDED_BY(self._lock)
|
||||||
self._descriptor_cache = weakref.WeakKeyDictionary()
|
self._descriptor_cache = weakref.WeakKeyDictionary()
|
||||||
self._name = name
|
self._name = name
|
||||||
|
self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
|
||||||
|
|
||||||
def _defun_with_scope(self, scope):
|
def _defun_with_scope(self, scope):
|
||||||
"""Creates a defun wrapped inside a variable creator scope."""
|
"""Creates a defun wrapped inside a variable creator scope."""
|
||||||
@ -406,11 +442,41 @@ class Function(object):
|
|||||||
self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
|
self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
|
||||||
self._python_function, self.input_signature)
|
self._python_function, self.input_signature)
|
||||||
|
|
||||||
|
def _get_tracing_count(self):
|
||||||
|
result = self._stateless_fn.tracing_count if self._stateless_fn else 0
|
||||||
|
result += self._stateful_fn.tracing_count if self._stateful_fn else 0
|
||||||
|
return result
|
||||||
|
|
||||||
def __call__(self, *args, **kwds):
|
def __call__(self, *args, **kwds):
|
||||||
"""Calls the graph function."""
|
"""Calls the graph function and warn too frequent tracings."""
|
||||||
context.ensure_initialized()
|
context.ensure_initialized()
|
||||||
if RUN_FUNCTIONS_EAGERLY:
|
if RUN_FUNCTIONS_EAGERLY:
|
||||||
return self._python_function(*args, **kwds)
|
return self._python_function(*args, **kwds)
|
||||||
|
|
||||||
|
tracing_count = self._get_tracing_count()
|
||||||
|
result = self._call(*args, **kwds)
|
||||||
|
if tracing_count == self._get_tracing_count():
|
||||||
|
self._call_counter.called_without_tracing()
|
||||||
|
return result
|
||||||
|
|
||||||
|
self._call_counter.called_with_tracing()
|
||||||
|
recent_tracing_count = self._call_counter.get_tracing_count()
|
||||||
|
if recent_tracing_count >= FREQUENT_TRACING_WARNING_THRESHOLD:
|
||||||
|
logging.warning(
|
||||||
|
"{} out of the last {} calls to {} triggered tf.function retracing. "
|
||||||
|
"Tracing is expensive and the excessive number of tracings is likely "
|
||||||
|
"due to passing python objects instead of tensors. Also, tf.function "
|
||||||
|
"has experimental_relax_shapes=True option that relaxes argument "
|
||||||
|
"shapes that can avoid unnecessary retracing. Please refer to "
|
||||||
|
"https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args"
|
||||||
|
" and https://www.tensorflow.org/api_docs/python/tf/function for more "
|
||||||
|
"details.".format(recent_tracing_count, self._call_counter.call_count,
|
||||||
|
self._python_function))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _call(self, *args, **kwds):
|
||||||
|
"""Calls the graph function."""
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
if self._created_variables:
|
if self._created_variables:
|
||||||
# Release the lock early so that multiple threads can perform the call
|
# Release the lock early so that multiple threads can perform the call
|
||||||
|
@ -1809,6 +1809,7 @@ class Function(object):
|
|||||||
self._function_cache = FunctionCache()
|
self._function_cache = FunctionCache()
|
||||||
self._function_attributes = attributes or {}
|
self._function_attributes = attributes or {}
|
||||||
self._capture_by_value = capture_by_value
|
self._capture_by_value = capture_by_value
|
||||||
|
self.tracing_count = 0
|
||||||
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
# _descriptor_cache is a of instance of a class to an instance-specific
|
# _descriptor_cache is a of instance of a class to an instance-specific
|
||||||
@ -2011,6 +2012,8 @@ class Function(object):
|
|||||||
|
|
||||||
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
||||||
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
||||||
|
self.tracing_count += 1
|
||||||
|
|
||||||
if self.input_signature is None:
|
if self.input_signature is None:
|
||||||
arglen = len(args)
|
arglen = len(args)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user