diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 30179333b87..5a5a2c95dde 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -174,14 +174,15 @@ class CallTreeTransformer(converter.Base): keywords=ast_util.keywords_to_dict(normal_keywords)) template = """ - ag__.converted_call(func, options, args, kwargs) + ag__.converted_call(func, options, args, kwargs, function_ctx) """ new_call = templates.replace_as_expression( template, func=func, options=parser.parse_expression(function_context_name + '.callopts'), args=args, - kwargs=kwargs) + kwargs=kwargs, + function_ctx=function_context_name) return new_call diff --git a/tensorflow/python/autograph/converters/function_scopes.py b/tensorflow/python/autograph/converters/function_scopes.py index 4b33b6bf24f..52bd701b790 100644 --- a/tensorflow/python/autograph/converters/function_scopes.py +++ b/tensorflow/python/autograph/converters/function_scopes.py @@ -56,18 +56,20 @@ class FunctionBodyTransformer(converter.Base): return node scope = anno.getanno(node, anno.Static.SCOPE) - function_context_name = self.ctx.namer.new_symbol( - 'lambda_scope', scope.referenced) + function_context_name = self.ctx.namer.new_symbol('lambda_scope', + scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) template = """ - ag__.with_function_scope(lambda function_context_name: body, options) + ag__.with_function_scope( + lambda function_context: body, function_context_name, options) """ node.body = templates.replace_as_expression( template, options=self.ctx.program.options.to_ast(), - function_context_name=function_context_name, + function_context=function_context_name, + function_context_name=gast.Str(function_context_name), body=node.body) self.state[_Function].exit() @@ -93,14 +95,16 @@ class FunctionBodyTransformer(converter.Base): node.body = node.body[1:] template = """ - with ag__.FunctionScope(function_name, options) as function_context_name: + with ag__.FunctionScope( + function_name, context_name, options) as function_context: body """ wrapped_body = templates.replace( template, function_name=gast.Str(node.name), + context_name=gast.Str(function_context_name), options=self.ctx.program.options.to_ast(), - function_context_name=function_context_name, + function_context=function_context_name, body=node.body) if docstring_node is not None: diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py index f7f0d93c7a9..e9bf009d029 100644 --- a/tensorflow/python/autograph/core/converter.py +++ b/tensorflow/python/autograph/core/converter.py @@ -262,13 +262,15 @@ class EntityContext(transformer.Context): Attributes: namer: Namer info: transformer.EntityInfo - program: ProgramContext + program: ProgramContext, + targe_name: Text """ - def __init__(self, namer, entity_info, program_ctx): + def __init__(self, namer, entity_info, program_ctx, target_name=None): super(EntityContext, self).__init__(entity_info) self.namer = namer self.program = program_ctx + self.target_name = target_name class Base(transformer.Base): diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index d4b1daf921e..7560b436ef5 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -57,7 +57,7 @@ class TestCase(test.TestCase): self.dynamic_calls = [] # See api.converted_call - def converted_call(f, unused_opts, args, kwargs): + def converted_call(f, unused_opts, args, kwargs, unused_function_ctx): """Mock version of api.converted_call.""" self.dynamic_calls.append((args, kwargs)) if kwargs is None: @@ -135,7 +135,8 @@ class TestCase(test.TestCase): source_file='', future_features=future_features, namespace=namespace) - ctx = converter.EntityContext(namer, entity_info, program_ctx) + ctx = converter.EntityContext( + namer, entity_info, program_ctx, 'test_fn') origin_info.resolve_entity(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx diff --git a/tensorflow/python/autograph/core/function_wrappers.py b/tensorflow/python/autograph/core/function_wrappers.py index e981d6b4ce9..55b1071b029 100644 --- a/tensorflow/python/autograph/core/function_wrappers.py +++ b/tensorflow/python/autograph/core/function_wrappers.py @@ -40,12 +40,13 @@ class FunctionScope(object): conversion options; """ - def __init__(self, function_name, options): + def __init__(self, function_name, scope_name, options): + self.name = scope_name self.options = options if options.user_requested: - self.autograph_ctx = ag_ctx.ControlStatusCtx( - ag_ctx.Status.ENABLED, options) + self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED, + options) self.callopts = options.call_options() use_name_scope = options.uses(converter.Feature.NAME_SCOPES) @@ -101,7 +102,7 @@ class FunctionScope(object): return value -def with_function_scope(thunk, options): +def with_function_scope(thunk, scope_name, options): """Inline version of the FunctionScope context manager.""" - with FunctionScope('lambda_', options) as scope: + with FunctionScope('lambda_', scope_name, options) as scope: return thunk(scope) diff --git a/tensorflow/python/autograph/core/function_wrappers_test.py b/tensorflow/python/autograph/core/function_wrappers_test.py index cd107096e7e..01918007bbd 100644 --- a/tensorflow/python/autograph/core/function_wrappers_test.py +++ b/tensorflow/python/autograph/core/function_wrappers_test.py @@ -33,7 +33,7 @@ class FunctionWrappersTest(test.TestCase): self.skipTest('Tensor names are disabled in eager') with function_wrappers.FunctionScope( - 'test_name', + 'test_name', None, converter.ConversionOptions( optional_features=converter.Feature.NAME_SCOPES)): t = constant_op.constant(1) @@ -42,7 +42,7 @@ class FunctionWrappersTest(test.TestCase): def test_auto_cotrol_deps(self): v = variables.Variable(1) with function_wrappers.FunctionScope( - '_', + '_', None, converter.ConversionOptions( optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope: v.assign(2) @@ -51,9 +51,11 @@ class FunctionWrappersTest(test.TestCase): self.assertEqual(self.evaluate(v.read_value()), 2) def test_all_disabled(self): - with function_wrappers.FunctionScope(None, converter.STANDARD_OPTIONS): + with function_wrappers.FunctionScope(None, None, + converter.STANDARD_OPTIONS): t = constant_op.constant(1) self.assertEqual(self.evaluate(t), 1) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 7d2b33c3842..283e294a79b 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -362,8 +362,23 @@ def _is_known_loaded_type(f, module_name, entity_name): return False -def converted_call(f, options, args, kwargs): - """Compiles a function call inline. For internal use only.""" +def converted_call(f, options, args, kwargs, caller_fn_scope=None): + """Compiles a function call inline. + + For internal use only. + + Args: + f: The function to convert. + options: converter.ConversionOptions + args: Tuple, the original positional arguments of f + kwargs: Dict, the original keyword arguments of f + caller_fn_scope: Optional[function_wrappers.FunctionScope], the function + scope of the converted function in which this call was originally made. + + Returns: + Any, the result of executing a possibly-converted `f` with the given + arguments. + """ logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, kwargs) @@ -372,9 +387,9 @@ def converted_call(f, options, args, kwargs): if inspect_utils.isbuiltin(f): if f is eval: - return py_builtins.eval_in_original_context(f, args, 1) + return py_builtins.eval_in_original_context(f, args, caller_fn_scope) if f is super: - return py_builtins.super_in_original_context(f, args, 1) + return py_builtins.super_in_original_context(f, args, caller_fn_scope) if kwargs: return py_builtins.overload_of(f)(*args, **kwargs) else: diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 1538c6df8e1..a0275725ad1 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -653,24 +653,27 @@ def convert_func_to_ast(f, program_ctx, do_rename=True): _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) + if isinstance(node, gast.Lambda): + new_name = namer.new_symbol('tf__lambda', ()) + elif do_rename: + new_name = namer.function_name(f.__name__) + else: + new_name = f.__name__ + entity_info = transformer.EntityInfo( source_code=source, source_file='', future_features=future_features, namespace=namespace) - context = converter.EntityContext(namer, entity_info, program_ctx) + context = converter.EntityContext(namer, entity_info, program_ctx, new_name) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): - new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign( targets=[gast.Name(new_name, gast.Store(), None)], value=node) - elif do_rename: - new_name = namer.function_name(f.__name__) node.name = new_name else: - new_name = f.__name__ assert node.name == new_name return (node,), new_name, entity_info diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index 1337b1e1c83..dd7acdabd86 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -100,6 +100,7 @@ py_test( deps = [ ":operators", "//tensorflow/python:client_testlib", + "//tensorflow/python/autograph/core", ], ) diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py index e2d52065ef8..435e1030e36 100644 --- a/tensorflow/python/autograph/operators/py_builtins.py +++ b/tensorflow/python/autograph/operators/py_builtins.py @@ -49,11 +49,34 @@ def overload_of(f): return f -def eval_in_original_context(f, args, caller_level_delta): - """Executes the eval function with the user-specified globals/locals.""" +def _find_originating_frame(caller_fn_scope, innermost=True): + """Locates the frame in which `caller_fn_scope` was defined.""" ctx_frame = inspect.currentframe() - for _ in range(caller_level_delta + 1): + result = None + while ctx_frame is not None: + # Note it should not be normally possible to get false positives this way + # because the function scope object is not accessible to user code (barring + # call stack introspection). + if ctx_frame.f_locals.get(caller_fn_scope.name, None) is caller_fn_scope: + result = ctx_frame + if innermost: + break ctx_frame = ctx_frame.f_back + + assert result is not None, ( + 'the conversion process should ensure the caller_fn_scope is always' + ' found somewhere on the call stack') + + return result + + +def eval_in_original_context(f, args, caller_fn_scope): + """Executes the eval function in the context of a specified function.""" + # When control flow is rewritten using functions, eval should use the + # variables found in the same block where it was called. That is equivalent + # to the innermost function call. + ctx_frame = _find_originating_frame(caller_fn_scope, innermost=True) + args = ( args[0], ctx_frame.f_globals if len(args) < 2 else args[1], @@ -62,33 +85,34 @@ def eval_in_original_context(f, args, caller_level_delta): return f(*args) -def super_in_original_context(f, args, caller_level_delta): - """Executes the super function with the correct implicit argument handling. +def super_in_original_context(f, args, caller_fn_scope): + """Executes the super function in the context of a specified function. See https://docs.python.org/3/library/functions.html#super for the exact details Args: - f: super builtin function object. - args: Arguments that will be passed to super(...). A valid call should have - 0, 1, or 2 number of arguments - caller_level_delta: The number of nested frames to the original super(...)'s - context frame. + f: Callable, typically the super builtin + args: List[Any], the original call arguments + caller_fn_scope: Optional[function_wrappers.FunctionScope], the function + scope of the converted function in which this call was originally made Returns: - The result of super(...) call. + The result of calling `f` as if it was called in the frame indicated by + `caller_fn_scope`. """ # Python 2 doesn't support implicit argument super variants. if six.PY2: - return overload_of(f)(*args) + return f(*args) - if len(args) >= 1: - return overload_of(f)(*args) + # Only the no-arg call is desugared. + if args: + return f(*args) - ctx_frame = inspect.currentframe() - for _ in range(caller_level_delta + 1): - ctx_frame = ctx_frame.f_back + # Inner functions seem to include their closure in f_locals, so we need + # to find the outermost frame. + ctx_frame = _find_originating_frame(caller_fn_scope, innermost=False) # When super(..) is called without arguments, it looks for __class__ cell # variable and the first argument passed in the enclosing function according @@ -104,33 +128,23 @@ def super_in_original_context(f, args, caller_level_delta): # https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py2_test_grammar.py#L157 # https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py3_test_grammar.py#L192 # - # TODO(kkimlabs): mdan@ had an idea to do it more correctly without relying - # on the co_varnames argument order assumption. - # 1. Getting the caller function from the call stack. - # 2. Getting its argspec. - # 3. Get the name of the first argument from argspec. - # 4. Retrieve its value from locals. + # Note: the name can be more reliably obtained by inspecting the calling + # function's argspec. # - # Sample code snippet: + # Even though methods can be declared using *args (def method(*args)), + # that pattern is disallowed by super() -- it raises super() no arguments. + # Method definitions using **kwargs are not allowed at all. + # In other words, we can always assume that self is on the first positional + # argument (for correct code). # - # def fn2(): - # fr = inspect.currentframe() - # parent_fr = fr.f_back - # grandparent_fr = parent_fr.f_back - # f_name = parent_fr.f_code.co_name - # f = grandparent_fr.f_locals[f_name] - # - # def fn1(): - # fn2() - # - # fn1() - # - # However, we also need to handle some edge cases like - # function in the closure or globals, etc,... + # TODO(mdan): Consider additional checks in case the input code is incorrect. + # For example, the error might be cryptic compared to what super() regularly + # raises. - first_arg_name = ctx_frame.f_code.co_varnames[0] - first_arg = ctx_frame.f_locals[first_arg_name] - return f(ctx_frame.f_locals['__class__'], first_arg) + type_arg = ctx_frame.f_locals['__class__'] + self_arg_name = ctx_frame.f_code.co_varnames[0] + self_arg = ctx_frame.f_locals[self_arg_name] + return f(type_arg, self_arg) def abs_(x): diff --git a/tensorflow/python/autograph/operators/py_builtins_py3_test.py b/tensorflow/python/autograph/operators/py_builtins_py3_test.py index bd77139d214..11a33b90b75 100644 --- a/tensorflow/python/autograph/operators/py_builtins_py3_test.py +++ b/tensorflow/python/autograph/operators/py_builtins_py3_test.py @@ -18,116 +18,105 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.core import function_wrappers from tensorflow.python.autograph.operators import py_builtins from tensorflow.python.platform import test +class TestBaseClass(object): + + def overridden_method(self, x): + return x + 20 + + class PyBuiltinsTest(test.TestCase): - def test_super_with_no_arg_in_original_context(self): + def _basic_function_scope(self): + return function_wrappers.FunctionScope( + 'test_function_name', + 'test_scope', # Note: this must match the name in the `with` statement. + converter.ConversionOptions()) + + def test_super_in_original_context_niladic_call(self): test_case_self = self - class TestBase(object): + class TestSubclass(TestBaseClass): - def plus_twenty(self, x): - return x + 20 - - class TestSubclass(TestBase): - - def plus_twenty(self, x): + def overridden_method(self, x): test_case_self.fail('This should never be called.') - def no_arg(self): - test_base = py_builtins.super_in_original_context(super, (), 0) - return test_base.plus_twenty(1) + def test_method(self): + with test_case_self._basic_function_scope() as test_scope: + b = py_builtins.super_in_original_context(super, (), test_scope) + return b.overridden_method(1) tc = TestSubclass() - self.assertEqual(tc.no_arg(), 21) + self.assertEqual(tc.test_method(), 21) - def test_super_in_original_context_with_locals(self): + def test_super_in_original_context_caller_with_locals(self): test_case_self = self - class TestBase(object): + class TestSubclass(TestBaseClass): - def plus_twenty(self, x): - return x + 20 - - class TestSubclass(TestBase): - - def plus_twenty(self, x): + def overridden_method(self, x): test_case_self.fail('This should never be called.') - def with_locals(self): - x = 1 + def test_method(self, x): y = 7 - z = 7 - - test_base = py_builtins.super_in_original_context(super, (), 0) - return test_base.plus_twenty(x + y - z) + with test_case_self._basic_function_scope() as test_scope: + z = 7 + return py_builtins.super_in_original_context( + super, (), test_scope).overridden_method(x + y - z) tc = TestSubclass() - self.assertEqual(tc.with_locals(), 21) + self.assertEqual(tc.test_method(1), 21) - def test_super_in_original_context_with_args(self): + def test_super_in_original_context_inner_function(self): test_case_self = self - class TestBase(object): + class TestSubclass(TestBaseClass): - def plus_twenty(self, x): - return x + 20 - - class TestSubclass(TestBase): - - def plus_twenty(self, x): + def overridden_method(self, x): test_case_self.fail('This should never be called.') - def with_args(self, x, y, z): - test_base = py_builtins.super_in_original_context(super, (), 0) - return test_base.plus_twenty(x + y - z) + def test_method(self, x): + with test_case_self._basic_function_scope() as test_scope: + # Oddly, it's sufficient to use `self` in an inner function + # to gain access to __class__ in this scope. + # TODO(mdan): Is this true across implementations? + # Note: normally, it's illegal to use super() in inner functions (it + # throws an error), but the generated code may create them. + def inner_fn(): + return py_builtins.super_in_original_context( + super, (), test_scope).overridden_method(x) + + return inner_fn() tc = TestSubclass() - self.assertEqual(tc.with_args(1, 7, 7), 21) - self.assertEqual(tc.with_args.__func__(*[tc, 1, 7, 7]), 21) + self.assertEqual(tc.test_method(1), 21) - def test_super_in_original_context_with_varargs(self): + def test_super_in_original_context_inner_lambda(self): test_case_self = self - class TestBase(object): + class TestSubclass(TestBaseClass): - def plus_twenty(self, x): - return x + 20 - - class TestSubclass(TestBase): - - def plus_twenty(self, x): + def overridden_method(self, x): test_case_self.fail('This should never be called.') - def with_varargs(self, *args): - test_base = py_builtins.super_in_original_context(super, (), 0) - return test_base.plus_twenty(args[0] + args[1] - args[2]) + def test_method(self, x): + with test_case_self._basic_function_scope() as test_scope: + # Oddly, it's sufficient to use `self` in an inner function + # to gain access to __class__ in this scope. + # TODO(mdan): Is this true across implementations? + # Note: normally, it's illegal to use super() in inner functions (it + # throws an error), but the generated code may create them. + l = lambda: py_builtins.super_in_original_context( # pylint:disable=g-long-lambda + super, (), test_scope).overridden_method(x) + return l() tc = TestSubclass() - self.assertEqual(tc.with_varargs.__func__(*[tc, 1, 7, 7]), 21) - - def test_super_in_original_context_with_kwargs(self): - test_case_self = self - - class TestBase(object): - - def plus_twenty(self, x): - return x + 20 - - class TestSubclass(TestBase): - - def plus_twenty(self, x): - test_case_self.fail('This should never be called.') - - def with_kwargs(self, **kwargs): - test_base = py_builtins.super_in_original_context(super, (), 0) - return test_base.plus_twenty(kwargs['x'] + kwargs['y'] - kwargs['z']) - - tc = TestSubclass() - self.assertEqual(tc.with_kwargs.__func__(tc, x=1, y=7, z=7), 21) + self.assertEqual(tc.test_method(1), 21) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py index 0e3f8f5cd52..e706a281ad7 100644 --- a/tensorflow/python/autograph/operators/py_builtins_test.py +++ b/tensorflow/python/autograph/operators/py_builtins_test.py @@ -22,6 +22,8 @@ import sys import six +from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.core import function_wrappers from tensorflow.python.autograph.operators import data_structures from tensorflow.python.autograph.operators import py_builtins from tensorflow.python.data.ops import dataset_ops @@ -34,6 +36,12 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test +class TestBase(object): + + def plus_twenty(self, x): + return x + 20 + + class PyBuiltinsTest(test.TestCase): def test_abs(self): @@ -155,66 +163,71 @@ class PyBuiltinsTest(test.TestCase): self.assertAllEqual(self.evaluate(iterator.get_next()), (20, b'a')) self.assertAllEqual(self.evaluate(iterator.get_next()), (21, b'c')) + def _basic_function_scope(self): + return function_wrappers.FunctionScope( + 'test_function_name', + 'test_scope', # Note: this must match the name in the `with` statement. + converter.ConversionOptions()) + def test_eval_in_original_context(self): - def caller_1(lvl_delta): + def test_fn(): l = 1 # pylint:disable=unused-variable - return py_builtins.eval_in_original_context(eval, ('l',), lvl_delta) + with self._basic_function_scope() as test_scope: + return py_builtins.eval_in_original_context(eval, ('l',), test_scope) - def caller_2(lvl_delta): - l = 2 # pylint:disable=unused-variable - return caller_1(lvl_delta) + self.assertEqual(test_fn(), 1) - def caller_3(lvl_delta): - l = 3 # pylint:disable=unused-variable - return caller_2(lvl_delta) + def test_eval_in_original_context_inner_function(self): - self.assertEqual(caller_3(0), 1) - self.assertEqual(caller_3(1), 2) - self.assertEqual(caller_3(2), 3) + def test_fn(): + l = 1 # pylint:disable=unused-variable + with self._basic_function_scope() as test_scope: - def test_super_with_one_arg_in_original_context(self): + def inner_fn(): + # Note: a user function without a top-level function scope should + # never be found in user code; it's only possible in generated code. + l = 2 # pylint:disable=unused-variable + return py_builtins.eval_in_original_context(eval, ('l',), test_scope) + + return inner_fn() + + self.assertEqual(test_fn(), 2) + + def test_super_in_original_context_unary_call(self): test_case_self = self - class TestBase(object): - - def plus_twenty(self, x): - return x + 20 - class TestSubclass(TestBase): def plus_twenty(self, x): test_case_self.fail('This should never be called.') - def one_arg(self): - test_base_unbound = py_builtins.super_in_original_context( - super, (TestSubclass,), 0) - test_base = test_base_unbound.__get__(self, TestSubclass) - return test_base.plus_twenty(1) + def test_method(self): + with test_case_self._basic_function_scope() as test_scope: + test_base_unbound = py_builtins.super_in_original_context( + super, (TestSubclass,), test_scope) + test_base = test_base_unbound.__get__(self, TestSubclass) + return test_base.plus_twenty(1) tc = TestSubclass() - self.assertEqual(tc.one_arg(), 21) + self.assertEqual(tc.test_method(), 21) - def test_super_with_two_args_in_original_context(self): + def test_super_in_original_context_binary_call(self): test_case_self = self - class TestBase(object): - - def plus_twenty(self, x): - return x + 20 - class TestSubclass(TestBase): def plus_twenty(self, x): test_case_self.fail('This should never be called.') - def two_args(self): - test_base = py_builtins.super_in_original_context( - super, (TestSubclass, self), 0) - return test_base.plus_twenty(1) + def test_method(self): + with test_case_self._basic_function_scope() as test_scope: + test_base = py_builtins.super_in_original_context( + super, (TestSubclass, self), test_scope) + return test_base.plus_twenty(1) tc = TestSubclass() - self.assertEqual(tc.two_args(), 21) + self.assertEqual(tc.test_method(), 21) if __name__ == '__main__':