[tfdbg] Let op_callbacks support MirroredStrategy & OneDeviceStrategy

- Add unit tests in tensorflow/python/debug/lib/distributed_callbacks_test.py
- In check_numerics_callback.py and dumping_callback.py, replace some
  thread-local objects with non-thread-local ones, so that states can
  be shared between threads.
- Add distributed_callbacks_test to cover the check-numerics and dumping ops
  running under various MirroredStrategy and OneDeviceStrategy scopes.

PiperOrigin-RevId: 279230921
Change-Id: I0bf137c4bb29eef6698ac96520841c419c3d6219
This commit is contained in:
Shanqing Cai 2019-11-07 20:46:31 -08:00 committed by TensorFlower Gardener
parent 4f42698fdc
commit af019188ad
8 changed files with 862 additions and 345 deletions

View File

@ -704,6 +704,35 @@ cuda_py_test(
], ],
) )
cuda_py_test(
name = "distributed_callbacks_test",
size = "medium",
srcs = ["lib/distributed_callbacks_test.py"],
additional_deps = [
":check_numerics_callback",
":debug_events_writer",
":dumping_callback",
":dumping_callback_test_lib",
"//third_party/py/numpy",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
],
tags = [
"guitar",
"multi_and_single_gpu",
"no_rocm",
"no_windows", # TODO(b/142475891): Enable this test on Windows.
"no_windows_gpu", # TODO(b/130551176)
],
xla_enable_strict_auto_jit = False, # Node names are different with autojit
)
cuda_py_test( cuda_py_test(
name = "dumping_callback_test", name = "dumping_callback_test",
size = "medium", size = "medium",
@ -720,7 +749,7 @@ cuda_py_test(
"//tensorflow/python:variables", "//tensorflow/python:variables",
"//tensorflow/python/keras", "//tensorflow/python/keras",
], ],
shard_count = 6, shard_count = 8,
tags = [ tags = [
"no_windows", # TODO(b/142475891): Enable this test on Windows. "no_windows", # TODO(b/142475891): Enable this test on Windows.
], ],

View File

@ -87,6 +87,8 @@ SAFE_OPS = (
b"Unpack", b"Unpack",
) )
_state = threading.local()
def limit_string_length(string, max_len=50): def limit_string_length(string, max_len=50):
"""Limit the length of input string. """Limit the length of input string.
@ -217,66 +219,69 @@ def _debug_summary(x):
debug_event_pb2.TensorDebugMode.REDUCE_INF_NAN_THREE_SLOTS)) debug_event_pb2.TensorDebugMode.REDUCE_INF_NAN_THREE_SLOTS))
def _check_numerics_callback(op_type, class CheckNumericsCallback(object):
inputs, """Wrapper for the numerics-checking callback for thread locality."""
attrs,
outputs,
op_name=None,
graph=None):
"""Eager-function unified callback for checking numerics."""
del attrs, op_name # Unused
op_type_bytes = compat.as_bytes(op_type)
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
if (op_type_bytes in op_callbacks_common.OP_CALLBACK_SKIP_OPS or
op_type_bytes in SAFE_OPS):
return
if graph:
# Under graph mode. Insert check_numerics op.
instrumented_outputs = []
for slot, output in enumerate(outputs):
if (output.dtype.is_floating and
(op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
checked_output = array_ops.check_numerics(
# TF v2 has automatic control dependencies added to stateful async
# ops, which allows us to run check_numerics asynchronously.
# In the above case we use debug_summary to reduce all output
# tensors asynchronously from the op being checked and then process
# the tensor summary with check_numerics.
output if is_v1_graph_mode else _debug_summary(output),
get_check_numerics_error_message(
slot,
len(outputs),
op_type,
output,
inputs,
graph=graph,
traceback=output.op.traceback))
_CHECK_NUMERICS_INPUT_LOOKUP[graph][checked_output.name] = output
instrumented_outputs.append(
checked_output if is_v1_graph_mode else output)
else:
instrumented_outputs.append(output)
return instrumented_outputs
else:
if op_type_bytes == b"CheckNumerics":
# TODO(b/140334369): Remove this special casing logic once op_callback.
# automatically prevents infinite recursion in eager mode.
return
# Under eager mode. Eagerly execute check_numerics op.
for slot, output in enumerate(outputs):
if (output.dtype.is_floating and
(op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
array_ops.check_numerics(
output,
get_check_numerics_error_message(
slot, len(outputs), op_type, output, inputs,
stack_height_limit=_state.config.stack_height_limit,
path_length_limit=_state.config.path_length_limit))
def __init__(self, stack_height_limit, path_length_limit):
self._stack_height_limit = stack_height_limit
self._path_length_limit = path_length_limit
CheckNumericsConfig = collections.namedtuple( def callback(self,
"CheckNumericsConfig", "stack_height_limit path_length_limit") op_type,
_state = threading.local() inputs,
attrs,
outputs,
op_name=None,
graph=None):
"""Eager-function unified callback for checking numerics."""
del attrs, op_name # Unused
op_type_bytes = compat.as_bytes(op_type)
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
if (op_type_bytes in op_callbacks_common.OP_CALLBACK_SKIP_OPS or
op_type_bytes in SAFE_OPS):
return None
if graph:
# Under graph mode. Insert check_numerics op.
instrumented_outputs = []
for slot, output in enumerate(outputs):
if (output.dtype.is_floating and
(op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
checked_output = array_ops.check_numerics(
# TF v2 has automatic control dependencies added to stateful async
# ops, which allows us to run check_numerics asynchronously.
# In the above case we use debug_summary to reduce all output
# tensors asynchronously from the op being checked and then
# process the tensor summary with check_numerics.
output if is_v1_graph_mode else _debug_summary(output),
get_check_numerics_error_message(
slot,
len(outputs),
op_type,
output,
inputs,
graph=graph,
traceback=output.op.traceback))
_CHECK_NUMERICS_INPUT_LOOKUP[graph][checked_output.name] = output
instrumented_outputs.append(
checked_output if is_v1_graph_mode else output)
else:
instrumented_outputs.append(output)
return instrumented_outputs
else:
if op_type_bytes == b"CheckNumerics":
# TODO(b/140334369): Remove this special casing logic once op_callback.
# automatically prevents infinite recursion in eager mode.
return None
# Under eager mode. Eagerly execute check_numerics op.
for slot, output in enumerate(outputs):
if (output.dtype.is_floating and
(op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
array_ops.check_numerics(
output,
get_check_numerics_error_message(
slot, len(outputs), op_type, output, inputs,
stack_height_limit=self._stack_height_limit,
path_length_limit=self._path_length_limit))
@tf_export("debugging.enable_check_numerics") @tf_export("debugging.enable_check_numerics")
@ -362,12 +367,10 @@ def enable_check_numerics(stack_height_limit=30,
path_length_limit: Limit to the file path included in the printed stack path_length_limit: Limit to the file path included in the printed stack
trace. Applicable only to ops in `tf.function`s (graphs). trace. Applicable only to ops in `tf.function`s (graphs).
""" """
if not hasattr(_state, "check_numerics_callback"):
if not hasattr(_state, "config"): _state.check_numerics_callback = CheckNumericsCallback(
_state.config = CheckNumericsConfig( stack_height_limit, path_length_limit)
stack_height_limit=stack_height_limit, op_callbacks.add_op_callback(_state.check_numerics_callback.callback)
path_length_limit=path_length_limit)
op_callbacks.add_op_callback(_check_numerics_callback)
logging.info( logging.info(
"Enabled check-numerics callback in thread %s", "Enabled check-numerics callback in thread %s",
@ -387,8 +390,11 @@ def disable_check_numerics():
This method takes effect only on the thread in which it is called. This method takes effect only on the thread in which it is called.
""" """
if not hasattr(_state, "check_numerics_callback"):
return
try: try:
op_callbacks.remove_op_callback(_check_numerics_callback) op_callbacks.remove_op_callback(_state.check_numerics_callback.callback)
delattr(_state, "check_numerics_callback")
logging.info( logging.info(
"Disabled check-numerics callback in thread %s", "Disabled check-numerics callback in thread %s",
threading.current_thread().name) threading.current_thread().name)

View File

@ -571,7 +571,6 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase):
self.assertTrue(np.isnan(batch_mean.squeeze())) self.assertTrue(np.isnan(batch_mean.squeeze()))
self.assertTrue(np.isnan(batch_variance.squeeze())) self.assertTrue(np.isnan(batch_variance.squeeze()))
# TODO(cais): Tests for Infs and NaNs during distributed execution.
# TODO(cais): Benchmark the slowdown due to callbacks and inserted nodes. # TODO(cais): Benchmark the slowdown due to callbacks and inserted nodes.
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,344 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfdbg op callbacks running with various `DistributionStrategy`s."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.debug.lib import check_numerics_callback
from tensorflow.python.debug.lib import dumping_callback
from tensorflow.python.debug.lib import dumping_callback_test_lib
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import gradient_descent
def filter_by_device_name(items, device_names, target_device_name):
"""Filter a list of items by device name.
Args:
items: A list of items to be filtered according to their corresponding
device names.
device_names: A list of the device names. Must have the same legnth
as `items`.
target_device_name: A `str` representing the desired device name.
Returns:
Filtered items from `items`.
"""
assert len(items) == len(device_names)
assert all(device_names), "device_names are not all non-empty strings"
# Note: we use `endswith` instead of `==` for device-name filtering because
# in some cases, the device names from kernel/op execution can have slightly
# different values than the device names from
# `distribution.extended.worker_devices`.
return [items[i] for i, device_name in enumerate(device_names)
if device_name.endswith(target_device_name)]
def filter_by_device_name_and_op_type(
items, device_names, op_types, target_device_name, target_op_type):
assert len(items) == len(device_names)
assert len(items) == len(op_types)
assert all(device_names), "device_names are not all non-empty strings"
assert all(op_types), "op_types are not all non-empty strings"
return [items[i] for i, device_name in enumerate(device_names)
if device_name.endswith(target_device_name)
and op_types[i] == target_op_type]
class MiniModel(keras.Model):
"""Minimal subclassed Keras model."""
def __init__(self, generate_infinity=False):
super(MiniModel, self).__init__(name="")
self._generate_infinity = generate_infinity
self.fc = keras.layers.Dense(
1, kernel_initializer="ones", bias_initializer="ones",
activation="linear")
@def_function.function
def call(self, inputs, training=True):
y = self.fc(inputs)
if self._generate_infinity:
y = math_ops.divide(y, array_ops.zeros_like(y))
return y
class DistributedDumpingCallbackTest(
dumping_callback_test_lib.DumpingCallbackTestBase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
inside_scope=[False, True],
# TODO(cais): Investigate that under V1 graph mode (mode="graph"),
# occasionally (~1-2% of time) the test runs into the following error:
# CancelledError: [_Derived_] Function was cancelled before it was
# started.
mode=["eager"],
))
def testCheckingInfinityInMiniModelOnOneOrTwoDevices(
self, distribution, inside_scope):
if not inside_scope:
check_numerics_callback.enable_check_numerics()
with distribution.scope():
if inside_scope:
check_numerics_callback.enable_check_numerics()
mini_model = MiniModel(generate_infinity=True)
def train_step():
with backprop.GradientTape() as tape:
loss = mini_model(array_ops.ones([1, 10]))
return tape.gradient(loss, mini_model.weights)
caught_error = None
try:
distribution.experimental_run_v2(train_step)
except errors.InvalidArgumentError as error:
caught_error = error
self.assertTrue(caught_error)
self.assertTrue(re.search(
r"Detected Infinity or NaN.*\"RealDiv\"", caught_error.message))
self.assertIn(
"-> | y = math_ops.divide(y, array_ops.zeros_like(y))",
caught_error.message)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
mode=["eager"],
tensor_debug_mode=["NO_TENSOR", "FULL_TENSOR"],
))
def testDumpingMiniModel(self, distribution, tensor_debug_mode):
with distribution.scope():
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode)
mini_model = MiniModel()
optimizer = gradient_descent.GradientDescentOptimizer(0.25)
def train_step():
with backprop.GradientTape() as tape:
loss = mini_model(array_ops.ones([1, 10]))
grads = tape.gradient(loss, mini_model.weights)
grads_and_vars = zip(grads, mini_model.weights)
optimizer.apply_gradients(grads_and_vars)
distribution.experimental_run_v2(train_step)
updated_var_values = self.evaluate(mini_model.variables)
num_devices = len(distribution.extended.worker_devices)
assert num_devices in [1, 2]
# TODO(cais): We currently refrain from asserting the
# element-by-element values of the variable updates. The values seem to
# vary among builds. On some builds, it's 0.75; on others, it's 1.0.
# This variation is seen in the MirroredCPUAndGPU and OneDeviceGPU
# strategies. Needs investigation.
# if num_devices == 1:
# self.assertAllEqual(0.75 * np.ones([10, 1]), updated_var_values[0])
# self.assertAllEqual([0.75], updated_var_values[1]).
# else:
# self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
# self.assertAllEqual([0.5], updated_var_values[1])
self.assertEqual(updated_var_values[0].shape, (10, 1))
self.assertEqual(updated_var_values[1].shape, (1,))
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, _,
op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
(op_names, device_names, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
device_name_0 = distribution.extended.worker_devices[0]
logging.info("device_name_0 = %s", device_name_0)
if num_devices > 1:
device_name_1 = distribution.extended.worker_devices[1]
logging.info("device_name_1 = %s", device_name_1)
device_0_executed_op_types = filter_by_device_name(
executed_op_types, device_names, device_name_0)
if num_devices > 1:
device_1_executed_op_types = filter_by_device_name(
executed_op_types, device_names, device_name_1)
# Verify graph-execution traces are available for both devices.
# We don't assert MatMul occurs exactly once because the gradient of MatMul
# involves MatMul.
self.assertIn("MatMul", device_0_executed_op_types)
self.assertEqual(device_0_executed_op_types.count("BiasAdd"), 1)
if num_devices > 1:
self.assertIn("MatMul", device_1_executed_op_types)
self.assertEqual(device_1_executed_op_types.count("BiasAdd"), 1)
if tensor_debug_mode == "NO_TENSOR":
for value_list in tensor_values:
for tensor_value in value_list:
self.assertEqual(tensor_value.dtype, np.float32)
self.assertEqual(tensor_value.shape, [])
elif tensor_debug_mode == "FULL_TENSOR":
device_0_matmul_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_0,
"MatMul")
device_0_bias_add_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_0,
"BiasAdd")
self.assertAllClose(device_0_matmul_values[0], [[10.0]])
self.assertAllClose(device_0_bias_add_values[0], [[11.0]])
if num_devices > 1:
device_1_matmul_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_1,
"MatMul")
device_1_bias_add_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_1,
"BiasAdd")
self.assertAllClose(device_1_matmul_values[0], [[10.0]])
self.assertAllClose(device_1_bias_add_values[0], [[11.0]])
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
mode=["eager"],
tensor_debug_mode=["NO_TENSOR", "FULL_TENSOR"],
))
def testKerasModelFitOnOneOrTwoDevices(self, distribution, tensor_debug_mode):
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode)
with distribution.scope():
model = keras.Sequential()
model.add(keras.layers.Dense(
units=10, input_shape=[5], activation="relu"))
model.add(keras.layers.Dense(units=1))
model.compile(loss="mse", optimizer="sgd")
batch_size = 20
x = np.ones([batch_size, 5])
y = np.ones([batch_size, 1])
epochs = 1
history = model.fit(x, y, epochs=epochs, verbose=0)
self.assertLen(history.history["loss"], epochs)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, _,
op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
(op_names, device_names, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
# Eager execution of tf.function should be recorded.
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
fit_functions = [op_type for op_type in executed_op_types
if "_distributed_function" in op_type]
self.assertLen(fit_functions, epochs)
num_devices = len(distribution.extended.worker_devices)
device_name_0 = distribution.extended.worker_devices[0]
logging.info("device_name_0 = %s", device_name_0)
if num_devices > 1:
device_name_1 = distribution.extended.worker_devices[1]
logging.info("device_name_1 = %s", device_name_1)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
device_0_executed_op_types = filter_by_device_name(
executed_op_types, device_names, device_name_0)
if num_devices > 1:
device_1_executed_op_types = filter_by_device_name(
executed_op_types, device_names, device_name_1)
self.assertIn("MatMul", device_0_executed_op_types)
self.assertIn("BiasAdd", device_0_executed_op_types)
self.assertIn("Relu", device_0_executed_op_types)
self.assertIn("ReluGrad", device_0_executed_op_types)
if num_devices > 1:
# If there are two devices involved, assert the ops inside tf.functions
# are executed and recorded for the equal numbers of times by the
# dumping op-callback.
self.assertEqual(device_0_executed_op_types.count("MatMul"),
device_1_executed_op_types.count("MatMul"))
self.assertEqual(device_0_executed_op_types.count("BiasAdd"),
device_1_executed_op_types.count("BiasAdd"))
self.assertEqual(device_0_executed_op_types.count("Relu"),
device_1_executed_op_types.count("Relu"))
self.assertEqual(device_0_executed_op_types.count("ReluGrad"),
device_1_executed_op_types.count("ReluGrad"))
if tensor_debug_mode == "NO_TENSOR":
for value_list in tensor_values:
for tensor_value in value_list:
self.assertEqual(tensor_value.dtype, np.float32)
self.assertEqual(tensor_value.shape, [])
elif tensor_debug_mode == "FULL_TENSOR":
gpu_0_relu_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_0, "Relu")
self.assertTrue(gpu_0_relu_values)
gpu_0_relu_grad_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_0,
"ReluGrad")
self.assertTrue(gpu_0_relu_grad_values)
if num_devices > 1:
gpu_1_relu_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_1,
"Relu")
self.assertTrue(gpu_1_relu_values)
for i in range(len(gpu_0_relu_values)):
self.assertEqual(gpu_0_relu_values[i].shape,
gpu_1_relu_values[i].shape)
gpu_1_relu_grad_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_1,
"ReluGrad")
self.assertTrue(gpu_1_relu_grad_values)
for i in range(len(gpu_0_relu_grad_values)):
self.assertEqual(
gpu_0_relu_grad_values[i].shape, gpu_1_relu_grad_values[i].shape)
if __name__ == "__main__":
googletest.main()

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import atexit import atexit
import collections
import re import re
import socket import socket
import threading import threading
@ -42,11 +41,8 @@ from tensorflow.python.util import compat
from tensorflow.python.util import tf_stack from tensorflow.python.util import tf_stack
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
DumpingConfig = collections.namedtuple(
"DumpingConfig",
"dump_root tensor_debug_mode circular_buffer_size "
"op_regex tensor_dtypes")
_state = threading.local() _state = threading.local()
DEFAULT_TENSOR_DEBUG_MODE = "NO_TENSOR"
@ops.RegisterGradient("DebugIdentityV2") @ops.RegisterGradient("DebugIdentityV2")
@ -56,272 +52,340 @@ def _debug_identity_v2_grad(op, dy):
return dy return dy
def _get_writer():
"""Get the debug events writer for the currently configured dump root."""
# TODO(cais): Explore caching the object for possible performance gain.
# TODO(cais): Rename circular_buffer_size to circular_buffer_size in C++ and
# Python-bindng code.
return debug_events_writer.DebugEventsWriter(
_state.config.dump_root,
circular_buffer_size=_state.config.circular_buffer_size)
def _get_id(): def _get_id():
"""Get a short unique ID.""" """Get a short unique ID."""
return str(uuid.uuid4()) return str(uuid.uuid4())
def _get_context_id(context): class _DumpingCallback(object):
"""Get a unique ID for an op-construction context (e.g., a graph). """An object holding the states surrouding the dumping callback."""
If the graph has been encountered before, reuse the same unique ID. def __init__(self,
dump_root,
tensor_debug_mode,
circular_buffer_size,
op_regex,
tensor_dtypes):
self._dump_root = dump_root
self._tensor_debug_mode = tensor_debug_mode
self._circular_buffer_size = circular_buffer_size
self._op_regex = op_regex
self._tensor_dtypes = tensor_dtypes
Args: self._hostname = socket.gethostname()
context: A context to get the unique ID for. Must be hashable. E.g., a Graph # A list of source-file paths.
object. self._source_file_paths = []
# A map from stack frame (FileLineCol) to unique ID.
self._stack_frame_to_id = dict()
# Mapping op context to unique ID.
self._context_to_id = dict()
self._source_file_paths_lock = threading.Lock()
self._stack_frame_to_id_lock = threading.Lock()
self._context_to_id_lock = threading.Lock()
self._writer = None
Returns: @property
A unique ID for the context. def dump_root(self):
""" return self._dump_root
if context not in _state.context_to_id:
_state.context_to_id[context] = _get_id()
return _state.context_to_id[context]
@dump_root.setter
def dump_root(self, dump_root):
if self._dump_root != dump_root:
self._dump_root = dump_root
self._writer = None
def _write_source_file_content(file_path): @property
"""Send the content of a source file via debug-events writer. def tensor_debug_mode(self):
return self._tensor_debug_mode
Args: @property
file_path: Path to the source file. def circular_buffer_size(self):
return self._circular_buffer_size
Returns: def get_writer(self):
An int index for the file. """Get the debug events writer for the currently configured dump root."""
""" if not self._writer:
if file_path not in _state.source_file_paths: self._writer = debug_events_writer.DebugEventsWriter(
lines = None self._dump_root,
if source_utils.is_extension_uncompiled_python_source(file_path): circular_buffer_size=self._circular_buffer_size)
try: return self._writer
lines, _ = source_utils.load_source(file_path)
except IOError:
# Accept the fact that some source files are not readable. Here we use
# best effort to send the source-file contents.
pass
writer = _get_writer()
writer.WriteSourceFile(debug_event_pb2.SourceFile(
file_path=file_path, host_name=_state.hostname, lines=lines))
_state.source_file_paths.append(file_path)
return _state.source_file_paths.index(file_path)
def _get_context_id(self, context):
"""Get a unique ID for an op-construction context (e.g., a graph).
def _process_stack_frames(): If the graph has been encountered before, reuse the same unique ID.
"""Process stack frames.
Send the content of source-files, on a best-effort basis. Args:
context: A context to get the unique ID for. Must be hashable. E.g., a
Graph object.
Returns: Returns:
A list of stack frame IDs. A unique ID for the context.
""" """
stack_frames = tf_stack.extract_stack() # Use the double-checked lock pattern to optimize the common case.
stack_frame_ids = [] if context in self._context_to_id: # 1st check, without lock.
writer = None return self._context_to_id[context]
for file_path, lineno, func, _ in stack_frames: with self._context_to_id_lock:
if (file_path, lineno, func) not in _state.stack_frame_to_id: if context not in self._context_to_id: # 2nd check, with lock.
stack_frame_id = _get_id() self._context_to_id[context] = _get_id()
_state.stack_frame_to_id[(file_path, lineno, func)] = stack_frame_id return self._context_to_id[context]
file_index = _write_source_file_content(file_path)
file_line_col = graph_debug_info_pb2.GraphDebugInfo.FileLineCol(
file_index=file_index, line=lineno, func=func)
stack_frame_with_id = debug_event_pb2.StackFrameWithId(
id=stack_frame_id, file_line_col=file_line_col)
writer = _get_writer()
writer.WriteStackFrameWithId(stack_frame_with_id)
stack_frame_ids.append(_state.stack_frame_to_id[(file_path, lineno, func)])
code_location = debug_event_pb2.CodeLocation( def _write_source_file_content(self, file_path):
host_name=_state.hostname, stack_frame_ids=stack_frame_ids) """Send the content of a source file via debug-events writer.
return code_location
Args:
file_path: Path to the source file.
def _should_dump_tensor(op_type, dtype): Returns:
"""Determine if the given tensor's value will be dumped. An int index for the file.
"""
if file_path in self._source_file_paths:
return self._source_file_paths.index(file_path)
with self._source_file_paths_lock:
if file_path not in self._source_file_paths:
lines = None
if source_utils.is_extension_uncompiled_python_source(file_path):
try:
lines, _ = source_utils.load_source(file_path)
except IOError:
# Accept the fact that some source files are not readable. Here we
# use best effort to send the source-file contents.
pass
writer = self.get_writer()
writer.WriteSourceFile(debug_event_pb2.SourceFile(
file_path=file_path, host_name=self._hostname, lines=lines))
self._source_file_paths.append(file_path)
return self._source_file_paths.index(file_path)
The determination is made given the configurations such as `op_regex`, def _process_stack_frames(self):
`tensor_dtypes`. """Process stack frames.
Args: Send the content of source-files, on a best-effort basis.
op_type: Name of the op's type, as a string (e.g., "MatMul").
dtype: The dtype of the tensor, as a `dtypes.DType` object.
Returns: Returns:
A bool indicating whether the tensor's value will be dumped. A list of stack frame IDs.
""" """
should_dump = True stack_frames = tf_stack.extract_stack()
if _state.config.op_regex: stack_frame_ids = []
should_dump = (should_dump and writer = None
re.match(_state.config.op_regex, op_type)) for file_path, lineno, func, _ in stack_frames:
if _state.config.tensor_dtypes: if (file_path, lineno, func) in self._stack_frame_to_id:
if isinstance(_state.config.tensor_dtypes, (list, tuple)): stack_frame_ids.append(
should_dump = (should_dump and self._stack_frame_to_id[(file_path, lineno, func)])
any(dtype == dtype_item for dtype_item continue
in _state.config.tensor_dtypes)) with self._stack_frame_to_id_lock:
else: # A callable that takes a DType argument and return a boolean. if (file_path, lineno, func) not in self._stack_frame_to_id:
should_dump = should_dump and _state.config.tensor_dtypes(dtype) stack_frame_id = _get_id()
return should_dump self._stack_frame_to_id[(file_path, lineno, func)] = stack_frame_id
file_index = self._write_source_file_content(file_path)
file_line_col = graph_debug_info_pb2.GraphDebugInfo.FileLineCol(
file_index=file_index, line=lineno, func=func)
stack_frame_with_id = debug_event_pb2.StackFrameWithId(
id=stack_frame_id, file_line_col=file_line_col)
writer = self.get_writer()
writer.WriteStackFrameWithId(stack_frame_with_id)
stack_frame_ids.append(
self._stack_frame_to_id[(file_path, lineno, func)])
code_location = debug_event_pb2.CodeLocation(
host_name=self._hostname, stack_frame_ids=stack_frame_ids)
return code_location
def _instrument_symbolic_tensors(tensors, op_type, op_name, tfdbg_context_id): def _instrument_symbolic_tensors(self,
"""Add debugging instrumentation for symbolic (i.e., non-eager) tensors. tensors,
op_type,
op_name,
tfdbg_context_id):
"""Add debugging instrumentation for symbolic (i.e., non-eager) tensors.
The detailed fashion in which the tensors are instrumented is determined The detailed fashion in which the tensors are instrumented is determined
by the tensor_debug_mode configured for the currently enabled dumping by the tensor_debug_mode configured for the currently enabled dumping
callback. callback.
Args: Args:
tensors: A tuple of Tensors to instrument. It is assumed that their ordering tensors: A tuple of Tensors to instrument. It is assumed that their
corresponds to the ordering of output tensors of an original op. Output ordering corresponds to the ordering of output tensors of an original
slot indices (0-based) will be generated based on the ordering. op. Output slot indices (0-based) will be generated based on the
op_type: Name of the op type of the node that emits `tensors` (e.g., ordering.
"MatMul"), as a string. op_type: Type name of the op that emits the Tensors (e.g., "MatMul").
op_name: Name of the node that emits `tensors` (e.g., "dense_1/MatMul"), as op_name: Name of the op that emits the Tensors (e.g., "dense_1/MatMul").
a string. tfdbg_context_id: A unique ID for the context that the op belongs to
tfdbg_context_id: A unique ID for the context that the op belongs to (e.g., (e.g., a graph).
a graph).
Returns:
Non-eager Tensors that override the `tensors` as the output of the op
that originally generated `tensors`. In some cases (e.g., non-V1 graph
mode), this may be `None`, as the instrumentation can simply rely on
automatic control dependencies (see `auto_control_deps.py`) instead of
tensor overriding.
"""
tensor_debug_mode = _state.config.tensor_debug_mode
debug_urls = ["file://%s" % _state.config.dump_root]
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
instrumented_tensors = [] if is_v1_graph_mode else None
for output_slot, tensor in enumerate(tensors):
if not _should_dump_tensor(op_type, tensor.dtype):
if is_v1_graph_mode:
instrumented_tensors.append(tensor)
continue
Returns:
Non-eager Tensors that override the `tensors` as the output of the op
that originally generated `tensors`. In some cases (e.g., non-V1 graph
mode), this may be `None`, as the instrumentation can simply rely on
automatic control dependencies (see `auto_control_deps.py`) instead of
tensor overriding.
"""
tensor_debug_mode = self._tensor_debug_mode
debug_urls = ["file://%s" % self._dump_root]
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
instrumented_tensors = [] if is_v1_graph_mode else None
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR: if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
# Except in V1 graph mode + control flow, debug_identity_v2 trigger auto for output_slot, tensor in enumerate(tensors):
# control dependency because it's a stateful op. if (not self._should_dump_tensor(op_type, tensor.dtype) or
debug_tensor = gen_debug_ops.debug_identity_v2( not tensor.dtype.is_numpy_compatible):
# Use an empty (shape=[0]) float32 tensor for the NO_TENSOR mode # Instrumenting DT_VARIANT and DT_RESOURCE type tensors under
# as a low-overhead placeholder, since no actual tensor value is # V1 graph mode is known to have issues. TODO(cais): Investigate.
# traced. if is_v1_graph_mode:
constant_op.constant([], dtype=dtypes.float32), instrumented_tensors.append(tensor)
tfdbg_context_id=tfdbg_context_id, continue
op_name=op_name, if is_v1_graph_mode and not tensor.dtype.is_numpy_compatible:
output_slot=output_slot,
tensor_debug_mode=_state.config.tensor_debug_mode, instrumented_tensors.append(tensor)
debug_urls=debug_urls) continue
if is_v1_graph_mode: # Except in V1 graph mode + control flow, debug_identity_v2 trigger auto
# TODO(cais): Evaluate performance optimization options. For the # control dependency because it's a stateful op.
# `NO_TENSOR` debug mode, an alternative is to add `debug_tensor` as a debug_tensor = gen_debug_ops.debug_identity_v2(
# control dependency of `tensor.op` without an additional identity op. # Use an empty (shape=[0]) float32 tensor for the NO_TENSOR mode
identity = array_ops.identity(tensor) # as a low-overhead placeholder, since no actual tensor value is
identity.op._add_control_input( # pylint: disable=protected-access # traced.
debug_tensor.op) constant_op.constant([], dtype=dtypes.float32),
instrumented_tensors.append(identity) tfdbg_context_id=tfdbg_context_id,
op_name=op_name,
output_slot=output_slot,
tensor_debug_mode=self._tensor_debug_mode,
debug_urls=debug_urls)
if is_v1_graph_mode:
# TODO(cais): Evaluate performance optimization options. For the
# `NO_TENSOR` debug mode, an alternative is to add `debug_tensor` as a
# control dependency of `tensor.op` without an additional identity op.
identity = array_ops.identity(tensor)
identity.op._add_control_input( # pylint: disable=protected-access
debug_tensor.op)
instrumented_tensors.append(identity)
return instrumented_tensors
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR: elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
debug_tensor = gen_debug_ops.debug_identity_v2( for output_slot, tensor in enumerate(tensors):
tensor, if (not self._should_dump_tensor(op_type, tensor.dtype) or
tfdbg_context_id=tfdbg_context_id, not tensor.dtype.is_numpy_compatible):
op_name=op_name, # Instrumenting DT_VARIANT and DT_RESOURCE type tensors under
output_slot=output_slot, # V1 graph mode is known to have issues. TODO(cais): Investigate.
tensor_debug_mode=_state.config.tensor_debug_mode, if is_v1_graph_mode:
debug_urls=debug_urls) instrumented_tensors.append(tensor)
if is_v1_graph_mode: continue
instrumented_tensors.append(debug_tensor) debug_tensor = gen_debug_ops.debug_identity_v2(
tensor,
tfdbg_context_id=tfdbg_context_id,
op_name=op_name,
output_slot=output_slot,
tensor_debug_mode=self._tensor_debug_mode,
debug_urls=debug_urls)
if is_v1_graph_mode:
instrumented_tensors.append(debug_tensor)
return instrumented_tensors
else: else:
raise NotImplementedError( raise NotImplementedError(
"Symbolic tensor instrumentation is not implemented for debug " "Symbolic tensor instrumentation is not implemented for debug mode "
"mode %s" % _state.config.tensor_debug_mode) "%s" % self._tensor_debug_mode)
return instrumented_tensors
def _dump_eager_tensors(self, tensors, op_type, input_tensor_ids):
"""Dump the value of eager tensors.
def _dump_eager_tensors(tensors, op_type, input_tensor_ids): The destination of the dumping is determined by the dump_root of the
"""Dump the value of eager tensors. currently enabled dumping callback. The tensors may be transformed prior to
dumping (e.g., reduced as summary statistics such as minimum, maximum and
arithmetic mean). The details of this transformation (if any) depends on
the tensor_debug_mode of the currently enabled dumping callback.
The destination of the dumping is determined by the dump_root of the currently Args:
enabled dumping callback. The tensors may be transformed prior to dumping tensors: The EagerTensors whose values are to be dumped, with or without
(e.g., reduced as summary statistics such as minimum, maximum and arithmetic value transform.
mean). The details of this transformation (if any) depends on the op_type: Type of the op that generates the tensors, as a string.
tensor_debug_mode of the currently enabled dumping callback. input_tensor_ids: IDs of the input EagerTensors to the op.
Args: Returns:
tensors: The EagerTensors whose values are to be dumped, with or without A tfdbg Execution protocol buffer.
value transform. """
op_type: Type of the op that generates the tensors, as a string. tensor_debug_mode = self._tensor_debug_mode
input_tensor_ids: IDs of the input EagerTensors to the op. output_tensor_ids = [
t._id for t in tensors] # pylint:disable=protected-access
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
return debug_event_pb2.Execution(
op_type=op_type,
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
tensor_debug_mode=tensor_debug_mode,
code_location=self._process_stack_frames())
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
execution_proto = debug_event_pb2.Execution(
op_type=op_type,
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
tensor_debug_mode=tensor_debug_mode,
code_location=self._process_stack_frames())
for tensor in tensors:
if (self._should_dump_tensor(op_type, tensor.dtype) and
tensor.dtype.is_numpy_compatible):
execution_proto.tensor_protos.append(
tensor_util.make_tensor_proto(tensor.numpy()))
return execution_proto
else:
raise NotImplementedError(
"Tensor instrumentation is not implemented for debug mode %s yet " %
self._tensor_debug_mode)
Returns: def callback(self,
A tfdbg Execution protocol buffer. op_type,
""" inputs,
tensor_debug_mode = _state.config.tensor_debug_mode attrs,
output_tensor_ids = [ outputs,
t._id for t in tensors] # pylint:disable=protected-access op_name=None,
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR: graph=None):
return debug_event_pb2.Execution( """Op callback for tracing (dumping) a TF program's execution."""
op_type=op_type, del attrs # Unused
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
tensor_debug_mode=tensor_debug_mode,
code_location=_process_stack_frames())
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
execution_proto = debug_event_pb2.Execution(
op_type=op_type,
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
tensor_debug_mode=tensor_debug_mode,
code_location=_process_stack_frames())
for tensor in tensors:
if (_should_dump_tensor(op_type, tensor.dtype) and
tensor.dtype.is_numpy_compatible):
execution_proto.tensor_protos.append(
tensor_util.make_tensor_proto(tensor.numpy()))
return execution_proto
else:
raise NotImplementedError(
"Tensor instrumentation is not implemented for debug mode %s yet " %
_state.config.tensor_debug_mode)
writer = self.get_writer()
if graph:
context_id = self._get_context_id(graph)
assert op_name is not None
graph_op_creation = debug_event_pb2.GraphOpCreation(
op_type=op_type,
op_name=op_name,
graph_name=graph.name if hasattr(graph, "name") else None,
graph_id=context_id,
input_names=[input_tensor.name for input_tensor in inputs],
num_outputs=len(outputs),
code_location=self._process_stack_frames())
writer.WriteGraphOpCreation(graph_op_creation)
if outputs and compat.as_bytes(
op_type) not in op_callbacks_common.OP_CALLBACK_SKIP_OPS:
return self._instrument_symbolic_tensors(
outputs, op_type, op_name, context_id)
else:
input_ids = [t._id for t in inputs] # pylint:disable=protected-access
writer.WriteExecution(
self._dump_eager_tensors(outputs, op_type, input_ids))
def _dumping_callback(op_type, def _should_dump_tensor(self, op_type, dtype):
inputs, """Determine if the given tensor's value will be dumped.
attrs,
outputs,
op_name=None,
graph=None):
"""Op callback for tracing a TF program's execution."""
del attrs # Unused
writer = _get_writer() The determination is made given the configurations such as `op_regex`,
if graph: `tensor_dtypes`.
context_id = _get_context_id(graph)
assert op_name is not None
graph_op_creation = debug_event_pb2.GraphOpCreation(
op_type=op_type,
op_name=op_name,
graph_name=graph.name if hasattr(graph, "name") else None,
graph_id=context_id,
input_names=[input_tensor.name for input_tensor in inputs],
num_outputs=len(outputs),
code_location=_process_stack_frames())
writer.WriteGraphOpCreation(graph_op_creation)
if outputs and compat.as_bytes(
op_type) not in op_callbacks_common.OP_CALLBACK_SKIP_OPS:
return _instrument_symbolic_tensors(outputs, op_type, op_name, context_id)
else:
input_ids = [t._id for t in inputs] # pylint:disable=protected-access
writer.WriteExecution(_dump_eager_tensors(outputs, op_type, input_ids))
Args:
op_type: Name of the op's type, as a string (e.g., "MatMul").
dtype: The dtype of the tensor, as a `dtypes.DType` object.
DEFAULT_TENSOR_DEBUG_MODE = "NO_TENSOR" Returns:
A bool indicating whether the tensor's value will be dumped.
"""
should_dump = True
if self._op_regex:
should_dump = (should_dump and
re.match(self._op_regex, op_type))
if self._tensor_dtypes:
if isinstance(self._tensor_dtypes, (list, tuple)):
should_dump = (should_dump and
any(dtype == dtype_item for dtype_item
in self._tensor_dtypes))
else: # A callable that takes a DType argument and return a boolean.
should_dump = should_dump and self._tensor_dtypes(dtype)
return should_dump
@tf_export("debugging.experimental.enable_dump_debug_info") @tf_export("debugging.experimental.enable_dump_debug_info")
@ -416,6 +480,8 @@ def enable_dump_debug_info(dump_root,
# TODO(cais): Revise the "UIs (currently under construction)" part of the doc # TODO(cais): Revise the "UIs (currently under construction)" part of the doc
# string above. # string above.
# TODO(cais): Add Python code example to the doc string above. # TODO(cais): Add Python code example to the doc string above.
global _state
tensor_debug_mode_keys = debug_event_pb2.TensorDebugMode.keys() tensor_debug_mode_keys = debug_event_pb2.TensorDebugMode.keys()
if tensor_debug_mode not in tensor_debug_mode_keys: if tensor_debug_mode not in tensor_debug_mode_keys:
raise ValueError( raise ValueError(
@ -429,25 +495,6 @@ def enable_dump_debug_info(dump_root,
"tfdbg dumping: support for tensor debug mode %s is not " "tfdbg dumping: support for tensor debug mode %s is not "
"implemented yet" % tensor_debug_mode) "implemented yet" % tensor_debug_mode)
if (hasattr(_state, "config") and
_state.config.circular_buffer_size != circular_buffer_size):
raise ValueError(
"There is already a dumping callback configured with a different "
"circular-buffer size (%d). Therefore the newly request "
"circular-buffer size (%d) will not be honored." %
(_state.config.circular_buffer_size, circular_buffer_size))
if (hasattr(_state, "config") and
_state.config.tensor_debug_mode != tensor_debug_mode):
raise ValueError(
"There is already a dumping callback configured for dump root "
"%s with a different "
"tensor-debug mode (%s). Therefore the newly request "
"tensor-debug mode (%s) size will not be honored." %
(_state.config.dump_root,
tensor_debug_mode_keys[_state.config.tensor_debug_mode],
tensor_debug_mode_keys[tensor_debug_mode]))
# Validate the types of tensor_dtypes. # Validate the types of tensor_dtypes.
if tensor_dtypes is not None: if tensor_dtypes is not None:
if (not isinstance(tensor_dtypes, (list, tuple)) and if (not isinstance(tensor_dtypes, (list, tuple)) and
@ -460,30 +507,41 @@ def enable_dump_debug_info(dump_root,
tensor_dtypes = [ tensor_dtypes = [
dtypes.as_dtype(dtype_item) for dtype_item in tensor_dtypes] dtypes.as_dtype(dtype_item) for dtype_item in tensor_dtypes]
if not hasattr(_state, "config") or _state.config.dump_root != dump_root: if hasattr(_state, "dumping_callback"):
_state.config = DumpingConfig( if _state.dumping_callback.circular_buffer_size != circular_buffer_size:
dump_root=dump_root, raise ValueError(
tensor_debug_mode=tensor_debug_mode, "There is already a dumping callback configured with a different "
circular_buffer_size=int(circular_buffer_size), "circular-buffer size (%d). Therefore the newly request "
op_regex=re.compile(op_regex) if op_regex else None, "circular-buffer size (%d) will not be honored." %
tensor_dtypes=tensor_dtypes) (_state.dumping_callback.circular_buffer_size, circular_buffer_size))
_state.hostname = socket.gethostname() if _state.dumping_callback.tensor_debug_mode != tensor_debug_mode:
# A list of source-file paths. raise ValueError(
_state.source_file_paths = [] "There is already a dumping callback configured for dump root "
# A map from stack frame (FileLineCol) to unique ID. "%s with a different "
_state.stack_frame_to_id = dict() "tensor-debug mode (%s). Therefore the newly request "
# Mapping op context to unique ID. "tensor-debug mode (%s) size will not be honored." %
_state.context_to_id = dict() (_state.dumping_callback.dump_root,
tensor_debug_mode_keys[_state.dumping_callback.tensor_debug_mode],
tensor_debug_mode_keys[tensor_debug_mode]))
else:
_state.dumping_callback = _DumpingCallback(dump_root,
tensor_debug_mode,
circular_buffer_size,
op_regex,
tensor_dtypes)
op_callbacks.add_op_callback(_state.dumping_callback.callback)
if _state.dumping_callback.dump_root != dump_root:
_state.dumping_callback.dump_root = dump_root
op_callbacks.add_op_callback(_dumping_callback)
logging.info( logging.info(
"Enabled dumping callback in thread %s " "Enabled dumping callback in thread %s "
"(dump root: %s, tensor debug mode: %s)", "(dump root: %s, tensor debug mode: %s)",
threading.current_thread().name, _state.config.dump_root, threading.current_thread().name,
tensor_debug_mode) _state.dumping_callback.dump_root, tensor_debug_mode)
atexit.register(disable_dump_debug_info) atexit.register(disable_dump_debug_info)
return _get_writer() return _state.dumping_callback.get_writer()
@tf_export("debugging.experimental.disable_dump_debug_info") @tf_export("debugging.experimental.disable_dump_debug_info")
@ -495,10 +553,10 @@ def disable_dump_debug_info():
`enable_dump_debug_info()` has been made, calling this method is a no-op. `enable_dump_debug_info()` has been made, calling this method is a no-op.
Calling this method more than once is idempotent. Calling this method more than once is idempotent.
""" """
if hasattr(_state, "config"): if hasattr(_state, "dumping_callback"):
dump_root = _state.config.dump_root dump_root = _state.dumping_callback.dump_root
delattr(_state, "config")
debug_events_writer.DebugEventsWriter(dump_root).Close() debug_events_writer.DebugEventsWriter(dump_root).Close()
op_callbacks.remove_op_callback(_dumping_callback) op_callbacks.remove_op_callback(_state.dumping_callback.callback)
delattr(_state, "dumping_callback")
logging.info("Disabled dumping callback in thread %s (dump root: %s)", logging.info("Disabled dumping callback in thread %s (dump root: %s)",
threading.current_thread().name, dump_root) threading.current_thread().name, dump_root)

View File

@ -219,8 +219,9 @@ class TracingCallbackTest(
# Session.run() in v1 graph mode, so doesn't get logged to the # Session.run() in v1 graph mode, so doesn't get logged to the
# .execution file. # .execution file.
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile() executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
executed_op_types = [op_type for op_type in executed_op_types
if "sin1p_log_sum" in op_type]
self.assertLen(executed_op_types, 1) self.assertLen(executed_op_types, 1)
self.assertIn("sin1p_log_sum", executed_op_types[0])
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, op_types, (context_ids, op_types,
@ -601,7 +602,8 @@ class TracingCallbackTest(
("NoTensor", "NO_TENSOR"), ("NoTensor", "NO_TENSOR"),
("FullTensor", "FULL_TENSOR"), ("FullTensor", "FULL_TENSOR"),
) )
def testMultiThreadedExecution(self, tensor_debug_mode): def testMultiThreadedExecutionWithSameSetting(self, tensor_debug_mode):
"""Dumping from multiple threads using the same setting."""
writer = dumping_callback.enable_dump_debug_info( writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode) self.dump_root, tensor_debug_mode=tensor_debug_mode)
x = variables.Variable(10.0, dtype=dtypes.float32) x = variables.Variable(10.0, dtype=dtypes.float32)
@ -658,6 +660,64 @@ class TracingCallbackTest(
] ]
self.assertAllClose(mul_values, [6.0, 6.0, 6.0, 6.0]) self.assertAllClose(mul_values, [6.0, 6.0, 6.0, 6.0])
def testMultiThreadedDumpingWithDifferentSettings(self):
dump_root_1 = os.path.join(self.dump_root, "dump_root_1")
dump_root_2 = os.path.join(self.dump_root, "dump_root_2")
v1 = variables.Variable(10.0, dtype=dtypes.float32)
v2 = variables.Variable(3.0, dtype=dtypes.float32)
def add_negative_v1_squared_to_itself():
writer = dumping_callback.enable_dump_debug_info(
dump_root_1, tensor_debug_mode="FULL_TENSOR")
# Run in a loop to facilitate interleaving between threads.
for _ in range(3):
v1.assign_add(-(v1 ** 2.0))
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
def add_negative_v2_squared_to_itself():
writer = dumping_callback.enable_dump_debug_info(
dump_root_2, tensor_debug_mode="FULL_TENSOR")
v2_squared = v2 ** 2.0
# Since dumping is disabled before the Neg op is called, no tensor data
# should be dumped from the op, but this shouldn't affect the dumping of
# the tensor data from the Neg op in `add_negative_v1_squared_to_itself`.
# Both behavior is checked below.
dumping_callback.disable_dump_debug_info()
negative_v2_squared = -v2_squared
v2.assign_add(negative_v2_squared)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
# v2 is mutated on a sub-thread.
sub_thread = threading.Thread(target=add_negative_v2_squared_to_itself)
sub_thread.start()
add_negative_v1_squared_to_itself() # v1 is mutated on the main thread.
sub_thread.join()
# 10 - 10 * 10 = -90.
# -90 - (-90 * -90) = -8190.
# -8190 - (-8190 * -8190) = -67084290.
self.assertAllClose(v1.read_value(), -67084290.0)
self.assertAllClose(v2.read_value(), -6.0)
(executed_op_types, _, _, _,
tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_1)
v1_squared_values = [
tensor_values[i] for i, op_type in enumerate(executed_op_types)
if op_type == "Pow"]
negative_v1_squared_values = [
tensor_values[i] for i, op_type in enumerate(executed_op_types)
if op_type == "Neg"]
self.assertAllClose(v1_squared_values, [[100.0], [8100.0], [67076100.0]])
self.assertAllClose(
negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]])
(executed_op_types, _, _, _,
tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_2)
self.assertNotIn("Neg", executed_op_types)
v2_squared_values = tensor_values[executed_op_types.index("Pow")]
self.assertAllClose(v2_squared_values, [9.0])
@parameterized.named_parameters( @parameterized.named_parameters(
("NoTensor", "NO_TENSOR"), ("NoTensor", "NO_TENSOR"),
("FullTensor", "FULL_TENSOR"), ("FullTensor", "FULL_TENSOR"),

View File

@ -23,6 +23,7 @@ import shutil
import socket import socket
import tempfile import tempfile
from tensorflow.core.framework import types_pb2
from tensorflow.python.debug.lib import check_numerics_callback from tensorflow.python.debug.lib import check_numerics_callback
from tensorflow.python.debug.lib import debug_events_reader from tensorflow.python.debug.lib import debug_events_reader
from tensorflow.python.debug.lib import dumping_callback from tensorflow.python.debug.lib import dumping_callback
@ -137,9 +138,13 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
self.assertIn(stack_frame_id, stack_frame_by_id) self.assertIn(stack_frame_id, stack_frame_by_id)
return context_ids, op_types, op_name_to_op_type return context_ids, op_types, op_name_to_op_type
def _readAndCheckExecutionFile(self): def _readAndCheckExecutionFile(self, dump_root=None):
"""Read and verify the content of the .execution debug-event file. """Read and verify the content of the .execution debug-event file.
Args:
dump_root: Optional argument that can be used to override the default
dump root to read the data from.
Returns: Returns:
executed_op_types: Types of ops that are created, as a `list` of `str`. executed_op_types: Types of ops that are created, as a `list` of `str`.
input_tensor_ids: Input tensor IDs for each of the ops executed, as a input_tensor_ids: Input tensor IDs for each of the ops executed, as a
@ -153,7 +158,8 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
execution event. Each item of the inner `list` corresponds to one execution event. Each item of the inner `list` corresponds to one
output tensor slot of the executed op or Function. output tensor slot of the executed op or Function.
""" """
reader = debug_events_reader.DebugEventsReader(self.dump_root) dump_root = self.dump_root if dump_root is None else dump_root
reader = debug_events_reader.DebugEventsReader(dump_root)
execution_iter = reader.execution_iterator() execution_iter = reader.execution_iterator()
prev_wall_time = 1 prev_wall_time = 1
executed_op_types = [] executed_op_types = []
@ -213,7 +219,10 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
self.assertIn(graph_execution_trace.tfdbg_context_id, context_ids) self.assertIn(graph_execution_trace.tfdbg_context_id, context_ids)
output_slots.append(graph_execution_trace.output_slot) output_slots.append(graph_execution_trace.output_slot)
dtype = dtypes.DType(graph_execution_trace.tensor_proto.dtype) dtype = dtypes.DType(graph_execution_trace.tensor_proto.dtype)
if dtype.is_numpy_compatible: # pylint:disable=protected-access if (dtype.is_numpy_compatible and
dtype._type_enum != types_pb2.DT_STRING): # pylint:disable=protected-access
# TODO(cais): Figure out how to properly convert string tensor proto to
# numpy representation.
tensor_values.append( tensor_values.append(
tensor_util.MakeNdarray(graph_execution_trace.tensor_proto)) tensor_util.MakeNdarray(graph_execution_trace.tensor_proto))
else: else:

View File

@ -867,6 +867,7 @@ class _MirroredReplicaThread(threading.Thread):
ctx = context.context() ctx = context.context()
self.in_eager = ctx.executing_eagerly() self.in_eager = ctx.executing_eagerly()
self.record_thread_local_summary_state() self.record_thread_local_summary_state()
self.record_thread_local_eager_context_state()
self.context_device_policy = ( self.context_device_policy = (
pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
ctx._context_handle)) # pylint: disable=protected-access ctx._context_handle)) # pylint: disable=protected-access
@ -892,6 +893,7 @@ class _MirroredReplicaThread(threading.Thread):
if self.coord.should_stop(): if self.coord.should_stop():
return return
self.restore_thread_local_summary_state() self.restore_thread_local_summary_state()
self.restore_thread_local_eager_context_state()
# TODO(josh11b): Use current logical device instead of 0 here. # TODO(josh11b): Use current logical device instead of 0 here.
with self.coord.stop_on_exception(), \ with self.coord.stop_on_exception(), \
_enter_graph(self._init_graph, self._init_in_eager), \ _enter_graph(self._init_graph, self._init_in_eager), \
@ -920,7 +922,6 @@ class _MirroredReplicaThread(threading.Thread):
self._summary_recording = summary_state.is_recording self._summary_recording = summary_state.is_recording
self._summary_recording_distribution_strategy = ( self._summary_recording_distribution_strategy = (
summary_state.is_recording_distribution_strategy) summary_state.is_recording_distribution_strategy)
# TODO(b/125892694): record other fields in EagerContext.
def restore_thread_local_summary_state(self): def restore_thread_local_summary_state(self):
"""Restore thread local summary state from self.""" """Restore thread local summary state from self."""
@ -931,7 +932,18 @@ class _MirroredReplicaThread(threading.Thread):
summary_state.is_recording = self._summary_recording summary_state.is_recording = self._summary_recording
summary_state.is_recording_distribution_strategy = ( summary_state.is_recording_distribution_strategy = (
self._summary_recording_distribution_strategy) self._summary_recording_distribution_strategy)
# TODO(b/125892694): restore other fields in EagerContext.
def record_thread_local_eager_context_state(self):
ctx = context.context()
eager_context_state = ctx._thread_local_data # pylint: disable=protected-access
self._eager_context_op_callbacks = eager_context_state.op_callbacks
# TODO(b/125892694): record other fields in EagerContext.
def restore_thread_local_eager_context_state(self):
ctx = context.context()
eager_context_state = ctx._thread_local_data # pylint: disable=protected-access
eager_context_state.op_callbacks = self._eager_context_op_callbacks
# TODO(b/125892694): record other fields in EagerContext.
class MirroredReplicaContext(distribute_lib.ReplicaContext): class MirroredReplicaContext(distribute_lib.ReplicaContext):