Merge pull request #32258 from kkimdev/cherrypicks_5XVFU

[r2.0 CherryPick]: @tf.function: Show a warning message when tracing happens too frequently
This commit is contained in:
Alexandre Passos 2019-09-05 15:42:16 -07:00 committed by GitHub
commit f67991359e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 1 deletions

View File

@ -38,6 +38,41 @@ from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
FREQUENT_TRACING_WARNING_THRESHOLD = 5
class _CallCounter(object):
"""Class keeping track of how many recent calls triggered tracing."""
def __init__(self, max_call_history):
self._max_call_history = max_call_history
self._calls_per_tracings = []
self.call_count = 0
def called_with_tracing(self):
self.call_count += 1
self._calls_per_tracings.append(1)
while self._calls_per_tracings:
if self.call_count - self._calls_per_tracings[0] > self._max_call_history:
self.call_count -= self._calls_per_tracings.pop(0)
else:
break
def called_without_tracing(self):
# TODO(kkimlabs): This is an unnecessary defensive check. Since this is last
# minute CL before 2.0 release, I've decided to be very defensive here to
# avoid a potential crash. Remove once we release 2.0.
if not self._calls_per_tracings:
return
self._calls_per_tracings[-1] += 1
self.call_count += 1
def get_tracing_count(self):
return len(self._calls_per_tracings)
class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
"""Variable which does not lift its initializer out of function context.
@ -297,6 +332,7 @@ class Function(object):
self._stateless_fn = None # GUARDED_BY(self._lock)
self._descriptor_cache = weakref.WeakKeyDictionary()
self._name = name
self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
def _defun_with_scope(self, scope):
"""Creates a defun wrapped inside a variable creator scope."""
@ -406,11 +442,41 @@ class Function(object):
self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
self._python_function, self.input_signature)
def _get_tracing_count(self):
result = self._stateless_fn.tracing_count if self._stateless_fn else 0
result += self._stateful_fn.tracing_count if self._stateful_fn else 0
return result
def __call__(self, *args, **kwds):
"""Calls the graph function."""
"""Calls the graph function and warn too frequent tracings."""
context.ensure_initialized()
if RUN_FUNCTIONS_EAGERLY:
return self._python_function(*args, **kwds)
tracing_count = self._get_tracing_count()
result = self._call(*args, **kwds)
if tracing_count == self._get_tracing_count():
self._call_counter.called_without_tracing()
return result
self._call_counter.called_with_tracing()
recent_tracing_count = self._call_counter.get_tracing_count()
if recent_tracing_count >= FREQUENT_TRACING_WARNING_THRESHOLD:
logging.warning(
"{} out of the last {} calls to {} triggered tf.function retracing. "
"Tracing is expensive and the excessive number of tracings is likely "
"due to passing python objects instead of tensors. Also, tf.function "
"has experimental_relax_shapes=True option that relaxes argument "
"shapes that can avoid unnecessary retracing. Please refer to "
"https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args"
" and https://www.tensorflow.org/api_docs/python/tf/function for more "
"details.".format(recent_tracing_count, self._call_counter.call_count,
self._python_function))
return result
def _call(self, *args, **kwds):
"""Calls the graph function."""
self._lock.acquire()
if self._created_variables:
# Release the lock early so that multiple threads can perform the call

View File

@ -1809,6 +1809,7 @@ class Function(object):
self._function_cache = FunctionCache()
self._function_attributes = attributes or {}
self._capture_by_value = capture_by_value
self.tracing_count = 0
self._lock = threading.Lock()
# _descriptor_cache is a of instance of a class to an instance-specific
@ -2011,6 +2012,8 @@ class Function(object):
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
self.tracing_count += 1
if self.input_signature is None:
arglen = len(args)
else: