[tfdbg] Let op_callbacks support MirroredStrategy & OneDeviceStrategy

- Add unit tests in tensorflow/python/debug/lib/distributed_callbacks_test.py
- In check_numerics_callback.py and dumping_callback.py, replace some
  thread-local objects with non-thread-local ones, so that states can
  be shared between threads.
- Add distributed_callbacks_test to cover the check-numerics and dumping ops
  running under various MirroredStrategy and OneDeviceStrategy scopes.

PiperOrigin-RevId: 279230921
Change-Id: I0bf137c4bb29eef6698ac96520841c419c3d6219
This commit is contained in:
Shanqing Cai 2019-11-07 20:46:31 -08:00 committed by TensorFlower Gardener
parent 4f42698fdc
commit af019188ad
8 changed files with 862 additions and 345 deletions

View File

@ -704,6 +704,35 @@ cuda_py_test(
],
)
cuda_py_test(
name = "distributed_callbacks_test",
size = "medium",
srcs = ["lib/distributed_callbacks_test.py"],
additional_deps = [
":check_numerics_callback",
":debug_events_writer",
":dumping_callback",
":dumping_callback_test_lib",
"//third_party/py/numpy",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
],
tags = [
"guitar",
"multi_and_single_gpu",
"no_rocm",
"no_windows", # TODO(b/142475891): Enable this test on Windows.
"no_windows_gpu", # TODO(b/130551176)
],
xla_enable_strict_auto_jit = False, # Node names are different with autojit
)
cuda_py_test(
name = "dumping_callback_test",
size = "medium",
@ -720,7 +749,7 @@ cuda_py_test(
"//tensorflow/python:variables",
"//tensorflow/python/keras",
],
shard_count = 6,
shard_count = 8,
tags = [
"no_windows", # TODO(b/142475891): Enable this test on Windows.
],

View File

@ -87,6 +87,8 @@ SAFE_OPS = (
b"Unpack",
)
_state = threading.local()
def limit_string_length(string, max_len=50):
"""Limit the length of input string.
@ -217,7 +219,15 @@ def _debug_summary(x):
debug_event_pb2.TensorDebugMode.REDUCE_INF_NAN_THREE_SLOTS))
def _check_numerics_callback(op_type,
class CheckNumericsCallback(object):
"""Wrapper for the numerics-checking callback for thread locality."""
def __init__(self, stack_height_limit, path_length_limit):
self._stack_height_limit = stack_height_limit
self._path_length_limit = path_length_limit
def callback(self,
op_type,
inputs,
attrs,
outputs,
@ -229,7 +239,7 @@ def _check_numerics_callback(op_type,
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
if (op_type_bytes in op_callbacks_common.OP_CALLBACK_SKIP_OPS or
op_type_bytes in SAFE_OPS):
return
return None
if graph:
# Under graph mode. Insert check_numerics op.
instrumented_outputs = []
@ -240,8 +250,8 @@ def _check_numerics_callback(op_type,
# TF v2 has automatic control dependencies added to stateful async
# ops, which allows us to run check_numerics asynchronously.
# In the above case we use debug_summary to reduce all output
# tensors asynchronously from the op being checked and then process
# the tensor summary with check_numerics.
# tensors asynchronously from the op being checked and then
# process the tensor summary with check_numerics.
output if is_v1_graph_mode else _debug_summary(output),
get_check_numerics_error_message(
slot,
@ -261,7 +271,7 @@ def _check_numerics_callback(op_type,
if op_type_bytes == b"CheckNumerics":
# TODO(b/140334369): Remove this special casing logic once op_callback.
# automatically prevents infinite recursion in eager mode.
return
return None
# Under eager mode. Eagerly execute check_numerics op.
for slot, output in enumerate(outputs):
if (output.dtype.is_floating and
@ -270,13 +280,8 @@ def _check_numerics_callback(op_type,
output,
get_check_numerics_error_message(
slot, len(outputs), op_type, output, inputs,
stack_height_limit=_state.config.stack_height_limit,
path_length_limit=_state.config.path_length_limit))
CheckNumericsConfig = collections.namedtuple(
"CheckNumericsConfig", "stack_height_limit path_length_limit")
_state = threading.local()
stack_height_limit=self._stack_height_limit,
path_length_limit=self._path_length_limit))
@tf_export("debugging.enable_check_numerics")
@ -362,12 +367,10 @@ def enable_check_numerics(stack_height_limit=30,
path_length_limit: Limit to the file path included in the printed stack
trace. Applicable only to ops in `tf.function`s (graphs).
"""
if not hasattr(_state, "config"):
_state.config = CheckNumericsConfig(
stack_height_limit=stack_height_limit,
path_length_limit=path_length_limit)
op_callbacks.add_op_callback(_check_numerics_callback)
if not hasattr(_state, "check_numerics_callback"):
_state.check_numerics_callback = CheckNumericsCallback(
stack_height_limit, path_length_limit)
op_callbacks.add_op_callback(_state.check_numerics_callback.callback)
logging.info(
"Enabled check-numerics callback in thread %s",
@ -387,8 +390,11 @@ def disable_check_numerics():
This method takes effect only on the thread in which it is called.
"""
if not hasattr(_state, "check_numerics_callback"):
return
try:
op_callbacks.remove_op_callback(_check_numerics_callback)
op_callbacks.remove_op_callback(_state.check_numerics_callback.callback)
delattr(_state, "check_numerics_callback")
logging.info(
"Disabled check-numerics callback in thread %s",
threading.current_thread().name)

View File

@ -571,7 +571,6 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase):
self.assertTrue(np.isnan(batch_mean.squeeze()))
self.assertTrue(np.isnan(batch_variance.squeeze()))
# TODO(cais): Tests for Infs and NaNs during distributed execution.
# TODO(cais): Benchmark the slowdown due to callbacks and inserted nodes.
if __name__ == "__main__":

View File

@ -0,0 +1,344 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfdbg op callbacks running with various `DistributionStrategy`s."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.debug.lib import check_numerics_callback
from tensorflow.python.debug.lib import dumping_callback
from tensorflow.python.debug.lib import dumping_callback_test_lib
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import gradient_descent
def filter_by_device_name(items, device_names, target_device_name):
"""Filter a list of items by device name.
Args:
items: A list of items to be filtered according to their corresponding
device names.
device_names: A list of the device names. Must have the same legnth
as `items`.
target_device_name: A `str` representing the desired device name.
Returns:
Filtered items from `items`.
"""
assert len(items) == len(device_names)
assert all(device_names), "device_names are not all non-empty strings"
# Note: we use `endswith` instead of `==` for device-name filtering because
# in some cases, the device names from kernel/op execution can have slightly
# different values than the device names from
# `distribution.extended.worker_devices`.
return [items[i] for i, device_name in enumerate(device_names)
if device_name.endswith(target_device_name)]
def filter_by_device_name_and_op_type(
items, device_names, op_types, target_device_name, target_op_type):
assert len(items) == len(device_names)
assert len(items) == len(op_types)
assert all(device_names), "device_names are not all non-empty strings"
assert all(op_types), "op_types are not all non-empty strings"
return [items[i] for i, device_name in enumerate(device_names)
if device_name.endswith(target_device_name)
and op_types[i] == target_op_type]
class MiniModel(keras.Model):
"""Minimal subclassed Keras model."""
def __init__(self, generate_infinity=False):
super(MiniModel, self).__init__(name="")
self._generate_infinity = generate_infinity
self.fc = keras.layers.Dense(
1, kernel_initializer="ones", bias_initializer="ones",
activation="linear")
@def_function.function
def call(self, inputs, training=True):
y = self.fc(inputs)
if self._generate_infinity:
y = math_ops.divide(y, array_ops.zeros_like(y))
return y
class DistributedDumpingCallbackTest(
dumping_callback_test_lib.DumpingCallbackTestBase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
inside_scope=[False, True],
# TODO(cais): Investigate that under V1 graph mode (mode="graph"),
# occasionally (~1-2% of time) the test runs into the following error:
# CancelledError: [_Derived_] Function was cancelled before it was
# started.
mode=["eager"],
))
def testCheckingInfinityInMiniModelOnOneOrTwoDevices(
self, distribution, inside_scope):
if not inside_scope:
check_numerics_callback.enable_check_numerics()
with distribution.scope():
if inside_scope:
check_numerics_callback.enable_check_numerics()
mini_model = MiniModel(generate_infinity=True)
def train_step():
with backprop.GradientTape() as tape:
loss = mini_model(array_ops.ones([1, 10]))
return tape.gradient(loss, mini_model.weights)
caught_error = None
try:
distribution.experimental_run_v2(train_step)
except errors.InvalidArgumentError as error:
caught_error = error
self.assertTrue(caught_error)
self.assertTrue(re.search(
r"Detected Infinity or NaN.*\"RealDiv\"", caught_error.message))
self.assertIn(
"-> | y = math_ops.divide(y, array_ops.zeros_like(y))",
caught_error.message)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
mode=["eager"],
tensor_debug_mode=["NO_TENSOR", "FULL_TENSOR"],
))
def testDumpingMiniModel(self, distribution, tensor_debug_mode):
with distribution.scope():
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode)
mini_model = MiniModel()
optimizer = gradient_descent.GradientDescentOptimizer(0.25)
def train_step():
with backprop.GradientTape() as tape:
loss = mini_model(array_ops.ones([1, 10]))
grads = tape.gradient(loss, mini_model.weights)
grads_and_vars = zip(grads, mini_model.weights)
optimizer.apply_gradients(grads_and_vars)
distribution.experimental_run_v2(train_step)
updated_var_values = self.evaluate(mini_model.variables)
num_devices = len(distribution.extended.worker_devices)
assert num_devices in [1, 2]
# TODO(cais): We currently refrain from asserting the
# element-by-element values of the variable updates. The values seem to
# vary among builds. On some builds, it's 0.75; on others, it's 1.0.
# This variation is seen in the MirroredCPUAndGPU and OneDeviceGPU
# strategies. Needs investigation.
# if num_devices == 1:
# self.assertAllEqual(0.75 * np.ones([10, 1]), updated_var_values[0])
# self.assertAllEqual([0.75], updated_var_values[1]).
# else:
# self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
# self.assertAllEqual([0.5], updated_var_values[1])
self.assertEqual(updated_var_values[0].shape, (10, 1))
self.assertEqual(updated_var_values[1].shape, (1,))
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, _,
op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
(op_names, device_names, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
device_name_0 = distribution.extended.worker_devices[0]
logging.info("device_name_0 = %s", device_name_0)
if num_devices > 1:
device_name_1 = distribution.extended.worker_devices[1]
logging.info("device_name_1 = %s", device_name_1)
device_0_executed_op_types = filter_by_device_name(
executed_op_types, device_names, device_name_0)
if num_devices > 1:
device_1_executed_op_types = filter_by_device_name(
executed_op_types, device_names, device_name_1)
# Verify graph-execution traces are available for both devices.
# We don't assert MatMul occurs exactly once because the gradient of MatMul
# involves MatMul.
self.assertIn("MatMul", device_0_executed_op_types)
self.assertEqual(device_0_executed_op_types.count("BiasAdd"), 1)
if num_devices > 1:
self.assertIn("MatMul", device_1_executed_op_types)
self.assertEqual(device_1_executed_op_types.count("BiasAdd"), 1)
if tensor_debug_mode == "NO_TENSOR":
for value_list in tensor_values:
for tensor_value in value_list:
self.assertEqual(tensor_value.dtype, np.float32)
self.assertEqual(tensor_value.shape, [])
elif tensor_debug_mode == "FULL_TENSOR":
device_0_matmul_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_0,
"MatMul")
device_0_bias_add_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_0,
"BiasAdd")
self.assertAllClose(device_0_matmul_values[0], [[10.0]])
self.assertAllClose(device_0_bias_add_values[0], [[11.0]])
if num_devices > 1:
device_1_matmul_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_1,
"MatMul")
device_1_bias_add_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_1,
"BiasAdd")
self.assertAllClose(device_1_matmul_values[0], [[10.0]])
self.assertAllClose(device_1_bias_add_values[0], [[11.0]])
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
mode=["eager"],
tensor_debug_mode=["NO_TENSOR", "FULL_TENSOR"],
))
def testKerasModelFitOnOneOrTwoDevices(self, distribution, tensor_debug_mode):
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode)
with distribution.scope():
model = keras.Sequential()
model.add(keras.layers.Dense(
units=10, input_shape=[5], activation="relu"))
model.add(keras.layers.Dense(units=1))
model.compile(loss="mse", optimizer="sgd")
batch_size = 20
x = np.ones([batch_size, 5])
y = np.ones([batch_size, 1])
epochs = 1
history = model.fit(x, y, epochs=epochs, verbose=0)
self.assertLen(history.history["loss"], epochs)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, _,
op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
(op_names, device_names, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
# Eager execution of tf.function should be recorded.
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
fit_functions = [op_type for op_type in executed_op_types
if "_distributed_function" in op_type]
self.assertLen(fit_functions, epochs)
num_devices = len(distribution.extended.worker_devices)
device_name_0 = distribution.extended.worker_devices[0]
logging.info("device_name_0 = %s", device_name_0)
if num_devices > 1:
device_name_1 = distribution.extended.worker_devices[1]
logging.info("device_name_1 = %s", device_name_1)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
device_0_executed_op_types = filter_by_device_name(
executed_op_types, device_names, device_name_0)
if num_devices > 1:
device_1_executed_op_types = filter_by_device_name(
executed_op_types, device_names, device_name_1)
self.assertIn("MatMul", device_0_executed_op_types)
self.assertIn("BiasAdd", device_0_executed_op_types)
self.assertIn("Relu", device_0_executed_op_types)
self.assertIn("ReluGrad", device_0_executed_op_types)
if num_devices > 1:
# If there are two devices involved, assert the ops inside tf.functions
# are executed and recorded for the equal numbers of times by the
# dumping op-callback.
self.assertEqual(device_0_executed_op_types.count("MatMul"),
device_1_executed_op_types.count("MatMul"))
self.assertEqual(device_0_executed_op_types.count("BiasAdd"),
device_1_executed_op_types.count("BiasAdd"))
self.assertEqual(device_0_executed_op_types.count("Relu"),
device_1_executed_op_types.count("Relu"))
self.assertEqual(device_0_executed_op_types.count("ReluGrad"),
device_1_executed_op_types.count("ReluGrad"))
if tensor_debug_mode == "NO_TENSOR":
for value_list in tensor_values:
for tensor_value in value_list:
self.assertEqual(tensor_value.dtype, np.float32)
self.assertEqual(tensor_value.shape, [])
elif tensor_debug_mode == "FULL_TENSOR":
gpu_0_relu_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_0, "Relu")
self.assertTrue(gpu_0_relu_values)
gpu_0_relu_grad_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_0,
"ReluGrad")
self.assertTrue(gpu_0_relu_grad_values)
if num_devices > 1:
gpu_1_relu_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_1,
"Relu")
self.assertTrue(gpu_1_relu_values)
for i in range(len(gpu_0_relu_values)):
self.assertEqual(gpu_0_relu_values[i].shape,
gpu_1_relu_values[i].shape)
gpu_1_relu_grad_values = filter_by_device_name_and_op_type(
tensor_values, device_names, executed_op_types, device_name_1,
"ReluGrad")
self.assertTrue(gpu_1_relu_grad_values)
for i in range(len(gpu_0_relu_grad_values)):
self.assertEqual(
gpu_0_relu_grad_values[i].shape, gpu_1_relu_grad_values[i].shape)
if __name__ == "__main__":
googletest.main()

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import atexit
import collections
import re
import socket
import threading
@ -42,11 +41,8 @@ from tensorflow.python.util import compat
from tensorflow.python.util import tf_stack
from tensorflow.python.util.tf_export import tf_export
DumpingConfig = collections.namedtuple(
"DumpingConfig",
"dump_root tensor_debug_mode circular_buffer_size "
"op_regex tensor_dtypes")
_state = threading.local()
DEFAULT_TENSOR_DEBUG_MODE = "NO_TENSOR"
@ops.RegisterGradient("DebugIdentityV2")
@ -56,39 +52,85 @@ def _debug_identity_v2_grad(op, dy):
return dy
def _get_writer():
"""Get the debug events writer for the currently configured dump root."""
# TODO(cais): Explore caching the object for possible performance gain.
# TODO(cais): Rename circular_buffer_size to circular_buffer_size in C++ and
# Python-bindng code.
return debug_events_writer.DebugEventsWriter(
_state.config.dump_root,
circular_buffer_size=_state.config.circular_buffer_size)
def _get_id():
"""Get a short unique ID."""
return str(uuid.uuid4())
def _get_context_id(context):
class _DumpingCallback(object):
"""An object holding the states surrouding the dumping callback."""
def __init__(self,
dump_root,
tensor_debug_mode,
circular_buffer_size,
op_regex,
tensor_dtypes):
self._dump_root = dump_root
self._tensor_debug_mode = tensor_debug_mode
self._circular_buffer_size = circular_buffer_size
self._op_regex = op_regex
self._tensor_dtypes = tensor_dtypes
self._hostname = socket.gethostname()
# A list of source-file paths.
self._source_file_paths = []
# A map from stack frame (FileLineCol) to unique ID.
self._stack_frame_to_id = dict()
# Mapping op context to unique ID.
self._context_to_id = dict()
self._source_file_paths_lock = threading.Lock()
self._stack_frame_to_id_lock = threading.Lock()
self._context_to_id_lock = threading.Lock()
self._writer = None
@property
def dump_root(self):
return self._dump_root
@dump_root.setter
def dump_root(self, dump_root):
if self._dump_root != dump_root:
self._dump_root = dump_root
self._writer = None
@property
def tensor_debug_mode(self):
return self._tensor_debug_mode
@property
def circular_buffer_size(self):
return self._circular_buffer_size
def get_writer(self):
"""Get the debug events writer for the currently configured dump root."""
if not self._writer:
self._writer = debug_events_writer.DebugEventsWriter(
self._dump_root,
circular_buffer_size=self._circular_buffer_size)
return self._writer
def _get_context_id(self, context):
"""Get a unique ID for an op-construction context (e.g., a graph).
If the graph has been encountered before, reuse the same unique ID.
Args:
context: A context to get the unique ID for. Must be hashable. E.g., a Graph
object.
context: A context to get the unique ID for. Must be hashable. E.g., a
Graph object.
Returns:
A unique ID for the context.
"""
if context not in _state.context_to_id:
_state.context_to_id[context] = _get_id()
return _state.context_to_id[context]
# Use the double-checked lock pattern to optimize the common case.
if context in self._context_to_id: # 1st check, without lock.
return self._context_to_id[context]
with self._context_to_id_lock:
if context not in self._context_to_id: # 2nd check, with lock.
self._context_to_id[context] = _get_id()
return self._context_to_id[context]
def _write_source_file_content(file_path):
def _write_source_file_content(self, file_path):
"""Send the content of a source file via debug-events writer.
Args:
@ -97,23 +139,25 @@ def _write_source_file_content(file_path):
Returns:
An int index for the file.
"""
if file_path not in _state.source_file_paths:
if file_path in self._source_file_paths:
return self._source_file_paths.index(file_path)
with self._source_file_paths_lock:
if file_path not in self._source_file_paths:
lines = None
if source_utils.is_extension_uncompiled_python_source(file_path):
try:
lines, _ = source_utils.load_source(file_path)
except IOError:
# Accept the fact that some source files are not readable. Here we use
# best effort to send the source-file contents.
# Accept the fact that some source files are not readable. Here we
# use best effort to send the source-file contents.
pass
writer = _get_writer()
writer = self.get_writer()
writer.WriteSourceFile(debug_event_pb2.SourceFile(
file_path=file_path, host_name=_state.hostname, lines=lines))
_state.source_file_paths.append(file_path)
return _state.source_file_paths.index(file_path)
file_path=file_path, host_name=self._hostname, lines=lines))
self._source_file_paths.append(file_path)
return self._source_file_paths.index(file_path)
def _process_stack_frames():
def _process_stack_frames(self):
"""Process stack frames.
Send the content of source-files, on a best-effort basis.
@ -125,24 +169,199 @@ def _process_stack_frames():
stack_frame_ids = []
writer = None
for file_path, lineno, func, _ in stack_frames:
if (file_path, lineno, func) not in _state.stack_frame_to_id:
if (file_path, lineno, func) in self._stack_frame_to_id:
stack_frame_ids.append(
self._stack_frame_to_id[(file_path, lineno, func)])
continue
with self._stack_frame_to_id_lock:
if (file_path, lineno, func) not in self._stack_frame_to_id:
stack_frame_id = _get_id()
_state.stack_frame_to_id[(file_path, lineno, func)] = stack_frame_id
file_index = _write_source_file_content(file_path)
self._stack_frame_to_id[(file_path, lineno, func)] = stack_frame_id
file_index = self._write_source_file_content(file_path)
file_line_col = graph_debug_info_pb2.GraphDebugInfo.FileLineCol(
file_index=file_index, line=lineno, func=func)
stack_frame_with_id = debug_event_pb2.StackFrameWithId(
id=stack_frame_id, file_line_col=file_line_col)
writer = _get_writer()
writer = self.get_writer()
writer.WriteStackFrameWithId(stack_frame_with_id)
stack_frame_ids.append(_state.stack_frame_to_id[(file_path, lineno, func)])
stack_frame_ids.append(
self._stack_frame_to_id[(file_path, lineno, func)])
code_location = debug_event_pb2.CodeLocation(
host_name=_state.hostname, stack_frame_ids=stack_frame_ids)
host_name=self._hostname, stack_frame_ids=stack_frame_ids)
return code_location
def _instrument_symbolic_tensors(self,
tensors,
op_type,
op_name,
tfdbg_context_id):
"""Add debugging instrumentation for symbolic (i.e., non-eager) tensors.
def _should_dump_tensor(op_type, dtype):
The detailed fashion in which the tensors are instrumented is determined
by the tensor_debug_mode configured for the currently enabled dumping
callback.
Args:
tensors: A tuple of Tensors to instrument. It is assumed that their
ordering corresponds to the ordering of output tensors of an original
op. Output slot indices (0-based) will be generated based on the
ordering.
op_type: Type name of the op that emits the Tensors (e.g., "MatMul").
op_name: Name of the op that emits the Tensors (e.g., "dense_1/MatMul").
tfdbg_context_id: A unique ID for the context that the op belongs to
(e.g., a graph).
Returns:
Non-eager Tensors that override the `tensors` as the output of the op
that originally generated `tensors`. In some cases (e.g., non-V1 graph
mode), this may be `None`, as the instrumentation can simply rely on
automatic control dependencies (see `auto_control_deps.py`) instead of
tensor overriding.
"""
tensor_debug_mode = self._tensor_debug_mode
debug_urls = ["file://%s" % self._dump_root]
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
instrumented_tensors = [] if is_v1_graph_mode else None
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
for output_slot, tensor in enumerate(tensors):
if (not self._should_dump_tensor(op_type, tensor.dtype) or
not tensor.dtype.is_numpy_compatible):
# Instrumenting DT_VARIANT and DT_RESOURCE type tensors under
# V1 graph mode is known to have issues. TODO(cais): Investigate.
if is_v1_graph_mode:
instrumented_tensors.append(tensor)
continue
if is_v1_graph_mode and not tensor.dtype.is_numpy_compatible:
instrumented_tensors.append(tensor)
continue
# Except in V1 graph mode + control flow, debug_identity_v2 trigger auto
# control dependency because it's a stateful op.
debug_tensor = gen_debug_ops.debug_identity_v2(
# Use an empty (shape=[0]) float32 tensor for the NO_TENSOR mode
# as a low-overhead placeholder, since no actual tensor value is
# traced.
constant_op.constant([], dtype=dtypes.float32),
tfdbg_context_id=tfdbg_context_id,
op_name=op_name,
output_slot=output_slot,
tensor_debug_mode=self._tensor_debug_mode,
debug_urls=debug_urls)
if is_v1_graph_mode:
# TODO(cais): Evaluate performance optimization options. For the
# `NO_TENSOR` debug mode, an alternative is to add `debug_tensor` as a
# control dependency of `tensor.op` without an additional identity op.
identity = array_ops.identity(tensor)
identity.op._add_control_input( # pylint: disable=protected-access
debug_tensor.op)
instrumented_tensors.append(identity)
return instrumented_tensors
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
for output_slot, tensor in enumerate(tensors):
if (not self._should_dump_tensor(op_type, tensor.dtype) or
not tensor.dtype.is_numpy_compatible):
# Instrumenting DT_VARIANT and DT_RESOURCE type tensors under
# V1 graph mode is known to have issues. TODO(cais): Investigate.
if is_v1_graph_mode:
instrumented_tensors.append(tensor)
continue
debug_tensor = gen_debug_ops.debug_identity_v2(
tensor,
tfdbg_context_id=tfdbg_context_id,
op_name=op_name,
output_slot=output_slot,
tensor_debug_mode=self._tensor_debug_mode,
debug_urls=debug_urls)
if is_v1_graph_mode:
instrumented_tensors.append(debug_tensor)
return instrumented_tensors
else:
raise NotImplementedError(
"Symbolic tensor instrumentation is not implemented for debug mode "
"%s" % self._tensor_debug_mode)
def _dump_eager_tensors(self, tensors, op_type, input_tensor_ids):
"""Dump the value of eager tensors.
The destination of the dumping is determined by the dump_root of the
currently enabled dumping callback. The tensors may be transformed prior to
dumping (e.g., reduced as summary statistics such as minimum, maximum and
arithmetic mean). The details of this transformation (if any) depends on
the tensor_debug_mode of the currently enabled dumping callback.
Args:
tensors: The EagerTensors whose values are to be dumped, with or without
value transform.
op_type: Type of the op that generates the tensors, as a string.
input_tensor_ids: IDs of the input EagerTensors to the op.
Returns:
A tfdbg Execution protocol buffer.
"""
tensor_debug_mode = self._tensor_debug_mode
output_tensor_ids = [
t._id for t in tensors] # pylint:disable=protected-access
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
return debug_event_pb2.Execution(
op_type=op_type,
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
tensor_debug_mode=tensor_debug_mode,
code_location=self._process_stack_frames())
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
execution_proto = debug_event_pb2.Execution(
op_type=op_type,
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
tensor_debug_mode=tensor_debug_mode,
code_location=self._process_stack_frames())
for tensor in tensors:
if (self._should_dump_tensor(op_type, tensor.dtype) and
tensor.dtype.is_numpy_compatible):
execution_proto.tensor_protos.append(
tensor_util.make_tensor_proto(tensor.numpy()))
return execution_proto
else:
raise NotImplementedError(
"Tensor instrumentation is not implemented for debug mode %s yet " %
self._tensor_debug_mode)
def callback(self,
op_type,
inputs,
attrs,
outputs,
op_name=None,
graph=None):
"""Op callback for tracing (dumping) a TF program's execution."""
del attrs # Unused
writer = self.get_writer()
if graph:
context_id = self._get_context_id(graph)
assert op_name is not None
graph_op_creation = debug_event_pb2.GraphOpCreation(
op_type=op_type,
op_name=op_name,
graph_name=graph.name if hasattr(graph, "name") else None,
graph_id=context_id,
input_names=[input_tensor.name for input_tensor in inputs],
num_outputs=len(outputs),
code_location=self._process_stack_frames())
writer.WriteGraphOpCreation(graph_op_creation)
if outputs and compat.as_bytes(
op_type) not in op_callbacks_common.OP_CALLBACK_SKIP_OPS:
return self._instrument_symbolic_tensors(
outputs, op_type, op_name, context_id)
else:
input_ids = [t._id for t in inputs] # pylint:disable=protected-access
writer.WriteExecution(
self._dump_eager_tensors(outputs, op_type, input_ids))
def _should_dump_tensor(self, op_type, dtype):
"""Determine if the given tensor's value will be dumped.
The determination is made given the configurations such as `op_regex`,
@ -156,174 +375,19 @@ def _should_dump_tensor(op_type, dtype):
A bool indicating whether the tensor's value will be dumped.
"""
should_dump = True
if _state.config.op_regex:
if self._op_regex:
should_dump = (should_dump and
re.match(_state.config.op_regex, op_type))
if _state.config.tensor_dtypes:
if isinstance(_state.config.tensor_dtypes, (list, tuple)):
re.match(self._op_regex, op_type))
if self._tensor_dtypes:
if isinstance(self._tensor_dtypes, (list, tuple)):
should_dump = (should_dump and
any(dtype == dtype_item for dtype_item
in _state.config.tensor_dtypes))
in self._tensor_dtypes))
else: # A callable that takes a DType argument and return a boolean.
should_dump = should_dump and _state.config.tensor_dtypes(dtype)
should_dump = should_dump and self._tensor_dtypes(dtype)
return should_dump
def _instrument_symbolic_tensors(tensors, op_type, op_name, tfdbg_context_id):
"""Add debugging instrumentation for symbolic (i.e., non-eager) tensors.
The detailed fashion in which the tensors are instrumented is determined
by the tensor_debug_mode configured for the currently enabled dumping
callback.
Args:
tensors: A tuple of Tensors to instrument. It is assumed that their ordering
corresponds to the ordering of output tensors of an original op. Output
slot indices (0-based) will be generated based on the ordering.
op_type: Name of the op type of the node that emits `tensors` (e.g.,
"MatMul"), as a string.
op_name: Name of the node that emits `tensors` (e.g., "dense_1/MatMul"), as
a string.
tfdbg_context_id: A unique ID for the context that the op belongs to (e.g.,
a graph).
Returns:
Non-eager Tensors that override the `tensors` as the output of the op
that originally generated `tensors`. In some cases (e.g., non-V1 graph
mode), this may be `None`, as the instrumentation can simply rely on
automatic control dependencies (see `auto_control_deps.py`) instead of
tensor overriding.
"""
tensor_debug_mode = _state.config.tensor_debug_mode
debug_urls = ["file://%s" % _state.config.dump_root]
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
instrumented_tensors = [] if is_v1_graph_mode else None
for output_slot, tensor in enumerate(tensors):
if not _should_dump_tensor(op_type, tensor.dtype):
if is_v1_graph_mode:
instrumented_tensors.append(tensor)
continue
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
# Except in V1 graph mode + control flow, debug_identity_v2 trigger auto
# control dependency because it's a stateful op.
debug_tensor = gen_debug_ops.debug_identity_v2(
# Use an empty (shape=[0]) float32 tensor for the NO_TENSOR mode
# as a low-overhead placeholder, since no actual tensor value is
# traced.
constant_op.constant([], dtype=dtypes.float32),
tfdbg_context_id=tfdbg_context_id,
op_name=op_name,
output_slot=output_slot,
tensor_debug_mode=_state.config.tensor_debug_mode,
debug_urls=debug_urls)
if is_v1_graph_mode:
# TODO(cais): Evaluate performance optimization options. For the
# `NO_TENSOR` debug mode, an alternative is to add `debug_tensor` as a
# control dependency of `tensor.op` without an additional identity op.
identity = array_ops.identity(tensor)
identity.op._add_control_input( # pylint: disable=protected-access
debug_tensor.op)
instrumented_tensors.append(identity)
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
debug_tensor = gen_debug_ops.debug_identity_v2(
tensor,
tfdbg_context_id=tfdbg_context_id,
op_name=op_name,
output_slot=output_slot,
tensor_debug_mode=_state.config.tensor_debug_mode,
debug_urls=debug_urls)
if is_v1_graph_mode:
instrumented_tensors.append(debug_tensor)
else:
raise NotImplementedError(
"Symbolic tensor instrumentation is not implemented for debug "
"mode %s" % _state.config.tensor_debug_mode)
return instrumented_tensors
def _dump_eager_tensors(tensors, op_type, input_tensor_ids):
"""Dump the value of eager tensors.
The destination of the dumping is determined by the dump_root of the currently
enabled dumping callback. The tensors may be transformed prior to dumping
(e.g., reduced as summary statistics such as minimum, maximum and arithmetic
mean). The details of this transformation (if any) depends on the
tensor_debug_mode of the currently enabled dumping callback.
Args:
tensors: The EagerTensors whose values are to be dumped, with or without
value transform.
op_type: Type of the op that generates the tensors, as a string.
input_tensor_ids: IDs of the input EagerTensors to the op.
Returns:
A tfdbg Execution protocol buffer.
"""
tensor_debug_mode = _state.config.tensor_debug_mode
output_tensor_ids = [
t._id for t in tensors] # pylint:disable=protected-access
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
return debug_event_pb2.Execution(
op_type=op_type,
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
tensor_debug_mode=tensor_debug_mode,
code_location=_process_stack_frames())
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
execution_proto = debug_event_pb2.Execution(
op_type=op_type,
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
tensor_debug_mode=tensor_debug_mode,
code_location=_process_stack_frames())
for tensor in tensors:
if (_should_dump_tensor(op_type, tensor.dtype) and
tensor.dtype.is_numpy_compatible):
execution_proto.tensor_protos.append(
tensor_util.make_tensor_proto(tensor.numpy()))
return execution_proto
else:
raise NotImplementedError(
"Tensor instrumentation is not implemented for debug mode %s yet " %
_state.config.tensor_debug_mode)
def _dumping_callback(op_type,
inputs,
attrs,
outputs,
op_name=None,
graph=None):
"""Op callback for tracing a TF program's execution."""
del attrs # Unused
writer = _get_writer()
if graph:
context_id = _get_context_id(graph)
assert op_name is not None
graph_op_creation = debug_event_pb2.GraphOpCreation(
op_type=op_type,
op_name=op_name,
graph_name=graph.name if hasattr(graph, "name") else None,
graph_id=context_id,
input_names=[input_tensor.name for input_tensor in inputs],
num_outputs=len(outputs),
code_location=_process_stack_frames())
writer.WriteGraphOpCreation(graph_op_creation)
if outputs and compat.as_bytes(
op_type) not in op_callbacks_common.OP_CALLBACK_SKIP_OPS:
return _instrument_symbolic_tensors(outputs, op_type, op_name, context_id)
else:
input_ids = [t._id for t in inputs] # pylint:disable=protected-access
writer.WriteExecution(_dump_eager_tensors(outputs, op_type, input_ids))
DEFAULT_TENSOR_DEBUG_MODE = "NO_TENSOR"
@tf_export("debugging.experimental.enable_dump_debug_info")
def enable_dump_debug_info(dump_root,
tensor_debug_mode=DEFAULT_TENSOR_DEBUG_MODE,
@ -416,6 +480,8 @@ def enable_dump_debug_info(dump_root,
# TODO(cais): Revise the "UIs (currently under construction)" part of the doc
# string above.
# TODO(cais): Add Python code example to the doc string above.
global _state
tensor_debug_mode_keys = debug_event_pb2.TensorDebugMode.keys()
if tensor_debug_mode not in tensor_debug_mode_keys:
raise ValueError(
@ -429,25 +495,6 @@ def enable_dump_debug_info(dump_root,
"tfdbg dumping: support for tensor debug mode %s is not "
"implemented yet" % tensor_debug_mode)
if (hasattr(_state, "config") and
_state.config.circular_buffer_size != circular_buffer_size):
raise ValueError(
"There is already a dumping callback configured with a different "
"circular-buffer size (%d). Therefore the newly request "
"circular-buffer size (%d) will not be honored." %
(_state.config.circular_buffer_size, circular_buffer_size))
if (hasattr(_state, "config") and
_state.config.tensor_debug_mode != tensor_debug_mode):
raise ValueError(
"There is already a dumping callback configured for dump root "
"%s with a different "
"tensor-debug mode (%s). Therefore the newly request "
"tensor-debug mode (%s) size will not be honored." %
(_state.config.dump_root,
tensor_debug_mode_keys[_state.config.tensor_debug_mode],
tensor_debug_mode_keys[tensor_debug_mode]))
# Validate the types of tensor_dtypes.
if tensor_dtypes is not None:
if (not isinstance(tensor_dtypes, (list, tuple)) and
@ -460,30 +507,41 @@ def enable_dump_debug_info(dump_root,
tensor_dtypes = [
dtypes.as_dtype(dtype_item) for dtype_item in tensor_dtypes]
if not hasattr(_state, "config") or _state.config.dump_root != dump_root:
_state.config = DumpingConfig(
dump_root=dump_root,
tensor_debug_mode=tensor_debug_mode,
circular_buffer_size=int(circular_buffer_size),
op_regex=re.compile(op_regex) if op_regex else None,
tensor_dtypes=tensor_dtypes)
_state.hostname = socket.gethostname()
# A list of source-file paths.
_state.source_file_paths = []
# A map from stack frame (FileLineCol) to unique ID.
_state.stack_frame_to_id = dict()
# Mapping op context to unique ID.
_state.context_to_id = dict()
if hasattr(_state, "dumping_callback"):
if _state.dumping_callback.circular_buffer_size != circular_buffer_size:
raise ValueError(
"There is already a dumping callback configured with a different "
"circular-buffer size (%d). Therefore the newly request "
"circular-buffer size (%d) will not be honored." %
(_state.dumping_callback.circular_buffer_size, circular_buffer_size))
if _state.dumping_callback.tensor_debug_mode != tensor_debug_mode:
raise ValueError(
"There is already a dumping callback configured for dump root "
"%s with a different "
"tensor-debug mode (%s). Therefore the newly request "
"tensor-debug mode (%s) size will not be honored." %
(_state.dumping_callback.dump_root,
tensor_debug_mode_keys[_state.dumping_callback.tensor_debug_mode],
tensor_debug_mode_keys[tensor_debug_mode]))
else:
_state.dumping_callback = _DumpingCallback(dump_root,
tensor_debug_mode,
circular_buffer_size,
op_regex,
tensor_dtypes)
op_callbacks.add_op_callback(_state.dumping_callback.callback)
if _state.dumping_callback.dump_root != dump_root:
_state.dumping_callback.dump_root = dump_root
op_callbacks.add_op_callback(_dumping_callback)
logging.info(
"Enabled dumping callback in thread %s "
"(dump root: %s, tensor debug mode: %s)",
threading.current_thread().name, _state.config.dump_root,
tensor_debug_mode)
threading.current_thread().name,
_state.dumping_callback.dump_root, tensor_debug_mode)
atexit.register(disable_dump_debug_info)
return _get_writer()
return _state.dumping_callback.get_writer()
@tf_export("debugging.experimental.disable_dump_debug_info")
@ -495,10 +553,10 @@ def disable_dump_debug_info():
`enable_dump_debug_info()` has been made, calling this method is a no-op.
Calling this method more than once is idempotent.
"""
if hasattr(_state, "config"):
dump_root = _state.config.dump_root
delattr(_state, "config")
if hasattr(_state, "dumping_callback"):
dump_root = _state.dumping_callback.dump_root
debug_events_writer.DebugEventsWriter(dump_root).Close()
op_callbacks.remove_op_callback(_dumping_callback)
op_callbacks.remove_op_callback(_state.dumping_callback.callback)
delattr(_state, "dumping_callback")
logging.info("Disabled dumping callback in thread %s (dump root: %s)",
threading.current_thread().name, dump_root)

View File

@ -219,8 +219,9 @@ class TracingCallbackTest(
# Session.run() in v1 graph mode, so doesn't get logged to the
# .execution file.
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
executed_op_types = [op_type for op_type in executed_op_types
if "sin1p_log_sum" in op_type]
self.assertLen(executed_op_types, 1)
self.assertIn("sin1p_log_sum", executed_op_types[0])
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, op_types,
@ -601,7 +602,8 @@ class TracingCallbackTest(
("NoTensor", "NO_TENSOR"),
("FullTensor", "FULL_TENSOR"),
)
def testMultiThreadedExecution(self, tensor_debug_mode):
def testMultiThreadedExecutionWithSameSetting(self, tensor_debug_mode):
"""Dumping from multiple threads using the same setting."""
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode)
x = variables.Variable(10.0, dtype=dtypes.float32)
@ -658,6 +660,64 @@ class TracingCallbackTest(
]
self.assertAllClose(mul_values, [6.0, 6.0, 6.0, 6.0])
def testMultiThreadedDumpingWithDifferentSettings(self):
dump_root_1 = os.path.join(self.dump_root, "dump_root_1")
dump_root_2 = os.path.join(self.dump_root, "dump_root_2")
v1 = variables.Variable(10.0, dtype=dtypes.float32)
v2 = variables.Variable(3.0, dtype=dtypes.float32)
def add_negative_v1_squared_to_itself():
writer = dumping_callback.enable_dump_debug_info(
dump_root_1, tensor_debug_mode="FULL_TENSOR")
# Run in a loop to facilitate interleaving between threads.
for _ in range(3):
v1.assign_add(-(v1 ** 2.0))
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
def add_negative_v2_squared_to_itself():
writer = dumping_callback.enable_dump_debug_info(
dump_root_2, tensor_debug_mode="FULL_TENSOR")
v2_squared = v2 ** 2.0
# Since dumping is disabled before the Neg op is called, no tensor data
# should be dumped from the op, but this shouldn't affect the dumping of
# the tensor data from the Neg op in `add_negative_v1_squared_to_itself`.
# Both behavior is checked below.
dumping_callback.disable_dump_debug_info()
negative_v2_squared = -v2_squared
v2.assign_add(negative_v2_squared)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
# v2 is mutated on a sub-thread.
sub_thread = threading.Thread(target=add_negative_v2_squared_to_itself)
sub_thread.start()
add_negative_v1_squared_to_itself() # v1 is mutated on the main thread.
sub_thread.join()
# 10 - 10 * 10 = -90.
# -90 - (-90 * -90) = -8190.
# -8190 - (-8190 * -8190) = -67084290.
self.assertAllClose(v1.read_value(), -67084290.0)
self.assertAllClose(v2.read_value(), -6.0)
(executed_op_types, _, _, _,
tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_1)
v1_squared_values = [
tensor_values[i] for i, op_type in enumerate(executed_op_types)
if op_type == "Pow"]
negative_v1_squared_values = [
tensor_values[i] for i, op_type in enumerate(executed_op_types)
if op_type == "Neg"]
self.assertAllClose(v1_squared_values, [[100.0], [8100.0], [67076100.0]])
self.assertAllClose(
negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]])
(executed_op_types, _, _, _,
tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_2)
self.assertNotIn("Neg", executed_op_types)
v2_squared_values = tensor_values[executed_op_types.index("Pow")]
self.assertAllClose(v2_squared_values, [9.0])
@parameterized.named_parameters(
("NoTensor", "NO_TENSOR"),
("FullTensor", "FULL_TENSOR"),

View File

@ -23,6 +23,7 @@ import shutil
import socket
import tempfile
from tensorflow.core.framework import types_pb2
from tensorflow.python.debug.lib import check_numerics_callback
from tensorflow.python.debug.lib import debug_events_reader
from tensorflow.python.debug.lib import dumping_callback
@ -137,9 +138,13 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
self.assertIn(stack_frame_id, stack_frame_by_id)
return context_ids, op_types, op_name_to_op_type
def _readAndCheckExecutionFile(self):
def _readAndCheckExecutionFile(self, dump_root=None):
"""Read and verify the content of the .execution debug-event file.
Args:
dump_root: Optional argument that can be used to override the default
dump root to read the data from.
Returns:
executed_op_types: Types of ops that are created, as a `list` of `str`.
input_tensor_ids: Input tensor IDs for each of the ops executed, as a
@ -153,7 +158,8 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
execution event. Each item of the inner `list` corresponds to one
output tensor slot of the executed op or Function.
"""
reader = debug_events_reader.DebugEventsReader(self.dump_root)
dump_root = self.dump_root if dump_root is None else dump_root
reader = debug_events_reader.DebugEventsReader(dump_root)
execution_iter = reader.execution_iterator()
prev_wall_time = 1
executed_op_types = []
@ -213,7 +219,10 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
self.assertIn(graph_execution_trace.tfdbg_context_id, context_ids)
output_slots.append(graph_execution_trace.output_slot)
dtype = dtypes.DType(graph_execution_trace.tensor_proto.dtype)
if dtype.is_numpy_compatible: # pylint:disable=protected-access
if (dtype.is_numpy_compatible and
dtype._type_enum != types_pb2.DT_STRING): # pylint:disable=protected-access
# TODO(cais): Figure out how to properly convert string tensor proto to
# numpy representation.
tensor_values.append(
tensor_util.MakeNdarray(graph_execution_trace.tensor_proto))
else:

View File

@ -867,6 +867,7 @@ class _MirroredReplicaThread(threading.Thread):
ctx = context.context()
self.in_eager = ctx.executing_eagerly()
self.record_thread_local_summary_state()
self.record_thread_local_eager_context_state()
self.context_device_policy = (
pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
ctx._context_handle)) # pylint: disable=protected-access
@ -892,6 +893,7 @@ class _MirroredReplicaThread(threading.Thread):
if self.coord.should_stop():
return
self.restore_thread_local_summary_state()
self.restore_thread_local_eager_context_state()
# TODO(josh11b): Use current logical device instead of 0 here.
with self.coord.stop_on_exception(), \
_enter_graph(self._init_graph, self._init_in_eager), \
@ -920,7 +922,6 @@ class _MirroredReplicaThread(threading.Thread):
self._summary_recording = summary_state.is_recording
self._summary_recording_distribution_strategy = (
summary_state.is_recording_distribution_strategy)
# TODO(b/125892694): record other fields in EagerContext.
def restore_thread_local_summary_state(self):
"""Restore thread local summary state from self."""
@ -931,7 +932,18 @@ class _MirroredReplicaThread(threading.Thread):
summary_state.is_recording = self._summary_recording
summary_state.is_recording_distribution_strategy = (
self._summary_recording_distribution_strategy)
# TODO(b/125892694): restore other fields in EagerContext.
def record_thread_local_eager_context_state(self):
ctx = context.context()
eager_context_state = ctx._thread_local_data # pylint: disable=protected-access
self._eager_context_op_callbacks = eager_context_state.op_callbacks
# TODO(b/125892694): record other fields in EagerContext.
def restore_thread_local_eager_context_state(self):
ctx = context.context()
eager_context_state = ctx._thread_local_data # pylint: disable=protected-access
eager_context_state.op_callbacks = self._eager_context_op_callbacks
# TODO(b/125892694): record other fields in EagerContext.
class MirroredReplicaContext(distribute_lib.ReplicaContext):