diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a362dee97d9..c33a579ad28 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1044,6 +1044,7 @@ py_test( ":client_testlib", ":constant_op", ":error_interpolation", + ":traceable_stack", ], ) diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py index 72d5dc99a81..a79073b748e 100644 --- a/tensorflow/python/framework/error_interpolation.py +++ b/tensorflow/python/framework/error_interpolation.py @@ -60,6 +60,8 @@ def _parse_message(message): Supported tags after node: file: Replaced with the filename in which the node was defined. line: Replaced by the line number at which the node was defined. + colocations: Replaced by a multi-line message describing the file and + line numbers at which this node was colocated with other nodes. Args: message: String to parse @@ -85,13 +87,53 @@ def _parse_message(message): return seps, tags -def _get_field_dict_from_traceback(tf_traceback, frame_index): - """Convert traceback elements into interpolation dictionary and return.""" - frame = tf_traceback[frame_index] - return { - "file": frame[tf_stack.TB_FILENAME], - "line": frame[tf_stack.TB_LINENO], - } +def _compute_colocation_summary_from_dict(colocation_dict, prefix=""): + """Return a summary of an op's colocation stack. + + Args: + colocation_dict: The op._colocation_dict. + prefix: An optional string prefix used before each line of the multi- + line string returned by this function. + + Returns: + A multi-line string similar to: + Node-device colocations active during op creation: + with tf.colocate_with(test_node_1): + with tf.colocate_with(test_node_2): + The first line will have no padding to its left by default. Subsequent + lines will have two spaces of left-padding. Use the prefix argument + to increase indentation. + """ + if not colocation_dict: + message = "No node-device colocations were active during op creation." + return prefix + message + + str_list = [] + str_list.append("%sNode-device colocations active during op creation:" + % prefix) + + for name, location in colocation_dict.items(): + location_summary = "<{file}:{line}>".format(file=location.filename, + line=location.lineno) + subs = { + "prefix": prefix, + "indent": " ", + "name": name, + "loc": location_summary, + } + str_list.append( + "{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs)) + + return "\n".join(str_list) + + +def _compute_colocation_summary_from_op(op, prefix=""): + """Fetch colocation file, line, and nesting and return a summary string.""" + if not op: + return "" + # pylint: disable=protected-access + return _compute_colocation_summary_from_dict(op._colocation_dict, prefix) + # pylint: enable=protected-access def _find_index_of_defining_frame_for_op(op): @@ -125,6 +167,54 @@ def _find_index_of_defining_frame_for_op(op): return 0 +def _get_defining_frame_from_op(op): + """Find and return stack frame where op was defined.""" + frame = None + if op: + # pylint: disable=protected-access + frame_index = _find_index_of_defining_frame_for_op(op) + frame = op._traceback[frame_index] + # pylint: enable=protected-access + return frame + + +def _compute_field_dict(op): + """Return a dictionary mapping interpolation tokens to values. + + Args: + op: op.Operation object having a _traceback member. + + Returns: + A dictionary mapping string tokens to string values. The keys are shown + below along with example values. + { + "file": "tool_utils.py", + "line": "124", + "colocations": + '''Node-device colocations active during op creation: + with tf.colocate_with(test_node_1): + with tf.colocate_with(test_node_2): ''' + } + If op is None or lacks a _traceback field, the returned values will be + "". + """ + default_value = "" + field_dict = { + "file": default_value, + "line": default_value, + "colocations": default_value, + } + frame = _get_defining_frame_from_op(op) + if frame: + field_dict["file"] = frame[tf_stack.TB_FILENAME] + field_dict["line"] = frame[tf_stack.TB_LINENO] + colocation_summary = _compute_colocation_summary_from_op(op) + if colocation_summary: + field_dict["colocations"] = colocation_summary + + return field_dict + + def interpolate(error_message, graph): """Interpolates an error message. @@ -148,19 +238,7 @@ def interpolate(error_message, graph): except KeyError: op = None - if op: - frame_index = _find_index_of_defining_frame_for_op(op) - # pylint: disable=protected-access - field_dict = _get_field_dict_from_traceback(op._traceback, frame_index) - # pylint: enable=protected-access - else: - field_dict = { - "file": "", - "line": "", - "func": "", - "code": None, - } - node_name_to_substitution_dict[name] = field_dict + node_name_to_substitution_dict[name] = _compute_field_dict(op) subs = [ string.Template(tag.format).safe_substitute( diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py index b6615317d12..1e5cb738540 100644 --- a/tensorflow/python/framework/error_interpolation_test.py +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -22,6 +22,8 @@ import os from tensorflow.python.framework import constant_op from tensorflow.python.framework import error_interpolation +from tensorflow.python.framework import ops +from tensorflow.python.framework import traceable_stack from tensorflow.python.platform import test from tensorflow.python.util import tf_stack @@ -55,6 +57,47 @@ def _modify_op_stack_with_filenames(op, num_user_frames, user_filename, op._traceback = stack +def assert_node_in_colocation_summary(test_obj, colocation_summary_string, + name, filename="", lineno=""): + lineno = str(lineno) + name_phrase = "colocate_with(%s)" % name + for term in [name_phrase, filename, lineno]: + test_obj.assertIn(term, colocation_summary_string) + test_obj.assertNotIn("loc:@", colocation_summary_string) + + +class ComputeColocationSummaryFromOpTest(test.TestCase): + + def testCorrectFormatWithActiveColocations(self): + t_obj_1 = traceable_stack.TraceableObject(None, + filename="test_1.py", + lineno=27) + t_obj_2 = traceable_stack.TraceableObject(None, + filename="test_2.py", + lineno=38) + colocation_dict = { + "test_node_1": t_obj_1, + "test_node_2": t_obj_2, + } + summary = error_interpolation._compute_colocation_summary_from_dict( + colocation_dict, prefix=" ") + assert_node_in_colocation_summary(self, + summary, + name="test_node_1", + filename="test_1.py", + lineno=27) + assert_node_in_colocation_summary(self, summary, + name="test_node_2", + filename="test_2.py", + lineno=38) + + def testCorrectFormatWhenNoColocationsWereActive(self): + colocation_dict = {} + summary = error_interpolation._compute_colocation_summary_from_dict( + colocation_dict, prefix=" ") + self.assertIn("No node-device colocations", summary) + + class InterpolateTest(test.TestCase): def setUp(self): @@ -134,5 +177,56 @@ class InterpolateTest(test.TestCase): self.assertRegexpMatches(interpolated_string, expected_regex) +class InterpolateColocationSummaryTest(test.TestCase): + + def setUp(self): + # Add nodes to the graph for retrieval by name later. + node_one = constant_op.constant(1, name="One") + node_two = constant_op.constant(2, name="Two") + + # node_three has one colocation group, obviously. + with ops.colocate_with(node_one): + node_three = constant_op.constant(3, name="Three_with_one") + + # node_four has one colocation group even though three is (transitively) + # colocated with one. + with ops.colocate_with(node_three): + constant_op.constant(4, name="Four_with_three") + + # node_five has two colocation groups because one and two are not colocated. + with ops.colocate_with(node_two): + with ops.colocate_with(node_one): + constant_op.constant(5, name="Five_with_one_with_two") + + self.graph = node_three.graph + + def testNodeThreeHasColocationInterpolation(self): + message = "^^node:Three_with_one:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + assert_node_in_colocation_summary(self, result, name="One") + + def testNodeFourHasColocationInterpolationForNodeThreeOnly(self): + message = "^^node:Four_with_three:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + assert_node_in_colocation_summary(self, result, name="Three_with_one") + self.assertNotIn( + "One", result, + "Node One should not appear in Four_with_three's summary:\n%s" + % result) + + def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self): + message = "^^node:Five_with_one_with_two:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + assert_node_in_colocation_summary(self, result, name="One") + assert_node_in_colocation_summary(self, result, name="Two") + + def testColocationInterpolationForNodeLackingColocation(self): + message = "^^node:One:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + self.assertIn("No node-device colocations", result) + self.assertNotIn("One", result) + self.assertNotIn("Two", result) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ea7a9986fe8..b813cd6c068 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -47,10 +47,10 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import registry -from tensorflow.python.util import tf_stack from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import traceable_stack from tensorflow.python.framework import versions +from tensorflow.python.util import tf_stack from tensorflow.python.ops import control_flow_util from tensorflow.python.platform import app from tensorflow.python.platform import tf_logging as logging @@ -1712,10 +1712,14 @@ class Operation(object): # This will be set by self.inputs. self._inputs_val = None - self._id_value = self._graph._next_id() # pylint: disable=protected-access + # pylint: disable=protected-access + self._id_value = self._graph._next_id() self._original_op = original_op self._traceback = tf_stack.extract_stack() - self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access + # List of traceable_stack.TraceableObjects for colocation context managers. + self._colocation_code_locations = None + self._control_flow_context = self.graph._get_control_flow_context() + # pylint: enable=protected-access # Initialize self._c_op. if c_op: @@ -1853,6 +1857,42 @@ class Operation(object): """ return c_api.TF_OperationDevice(self._c_op) + @property + def _colocation_dict(self): + """Code locations for colocation context managers active at op creation. + + This property will return a dictionary for which the keys are nodes with + which this Operation is colocated, and for which the values are + traceable_stack.TraceableObject instances. The TraceableObject instances + record the location of the relevant colocation context manager but have the + "obj" field set to None to prevent leaking private data. + + For example, suppose file_a contained these lines: + + file_a.py: + 14: node_a = tf.constant(3, name='NODE_A') + 15: with tf.colocate_with(node_a): + 16: node_b = tf.constant(4, name='NODE_B') + + Then a TraceableObject t_obj representing the colocation context manager + would have these member values: + + t_obj.obj -> None + t_obj.name = 'NODE_A' + t_obj.filename = 'file_a.py' + t_obj.lineno = 15 + + and node_b.op._colocation_code_locations would return the dictionary + + { 'NODE_A': t_obj } + + Returns: + {str: traceable_stack.TraceableObject} as per this method's description, + above. + """ + locations_dict = self._colocation_code_locations or {} + return locations_dict.copy() + @property def _output_types(self): """List this operation's output types. @@ -3249,6 +3289,7 @@ class Graph(object): # pylint: disable=protected-access op._set_attr("_class", attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) + op._colocation_code_locations = self._snapshot_colocation_stack_metadata() # pylint: enable=protected-access # Sets "container" attribute if @@ -4010,7 +4051,10 @@ class Graph(object): self._colocation_stack = traceable_stack.TraceableStack() if op is not None: - self._colocation_stack.push_obj(op, name=op.name, offset=1) + # offset refers to the stack frame used for storing code location. + # We use 4, the sum of 1 to use our caller's stack frame and 3 + # to jump over layers of context managers above us. + self._colocation_stack.push_obj(op, offset=4) try: yield @@ -4658,6 +4702,11 @@ class Graph(object): else: return self._graph_colocation_stack + def _snapshot_colocation_stack_metadata(self): + """Return colocation stack metadata as a dictionary.""" + traceable_objects = self._colocation_stack.peek_traceable_objs() + return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects} + @_colocation_stack.setter def _colocation_stack(self, colocation_stack): if self._stack_state_is_thread_local: diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 150100d771b..f848b69782e 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -2554,6 +2554,14 @@ class ColocationGroupTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): c.op.get_attr("_class") + # Roughly test that stack information is being saved correctly for the op. + locations_dict = b.op._colocation_dict + self.assertIn("a", locations_dict) + metadata = locations_dict["a"] + self.assertIsNone(metadata.obj) + basename = metadata.filename.split("/")[-1] + self.assertEqual("ops_test.py", basename) + def testColocationDeviceInteraction(self): with ops.device("/cpu:0"): with ops.device("/device:GPU:0"): diff --git a/tensorflow/python/framework/traceable_stack.py b/tensorflow/python/framework/traceable_stack.py index 1b7c6bd7c56..7f4d28237ff 100644 --- a/tensorflow/python/framework/traceable_stack.py +++ b/tensorflow/python/framework/traceable_stack.py @@ -27,9 +27,8 @@ class TraceableObject(object): # Return codes for the set_filename_and_line_from_caller() method. SUCCESS, HEURISTIC_USED, FAILURE = (0, 1, 2) - def __init__(self, obj, name=None, filename=None, lineno=None): + def __init__(self, obj, filename=None, lineno=None): self.obj = obj - self.name = name self.filename = filename self.lineno = lineno @@ -72,8 +71,7 @@ class TraceableObject(object): def copy_metadata(self): """Return a TraceableObject like this one, but without the object.""" - return self.__class__(None, name=self.name, filename=self.filename, - lineno=self.lineno) + return self.__class__(None, filename=self.filename, lineno=self.lineno) class TraceableStack(object): @@ -88,12 +86,11 @@ class TraceableStack(object): """ self._stack = existing_stack[:] if existing_stack else [] - def push_obj(self, obj, name=None, offset=0): + def push_obj(self, obj, offset=0): """Add object to the stack and record its filename and line information. Args: obj: An object to store on the stack. - name: A name for the object, used for dict keys in get_item_metadata_dict. offset: Integer. If 0, the caller's stack frame is used. If 1, the caller's caller's stack frame is used. @@ -102,7 +99,7 @@ class TraceableStack(object): TraceableObject.HEURISTIC_USED if the stack was smaller than expected, and TraceableObject.FAILURE if the stack was empty. """ - traceable_obj = TraceableObject(obj, name=name) + traceable_obj = TraceableObject(obj) self._stack.append(traceable_obj) # Offset is defined in "Args" as relative to the caller. We are 1 frame # beyond the caller and need to compensate.