Fix bug in parallel_walk, along with a few structural bugs that this fix revealed:

1. The conversion process was inconsistently packaging the final output into modules or lists. This CL uniformly uses a list of nodes as output from all *_to_graph functions. As a side effect, converter_testing.py asserts that the output is always a single node and extracts it, so there is no need for tests to unpack it any more. Modify the compiler to skip generating a source map by default.

2. The class converter was incorrectly saving the superclass value to the string 'object' instead of the symbol `object`.

Additional refactoring that was caught along: Simplify the source mapping code, move it to origin_info.py, add tests and additional checks. Slightly simplify the error rewriting mechanism.

PiperOrigin-RevId: 206087110
This commit is contained in:
Dan Moldovan 2018-07-25 18:01:50 -07:00 committed by TensorFlower Gardener
parent 59305b118a
commit b24037513f
22 changed files with 530 additions and 335 deletions

View File

@ -35,7 +35,7 @@ class AssertsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = asserts.transform(node, ctx)
self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
self.assertTrue(isinstance(node.body[0].value, gast.Call))
if __name__ == '__main__':

View File

@ -38,7 +38,7 @@ class DirectivesTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {'directives': directives})
node = directives_converter.transform(node, ctx)
def_, = anno.getanno(node.body[0].body[0].targets[0],
def_, = anno.getanno(node.body[0].targets[0],
anno.Static.DEFINITIONS)
d = def_.directives[directives.set_element_type]
self.assertEqual(d['dtype'].s, 'a')
@ -52,7 +52,7 @@ class DirectivesTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {'directives': directives})
node = directives_converter.transform(node, ctx)
def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
d = def_.directives[directives.set_element_type]
self.assertEqual(d['dtype'].n, 1)
self.assertEqual(d['shape'].n, 2)
@ -67,7 +67,7 @@ class DirectivesTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {'directives': directives})
node = directives_converter.transform(node, ctx)
d = anno.getanno(node.body[0].body[1], AgAnno.DIRECTIVES)
d = anno.getanno(node.body[1], AgAnno.DIRECTIVES)
d = d[directives.set_loop_options]
self.assertEqual(d['parallel_iterations'].n, 10)
self.assertEqual(d['back_prop'].id, 'a')

View File

@ -34,11 +34,13 @@ class ErrorHandlersTest(converter_testing.TestCase):
raise ValueError()
node, ctx = self.prepare(test_fn, {})
anno.setanno(node.body[0], anno.Basic.ORIGIN,
origin_info.OriginInfo('test_path', None, None, None, None))
anno.setanno(node, anno.Basic.ORIGIN,
origin_info.OriginInfo(None, None, None))
node = error_handlers.transform(node, ctx)
with self.compiled(node, {}) as result:
with self.assertRaises(errors.GraphConstructionError):
# Here we just assert that the handler works. Its correctness is
# verified by errors_test.py.
result.test_fn()
def test_no_origin_annotation(self):

View File

@ -79,7 +79,7 @@ class ListTest(converter_testing.TestCase):
ns = {'special_functions': special_functions}
node, ctx = self.prepare(test_fn, ns)
def_, = anno.getanno(node.body[0].body[0].targets[0],
def_, = anno.getanno(node.body[0].targets[0],
anno.Static.ORIG_DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32'),
@ -114,7 +114,7 @@ class ListTest(converter_testing.TestCase):
return tf.stack(l)
node, ctx = self.prepare(test_fn, {})
def_, = anno.getanno(node.body[0].body[0].targets[0],
def_, = anno.getanno(node.body[0].targets[0],
anno.Static.ORIG_DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')

View File

@ -43,7 +43,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
self.assertEqual(len(node.body[0].body), 1)
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign) as result:
with self.test_session() as sess:
@ -64,7 +64,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
self.assertEqual(len(node.body[0].body), 1)
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign) as result:
with self.test_session() as sess:
@ -84,7 +84,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
self.assertEqual(len(node.body[0].body), 1)
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, control_flow_ops.Assert) as result:
with self.test_session() as sess:
@ -104,7 +104,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
self.assertEqual(len(node.body[0].body), 1)
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign_add) as result:
with self.test_session() as sess:
@ -125,7 +125,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
self.assertEqual(len(node.body[0].body[0].body), 1)
self.assertEqual(len(node.body[0].body), 1)
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
with self.test_session() as sess:
@ -147,7 +147,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
node, ctx = self.prepare(test_fn, {})
node = side_effect_guards.transform(node, ctx)
self.assertEqual(len(node.body[0].body), 1)
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign,
state_ops.assign_add) as result:

View File

@ -38,7 +38,7 @@ class SliceTest(converter_testing.TestCase):
return l[1]
node, ctx = self.prepare(test_fn, {})
def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
}
@ -59,11 +59,11 @@ class SliceTest(converter_testing.TestCase):
return l[1]
node, ctx = self.prepare(test_fn, {})
def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
}
def_, = anno.getanno(node.body[0].body[0].body[0].targets[0],
def_, = anno.getanno(node.body[0].body[0].targets[0],
anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.float32')

View File

@ -94,7 +94,8 @@ class TestCase(test.TestCase):
return 7
try:
result, source = compiler.ast_to_object(node)
result, source = compiler.ast_to_object(node, include_source_map=True)
result.tf = self.make_fake_mod('fake_tf', *symbols)
fake_ag = self.make_fake_mod('fake_ag', converted_call)
fake_ag.__dict__.update(operators.__dict__)
@ -144,6 +145,7 @@ class TestCase(test.TestCase):
recursive=True,
autograph_decorators=()):
node, source = parser.parse_entity(test_fn)
node = node.body[0]
if namer is None:
namer = FakeNamer()
program_ctx = converter.ProgramContext(

View File

@ -31,11 +31,14 @@ import logging
import sys
import traceback
from tensorflow.contrib.autograph.pyct.origin_info import CodeLocation
from tensorflow.contrib.autograph.pyct import origin_info
from tensorflow.python.framework import errors_impl
from tensorflow.python.util import tf_inspect
# TODO(mdan): Add a superclass common to all errors.
class GraphConstructionError(Exception):
"""Error for graph construction errors from AutoGraph generated code."""
@ -65,27 +68,35 @@ class TfRuntimeError(Exception):
return message + ''.join(traceback.format_list(self.custom_traceback))
def _rewrite_frame(source_map, cleaned_traceback, stack_frame_indices):
"""Rewrites the stack frames at the given indices using the given source map.
def _rewrite_tb(source_map, tb, filter_function_name=None):
"""Rewrites code references in a traceback.
Args:
source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
AG generated code.
cleaned_traceback: List[Tuple[text, text, text, text]], the current
traceback.
stack_frame_indices: Iterable[Int], frame indices to possibly rewrite if
there are matching source mapping keys.
source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping
locations to their origin
tb: List[Tuple[Text, Text, Text, Text]], consistent with
traceback.extract_tb
filter_function_name: Optional[Text], allows restricting restricts the
frames to rewrite to a particular function name
Returns:
None
List[Tuple[Text, Text, Text, Text]], the rewritten traceback
"""
for frame_index in stack_frame_indices:
# (file_path, line number, function name, code)
file_path, line_number, _, _ = cleaned_traceback[frame_index]
source_map_key = CodeLocation(file_path=file_path, line_number=line_number)
found_mapping = source_map_key in source_map
if found_mapping:
cleaned_traceback[frame_index] = source_map[source_map_key].as_frame()
new_tb = []
for frame in tb:
filename, lineno, function_name, _ = frame
loc = origin_info.LineLocation(filename, lineno)
origin = source_map.get(loc)
# TODO(mdan): We shouldn't need the function name at all.
# filename + lineno should be sufficient, even if there are multiple source
# maps.
if origin is not None:
if filter_function_name == function_name or filter_function_name is None:
new_tb.append(origin.as_frame())
else:
new_tb.append(frame)
else:
new_tb.append(frame)
return new_tb
# TODO(znado): Make more robust to name changes in the rewriting logic.
@ -98,18 +109,20 @@ def _remove_rewrite_frames(tb):
return cleaned_tb
# TODO(mdan): rename to raise_*
def rewrite_graph_construction_error(source_map):
"""Rewrites errors raised by non-AG APIs inside AG generated code.
Meant to be called from the try/except block inside each AutoGraph generated
function. Only rewrites the traceback frames corresponding to the function
that this is called from. When we raise a GraphConstructionError at the end
it is then caught by calling functions, where they can be responsible for
rewriting their own frames.
This is called from the except handler inside an AutoGraph generated function
(that is, during exception handling). Only rewrites the frames corresponding
to the function that this is called from, so each function is responsible
to call this to have its own frames rewritten.
This function always raises an error.
Args:
source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
AG generated code.
source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source
map belonging to the calling function
Raises:
GraphConstructionError: The rewritten underlying error.
@ -120,31 +133,19 @@ def rewrite_graph_construction_error(source_map):
assert original_error is not None
try:
_, _, _, func_name, _, _ = tf_inspect.stack()[1]
# The latest function call is added to the beginning of a traceback, but
# when rewriting the traceback of multiple function calls the previous
# functions' except blocks may have already rewritten their own frames so
# we want to copy over all of the previous frames. We may have rewritten
# previous frames only if the error is a GraphConstructionError.
if isinstance(original_error, GraphConstructionError):
# TODO(mdan): This is incomplete.
# The error might have bubbled through a non-converted function.
cleaned_traceback = traceback.extract_tb(e_traceback)
previous_traceback = original_error.custom_traceback
cleaned_traceback = [cleaned_traceback[0]] + previous_traceback
else:
cleaned_traceback = traceback.extract_tb(e_traceback)
cleaned_traceback = _remove_rewrite_frames(cleaned_traceback)
current_frame_indices = []
# This code is meant to be called from the try/except block that wraps a
# function body. Here we look for all frames that came from the function
# that this wraps, look for any matching line numbers in the source
# mapping, and then rewrite them if matches are found.
for fi, frame in enumerate(cleaned_traceback):
_, _, frame_func_name, _ = frame
if frame_func_name == func_name:
current_frame_indices.append(fi)
break
if current_frame_indices:
_rewrite_frame(source_map, cleaned_traceback, current_frame_indices)
# Remove the frame corresponding to this function call.
cleaned_traceback = cleaned_traceback[1:]
cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback, func_name)
if isinstance(original_error, GraphConstructionError):
original_error.custom_traceback = cleaned_traceback
@ -153,6 +154,7 @@ def rewrite_graph_construction_error(source_map):
new_error = GraphConstructionError(original_error, cleaned_traceback)
except Exception:
logging.exception('Error while rewriting AutoGraph error:')
# TODO(mdan): Should reraise here, removing the top frame as well.
raise original_error
else:
raise new_error
@ -161,18 +163,17 @@ def rewrite_graph_construction_error(source_map):
del e_traceback
# TODO(mdan): This should be consistent with rewrite_graph_construction_error
# Both should either raise or return.
def rewrite_tf_runtime_error(error, source_map):
"""Rewrites TensorFlow runtime errors raised by ops created in AG code.
Args:
error: error_impl.OpError, an TensorFlow error that will have its traceback
rewritten.
source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
AG generated code.
error: tf.OpError
source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo]
Returns:
A TfRuntimeError with a traceback rewritten according to the given
source mapping.
TfRuntimeError, the rewritten underlying error.
"""
# Check for cases where we leave a user method and re-enter it in the
# traceback. This is done by looking at the function names when the
@ -198,15 +199,16 @@ def rewrite_tf_runtime_error(error, source_map):
# The source map keys are (file_path, line_number) so get the set of all user
# file_paths.
try:
all_user_files = set(k.file_path for k in source_map)
all_user_files = set(loc.filename for loc in source_map)
cleaned_traceback = []
last_user_frame_index = None
last_user_user_file_path = None
last_user_user_fn_name = None
# TODO(mdan): Simplify this logic.
for fi, frame in enumerate(error.op.traceback):
frame_file_path, frame_line_number, _, _ = frame
src_map_key = CodeLocation(
file_path=frame_file_path, line_number=frame_line_number)
frame_file_path, lineno, _, _ = frame
lineno -= 1 # Frame line numbers are 1-based.
src_map_key = origin_info.LineLocation(frame_file_path, lineno)
if frame_file_path in all_user_files:
if src_map_key in source_map:
original_fn_name = source_map[src_map_key].function_name
@ -223,8 +225,8 @@ def rewrite_tf_runtime_error(error, source_map):
last_user_user_file_path = frame_file_path
cleaned_traceback.append(frame)
for fi in range(len(cleaned_traceback)):
_rewrite_frame(source_map, cleaned_traceback, [fi])
cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
op_name = error.op.name
op_message = error.message
rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback)
@ -263,7 +265,7 @@ def improved_errors(converted_function):
ValueError: If converted_function is not generated by AutoGraph
"""
if (getattr(converted_function, 'ag_source_map', None) is None or
not converted_function.ag_source_map):
not isinstance(converted_function.ag_source_map, dict)):
raise ValueError(
'converted_function must be the result of an autograph.to_graph call')
try:

View File

@ -28,88 +28,76 @@ from tensorflow.python.util import tf_inspect
def zero_div():
return array_ops.constant(10, dtype=dtypes.int32) // 0
x = array_ops.constant(10, dtype=dtypes.int32)
return x // 0
def zero_div_caller():
a = zero_div() + 2
return a
return zero_div()
class RuntimeErrorsTest(test.TestCase):
def setUp(self):
self._fake_origin = origin_info.OriginInfo('new file', 'new func', 96, 0,
'print("hello world!")')
def fake_origin(self, function, line_offset):
_, lineno = tf_inspect.getsourcelines(function)
filename = tf_inspect.getsourcefile(function)
lineno += line_offset
loc = origin_info.LineLocation(filename, lineno)
origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code')
return loc, origin
def test_error_replacement(self):
_, zero_div_lineno = tf_inspect.getsourcelines(zero_div)
src_map = {
errors.CodeLocation(
file_path=__file__, line_number=zero_div_lineno + 1):
self._fake_origin
}
def test_improved_errors_basic(self):
loc, origin = self.fake_origin(zero_div, 2)
zero_div_caller.ag_source_map = {loc: origin}
ops = zero_div_caller()
with self.assertRaises(errors.TfRuntimeError) as cm:
z = zero_div_caller()
zero_div_caller.ag_source_map = src_map
with errors.improved_errors(zero_div_caller):
with self.test_session() as sess:
sess.run(z)
expected = cm.exception
current_traceback = expected.custom_traceback
for frame in current_traceback:
self.assertNotEqual('zero_div', frame[2])
self.assertTrue(
any(self._fake_origin.as_frame() == frame
for frame in current_traceback))
sess.run(ops)
def test_error_not_found(self):
src_map = {
errors.CodeLocation(file_path=__file__, line_number=-1):
self._fake_origin
}
for frame in cm.exception.custom_traceback:
_, _, function_name, _ = frame
self.assertNotEqual('zero_div', function_name)
self.assertIn(origin.as_frame(), set(cm.exception.custom_traceback))
def test_improved_errors_no_matching_lineno(self):
loc, origin = self.fake_origin(zero_div, -1)
zero_div_caller.ag_source_map = {loc: origin}
ops = zero_div_caller()
with self.assertRaises(errors.TfRuntimeError) as cm:
z = zero_div_caller()
zero_div_caller.ag_source_map = src_map
with errors.improved_errors(zero_div_caller):
with self.test_session() as sess:
sess.run(z)
expected = cm.exception
current_traceback = expected.custom_traceback
self.assertTrue(any('zero_div' in frame[2] for frame in current_traceback))
for frame in current_traceback:
self.assertNotEqual(frame, self._fake_origin.as_frame())
sess.run(ops)
def test_rewriting_error(self):
_, zero_div_lineno = tf_inspect.getsourcelines(zero_div)
src_map = {
errors.CodeLocation(
file_path=__file__, line_number=zero_div_lineno + 1):
None
}
with self.assertRaisesRegexp(tf_errors.InvalidArgumentError,
'Integer division by zero'):
z = zero_div_caller()
zero_div_caller.ag_source_map = src_map
all_function_names = set()
for frame in cm.exception.custom_traceback:
_, _, function_name, _ = frame
all_function_names.add(function_name)
self.assertNotEqual('test_function_name', function_name)
self.assertIn('zero_div', all_function_names)
def test_improved_errors_failures(self):
loc, _ = self.fake_origin(zero_div, 2)
zero_div_caller.ag_source_map = {loc: 'bogus object'}
ops = zero_div_caller()
with self.assertRaises(tf_errors.InvalidArgumentError):
with errors.improved_errors(zero_div_caller):
with self.test_session() as sess:
sess.run(z)
sess.run(ops)
def test_no_ag_source_map(self):
def test_improved_errors_validation(self):
with self.assertRaisesRegexp(
ValueError,
'converted_function must be the result of an autograph.to_graph call'):
with errors.improved_errors(None):
pass
def test_bad_ag_source_map(self):
errors.improved_errors(zero_div).__enter__()
with self.assertRaisesRegexp(
ValueError,
'converted_function must be the result of an autograph.to_graph call'):
src_map = None
zero_div_caller.ag_source_map = src_map
with errors.improved_errors(None):
pass
zero_div_caller.ag_source_map = 'not a dict'
errors.improved_errors(zero_div_caller).__enter__()
if __name__ == '__main__':

View File

@ -23,7 +23,6 @@ from functools import wraps
from enum import Enum
# pylint:disable=g-bad-import-order
import gast
import six
# pylint:enable=g-bad-import-order
@ -245,19 +244,21 @@ def to_graph(e,
_, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
arg_types)
module = gast.Module([])
nodes = []
for dep in reversed(program_ctx.dependency_cache.values()):
module.body.append(dep)
compiled_node, compiled_src = compiler.ast_to_object(
module, source_prefix=program_ctx.required_imports)
nodes.extend(dep)
compiled_module, compiled_src = compiler.ast_to_object(
nodes,
source_prefix=program_ctx.required_imports,
include_source_map=True)
# The compiled code should see everything the entry entity saw.
# TODO(mdan): This might not work well if the call tree spans modules?
for key, val in namespace.items():
# Avoid overwriting entities that have been transformed.
if key not in compiled_node.__dict__:
compiled_node.__dict__[key] = val
compiled_fn = getattr(compiled_node, name)
if key not in compiled_module.__dict__:
compiled_module.__dict__[key] = val
compiled_fn = getattr(compiled_module, name)
# Need this so the source_mapping attribute is available for the context
# manager to access for runtime errors.
@ -270,7 +271,7 @@ def to_graph(e,
'"%s", which is reserved for AutoGraph.' %
(compiled_fn, source_map_attribute_name))
setattr(compiled_fn, source_map_attribute_name,
compiled_node.__dict__['ag_source_map__'])
compiled_module.__dict__['ag_source_map__'])
if verbose:
logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
@ -308,7 +309,7 @@ def to_code(e,
conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
code = '\n'.join(
compiler.ast_to_source(dep, indentation)[0]
compiler.ast_to_source(dep, indentation)
for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))
return program_ctx.required_imports + '\n\n' + code

View File

@ -164,7 +164,7 @@ def class_to_graph(c, program_ctx):
class_namespace = namespace
else:
class_namespace.update(namespace)
converted_members[m] = node
converted_members[m] = node[0]
namer = program_ctx.new_namer(class_namespace)
class_name = namer.compiled_class_name(c.__name__, c)
@ -175,10 +175,10 @@ def class_to_graph(c, program_ctx):
# program_ctx.update_name_map(namer)).
output_nodes = []
renames = {}
bases = []
base_names = []
for base in c.__bases__:
if isinstance(object, base):
bases.append('object')
base_names.append('object')
continue
if is_whitelisted_for_graph(base):
alias = namer.new_symbol(base.__name__, ())
@ -190,28 +190,28 @@ def class_to_graph(c, program_ctx):
else:
# This will trigger a conversion into a class with this name.
alias = namer.compiled_class_name(base.__name__, base)
bases.append(alias)
base_names.append(alias)
renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
program_ctx.update_name_map(namer)
# Generate the definition of the converted class.
output_nodes.append(
gast.ClassDef(
class_name,
bases=bases,
keywords=[],
body=list(converted_members.values()),
decorator_list=[]))
node = gast.Module(output_nodes)
bases = [gast.Name(n, gast.Load(), None) for n in base_names]
class_def = gast.ClassDef(
class_name,
bases=bases,
keywords=[],
body=list(converted_members.values()),
decorator_list=[])
# Make a final pass to replace references to the class or its base classes.
# Most commonly, this occurs when making super().__init__() calls.
# TODO(mdan): Making direct references to superclass' superclass will fail.
node = qual_names.resolve(node)
class_def = qual_names.resolve(class_def)
renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
node = ast_util.rename_symbols(node, renames)
class_def = ast_util.rename_symbols(class_def, renames)
return node, class_name, class_namespace
output_nodes.append(class_def)
return output_nodes, class_name, class_namespace
def _add_reserved_symbol(namespace, name, entity):
@ -279,7 +279,7 @@ def function_to_graph(f,
program_ctx.update_name_map(namer)
# TODO(mdan): Use this at compilation.
return node, new_name, namespace
return (node,), new_name, namespace
def node_to_graph(node, context, rewrite_errors=True):

View File

@ -60,10 +60,11 @@ class ConversionTest(test.TestCase):
return a + b
program_ctx = self._simple_program_ctx()
ast, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, = nodes
self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name)
self.assertTrue(ns['b'] is b)
self.assertIs(ns['b'], b)
def test_entity_to_graph_call_tree(self):
@ -78,14 +79,11 @@ class ConversionTest(test.TestCase):
self.assertTrue(f in program_ctx.dependency_cache)
self.assertTrue(g in program_ctx.dependency_cache)
self.assertEqual('tf__f', program_ctx.dependency_cache[f].name)
# need one extra .body[0] in order to step past the try/except wrapper that
# is added automatically, the other for the with tf.name_scope('f') that is
# added automatically
self.assertEqual(
'tf__g',
program_ctx.dependency_cache[f].body[0].body[0].body[0].value.func.id)
self.assertEqual('tf__g', program_ctx.dependency_cache[g].name)
f_node = program_ctx.dependency_cache[f][0]
g_node = program_ctx.dependency_cache[g][0]
self.assertEqual('tf__f', f_node.name)
self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id)
self.assertEqual('tf__g', g_node.name)
def test_entity_to_graph_class_hierarchy(self):
@ -118,9 +116,9 @@ class ConversionTest(test.TestCase):
self.assertTrue(TestBase in program_ctx.dependency_cache)
self.assertTrue(TestSubclass in program_ctx.dependency_cache)
self.assertEqual('TfTestBase',
program_ctx.dependency_cache[TestBase].body[-1].name)
program_ctx.dependency_cache[TestBase][-1].name)
self.assertEqual('TfTestSubclass',
program_ctx.dependency_cache[TestSubclass].body[-1].name)
program_ctx.dependency_cache[TestSubclass][-1].name)
def test_entity_to_graph_class_hierarchy_whitelisted(self):
@ -139,10 +137,9 @@ class ConversionTest(test.TestCase):
self.assertTrue(TestSubclass in program_ctx.dependency_cache)
self.assertFalse(training.Model in program_ctx.dependency_cache)
self.assertEqual(
'Model',
program_ctx.dependency_cache[TestSubclass].body[0].names[0].name)
'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name)
self.assertEqual('TfTestSubclass',
program_ctx.dependency_cache[TestSubclass].body[-1].name)
program_ctx.dependency_cache[TestSubclass][-1].name)
def test_entity_to_graph_lambda(self):
f = lambda a: a

View File

@ -99,6 +99,16 @@ py_test(
],
)
py_test(
name = "origin_info_test",
srcs = ["origin_info_test.py"],
srcs_version = "PY2AND3",
deps = [
":pyct",
"//tensorflow/python:client_testlib",
],
)
py_test(
name = "parser_test",
srcs = ["parser_test.py"],

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import ast
import collections
import gast
from tensorflow.contrib.autograph.pyct import anno
@ -185,6 +184,7 @@ class PatternMatcher(gast.NodeVisitor):
if v != p:
return self.no_match()
def matches(node, pattern):
"""Basic pattern matcher for AST.
@ -253,30 +253,61 @@ def apply_to_single_assignments(targets, values, apply_fn):
apply_fn(target, values)
def iter_fields(node):
for field in sorted(node._fields):
try:
yield getattr(node, field)
except AttributeError:
pass
def parallel_walk(node, other):
"""Walks two ASTs in parallel.
The two trees must have identical structure.
def iter_child_nodes(node):
for field in iter_fields(node):
if isinstance(field, gast.AST):
yield field
elif isinstance(field, list):
for item in field:
if isinstance(item, gast.AST):
yield item
Args:
node: Union[ast.AST, Iterable[ast.AST]]
other: Union[ast.AST, Iterable[ast.AST]]
Yields:
Tuple[ast.AST, ast.AST]
Raises:
ValueError: if the two trees don't have identical structure.
"""
if isinstance(node, (list, tuple)):
node_stack = list(node)
else:
node_stack = [node]
if isinstance(other, (list, tuple)):
other_stack = list(other)
else:
other_stack = [other]
def parallel_walk(node_a, node_b):
todo_a = collections.deque([node_a])
todo_b = collections.deque([node_b])
while todo_a and todo_b:
node_a = todo_a.popleft()
node_b = todo_b.popleft()
todo_a.extend(iter_child_nodes(node_a))
todo_b.extend(iter_child_nodes(node_b))
yield node_a, node_b
while node_stack and other_stack:
assert len(node_stack) == len(other_stack)
n = node_stack.pop()
o = other_stack.pop()
if (not isinstance(n, (ast.AST, gast.AST)) or
not isinstance(o, (ast.AST, gast.AST)) or
n.__class__.__name__ != o.__class__.__name__):
raise ValueError('inconsistent nodes: {} and {}'.format(n, o))
yield n, o
for f in n._fields:
n_child = getattr(n, f, None)
o_child = getattr(o, f, None)
if f.startswith('__') or n_child is None or o_child is None:
continue
if isinstance(n_child, (list, tuple)):
if (not isinstance(o_child, (list, tuple)) or
len(n_child) != len(o_child)):
raise ValueError(
'inconsistent values for field {}: {} and {}'.format(
f, n_child, o_child))
node_stack.extend(n_child)
other_stack.extend(o_child)
elif isinstance(n_child, (gast.AST, ast.AST)):
node_stack.append(n_child)
other_stack.append(o_child)
elif n_child != o_child:
raise ValueError(
'inconsistent values for field {}: {} and {}'.format(
f, n_child, o_child))

View File

@ -44,7 +44,7 @@ class AstUtilTest(test.TestCase):
node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
self.assertIsInstance(node.body[0].value.left.id, str)
source, _ = compiler.ast_to_source(node)
source = compiler.ast_to_source(node)
self.assertEqual(source.strip(), 'renamed_a + b')
def test_rename_symbols_attributes(self):
@ -54,7 +54,7 @@ class AstUtilTest(test.TestCase):
node = ast_util.rename_symbols(
node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})
source, _ = compiler.ast_to_source(node)
source = compiler.ast_to_source(node)
self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_annotations(self):
@ -97,10 +97,10 @@ class AstUtilTest(test.TestCase):
d = ast_util.keywords_to_dict(keywords)
# Make sure we generate a usable dict node by attaching it to a variable and
# compiling everything.
output = parser.parse_str('b = 3')
output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),)
result, _ = compiler.ast_to_object(output)
self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
node = parser.parse_str('def f(b): pass').body[0]
node.body.append(ast.Return(d))
result, _ = compiler.ast_to_object(node)
self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
def assertMatch(self, target_str, pattern_str):
node = parser.parse_expression(target_str)
@ -130,8 +130,8 @@ class AstUtilTest(test.TestCase):
'super(Bar, _).__init__(_)')
def _mock_apply_fn(self, target, source):
target, _ = compiler.ast_to_source(target)
source, _ = compiler.ast_to_source(source)
target = compiler.ast_to_source(target)
source = compiler.ast_to_source(source)
self._invocation_counts[(target.strip(), source.strip())] += 1
def test_apply_to_single_assignments_dynamic_unpack(self):
@ -157,24 +157,40 @@ class AstUtilTest(test.TestCase):
})
def test_parallel_walk(self):
ret = ast.Return(
ast.BinOp(
op=ast.Add(),
left=ast.Name(id='a', ctx=ast.Load()),
right=ast.Num(1)))
node = ast.FunctionDef(
name='f',
args=ast.arguments(
args=[ast.Name(id='a', ctx=ast.Param())],
vararg=None,
kwarg=None,
defaults=[]),
body=[ret],
decorator_list=[],
returns=None)
node = parser.parse_str(
textwrap.dedent("""
def f(a):
return a + 1
"""))
for child_a, child_b in ast_util.parallel_walk(node, node):
self.assertEqual(child_a, child_b)
def test_parallel_walk_inconsistent_trees(self):
node_1 = parser.parse_str(
textwrap.dedent("""
def f(a):
return a + 1
"""))
node_2 = parser.parse_str(
textwrap.dedent("""
def f(a):
return a + (a * 2)
"""))
node_3 = parser.parse_str(
textwrap.dedent("""
def f(a):
return a + 2
"""))
with self.assertRaises(ValueError):
for _ in ast_util.parallel_walk(node_1, node_2):
pass
# There is not particular reason to reject trees that differ only in the
# value of a constant.
# TODO(mdan): This should probably be allowed.
with self.assertRaises(ValueError):
for _ in ast_util.parallel_walk(node_1, node_3):
pass
if __name__ == '__main__':
test.main()

View File

@ -67,10 +67,8 @@ class Node(object):
if isinstance(self.ast_node, gast.FunctionDef):
return 'def %s' % self.ast_node.name
elif isinstance(self.ast_node, gast.withitem):
source, _ = compiler.ast_to_source(self.ast_node.context_expr)
return source.strip()
source, _ = compiler.ast_to_source(self.ast_node)
return source.strip()
return compiler.ast_to_source(self.ast_node.context_expr).strip()
return compiler.ast_to_source(self.ast_node).strip()
class Graph(

View File

@ -43,7 +43,7 @@ class AnfTransformerTest(test.TestCase):
return a
node, _ = parser.parse_entity(test_function)
node = anf.transform(node, self._simple_source_info())
node = anf.transform(node.body[0], self._simple_source_info())
result, _ = compiler.ast_to_object(node)
self.assertEqual(test_function(), result.test_function())

View File

@ -30,44 +30,7 @@ import tempfile
import astor
import gast
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import origin_info
from tensorflow.contrib.autograph.pyct import parser
def _build_source_map(node, code):
"""Return the Python objects represented by given AST.
Compiling the AST code this way ensures that the source code is readable by
e.g. `pdb` or `inspect`.
Args:
node: An AST node of the original generated code, before the source code is
generated.
code: The string representation of the source code for the newly generated
code.
Returns:
Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
generated code.
"""
# After we have the final generated code we reparse it to get the final line
# numbers. Then we walk through the generated and original ASTs in parallel
# to build the mapping between the user and generated code.
new_node = parser.parse_str(code)
origin_info.resolve(new_node, code)
source_mapping = {}
for before, after in ast_util.parallel_walk(node, new_node):
# Need both checks because if origin information is ever copied over to new
# nodes then we need to rely on the fact that only the original user code
# has the origin annotation.
if (anno.hasanno(before, anno.Basic.ORIGIN) and
anno.hasanno(after, anno.Basic.ORIGIN)):
source_info = anno.getanno(before, anno.Basic.ORIGIN)
new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
source_mapping[new_line_number] = source_info
return source_mapping
def ast_to_source(node, indentation=' '):
@ -81,24 +44,28 @@ def ast_to_source(node, indentation=' '):
code: The source code generated from the AST object
source_mapping: A mapping between the user and AutoGraph generated code.
"""
original_node = node
if isinstance(node, gast.AST):
node = gast.gast_to_ast(node)
if not isinstance(node, (list, tuple)):
node = (node,)
generator = astor.codegen.SourceGenerator(indentation, False,
astor.string_repr.pretty_string)
generator.visit(node)
generator.result.append('\n')
for n in node:
if isinstance(n, gast.AST):
n = gast.gast_to_ast(n)
generator.visit(n)
generator.result.append('\n')
# In some versions of Python, literals may appear as actual values. This
# ensures everything is string.
code = map(str, generator.result)
code = astor.source_repr.pretty_source(code).lstrip()
source_mapping = _build_source_map(original_node, code)
return code, source_mapping
return code
def ast_to_object(node,
def ast_to_object(nodes,
indentation=' ',
include_source_map=False,
source_prefix=None,
delete_on_exit=True):
"""Return the Python objects represented by given AST.
@ -107,42 +74,46 @@ def ast_to_object(node,
e.g. `pdb` or `inspect`.
Args:
node: The code to compile, as an AST object.
indentation: The string to use for indentation.
source_prefix: Optional string to print as-is into the source file.
delete_on_exit: Whether to delete the temporary file used for compilation on
exit.
nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST
object.
indentation: Text, the string to use for indentation.
include_source_map: bool, whether to attach a source map to the compiled
object. Also see origin_info.py.
source_prefix: Optional[Text], string to print as-is into the source file.
delete_on_exit: bool, whether to delete the temporary file used for
compilation on exit.
Returns:
compiled_node: A module object containing the compiled source code.
compiled_nodes: A module object containing the compiled source code.
source: The source code of the compiled object
Raises:
ValueError: If ag_source_map__ is already in the namespace of the compiled
node.
nodes.
"""
# code_source_mapping does not yet include the offsets from import statements.
source, code_source_mapping = ast_to_source(node, indentation=indentation)
if not isinstance(nodes, (list, tuple)):
nodes = (nodes,)
source = ast_to_source(nodes, indentation=indentation)
if source_prefix:
source = source_prefix + '\n' + source
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
# TODO(znado): move into an _offset_source_map() helper function.
# Need to offset the generated line numbers by the number of import lines.
if source_prefix:
num_import_lines = source_prefix.count('\n') + 1
else:
num_import_lines = 0
source_mapping = {}
for line_number, original_position in code_source_mapping.items():
source_map_key = origin_info.CodeLocation(
file_path=f.name, line_number=line_number + num_import_lines)
source_mapping[source_map_key] = original_position
module_name = os.path.basename(f.name[:-3])
if source_prefix:
f.write(source_prefix)
f.write('\n')
f.write(source)
if isinstance(nodes, (list, tuple)):
indices = range(-len(nodes), 0)
else:
indices = (-1,)
if include_source_map:
source_map = origin_info.source_map(nodes, source, f.name, indices)
# TODO(mdan): Try flush() and delete=False instead.
if delete_on_exit:
atexit.register(lambda: os.remove(f.name))
compiled_node = imp.load_source(module_name, f.name)
compiled_nodes = imp.load_source(module_name, f.name)
# TODO(znado): Clean this up so we don't need to attach it to the namespace.
# TODO(znado): This does not work for classes because their methods share a
@ -158,11 +129,13 @@ def ast_to_object(node,
# is hard, and this cleanly fixes the
# issues encountered with nested functions because this is attached to the
# outermost one.
source_map_name = 'ag_source_map__'
if source_map_name in compiled_node.__dict__:
raise ValueError('cannot convert %s because is has namespace attribute '
'"%s", which is reserved for AutoGraph.' %
(compiled_node, source_map_name))
compiled_node.__dict__[source_map_name] = source_mapping
if include_source_map:
# TODO(mdan): This name should be decided by the caller.
source_map_name = 'ag_source_map__'
if source_map_name in compiled_nodes.__dict__:
raise ValueError('cannot convert %s because is has namespace attribute '
'"%s", which is reserved for AutoGraph.' %
(compiled_nodes, source_map_name))
compiled_nodes.__dict__[source_map_name] = source_map
return compiled_node, source
return compiled_nodes, source

View File

@ -59,7 +59,7 @@ class CompilerTest(test.TestCase):
value=gast.Str('c'))
])
source, _ = compiler.ast_to_source(node, indentation=' ')
source = compiler.ast_to_source(node, indentation=' ')
self.assertEqual(
textwrap.dedent("""
if 1:

View File

@ -22,49 +22,115 @@ import collections
import gast
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.python.util import tf_inspect
class CodeLocation(
collections.namedtuple('CodeLocation', ('file_path', 'line_number'))):
"""Location of a line of code.
class LineLocation(
collections.namedtuple('LineLocation', ('filename', 'lineno'))):
"""Similar to Location, but without column information.
Attributes:
file_path: text, the full path to the file containing the code.
line_number: Int, the 1-based line number of the code in its file.
filename: Text
lineno: int, 1-based
"""
pass
class Location(
collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))):
"""Encodes code location information.
Attributes:
filename: Text
lineno: int, 1-based
col_offset: int
"""
@property
def line_loc(self):
return LineLocation(self.filename, self.lineno)
class OriginInfo(
collections.namedtuple('OriginInfo',
('file_path', 'function_name', 'line_number',
'column_offset', 'source_code_line'))):
collections.namedtuple(
'OriginInfo',
('loc', 'function_name', 'source_code_line'))):
"""Container for information about the source code before conversion.
Instances of this class contain information about the source code that
transformed code originated from. Examples include:
* line number
* file name
* original user code
Attributes:
loc: Location
function_name: Optional[Text]
source_code_line: Text
"""
def as_frame(self):
"""Makes a traceback frame tuple.
Returns:
A tuple of (file_path, line_number, function_name, source_code_line).
"""
return (self.file_path, self.line_number, self.function_name,
"""Returns a 4-tuple consistent with the return of traceback.extract_tb."""
return (self.loc.filename, self.loc.lineno, self.function_name,
self.source_code_line)
# TODO(mdan): This source map should be a class - easier to refer to.
def source_map(nodes, code, filename, indices_in_code):
"""Creates a source map between an annotated AST and the code it compiles to.
Args:
nodes: Iterable[ast.AST, ...]
code: Text
filename: Optional[Text]
indices_in_code: Union[int, Iterable[int, ...]], the positions at which
nodes appear in code. The parser always returns a module when parsing
code. This argument indicates the position in that module's body at
which the corresponding of node should appear.
Returns:
Dict[CodeLocation, OriginInfo], mapping locations in code to locations
indicated by origin annotations in node.
"""
reparsed_nodes = parser.parse_str(code)
reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]
resolve(reparsed_nodes, code)
result = {}
for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
# Note: generated code might not be mapped back to its origin.
# TODO(mdan): Generated code should always be mapped to something.
origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
if origin_info is None or final_info is None:
continue
line_loc = LineLocation(filename, final_info.loc.lineno)
existing_origin = result.get(line_loc)
if existing_origin is not None:
# Overlaps may exist because of child nodes, but almost never to
# different line locations. Exception make decorated functions, where
# both lines are mapped to the same line in the AST.
# Line overlaps: keep bottom node.
if existing_origin.loc.line_loc == origin_info.loc.line_loc:
if existing_origin.loc.lineno >= origin_info.loc.lineno:
continue
# In case of overlaps, keep the leftmost node.
if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
continue
result[line_loc] = origin_info
return result
# TODO(znado): Consider refactoring this into a Visitor.
def resolve(node, source, function=None):
# TODO(mdan): Does this work correctly with inner functions?
def resolve(nodes, source, function=None):
"""Adds an origin information to all nodes inside the body of function.
Args:
node: The AST node for the function whose body nodes will be annotated.
nodes: Union[ast.AST, Iterable[ast.AST, ...]]
source: Text, the source code string for the function whose body nodes will
be annotated.
function: Callable, the function that will have all nodes inside of it
@ -76,25 +142,32 @@ def resolve(node, source, function=None):
A tuple of the AST node for function and a String containing its source
code.
"""
if not isinstance(nodes, (list, tuple)):
nodes = (nodes,)
if function:
_, function_lineno = tf_inspect.getsourcelines(function)
function_filepath = tf_inspect.getsourcefile(function)
else:
function_lineno = None
function_filepath = None
source_lines = source.split('\n')
for n in gast.walk(node):
if hasattr(n, 'lineno'):
# n.lineno is relative to the start of the enclosing function, so need to
# offset it by the line of the function.
source_code_line = source_lines[n.lineno - 1]
for node in nodes:
for n in gast.walk(node):
if not hasattr(n, 'lineno'):
continue
lineno_in_body = n.lineno
source_code_line = source_lines[lineno_in_body - 1]
if function:
source_lineno = n.lineno + function_lineno - 1
source_lineno = function_lineno + lineno_in_body
function_name = function.__name__
else:
source_lineno = n.lineno
source_lineno = lineno_in_body
function_name = None
anno.setanno(
n, anno.Basic.ORIGIN,
OriginInfo(function_filepath, function_name, source_lineno,
n.col_offset, source_code_line))
location = Location(function_filepath, source_lineno, n.col_offset)
origin = OriginInfo(location, function_name, source_code_line)
anno.setanno(n, anno.Basic.ORIGIN, origin)

View File

@ -0,0 +1,101 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for origin_info module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import origin_info
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.python.platform import test
class OriginInfoTest(test.TestCase):
def test_source_map(self):
def test_fn(x):
if x > 0:
x += 1
return x
node, source = parser.parse_entity(test_fn)
fn_node = node.body[0]
origin_info.resolve(fn_node, source)
# Insert a traced line.
new_node = parser.parse_str('x = abs(x)').body[0]
anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
fn_node.body.insert(0, new_node)
# Insert an untraced line.
fn_node.body.insert(0, parser.parse_str('x = 0').body[0])
modified_source = compiler.ast_to_source(fn_node)
source_map = origin_info.source_map(fn_node, modified_source,
'test_filename', [0])
loc = origin_info.LineLocation('test_filename', 1)
origin = source_map[loc]
self.assertEqual(origin.source_code_line, 'def test_fn(x):')
self.assertEqual(origin.loc.lineno, 1)
# The untraced line, inserted second.
loc = origin_info.LineLocation('test_filename', 2)
self.assertFalse(loc in source_map)
# The traced line, inserted first.
loc = origin_info.LineLocation('test_filename', 3)
origin = source_map[loc]
self.assertEqual(origin.source_code_line, ' if x > 0:')
self.assertEqual(origin.loc.lineno, 2)
loc = origin_info.LineLocation('test_filename', 4)
origin = source_map[loc]
self.assertEqual(origin.source_code_line, ' if x > 0:')
self.assertEqual(origin.loc.lineno, 2)
def test_resolve(self):
def test_fn(x):
"""Docstring."""
return x # comment
node, source = parser.parse_entity(test_fn)
fn_node = node.body[0]
origin_info.resolve(fn_node, source)
origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 1)
self.assertEqual(origin.loc.col_offset, 0)
self.assertEqual(origin.source_code_line, 'def test_fn(x):')
origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 2)
self.assertEqual(origin.loc.col_offset, 2)
self.assertEqual(origin.source_code_line, ' """Docstring."""')
origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 3)
self.assertEqual(origin.loc.col_offset, 2)
self.assertEqual(origin.source_code_line, ' return x # comment')
if __name__ == '__main__':
test.main()

View File

@ -37,6 +37,7 @@ def parse_entity(entity):
def parse_str(src):
"""Returns the AST of given piece of code."""
# TODO(mdan): This should exclude the module things are autowrapped in.
return gast.parse(src)