Reduce retracing warning from RestoredFunction

PiperOrigin-RevId: 344913737
Change-Id: I7a8337fbe4e03ec1f18468c4a838dad12a9f4c7d
This commit is contained in:
A. Unique TensorFlower 2020-11-30 16:37:58 -08:00 committed by TensorFlower Gardener
parent ae3c3cf773
commit 3a158a0b4a
3 changed files with 47 additions and 3 deletions

View File

@ -105,10 +105,12 @@ class _FrequentTracingDetector(object):
counter = self._get_counter(key)
counter.called_without_tracing()
def called_with_tracing(self, key, function_name):
def called_with_tracing(self, key, function_name, omit_warning):
with self._lock:
counter = self._get_counter(key)
counter.called_with_tracing()
if omit_warning:
return
if counter.get_tracing_count() >= FREQUENT_TRACING_WARNING_THRESHOLD:
logging.warning(
"{} out of the last {} calls to {} triggered tf.function "
@ -514,6 +516,7 @@ class Function(object):
self._name = name
self._input_signature = input_signature
self._key_for_call_stats = self._get_key_for_call_stats()
self._omit_frequent_tracing_warning = False
ops._tf_function_api_guage.get_cell().set(True) # pylint: disable=protected-access
def __getstate__(self):
@ -794,8 +797,9 @@ class Function(object):
_frequent_tracing_detector.called_without_tracing(
self._key_for_call_stats)
else:
_frequent_tracing_detector.called_with_tracing(self._key_for_call_stats,
self._python_function)
_frequent_tracing_detector.called_with_tracing(
self._key_for_call_stats, self._python_function,
self._omit_frequent_tracing_warning)
return result

View File

@ -47,6 +47,9 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save_context
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model.load import load
from tensorflow.python.saved_model.save import save
from tensorflow.python.training.tracking.util import Checkpoint
def undecorated_function(x):
@ -920,6 +923,39 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
def test_restored_function_retracing_warning(self):
class Foo(Checkpoint):
@def_function.function
def __call__(self, x):
return x
f_flexible = Foo()
_ = f_flexible.__call__.get_concrete_function(
tensor_spec.TensorSpec(shape=[None], dtype=dtypes.int32))
tmp_dir = self.create_tempdir()
save(f_flexible, tmp_dir.full_path)
restored_f_flexible = load(tmp_dir.full_path)
f_fixed_shape = Foo()
with self.assertLogs(level='WARN') as logs:
restored_f_flexible(constant_op.constant([1], dtypes.int32))
restored_f_flexible(constant_op.constant([1, 2], dtypes.int32))
restored_f_flexible(constant_op.constant([1, 2, 3], dtypes.int32))
restored_f_flexible(constant_op.constant([1, 2, 3, 4], dtypes.int32))
restored_f_flexible(constant_op.constant([1, 2, 3, 4, 5], dtypes.int32))
self.assertEmpty(logs.output)
f_fixed_shape(constant_op.constant([1], dtypes.int32))
f_fixed_shape(constant_op.constant([1, 2], dtypes.int32))
f_fixed_shape(constant_op.constant([1, 2, 3], dtypes.int32))
f_fixed_shape(constant_op.constant([1, 2, 3, 4], dtypes.int32))
f_fixed_shape(constant_op.constant([1, 2, 3, 4, 5], dtypes.int32))
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
def test_experimental_get_tracing_count_function(self):
@def_function.function

View File

@ -195,6 +195,10 @@ class RestoredFunction(def_function.Function):
self.concrete_functions = concrete_functions
self._function_spec = function_spec
# Prevent RestoredFunction from spamming users with frequent tracing
# warnings.
self._omit_frequent_tracing_warning = True
def _list_all_concrete_functions_for_serialization(self):
return self.concrete_functions