Track colocation context manager locations (file:line) and add colocation information support to error interpolation.

This CL add a new private property on ops: Operation._colocation_dict.  This property will return a dictionary for which the keys are nodes with which this Operation is colocated, and for which the values are traceable_stack.TraceableObject instances.  The TraceableObject instances record the location of the relevant colocation context manager but have the "obj" field set to None to prevent leaking private data.

For example, suppose file_a contained these lines:

  file_a.py:
    14: node_a = tf.constant(3, name='NODE_A')
    15: with tf.colocate_with(node_a):
    16:   node_b = tf.constant(4, name='NODE_B')

Then a TraceableObject t_obj representing the colocation context manager would have these member values:

  t_obj.obj -> None
  t_obj.name = 'NODE_A'
  t_obj.filename = 'file_a.py'
  t_obj.lineno = 15

and node_b.op._colocation_dict would return the dictionary

  { 'NODE_A': t_obj }

PiperOrigin-RevId: 205035378
This commit is contained in:
A. Unique TensorFlower 2018-07-18 00:18:21 -07:00 committed by TensorFlower Gardener
parent d9d029f510
commit 5d6aec5318
6 changed files with 258 additions and 31 deletions

View File

@ -1044,6 +1044,7 @@ py_test(
":client_testlib",
":constant_op",
":error_interpolation",
":traceable_stack",
],
)

View File

@ -60,6 +60,8 @@ def _parse_message(message):
Supported tags after node:<node_name>
file: Replaced with the filename in which the node was defined.
line: Replaced by the line number at which the node was defined.
colocations: Replaced by a multi-line message describing the file and
line numbers at which this node was colocated with other nodes.
Args:
message: String to parse
@ -85,13 +87,53 @@ def _parse_message(message):
return seps, tags
def _get_field_dict_from_traceback(tf_traceback, frame_index):
"""Convert traceback elements into interpolation dictionary and return."""
frame = tf_traceback[frame_index]
return {
"file": frame[tf_stack.TB_FILENAME],
"line": frame[tf_stack.TB_LINENO],
}
def _compute_colocation_summary_from_dict(colocation_dict, prefix=""):
"""Return a summary of an op's colocation stack.
Args:
colocation_dict: The op._colocation_dict.
prefix: An optional string prefix used before each line of the multi-
line string returned by this function.
Returns:
A multi-line string similar to:
Node-device colocations active during op creation:
with tf.colocate_with(test_node_1): <test_1.py:27>
with tf.colocate_with(test_node_2): <test_2.py:38>
The first line will have no padding to its left by default. Subsequent
lines will have two spaces of left-padding. Use the prefix argument
to increase indentation.
"""
if not colocation_dict:
message = "No node-device colocations were active during op creation."
return prefix + message
str_list = []
str_list.append("%sNode-device colocations active during op creation:"
% prefix)
for name, location in colocation_dict.items():
location_summary = "<{file}:{line}>".format(file=location.filename,
line=location.lineno)
subs = {
"prefix": prefix,
"indent": " ",
"name": name,
"loc": location_summary,
}
str_list.append(
"{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs))
return "\n".join(str_list)
def _compute_colocation_summary_from_op(op, prefix=""):
"""Fetch colocation file, line, and nesting and return a summary string."""
if not op:
return ""
# pylint: disable=protected-access
return _compute_colocation_summary_from_dict(op._colocation_dict, prefix)
# pylint: enable=protected-access
def _find_index_of_defining_frame_for_op(op):
@ -125,6 +167,54 @@ def _find_index_of_defining_frame_for_op(op):
return 0
def _get_defining_frame_from_op(op):
"""Find and return stack frame where op was defined."""
frame = None
if op:
# pylint: disable=protected-access
frame_index = _find_index_of_defining_frame_for_op(op)
frame = op._traceback[frame_index]
# pylint: enable=protected-access
return frame
def _compute_field_dict(op):
"""Return a dictionary mapping interpolation tokens to values.
Args:
op: op.Operation object having a _traceback member.
Returns:
A dictionary mapping string tokens to string values. The keys are shown
below along with example values.
{
"file": "tool_utils.py",
"line": "124",
"colocations":
'''Node-device colocations active during op creation:
with tf.colocate_with(test_node_1): <test_1.py:27>
with tf.colocate_with(test_node_2): <test_2.py:38>'''
}
If op is None or lacks a _traceback field, the returned values will be
"<NA>".
"""
default_value = "<NA>"
field_dict = {
"file": default_value,
"line": default_value,
"colocations": default_value,
}
frame = _get_defining_frame_from_op(op)
if frame:
field_dict["file"] = frame[tf_stack.TB_FILENAME]
field_dict["line"] = frame[tf_stack.TB_LINENO]
colocation_summary = _compute_colocation_summary_from_op(op)
if colocation_summary:
field_dict["colocations"] = colocation_summary
return field_dict
def interpolate(error_message, graph):
"""Interpolates an error message.
@ -148,19 +238,7 @@ def interpolate(error_message, graph):
except KeyError:
op = None
if op:
frame_index = _find_index_of_defining_frame_for_op(op)
# pylint: disable=protected-access
field_dict = _get_field_dict_from_traceback(op._traceback, frame_index)
# pylint: enable=protected-access
else:
field_dict = {
"file": "<NA>",
"line": "<NA>",
"func": "<NA>",
"code": None,
}
node_name_to_substitution_dict[name] = field_dict
node_name_to_substitution_dict[name] = _compute_field_dict(op)
subs = [
string.Template(tag.format).safe_substitute(

View File

@ -22,6 +22,8 @@ import os
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import ops
from tensorflow.python.framework import traceable_stack
from tensorflow.python.platform import test
from tensorflow.python.util import tf_stack
@ -55,6 +57,47 @@ def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
op._traceback = stack
def assert_node_in_colocation_summary(test_obj, colocation_summary_string,
name, filename="", lineno=""):
lineno = str(lineno)
name_phrase = "colocate_with(%s)" % name
for term in [name_phrase, filename, lineno]:
test_obj.assertIn(term, colocation_summary_string)
test_obj.assertNotIn("loc:@", colocation_summary_string)
class ComputeColocationSummaryFromOpTest(test.TestCase):
def testCorrectFormatWithActiveColocations(self):
t_obj_1 = traceable_stack.TraceableObject(None,
filename="test_1.py",
lineno=27)
t_obj_2 = traceable_stack.TraceableObject(None,
filename="test_2.py",
lineno=38)
colocation_dict = {
"test_node_1": t_obj_1,
"test_node_2": t_obj_2,
}
summary = error_interpolation._compute_colocation_summary_from_dict(
colocation_dict, prefix=" ")
assert_node_in_colocation_summary(self,
summary,
name="test_node_1",
filename="test_1.py",
lineno=27)
assert_node_in_colocation_summary(self, summary,
name="test_node_2",
filename="test_2.py",
lineno=38)
def testCorrectFormatWhenNoColocationsWereActive(self):
colocation_dict = {}
summary = error_interpolation._compute_colocation_summary_from_dict(
colocation_dict, prefix=" ")
self.assertIn("No node-device colocations", summary)
class InterpolateTest(test.TestCase):
def setUp(self):
@ -134,5 +177,56 @@ class InterpolateTest(test.TestCase):
self.assertRegexpMatches(interpolated_string, expected_regex)
class InterpolateColocationSummaryTest(test.TestCase):
def setUp(self):
# Add nodes to the graph for retrieval by name later.
node_one = constant_op.constant(1, name="One")
node_two = constant_op.constant(2, name="Two")
# node_three has one colocation group, obviously.
with ops.colocate_with(node_one):
node_three = constant_op.constant(3, name="Three_with_one")
# node_four has one colocation group even though three is (transitively)
# colocated with one.
with ops.colocate_with(node_three):
constant_op.constant(4, name="Four_with_three")
# node_five has two colocation groups because one and two are not colocated.
with ops.colocate_with(node_two):
with ops.colocate_with(node_one):
constant_op.constant(5, name="Five_with_one_with_two")
self.graph = node_three.graph
def testNodeThreeHasColocationInterpolation(self):
message = "^^node:Three_with_one:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
assert_node_in_colocation_summary(self, result, name="One")
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
message = "^^node:Four_with_three:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
assert_node_in_colocation_summary(self, result, name="Three_with_one")
self.assertNotIn(
"One", result,
"Node One should not appear in Four_with_three's summary:\n%s"
% result)
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
message = "^^node:Five_with_one_with_two:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
assert_node_in_colocation_summary(self, result, name="One")
assert_node_in_colocation_summary(self, result, name="Two")
def testColocationInterpolationForNodeLackingColocation(self):
message = "^^node:One:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No node-device colocations", result)
self.assertNotIn("One", result)
self.assertNotIn("Two", result)
if __name__ == "__main__":
test.main()

View File

@ -47,10 +47,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import registry
from tensorflow.python.util import tf_stack
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import traceable_stack
from tensorflow.python.framework import versions
from tensorflow.python.util import tf_stack
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
@ -1712,10 +1712,14 @@ class Operation(object):
# This will be set by self.inputs.
self._inputs_val = None
self._id_value = self._graph._next_id() # pylint: disable=protected-access
# pylint: disable=protected-access
self._id_value = self._graph._next_id()
self._original_op = original_op
self._traceback = tf_stack.extract_stack()
self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access
# List of traceable_stack.TraceableObjects for colocation context managers.
self._colocation_code_locations = None
self._control_flow_context = self.graph._get_control_flow_context()
# pylint: enable=protected-access
# Initialize self._c_op.
if c_op:
@ -1853,6 +1857,42 @@ class Operation(object):
"""
return c_api.TF_OperationDevice(self._c_op)
@property
def _colocation_dict(self):
"""Code locations for colocation context managers active at op creation.
This property will return a dictionary for which the keys are nodes with
which this Operation is colocated, and for which the values are
traceable_stack.TraceableObject instances. The TraceableObject instances
record the location of the relevant colocation context manager but have the
"obj" field set to None to prevent leaking private data.
For example, suppose file_a contained these lines:
file_a.py:
14: node_a = tf.constant(3, name='NODE_A')
15: with tf.colocate_with(node_a):
16: node_b = tf.constant(4, name='NODE_B')
Then a TraceableObject t_obj representing the colocation context manager
would have these member values:
t_obj.obj -> None
t_obj.name = 'NODE_A'
t_obj.filename = 'file_a.py'
t_obj.lineno = 15
and node_b.op._colocation_code_locations would return the dictionary
{ 'NODE_A': t_obj }
Returns:
{str: traceable_stack.TraceableObject} as per this method's description,
above.
"""
locations_dict = self._colocation_code_locations or {}
return locations_dict.copy()
@property
def _output_types(self):
"""List this operation's output types.
@ -3249,6 +3289,7 @@ class Graph(object):
# pylint: disable=protected-access
op._set_attr("_class", attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
# pylint: enable=protected-access
# Sets "container" attribute if
@ -4010,7 +4051,10 @@ class Graph(object):
self._colocation_stack = traceable_stack.TraceableStack()
if op is not None:
self._colocation_stack.push_obj(op, name=op.name, offset=1)
# offset refers to the stack frame used for storing code location.
# We use 4, the sum of 1 to use our caller's stack frame and 3
# to jump over layers of context managers above us.
self._colocation_stack.push_obj(op, offset=4)
try:
yield
@ -4658,6 +4702,11 @@ class Graph(object):
else:
return self._graph_colocation_stack
def _snapshot_colocation_stack_metadata(self):
"""Return colocation stack metadata as a dictionary."""
traceable_objects = self._colocation_stack.peek_traceable_objs()
return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects}
@_colocation_stack.setter
def _colocation_stack(self, colocation_stack):
if self._stack_state_is_thread_local:

View File

@ -2554,6 +2554,14 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
c.op.get_attr("_class")
# Roughly test that stack information is being saved correctly for the op.
locations_dict = b.op._colocation_dict
self.assertIn("a", locations_dict)
metadata = locations_dict["a"]
self.assertIsNone(metadata.obj)
basename = metadata.filename.split("/")[-1]
self.assertEqual("ops_test.py", basename)
def testColocationDeviceInteraction(self):
with ops.device("/cpu:0"):
with ops.device("/device:GPU:0"):

View File

@ -27,9 +27,8 @@ class TraceableObject(object):
# Return codes for the set_filename_and_line_from_caller() method.
SUCCESS, HEURISTIC_USED, FAILURE = (0, 1, 2)
def __init__(self, obj, name=None, filename=None, lineno=None):
def __init__(self, obj, filename=None, lineno=None):
self.obj = obj
self.name = name
self.filename = filename
self.lineno = lineno
@ -72,8 +71,7 @@ class TraceableObject(object):
def copy_metadata(self):
"""Return a TraceableObject like this one, but without the object."""
return self.__class__(None, name=self.name, filename=self.filename,
lineno=self.lineno)
return self.__class__(None, filename=self.filename, lineno=self.lineno)
class TraceableStack(object):
@ -88,12 +86,11 @@ class TraceableStack(object):
"""
self._stack = existing_stack[:] if existing_stack else []
def push_obj(self, obj, name=None, offset=0):
def push_obj(self, obj, offset=0):
"""Add object to the stack and record its filename and line information.
Args:
obj: An object to store on the stack.
name: A name for the object, used for dict keys in get_item_metadata_dict.
offset: Integer. If 0, the caller's stack frame is used. If 1,
the caller's caller's stack frame is used.
@ -102,7 +99,7 @@ class TraceableStack(object):
TraceableObject.HEURISTIC_USED if the stack was smaller than expected,
and TraceableObject.FAILURE if the stack was empty.
"""
traceable_obj = TraceableObject(obj, name=name)
traceable_obj = TraceableObject(obj)
self._stack.append(traceable_obj)
# Offset is defined in "Args" as relative to the caller. We are 1 frame
# beyond the caller and need to compensate.