diff --git a/tensorflow/python/debug/lib/check_numerics_callback.py b/tensorflow/python/debug/lib/check_numerics_callback.py index edcafad201e..440dc758e76 100644 --- a/tensorflow/python/debug/lib/check_numerics_callback.py +++ b/tensorflow/python/debug/lib/check_numerics_callback.py @@ -410,6 +410,21 @@ def enable_check_numerics(stack_height_limit=30, z = tf.matmul(y, y) ``` + NOTE: If your code is running on TPUs, be sure to call + `tf.config.set_soft_device_placement(True)` before calling + `tf.debugging.enable_check_numerics()` as this API uses automatic outside + compilation on TPUs. For example: + + ```py + tf.config.set_soft_device_placement(True) + tf.debugging.enable_check_numerics() + + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + strategy = tf.distribute.experimental.TPUStrategy(resolver) + with strategy.scope(): + # ... + ``` + Args: stack_height_limit: Limit to the height of the printed stack trace. Applicable only to ops in `tf.function`s (graphs). diff --git a/tensorflow/python/debug/lib/dumping_callback.py b/tensorflow/python/debug/lib/dumping_callback.py index 5f7fe5e7ea4..f012faf5f3c 100644 --- a/tensorflow/python/debug/lib/dumping_callback.py +++ b/tensorflow/python/debug/lib/dumping_callback.py @@ -721,6 +721,22 @@ def enable_dump_debug_info(dump_root, # Code to build, train and run your TensorFlow model... ``` + NOTE: If your code is running on TPUs, be sure to call + `tf.config.set_soft_device_placement(True)` before calling + `tf.debugging.experimental.enable_dump_debug_info()` as this API uses + automatic outside compilation on TPUs. For example: + + ```py + tf.config.set_soft_device_placement(True) + tf.debugging.experimental.enable_dump_debug_info( + logdir, tensor_debug_mode="FULL_HEALTH") + + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + strategy = tf.distribute.experimental.TPUStrategy(resolver) + with strategy.scope(): + # ... + ``` + Args: dump_root: The directory path where the dumping information will be written. tensor_debug_mode: Debug mode for tensor values, as a string.