Reduce retracing warning from RestoredFunction
PiperOrigin-RevId: 344913737 Change-Id: I7a8337fbe4e03ec1f18468c4a838dad12a9f4c7d
This commit is contained in:
parent
ae3c3cf773
commit
3a158a0b4a
tensorflow/python
@ -105,10 +105,12 @@ class _FrequentTracingDetector(object):
|
||||
counter = self._get_counter(key)
|
||||
counter.called_without_tracing()
|
||||
|
||||
def called_with_tracing(self, key, function_name):
|
||||
def called_with_tracing(self, key, function_name, omit_warning):
|
||||
with self._lock:
|
||||
counter = self._get_counter(key)
|
||||
counter.called_with_tracing()
|
||||
if omit_warning:
|
||||
return
|
||||
if counter.get_tracing_count() >= FREQUENT_TRACING_WARNING_THRESHOLD:
|
||||
logging.warning(
|
||||
"{} out of the last {} calls to {} triggered tf.function "
|
||||
@ -514,6 +516,7 @@ class Function(object):
|
||||
self._name = name
|
||||
self._input_signature = input_signature
|
||||
self._key_for_call_stats = self._get_key_for_call_stats()
|
||||
self._omit_frequent_tracing_warning = False
|
||||
ops._tf_function_api_guage.get_cell().set(True) # pylint: disable=protected-access
|
||||
|
||||
def __getstate__(self):
|
||||
@ -794,8 +797,9 @@ class Function(object):
|
||||
_frequent_tracing_detector.called_without_tracing(
|
||||
self._key_for_call_stats)
|
||||
else:
|
||||
_frequent_tracing_detector.called_with_tracing(self._key_for_call_stats,
|
||||
self._python_function)
|
||||
_frequent_tracing_detector.called_with_tracing(
|
||||
self._key_for_call_stats, self._python_function,
|
||||
self._omit_frequent_tracing_warning)
|
||||
|
||||
return result
|
||||
|
||||
|
@ -47,6 +47,9 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import save_context
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.saved_model.load import load
|
||||
from tensorflow.python.saved_model.save import save
|
||||
from tensorflow.python.training.tracking.util import Checkpoint
|
||||
|
||||
|
||||
def undecorated_function(x):
|
||||
@ -920,6 +923,39 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertLen(logs.output, 1)
|
||||
self.assertIn('Tracing is expensive', logs.output[0])
|
||||
|
||||
def test_restored_function_retracing_warning(self):
|
||||
|
||||
class Foo(Checkpoint):
|
||||
|
||||
@def_function.function
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
f_flexible = Foo()
|
||||
_ = f_flexible.__call__.get_concrete_function(
|
||||
tensor_spec.TensorSpec(shape=[None], dtype=dtypes.int32))
|
||||
tmp_dir = self.create_tempdir()
|
||||
save(f_flexible, tmp_dir.full_path)
|
||||
restored_f_flexible = load(tmp_dir.full_path)
|
||||
|
||||
f_fixed_shape = Foo()
|
||||
|
||||
with self.assertLogs(level='WARN') as logs:
|
||||
restored_f_flexible(constant_op.constant([1], dtypes.int32))
|
||||
restored_f_flexible(constant_op.constant([1, 2], dtypes.int32))
|
||||
restored_f_flexible(constant_op.constant([1, 2, 3], dtypes.int32))
|
||||
restored_f_flexible(constant_op.constant([1, 2, 3, 4], dtypes.int32))
|
||||
restored_f_flexible(constant_op.constant([1, 2, 3, 4, 5], dtypes.int32))
|
||||
self.assertEmpty(logs.output)
|
||||
|
||||
f_fixed_shape(constant_op.constant([1], dtypes.int32))
|
||||
f_fixed_shape(constant_op.constant([1, 2], dtypes.int32))
|
||||
f_fixed_shape(constant_op.constant([1, 2, 3], dtypes.int32))
|
||||
f_fixed_shape(constant_op.constant([1, 2, 3, 4], dtypes.int32))
|
||||
f_fixed_shape(constant_op.constant([1, 2, 3, 4, 5], dtypes.int32))
|
||||
self.assertLen(logs.output, 1)
|
||||
self.assertIn('Tracing is expensive', logs.output[0])
|
||||
|
||||
def test_experimental_get_tracing_count_function(self):
|
||||
|
||||
@def_function.function
|
||||
|
@ -195,6 +195,10 @@ class RestoredFunction(def_function.Function):
|
||||
self.concrete_functions = concrete_functions
|
||||
self._function_spec = function_spec
|
||||
|
||||
# Prevent RestoredFunction from spamming users with frequent tracing
|
||||
# warnings.
|
||||
self._omit_frequent_tracing_warning = True
|
||||
|
||||
def _list_all_concrete_functions_for_serialization(self):
|
||||
return self.concrete_functions
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user