Fix bug in parallel_walk, along with a few structural bugs that this fix revealed:
1. The conversion process was inconsistently packaging the final output into modules or lists. This CL uniformly uses a list of nodes as output from all *_to_graph functions. As a side effect, converter_testing.py asserts that the output is always a single node and extracts it, so there is no need for tests to unpack it any more. Modify the compiler to skip generating a source map by default. 2. The class converter was incorrectly saving the superclass value to the string 'object' instead of the symbol `object`. Additional refactoring that was caught along: Simplify the source mapping code, move it to origin_info.py, add tests and additional checks. Slightly simplify the error rewriting mechanism. PiperOrigin-RevId: 206087110
This commit is contained in:
parent
59305b118a
commit
b24037513f
@ -35,7 +35,7 @@ class AssertsTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = asserts.transform(node, ctx)
|
||||
|
||||
self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
|
||||
self.assertTrue(isinstance(node.body[0].value, gast.Call))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -38,7 +38,7 @@ class DirectivesTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(test_fn, {'directives': directives})
|
||||
node = directives_converter.transform(node, ctx)
|
||||
|
||||
def_, = anno.getanno(node.body[0].body[0].targets[0],
|
||||
def_, = anno.getanno(node.body[0].targets[0],
|
||||
anno.Static.DEFINITIONS)
|
||||
d = def_.directives[directives.set_element_type]
|
||||
self.assertEqual(d['dtype'].s, 'a')
|
||||
@ -52,7 +52,7 @@ class DirectivesTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(test_fn, {'directives': directives})
|
||||
node = directives_converter.transform(node, ctx)
|
||||
|
||||
def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
|
||||
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
|
||||
d = def_.directives[directives.set_element_type]
|
||||
self.assertEqual(d['dtype'].n, 1)
|
||||
self.assertEqual(d['shape'].n, 2)
|
||||
@ -67,7 +67,7 @@ class DirectivesTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(test_fn, {'directives': directives})
|
||||
node = directives_converter.transform(node, ctx)
|
||||
|
||||
d = anno.getanno(node.body[0].body[1], AgAnno.DIRECTIVES)
|
||||
d = anno.getanno(node.body[1], AgAnno.DIRECTIVES)
|
||||
d = d[directives.set_loop_options]
|
||||
self.assertEqual(d['parallel_iterations'].n, 10)
|
||||
self.assertEqual(d['back_prop'].id, 'a')
|
||||
|
@ -34,11 +34,13 @@ class ErrorHandlersTest(converter_testing.TestCase):
|
||||
raise ValueError()
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
anno.setanno(node.body[0], anno.Basic.ORIGIN,
|
||||
origin_info.OriginInfo('test_path', None, None, None, None))
|
||||
anno.setanno(node, anno.Basic.ORIGIN,
|
||||
origin_info.OriginInfo(None, None, None))
|
||||
node = error_handlers.transform(node, ctx)
|
||||
with self.compiled(node, {}) as result:
|
||||
with self.assertRaises(errors.GraphConstructionError):
|
||||
# Here we just assert that the handler works. Its correctness is
|
||||
# verified by errors_test.py.
|
||||
result.test_fn()
|
||||
|
||||
def test_no_origin_annotation(self):
|
||||
|
@ -79,7 +79,7 @@ class ListTest(converter_testing.TestCase):
|
||||
|
||||
ns = {'special_functions': special_functions}
|
||||
node, ctx = self.prepare(test_fn, ns)
|
||||
def_, = anno.getanno(node.body[0].body[0].targets[0],
|
||||
def_, = anno.getanno(node.body[0].targets[0],
|
||||
anno.Static.ORIG_DEFINITIONS)
|
||||
def_.directives[directives.set_element_type] = {
|
||||
'dtype': parser.parse_expression('tf.int32'),
|
||||
@ -114,7 +114,7 @@ class ListTest(converter_testing.TestCase):
|
||||
return tf.stack(l)
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
def_, = anno.getanno(node.body[0].body[0].targets[0],
|
||||
def_, = anno.getanno(node.body[0].targets[0],
|
||||
anno.Static.ORIG_DEFINITIONS)
|
||||
def_.directives[directives.set_element_type] = {
|
||||
'dtype': parser.parse_expression('tf.int32')
|
||||
|
@ -43,7 +43,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = side_effect_guards.transform(node, ctx)
|
||||
|
||||
self.assertEqual(len(node.body[0].body), 1)
|
||||
self.assertEqual(len(node.body), 1)
|
||||
|
||||
with self.compiled(node, {}, state_ops.assign) as result:
|
||||
with self.test_session() as sess:
|
||||
@ -64,7 +64,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = side_effect_guards.transform(node, ctx)
|
||||
|
||||
self.assertEqual(len(node.body[0].body), 1)
|
||||
self.assertEqual(len(node.body), 1)
|
||||
|
||||
with self.compiled(node, {}, state_ops.assign) as result:
|
||||
with self.test_session() as sess:
|
||||
@ -84,7 +84,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = side_effect_guards.transform(node, ctx)
|
||||
|
||||
self.assertEqual(len(node.body[0].body), 1)
|
||||
self.assertEqual(len(node.body), 1)
|
||||
|
||||
with self.compiled(node, {}, control_flow_ops.Assert) as result:
|
||||
with self.test_session() as sess:
|
||||
@ -104,7 +104,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = side_effect_guards.transform(node, ctx)
|
||||
|
||||
self.assertEqual(len(node.body[0].body), 1)
|
||||
self.assertEqual(len(node.body), 1)
|
||||
|
||||
with self.compiled(node, {}, state_ops.assign_add) as result:
|
||||
with self.test_session() as sess:
|
||||
@ -125,7 +125,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = side_effect_guards.transform(node, ctx)
|
||||
|
||||
self.assertEqual(len(node.body[0].body[0].body), 1)
|
||||
self.assertEqual(len(node.body[0].body), 1)
|
||||
|
||||
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
|
||||
with self.test_session() as sess:
|
||||
@ -147,7 +147,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = side_effect_guards.transform(node, ctx)
|
||||
|
||||
self.assertEqual(len(node.body[0].body), 1)
|
||||
self.assertEqual(len(node.body), 1)
|
||||
|
||||
with self.compiled(node, {}, state_ops.assign,
|
||||
state_ops.assign_add) as result:
|
||||
|
@ -38,7 +38,7 @@ class SliceTest(converter_testing.TestCase):
|
||||
return l[1]
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
|
||||
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
|
||||
def_.directives[directives.set_element_type] = {
|
||||
'dtype': parser.parse_expression('tf.int32')
|
||||
}
|
||||
@ -59,11 +59,11 @@ class SliceTest(converter_testing.TestCase):
|
||||
return l[1]
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
|
||||
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
|
||||
def_.directives[directives.set_element_type] = {
|
||||
'dtype': parser.parse_expression('tf.int32')
|
||||
}
|
||||
def_, = anno.getanno(node.body[0].body[0].body[0].targets[0],
|
||||
def_, = anno.getanno(node.body[0].body[0].targets[0],
|
||||
anno.Static.DEFINITIONS)
|
||||
def_.directives[directives.set_element_type] = {
|
||||
'dtype': parser.parse_expression('tf.float32')
|
||||
|
@ -94,7 +94,8 @@ class TestCase(test.TestCase):
|
||||
return 7
|
||||
|
||||
try:
|
||||
result, source = compiler.ast_to_object(node)
|
||||
result, source = compiler.ast_to_object(node, include_source_map=True)
|
||||
|
||||
result.tf = self.make_fake_mod('fake_tf', *symbols)
|
||||
fake_ag = self.make_fake_mod('fake_ag', converted_call)
|
||||
fake_ag.__dict__.update(operators.__dict__)
|
||||
@ -144,6 +145,7 @@ class TestCase(test.TestCase):
|
||||
recursive=True,
|
||||
autograph_decorators=()):
|
||||
node, source = parser.parse_entity(test_fn)
|
||||
node = node.body[0]
|
||||
if namer is None:
|
||||
namer = FakeNamer()
|
||||
program_ctx = converter.ProgramContext(
|
||||
|
@ -31,11 +31,14 @@ import logging
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from tensorflow.contrib.autograph.pyct.origin_info import CodeLocation
|
||||
from tensorflow.contrib.autograph.pyct import origin_info
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
# TODO(mdan): Add a superclass common to all errors.
|
||||
|
||||
|
||||
class GraphConstructionError(Exception):
|
||||
"""Error for graph construction errors from AutoGraph generated code."""
|
||||
|
||||
@ -65,27 +68,35 @@ class TfRuntimeError(Exception):
|
||||
return message + ''.join(traceback.format_list(self.custom_traceback))
|
||||
|
||||
|
||||
def _rewrite_frame(source_map, cleaned_traceback, stack_frame_indices):
|
||||
"""Rewrites the stack frames at the given indices using the given source map.
|
||||
def _rewrite_tb(source_map, tb, filter_function_name=None):
|
||||
"""Rewrites code references in a traceback.
|
||||
|
||||
Args:
|
||||
source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
|
||||
AG generated code.
|
||||
cleaned_traceback: List[Tuple[text, text, text, text]], the current
|
||||
traceback.
|
||||
stack_frame_indices: Iterable[Int], frame indices to possibly rewrite if
|
||||
there are matching source mapping keys.
|
||||
|
||||
source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping
|
||||
locations to their origin
|
||||
tb: List[Tuple[Text, Text, Text, Text]], consistent with
|
||||
traceback.extract_tb
|
||||
filter_function_name: Optional[Text], allows restricting restricts the
|
||||
frames to rewrite to a particular function name
|
||||
Returns:
|
||||
None
|
||||
List[Tuple[Text, Text, Text, Text]], the rewritten traceback
|
||||
"""
|
||||
for frame_index in stack_frame_indices:
|
||||
# (file_path, line number, function name, code)
|
||||
file_path, line_number, _, _ = cleaned_traceback[frame_index]
|
||||
source_map_key = CodeLocation(file_path=file_path, line_number=line_number)
|
||||
found_mapping = source_map_key in source_map
|
||||
if found_mapping:
|
||||
cleaned_traceback[frame_index] = source_map[source_map_key].as_frame()
|
||||
new_tb = []
|
||||
for frame in tb:
|
||||
filename, lineno, function_name, _ = frame
|
||||
loc = origin_info.LineLocation(filename, lineno)
|
||||
origin = source_map.get(loc)
|
||||
# TODO(mdan): We shouldn't need the function name at all.
|
||||
# filename + lineno should be sufficient, even if there are multiple source
|
||||
# maps.
|
||||
if origin is not None:
|
||||
if filter_function_name == function_name or filter_function_name is None:
|
||||
new_tb.append(origin.as_frame())
|
||||
else:
|
||||
new_tb.append(frame)
|
||||
else:
|
||||
new_tb.append(frame)
|
||||
return new_tb
|
||||
|
||||
|
||||
# TODO(znado): Make more robust to name changes in the rewriting logic.
|
||||
@ -98,18 +109,20 @@ def _remove_rewrite_frames(tb):
|
||||
return cleaned_tb
|
||||
|
||||
|
||||
# TODO(mdan): rename to raise_*
|
||||
def rewrite_graph_construction_error(source_map):
|
||||
"""Rewrites errors raised by non-AG APIs inside AG generated code.
|
||||
|
||||
Meant to be called from the try/except block inside each AutoGraph generated
|
||||
function. Only rewrites the traceback frames corresponding to the function
|
||||
that this is called from. When we raise a GraphConstructionError at the end
|
||||
it is then caught by calling functions, where they can be responsible for
|
||||
rewriting their own frames.
|
||||
This is called from the except handler inside an AutoGraph generated function
|
||||
(that is, during exception handling). Only rewrites the frames corresponding
|
||||
to the function that this is called from, so each function is responsible
|
||||
to call this to have its own frames rewritten.
|
||||
|
||||
This function always raises an error.
|
||||
|
||||
Args:
|
||||
source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
|
||||
AG generated code.
|
||||
source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source
|
||||
map belonging to the calling function
|
||||
|
||||
Raises:
|
||||
GraphConstructionError: The rewritten underlying error.
|
||||
@ -120,31 +133,19 @@ def rewrite_graph_construction_error(source_map):
|
||||
assert original_error is not None
|
||||
try:
|
||||
_, _, _, func_name, _, _ = tf_inspect.stack()[1]
|
||||
# The latest function call is added to the beginning of a traceback, but
|
||||
# when rewriting the traceback of multiple function calls the previous
|
||||
# functions' except blocks may have already rewritten their own frames so
|
||||
# we want to copy over all of the previous frames. We may have rewritten
|
||||
# previous frames only if the error is a GraphConstructionError.
|
||||
if isinstance(original_error, GraphConstructionError):
|
||||
# TODO(mdan): This is incomplete.
|
||||
# The error might have bubbled through a non-converted function.
|
||||
cleaned_traceback = traceback.extract_tb(e_traceback)
|
||||
previous_traceback = original_error.custom_traceback
|
||||
cleaned_traceback = [cleaned_traceback[0]] + previous_traceback
|
||||
else:
|
||||
cleaned_traceback = traceback.extract_tb(e_traceback)
|
||||
cleaned_traceback = _remove_rewrite_frames(cleaned_traceback)
|
||||
|
||||
current_frame_indices = []
|
||||
# This code is meant to be called from the try/except block that wraps a
|
||||
# function body. Here we look for all frames that came from the function
|
||||
# that this wraps, look for any matching line numbers in the source
|
||||
# mapping, and then rewrite them if matches are found.
|
||||
for fi, frame in enumerate(cleaned_traceback):
|
||||
_, _, frame_func_name, _ = frame
|
||||
if frame_func_name == func_name:
|
||||
current_frame_indices.append(fi)
|
||||
break
|
||||
if current_frame_indices:
|
||||
_rewrite_frame(source_map, cleaned_traceback, current_frame_indices)
|
||||
# Remove the frame corresponding to this function call.
|
||||
cleaned_traceback = cleaned_traceback[1:]
|
||||
|
||||
cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback, func_name)
|
||||
|
||||
if isinstance(original_error, GraphConstructionError):
|
||||
original_error.custom_traceback = cleaned_traceback
|
||||
@ -153,6 +154,7 @@ def rewrite_graph_construction_error(source_map):
|
||||
new_error = GraphConstructionError(original_error, cleaned_traceback)
|
||||
except Exception:
|
||||
logging.exception('Error while rewriting AutoGraph error:')
|
||||
# TODO(mdan): Should reraise here, removing the top frame as well.
|
||||
raise original_error
|
||||
else:
|
||||
raise new_error
|
||||
@ -161,18 +163,17 @@ def rewrite_graph_construction_error(source_map):
|
||||
del e_traceback
|
||||
|
||||
|
||||
# TODO(mdan): This should be consistent with rewrite_graph_construction_error
|
||||
# Both should either raise or return.
|
||||
def rewrite_tf_runtime_error(error, source_map):
|
||||
"""Rewrites TensorFlow runtime errors raised by ops created in AG code.
|
||||
|
||||
Args:
|
||||
error: error_impl.OpError, an TensorFlow error that will have its traceback
|
||||
rewritten.
|
||||
source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
|
||||
AG generated code.
|
||||
error: tf.OpError
|
||||
source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo]
|
||||
|
||||
Returns:
|
||||
A TfRuntimeError with a traceback rewritten according to the given
|
||||
source mapping.
|
||||
TfRuntimeError, the rewritten underlying error.
|
||||
"""
|
||||
# Check for cases where we leave a user method and re-enter it in the
|
||||
# traceback. This is done by looking at the function names when the
|
||||
@ -198,15 +199,16 @@ def rewrite_tf_runtime_error(error, source_map):
|
||||
# The source map keys are (file_path, line_number) so get the set of all user
|
||||
# file_paths.
|
||||
try:
|
||||
all_user_files = set(k.file_path for k in source_map)
|
||||
all_user_files = set(loc.filename for loc in source_map)
|
||||
cleaned_traceback = []
|
||||
last_user_frame_index = None
|
||||
last_user_user_file_path = None
|
||||
last_user_user_fn_name = None
|
||||
# TODO(mdan): Simplify this logic.
|
||||
for fi, frame in enumerate(error.op.traceback):
|
||||
frame_file_path, frame_line_number, _, _ = frame
|
||||
src_map_key = CodeLocation(
|
||||
file_path=frame_file_path, line_number=frame_line_number)
|
||||
frame_file_path, lineno, _, _ = frame
|
||||
lineno -= 1 # Frame line numbers are 1-based.
|
||||
src_map_key = origin_info.LineLocation(frame_file_path, lineno)
|
||||
if frame_file_path in all_user_files:
|
||||
if src_map_key in source_map:
|
||||
original_fn_name = source_map[src_map_key].function_name
|
||||
@ -223,8 +225,8 @@ def rewrite_tf_runtime_error(error, source_map):
|
||||
last_user_user_file_path = frame_file_path
|
||||
cleaned_traceback.append(frame)
|
||||
|
||||
for fi in range(len(cleaned_traceback)):
|
||||
_rewrite_frame(source_map, cleaned_traceback, [fi])
|
||||
cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
|
||||
|
||||
op_name = error.op.name
|
||||
op_message = error.message
|
||||
rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback)
|
||||
@ -263,7 +265,7 @@ def improved_errors(converted_function):
|
||||
ValueError: If converted_function is not generated by AutoGraph
|
||||
"""
|
||||
if (getattr(converted_function, 'ag_source_map', None) is None or
|
||||
not converted_function.ag_source_map):
|
||||
not isinstance(converted_function.ag_source_map, dict)):
|
||||
raise ValueError(
|
||||
'converted_function must be the result of an autograph.to_graph call')
|
||||
try:
|
||||
|
@ -28,88 +28,76 @@ from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def zero_div():
|
||||
return array_ops.constant(10, dtype=dtypes.int32) // 0
|
||||
x = array_ops.constant(10, dtype=dtypes.int32)
|
||||
return x // 0
|
||||
|
||||
|
||||
def zero_div_caller():
|
||||
a = zero_div() + 2
|
||||
return a
|
||||
return zero_div()
|
||||
|
||||
|
||||
class RuntimeErrorsTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._fake_origin = origin_info.OriginInfo('new file', 'new func', 96, 0,
|
||||
'print("hello world!")')
|
||||
def fake_origin(self, function, line_offset):
|
||||
_, lineno = tf_inspect.getsourcelines(function)
|
||||
filename = tf_inspect.getsourcefile(function)
|
||||
lineno += line_offset
|
||||
loc = origin_info.LineLocation(filename, lineno)
|
||||
origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code')
|
||||
return loc, origin
|
||||
|
||||
def test_error_replacement(self):
|
||||
_, zero_div_lineno = tf_inspect.getsourcelines(zero_div)
|
||||
src_map = {
|
||||
errors.CodeLocation(
|
||||
file_path=__file__, line_number=zero_div_lineno + 1):
|
||||
self._fake_origin
|
||||
}
|
||||
def test_improved_errors_basic(self):
|
||||
loc, origin = self.fake_origin(zero_div, 2)
|
||||
zero_div_caller.ag_source_map = {loc: origin}
|
||||
|
||||
ops = zero_div_caller()
|
||||
with self.assertRaises(errors.TfRuntimeError) as cm:
|
||||
z = zero_div_caller()
|
||||
zero_div_caller.ag_source_map = src_map
|
||||
with errors.improved_errors(zero_div_caller):
|
||||
with self.test_session() as sess:
|
||||
sess.run(z)
|
||||
expected = cm.exception
|
||||
current_traceback = expected.custom_traceback
|
||||
for frame in current_traceback:
|
||||
self.assertNotEqual('zero_div', frame[2])
|
||||
self.assertTrue(
|
||||
any(self._fake_origin.as_frame() == frame
|
||||
for frame in current_traceback))
|
||||
sess.run(ops)
|
||||
|
||||
def test_error_not_found(self):
|
||||
src_map = {
|
||||
errors.CodeLocation(file_path=__file__, line_number=-1):
|
||||
self._fake_origin
|
||||
}
|
||||
for frame in cm.exception.custom_traceback:
|
||||
_, _, function_name, _ = frame
|
||||
self.assertNotEqual('zero_div', function_name)
|
||||
self.assertIn(origin.as_frame(), set(cm.exception.custom_traceback))
|
||||
|
||||
def test_improved_errors_no_matching_lineno(self):
|
||||
loc, origin = self.fake_origin(zero_div, -1)
|
||||
zero_div_caller.ag_source_map = {loc: origin}
|
||||
|
||||
ops = zero_div_caller()
|
||||
with self.assertRaises(errors.TfRuntimeError) as cm:
|
||||
z = zero_div_caller()
|
||||
zero_div_caller.ag_source_map = src_map
|
||||
with errors.improved_errors(zero_div_caller):
|
||||
with self.test_session() as sess:
|
||||
sess.run(z)
|
||||
expected = cm.exception
|
||||
current_traceback = expected.custom_traceback
|
||||
self.assertTrue(any('zero_div' in frame[2] for frame in current_traceback))
|
||||
for frame in current_traceback:
|
||||
self.assertNotEqual(frame, self._fake_origin.as_frame())
|
||||
sess.run(ops)
|
||||
|
||||
def test_rewriting_error(self):
|
||||
_, zero_div_lineno = tf_inspect.getsourcelines(zero_div)
|
||||
src_map = {
|
||||
errors.CodeLocation(
|
||||
file_path=__file__, line_number=zero_div_lineno + 1):
|
||||
None
|
||||
}
|
||||
with self.assertRaisesRegexp(tf_errors.InvalidArgumentError,
|
||||
'Integer division by zero'):
|
||||
z = zero_div_caller()
|
||||
zero_div_caller.ag_source_map = src_map
|
||||
all_function_names = set()
|
||||
for frame in cm.exception.custom_traceback:
|
||||
_, _, function_name, _ = frame
|
||||
all_function_names.add(function_name)
|
||||
self.assertNotEqual('test_function_name', function_name)
|
||||
self.assertIn('zero_div', all_function_names)
|
||||
|
||||
def test_improved_errors_failures(self):
|
||||
loc, _ = self.fake_origin(zero_div, 2)
|
||||
zero_div_caller.ag_source_map = {loc: 'bogus object'}
|
||||
|
||||
ops = zero_div_caller()
|
||||
with self.assertRaises(tf_errors.InvalidArgumentError):
|
||||
with errors.improved_errors(zero_div_caller):
|
||||
with self.test_session() as sess:
|
||||
sess.run(z)
|
||||
sess.run(ops)
|
||||
|
||||
def test_no_ag_source_map(self):
|
||||
def test_improved_errors_validation(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'converted_function must be the result of an autograph.to_graph call'):
|
||||
with errors.improved_errors(None):
|
||||
pass
|
||||
|
||||
def test_bad_ag_source_map(self):
|
||||
errors.improved_errors(zero_div).__enter__()
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'converted_function must be the result of an autograph.to_graph call'):
|
||||
src_map = None
|
||||
zero_div_caller.ag_source_map = src_map
|
||||
with errors.improved_errors(None):
|
||||
pass
|
||||
zero_div_caller.ag_source_map = 'not a dict'
|
||||
errors.improved_errors(zero_div_caller).__enter__()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -23,7 +23,6 @@ from functools import wraps
|
||||
from enum import Enum
|
||||
|
||||
# pylint:disable=g-bad-import-order
|
||||
import gast
|
||||
import six
|
||||
# pylint:enable=g-bad-import-order
|
||||
|
||||
@ -245,19 +244,21 @@ def to_graph(e,
|
||||
_, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
|
||||
arg_types)
|
||||
|
||||
module = gast.Module([])
|
||||
nodes = []
|
||||
for dep in reversed(program_ctx.dependency_cache.values()):
|
||||
module.body.append(dep)
|
||||
compiled_node, compiled_src = compiler.ast_to_object(
|
||||
module, source_prefix=program_ctx.required_imports)
|
||||
nodes.extend(dep)
|
||||
compiled_module, compiled_src = compiler.ast_to_object(
|
||||
nodes,
|
||||
source_prefix=program_ctx.required_imports,
|
||||
include_source_map=True)
|
||||
|
||||
# The compiled code should see everything the entry entity saw.
|
||||
# TODO(mdan): This might not work well if the call tree spans modules?
|
||||
for key, val in namespace.items():
|
||||
# Avoid overwriting entities that have been transformed.
|
||||
if key not in compiled_node.__dict__:
|
||||
compiled_node.__dict__[key] = val
|
||||
compiled_fn = getattr(compiled_node, name)
|
||||
if key not in compiled_module.__dict__:
|
||||
compiled_module.__dict__[key] = val
|
||||
compiled_fn = getattr(compiled_module, name)
|
||||
|
||||
# Need this so the source_mapping attribute is available for the context
|
||||
# manager to access for runtime errors.
|
||||
@ -270,7 +271,7 @@ def to_graph(e,
|
||||
'"%s", which is reserved for AutoGraph.' %
|
||||
(compiled_fn, source_map_attribute_name))
|
||||
setattr(compiled_fn, source_map_attribute_name,
|
||||
compiled_node.__dict__['ag_source_map__'])
|
||||
compiled_module.__dict__['ag_source_map__'])
|
||||
|
||||
if verbose:
|
||||
logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
|
||||
@ -308,7 +309,7 @@ def to_code(e,
|
||||
conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
|
||||
|
||||
code = '\n'.join(
|
||||
compiler.ast_to_source(dep, indentation)[0]
|
||||
compiler.ast_to_source(dep, indentation)
|
||||
for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))
|
||||
|
||||
return program_ctx.required_imports + '\n\n' + code
|
||||
|
@ -164,7 +164,7 @@ def class_to_graph(c, program_ctx):
|
||||
class_namespace = namespace
|
||||
else:
|
||||
class_namespace.update(namespace)
|
||||
converted_members[m] = node
|
||||
converted_members[m] = node[0]
|
||||
namer = program_ctx.new_namer(class_namespace)
|
||||
class_name = namer.compiled_class_name(c.__name__, c)
|
||||
|
||||
@ -175,10 +175,10 @@ def class_to_graph(c, program_ctx):
|
||||
# program_ctx.update_name_map(namer)).
|
||||
output_nodes = []
|
||||
renames = {}
|
||||
bases = []
|
||||
base_names = []
|
||||
for base in c.__bases__:
|
||||
if isinstance(object, base):
|
||||
bases.append('object')
|
||||
base_names.append('object')
|
||||
continue
|
||||
if is_whitelisted_for_graph(base):
|
||||
alias = namer.new_symbol(base.__name__, ())
|
||||
@ -190,28 +190,28 @@ def class_to_graph(c, program_ctx):
|
||||
else:
|
||||
# This will trigger a conversion into a class with this name.
|
||||
alias = namer.compiled_class_name(base.__name__, base)
|
||||
bases.append(alias)
|
||||
base_names.append(alias)
|
||||
renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
|
||||
program_ctx.update_name_map(namer)
|
||||
|
||||
# Generate the definition of the converted class.
|
||||
output_nodes.append(
|
||||
gast.ClassDef(
|
||||
class_name,
|
||||
bases=bases,
|
||||
keywords=[],
|
||||
body=list(converted_members.values()),
|
||||
decorator_list=[]))
|
||||
node = gast.Module(output_nodes)
|
||||
|
||||
bases = [gast.Name(n, gast.Load(), None) for n in base_names]
|
||||
class_def = gast.ClassDef(
|
||||
class_name,
|
||||
bases=bases,
|
||||
keywords=[],
|
||||
body=list(converted_members.values()),
|
||||
decorator_list=[])
|
||||
# Make a final pass to replace references to the class or its base classes.
|
||||
# Most commonly, this occurs when making super().__init__() calls.
|
||||
# TODO(mdan): Making direct references to superclass' superclass will fail.
|
||||
node = qual_names.resolve(node)
|
||||
class_def = qual_names.resolve(class_def)
|
||||
renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
|
||||
node = ast_util.rename_symbols(node, renames)
|
||||
class_def = ast_util.rename_symbols(class_def, renames)
|
||||
|
||||
return node, class_name, class_namespace
|
||||
output_nodes.append(class_def)
|
||||
|
||||
return output_nodes, class_name, class_namespace
|
||||
|
||||
|
||||
def _add_reserved_symbol(namespace, name, entity):
|
||||
@ -279,7 +279,7 @@ def function_to_graph(f,
|
||||
program_ctx.update_name_map(namer)
|
||||
# TODO(mdan): Use this at compilation.
|
||||
|
||||
return node, new_name, namespace
|
||||
return (node,), new_name, namespace
|
||||
|
||||
|
||||
def node_to_graph(node, context, rewrite_errors=True):
|
||||
|
@ -60,10 +60,11 @@ class ConversionTest(test.TestCase):
|
||||
return a + b
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
ast, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
|
||||
self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
|
||||
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
|
||||
fn_node, = nodes
|
||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
||||
self.assertEqual('tf__f', name)
|
||||
self.assertTrue(ns['b'] is b)
|
||||
self.assertIs(ns['b'], b)
|
||||
|
||||
def test_entity_to_graph_call_tree(self):
|
||||
|
||||
@ -78,14 +79,11 @@ class ConversionTest(test.TestCase):
|
||||
|
||||
self.assertTrue(f in program_ctx.dependency_cache)
|
||||
self.assertTrue(g in program_ctx.dependency_cache)
|
||||
self.assertEqual('tf__f', program_ctx.dependency_cache[f].name)
|
||||
# need one extra .body[0] in order to step past the try/except wrapper that
|
||||
# is added automatically, the other for the with tf.name_scope('f') that is
|
||||
# added automatically
|
||||
self.assertEqual(
|
||||
'tf__g',
|
||||
program_ctx.dependency_cache[f].body[0].body[0].body[0].value.func.id)
|
||||
self.assertEqual('tf__g', program_ctx.dependency_cache[g].name)
|
||||
f_node = program_ctx.dependency_cache[f][0]
|
||||
g_node = program_ctx.dependency_cache[g][0]
|
||||
self.assertEqual('tf__f', f_node.name)
|
||||
self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id)
|
||||
self.assertEqual('tf__g', g_node.name)
|
||||
|
||||
def test_entity_to_graph_class_hierarchy(self):
|
||||
|
||||
@ -118,9 +116,9 @@ class ConversionTest(test.TestCase):
|
||||
self.assertTrue(TestBase in program_ctx.dependency_cache)
|
||||
self.assertTrue(TestSubclass in program_ctx.dependency_cache)
|
||||
self.assertEqual('TfTestBase',
|
||||
program_ctx.dependency_cache[TestBase].body[-1].name)
|
||||
program_ctx.dependency_cache[TestBase][-1].name)
|
||||
self.assertEqual('TfTestSubclass',
|
||||
program_ctx.dependency_cache[TestSubclass].body[-1].name)
|
||||
program_ctx.dependency_cache[TestSubclass][-1].name)
|
||||
|
||||
def test_entity_to_graph_class_hierarchy_whitelisted(self):
|
||||
|
||||
@ -139,10 +137,9 @@ class ConversionTest(test.TestCase):
|
||||
self.assertTrue(TestSubclass in program_ctx.dependency_cache)
|
||||
self.assertFalse(training.Model in program_ctx.dependency_cache)
|
||||
self.assertEqual(
|
||||
'Model',
|
||||
program_ctx.dependency_cache[TestSubclass].body[0].names[0].name)
|
||||
'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name)
|
||||
self.assertEqual('TfTestSubclass',
|
||||
program_ctx.dependency_cache[TestSubclass].body[-1].name)
|
||||
program_ctx.dependency_cache[TestSubclass][-1].name)
|
||||
|
||||
def test_entity_to_graph_lambda(self):
|
||||
f = lambda a: a
|
||||
|
@ -99,6 +99,16 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "origin_info_test",
|
||||
srcs = ["origin_info_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pyct",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "parser_test",
|
||||
srcs = ["parser_test.py"],
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import ast
|
||||
|
||||
import collections
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.autograph.pyct import anno
|
||||
@ -185,6 +184,7 @@ class PatternMatcher(gast.NodeVisitor):
|
||||
if v != p:
|
||||
return self.no_match()
|
||||
|
||||
|
||||
def matches(node, pattern):
|
||||
"""Basic pattern matcher for AST.
|
||||
|
||||
@ -253,30 +253,61 @@ def apply_to_single_assignments(targets, values, apply_fn):
|
||||
apply_fn(target, values)
|
||||
|
||||
|
||||
def iter_fields(node):
|
||||
for field in sorted(node._fields):
|
||||
try:
|
||||
yield getattr(node, field)
|
||||
except AttributeError:
|
||||
pass
|
||||
def parallel_walk(node, other):
|
||||
"""Walks two ASTs in parallel.
|
||||
|
||||
The two trees must have identical structure.
|
||||
|
||||
def iter_child_nodes(node):
|
||||
for field in iter_fields(node):
|
||||
if isinstance(field, gast.AST):
|
||||
yield field
|
||||
elif isinstance(field, list):
|
||||
for item in field:
|
||||
if isinstance(item, gast.AST):
|
||||
yield item
|
||||
Args:
|
||||
node: Union[ast.AST, Iterable[ast.AST]]
|
||||
other: Union[ast.AST, Iterable[ast.AST]]
|
||||
Yields:
|
||||
Tuple[ast.AST, ast.AST]
|
||||
Raises:
|
||||
ValueError: if the two trees don't have identical structure.
|
||||
"""
|
||||
if isinstance(node, (list, tuple)):
|
||||
node_stack = list(node)
|
||||
else:
|
||||
node_stack = [node]
|
||||
|
||||
if isinstance(other, (list, tuple)):
|
||||
other_stack = list(other)
|
||||
else:
|
||||
other_stack = [other]
|
||||
|
||||
def parallel_walk(node_a, node_b):
|
||||
todo_a = collections.deque([node_a])
|
||||
todo_b = collections.deque([node_b])
|
||||
while todo_a and todo_b:
|
||||
node_a = todo_a.popleft()
|
||||
node_b = todo_b.popleft()
|
||||
todo_a.extend(iter_child_nodes(node_a))
|
||||
todo_b.extend(iter_child_nodes(node_b))
|
||||
yield node_a, node_b
|
||||
while node_stack and other_stack:
|
||||
assert len(node_stack) == len(other_stack)
|
||||
n = node_stack.pop()
|
||||
o = other_stack.pop()
|
||||
|
||||
if (not isinstance(n, (ast.AST, gast.AST)) or
|
||||
not isinstance(o, (ast.AST, gast.AST)) or
|
||||
n.__class__.__name__ != o.__class__.__name__):
|
||||
raise ValueError('inconsistent nodes: {} and {}'.format(n, o))
|
||||
|
||||
yield n, o
|
||||
|
||||
for f in n._fields:
|
||||
n_child = getattr(n, f, None)
|
||||
o_child = getattr(o, f, None)
|
||||
if f.startswith('__') or n_child is None or o_child is None:
|
||||
continue
|
||||
|
||||
if isinstance(n_child, (list, tuple)):
|
||||
if (not isinstance(o_child, (list, tuple)) or
|
||||
len(n_child) != len(o_child)):
|
||||
raise ValueError(
|
||||
'inconsistent values for field {}: {} and {}'.format(
|
||||
f, n_child, o_child))
|
||||
node_stack.extend(n_child)
|
||||
other_stack.extend(o_child)
|
||||
|
||||
elif isinstance(n_child, (gast.AST, ast.AST)):
|
||||
node_stack.append(n_child)
|
||||
other_stack.append(o_child)
|
||||
|
||||
elif n_child != o_child:
|
||||
raise ValueError(
|
||||
'inconsistent values for field {}: {} and {}'.format(
|
||||
f, n_child, o_child))
|
||||
|
@ -44,7 +44,7 @@ class AstUtilTest(test.TestCase):
|
||||
node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
|
||||
|
||||
self.assertIsInstance(node.body[0].value.left.id, str)
|
||||
source, _ = compiler.ast_to_source(node)
|
||||
source = compiler.ast_to_source(node)
|
||||
self.assertEqual(source.strip(), 'renamed_a + b')
|
||||
|
||||
def test_rename_symbols_attributes(self):
|
||||
@ -54,7 +54,7 @@ class AstUtilTest(test.TestCase):
|
||||
node = ast_util.rename_symbols(
|
||||
node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})
|
||||
|
||||
source, _ = compiler.ast_to_source(node)
|
||||
source = compiler.ast_to_source(node)
|
||||
self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
|
||||
|
||||
def test_rename_symbols_annotations(self):
|
||||
@ -97,10 +97,10 @@ class AstUtilTest(test.TestCase):
|
||||
d = ast_util.keywords_to_dict(keywords)
|
||||
# Make sure we generate a usable dict node by attaching it to a variable and
|
||||
# compiling everything.
|
||||
output = parser.parse_str('b = 3')
|
||||
output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),)
|
||||
result, _ = compiler.ast_to_object(output)
|
||||
self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
|
||||
node = parser.parse_str('def f(b): pass').body[0]
|
||||
node.body.append(ast.Return(d))
|
||||
result, _ = compiler.ast_to_object(node)
|
||||
self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
|
||||
|
||||
def assertMatch(self, target_str, pattern_str):
|
||||
node = parser.parse_expression(target_str)
|
||||
@ -130,8 +130,8 @@ class AstUtilTest(test.TestCase):
|
||||
'super(Bar, _).__init__(_)')
|
||||
|
||||
def _mock_apply_fn(self, target, source):
|
||||
target, _ = compiler.ast_to_source(target)
|
||||
source, _ = compiler.ast_to_source(source)
|
||||
target = compiler.ast_to_source(target)
|
||||
source = compiler.ast_to_source(source)
|
||||
self._invocation_counts[(target.strip(), source.strip())] += 1
|
||||
|
||||
def test_apply_to_single_assignments_dynamic_unpack(self):
|
||||
@ -157,24 +157,40 @@ class AstUtilTest(test.TestCase):
|
||||
})
|
||||
|
||||
def test_parallel_walk(self):
|
||||
ret = ast.Return(
|
||||
ast.BinOp(
|
||||
op=ast.Add(),
|
||||
left=ast.Name(id='a', ctx=ast.Load()),
|
||||
right=ast.Num(1)))
|
||||
node = ast.FunctionDef(
|
||||
name='f',
|
||||
args=ast.arguments(
|
||||
args=[ast.Name(id='a', ctx=ast.Param())],
|
||||
vararg=None,
|
||||
kwarg=None,
|
||||
defaults=[]),
|
||||
body=[ret],
|
||||
decorator_list=[],
|
||||
returns=None)
|
||||
node = parser.parse_str(
|
||||
textwrap.dedent("""
|
||||
def f(a):
|
||||
return a + 1
|
||||
"""))
|
||||
for child_a, child_b in ast_util.parallel_walk(node, node):
|
||||
self.assertEqual(child_a, child_b)
|
||||
|
||||
def test_parallel_walk_inconsistent_trees(self):
|
||||
node_1 = parser.parse_str(
|
||||
textwrap.dedent("""
|
||||
def f(a):
|
||||
return a + 1
|
||||
"""))
|
||||
node_2 = parser.parse_str(
|
||||
textwrap.dedent("""
|
||||
def f(a):
|
||||
return a + (a * 2)
|
||||
"""))
|
||||
node_3 = parser.parse_str(
|
||||
textwrap.dedent("""
|
||||
def f(a):
|
||||
return a + 2
|
||||
"""))
|
||||
with self.assertRaises(ValueError):
|
||||
for _ in ast_util.parallel_walk(node_1, node_2):
|
||||
pass
|
||||
# There is not particular reason to reject trees that differ only in the
|
||||
# value of a constant.
|
||||
# TODO(mdan): This should probably be allowed.
|
||||
with self.assertRaises(ValueError):
|
||||
for _ in ast_util.parallel_walk(node_1, node_3):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -67,10 +67,8 @@ class Node(object):
|
||||
if isinstance(self.ast_node, gast.FunctionDef):
|
||||
return 'def %s' % self.ast_node.name
|
||||
elif isinstance(self.ast_node, gast.withitem):
|
||||
source, _ = compiler.ast_to_source(self.ast_node.context_expr)
|
||||
return source.strip()
|
||||
source, _ = compiler.ast_to_source(self.ast_node)
|
||||
return source.strip()
|
||||
return compiler.ast_to_source(self.ast_node.context_expr).strip()
|
||||
return compiler.ast_to_source(self.ast_node).strip()
|
||||
|
||||
|
||||
class Graph(
|
||||
|
@ -43,7 +43,7 @@ class AnfTransformerTest(test.TestCase):
|
||||
return a
|
||||
|
||||
node, _ = parser.parse_entity(test_function)
|
||||
node = anf.transform(node, self._simple_source_info())
|
||||
node = anf.transform(node.body[0], self._simple_source_info())
|
||||
result, _ = compiler.ast_to_object(node)
|
||||
|
||||
self.assertEqual(test_function(), result.test_function())
|
||||
|
@ -30,44 +30,7 @@ import tempfile
|
||||
import astor
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.autograph.pyct import anno
|
||||
from tensorflow.contrib.autograph.pyct import ast_util
|
||||
from tensorflow.contrib.autograph.pyct import origin_info
|
||||
from tensorflow.contrib.autograph.pyct import parser
|
||||
|
||||
|
||||
def _build_source_map(node, code):
|
||||
"""Return the Python objects represented by given AST.
|
||||
|
||||
Compiling the AST code this way ensures that the source code is readable by
|
||||
e.g. `pdb` or `inspect`.
|
||||
|
||||
Args:
|
||||
node: An AST node of the original generated code, before the source code is
|
||||
generated.
|
||||
code: The string representation of the source code for the newly generated
|
||||
code.
|
||||
|
||||
Returns:
|
||||
Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
|
||||
generated code.
|
||||
"""
|
||||
# After we have the final generated code we reparse it to get the final line
|
||||
# numbers. Then we walk through the generated and original ASTs in parallel
|
||||
# to build the mapping between the user and generated code.
|
||||
new_node = parser.parse_str(code)
|
||||
origin_info.resolve(new_node, code)
|
||||
source_mapping = {}
|
||||
for before, after in ast_util.parallel_walk(node, new_node):
|
||||
# Need both checks because if origin information is ever copied over to new
|
||||
# nodes then we need to rely on the fact that only the original user code
|
||||
# has the origin annotation.
|
||||
if (anno.hasanno(before, anno.Basic.ORIGIN) and
|
||||
anno.hasanno(after, anno.Basic.ORIGIN)):
|
||||
source_info = anno.getanno(before, anno.Basic.ORIGIN)
|
||||
new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
|
||||
source_mapping[new_line_number] = source_info
|
||||
return source_mapping
|
||||
|
||||
|
||||
def ast_to_source(node, indentation=' '):
|
||||
@ -81,24 +44,28 @@ def ast_to_source(node, indentation=' '):
|
||||
code: The source code generated from the AST object
|
||||
source_mapping: A mapping between the user and AutoGraph generated code.
|
||||
"""
|
||||
original_node = node
|
||||
if isinstance(node, gast.AST):
|
||||
node = gast.gast_to_ast(node)
|
||||
if not isinstance(node, (list, tuple)):
|
||||
node = (node,)
|
||||
generator = astor.codegen.SourceGenerator(indentation, False,
|
||||
astor.string_repr.pretty_string)
|
||||
generator.visit(node)
|
||||
generator.result.append('\n')
|
||||
|
||||
for n in node:
|
||||
if isinstance(n, gast.AST):
|
||||
n = gast.gast_to_ast(n)
|
||||
generator.visit(n)
|
||||
generator.result.append('\n')
|
||||
|
||||
# In some versions of Python, literals may appear as actual values. This
|
||||
# ensures everything is string.
|
||||
code = map(str, generator.result)
|
||||
code = astor.source_repr.pretty_source(code).lstrip()
|
||||
source_mapping = _build_source_map(original_node, code)
|
||||
|
||||
return code, source_mapping
|
||||
return code
|
||||
|
||||
|
||||
def ast_to_object(node,
|
||||
def ast_to_object(nodes,
|
||||
indentation=' ',
|
||||
include_source_map=False,
|
||||
source_prefix=None,
|
||||
delete_on_exit=True):
|
||||
"""Return the Python objects represented by given AST.
|
||||
@ -107,42 +74,46 @@ def ast_to_object(node,
|
||||
e.g. `pdb` or `inspect`.
|
||||
|
||||
Args:
|
||||
node: The code to compile, as an AST object.
|
||||
indentation: The string to use for indentation.
|
||||
source_prefix: Optional string to print as-is into the source file.
|
||||
delete_on_exit: Whether to delete the temporary file used for compilation on
|
||||
exit.
|
||||
nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST
|
||||
object.
|
||||
indentation: Text, the string to use for indentation.
|
||||
include_source_map: bool, whether to attach a source map to the compiled
|
||||
object. Also see origin_info.py.
|
||||
source_prefix: Optional[Text], string to print as-is into the source file.
|
||||
delete_on_exit: bool, whether to delete the temporary file used for
|
||||
compilation on exit.
|
||||
|
||||
Returns:
|
||||
compiled_node: A module object containing the compiled source code.
|
||||
compiled_nodes: A module object containing the compiled source code.
|
||||
source: The source code of the compiled object
|
||||
Raises:
|
||||
ValueError: If ag_source_map__ is already in the namespace of the compiled
|
||||
node.
|
||||
nodes.
|
||||
"""
|
||||
# code_source_mapping does not yet include the offsets from import statements.
|
||||
source, code_source_mapping = ast_to_source(node, indentation=indentation)
|
||||
if not isinstance(nodes, (list, tuple)):
|
||||
nodes = (nodes,)
|
||||
|
||||
source = ast_to_source(nodes, indentation=indentation)
|
||||
|
||||
if source_prefix:
|
||||
source = source_prefix + '\n' + source
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
# TODO(znado): move into an _offset_source_map() helper function.
|
||||
# Need to offset the generated line numbers by the number of import lines.
|
||||
if source_prefix:
|
||||
num_import_lines = source_prefix.count('\n') + 1
|
||||
else:
|
||||
num_import_lines = 0
|
||||
source_mapping = {}
|
||||
for line_number, original_position in code_source_mapping.items():
|
||||
source_map_key = origin_info.CodeLocation(
|
||||
file_path=f.name, line_number=line_number + num_import_lines)
|
||||
source_mapping[source_map_key] = original_position
|
||||
module_name = os.path.basename(f.name[:-3])
|
||||
if source_prefix:
|
||||
f.write(source_prefix)
|
||||
f.write('\n')
|
||||
f.write(source)
|
||||
|
||||
if isinstance(nodes, (list, tuple)):
|
||||
indices = range(-len(nodes), 0)
|
||||
else:
|
||||
indices = (-1,)
|
||||
|
||||
if include_source_map:
|
||||
source_map = origin_info.source_map(nodes, source, f.name, indices)
|
||||
|
||||
# TODO(mdan): Try flush() and delete=False instead.
|
||||
if delete_on_exit:
|
||||
atexit.register(lambda: os.remove(f.name))
|
||||
compiled_node = imp.load_source(module_name, f.name)
|
||||
compiled_nodes = imp.load_source(module_name, f.name)
|
||||
|
||||
# TODO(znado): Clean this up so we don't need to attach it to the namespace.
|
||||
# TODO(znado): This does not work for classes because their methods share a
|
||||
@ -158,11 +129,13 @@ def ast_to_object(node,
|
||||
# is hard, and this cleanly fixes the
|
||||
# issues encountered with nested functions because this is attached to the
|
||||
# outermost one.
|
||||
source_map_name = 'ag_source_map__'
|
||||
if source_map_name in compiled_node.__dict__:
|
||||
raise ValueError('cannot convert %s because is has namespace attribute '
|
||||
'"%s", which is reserved for AutoGraph.' %
|
||||
(compiled_node, source_map_name))
|
||||
compiled_node.__dict__[source_map_name] = source_mapping
|
||||
if include_source_map:
|
||||
# TODO(mdan): This name should be decided by the caller.
|
||||
source_map_name = 'ag_source_map__'
|
||||
if source_map_name in compiled_nodes.__dict__:
|
||||
raise ValueError('cannot convert %s because is has namespace attribute '
|
||||
'"%s", which is reserved for AutoGraph.' %
|
||||
(compiled_nodes, source_map_name))
|
||||
compiled_nodes.__dict__[source_map_name] = source_map
|
||||
|
||||
return compiled_node, source
|
||||
return compiled_nodes, source
|
||||
|
@ -59,7 +59,7 @@ class CompilerTest(test.TestCase):
|
||||
value=gast.Str('c'))
|
||||
])
|
||||
|
||||
source, _ = compiler.ast_to_source(node, indentation=' ')
|
||||
source = compiler.ast_to_source(node, indentation=' ')
|
||||
self.assertEqual(
|
||||
textwrap.dedent("""
|
||||
if 1:
|
||||
|
@ -22,49 +22,115 @@ import collections
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.autograph.pyct import anno
|
||||
from tensorflow.contrib.autograph.pyct import ast_util
|
||||
from tensorflow.contrib.autograph.pyct import parser
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class CodeLocation(
|
||||
collections.namedtuple('CodeLocation', ('file_path', 'line_number'))):
|
||||
"""Location of a line of code.
|
||||
class LineLocation(
|
||||
collections.namedtuple('LineLocation', ('filename', 'lineno'))):
|
||||
"""Similar to Location, but without column information.
|
||||
|
||||
Attributes:
|
||||
file_path: text, the full path to the file containing the code.
|
||||
line_number: Int, the 1-based line number of the code in its file.
|
||||
filename: Text
|
||||
lineno: int, 1-based
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Location(
|
||||
collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))):
|
||||
"""Encodes code location information.
|
||||
|
||||
Attributes:
|
||||
filename: Text
|
||||
lineno: int, 1-based
|
||||
col_offset: int
|
||||
"""
|
||||
|
||||
@property
|
||||
def line_loc(self):
|
||||
return LineLocation(self.filename, self.lineno)
|
||||
|
||||
|
||||
class OriginInfo(
|
||||
collections.namedtuple('OriginInfo',
|
||||
('file_path', 'function_name', 'line_number',
|
||||
'column_offset', 'source_code_line'))):
|
||||
collections.namedtuple(
|
||||
'OriginInfo',
|
||||
('loc', 'function_name', 'source_code_line'))):
|
||||
"""Container for information about the source code before conversion.
|
||||
|
||||
Instances of this class contain information about the source code that
|
||||
transformed code originated from. Examples include:
|
||||
* line number
|
||||
* file name
|
||||
* original user code
|
||||
Attributes:
|
||||
loc: Location
|
||||
function_name: Optional[Text]
|
||||
source_code_line: Text
|
||||
"""
|
||||
|
||||
def as_frame(self):
|
||||
"""Makes a traceback frame tuple.
|
||||
|
||||
Returns:
|
||||
A tuple of (file_path, line_number, function_name, source_code_line).
|
||||
"""
|
||||
return (self.file_path, self.line_number, self.function_name,
|
||||
"""Returns a 4-tuple consistent with the return of traceback.extract_tb."""
|
||||
return (self.loc.filename, self.loc.lineno, self.function_name,
|
||||
self.source_code_line)
|
||||
|
||||
|
||||
# TODO(mdan): This source map should be a class - easier to refer to.
|
||||
def source_map(nodes, code, filename, indices_in_code):
|
||||
"""Creates a source map between an annotated AST and the code it compiles to.
|
||||
|
||||
Args:
|
||||
nodes: Iterable[ast.AST, ...]
|
||||
code: Text
|
||||
filename: Optional[Text]
|
||||
indices_in_code: Union[int, Iterable[int, ...]], the positions at which
|
||||
nodes appear in code. The parser always returns a module when parsing
|
||||
code. This argument indicates the position in that module's body at
|
||||
which the corresponding of node should appear.
|
||||
|
||||
Returns:
|
||||
Dict[CodeLocation, OriginInfo], mapping locations in code to locations
|
||||
indicated by origin annotations in node.
|
||||
"""
|
||||
reparsed_nodes = parser.parse_str(code)
|
||||
reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]
|
||||
|
||||
resolve(reparsed_nodes, code)
|
||||
result = {}
|
||||
|
||||
for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
|
||||
# Note: generated code might not be mapped back to its origin.
|
||||
# TODO(mdan): Generated code should always be mapped to something.
|
||||
origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
|
||||
final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
|
||||
if origin_info is None or final_info is None:
|
||||
continue
|
||||
|
||||
line_loc = LineLocation(filename, final_info.loc.lineno)
|
||||
|
||||
existing_origin = result.get(line_loc)
|
||||
if existing_origin is not None:
|
||||
# Overlaps may exist because of child nodes, but almost never to
|
||||
# different line locations. Exception make decorated functions, where
|
||||
# both lines are mapped to the same line in the AST.
|
||||
|
||||
# Line overlaps: keep bottom node.
|
||||
if existing_origin.loc.line_loc == origin_info.loc.line_loc:
|
||||
if existing_origin.loc.lineno >= origin_info.loc.lineno:
|
||||
continue
|
||||
|
||||
# In case of overlaps, keep the leftmost node.
|
||||
if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
|
||||
continue
|
||||
|
||||
result[line_loc] = origin_info
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# TODO(znado): Consider refactoring this into a Visitor.
|
||||
def resolve(node, source, function=None):
|
||||
# TODO(mdan): Does this work correctly with inner functions?
|
||||
def resolve(nodes, source, function=None):
|
||||
"""Adds an origin information to all nodes inside the body of function.
|
||||
|
||||
Args:
|
||||
node: The AST node for the function whose body nodes will be annotated.
|
||||
nodes: Union[ast.AST, Iterable[ast.AST, ...]]
|
||||
source: Text, the source code string for the function whose body nodes will
|
||||
be annotated.
|
||||
function: Callable, the function that will have all nodes inside of it
|
||||
@ -76,25 +142,32 @@ def resolve(node, source, function=None):
|
||||
A tuple of the AST node for function and a String containing its source
|
||||
code.
|
||||
"""
|
||||
if not isinstance(nodes, (list, tuple)):
|
||||
nodes = (nodes,)
|
||||
|
||||
if function:
|
||||
_, function_lineno = tf_inspect.getsourcelines(function)
|
||||
function_filepath = tf_inspect.getsourcefile(function)
|
||||
else:
|
||||
function_lineno = None
|
||||
function_filepath = None
|
||||
|
||||
source_lines = source.split('\n')
|
||||
for n in gast.walk(node):
|
||||
if hasattr(n, 'lineno'):
|
||||
# n.lineno is relative to the start of the enclosing function, so need to
|
||||
# offset it by the line of the function.
|
||||
source_code_line = source_lines[n.lineno - 1]
|
||||
for node in nodes:
|
||||
for n in gast.walk(node):
|
||||
if not hasattr(n, 'lineno'):
|
||||
continue
|
||||
|
||||
lineno_in_body = n.lineno
|
||||
|
||||
source_code_line = source_lines[lineno_in_body - 1]
|
||||
if function:
|
||||
source_lineno = n.lineno + function_lineno - 1
|
||||
source_lineno = function_lineno + lineno_in_body
|
||||
function_name = function.__name__
|
||||
else:
|
||||
source_lineno = n.lineno
|
||||
source_lineno = lineno_in_body
|
||||
function_name = None
|
||||
anno.setanno(
|
||||
n, anno.Basic.ORIGIN,
|
||||
OriginInfo(function_filepath, function_name, source_lineno,
|
||||
n.col_offset, source_code_line))
|
||||
|
||||
location = Location(function_filepath, source_lineno, n.col_offset)
|
||||
origin = OriginInfo(location, function_name, source_code_line)
|
||||
anno.setanno(n, anno.Basic.ORIGIN, origin)
|
||||
|
101
tensorflow/contrib/autograph/pyct/origin_info_test.py
Normal file
101
tensorflow/contrib/autograph/pyct/origin_info_test.py
Normal file
@ -0,0 +1,101 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for origin_info module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.autograph.pyct import anno
|
||||
from tensorflow.contrib.autograph.pyct import compiler
|
||||
from tensorflow.contrib.autograph.pyct import origin_info
|
||||
from tensorflow.contrib.autograph.pyct import parser
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class OriginInfoTest(test.TestCase):
|
||||
|
||||
def test_source_map(self):
|
||||
|
||||
def test_fn(x):
|
||||
if x > 0:
|
||||
x += 1
|
||||
return x
|
||||
|
||||
node, source = parser.parse_entity(test_fn)
|
||||
fn_node = node.body[0]
|
||||
origin_info.resolve(fn_node, source)
|
||||
|
||||
# Insert a traced line.
|
||||
new_node = parser.parse_str('x = abs(x)').body[0]
|
||||
anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
|
||||
fn_node.body.insert(0, new_node)
|
||||
|
||||
# Insert an untraced line.
|
||||
fn_node.body.insert(0, parser.parse_str('x = 0').body[0])
|
||||
|
||||
modified_source = compiler.ast_to_source(fn_node)
|
||||
|
||||
source_map = origin_info.source_map(fn_node, modified_source,
|
||||
'test_filename', [0])
|
||||
|
||||
loc = origin_info.LineLocation('test_filename', 1)
|
||||
origin = source_map[loc]
|
||||
self.assertEqual(origin.source_code_line, 'def test_fn(x):')
|
||||
self.assertEqual(origin.loc.lineno, 1)
|
||||
|
||||
# The untraced line, inserted second.
|
||||
loc = origin_info.LineLocation('test_filename', 2)
|
||||
self.assertFalse(loc in source_map)
|
||||
|
||||
# The traced line, inserted first.
|
||||
loc = origin_info.LineLocation('test_filename', 3)
|
||||
origin = source_map[loc]
|
||||
self.assertEqual(origin.source_code_line, ' if x > 0:')
|
||||
self.assertEqual(origin.loc.lineno, 2)
|
||||
|
||||
loc = origin_info.LineLocation('test_filename', 4)
|
||||
origin = source_map[loc]
|
||||
self.assertEqual(origin.source_code_line, ' if x > 0:')
|
||||
self.assertEqual(origin.loc.lineno, 2)
|
||||
|
||||
def test_resolve(self):
|
||||
|
||||
def test_fn(x):
|
||||
"""Docstring."""
|
||||
return x # comment
|
||||
|
||||
node, source = parser.parse_entity(test_fn)
|
||||
fn_node = node.body[0]
|
||||
origin_info.resolve(fn_node, source)
|
||||
|
||||
origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
|
||||
self.assertEqual(origin.loc.lineno, 1)
|
||||
self.assertEqual(origin.loc.col_offset, 0)
|
||||
self.assertEqual(origin.source_code_line, 'def test_fn(x):')
|
||||
|
||||
origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
|
||||
self.assertEqual(origin.loc.lineno, 2)
|
||||
self.assertEqual(origin.loc.col_offset, 2)
|
||||
self.assertEqual(origin.source_code_line, ' """Docstring."""')
|
||||
|
||||
origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
|
||||
self.assertEqual(origin.loc.lineno, 3)
|
||||
self.assertEqual(origin.loc.col_offset, 2)
|
||||
self.assertEqual(origin.source_code_line, ' return x # comment')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -37,6 +37,7 @@ def parse_entity(entity):
|
||||
|
||||
def parse_str(src):
|
||||
"""Returns the AST of given piece of code."""
|
||||
# TODO(mdan): This should exclude the module things are autowrapped in.
|
||||
return gast.parse(src)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user