[tfdbg] Expose tf.debugging.enable_check_numerics() and disable_check_numerics()
RELNOTES: Add `tf.debugging.enable_check_numerics()` and `tf.debugging.disable_check_numerics()` to facilitate debugging of numeric instability (`Infinity`s and `NaN`s) under eager mode and `tf.function`s. PiperOrigin-RevId: 269418702
This commit is contained in:
parent
9cca4a050c
commit
b056951d0d
@ -169,6 +169,7 @@ py_library(
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/compiler",
|
||||
"//tensorflow/python/data",
|
||||
"//tensorflow/python/debug:debug_py",
|
||||
"//tensorflow/python/distribute",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:distribute_config",
|
||||
|
||||
@ -157,6 +157,9 @@ _tf2_gauge.get_cell().set(_tf2.enabled())
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
|
||||
# TensorFlow Debugger (tfdbg).
|
||||
from tensorflow.python.debug.lib import check_numerics_callback
|
||||
|
||||
# XLA JIT compiler APIs.
|
||||
from tensorflow.python.compiler.xla import jit
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
|
||||
@ -28,6 +28,7 @@ from tensorflow.python.framework import op_callbacks
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def limit_string_length(string, max_len=50):
|
||||
@ -208,6 +209,7 @@ CheckNumericsConfig = collections.namedtuple(
|
||||
_state = threading.local()
|
||||
|
||||
|
||||
@tf_export("debugging.enable_check_numerics")
|
||||
def enable_check_numerics(stack_height_limit=30,
|
||||
path_length_limit=50):
|
||||
r"""Enable tensor numerics checking in an eager/graph unified fashion.
|
||||
@ -302,6 +304,7 @@ def enable_check_numerics(stack_height_limit=30,
|
||||
threading.current_thread().name)
|
||||
|
||||
|
||||
@tf_export("debugging.disable_check_numerics")
|
||||
def disable_check_numerics():
|
||||
"""Disable the eager/graph unified numerics checking mechanism.
|
||||
|
||||
|
||||
@ -92,15 +92,15 @@ def add_check_numerics_ops():
|
||||
|
||||
@compatibility(eager)
|
||||
Not compatible with eager execution. To check for `Inf`s and `NaN`s under
|
||||
eager execution, call `tfe.seterr(inf_or_nan='raise')` once before executing
|
||||
the checked operations.
|
||||
eager execution, call `tf.debugging.enable_check_numerics()` once before
|
||||
executing the checked operations.
|
||||
@end_compatibility
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"add_check_numerics_ops() is not compatible with eager execution. "
|
||||
"To check for Inf's and NaN's under eager execution, call "
|
||||
"tfe.seterr(inf_or_nan='raise') once before executing the "
|
||||
"tf.debugging.enable_check_numerics() once before executing the "
|
||||
"checked operations.")
|
||||
|
||||
check_op = []
|
||||
|
||||
@ -92,6 +92,14 @@ tf_module {
|
||||
name: "check_numerics"
|
||||
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "disable_check_numerics"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "enable_check_numerics"
|
||||
argspec: "args=[\'stack_height_limit\', \'path_length_limit\'], varargs=None, keywords=None, defaults=[\'30\', \'50\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_log_device_placement"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
||||
@ -92,6 +92,14 @@ tf_module {
|
||||
name: "check_numerics"
|
||||
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "disable_check_numerics"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "enable_check_numerics"
|
||||
argspec: "args=[\'stack_height_limit\', \'path_length_limit\'], varargs=None, keywords=None, defaults=[\'30\', \'50\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_log_device_placement"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user