[tfdbg] Support linking FuncGraph building and execution

- Append a _function_name property to FuncGraph, in order to
  allow establishment of connections between the FuncGraph object
  and _EagerDefinedFunctions based on it.
- The dumping op callback extracts the graph_id value and saves it
  with the DebugEvent.execution proto.

Also in this CL:
- Add unit test for the recorded graph IDs for eager execution of
  FuncGraphs.
- Replace the magic string prefixes for Function names (e.g.,
  "__inference_" with constants.
- Use the said string constants in dumping_callback.py and in
  function_deserialization.py

PiperOrigin-RevId: 284228685
Change-Id: I8fc540d6d6de0ed58c77d8de5804b0e997297f68
This commit is contained in:
Shanqing Cai 2019-12-06 11:39:55 -08:00 committed by TensorFlower Gardener
parent a5ac4c72dd
commit 2792dd7cf2
8 changed files with 241 additions and 30 deletions

View File

@ -726,7 +726,7 @@ cuda_py_test(
"//tensorflow/python/keras",
],
python_version = "PY3",
shard_count = 8,
shard_count = 4,
tags = [
"guitar",
"multi_and_single_gpu",

View File

@ -202,11 +202,11 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
# Before FlushExecutionFiles() is called. No data should have been written
# to the file.
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
self.assertFalse(executed_op_types)
writer.FlushExecutionFiles()
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
for i, executed_op_type in enumerate(executed_op_types):
self.assertEqual(
executed_op_type,
@ -222,7 +222,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
writer.WriteExecution(execution)
writer.FlushExecutionFiles()
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
self.assertLen(executed_op_types, num_execution_events)
for i, executed_op_type in enumerate(executed_op_types):
self.assertEqual(executed_op_type, "OpType%d" % i)
@ -302,7 +302,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
writer.FlushExecutionFiles()
# Verify the content of the .execution file.
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
self.assertLen(executed_op_types, circular_buffer_size)
self.assertLen(executed_op_types, len(set(executed_op_types)))

View File

@ -266,7 +266,7 @@ class DistributedDumpingCallbackTest(
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
# Eager execution of tf.function should be recorded.
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
fit_functions = [op_type for op_type in executed_op_types
if "_distributed_function" in op_type]
self.assertLen(fit_functions, epochs)

View File

@ -23,6 +23,7 @@ import re
import socket
import threading
import uuid
import weakref
from six.moves import xrange # pylint: disable=redefined-builtin
@ -31,6 +32,7 @@ from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.python.debug.lib import debug_events_writer
from tensorflow.python.debug.lib import op_callbacks_common
from tensorflow.python.debug.lib import source_utils
from tensorflow.python.eager import function as function_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_callbacks
@ -81,14 +83,35 @@ class _DumpingCallback(object):
self._stack_frame_to_id = dict()
# Mapping op context to unique ID.
self._context_to_id = dict()
self._function_weakref_to_graph_id = dict()
# pylint:disable=protected-access
self._function_prefixes = (
compat.as_bytes(function_lib._FORWARD_PREFIX),
compat.as_bytes(function_lib._BACKWARD_PREFIX),
compat.as_bytes(function_lib._INFERENCE_PREFIX))
# pylint:enable=protected-access
self._op_type_to_context_id = dict()
# Keeps track of counter for symbolic tensors output by in-graph ops.
self._symbolic_tensor_counter = 0
self._source_file_paths_lock = threading.Lock()
self._stack_frame_to_id_lock = threading.Lock()
self._context_to_id_lock = threading.Lock()
self._context_lock = threading.Lock()
self._symbolic_tensor_counter_lock = threading.Lock()
self._writer = None
def function_callback(self, function):
"""A callback to be called on creation of Functions.
Used to establish a join between function name and graph (context) ID.
Args:
function: The just-created Function.
"""
function_weakref = weakref.ref(function)
graph_id = self._get_context_id(function.graph)
with self._context_lock:
self._function_weakref_to_graph_id[function_weakref] = graph_id
@property
def dump_root(self):
return self._dump_root
@ -133,7 +156,7 @@ class _DumpingCallback(object):
if context in self._context_to_id: # 1st check, without lock.
return self._context_to_id[context]
graph_is_new = False
with self._context_to_id_lock:
with self._context_lock:
if context not in self._context_to_id: # 2nd check, with lock.
graph_is_new = True
context_id = _get_id()
@ -318,7 +341,11 @@ class _DumpingCallback(object):
"Symbolic tensor instrumentation is not implemented for debug mode "
"%s" % self._tensor_debug_mode)
def _dump_eager_tensors(self, tensors, op_type, input_tensor_ids):
def _dump_eager_tensors(self,
tensors,
op_type,
input_tensor_ids,
graph_id=None):
"""Dump the value of eager tensors.
The destination of the dumping is determined by the dump_root of the
@ -332,6 +359,8 @@ class _DumpingCallback(object):
value transform.
op_type: Type of the op that generates the tensors, as a string.
input_tensor_ids: IDs of the input EagerTensors to the op.
graph_id: ID of the executed graph, applicable only to eager execution of
a FuncGraph.
Returns:
A tfdbg Execution protocol buffer.
@ -342,6 +371,7 @@ class _DumpingCallback(object):
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
return debug_event_pb2.Execution(
op_type=op_type,
graph_id=graph_id,
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
@ -351,6 +381,7 @@ class _DumpingCallback(object):
execution_proto = debug_event_pb2.Execution(
op_type=op_type,
num_outputs=len(tensors),
graph_id=graph_id,
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
tensor_debug_mode=tensor_debug_mode,
@ -396,9 +427,45 @@ class _DumpingCallback(object):
return self._instrument_symbolic_tensors(
outputs, op_type, op_name, context_id, output_tensor_ids)
else:
context_id = self._func_graph_id_from_func_name(op_type)
input_ids = [t._id for t in inputs] # pylint:disable=protected-access
writer.WriteExecution(
self._dump_eager_tensors(outputs, op_type, input_ids))
writer.WriteExecution(self._dump_eager_tensors(
outputs, op_type, input_ids, graph_id=context_id))
def _func_graph_id_from_func_name(self, op_type):
"""Attempt to get the ID of a FuncGraph based on an op type name.
Also caches the ID for faster access later.
Args:
op_type: Op type string, which may be the name of a function.
Returns:
If the op_type name does not fit the pattern of a function name (e.g.,
one that starts with "__inference_"), `None` is returned immediately.
Else, if the FuncGraph is found, ID of the underlying FuncGraph is
returned as a string.
Else, `None` is returned.
"""
op_type = compat.as_bytes(op_type)
if op_type.startswith(self._function_prefixes):
# op_type for eagerly-executed FuncGraphs have the prefixed and suffixed
# form such as "__inference_my_function_13579", wherein the middle part
# "my_function" is the name of the Python function from which the
# FuncGraph is compiled. Due to the suffix, the op_type is unique for
# - duplicate Python function names
# - multiple compilation of the same Python function
if op_type in self._op_type_to_context_id:
return self._op_type_to_context_id[op_type]
with self._context_lock:
for function_weakref in self._function_weakref_to_graph_id:
if function_weakref().name == op_type:
graph_id = self._function_weakref_to_graph_id[function_weakref]
self._op_type_to_context_id[op_type] = graph_id
return graph_id
return None
else:
return None
def _get_symbolic_tensor_ids(self, num_tensors):
tensor_ids = []
@ -578,6 +645,8 @@ def enable_dump_debug_info(dump_root,
op_regex,
tensor_dtypes)
op_callbacks.add_op_callback(_state.dumping_callback.callback)
function_lib.add_function_callback(
_state.dumping_callback.function_callback)
if _state.dumping_callback.dump_root != dump_root:
_state.dumping_callback.dump_root = dump_root
@ -605,6 +674,8 @@ def disable_dump_debug_info():
dump_root = _state.dumping_callback.dump_root
debug_events_writer.DebugEventsWriter(dump_root).Close()
op_callbacks.remove_op_callback(_state.dumping_callback.callback)
function_lib.remove_function_callback(
_state.dumping_callback.function_callback)
delattr(_state, "dumping_callback")
logging.info("Disabled dumping callback in thread %s (dump root: %s)",
threading.current_thread().name, dump_root)

View File

@ -129,6 +129,8 @@ class TracingCallbackTest(
prev_wall_time = debug_event.wall_time
execution = debug_event.execution
executed_op_types.append(execution.op_type)
# No graph IDs should have been logged for eager op executions.
self.assertFalse(execution.graph_id)
self.assertTrue(execution.input_tensor_ids)
self.assertTrue(execution.output_tensor_ids)
if tensor_debug_mode == "NO_TENSOR":
@ -218,17 +220,30 @@ class TracingCallbackTest(
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, so doesn't get logged to the
# .execution file.
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
(executed_op_types, executed_graph_ids,
_, _, _, _) = self._readAndCheckExecutionFile()
executed_op_types = [op_type for op_type in executed_op_types
if "sin1p_log_sum" in op_type]
self.assertLen(executed_op_types, 1)
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, op_types,
op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
(context_ids, op_types, op_name_to_op_type,
op_name_to_context_id) = self._readAndCheckGraphsFile(stack_frame_by_id)
self.assertIn("AddV2", op_types)
self.assertIn("Log", op_types)
self.assertIn("Sin", op_types)
if context.executing_eagerly():
# Check the correctness of the ID of the executed graph ID.
sin_op_name = [op_name for op_name in op_name_to_op_type
if op_name_to_op_type[op_name] == "Sin"]
self.assertLen(sin_op_name, 1)
sin_context_id = op_name_to_context_id[sin_op_name[0]]
# The executed "op" is a FuncGraph, and its graph ID should have been
# recorded properly and be the ID of the graph that the Sin op belongs to.
executed_graph_ids = [
executed_graph_ids[i] for i, op_type
in enumerate(executed_op_types) if "sin1p_log_sum" in op_type]
self.assertEqual(executed_graph_ids[0], sin_context_id)
(op_names, _, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
@ -248,6 +263,72 @@ class TracingCallbackTest(
self.assertAllClose(tensor_values[3],
np.sin(np.log(5.0) + 1.0)) # Sin op.
def testCapturingExecutedGraphIdsOfTwoCompilationsOfSameFunction(self):
"""Test correct executed IDs of two FuncGraphs from the same Py function."""
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode="NO_TENSOR")
@def_function.function
def ceil_times_two(x):
return math_ops.ceil(x) * 2.0
x_float32 = np.array(3.5, dtype=np.float32)
x_float64 = np.array(4.5, dtype=np.float64)
# Four executions, with two different FuncGraphs, which should lead
# to two unique executed graph IDs (see assertion below).
self.assertAllClose(ceil_times_two(x_float32), 8.0)
self.assertAllClose(ceil_times_two(x_float64), 10.0)
self.assertAllClose(ceil_times_two(x_float32), 8.0)
self.assertAllClose(ceil_times_two(x_float64), 10.0)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
(executed_op_types, executed_graph_ids,
_, _, _, _) = self._readAndCheckExecutionFile()
self.assertLen(executed_op_types, 4)
for executed_op_type in executed_op_types:
self.assertStartsWith(executed_op_type, "__inference_ceil_times_two_")
self.assertLen(executed_graph_ids, 4)
self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
self.assertLen(set(executed_graph_ids), 2)
def testCapturingExecutedGraphIdsOfDuplicateFunctionNames(self):
"""Two FuncGraphs compiled from Python functions with identical names."""
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode="NO_TENSOR")
class TestClass(object):
@def_function.function
def ceil_times_two(self, x):
return math_ops.ceil(x) * 2.0
# The `ceil_times_two` method of the two objects will be compiled
# into separate FuncGraphs.
test_object_1 = TestClass()
test_object_2 = TestClass()
x = np.array(3.5, dtype=np.float32)
# Four executions, with two different FuncGraphs, which should lead
# to two unique executed graph IDs (see assertion below).
self.assertAllClose(test_object_1.ceil_times_two(x), 8.0)
self.assertAllClose(test_object_2.ceil_times_two(x), 8.0)
self.assertAllClose(test_object_1.ceil_times_two(x), 8.0)
self.assertAllClose(test_object_2.ceil_times_two(x), 8.0)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
(executed_op_types, executed_graph_ids,
_, _, _, _) = self._readAndCheckExecutionFile()
self.assertLen(executed_op_types, 4)
for executed_op_type in executed_op_types:
self.assertStartsWith(executed_op_type, "__inference_ceil_times_two_")
self.assertLen(executed_graph_ids, 4)
self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
self.assertLen(set(executed_graph_ids), 2)
@parameterized.named_parameters(
("AddV2", "AddV2"),
("Log", "Log"),
@ -438,7 +519,7 @@ class TracingCallbackTest(
# After the flushing, the .execution file should hold the appropriate
# contents.
if context.executing_eagerly():
(executed_op_types, input_tensor_ids, output_tensor_ids,
(executed_op_types, _, input_tensor_ids, output_tensor_ids,
tensor_debug_modes, tensor_values) = self._readAndCheckExecutionFile()
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, hence it doesn't get logged to the
@ -558,7 +639,7 @@ class TracingCallbackTest(
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
context_ids, _, _, _ = self._readAndCheckGraphsFile(stack_frame_by_id)
_, _, _, _, tensor_values = self._readAndCheckExecutionFile()
_, _, _, _, _, tensor_values = self._readAndCheckExecutionFile()
self.assertEqual(tensor_values, [[]])
(_, _, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
@ -702,7 +783,7 @@ class TracingCallbackTest(
self.assertAllClose(v1.read_value(), -67084290.0)
self.assertAllClose(v2.read_value(), -6.0)
(executed_op_types, _, _, _,
(executed_op_types, _, _, _, _,
tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_1)
v1_squared_values = [
tensor_values[i] for i, op_type in enumerate(executed_op_types)
@ -714,7 +795,7 @@ class TracingCallbackTest(
self.assertAllClose(
negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]])
(executed_op_types, _, _, _,
(executed_op_types, _, _, _, _,
tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_2)
self.assertNotIn("Neg", executed_op_types)
v2_squared_values = tensor_values[executed_op_types.index("Pow")]
@ -800,7 +881,7 @@ class TracingCallbackTest(
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, hence it doesn't get logged to the
# .execution file.
(executed_op_types, _, _, _,
(executed_op_types, _, _, _, _,
tensor_values) = self._readAndCheckExecutionFile()
self.assertTrue(executed_op_types)
@ -867,7 +948,7 @@ class TracingCallbackTest(
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, hence it doesn't get logged to the
# .execution file.
(executed_op_types, _, _, _,
(executed_op_types, _, _, _, _,
tensor_values) = self._readAndCheckExecutionFile()
self.assertTrue(executed_op_types)
if tensor_debug_mode == "NO_TENSOR":
@ -940,7 +1021,7 @@ class TracingCallbackTest(
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, hence it doesn't get logged to the
# .execution file.
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
self.assertTrue(executed_op_types)
(op_names, _, _,

View File

@ -193,6 +193,11 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
Returns:
executed_op_types: Types of ops that are created, as a `list` of `str`.
executed_graph_ids: A `list` of the same length as `executed_op_types`.
If the executed op is a FuncGraph, the corresponding element of the
`list` will be the ID of the FuncGraph. Else, the corresponding element
will be an empty string. This allows establishing connection between
eagerly executed FuncGraphs and their prior graph building.
input_tensor_ids: Input tensor IDs for each of the ops executed, as a
`list` of `list` of `int`s, with the same length as `executed_op_types`.
output_tensor_ids: Output tensor IDs for each of the ops executed, as a
@ -209,6 +214,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
execution_iter = reader.execution_iterator()
prev_wall_time = 1
executed_op_types = []
executed_graph_ids = [] # Empty string for execution of inidividual ops.
input_tensor_ids = []
output_tensor_ids = []
tensor_debug_modes = []
@ -218,6 +224,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
prev_wall_time = debug_event.wall_time
execution = debug_event.execution
executed_op_types.append(execution.op_type)
executed_graph_ids.append(execution.graph_id)
input_tensor_ids.append(execution.input_tensor_ids)
output_tensor_ids.append(execution.output_tensor_ids)
tensor_debug_modes.append(execution.tensor_debug_mode)
@ -227,8 +234,8 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
])
# TODO(cais): When tensor debug modes other than NO_TENSOR is supported,
# return tensor_values as well.
return (executed_op_types, input_tensor_ids, output_tensor_ids,
tensor_debug_modes, tensor_values)
return (executed_op_types, executed_graph_ids, input_tensor_ids,
output_tensor_ids, tensor_debug_modes, tensor_values)
def _readAndCheckGraphExecutionTracesFile(self, context_ids):
"""Read & verify the content of the .graph_execution_trace debug-event file.

View File

@ -345,19 +345,59 @@ class _InterpolateFunctionError(object):
return False
_function_callbacks = set()
def add_function_callback(function_callback):
"""Add a callback function for Function creation.
The callback function has the signature:
`def function_callback(function):`
wherein `function` is the just-created _EagerDefinedFunction.
The callback is invoked immediately after a new `_EagerDefinedFunction`
is created. The return value(s) of the callback fucntion (if any) is ignored.
Repeated registration of the same callback function is idempotent.
After a callback is added, it can be removed with the
`remove_function_callback()` method.
Args:
function_callback: The callback to add.
"""
_function_callbacks.add(function_callback)
def remove_function_callback(function_callback):
"""Remove an already-added function callback.
See the doc string of `add_function_callback()` for more information.
Args:
function_callback: The callback to remove.
"""
_function_callbacks.remove(function_callback)
_FORWARD_PREFIX = "__forward_"
_BACKWARD_PREFIX = "__backward_"
_INFERENCE_PREFIX = "__inference_"
def _forward_name(n):
"""The name of a generated forward defun named n."""
return "__forward_%s_%s" % (n, ops.uid())
return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid())
def _backward_name(n):
"""The name of a generated backward defun named n."""
return "__backward_%s_%s" % (n, ops.uid())
return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid())
def _inference_name(n):
"""The name of a forward-but-no-gradient defun named n."""
return "__inference_%s_%s" % (n, ops.uid())
return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid())
def _enclosing_xla_context():
@ -463,7 +503,7 @@ class _EagerDefinedFunction(object):
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef()
function_def.ParseFromString(compat.as_bytes(proto_data))
self.name = compat.as_bytes(function_def.signature.name)
self._name = compat.as_bytes(function_def.signature.name)
with ops.init_scope():
if context.executing_eagerly():
context.ensure_initialized()
@ -485,6 +525,9 @@ class _EagerDefinedFunction(object):
self.graph = graph
self._stateful_ops = tuple(op for op in operations if op._is_stateful) # pylint: disable=protected-access
for function_callback in _function_callbacks:
function_callback(self)
def add_to_graph(self, g=None):
# pylint: disable=protected-access
if not g and context.executing_eagerly():
@ -497,6 +540,10 @@ class _EagerDefinedFunction(object):
g._add_function(f)
# pylint: enable=protected-access
@property
def name(self):
return self._name
@property
def stateful_ops(self):
return self._stateful_ops
@ -533,6 +580,7 @@ class _EagerDefinedFunction(object):
executor_type = function_call_options.executor_type or ""
executing_eagerly = ctx.executing_eagerly()
attrs = ("executor_type", executor_type, "config_proto", config)
if executing_eagerly:
with _InterpolateFunctionError(self):
if cancellation_manager is None:
@ -540,14 +588,14 @@ class _EagerDefinedFunction(object):
str(self.signature.name),
num_outputs=self._num_outputs,
inputs=args,
attrs=("executor_type", executor_type, "config_proto", config),
attrs=attrs,
ctx=ctx)
else:
outputs = execute.execute_with_cancellation(
str(self.signature.name),
num_outputs=self._num_outputs,
inputs=args,
attrs=("executor_type", executor_type, "config_proto", config),
attrs=attrs,
ctx=ctx,
cancellation_manager=cancellation_manager)
# Replace empty list with None

View File

@ -447,11 +447,15 @@ def _list_function_deps(fdef, library_function_names):
return deps
_FUNCTION_WARPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (
function_lib._INFERENCE_PREFIX) # pylint:disable=protected-access
def _clean_function_name(name):
"""Vanity function to keep the function names comprehensible."""
# Note: each time a function is wrapped into `function_lib.ConcreteFunction`
# its name becomes "__inference_<orig>_xyz".
match = re.search(r"^__inference_(.*)_\d+$", name)
match = re.search(_FUNCTION_WARPPER_NAME_REGEX, name)
if match:
return match.group(1)
else: