Use the function context created during conversion to identify the correct frame for calling super and eval. Fixes #29191.

PiperOrigin-RevId: 261144150
This commit is contained in:
Dan Moldovan 2019-08-01 10:10:00 -07:00 committed by TensorFlower Gardener
parent 6a42e239dc
commit 4d4f7ed0a4
12 changed files with 222 additions and 176 deletions

View File

@ -174,14 +174,15 @@ class CallTreeTransformer(converter.Base):
keywords=ast_util.keywords_to_dict(normal_keywords))
template = """
ag__.converted_call(func, options, args, kwargs)
ag__.converted_call(func, options, args, kwargs, function_ctx)
"""
new_call = templates.replace_as_expression(
template,
func=func,
options=parser.parse_expression(function_context_name + '.callopts'),
args=args,
kwargs=kwargs)
kwargs=kwargs,
function_ctx=function_context_name)
return new_call

View File

@ -56,18 +56,20 @@ class FunctionBodyTransformer(converter.Base):
return node
scope = anno.getanno(node, anno.Static.SCOPE)
function_context_name = self.ctx.namer.new_symbol(
'lambda_scope', scope.referenced)
function_context_name = self.ctx.namer.new_symbol('lambda_scope',
scope.referenced)
self.state[_Function].context_name = function_context_name
anno.setanno(node, 'function_context_name', function_context_name)
template = """
ag__.with_function_scope(lambda function_context_name: body, options)
ag__.with_function_scope(
lambda function_context: body, function_context_name, options)
"""
node.body = templates.replace_as_expression(
template,
options=self.ctx.program.options.to_ast(),
function_context_name=function_context_name,
function_context=function_context_name,
function_context_name=gast.Str(function_context_name),
body=node.body)
self.state[_Function].exit()
@ -93,14 +95,16 @@ class FunctionBodyTransformer(converter.Base):
node.body = node.body[1:]
template = """
with ag__.FunctionScope(function_name, options) as function_context_name:
with ag__.FunctionScope(
function_name, context_name, options) as function_context:
body
"""
wrapped_body = templates.replace(
template,
function_name=gast.Str(node.name),
context_name=gast.Str(function_context_name),
options=self.ctx.program.options.to_ast(),
function_context_name=function_context_name,
function_context=function_context_name,
body=node.body)
if docstring_node is not None:

View File

@ -262,13 +262,15 @@ class EntityContext(transformer.Context):
Attributes:
namer: Namer
info: transformer.EntityInfo
program: ProgramContext
program: ProgramContext,
targe_name: Text
"""
def __init__(self, namer, entity_info, program_ctx):
def __init__(self, namer, entity_info, program_ctx, target_name=None):
super(EntityContext, self).__init__(entity_info)
self.namer = namer
self.program = program_ctx
self.target_name = target_name
class Base(transformer.Base):

View File

@ -57,7 +57,7 @@ class TestCase(test.TestCase):
self.dynamic_calls = []
# See api.converted_call
def converted_call(f, unused_opts, args, kwargs):
def converted_call(f, unused_opts, args, kwargs, unused_function_ctx):
"""Mock version of api.converted_call."""
self.dynamic_calls.append((args, kwargs))
if kwargs is None:
@ -135,7 +135,8 @@ class TestCase(test.TestCase):
source_file='<fragment>',
future_features=future_features,
namespace=namespace)
ctx = converter.EntityContext(namer, entity_info, program_ctx)
ctx = converter.EntityContext(
namer, entity_info, program_ctx, 'test_fn')
origin_info.resolve_entity(node, source, test_fn)
node = converter.standard_analysis(node, ctx, is_initial=True)
return node, ctx

View File

@ -40,12 +40,13 @@ class FunctionScope(object):
conversion options;
"""
def __init__(self, function_name, options):
def __init__(self, function_name, scope_name, options):
self.name = scope_name
self.options = options
if options.user_requested:
self.autograph_ctx = ag_ctx.ControlStatusCtx(
ag_ctx.Status.ENABLED, options)
self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED,
options)
self.callopts = options.call_options()
use_name_scope = options.uses(converter.Feature.NAME_SCOPES)
@ -101,7 +102,7 @@ class FunctionScope(object):
return value
def with_function_scope(thunk, options):
def with_function_scope(thunk, scope_name, options):
"""Inline version of the FunctionScope context manager."""
with FunctionScope('lambda_', options) as scope:
with FunctionScope('lambda_', scope_name, options) as scope:
return thunk(scope)

View File

@ -33,7 +33,7 @@ class FunctionWrappersTest(test.TestCase):
self.skipTest('Tensor names are disabled in eager')
with function_wrappers.FunctionScope(
'test_name',
'test_name', None,
converter.ConversionOptions(
optional_features=converter.Feature.NAME_SCOPES)):
t = constant_op.constant(1)
@ -42,7 +42,7 @@ class FunctionWrappersTest(test.TestCase):
def test_auto_cotrol_deps(self):
v = variables.Variable(1)
with function_wrappers.FunctionScope(
'_',
'_', None,
converter.ConversionOptions(
optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope:
v.assign(2)
@ -51,9 +51,11 @@ class FunctionWrappersTest(test.TestCase):
self.assertEqual(self.evaluate(v.read_value()), 2)
def test_all_disabled(self):
with function_wrappers.FunctionScope(None, converter.STANDARD_OPTIONS):
with function_wrappers.FunctionScope(None, None,
converter.STANDARD_OPTIONS):
t = constant_op.constant(1)
self.assertEqual(self.evaluate(t), 1)
if __name__ == '__main__':
test.main()

View File

@ -362,8 +362,23 @@ def _is_known_loaded_type(f, module_name, entity_name):
return False
def converted_call(f, options, args, kwargs):
"""Compiles a function call inline. For internal use only."""
def converted_call(f, options, args, kwargs, caller_fn_scope=None):
"""Compiles a function call inline.
For internal use only.
Args:
f: The function to convert.
options: converter.ConversionOptions
args: Tuple, the original positional arguments of f
kwargs: Dict, the original keyword arguments of f
caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
scope of the converted function in which this call was originally made.
Returns:
Any, the result of executing a possibly-converted `f` with the given
arguments.
"""
logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args,
kwargs)
@ -372,9 +387,9 @@ def converted_call(f, options, args, kwargs):
if inspect_utils.isbuiltin(f):
if f is eval:
return py_builtins.eval_in_original_context(f, args, 1)
return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
if f is super:
return py_builtins.super_in_original_context(f, args, 1)
return py_builtins.super_in_original_context(f, args, caller_fn_scope)
if kwargs:
return py_builtins.overload_of(f)(*args, **kwargs)
else:

View File

@ -653,24 +653,27 @@ def convert_func_to_ast(f, program_ctx, do_rename=True):
_add_self_references(namespace, program_ctx.autograph_module)
namer = naming.Namer(namespace)
if isinstance(node, gast.Lambda):
new_name = namer.new_symbol('tf__lambda', ())
elif do_rename:
new_name = namer.function_name(f.__name__)
else:
new_name = f.__name__
entity_info = transformer.EntityInfo(
source_code=source,
source_file='<fragment>',
future_features=future_features,
namespace=namespace)
context = converter.EntityContext(namer, entity_info, program_ctx)
context = converter.EntityContext(namer, entity_info, program_ctx, new_name)
node = node_to_graph(node, context)
if isinstance(node, gast.Lambda):
new_name = namer.new_symbol('tf__lambda', ())
node = gast.Assign(
targets=[gast.Name(new_name, gast.Store(), None)], value=node)
elif do_rename:
new_name = namer.function_name(f.__name__)
node.name = new_name
else:
new_name = f.__name__
assert node.name == new_name
return (node,), new_name, entity_info

View File

@ -100,6 +100,7 @@ py_test(
deps = [
":operators",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/core",
],
)

View File

@ -49,11 +49,34 @@ def overload_of(f):
return f
def eval_in_original_context(f, args, caller_level_delta):
"""Executes the eval function with the user-specified globals/locals."""
def _find_originating_frame(caller_fn_scope, innermost=True):
"""Locates the frame in which `caller_fn_scope` was defined."""
ctx_frame = inspect.currentframe()
for _ in range(caller_level_delta + 1):
result = None
while ctx_frame is not None:
# Note it should not be normally possible to get false positives this way
# because the function scope object is not accessible to user code (barring
# call stack introspection).
if ctx_frame.f_locals.get(caller_fn_scope.name, None) is caller_fn_scope:
result = ctx_frame
if innermost:
break
ctx_frame = ctx_frame.f_back
assert result is not None, (
'the conversion process should ensure the caller_fn_scope is always'
' found somewhere on the call stack')
return result
def eval_in_original_context(f, args, caller_fn_scope):
"""Executes the eval function in the context of a specified function."""
# When control flow is rewritten using functions, eval should use the
# variables found in the same block where it was called. That is equivalent
# to the innermost function call.
ctx_frame = _find_originating_frame(caller_fn_scope, innermost=True)
args = (
args[0],
ctx_frame.f_globals if len(args) < 2 else args[1],
@ -62,33 +85,34 @@ def eval_in_original_context(f, args, caller_level_delta):
return f(*args)
def super_in_original_context(f, args, caller_level_delta):
"""Executes the super function with the correct implicit argument handling.
def super_in_original_context(f, args, caller_fn_scope):
"""Executes the super function in the context of a specified function.
See https://docs.python.org/3/library/functions.html#super for the exact
details
Args:
f: super builtin function object.
args: Arguments that will be passed to super(...). A valid call should have
0, 1, or 2 number of arguments
caller_level_delta: The number of nested frames to the original super(...)'s
context frame.
f: Callable, typically the super builtin
args: List[Any], the original call arguments
caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
scope of the converted function in which this call was originally made
Returns:
The result of super(...) call.
The result of calling `f` as if it was called in the frame indicated by
`caller_fn_scope`.
"""
# Python 2 doesn't support implicit argument super variants.
if six.PY2:
return overload_of(f)(*args)
return f(*args)
if len(args) >= 1:
return overload_of(f)(*args)
# Only the no-arg call is desugared.
if args:
return f(*args)
ctx_frame = inspect.currentframe()
for _ in range(caller_level_delta + 1):
ctx_frame = ctx_frame.f_back
# Inner functions seem to include their closure in f_locals, so we need
# to find the outermost frame.
ctx_frame = _find_originating_frame(caller_fn_scope, innermost=False)
# When super(..) is called without arguments, it looks for __class__ cell
# variable and the first argument passed in the enclosing function according
@ -104,33 +128,23 @@ def super_in_original_context(f, args, caller_level_delta):
# https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py2_test_grammar.py#L157
# https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py3_test_grammar.py#L192
#
# TODO(kkimlabs): mdan@ had an idea to do it more correctly without relying
# on the co_varnames argument order assumption.
# 1. Getting the caller function from the call stack.
# 2. Getting its argspec.
# 3. Get the name of the first argument from argspec.
# 4. Retrieve its value from locals.
# Note: the name can be more reliably obtained by inspecting the calling
# function's argspec.
#
# Sample code snippet:
# Even though methods can be declared using *args (def method(*args)),
# that pattern is disallowed by super() -- it raises super() no arguments.
# Method definitions using **kwargs are not allowed at all.
# In other words, we can always assume that self is on the first positional
# argument (for correct code).
#
# def fn2():
# fr = inspect.currentframe()
# parent_fr = fr.f_back
# grandparent_fr = parent_fr.f_back
# f_name = parent_fr.f_code.co_name
# f = grandparent_fr.f_locals[f_name]
#
# def fn1():
# fn2()
#
# fn1()
#
# However, we also need to handle some edge cases like
# function in the closure or globals, etc,...
# TODO(mdan): Consider additional checks in case the input code is incorrect.
# For example, the error might be cryptic compared to what super() regularly
# raises.
first_arg_name = ctx_frame.f_code.co_varnames[0]
first_arg = ctx_frame.f_locals[first_arg_name]
return f(ctx_frame.f_locals['__class__'], first_arg)
type_arg = ctx_frame.f_locals['__class__']
self_arg_name = ctx_frame.f_code.co_varnames[0]
self_arg = ctx_frame.f_locals[self_arg_name]
return f(type_arg, self_arg)
def abs_(x):

View File

@ -18,116 +18,105 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.platform import test
class TestBaseClass(object):
def overridden_method(self, x):
return x + 20
class PyBuiltinsTest(test.TestCase):
def test_super_with_no_arg_in_original_context(self):
def _basic_function_scope(self):
return function_wrappers.FunctionScope(
'test_function_name',
'test_scope', # Note: this must match the name in the `with` statement.
converter.ConversionOptions())
def test_super_in_original_context_niladic_call(self):
test_case_self = self
class TestBase(object):
class TestSubclass(TestBaseClass):
def plus_twenty(self, x):
return x + 20
class TestSubclass(TestBase):
def plus_twenty(self, x):
def overridden_method(self, x):
test_case_self.fail('This should never be called.')
def no_arg(self):
test_base = py_builtins.super_in_original_context(super, (), 0)
return test_base.plus_twenty(1)
def test_method(self):
with test_case_self._basic_function_scope() as test_scope:
b = py_builtins.super_in_original_context(super, (), test_scope)
return b.overridden_method(1)
tc = TestSubclass()
self.assertEqual(tc.no_arg(), 21)
self.assertEqual(tc.test_method(), 21)
def test_super_in_original_context_with_locals(self):
def test_super_in_original_context_caller_with_locals(self):
test_case_self = self
class TestBase(object):
class TestSubclass(TestBaseClass):
def plus_twenty(self, x):
return x + 20
class TestSubclass(TestBase):
def plus_twenty(self, x):
def overridden_method(self, x):
test_case_self.fail('This should never be called.')
def with_locals(self):
x = 1
def test_method(self, x):
y = 7
z = 7
test_base = py_builtins.super_in_original_context(super, (), 0)
return test_base.plus_twenty(x + y - z)
with test_case_self._basic_function_scope() as test_scope:
z = 7
return py_builtins.super_in_original_context(
super, (), test_scope).overridden_method(x + y - z)
tc = TestSubclass()
self.assertEqual(tc.with_locals(), 21)
self.assertEqual(tc.test_method(1), 21)
def test_super_in_original_context_with_args(self):
def test_super_in_original_context_inner_function(self):
test_case_self = self
class TestBase(object):
class TestSubclass(TestBaseClass):
def plus_twenty(self, x):
return x + 20
class TestSubclass(TestBase):
def plus_twenty(self, x):
def overridden_method(self, x):
test_case_self.fail('This should never be called.')
def with_args(self, x, y, z):
test_base = py_builtins.super_in_original_context(super, (), 0)
return test_base.plus_twenty(x + y - z)
def test_method(self, x):
with test_case_self._basic_function_scope() as test_scope:
# Oddly, it's sufficient to use `self` in an inner function
# to gain access to __class__ in this scope.
# TODO(mdan): Is this true across implementations?
# Note: normally, it's illegal to use super() in inner functions (it
# throws an error), but the generated code may create them.
def inner_fn():
return py_builtins.super_in_original_context(
super, (), test_scope).overridden_method(x)
return inner_fn()
tc = TestSubclass()
self.assertEqual(tc.with_args(1, 7, 7), 21)
self.assertEqual(tc.with_args.__func__(*[tc, 1, 7, 7]), 21)
self.assertEqual(tc.test_method(1), 21)
def test_super_in_original_context_with_varargs(self):
def test_super_in_original_context_inner_lambda(self):
test_case_self = self
class TestBase(object):
class TestSubclass(TestBaseClass):
def plus_twenty(self, x):
return x + 20
class TestSubclass(TestBase):
def plus_twenty(self, x):
def overridden_method(self, x):
test_case_self.fail('This should never be called.')
def with_varargs(self, *args):
test_base = py_builtins.super_in_original_context(super, (), 0)
return test_base.plus_twenty(args[0] + args[1] - args[2])
def test_method(self, x):
with test_case_self._basic_function_scope() as test_scope:
# Oddly, it's sufficient to use `self` in an inner function
# to gain access to __class__ in this scope.
# TODO(mdan): Is this true across implementations?
# Note: normally, it's illegal to use super() in inner functions (it
# throws an error), but the generated code may create them.
l = lambda: py_builtins.super_in_original_context( # pylint:disable=g-long-lambda
super, (), test_scope).overridden_method(x)
return l()
tc = TestSubclass()
self.assertEqual(tc.with_varargs.__func__(*[tc, 1, 7, 7]), 21)
def test_super_in_original_context_with_kwargs(self):
test_case_self = self
class TestBase(object):
def plus_twenty(self, x):
return x + 20
class TestSubclass(TestBase):
def plus_twenty(self, x):
test_case_self.fail('This should never be called.')
def with_kwargs(self, **kwargs):
test_base = py_builtins.super_in_original_context(super, (), 0)
return test_base.plus_twenty(kwargs['x'] + kwargs['y'] - kwargs['z'])
tc = TestSubclass()
self.assertEqual(tc.with_kwargs.__func__(tc, x=1, y=7, z=7), 21)
self.assertEqual(tc.test_method(1), 21)
if __name__ == '__main__':

View File

@ -22,6 +22,8 @@ import sys
import six
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.operators import data_structures
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.data.ops import dataset_ops
@ -34,6 +36,12 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
class TestBase(object):
def plus_twenty(self, x):
return x + 20
class PyBuiltinsTest(test.TestCase):
def test_abs(self):
@ -155,66 +163,71 @@ class PyBuiltinsTest(test.TestCase):
self.assertAllEqual(self.evaluate(iterator.get_next()), (20, b'a'))
self.assertAllEqual(self.evaluate(iterator.get_next()), (21, b'c'))
def _basic_function_scope(self):
return function_wrappers.FunctionScope(
'test_function_name',
'test_scope', # Note: this must match the name in the `with` statement.
converter.ConversionOptions())
def test_eval_in_original_context(self):
def caller_1(lvl_delta):
def test_fn():
l = 1 # pylint:disable=unused-variable
return py_builtins.eval_in_original_context(eval, ('l',), lvl_delta)
with self._basic_function_scope() as test_scope:
return py_builtins.eval_in_original_context(eval, ('l',), test_scope)
def caller_2(lvl_delta):
l = 2 # pylint:disable=unused-variable
return caller_1(lvl_delta)
self.assertEqual(test_fn(), 1)
def caller_3(lvl_delta):
l = 3 # pylint:disable=unused-variable
return caller_2(lvl_delta)
def test_eval_in_original_context_inner_function(self):
self.assertEqual(caller_3(0), 1)
self.assertEqual(caller_3(1), 2)
self.assertEqual(caller_3(2), 3)
def test_fn():
l = 1 # pylint:disable=unused-variable
with self._basic_function_scope() as test_scope:
def test_super_with_one_arg_in_original_context(self):
def inner_fn():
# Note: a user function without a top-level function scope should
# never be found in user code; it's only possible in generated code.
l = 2 # pylint:disable=unused-variable
return py_builtins.eval_in_original_context(eval, ('l',), test_scope)
return inner_fn()
self.assertEqual(test_fn(), 2)
def test_super_in_original_context_unary_call(self):
test_case_self = self
class TestBase(object):
def plus_twenty(self, x):
return x + 20
class TestSubclass(TestBase):
def plus_twenty(self, x):
test_case_self.fail('This should never be called.')
def one_arg(self):
test_base_unbound = py_builtins.super_in_original_context(
super, (TestSubclass,), 0)
test_base = test_base_unbound.__get__(self, TestSubclass)
return test_base.plus_twenty(1)
def test_method(self):
with test_case_self._basic_function_scope() as test_scope:
test_base_unbound = py_builtins.super_in_original_context(
super, (TestSubclass,), test_scope)
test_base = test_base_unbound.__get__(self, TestSubclass)
return test_base.plus_twenty(1)
tc = TestSubclass()
self.assertEqual(tc.one_arg(), 21)
self.assertEqual(tc.test_method(), 21)
def test_super_with_two_args_in_original_context(self):
def test_super_in_original_context_binary_call(self):
test_case_self = self
class TestBase(object):
def plus_twenty(self, x):
return x + 20
class TestSubclass(TestBase):
def plus_twenty(self, x):
test_case_self.fail('This should never be called.')
def two_args(self):
test_base = py_builtins.super_in_original_context(
super, (TestSubclass, self), 0)
return test_base.plus_twenty(1)
def test_method(self):
with test_case_self._basic_function_scope() as test_scope:
test_base = py_builtins.super_in_original_context(
super, (TestSubclass, self), test_scope)
return test_base.plus_twenty(1)
tc = TestSubclass()
self.assertEqual(tc.two_args(), 21)
self.assertEqual(tc.test_method(), 21)
if __name__ == '__main__':