[tfdbg] Support linking FuncGraph building and execution
- Append a _function_name property to FuncGraph, in order to allow establishment of connections between the FuncGraph object and _EagerDefinedFunctions based on it. - The dumping op callback extracts the graph_id value and saves it with the DebugEvent.execution proto. Also in this CL: - Add unit test for the recorded graph IDs for eager execution of FuncGraphs. - Replace the magic string prefixes for Function names (e.g., "__inference_" with constants. - Use the said string constants in dumping_callback.py and in function_deserialization.py PiperOrigin-RevId: 284228685 Change-Id: I8fc540d6d6de0ed58c77d8de5804b0e997297f68
This commit is contained in:
parent
a5ac4c72dd
commit
2792dd7cf2
@ -726,7 +726,7 @@ cuda_py_test(
|
||||
"//tensorflow/python/keras",
|
||||
],
|
||||
python_version = "PY3",
|
||||
shard_count = 8,
|
||||
shard_count = 4,
|
||||
tags = [
|
||||
"guitar",
|
||||
"multi_and_single_gpu",
|
||||
|
@ -202,11 +202,11 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
|
||||
|
||||
# Before FlushExecutionFiles() is called. No data should have been written
|
||||
# to the file.
|
||||
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
self.assertFalse(executed_op_types)
|
||||
|
||||
writer.FlushExecutionFiles()
|
||||
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
for i, executed_op_type in enumerate(executed_op_types):
|
||||
self.assertEqual(
|
||||
executed_op_type,
|
||||
@ -222,7 +222,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
|
||||
writer.WriteExecution(execution)
|
||||
writer.FlushExecutionFiles()
|
||||
|
||||
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
self.assertLen(executed_op_types, num_execution_events)
|
||||
for i, executed_op_type in enumerate(executed_op_types):
|
||||
self.assertEqual(executed_op_type, "OpType%d" % i)
|
||||
@ -302,7 +302,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
|
||||
writer.FlushExecutionFiles()
|
||||
|
||||
# Verify the content of the .execution file.
|
||||
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
self.assertLen(executed_op_types, circular_buffer_size)
|
||||
self.assertLen(executed_op_types, len(set(executed_op_types)))
|
||||
|
||||
|
@ -266,7 +266,7 @@ class DistributedDumpingCallbackTest(
|
||||
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
|
||||
|
||||
# Eager execution of tf.function should be recorded.
|
||||
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
fit_functions = [op_type for op_type in executed_op_types
|
||||
if "_distributed_function" in op_type]
|
||||
self.assertLen(fit_functions, epochs)
|
||||
|
@ -23,6 +23,7 @@ import re
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
import weakref
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
@ -31,6 +32,7 @@ from tensorflow.core.protobuf import graph_debug_info_pb2
|
||||
from tensorflow.python.debug.lib import debug_events_writer
|
||||
from tensorflow.python.debug.lib import op_callbacks_common
|
||||
from tensorflow.python.debug.lib import source_utils
|
||||
from tensorflow.python.eager import function as function_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import op_callbacks
|
||||
@ -81,14 +83,35 @@ class _DumpingCallback(object):
|
||||
self._stack_frame_to_id = dict()
|
||||
# Mapping op context to unique ID.
|
||||
self._context_to_id = dict()
|
||||
self._function_weakref_to_graph_id = dict()
|
||||
# pylint:disable=protected-access
|
||||
self._function_prefixes = (
|
||||
compat.as_bytes(function_lib._FORWARD_PREFIX),
|
||||
compat.as_bytes(function_lib._BACKWARD_PREFIX),
|
||||
compat.as_bytes(function_lib._INFERENCE_PREFIX))
|
||||
# pylint:enable=protected-access
|
||||
self._op_type_to_context_id = dict()
|
||||
# Keeps track of counter for symbolic tensors output by in-graph ops.
|
||||
self._symbolic_tensor_counter = 0
|
||||
self._source_file_paths_lock = threading.Lock()
|
||||
self._stack_frame_to_id_lock = threading.Lock()
|
||||
self._context_to_id_lock = threading.Lock()
|
||||
self._context_lock = threading.Lock()
|
||||
self._symbolic_tensor_counter_lock = threading.Lock()
|
||||
self._writer = None
|
||||
|
||||
def function_callback(self, function):
|
||||
"""A callback to be called on creation of Functions.
|
||||
|
||||
Used to establish a join between function name and graph (context) ID.
|
||||
|
||||
Args:
|
||||
function: The just-created Function.
|
||||
"""
|
||||
function_weakref = weakref.ref(function)
|
||||
graph_id = self._get_context_id(function.graph)
|
||||
with self._context_lock:
|
||||
self._function_weakref_to_graph_id[function_weakref] = graph_id
|
||||
|
||||
@property
|
||||
def dump_root(self):
|
||||
return self._dump_root
|
||||
@ -133,7 +156,7 @@ class _DumpingCallback(object):
|
||||
if context in self._context_to_id: # 1st check, without lock.
|
||||
return self._context_to_id[context]
|
||||
graph_is_new = False
|
||||
with self._context_to_id_lock:
|
||||
with self._context_lock:
|
||||
if context not in self._context_to_id: # 2nd check, with lock.
|
||||
graph_is_new = True
|
||||
context_id = _get_id()
|
||||
@ -318,7 +341,11 @@ class _DumpingCallback(object):
|
||||
"Symbolic tensor instrumentation is not implemented for debug mode "
|
||||
"%s" % self._tensor_debug_mode)
|
||||
|
||||
def _dump_eager_tensors(self, tensors, op_type, input_tensor_ids):
|
||||
def _dump_eager_tensors(self,
|
||||
tensors,
|
||||
op_type,
|
||||
input_tensor_ids,
|
||||
graph_id=None):
|
||||
"""Dump the value of eager tensors.
|
||||
|
||||
The destination of the dumping is determined by the dump_root of the
|
||||
@ -332,6 +359,8 @@ class _DumpingCallback(object):
|
||||
value transform.
|
||||
op_type: Type of the op that generates the tensors, as a string.
|
||||
input_tensor_ids: IDs of the input EagerTensors to the op.
|
||||
graph_id: ID of the executed graph, applicable only to eager execution of
|
||||
a FuncGraph.
|
||||
|
||||
Returns:
|
||||
A tfdbg Execution protocol buffer.
|
||||
@ -342,6 +371,7 @@ class _DumpingCallback(object):
|
||||
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
|
||||
return debug_event_pb2.Execution(
|
||||
op_type=op_type,
|
||||
graph_id=graph_id,
|
||||
num_outputs=len(tensors),
|
||||
input_tensor_ids=input_tensor_ids,
|
||||
output_tensor_ids=output_tensor_ids,
|
||||
@ -351,6 +381,7 @@ class _DumpingCallback(object):
|
||||
execution_proto = debug_event_pb2.Execution(
|
||||
op_type=op_type,
|
||||
num_outputs=len(tensors),
|
||||
graph_id=graph_id,
|
||||
input_tensor_ids=input_tensor_ids,
|
||||
output_tensor_ids=output_tensor_ids,
|
||||
tensor_debug_mode=tensor_debug_mode,
|
||||
@ -396,9 +427,45 @@ class _DumpingCallback(object):
|
||||
return self._instrument_symbolic_tensors(
|
||||
outputs, op_type, op_name, context_id, output_tensor_ids)
|
||||
else:
|
||||
context_id = self._func_graph_id_from_func_name(op_type)
|
||||
input_ids = [t._id for t in inputs] # pylint:disable=protected-access
|
||||
writer.WriteExecution(
|
||||
self._dump_eager_tensors(outputs, op_type, input_ids))
|
||||
writer.WriteExecution(self._dump_eager_tensors(
|
||||
outputs, op_type, input_ids, graph_id=context_id))
|
||||
|
||||
def _func_graph_id_from_func_name(self, op_type):
|
||||
"""Attempt to get the ID of a FuncGraph based on an op type name.
|
||||
|
||||
Also caches the ID for faster access later.
|
||||
|
||||
Args:
|
||||
op_type: Op type string, which may be the name of a function.
|
||||
|
||||
Returns:
|
||||
If the op_type name does not fit the pattern of a function name (e.g.,
|
||||
one that starts with "__inference_"), `None` is returned immediately.
|
||||
Else, if the FuncGraph is found, ID of the underlying FuncGraph is
|
||||
returned as a string.
|
||||
Else, `None` is returned.
|
||||
"""
|
||||
op_type = compat.as_bytes(op_type)
|
||||
if op_type.startswith(self._function_prefixes):
|
||||
# op_type for eagerly-executed FuncGraphs have the prefixed and suffixed
|
||||
# form such as "__inference_my_function_13579", wherein the middle part
|
||||
# "my_function" is the name of the Python function from which the
|
||||
# FuncGraph is compiled. Due to the suffix, the op_type is unique for
|
||||
# - duplicate Python function names
|
||||
# - multiple compilation of the same Python function
|
||||
if op_type in self._op_type_to_context_id:
|
||||
return self._op_type_to_context_id[op_type]
|
||||
with self._context_lock:
|
||||
for function_weakref in self._function_weakref_to_graph_id:
|
||||
if function_weakref().name == op_type:
|
||||
graph_id = self._function_weakref_to_graph_id[function_weakref]
|
||||
self._op_type_to_context_id[op_type] = graph_id
|
||||
return graph_id
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_symbolic_tensor_ids(self, num_tensors):
|
||||
tensor_ids = []
|
||||
@ -578,6 +645,8 @@ def enable_dump_debug_info(dump_root,
|
||||
op_regex,
|
||||
tensor_dtypes)
|
||||
op_callbacks.add_op_callback(_state.dumping_callback.callback)
|
||||
function_lib.add_function_callback(
|
||||
_state.dumping_callback.function_callback)
|
||||
|
||||
if _state.dumping_callback.dump_root != dump_root:
|
||||
_state.dumping_callback.dump_root = dump_root
|
||||
@ -605,6 +674,8 @@ def disable_dump_debug_info():
|
||||
dump_root = _state.dumping_callback.dump_root
|
||||
debug_events_writer.DebugEventsWriter(dump_root).Close()
|
||||
op_callbacks.remove_op_callback(_state.dumping_callback.callback)
|
||||
function_lib.remove_function_callback(
|
||||
_state.dumping_callback.function_callback)
|
||||
delattr(_state, "dumping_callback")
|
||||
logging.info("Disabled dumping callback in thread %s (dump root: %s)",
|
||||
threading.current_thread().name, dump_root)
|
||||
|
@ -129,6 +129,8 @@ class TracingCallbackTest(
|
||||
prev_wall_time = debug_event.wall_time
|
||||
execution = debug_event.execution
|
||||
executed_op_types.append(execution.op_type)
|
||||
# No graph IDs should have been logged for eager op executions.
|
||||
self.assertFalse(execution.graph_id)
|
||||
self.assertTrue(execution.input_tensor_ids)
|
||||
self.assertTrue(execution.output_tensor_ids)
|
||||
if tensor_debug_mode == "NO_TENSOR":
|
||||
@ -218,17 +220,30 @@ class TracingCallbackTest(
|
||||
# NOTE(b/142486213): Execution of the TF function happens with
|
||||
# Session.run() in v1 graph mode, so doesn't get logged to the
|
||||
# .execution file.
|
||||
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
(executed_op_types, executed_graph_ids,
|
||||
_, _, _, _) = self._readAndCheckExecutionFile()
|
||||
executed_op_types = [op_type for op_type in executed_op_types
|
||||
if "sin1p_log_sum" in op_type]
|
||||
self.assertLen(executed_op_types, 1)
|
||||
|
||||
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
|
||||
(context_ids, op_types,
|
||||
op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
|
||||
(context_ids, op_types, op_name_to_op_type,
|
||||
op_name_to_context_id) = self._readAndCheckGraphsFile(stack_frame_by_id)
|
||||
self.assertIn("AddV2", op_types)
|
||||
self.assertIn("Log", op_types)
|
||||
self.assertIn("Sin", op_types)
|
||||
if context.executing_eagerly():
|
||||
# Check the correctness of the ID of the executed graph ID.
|
||||
sin_op_name = [op_name for op_name in op_name_to_op_type
|
||||
if op_name_to_op_type[op_name] == "Sin"]
|
||||
self.assertLen(sin_op_name, 1)
|
||||
sin_context_id = op_name_to_context_id[sin_op_name[0]]
|
||||
# The executed "op" is a FuncGraph, and its graph ID should have been
|
||||
# recorded properly and be the ID of the graph that the Sin op belongs to.
|
||||
executed_graph_ids = [
|
||||
executed_graph_ids[i] for i, op_type
|
||||
in enumerate(executed_op_types) if "sin1p_log_sum" in op_type]
|
||||
self.assertEqual(executed_graph_ids[0], sin_context_id)
|
||||
|
||||
(op_names, _, _,
|
||||
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
|
||||
@ -248,6 +263,72 @@ class TracingCallbackTest(
|
||||
self.assertAllClose(tensor_values[3],
|
||||
np.sin(np.log(5.0) + 1.0)) # Sin op.
|
||||
|
||||
def testCapturingExecutedGraphIdsOfTwoCompilationsOfSameFunction(self):
|
||||
"""Test correct executed IDs of two FuncGraphs from the same Py function."""
|
||||
writer = dumping_callback.enable_dump_debug_info(
|
||||
self.dump_root, tensor_debug_mode="NO_TENSOR")
|
||||
|
||||
@def_function.function
|
||||
def ceil_times_two(x):
|
||||
return math_ops.ceil(x) * 2.0
|
||||
|
||||
x_float32 = np.array(3.5, dtype=np.float32)
|
||||
x_float64 = np.array(4.5, dtype=np.float64)
|
||||
# Four executions, with two different FuncGraphs, which should lead
|
||||
# to two unique executed graph IDs (see assertion below).
|
||||
self.assertAllClose(ceil_times_two(x_float32), 8.0)
|
||||
self.assertAllClose(ceil_times_two(x_float64), 10.0)
|
||||
self.assertAllClose(ceil_times_two(x_float32), 8.0)
|
||||
self.assertAllClose(ceil_times_two(x_float64), 10.0)
|
||||
writer.FlushNonExecutionFiles()
|
||||
writer.FlushExecutionFiles()
|
||||
|
||||
(executed_op_types, executed_graph_ids,
|
||||
_, _, _, _) = self._readAndCheckExecutionFile()
|
||||
self.assertLen(executed_op_types, 4)
|
||||
for executed_op_type in executed_op_types:
|
||||
self.assertStartsWith(executed_op_type, "__inference_ceil_times_two_")
|
||||
self.assertLen(executed_graph_ids, 4)
|
||||
self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
|
||||
self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
|
||||
self.assertLen(set(executed_graph_ids), 2)
|
||||
|
||||
def testCapturingExecutedGraphIdsOfDuplicateFunctionNames(self):
|
||||
"""Two FuncGraphs compiled from Python functions with identical names."""
|
||||
writer = dumping_callback.enable_dump_debug_info(
|
||||
self.dump_root, tensor_debug_mode="NO_TENSOR")
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
@def_function.function
|
||||
def ceil_times_two(self, x):
|
||||
return math_ops.ceil(x) * 2.0
|
||||
|
||||
# The `ceil_times_two` method of the two objects will be compiled
|
||||
# into separate FuncGraphs.
|
||||
test_object_1 = TestClass()
|
||||
test_object_2 = TestClass()
|
||||
|
||||
x = np.array(3.5, dtype=np.float32)
|
||||
# Four executions, with two different FuncGraphs, which should lead
|
||||
# to two unique executed graph IDs (see assertion below).
|
||||
self.assertAllClose(test_object_1.ceil_times_two(x), 8.0)
|
||||
self.assertAllClose(test_object_2.ceil_times_two(x), 8.0)
|
||||
self.assertAllClose(test_object_1.ceil_times_two(x), 8.0)
|
||||
self.assertAllClose(test_object_2.ceil_times_two(x), 8.0)
|
||||
writer.FlushNonExecutionFiles()
|
||||
writer.FlushExecutionFiles()
|
||||
|
||||
(executed_op_types, executed_graph_ids,
|
||||
_, _, _, _) = self._readAndCheckExecutionFile()
|
||||
self.assertLen(executed_op_types, 4)
|
||||
for executed_op_type in executed_op_types:
|
||||
self.assertStartsWith(executed_op_type, "__inference_ceil_times_two_")
|
||||
self.assertLen(executed_graph_ids, 4)
|
||||
self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
|
||||
self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
|
||||
self.assertLen(set(executed_graph_ids), 2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("AddV2", "AddV2"),
|
||||
("Log", "Log"),
|
||||
@ -438,7 +519,7 @@ class TracingCallbackTest(
|
||||
# After the flushing, the .execution file should hold the appropriate
|
||||
# contents.
|
||||
if context.executing_eagerly():
|
||||
(executed_op_types, input_tensor_ids, output_tensor_ids,
|
||||
(executed_op_types, _, input_tensor_ids, output_tensor_ids,
|
||||
tensor_debug_modes, tensor_values) = self._readAndCheckExecutionFile()
|
||||
# NOTE(b/142486213): Execution of the TF function happens with
|
||||
# Session.run() in v1 graph mode, hence it doesn't get logged to the
|
||||
@ -558,7 +639,7 @@ class TracingCallbackTest(
|
||||
writer.FlushExecutionFiles()
|
||||
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
|
||||
context_ids, _, _, _ = self._readAndCheckGraphsFile(stack_frame_by_id)
|
||||
_, _, _, _, tensor_values = self._readAndCheckExecutionFile()
|
||||
_, _, _, _, _, tensor_values = self._readAndCheckExecutionFile()
|
||||
self.assertEqual(tensor_values, [[]])
|
||||
(_, _, _,
|
||||
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
|
||||
@ -702,7 +783,7 @@ class TracingCallbackTest(
|
||||
self.assertAllClose(v1.read_value(), -67084290.0)
|
||||
self.assertAllClose(v2.read_value(), -6.0)
|
||||
|
||||
(executed_op_types, _, _, _,
|
||||
(executed_op_types, _, _, _, _,
|
||||
tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_1)
|
||||
v1_squared_values = [
|
||||
tensor_values[i] for i, op_type in enumerate(executed_op_types)
|
||||
@ -714,7 +795,7 @@ class TracingCallbackTest(
|
||||
self.assertAllClose(
|
||||
negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]])
|
||||
|
||||
(executed_op_types, _, _, _,
|
||||
(executed_op_types, _, _, _, _,
|
||||
tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_2)
|
||||
self.assertNotIn("Neg", executed_op_types)
|
||||
v2_squared_values = tensor_values[executed_op_types.index("Pow")]
|
||||
@ -800,7 +881,7 @@ class TracingCallbackTest(
|
||||
# NOTE(b/142486213): Execution of the TF function happens with
|
||||
# Session.run() in v1 graph mode, hence it doesn't get logged to the
|
||||
# .execution file.
|
||||
(executed_op_types, _, _, _,
|
||||
(executed_op_types, _, _, _, _,
|
||||
tensor_values) = self._readAndCheckExecutionFile()
|
||||
self.assertTrue(executed_op_types)
|
||||
|
||||
@ -867,7 +948,7 @@ class TracingCallbackTest(
|
||||
# NOTE(b/142486213): Execution of the TF function happens with
|
||||
# Session.run() in v1 graph mode, hence it doesn't get logged to the
|
||||
# .execution file.
|
||||
(executed_op_types, _, _, _,
|
||||
(executed_op_types, _, _, _, _,
|
||||
tensor_values) = self._readAndCheckExecutionFile()
|
||||
self.assertTrue(executed_op_types)
|
||||
if tensor_debug_mode == "NO_TENSOR":
|
||||
@ -940,7 +1021,7 @@ class TracingCallbackTest(
|
||||
# NOTE(b/142486213): Execution of the TF function happens with
|
||||
# Session.run() in v1 graph mode, hence it doesn't get logged to the
|
||||
# .execution file.
|
||||
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
self.assertTrue(executed_op_types)
|
||||
|
||||
(op_names, _, _,
|
||||
|
@ -193,6 +193,11 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
|
||||
|
||||
Returns:
|
||||
executed_op_types: Types of ops that are created, as a `list` of `str`.
|
||||
executed_graph_ids: A `list` of the same length as `executed_op_types`.
|
||||
If the executed op is a FuncGraph, the corresponding element of the
|
||||
`list` will be the ID of the FuncGraph. Else, the corresponding element
|
||||
will be an empty string. This allows establishing connection between
|
||||
eagerly executed FuncGraphs and their prior graph building.
|
||||
input_tensor_ids: Input tensor IDs for each of the ops executed, as a
|
||||
`list` of `list` of `int`s, with the same length as `executed_op_types`.
|
||||
output_tensor_ids: Output tensor IDs for each of the ops executed, as a
|
||||
@ -209,6 +214,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
|
||||
execution_iter = reader.execution_iterator()
|
||||
prev_wall_time = 1
|
||||
executed_op_types = []
|
||||
executed_graph_ids = [] # Empty string for execution of inidividual ops.
|
||||
input_tensor_ids = []
|
||||
output_tensor_ids = []
|
||||
tensor_debug_modes = []
|
||||
@ -218,6 +224,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
|
||||
prev_wall_time = debug_event.wall_time
|
||||
execution = debug_event.execution
|
||||
executed_op_types.append(execution.op_type)
|
||||
executed_graph_ids.append(execution.graph_id)
|
||||
input_tensor_ids.append(execution.input_tensor_ids)
|
||||
output_tensor_ids.append(execution.output_tensor_ids)
|
||||
tensor_debug_modes.append(execution.tensor_debug_mode)
|
||||
@ -227,8 +234,8 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
|
||||
])
|
||||
# TODO(cais): When tensor debug modes other than NO_TENSOR is supported,
|
||||
# return tensor_values as well.
|
||||
return (executed_op_types, input_tensor_ids, output_tensor_ids,
|
||||
tensor_debug_modes, tensor_values)
|
||||
return (executed_op_types, executed_graph_ids, input_tensor_ids,
|
||||
output_tensor_ids, tensor_debug_modes, tensor_values)
|
||||
|
||||
def _readAndCheckGraphExecutionTracesFile(self, context_ids):
|
||||
"""Read & verify the content of the .graph_execution_trace debug-event file.
|
||||
|
@ -345,19 +345,59 @@ class _InterpolateFunctionError(object):
|
||||
return False
|
||||
|
||||
|
||||
_function_callbacks = set()
|
||||
|
||||
|
||||
def add_function_callback(function_callback):
|
||||
"""Add a callback function for Function creation.
|
||||
|
||||
The callback function has the signature:
|
||||
|
||||
`def function_callback(function):`
|
||||
|
||||
wherein `function` is the just-created _EagerDefinedFunction.
|
||||
The callback is invoked immediately after a new `_EagerDefinedFunction`
|
||||
is created. The return value(s) of the callback fucntion (if any) is ignored.
|
||||
|
||||
Repeated registration of the same callback function is idempotent.
|
||||
After a callback is added, it can be removed with the
|
||||
`remove_function_callback()` method.
|
||||
|
||||
Args:
|
||||
function_callback: The callback to add.
|
||||
"""
|
||||
_function_callbacks.add(function_callback)
|
||||
|
||||
|
||||
def remove_function_callback(function_callback):
|
||||
"""Remove an already-added function callback.
|
||||
|
||||
See the doc string of `add_function_callback()` for more information.
|
||||
|
||||
Args:
|
||||
function_callback: The callback to remove.
|
||||
"""
|
||||
_function_callbacks.remove(function_callback)
|
||||
|
||||
|
||||
_FORWARD_PREFIX = "__forward_"
|
||||
_BACKWARD_PREFIX = "__backward_"
|
||||
_INFERENCE_PREFIX = "__inference_"
|
||||
|
||||
|
||||
def _forward_name(n):
|
||||
"""The name of a generated forward defun named n."""
|
||||
return "__forward_%s_%s" % (n, ops.uid())
|
||||
return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid())
|
||||
|
||||
|
||||
def _backward_name(n):
|
||||
"""The name of a generated backward defun named n."""
|
||||
return "__backward_%s_%s" % (n, ops.uid())
|
||||
return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid())
|
||||
|
||||
|
||||
def _inference_name(n):
|
||||
"""The name of a forward-but-no-gradient defun named n."""
|
||||
return "__inference_%s_%s" % (n, ops.uid())
|
||||
return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid())
|
||||
|
||||
|
||||
def _enclosing_xla_context():
|
||||
@ -463,7 +503,7 @@ class _EagerDefinedFunction(object):
|
||||
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
function_def = function_pb2.FunctionDef()
|
||||
function_def.ParseFromString(compat.as_bytes(proto_data))
|
||||
self.name = compat.as_bytes(function_def.signature.name)
|
||||
self._name = compat.as_bytes(function_def.signature.name)
|
||||
with ops.init_scope():
|
||||
if context.executing_eagerly():
|
||||
context.ensure_initialized()
|
||||
@ -485,6 +525,9 @@ class _EagerDefinedFunction(object):
|
||||
self.graph = graph
|
||||
self._stateful_ops = tuple(op for op in operations if op._is_stateful) # pylint: disable=protected-access
|
||||
|
||||
for function_callback in _function_callbacks:
|
||||
function_callback(self)
|
||||
|
||||
def add_to_graph(self, g=None):
|
||||
# pylint: disable=protected-access
|
||||
if not g and context.executing_eagerly():
|
||||
@ -497,6 +540,10 @@ class _EagerDefinedFunction(object):
|
||||
g._add_function(f)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def stateful_ops(self):
|
||||
return self._stateful_ops
|
||||
@ -533,6 +580,7 @@ class _EagerDefinedFunction(object):
|
||||
executor_type = function_call_options.executor_type or ""
|
||||
|
||||
executing_eagerly = ctx.executing_eagerly()
|
||||
attrs = ("executor_type", executor_type, "config_proto", config)
|
||||
if executing_eagerly:
|
||||
with _InterpolateFunctionError(self):
|
||||
if cancellation_manager is None:
|
||||
@ -540,14 +588,14 @@ class _EagerDefinedFunction(object):
|
||||
str(self.signature.name),
|
||||
num_outputs=self._num_outputs,
|
||||
inputs=args,
|
||||
attrs=("executor_type", executor_type, "config_proto", config),
|
||||
attrs=attrs,
|
||||
ctx=ctx)
|
||||
else:
|
||||
outputs = execute.execute_with_cancellation(
|
||||
str(self.signature.name),
|
||||
num_outputs=self._num_outputs,
|
||||
inputs=args,
|
||||
attrs=("executor_type", executor_type, "config_proto", config),
|
||||
attrs=attrs,
|
||||
ctx=ctx,
|
||||
cancellation_manager=cancellation_manager)
|
||||
# Replace empty list with None
|
||||
|
@ -447,11 +447,15 @@ def _list_function_deps(fdef, library_function_names):
|
||||
return deps
|
||||
|
||||
|
||||
_FUNCTION_WARPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (
|
||||
function_lib._INFERENCE_PREFIX) # pylint:disable=protected-access
|
||||
|
||||
|
||||
def _clean_function_name(name):
|
||||
"""Vanity function to keep the function names comprehensible."""
|
||||
# Note: each time a function is wrapped into `function_lib.ConcreteFunction`
|
||||
# its name becomes "__inference_<orig>_xyz".
|
||||
match = re.search(r"^__inference_(.*)_\d+$", name)
|
||||
match = re.search(_FUNCTION_WARPPER_NAME_REGEX, name)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user