Use the function context created during conversion to identify the correct frame for calling super
and eval
. Fixes #29191.
PiperOrigin-RevId: 261144150
This commit is contained in:
parent
6a42e239dc
commit
4d4f7ed0a4
@ -174,14 +174,15 @@ class CallTreeTransformer(converter.Base):
|
||||
keywords=ast_util.keywords_to_dict(normal_keywords))
|
||||
|
||||
template = """
|
||||
ag__.converted_call(func, options, args, kwargs)
|
||||
ag__.converted_call(func, options, args, kwargs, function_ctx)
|
||||
"""
|
||||
new_call = templates.replace_as_expression(
|
||||
template,
|
||||
func=func,
|
||||
options=parser.parse_expression(function_context_name + '.callopts'),
|
||||
args=args,
|
||||
kwargs=kwargs)
|
||||
kwargs=kwargs,
|
||||
function_ctx=function_context_name)
|
||||
|
||||
return new_call
|
||||
|
||||
|
@ -56,18 +56,20 @@ class FunctionBodyTransformer(converter.Base):
|
||||
return node
|
||||
|
||||
scope = anno.getanno(node, anno.Static.SCOPE)
|
||||
function_context_name = self.ctx.namer.new_symbol(
|
||||
'lambda_scope', scope.referenced)
|
||||
function_context_name = self.ctx.namer.new_symbol('lambda_scope',
|
||||
scope.referenced)
|
||||
self.state[_Function].context_name = function_context_name
|
||||
anno.setanno(node, 'function_context_name', function_context_name)
|
||||
|
||||
template = """
|
||||
ag__.with_function_scope(lambda function_context_name: body, options)
|
||||
ag__.with_function_scope(
|
||||
lambda function_context: body, function_context_name, options)
|
||||
"""
|
||||
node.body = templates.replace_as_expression(
|
||||
template,
|
||||
options=self.ctx.program.options.to_ast(),
|
||||
function_context_name=function_context_name,
|
||||
function_context=function_context_name,
|
||||
function_context_name=gast.Str(function_context_name),
|
||||
body=node.body)
|
||||
|
||||
self.state[_Function].exit()
|
||||
@ -93,14 +95,16 @@ class FunctionBodyTransformer(converter.Base):
|
||||
node.body = node.body[1:]
|
||||
|
||||
template = """
|
||||
with ag__.FunctionScope(function_name, options) as function_context_name:
|
||||
with ag__.FunctionScope(
|
||||
function_name, context_name, options) as function_context:
|
||||
body
|
||||
"""
|
||||
wrapped_body = templates.replace(
|
||||
template,
|
||||
function_name=gast.Str(node.name),
|
||||
context_name=gast.Str(function_context_name),
|
||||
options=self.ctx.program.options.to_ast(),
|
||||
function_context_name=function_context_name,
|
||||
function_context=function_context_name,
|
||||
body=node.body)
|
||||
|
||||
if docstring_node is not None:
|
||||
|
@ -262,13 +262,15 @@ class EntityContext(transformer.Context):
|
||||
Attributes:
|
||||
namer: Namer
|
||||
info: transformer.EntityInfo
|
||||
program: ProgramContext
|
||||
program: ProgramContext,
|
||||
targe_name: Text
|
||||
"""
|
||||
|
||||
def __init__(self, namer, entity_info, program_ctx):
|
||||
def __init__(self, namer, entity_info, program_ctx, target_name=None):
|
||||
super(EntityContext, self).__init__(entity_info)
|
||||
self.namer = namer
|
||||
self.program = program_ctx
|
||||
self.target_name = target_name
|
||||
|
||||
|
||||
class Base(transformer.Base):
|
||||
|
@ -57,7 +57,7 @@ class TestCase(test.TestCase):
|
||||
|
||||
self.dynamic_calls = []
|
||||
# See api.converted_call
|
||||
def converted_call(f, unused_opts, args, kwargs):
|
||||
def converted_call(f, unused_opts, args, kwargs, unused_function_ctx):
|
||||
"""Mock version of api.converted_call."""
|
||||
self.dynamic_calls.append((args, kwargs))
|
||||
if kwargs is None:
|
||||
@ -135,7 +135,8 @@ class TestCase(test.TestCase):
|
||||
source_file='<fragment>',
|
||||
future_features=future_features,
|
||||
namespace=namespace)
|
||||
ctx = converter.EntityContext(namer, entity_info, program_ctx)
|
||||
ctx = converter.EntityContext(
|
||||
namer, entity_info, program_ctx, 'test_fn')
|
||||
origin_info.resolve_entity(node, source, test_fn)
|
||||
node = converter.standard_analysis(node, ctx, is_initial=True)
|
||||
return node, ctx
|
||||
|
@ -40,12 +40,13 @@ class FunctionScope(object):
|
||||
conversion options;
|
||||
"""
|
||||
|
||||
def __init__(self, function_name, options):
|
||||
def __init__(self, function_name, scope_name, options):
|
||||
self.name = scope_name
|
||||
self.options = options
|
||||
|
||||
if options.user_requested:
|
||||
self.autograph_ctx = ag_ctx.ControlStatusCtx(
|
||||
ag_ctx.Status.ENABLED, options)
|
||||
self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED,
|
||||
options)
|
||||
self.callopts = options.call_options()
|
||||
|
||||
use_name_scope = options.uses(converter.Feature.NAME_SCOPES)
|
||||
@ -101,7 +102,7 @@ class FunctionScope(object):
|
||||
return value
|
||||
|
||||
|
||||
def with_function_scope(thunk, options):
|
||||
def with_function_scope(thunk, scope_name, options):
|
||||
"""Inline version of the FunctionScope context manager."""
|
||||
with FunctionScope('lambda_', options) as scope:
|
||||
with FunctionScope('lambda_', scope_name, options) as scope:
|
||||
return thunk(scope)
|
||||
|
@ -33,7 +33,7 @@ class FunctionWrappersTest(test.TestCase):
|
||||
self.skipTest('Tensor names are disabled in eager')
|
||||
|
||||
with function_wrappers.FunctionScope(
|
||||
'test_name',
|
||||
'test_name', None,
|
||||
converter.ConversionOptions(
|
||||
optional_features=converter.Feature.NAME_SCOPES)):
|
||||
t = constant_op.constant(1)
|
||||
@ -42,7 +42,7 @@ class FunctionWrappersTest(test.TestCase):
|
||||
def test_auto_cotrol_deps(self):
|
||||
v = variables.Variable(1)
|
||||
with function_wrappers.FunctionScope(
|
||||
'_',
|
||||
'_', None,
|
||||
converter.ConversionOptions(
|
||||
optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope:
|
||||
v.assign(2)
|
||||
@ -51,9 +51,11 @@ class FunctionWrappersTest(test.TestCase):
|
||||
self.assertEqual(self.evaluate(v.read_value()), 2)
|
||||
|
||||
def test_all_disabled(self):
|
||||
with function_wrappers.FunctionScope(None, converter.STANDARD_OPTIONS):
|
||||
with function_wrappers.FunctionScope(None, None,
|
||||
converter.STANDARD_OPTIONS):
|
||||
t = constant_op.constant(1)
|
||||
self.assertEqual(self.evaluate(t), 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -362,8 +362,23 @@ def _is_known_loaded_type(f, module_name, entity_name):
|
||||
return False
|
||||
|
||||
|
||||
def converted_call(f, options, args, kwargs):
|
||||
"""Compiles a function call inline. For internal use only."""
|
||||
def converted_call(f, options, args, kwargs, caller_fn_scope=None):
|
||||
"""Compiles a function call inline.
|
||||
|
||||
For internal use only.
|
||||
|
||||
Args:
|
||||
f: The function to convert.
|
||||
options: converter.ConversionOptions
|
||||
args: Tuple, the original positional arguments of f
|
||||
kwargs: Dict, the original keyword arguments of f
|
||||
caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
|
||||
scope of the converted function in which this call was originally made.
|
||||
|
||||
Returns:
|
||||
Any, the result of executing a possibly-converted `f` with the given
|
||||
arguments.
|
||||
"""
|
||||
logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args,
|
||||
kwargs)
|
||||
|
||||
@ -372,9 +387,9 @@ def converted_call(f, options, args, kwargs):
|
||||
|
||||
if inspect_utils.isbuiltin(f):
|
||||
if f is eval:
|
||||
return py_builtins.eval_in_original_context(f, args, 1)
|
||||
return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
|
||||
if f is super:
|
||||
return py_builtins.super_in_original_context(f, args, 1)
|
||||
return py_builtins.super_in_original_context(f, args, caller_fn_scope)
|
||||
if kwargs:
|
||||
return py_builtins.overload_of(f)(*args, **kwargs)
|
||||
else:
|
||||
|
@ -653,24 +653,27 @@ def convert_func_to_ast(f, program_ctx, do_rename=True):
|
||||
_add_self_references(namespace, program_ctx.autograph_module)
|
||||
namer = naming.Namer(namespace)
|
||||
|
||||
if isinstance(node, gast.Lambda):
|
||||
new_name = namer.new_symbol('tf__lambda', ())
|
||||
elif do_rename:
|
||||
new_name = namer.function_name(f.__name__)
|
||||
else:
|
||||
new_name = f.__name__
|
||||
|
||||
entity_info = transformer.EntityInfo(
|
||||
source_code=source,
|
||||
source_file='<fragment>',
|
||||
future_features=future_features,
|
||||
namespace=namespace)
|
||||
context = converter.EntityContext(namer, entity_info, program_ctx)
|
||||
context = converter.EntityContext(namer, entity_info, program_ctx, new_name)
|
||||
node = node_to_graph(node, context)
|
||||
|
||||
if isinstance(node, gast.Lambda):
|
||||
new_name = namer.new_symbol('tf__lambda', ())
|
||||
node = gast.Assign(
|
||||
targets=[gast.Name(new_name, gast.Store(), None)], value=node)
|
||||
|
||||
elif do_rename:
|
||||
new_name = namer.function_name(f.__name__)
|
||||
node.name = new_name
|
||||
else:
|
||||
new_name = f.__name__
|
||||
assert node.name == new_name
|
||||
|
||||
return (node,), new_name, entity_info
|
||||
|
@ -100,6 +100,7 @@ py_test(
|
||||
deps = [
|
||||
":operators",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/autograph/core",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -49,11 +49,34 @@ def overload_of(f):
|
||||
return f
|
||||
|
||||
|
||||
def eval_in_original_context(f, args, caller_level_delta):
|
||||
"""Executes the eval function with the user-specified globals/locals."""
|
||||
def _find_originating_frame(caller_fn_scope, innermost=True):
|
||||
"""Locates the frame in which `caller_fn_scope` was defined."""
|
||||
ctx_frame = inspect.currentframe()
|
||||
for _ in range(caller_level_delta + 1):
|
||||
result = None
|
||||
while ctx_frame is not None:
|
||||
# Note it should not be normally possible to get false positives this way
|
||||
# because the function scope object is not accessible to user code (barring
|
||||
# call stack introspection).
|
||||
if ctx_frame.f_locals.get(caller_fn_scope.name, None) is caller_fn_scope:
|
||||
result = ctx_frame
|
||||
if innermost:
|
||||
break
|
||||
ctx_frame = ctx_frame.f_back
|
||||
|
||||
assert result is not None, (
|
||||
'the conversion process should ensure the caller_fn_scope is always'
|
||||
' found somewhere on the call stack')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def eval_in_original_context(f, args, caller_fn_scope):
|
||||
"""Executes the eval function in the context of a specified function."""
|
||||
# When control flow is rewritten using functions, eval should use the
|
||||
# variables found in the same block where it was called. That is equivalent
|
||||
# to the innermost function call.
|
||||
ctx_frame = _find_originating_frame(caller_fn_scope, innermost=True)
|
||||
|
||||
args = (
|
||||
args[0],
|
||||
ctx_frame.f_globals if len(args) < 2 else args[1],
|
||||
@ -62,33 +85,34 @@ def eval_in_original_context(f, args, caller_level_delta):
|
||||
return f(*args)
|
||||
|
||||
|
||||
def super_in_original_context(f, args, caller_level_delta):
|
||||
"""Executes the super function with the correct implicit argument handling.
|
||||
def super_in_original_context(f, args, caller_fn_scope):
|
||||
"""Executes the super function in the context of a specified function.
|
||||
|
||||
See https://docs.python.org/3/library/functions.html#super for the exact
|
||||
details
|
||||
|
||||
Args:
|
||||
f: super builtin function object.
|
||||
args: Arguments that will be passed to super(...). A valid call should have
|
||||
0, 1, or 2 number of arguments
|
||||
caller_level_delta: The number of nested frames to the original super(...)'s
|
||||
context frame.
|
||||
f: Callable, typically the super builtin
|
||||
args: List[Any], the original call arguments
|
||||
caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
|
||||
scope of the converted function in which this call was originally made
|
||||
|
||||
Returns:
|
||||
The result of super(...) call.
|
||||
The result of calling `f` as if it was called in the frame indicated by
|
||||
`caller_fn_scope`.
|
||||
"""
|
||||
|
||||
# Python 2 doesn't support implicit argument super variants.
|
||||
if six.PY2:
|
||||
return overload_of(f)(*args)
|
||||
return f(*args)
|
||||
|
||||
if len(args) >= 1:
|
||||
return overload_of(f)(*args)
|
||||
# Only the no-arg call is desugared.
|
||||
if args:
|
||||
return f(*args)
|
||||
|
||||
ctx_frame = inspect.currentframe()
|
||||
for _ in range(caller_level_delta + 1):
|
||||
ctx_frame = ctx_frame.f_back
|
||||
# Inner functions seem to include their closure in f_locals, so we need
|
||||
# to find the outermost frame.
|
||||
ctx_frame = _find_originating_frame(caller_fn_scope, innermost=False)
|
||||
|
||||
# When super(..) is called without arguments, it looks for __class__ cell
|
||||
# variable and the first argument passed in the enclosing function according
|
||||
@ -104,33 +128,23 @@ def super_in_original_context(f, args, caller_level_delta):
|
||||
# https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py2_test_grammar.py#L157
|
||||
# https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py3_test_grammar.py#L192
|
||||
#
|
||||
# TODO(kkimlabs): mdan@ had an idea to do it more correctly without relying
|
||||
# on the co_varnames argument order assumption.
|
||||
# 1. Getting the caller function from the call stack.
|
||||
# 2. Getting its argspec.
|
||||
# 3. Get the name of the first argument from argspec.
|
||||
# 4. Retrieve its value from locals.
|
||||
# Note: the name can be more reliably obtained by inspecting the calling
|
||||
# function's argspec.
|
||||
#
|
||||
# Sample code snippet:
|
||||
# Even though methods can be declared using *args (def method(*args)),
|
||||
# that pattern is disallowed by super() -- it raises super() no arguments.
|
||||
# Method definitions using **kwargs are not allowed at all.
|
||||
# In other words, we can always assume that self is on the first positional
|
||||
# argument (for correct code).
|
||||
#
|
||||
# def fn2():
|
||||
# fr = inspect.currentframe()
|
||||
# parent_fr = fr.f_back
|
||||
# grandparent_fr = parent_fr.f_back
|
||||
# f_name = parent_fr.f_code.co_name
|
||||
# f = grandparent_fr.f_locals[f_name]
|
||||
#
|
||||
# def fn1():
|
||||
# fn2()
|
||||
#
|
||||
# fn1()
|
||||
#
|
||||
# However, we also need to handle some edge cases like
|
||||
# function in the closure or globals, etc,...
|
||||
# TODO(mdan): Consider additional checks in case the input code is incorrect.
|
||||
# For example, the error might be cryptic compared to what super() regularly
|
||||
# raises.
|
||||
|
||||
first_arg_name = ctx_frame.f_code.co_varnames[0]
|
||||
first_arg = ctx_frame.f_locals[first_arg_name]
|
||||
return f(ctx_frame.f_locals['__class__'], first_arg)
|
||||
type_arg = ctx_frame.f_locals['__class__']
|
||||
self_arg_name = ctx_frame.f_code.co_varnames[0]
|
||||
self_arg = ctx_frame.f_locals[self_arg_name]
|
||||
return f(type_arg, self_arg)
|
||||
|
||||
|
||||
def abs_(x):
|
||||
|
@ -18,116 +18,105 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import function_wrappers
|
||||
from tensorflow.python.autograph.operators import py_builtins
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TestBaseClass(object):
|
||||
|
||||
def overridden_method(self, x):
|
||||
return x + 20
|
||||
|
||||
|
||||
class PyBuiltinsTest(test.TestCase):
|
||||
|
||||
def test_super_with_no_arg_in_original_context(self):
|
||||
def _basic_function_scope(self):
|
||||
return function_wrappers.FunctionScope(
|
||||
'test_function_name',
|
||||
'test_scope', # Note: this must match the name in the `with` statement.
|
||||
converter.ConversionOptions())
|
||||
|
||||
def test_super_in_original_context_niladic_call(self):
|
||||
test_case_self = self
|
||||
|
||||
class TestBase(object):
|
||||
class TestSubclass(TestBaseClass):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
return x + 20
|
||||
|
||||
class TestSubclass(TestBase):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
def overridden_method(self, x):
|
||||
test_case_self.fail('This should never be called.')
|
||||
|
||||
def no_arg(self):
|
||||
test_base = py_builtins.super_in_original_context(super, (), 0)
|
||||
return test_base.plus_twenty(1)
|
||||
def test_method(self):
|
||||
with test_case_self._basic_function_scope() as test_scope:
|
||||
b = py_builtins.super_in_original_context(super, (), test_scope)
|
||||
return b.overridden_method(1)
|
||||
|
||||
tc = TestSubclass()
|
||||
self.assertEqual(tc.no_arg(), 21)
|
||||
self.assertEqual(tc.test_method(), 21)
|
||||
|
||||
def test_super_in_original_context_with_locals(self):
|
||||
def test_super_in_original_context_caller_with_locals(self):
|
||||
test_case_self = self
|
||||
|
||||
class TestBase(object):
|
||||
class TestSubclass(TestBaseClass):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
return x + 20
|
||||
|
||||
class TestSubclass(TestBase):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
def overridden_method(self, x):
|
||||
test_case_self.fail('This should never be called.')
|
||||
|
||||
def with_locals(self):
|
||||
x = 1
|
||||
def test_method(self, x):
|
||||
y = 7
|
||||
z = 7
|
||||
|
||||
test_base = py_builtins.super_in_original_context(super, (), 0)
|
||||
return test_base.plus_twenty(x + y - z)
|
||||
with test_case_self._basic_function_scope() as test_scope:
|
||||
z = 7
|
||||
return py_builtins.super_in_original_context(
|
||||
super, (), test_scope).overridden_method(x + y - z)
|
||||
|
||||
tc = TestSubclass()
|
||||
self.assertEqual(tc.with_locals(), 21)
|
||||
self.assertEqual(tc.test_method(1), 21)
|
||||
|
||||
def test_super_in_original_context_with_args(self):
|
||||
def test_super_in_original_context_inner_function(self):
|
||||
test_case_self = self
|
||||
|
||||
class TestBase(object):
|
||||
class TestSubclass(TestBaseClass):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
return x + 20
|
||||
|
||||
class TestSubclass(TestBase):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
def overridden_method(self, x):
|
||||
test_case_self.fail('This should never be called.')
|
||||
|
||||
def with_args(self, x, y, z):
|
||||
test_base = py_builtins.super_in_original_context(super, (), 0)
|
||||
return test_base.plus_twenty(x + y - z)
|
||||
def test_method(self, x):
|
||||
with test_case_self._basic_function_scope() as test_scope:
|
||||
# Oddly, it's sufficient to use `self` in an inner function
|
||||
# to gain access to __class__ in this scope.
|
||||
# TODO(mdan): Is this true across implementations?
|
||||
# Note: normally, it's illegal to use super() in inner functions (it
|
||||
# throws an error), but the generated code may create them.
|
||||
def inner_fn():
|
||||
return py_builtins.super_in_original_context(
|
||||
super, (), test_scope).overridden_method(x)
|
||||
|
||||
return inner_fn()
|
||||
|
||||
tc = TestSubclass()
|
||||
self.assertEqual(tc.with_args(1, 7, 7), 21)
|
||||
self.assertEqual(tc.with_args.__func__(*[tc, 1, 7, 7]), 21)
|
||||
self.assertEqual(tc.test_method(1), 21)
|
||||
|
||||
def test_super_in_original_context_with_varargs(self):
|
||||
def test_super_in_original_context_inner_lambda(self):
|
||||
test_case_self = self
|
||||
|
||||
class TestBase(object):
|
||||
class TestSubclass(TestBaseClass):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
return x + 20
|
||||
|
||||
class TestSubclass(TestBase):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
def overridden_method(self, x):
|
||||
test_case_self.fail('This should never be called.')
|
||||
|
||||
def with_varargs(self, *args):
|
||||
test_base = py_builtins.super_in_original_context(super, (), 0)
|
||||
return test_base.plus_twenty(args[0] + args[1] - args[2])
|
||||
def test_method(self, x):
|
||||
with test_case_self._basic_function_scope() as test_scope:
|
||||
# Oddly, it's sufficient to use `self` in an inner function
|
||||
# to gain access to __class__ in this scope.
|
||||
# TODO(mdan): Is this true across implementations?
|
||||
# Note: normally, it's illegal to use super() in inner functions (it
|
||||
# throws an error), but the generated code may create them.
|
||||
l = lambda: py_builtins.super_in_original_context( # pylint:disable=g-long-lambda
|
||||
super, (), test_scope).overridden_method(x)
|
||||
return l()
|
||||
|
||||
tc = TestSubclass()
|
||||
self.assertEqual(tc.with_varargs.__func__(*[tc, 1, 7, 7]), 21)
|
||||
|
||||
def test_super_in_original_context_with_kwargs(self):
|
||||
test_case_self = self
|
||||
|
||||
class TestBase(object):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
return x + 20
|
||||
|
||||
class TestSubclass(TestBase):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
test_case_self.fail('This should never be called.')
|
||||
|
||||
def with_kwargs(self, **kwargs):
|
||||
test_base = py_builtins.super_in_original_context(super, (), 0)
|
||||
return test_base.plus_twenty(kwargs['x'] + kwargs['y'] - kwargs['z'])
|
||||
|
||||
tc = TestSubclass()
|
||||
self.assertEqual(tc.with_kwargs.__func__(tc, x=1, y=7, z=7), 21)
|
||||
self.assertEqual(tc.test_method(1), 21)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -22,6 +22,8 @@ import sys
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import function_wrappers
|
||||
from tensorflow.python.autograph.operators import data_structures
|
||||
from tensorflow.python.autograph.operators import py_builtins
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
@ -34,6 +36,12 @@ from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TestBase(object):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
return x + 20
|
||||
|
||||
|
||||
class PyBuiltinsTest(test.TestCase):
|
||||
|
||||
def test_abs(self):
|
||||
@ -155,66 +163,71 @@ class PyBuiltinsTest(test.TestCase):
|
||||
self.assertAllEqual(self.evaluate(iterator.get_next()), (20, b'a'))
|
||||
self.assertAllEqual(self.evaluate(iterator.get_next()), (21, b'c'))
|
||||
|
||||
def _basic_function_scope(self):
|
||||
return function_wrappers.FunctionScope(
|
||||
'test_function_name',
|
||||
'test_scope', # Note: this must match the name in the `with` statement.
|
||||
converter.ConversionOptions())
|
||||
|
||||
def test_eval_in_original_context(self):
|
||||
|
||||
def caller_1(lvl_delta):
|
||||
def test_fn():
|
||||
l = 1 # pylint:disable=unused-variable
|
||||
return py_builtins.eval_in_original_context(eval, ('l',), lvl_delta)
|
||||
with self._basic_function_scope() as test_scope:
|
||||
return py_builtins.eval_in_original_context(eval, ('l',), test_scope)
|
||||
|
||||
def caller_2(lvl_delta):
|
||||
l = 2 # pylint:disable=unused-variable
|
||||
return caller_1(lvl_delta)
|
||||
self.assertEqual(test_fn(), 1)
|
||||
|
||||
def caller_3(lvl_delta):
|
||||
l = 3 # pylint:disable=unused-variable
|
||||
return caller_2(lvl_delta)
|
||||
def test_eval_in_original_context_inner_function(self):
|
||||
|
||||
self.assertEqual(caller_3(0), 1)
|
||||
self.assertEqual(caller_3(1), 2)
|
||||
self.assertEqual(caller_3(2), 3)
|
||||
def test_fn():
|
||||
l = 1 # pylint:disable=unused-variable
|
||||
with self._basic_function_scope() as test_scope:
|
||||
|
||||
def test_super_with_one_arg_in_original_context(self):
|
||||
def inner_fn():
|
||||
# Note: a user function without a top-level function scope should
|
||||
# never be found in user code; it's only possible in generated code.
|
||||
l = 2 # pylint:disable=unused-variable
|
||||
return py_builtins.eval_in_original_context(eval, ('l',), test_scope)
|
||||
|
||||
return inner_fn()
|
||||
|
||||
self.assertEqual(test_fn(), 2)
|
||||
|
||||
def test_super_in_original_context_unary_call(self):
|
||||
test_case_self = self
|
||||
|
||||
class TestBase(object):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
return x + 20
|
||||
|
||||
class TestSubclass(TestBase):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
test_case_self.fail('This should never be called.')
|
||||
|
||||
def one_arg(self):
|
||||
test_base_unbound = py_builtins.super_in_original_context(
|
||||
super, (TestSubclass,), 0)
|
||||
test_base = test_base_unbound.__get__(self, TestSubclass)
|
||||
return test_base.plus_twenty(1)
|
||||
def test_method(self):
|
||||
with test_case_self._basic_function_scope() as test_scope:
|
||||
test_base_unbound = py_builtins.super_in_original_context(
|
||||
super, (TestSubclass,), test_scope)
|
||||
test_base = test_base_unbound.__get__(self, TestSubclass)
|
||||
return test_base.plus_twenty(1)
|
||||
|
||||
tc = TestSubclass()
|
||||
self.assertEqual(tc.one_arg(), 21)
|
||||
self.assertEqual(tc.test_method(), 21)
|
||||
|
||||
def test_super_with_two_args_in_original_context(self):
|
||||
def test_super_in_original_context_binary_call(self):
|
||||
test_case_self = self
|
||||
|
||||
class TestBase(object):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
return x + 20
|
||||
|
||||
class TestSubclass(TestBase):
|
||||
|
||||
def plus_twenty(self, x):
|
||||
test_case_self.fail('This should never be called.')
|
||||
|
||||
def two_args(self):
|
||||
test_base = py_builtins.super_in_original_context(
|
||||
super, (TestSubclass, self), 0)
|
||||
return test_base.plus_twenty(1)
|
||||
def test_method(self):
|
||||
with test_case_self._basic_function_scope() as test_scope:
|
||||
test_base = py_builtins.super_in_original_context(
|
||||
super, (TestSubclass, self), test_scope)
|
||||
return test_base.plus_twenty(1)
|
||||
|
||||
tc = TestSubclass()
|
||||
self.assertEqual(tc.two_args(), 21)
|
||||
self.assertEqual(tc.test_method(), 21)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user