More robustly check for undefined symbols before attempting to use them. This check is required because undefined symbols are initialized with a special placeholder before entering control flow. This placeholder can lead to confusing error messages if left unchecked. The change introduces two more general operators: "variable load" and "return".
PiperOrigin-RevId: 311411422 Change-Id: Ic8abda74c1f68c1d4de491949d309d60099b91b4
This commit is contained in:
parent
18c0da1024
commit
4eeb6d742e
@ -33,6 +33,7 @@ py_library(
|
||||
"logical_expressions.py",
|
||||
"return_statements.py",
|
||||
"slices.py",
|
||||
"variables.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
@ -213,3 +214,16 @@ py_test(
|
||||
"//tensorflow/python/autograph/pyct",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "variables_test",
|
||||
srcs = ["variables_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":converters",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/autograph/core:test_lib",
|
||||
"//tensorflow/python/autograph/pyct",
|
||||
],
|
||||
)
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.converters import asserts
|
||||
from tensorflow.python.autograph.converters import functions
|
||||
from tensorflow.python.autograph.converters import return_statements
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -36,7 +37,8 @@ class AssertsTest(converter_testing.TestCase):
|
||||
return a
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with self.converted(test_fn, (functions, asserts), {}) as result:
|
||||
with self.converted(
|
||||
test_fn, (functions, asserts, return_statements), {}) as result:
|
||||
op = result.test_fn(constant_op.constant(False))
|
||||
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
|
||||
|
@ -38,15 +38,6 @@ class _Function(object):
|
||||
class FunctionTransformer(converter.Base):
|
||||
"""Wraps function bodies around autograph-specific boilerplate."""
|
||||
|
||||
def visit_Return(self, node):
|
||||
if node.value is None:
|
||||
return node
|
||||
node = self.generic_visit(node)
|
||||
return templates.replace(
|
||||
'return function_context_name.mark_return_value(value)',
|
||||
function_context_name=self.state[_Function].context_name,
|
||||
value=node.value)
|
||||
|
||||
def _function_scope_options(self, fn_scope):
|
||||
"""Returns the options with which to create function scopes."""
|
||||
# Top-level function receive the options that were directly requested.
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.converters import functions
|
||||
from tensorflow.python.autograph.converters import return_statements
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
@ -74,7 +75,7 @@ class FunctionTransformer(converter_testing.TestCase):
|
||||
l += 1
|
||||
return l, inner_fn(l)
|
||||
|
||||
with self.converted(test_fn, functions, {},
|
||||
with self.converted(test_fn, (functions, return_statements), {},
|
||||
(ops.name_scope,)) as result:
|
||||
first, second = result.test_fn(constant_op.constant(1))
|
||||
self.assertIn('test_fn/', first.op.name)
|
||||
@ -119,6 +120,7 @@ class FunctionTransformer(converter_testing.TestCase):
|
||||
ns = {'TestClass': TestClass}
|
||||
node, ctx = self.prepare(TestClass, ns)
|
||||
node = functions.transform(node, ctx)
|
||||
node = return_statements.transform(node, ctx)
|
||||
|
||||
with self.compiled(node, {}, (ops.name_scope,)) as result:
|
||||
first, second = result.TestClass().test_fn(constant_op.constant(1))
|
||||
|
@ -220,9 +220,9 @@ class ReturnStatementsTransformer(converter.Base):
|
||||
retval = val
|
||||
"""
|
||||
|
||||
def __init__(self, ctx, default_to_null_return):
|
||||
def __init__(self, ctx, allow_missing_return):
|
||||
super(ReturnStatementsTransformer, self).__init__(ctx)
|
||||
self.default_to_null_return = default_to_null_return
|
||||
self.allow_missing_return = allow_missing_return
|
||||
|
||||
def visit_Return(self, node):
|
||||
for block in reversed(self.state[_Block].stack):
|
||||
@ -339,75 +339,68 @@ class ReturnStatementsTransformer(converter.Base):
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
self.state[_Function].enter()
|
||||
self.state[_Block].enter()
|
||||
self.state[_Block].is_function = True
|
||||
with self.state[_Function] as fn:
|
||||
with self.state[_Block] as block:
|
||||
block.is_function = True
|
||||
|
||||
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
||||
do_return_var_name = self.ctx.namer.new_symbol(
|
||||
'do_return', scope.referenced)
|
||||
retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced)
|
||||
self.state[_Function].do_return_var_name = do_return_var_name
|
||||
self.state[_Function].retval_var_name = retval_var_name
|
||||
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
||||
do_return_var_name = self.ctx.namer.new_symbol('do_return',
|
||||
scope.referenced)
|
||||
retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced)
|
||||
fn.do_return_var_name = do_return_var_name
|
||||
fn.retval_var_name = retval_var_name
|
||||
|
||||
converted_body = self._visit_statement_block(node, node.body)
|
||||
node.body = self._visit_statement_block(node, node.body)
|
||||
|
||||
# Avoid placing statements before any eventual docstring.
|
||||
# TODO(mdan): Should a docstring even be included in the output?
|
||||
docstring = None
|
||||
if converted_body:
|
||||
if (isinstance(converted_body[0], gast.Expr) and
|
||||
isinstance(converted_body[0].value, gast.Constant)):
|
||||
docstring = converted_body[0]
|
||||
converted_body = converted_body[1:]
|
||||
if block.return_used:
|
||||
|
||||
if self.state[_Block].return_used:
|
||||
if self.allow_missing_return:
|
||||
# The function whould have a single `with` node that wraps the
|
||||
# entire body. If the function had a docstring, the body has two
|
||||
# nodes, with the `with` as the second node.
|
||||
wrapper_node = node.body[-1]
|
||||
assert isinstance(wrapper_node, gast.With), (
|
||||
'This transformer requires the functions converter.')
|
||||
|
||||
if self.default_to_null_return:
|
||||
# TODO(mdan): Remove the (do_return_var_name,) below.
|
||||
# Currently, that line ensures the variable is both defined and alive
|
||||
# throughout the function.
|
||||
template = """
|
||||
do_return_var_name = False
|
||||
retval_var_name = ag__.UndefinedReturnValue()
|
||||
body
|
||||
(do_return_var_name,)
|
||||
return ag__.retval(retval_var_name)
|
||||
"""
|
||||
else:
|
||||
template = """
|
||||
body
|
||||
return retval_var_name
|
||||
"""
|
||||
node.body = templates.replace(
|
||||
template,
|
||||
body=converted_body,
|
||||
do_return_var_name=do_return_var_name,
|
||||
retval_var_name=retval_var_name)
|
||||
template = """
|
||||
do_return_var_name = False
|
||||
retval_var_name = ag__.UndefinedReturnValue()
|
||||
body
|
||||
return function_context.ret(retval_var_name, do_return_var_name)
|
||||
"""
|
||||
|
||||
if docstring:
|
||||
node.body.insert(0, docstring)
|
||||
wrapper_node.body = templates.replace(
|
||||
template,
|
||||
body=wrapper_node.body,
|
||||
do_return_var_name=do_return_var_name,
|
||||
function_context=anno.getanno(node, 'function_context_name'),
|
||||
retval_var_name=retval_var_name)
|
||||
else:
|
||||
template = """
|
||||
body
|
||||
return retval_var_name
|
||||
"""
|
||||
node.body = templates.replace(
|
||||
template,
|
||||
body=node.body,
|
||||
do_return_var_name=do_return_var_name,
|
||||
retval_var_name=retval_var_name)
|
||||
|
||||
self.state[_Block].exit()
|
||||
self.state[_Function].exit()
|
||||
return node
|
||||
|
||||
|
||||
def transform(node, ctx, default_to_null_return=True):
|
||||
"""Ensure a function has only a single return."""
|
||||
# Note: Technically, these two could be merged into a single walk, but
|
||||
# keeping them separate helps with readability.
|
||||
|
||||
"""Ensure a function has only a single return, at the end."""
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
|
||||
# Note: Technically, these two could be merged into a single walk, but
|
||||
# keeping them separate helps with readability.
|
||||
node = ConditionalReturnRewriter(ctx).visit(node)
|
||||
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
|
||||
transformer = ReturnStatementsTransformer(
|
||||
ctx, default_to_null_return=default_to_null_return)
|
||||
ctx, allow_missing_return=default_to_null_return)
|
||||
node = transformer.visit(node)
|
||||
|
||||
return node
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.converters import functions
|
||||
from tensorflow.python.autograph.converters import return_statements
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.framework import ops
|
||||
@ -28,7 +29,7 @@ class SingleReturnTest(converter_testing.TestCase):
|
||||
|
||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
||||
ns = {'ops': ops}
|
||||
with self.converted(test_fn, return_statements, ns) as result:
|
||||
with self.converted(test_fn, (functions, return_statements), ns) as result:
|
||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
||||
|
||||
def test_straightline(self):
|
||||
|
76
tensorflow/python/autograph/converters/variables.py
Normal file
76
tensorflow/python/autograph/converters/variables.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Overloads all variable read operations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
|
||||
|
||||
class VariableAccessTransformer(converter.Base):
|
||||
"""Rewrites basic symbol reads.
|
||||
|
||||
This transformer rewrites variable reads with a "read" operator which allows
|
||||
tracking activity.
|
||||
|
||||
Example:
|
||||
|
||||
For a basic statement:
|
||||
|
||||
a = b + c
|
||||
|
||||
This is translated to:
|
||||
|
||||
a = ld(b) + ld(c)
|
||||
|
||||
Augmented assignment operations also introduce a `ld` operator:
|
||||
|
||||
a += b
|
||||
|
||||
The assignment target also receives an operator to properly represent the
|
||||
read:
|
||||
|
||||
a = ld(a)
|
||||
a += ld(b)
|
||||
"""
|
||||
|
||||
def visit_Name(self, node):
|
||||
# Only the loads which existed in the original code are overloaded.
|
||||
if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS):
|
||||
return node
|
||||
if isinstance(node.ctx, gast.Load):
|
||||
node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
|
||||
return node
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
if isinstance(node.target, gast.Name):
|
||||
template = """
|
||||
var_ = ag__.ld(var_)
|
||||
original
|
||||
"""
|
||||
node = templates.replace(template, var_=node.target, original=node)
|
||||
else:
|
||||
node = self.generic_visit(node)
|
||||
return node
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
return VariableAccessTransformer(ctx).visit(node)
|
116
tensorflow/python/autograph/converters/variables_test.py
Normal file
116
tensorflow/python/autograph/converters/variables_test.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for variables module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
from tensorflow.python.autograph.converters import variables
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class VariablesTest(converter_testing.TestCase):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def apply_add_one_conversion(self, fn):
|
||||
"""Generates code which adds 1 to all variable reads."""
|
||||
with self.converted(fn, variables, {}) as result:
|
||||
result.ag__.__dict__['ld'] = lambda x: x + 1
|
||||
yield result
|
||||
|
||||
def test_read(self):
|
||||
|
||||
def test_fn(l):
|
||||
return l
|
||||
|
||||
with self.apply_add_one_conversion(test_fn) as result:
|
||||
self.assertEqual(result.test_fn(1), 2)
|
||||
|
||||
def test_aug_assign(self):
|
||||
|
||||
def test_fn(l):
|
||||
l *= 10
|
||||
return l
|
||||
|
||||
with self.apply_add_one_conversion(test_fn) as result:
|
||||
self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads
|
||||
|
||||
def test_attribute(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def __init__(self):
|
||||
self.v = 1
|
||||
|
||||
def __add__(self, other):
|
||||
self.v += other
|
||||
return self
|
||||
|
||||
def test_fn(l):
|
||||
return l.v
|
||||
|
||||
tc = TestClass()
|
||||
with self.apply_add_one_conversion(test_fn) as result:
|
||||
self.assertEqual(result.test_fn(tc), 2)
|
||||
|
||||
def test_subscript(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def __init__(self):
|
||||
self.v = 1
|
||||
|
||||
def __add__(self, other):
|
||||
self.v += other
|
||||
return self
|
||||
|
||||
def __getitem__(self, _):
|
||||
return self.v
|
||||
|
||||
def test_fn(l):
|
||||
return l[0]
|
||||
|
||||
tc = TestClass()
|
||||
with self.apply_add_one_conversion(test_fn) as result:
|
||||
self.assertEqual(result.test_fn(tc), 2)
|
||||
|
||||
def test_call(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def __init__(self):
|
||||
self.v = 1
|
||||
|
||||
def __add__(self, other):
|
||||
self.v += other
|
||||
return self
|
||||
|
||||
def __call__(self):
|
||||
return self.v
|
||||
|
||||
def test_fn(l):
|
||||
return l()
|
||||
|
||||
tc = TestClass()
|
||||
with self.apply_add_one_conversion(test_fn) as result:
|
||||
self.assertEqual(result.test_fn(tc), 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -30,6 +30,7 @@ py_library(
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python/autograph/operators",
|
||||
"//tensorflow/python/autograph/pyct",
|
||||
"//tensorflow/python/autograph/pyct/static_analysis",
|
||||
"//tensorflow/python/autograph/utils",
|
||||
|
@ -20,12 +20,16 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.operators import variables
|
||||
from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# TODO(mdan): Move this into operators - it represents a function definition.
|
||||
|
||||
|
||||
class FunctionScope(object):
|
||||
"""Context manager that wraps the body of a converted function.
|
||||
|
||||
@ -84,8 +88,13 @@ class FunctionScope(object):
|
||||
if self.use_auto_deps:
|
||||
self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def mark_return_value(self, value):
|
||||
def ret(self, value, did_return):
|
||||
"""Marks a value as returned from the function guarded by the scope."""
|
||||
del did_return
|
||||
|
||||
if isinstance(value, variables.UndefinedReturnValue):
|
||||
return None
|
||||
|
||||
if self.use_auto_deps:
|
||||
self._return_value_marked = True
|
||||
if value is None:
|
||||
|
@ -46,7 +46,7 @@ class FunctionWrappersTest(test.TestCase):
|
||||
converter.ConversionOptions(
|
||||
optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope:
|
||||
v.assign(2)
|
||||
op = scope.mark_return_value(constant_op.constant(1))
|
||||
op = scope.ret(constant_op.constant(1), True)
|
||||
self.evaluate(op)
|
||||
self.assertEqual(self.evaluate(v.read_value()), 2)
|
||||
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.autograph.converters import lists
|
||||
from tensorflow.python.autograph.converters import logical_expressions
|
||||
from tensorflow.python.autograph.converters import return_statements
|
||||
from tensorflow.python.autograph.converters import slices
|
||||
from tensorflow.python.autograph.converters import variables
|
||||
from tensorflow.python.autograph.core import config
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import function_wrappers
|
||||
@ -92,6 +93,7 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
|
||||
node = control_flow.transform(node, ctx)
|
||||
node = conditional_expressions.transform(node, ctx)
|
||||
node = logical_expressions.transform(node, ctx)
|
||||
node = variables.transform(node, ctx)
|
||||
return node
|
||||
|
||||
|
||||
|
@ -29,8 +29,7 @@ py_library(
|
||||
"logical.py",
|
||||
"py_builtins.py",
|
||||
"slices.py",
|
||||
"special_values.py",
|
||||
"symbols.py",
|
||||
"variables.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
@ -148,19 +147,8 @@ py_test(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "special_values_test",
|
||||
srcs = ["special_values_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":operators",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "symbols_test",
|
||||
srcs = ["symbols_test.py"],
|
||||
name = "variables_test",
|
||||
srcs = ["variables_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
|
@ -60,8 +60,6 @@ from tensorflow.python.autograph.operators.py_builtins import range_
|
||||
from tensorflow.python.autograph.operators.slices import get_item
|
||||
from tensorflow.python.autograph.operators.slices import GetItemOpts
|
||||
from tensorflow.python.autograph.operators.slices import set_item
|
||||
from tensorflow.python.autograph.operators.special_values import is_undefined
|
||||
from tensorflow.python.autograph.operators.special_values import is_undefined_return
|
||||
from tensorflow.python.autograph.operators.special_values import retval
|
||||
from tensorflow.python.autograph.operators.special_values import Undefined
|
||||
from tensorflow.python.autograph.operators.special_values import UndefinedReturnValue
|
||||
from tensorflow.python.autograph.operators.variables import ld
|
||||
from tensorflow.python.autograph.operators.variables import Undefined
|
||||
from tensorflow.python.autograph.operators.variables import UndefinedReturnValue
|
||||
|
@ -65,7 +65,7 @@ import traceback
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.autograph.operators import py_builtins
|
||||
from tensorflow.python.autograph.operators import special_values
|
||||
from tensorflow.python.autograph.operators import variables
|
||||
from tensorflow.python.autograph.utils import ag_logging
|
||||
from tensorflow.python.autograph.utils import compat_util
|
||||
from tensorflow.python.autograph.utils import misc
|
||||
@ -103,13 +103,13 @@ def _verify_loop_init_vars(values, symbol_names):
|
||||
for name, value in zip(symbol_names, values):
|
||||
if value is None:
|
||||
raise ValueError('"{}" may not be None before the loop.'.format(name))
|
||||
if special_values.is_undefined_return(value):
|
||||
if isinstance(value, variables.UndefinedReturnValue):
|
||||
# Assumption: the loop will only capture the variable which tracks the
|
||||
# return value if the loop contained a return statement.
|
||||
# TODO(mdan): This should be checked at the place where return occurs.
|
||||
raise ValueError(
|
||||
'return statements are not supported within a TensorFlow loop.')
|
||||
if special_values.is_undefined(value):
|
||||
if isinstance(value, variables.Undefined):
|
||||
raise ValueError('"{}" must be defined before the loop.'.format(name))
|
||||
|
||||
|
||||
@ -495,8 +495,7 @@ def _tf_range_for_stmt(
|
||||
iterate = compat_util.BasicRef(start)
|
||||
|
||||
def _value_or(name, var, default):
|
||||
if (name == opts['iterate_names']
|
||||
and isinstance(var, special_values.Undefined)):
|
||||
if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
|
||||
return default
|
||||
return var
|
||||
|
||||
@ -1019,7 +1018,15 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
|
||||
results_tuple = results
|
||||
else:
|
||||
results_tuple = results,
|
||||
undefined = tuple(filter(special_values.is_undefined, results_tuple))
|
||||
|
||||
for result in results_tuple:
|
||||
if isinstance(result, variables.UndefinedReturnValue):
|
||||
raise ValueError(
|
||||
'A value must also be returned from the {} branch. If a value is '
|
||||
'returned from one branch of a conditional a value must be '
|
||||
'returned from all branches.'.format(branch_name))
|
||||
|
||||
undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
|
||||
if undefined:
|
||||
raise ValueError(
|
||||
'The following symbols must also be initialized in the {} branch: {}.'
|
||||
@ -1027,13 +1034,6 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
|
||||
' statement.'.format(branch_name,
|
||||
tuple(s.symbol_name for s in undefined)))
|
||||
|
||||
for result in results_tuple:
|
||||
if special_values.is_undefined_return(result):
|
||||
raise ValueError(
|
||||
'A value must also be returned from the {} branch. If a value is '
|
||||
'returned from one branch of a conditional a value must be '
|
||||
'returned from all branches.'.format(branch_name))
|
||||
|
||||
return results
|
||||
|
||||
return wrapper
|
||||
|
@ -66,7 +66,7 @@ import functools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.autograph.operators import py_builtins
|
||||
from tensorflow.python.autograph.operators import special_values
|
||||
from tensorflow.python.autograph.operators import variables
|
||||
from tensorflow.python.autograph.utils import ag_logging
|
||||
from tensorflow.python.autograph.utils import misc
|
||||
from tensorflow.python.autograph.utils import tensors
|
||||
@ -103,13 +103,13 @@ INEFFICIENT_UNROLL_MIN_OPS = 1
|
||||
|
||||
def _disallow_undefs_into_loop(*values):
|
||||
"""Ensures that all values in the state are defined when entering a loop."""
|
||||
undefined = tuple(filter(special_values.is_undefined, values))
|
||||
undefined = [v for v in values if isinstance(v, variables.Undefined)]
|
||||
if undefined:
|
||||
raise ValueError(
|
||||
'{} must be defined before the loop.'.format(
|
||||
','.join(s.symbol_name for s in undefined)))
|
||||
for value in values:
|
||||
if special_values.is_undefined_return(value):
|
||||
if isinstance(value, variables.UndefinedReturnValue):
|
||||
# Assumption: the loop will only capture the variable which tracks the
|
||||
# return value if the loop contained a return statement.
|
||||
# TODO(mdan): This should be checked at the place where return occurs.
|
||||
@ -1129,7 +1129,7 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
|
||||
results_tuple = results
|
||||
else:
|
||||
results_tuple = results,
|
||||
undefined = tuple(filter(special_values.is_undefined, results_tuple))
|
||||
undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
|
||||
if undefined:
|
||||
raise ValueError(
|
||||
'The following symbols must also be initialized in the {} branch: {}.'
|
||||
@ -1138,7 +1138,7 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
|
||||
tuple(s.symbol_name for s in undefined)))
|
||||
|
||||
for result in results_tuple:
|
||||
if special_values.is_undefined_return(result):
|
||||
if isinstance(result, variables.UndefinedReturnValue):
|
||||
raise ValueError(
|
||||
'A value must also be returned from the {} branch. If a value is '
|
||||
'returned from one branch of a conditional a value must be '
|
||||
|
@ -29,7 +29,7 @@ import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.autograph.operators import control_flow
|
||||
from tensorflow.python.autograph.operators import special_values
|
||||
from tensorflow.python.autograph.operators import variables as variable_operators
|
||||
from tensorflow.python.autograph.utils import ag_logging
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -546,7 +546,7 @@ class ForLoopTest(test.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, '"s" may not be None'):
|
||||
self._basic_loop(None, lambda i, s: s)
|
||||
with self.assertRaisesRegex(ValueError, '"s" must be defined'):
|
||||
self._basic_loop(special_values.Undefined(''), lambda i, s: s)
|
||||
self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
|
||||
|
||||
def test_tensor_none_output(self):
|
||||
with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
|
||||
@ -785,7 +785,7 @@ class WhileLoopTest(test.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, '"s" may not be None'):
|
||||
self._basic_loop(None, lambda i, s: s)
|
||||
with self.assertRaisesRegex(ValueError, '"s" must be defined'):
|
||||
self._basic_loop(special_values.Undefined(''), lambda i, s: s)
|
||||
self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
|
||||
|
||||
def test_tensor_none_output(self):
|
||||
with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
|
||||
@ -887,10 +887,10 @@ class IfStmtTest(test.TestCase):
|
||||
def test_tensor_undefined_output(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "must also be initialized in the if.*'s'"):
|
||||
self._basic_cond(lambda: special_values.Undefined('s'), lambda: 1)
|
||||
self._basic_cond(lambda: variable_operators.Undefined('s'), lambda: 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "must also be initialized in the else.*'s'"):
|
||||
self._basic_cond(lambda: 1, lambda: special_values.Undefined('s'))
|
||||
self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s'))
|
||||
|
||||
def test_tensor_dtype_change(self):
|
||||
with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'):
|
||||
|
@ -1,115 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Abstract representation of composite symbols that can be used in staging code.
|
||||
|
||||
This provides a way to checkpoint the values of symbols that may be undefined
|
||||
entering staged control flow. This checkpointing is necessary to prevent some
|
||||
unintended side-effects. For example checkpointing prevents side-effects in one
|
||||
branch of a conditional from leaking into another.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.operators import special_values
|
||||
|
||||
|
||||
is_undefined = special_values.is_undefined
|
||||
Undefined = special_values.Undefined
|
||||
|
||||
|
||||
class Symbol(object):
|
||||
"""Representation of a simple or composite Python symbol.
|
||||
|
||||
Subclasses should implement `maybe_compute_value(self)` that returns the value
|
||||
corresponding to the symbol or Undefined if no such value exists.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class ValueSymbol(Symbol):
|
||||
"""Representation of a simple Python symbol with a concrete value.
|
||||
|
||||
This includes variables and literals. Since we are reifying undefined symbols
|
||||
`Undefined` is also a valid value.
|
||||
"""
|
||||
|
||||
def __init__(self, name, value):
|
||||
super(ValueSymbol, self).__init__(name)
|
||||
self.value = value
|
||||
|
||||
def maybe_compute_value(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class AttributeAccessSymbol(Symbol):
|
||||
"""Representation of Python attribute access e.g. `a.b`."""
|
||||
|
||||
def __init__(self, parent_symbol, attr_name):
|
||||
super(AttributeAccessSymbol, self).__init__(
|
||||
parent_symbol.name + '.' + attr_name)
|
||||
self.attr_name = attr_name
|
||||
self.parent_symbol = parent_symbol
|
||||
|
||||
def maybe_compute_value(self):
|
||||
"""Compute the value corresponding to the attribute access or `Undefined`.
|
||||
|
||||
This will be `Undefined` if no such value exists either because there is no
|
||||
such attribute or if the base is itself undefined.
|
||||
|
||||
Returns:
|
||||
value corresponding to the attribute access or `Undefined`
|
||||
"""
|
||||
parent_value = self.parent_symbol.maybe_compute_value()
|
||||
if (is_undefined(parent_value) or
|
||||
getattr(parent_value, self.attr_name, None) is None):
|
||||
return Undefined(self.name)
|
||||
|
||||
return parent_value.__getattribute__(self.attr_name)
|
||||
|
||||
|
||||
class SubscriptSymbol(Symbol):
|
||||
"""Representation of Python subscript access e.g. `a[b]`."""
|
||||
|
||||
def __init__(self, parent_symbol, index_symbol):
|
||||
super(SubscriptSymbol, self).__init__(
|
||||
parent_symbol.name + '[' + index_symbol.name + ']')
|
||||
self.index_symbol = index_symbol
|
||||
self.parent_symbol = parent_symbol
|
||||
|
||||
def maybe_compute_value(self):
|
||||
"""Compute the value corresponding to the subscript access or `Undefined`.
|
||||
|
||||
This will be `Undefined` if no such value exists either because there is no
|
||||
element corresponding to the given subscript or if the base itself is
|
||||
not defined.
|
||||
|
||||
Returns:
|
||||
value corresponding to the subscript access or `Undefined`
|
||||
"""
|
||||
parent_value = self.parent_symbol.maybe_compute_value()
|
||||
index_value = self.index_symbol.maybe_compute_value()
|
||||
if is_undefined(parent_value) or is_undefined(index_value):
|
||||
return Undefined(self.name)
|
||||
|
||||
try:
|
||||
return parent_value[index_value]
|
||||
except (IndexError, KeyError, TypeError):
|
||||
# Reify the lack of an object for the given index/key
|
||||
# This allows us to define them later without regret
|
||||
return Undefined(self.name)
|
@ -1,230 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for special symbol handling."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.operators import special_values
|
||||
from tensorflow.python.autograph.operators import symbols
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
Undefined = special_values.Undefined
|
||||
AttributeAccessSymbol = symbols.AttributeAccessSymbol
|
||||
SubscriptSymbol = symbols.SubscriptSymbol
|
||||
ValueSymbol = symbols.ValueSymbol
|
||||
|
||||
|
||||
class SymbolsTest(test.TestCase):
|
||||
|
||||
def test_value_symbol_returns_value(self):
|
||||
a = 42
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertEqual(a_symbol.name, 'a')
|
||||
|
||||
def test_attribute_access_missing_attribute(self):
|
||||
class Foo(object):
|
||||
pass
|
||||
a = Foo()
|
||||
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a.b')
|
||||
|
||||
def test_attribute_access_undefined_target(self):
|
||||
a = Undefined('a')
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a.b')
|
||||
|
||||
def test_attribute_access_basic(self):
|
||||
class Foo(object):
|
||||
|
||||
def __init__(self):
|
||||
self.b = 'this is an attribute'
|
||||
|
||||
a = Foo()
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
|
||||
|
||||
def test_item_access_undefined_index(self):
|
||||
class Foo(object):
|
||||
|
||||
def __getitem__(self, key):
|
||||
return 'this is an item'
|
||||
|
||||
a = Foo()
|
||||
b = Undefined('b')
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
b_symbol = ValueSymbol('b', b)
|
||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
||||
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
|
||||
|
||||
def test_item_access_no_getitem(self):
|
||||
class Foo(object):
|
||||
pass
|
||||
|
||||
a = Foo()
|
||||
b = 42
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
b_symbol = ValueSymbol('b', b)
|
||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
||||
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
|
||||
|
||||
def test_item_access_undefined_root(self):
|
||||
a = Undefined('a')
|
||||
b = 42
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
b_symbol = ValueSymbol('b', b)
|
||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
||||
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
|
||||
|
||||
def test_item_access_basic(self):
|
||||
class Foo(object):
|
||||
|
||||
def __getitem__(self, key):
|
||||
return 'this is an item'
|
||||
|
||||
a = Foo()
|
||||
b = 42
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
b_symbol = ValueSymbol('b', b)
|
||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
|
||||
|
||||
def test_item_access_after_attribute_access(self):
|
||||
class Foo(object):
|
||||
|
||||
def __getitem__(self, key):
|
||||
return 'this is an item'
|
||||
|
||||
class Bar(object):
|
||||
|
||||
def __init__(self):
|
||||
self.b = Foo()
|
||||
|
||||
a = Bar()
|
||||
c = 42
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
c_symbol = ValueSymbol('c', c)
|
||||
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
|
||||
a_b_c_symbol = SubscriptSymbol(a_b_symbol, c_symbol)
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertEqual(c_symbol.maybe_compute_value(), c)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
|
||||
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a.b[c])
|
||||
|
||||
def test_attribute_access_after_item_access(self):
|
||||
class Bar(object):
|
||||
|
||||
def __init__(self):
|
||||
self.c = object()
|
||||
|
||||
item = Bar()
|
||||
|
||||
class Foo(object):
|
||||
|
||||
def __getitem__(self, key):
|
||||
return item
|
||||
|
||||
a = Foo()
|
||||
b = 42
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
b_symbol = ValueSymbol('b', b)
|
||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
||||
a_b_c_symbol = AttributeAccessSymbol(a_b_symbol, 'c')
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
|
||||
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a[b].c)
|
||||
|
||||
def test_item_access_after_item_access(self):
|
||||
class Bar(object):
|
||||
|
||||
def __getitem__(self, key):
|
||||
return 'this is an item'
|
||||
|
||||
item = Bar()
|
||||
|
||||
class Foo(object):
|
||||
|
||||
def __getitem__(self, key):
|
||||
return item
|
||||
|
||||
a = Foo()
|
||||
b = 42
|
||||
c = 43
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
b_symbol = ValueSymbol('b', b)
|
||||
c_symbol = ValueSymbol('b', c)
|
||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
||||
a_b_c_symbol = SubscriptSymbol(a_b_symbol, c_symbol)
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
|
||||
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a[b][c])
|
||||
|
||||
def test_attribute_access_after_attribute_access(self):
|
||||
class Bar(object):
|
||||
|
||||
def __init__(self):
|
||||
self.c = object()
|
||||
|
||||
class Foo(object):
|
||||
|
||||
def __init__(self):
|
||||
self.b = Bar()
|
||||
|
||||
a = Foo()
|
||||
a_symbol = ValueSymbol('a', a)
|
||||
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
|
||||
a_b_c_symbol = AttributeAccessSymbol(a_b_symbol, 'c')
|
||||
|
||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
|
||||
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a.b.c)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -19,6 +19,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def ld(v):
|
||||
"""Load variable operator."""
|
||||
if isinstance(v, Undefined):
|
||||
return v.read()
|
||||
return v
|
||||
|
||||
|
||||
class Undefined(object):
|
||||
"""Represents an undefined symbol in Python.
|
||||
|
||||
@ -51,6 +58,10 @@ class Undefined(object):
|
||||
def __init__(self, symbol_name):
|
||||
self.symbol_name = symbol_name
|
||||
|
||||
def read(self):
|
||||
raise UnboundLocalError("'{}' is used before assignment".format(
|
||||
self.symbol_name))
|
||||
|
||||
def __repr__(self):
|
||||
return self.symbol_name
|
||||
|
||||
@ -66,34 +77,7 @@ class Undefined(object):
|
||||
return self
|
||||
|
||||
|
||||
def is_undefined(value):
|
||||
"""Checks whether Autograph has determined that a given value is undefined.
|
||||
|
||||
This only works in places where Autograph reifies undefined symbols. Note that
|
||||
if this function is passed a truly undefined symbol the call-site will raise
|
||||
NameError.
|
||||
|
||||
Args:
|
||||
value: value to test for undefinedness
|
||||
Returns:
|
||||
Boolean, whether the input value is undefined.
|
||||
"""
|
||||
return isinstance(value, Undefined)
|
||||
|
||||
|
||||
# TODO(mdan): Refactor as a RetVal object, aggregating the value and do_return.
|
||||
class UndefinedReturnValue(object):
|
||||
"""Represents a default return value from a function (None in Python)."""
|
||||
"""Represents a return value that is undefined."""
|
||||
pass
|
||||
|
||||
|
||||
def retval(value):
|
||||
"""Returns the actual value that a return statement should produce."""
|
||||
if isinstance(value, UndefinedReturnValue):
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def is_undefined_return(value):
|
||||
"""Checks whether `value` is the default return value."""
|
||||
return isinstance(value, UndefinedReturnValue)
|
@ -18,28 +18,38 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.operators import special_values
|
||||
from tensorflow.python.autograph.operators import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SpecialValuesTest(test.TestCase):
|
||||
|
||||
def test_undefined(self):
|
||||
undefined_symbol = special_values.Undefined('name')
|
||||
self.assertEqual(undefined_symbol.symbol_name, 'name')
|
||||
undefined_symbol = variables.Undefined('name')
|
||||
undefined_symbol2 = variables.Undefined('name')
|
||||
|
||||
undefined_symbol2 = special_values.Undefined('name')
|
||||
self.assertEqual(undefined_symbol.symbol_name, 'name')
|
||||
self.assertEqual(undefined_symbol2.symbol_name, 'name')
|
||||
self.assertNotEqual(undefined_symbol, undefined_symbol2)
|
||||
|
||||
self.assertTrue(special_values.is_undefined(undefined_symbol))
|
||||
self.assertTrue(special_values.is_undefined(undefined_symbol2))
|
||||
|
||||
def test_undefined_operations(self):
|
||||
undefined_symbol = special_values.Undefined('name')
|
||||
undefined_symbol = variables.Undefined('name')
|
||||
|
||||
self.assertIsInstance(undefined_symbol.foo, variables.Undefined)
|
||||
self.assertIsInstance(undefined_symbol[0], variables.Undefined)
|
||||
self.assertNotIsInstance(undefined_symbol.__class__, variables.Undefined)
|
||||
|
||||
def test_read(self):
|
||||
self.assertEqual(variables.ld(1), 1)
|
||||
o = object()
|
||||
self.assertEqual(variables.ld(o), o)
|
||||
|
||||
self.assertIsNone(variables.ld(None))
|
||||
|
||||
def test_read_undefined(self):
|
||||
with self.assertRaisesRegex(UnboundLocalError, 'used before assignment'):
|
||||
variables.ld(variables.Undefined('a'))
|
||||
|
||||
self.assertTrue(special_values.is_undefined(undefined_symbol.foo))
|
||||
self.assertTrue(special_values.is_undefined(undefined_symbol[0]))
|
||||
self.assertFalse(special_values.is_undefined(undefined_symbol.__class__))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -78,6 +78,18 @@ class ReachingDefinitionsAnalyzerTest(
|
||||
|
||||
self.assertSameDef(local_body[1].test, local_body[2].value.elts[0])
|
||||
|
||||
# Note: the function name is is visible inside the function body. But it's
|
||||
# a closure variable, not a local.
|
||||
#
|
||||
# Example:
|
||||
#
|
||||
# >>> def f():
|
||||
# ... print(f)
|
||||
# >>> g = f
|
||||
# >>> f = 'something else'
|
||||
# >>> g()
|
||||
# something else
|
||||
#
|
||||
self.assertHasDefinedIn(local_body[1], ('a', 'b'))
|
||||
|
||||
|
||||
|
@ -255,6 +255,9 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
|
||||
|
||||
inner_fn_body = fn_body[1].body[1].body
|
||||
def_of_a_in_foo = inner_fn_body[0].value
|
||||
# Even though `a` is visible in the inner functio above, the late binding
|
||||
# makes it impossible to assume that the same value will be visible at
|
||||
# call time.
|
||||
self.assertHasDefs(def_of_a_in_foo, 0)
|
||||
|
||||
def test_nested_functions_isolation(self):
|
||||
|
Loading…
Reference in New Issue
Block a user