More robustly check for undefined symbols before attempting to use them. This check is required because undefined symbols are initialized with a special placeholder before entering control flow. This placeholder can lead to confusing error messages if left unchecked. The change introduces two more general operators: "variable load" and "return".

PiperOrigin-RevId: 311411422
Change-Id: Ic8abda74c1f68c1d4de491949d309d60099b91b4
This commit is contained in:
Dan Moldovan 2020-05-13 15:00:41 -07:00 committed by TensorFlower Gardener
parent 18c0da1024
commit 4eeb6d742e
23 changed files with 351 additions and 494 deletions

View File

@ -33,6 +33,7 @@ py_library(
"logical_expressions.py",
"return_statements.py",
"slices.py",
"variables.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
@ -213,3 +214,16 @@ py_test(
"//tensorflow/python/autograph/pyct",
],
)
py_test(
name = "variables_test",
srcs = ["variables_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":converters",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/core:test_lib",
"//tensorflow/python/autograph/pyct",
],
)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.autograph.converters import asserts
from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
@ -36,7 +37,8 @@ class AssertsTest(converter_testing.TestCase):
return a
with ops.Graph().as_default():
with self.converted(test_fn, (functions, asserts), {}) as result:
with self.converted(
test_fn, (functions, asserts, return_statements), {}) as result:
op = result.test_fn(constant_op.constant(False))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):

View File

@ -38,15 +38,6 @@ class _Function(object):
class FunctionTransformer(converter.Base):
"""Wraps function bodies around autograph-specific boilerplate."""
def visit_Return(self, node):
if node.value is None:
return node
node = self.generic_visit(node)
return templates.replace(
'return function_context_name.mark_return_value(value)',
function_context_name=self.state[_Function].context_name,
value=node.value)
def _function_scope_options(self, fn_scope):
"""Returns the options with which to create function scopes."""
# Top-level function receive the options that were directly requested.

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import converter_testing
@ -74,7 +75,7 @@ class FunctionTransformer(converter_testing.TestCase):
l += 1
return l, inner_fn(l)
with self.converted(test_fn, functions, {},
with self.converted(test_fn, (functions, return_statements), {},
(ops.name_scope,)) as result:
first, second = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name)
@ -119,6 +120,7 @@ class FunctionTransformer(converter_testing.TestCase):
ns = {'TestClass': TestClass}
node, ctx = self.prepare(TestClass, ns)
node = functions.transform(node, ctx)
node = return_statements.transform(node, ctx)
with self.compiled(node, {}, (ops.name_scope,)) as result:
first, second = result.TestClass().test_fn(constant_op.constant(1))

View File

@ -220,9 +220,9 @@ class ReturnStatementsTransformer(converter.Base):
retval = val
"""
def __init__(self, ctx, default_to_null_return):
def __init__(self, ctx, allow_missing_return):
super(ReturnStatementsTransformer, self).__init__(ctx)
self.default_to_null_return = default_to_null_return
self.allow_missing_return = allow_missing_return
def visit_Return(self, node):
for block in reversed(self.state[_Block].stack):
@ -339,75 +339,68 @@ class ReturnStatementsTransformer(converter.Base):
return node
def visit_FunctionDef(self, node):
self.state[_Function].enter()
self.state[_Block].enter()
self.state[_Block].is_function = True
with self.state[_Function] as fn:
with self.state[_Block] as block:
block.is_function = True
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
do_return_var_name = self.ctx.namer.new_symbol(
'do_return', scope.referenced)
retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced)
self.state[_Function].do_return_var_name = do_return_var_name
self.state[_Function].retval_var_name = retval_var_name
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
do_return_var_name = self.ctx.namer.new_symbol('do_return',
scope.referenced)
retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced)
fn.do_return_var_name = do_return_var_name
fn.retval_var_name = retval_var_name
converted_body = self._visit_statement_block(node, node.body)
node.body = self._visit_statement_block(node, node.body)
# Avoid placing statements before any eventual docstring.
# TODO(mdan): Should a docstring even be included in the output?
docstring = None
if converted_body:
if (isinstance(converted_body[0], gast.Expr) and
isinstance(converted_body[0].value, gast.Constant)):
docstring = converted_body[0]
converted_body = converted_body[1:]
if block.return_used:
if self.state[_Block].return_used:
if self.allow_missing_return:
# The function whould have a single `with` node that wraps the
# entire body. If the function had a docstring, the body has two
# nodes, with the `with` as the second node.
wrapper_node = node.body[-1]
assert isinstance(wrapper_node, gast.With), (
'This transformer requires the functions converter.')
if self.default_to_null_return:
# TODO(mdan): Remove the (do_return_var_name,) below.
# Currently, that line ensures the variable is both defined and alive
# throughout the function.
template = """
do_return_var_name = False
retval_var_name = ag__.UndefinedReturnValue()
body
(do_return_var_name,)
return ag__.retval(retval_var_name)
"""
else:
template = """
body
return retval_var_name
"""
node.body = templates.replace(
template,
body=converted_body,
do_return_var_name=do_return_var_name,
retval_var_name=retval_var_name)
template = """
do_return_var_name = False
retval_var_name = ag__.UndefinedReturnValue()
body
return function_context.ret(retval_var_name, do_return_var_name)
"""
if docstring:
node.body.insert(0, docstring)
wrapper_node.body = templates.replace(
template,
body=wrapper_node.body,
do_return_var_name=do_return_var_name,
function_context=anno.getanno(node, 'function_context_name'),
retval_var_name=retval_var_name)
else:
template = """
body
return retval_var_name
"""
node.body = templates.replace(
template,
body=node.body,
do_return_var_name=do_return_var_name,
retval_var_name=retval_var_name)
self.state[_Block].exit()
self.state[_Function].exit()
return node
def transform(node, ctx, default_to_null_return=True):
"""Ensure a function has only a single return."""
# Note: Technically, these two could be merged into a single walk, but
# keeping them separate helps with readability.
"""Ensure a function has only a single return, at the end."""
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
# Note: Technically, these two could be merged into a single walk, but
# keeping them separate helps with readability.
node = ConditionalReturnRewriter(ctx).visit(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
transformer = ReturnStatementsTransformer(
ctx, default_to_null_return=default_to_null_return)
ctx, allow_missing_return=default_to_null_return)
node = transformer.visit(node)
return node

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import ops
@ -28,7 +29,7 @@ class SingleReturnTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
ns = {'ops': ops}
with self.converted(test_fn, return_statements, ns) as result:
with self.converted(test_fn, (functions, return_statements), ns) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_straightline(self):

View File

@ -0,0 +1,76 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Overloads all variable read operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import templates
class VariableAccessTransformer(converter.Base):
"""Rewrites basic symbol reads.
This transformer rewrites variable reads with a "read" operator which allows
tracking activity.
Example:
For a basic statement:
a = b + c
This is translated to:
a = ld(b) + ld(c)
Augmented assignment operations also introduce a `ld` operator:
a += b
The assignment target also receives an operator to properly represent the
read:
a = ld(a)
a += ld(b)
"""
def visit_Name(self, node):
# Only the loads which existed in the original code are overloaded.
if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS):
return node
if isinstance(node.ctx, gast.Load):
node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
return node
def visit_AugAssign(self, node):
if isinstance(node.target, gast.Name):
template = """
var_ = ag__.ld(var_)
original
"""
node = templates.replace(template, var_=node.target, original=node)
else:
node = self.generic_visit(node)
return node
def transform(node, ctx):
return VariableAccessTransformer(ctx).visit(node)

View File

@ -0,0 +1,116 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for variables module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python.autograph.converters import variables
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
class VariablesTest(converter_testing.TestCase):
@contextlib.contextmanager
def apply_add_one_conversion(self, fn):
"""Generates code which adds 1 to all variable reads."""
with self.converted(fn, variables, {}) as result:
result.ag__.__dict__['ld'] = lambda x: x + 1
yield result
def test_read(self):
def test_fn(l):
return l
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(1), 2)
def test_aug_assign(self):
def test_fn(l):
l *= 10
return l
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads
def test_attribute(self):
class TestClass(object):
def __init__(self):
self.v = 1
def __add__(self, other):
self.v += other
return self
def test_fn(l):
return l.v
tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(tc), 2)
def test_subscript(self):
class TestClass(object):
def __init__(self):
self.v = 1
def __add__(self, other):
self.v += other
return self
def __getitem__(self, _):
return self.v
def test_fn(l):
return l[0]
tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(tc), 2)
def test_call(self):
class TestClass(object):
def __init__(self):
self.v = 1
def __add__(self, other):
self.v += other
return self
def __call__(self):
return self.v
def test_fn(l):
return l()
tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(tc), 2)
if __name__ == '__main__':
test.main()

View File

@ -30,6 +30,7 @@ py_library(
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python/autograph/operators",
"//tensorflow/python/autograph/pyct",
"//tensorflow/python/autograph/pyct/static_analysis",
"//tensorflow/python/autograph/utils",

View File

@ -20,12 +20,16 @@ from __future__ import print_function
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.operators import variables
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.util import nest
# TODO(mdan): Move this into operators - it represents a function definition.
class FunctionScope(object):
"""Context manager that wraps the body of a converted function.
@ -84,8 +88,13 @@ class FunctionScope(object):
if self.use_auto_deps:
self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb)
def mark_return_value(self, value):
def ret(self, value, did_return):
"""Marks a value as returned from the function guarded by the scope."""
del did_return
if isinstance(value, variables.UndefinedReturnValue):
return None
if self.use_auto_deps:
self._return_value_marked = True
if value is None:

View File

@ -46,7 +46,7 @@ class FunctionWrappersTest(test.TestCase):
converter.ConversionOptions(
optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope:
v.assign(2)
op = scope.mark_return_value(constant_op.constant(1))
op = scope.ret(constant_op.constant(1), True)
self.evaluate(op)
self.assertEqual(self.evaluate(v.read_value()), 2)

View File

@ -36,6 +36,7 @@ from tensorflow.python.autograph.converters import lists
from tensorflow.python.autograph.converters import logical_expressions
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.converters import slices
from tensorflow.python.autograph.converters import variables
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
@ -92,6 +93,7 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
node = control_flow.transform(node, ctx)
node = conditional_expressions.transform(node, ctx)
node = logical_expressions.transform(node, ctx)
node = variables.transform(node, ctx)
return node

View File

@ -29,8 +29,7 @@ py_library(
"logical.py",
"py_builtins.py",
"slices.py",
"special_values.py",
"symbols.py",
"variables.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
@ -148,19 +147,8 @@ py_test(
)
py_test(
name = "special_values_test",
srcs = ["special_values_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":operators",
"//tensorflow/python:client_testlib",
],
)
py_test(
name = "symbols_test",
srcs = ["symbols_test.py"],
name = "variables_test",
srcs = ["variables_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [

View File

@ -60,8 +60,6 @@ from tensorflow.python.autograph.operators.py_builtins import range_
from tensorflow.python.autograph.operators.slices import get_item
from tensorflow.python.autograph.operators.slices import GetItemOpts
from tensorflow.python.autograph.operators.slices import set_item
from tensorflow.python.autograph.operators.special_values import is_undefined
from tensorflow.python.autograph.operators.special_values import is_undefined_return
from tensorflow.python.autograph.operators.special_values import retval
from tensorflow.python.autograph.operators.special_values import Undefined
from tensorflow.python.autograph.operators.special_values import UndefinedReturnValue
from tensorflow.python.autograph.operators.variables import ld
from tensorflow.python.autograph.operators.variables import Undefined
from tensorflow.python.autograph.operators.variables import UndefinedReturnValue

View File

@ -65,7 +65,7 @@ import traceback
import numpy as np
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.autograph.operators import special_values
from tensorflow.python.autograph.operators import variables
from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.autograph.utils import compat_util
from tensorflow.python.autograph.utils import misc
@ -103,13 +103,13 @@ def _verify_loop_init_vars(values, symbol_names):
for name, value in zip(symbol_names, values):
if value is None:
raise ValueError('"{}" may not be None before the loop.'.format(name))
if special_values.is_undefined_return(value):
if isinstance(value, variables.UndefinedReturnValue):
# Assumption: the loop will only capture the variable which tracks the
# return value if the loop contained a return statement.
# TODO(mdan): This should be checked at the place where return occurs.
raise ValueError(
'return statements are not supported within a TensorFlow loop.')
if special_values.is_undefined(value):
if isinstance(value, variables.Undefined):
raise ValueError('"{}" must be defined before the loop.'.format(name))
@ -495,8 +495,7 @@ def _tf_range_for_stmt(
iterate = compat_util.BasicRef(start)
def _value_or(name, var, default):
if (name == opts['iterate_names']
and isinstance(var, special_values.Undefined)):
if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
return default
return var
@ -1019,7 +1018,15 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
results_tuple = results
else:
results_tuple = results,
undefined = tuple(filter(special_values.is_undefined, results_tuple))
for result in results_tuple:
if isinstance(result, variables.UndefinedReturnValue):
raise ValueError(
'A value must also be returned from the {} branch. If a value is '
'returned from one branch of a conditional a value must be '
'returned from all branches.'.format(branch_name))
undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
if undefined:
raise ValueError(
'The following symbols must also be initialized in the {} branch: {}.'
@ -1027,13 +1034,6 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
' statement.'.format(branch_name,
tuple(s.symbol_name for s in undefined)))
for result in results_tuple:
if special_values.is_undefined_return(result):
raise ValueError(
'A value must also be returned from the {} branch. If a value is '
'returned from one branch of a conditional a value must be '
'returned from all branches.'.format(branch_name))
return results
return wrapper

View File

@ -66,7 +66,7 @@ import functools
import numpy as np
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.autograph.operators import special_values
from tensorflow.python.autograph.operators import variables
from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.autograph.utils import misc
from tensorflow.python.autograph.utils import tensors
@ -103,13 +103,13 @@ INEFFICIENT_UNROLL_MIN_OPS = 1
def _disallow_undefs_into_loop(*values):
"""Ensures that all values in the state are defined when entering a loop."""
undefined = tuple(filter(special_values.is_undefined, values))
undefined = [v for v in values if isinstance(v, variables.Undefined)]
if undefined:
raise ValueError(
'{} must be defined before the loop.'.format(
','.join(s.symbol_name for s in undefined)))
for value in values:
if special_values.is_undefined_return(value):
if isinstance(value, variables.UndefinedReturnValue):
# Assumption: the loop will only capture the variable which tracks the
# return value if the loop contained a return statement.
# TODO(mdan): This should be checked at the place where return occurs.
@ -1129,7 +1129,7 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
results_tuple = results
else:
results_tuple = results,
undefined = tuple(filter(special_values.is_undefined, results_tuple))
undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
if undefined:
raise ValueError(
'The following symbols must also be initialized in the {} branch: {}.'
@ -1138,7 +1138,7 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
tuple(s.symbol_name for s in undefined)))
for result in results_tuple:
if special_values.is_undefined_return(result):
if isinstance(result, variables.UndefinedReturnValue):
raise ValueError(
'A value must also be returned from the {} branch. If a value is '
'returned from one branch of a conditional a value must be '

View File

@ -29,7 +29,7 @@ import numpy as np
import six
from tensorflow.python.autograph.operators import control_flow
from tensorflow.python.autograph.operators import special_values
from tensorflow.python.autograph.operators import variables as variable_operators
from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import def_function
@ -546,7 +546,7 @@ class ForLoopTest(test.TestCase):
with self.assertRaisesRegex(ValueError, '"s" may not be None'):
self._basic_loop(None, lambda i, s: s)
with self.assertRaisesRegex(ValueError, '"s" must be defined'):
self._basic_loop(special_values.Undefined(''), lambda i, s: s)
self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
def test_tensor_none_output(self):
with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
@ -785,7 +785,7 @@ class WhileLoopTest(test.TestCase):
with self.assertRaisesRegex(ValueError, '"s" may not be None'):
self._basic_loop(None, lambda i, s: s)
with self.assertRaisesRegex(ValueError, '"s" must be defined'):
self._basic_loop(special_values.Undefined(''), lambda i, s: s)
self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
def test_tensor_none_output(self):
with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
@ -887,10 +887,10 @@ class IfStmtTest(test.TestCase):
def test_tensor_undefined_output(self):
with self.assertRaisesRegex(
ValueError, "must also be initialized in the if.*'s'"):
self._basic_cond(lambda: special_values.Undefined('s'), lambda: 1)
self._basic_cond(lambda: variable_operators.Undefined('s'), lambda: 1)
with self.assertRaisesRegex(
ValueError, "must also be initialized in the else.*'s'"):
self._basic_cond(lambda: 1, lambda: special_values.Undefined('s'))
self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s'))
def test_tensor_dtype_change(self):
with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'):

View File

@ -1,115 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstract representation of composite symbols that can be used in staging code.
This provides a way to checkpoint the values of symbols that may be undefined
entering staged control flow. This checkpointing is necessary to prevent some
unintended side-effects. For example checkpointing prevents side-effects in one
branch of a conditional from leaking into another.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import special_values
is_undefined = special_values.is_undefined
Undefined = special_values.Undefined
class Symbol(object):
"""Representation of a simple or composite Python symbol.
Subclasses should implement `maybe_compute_value(self)` that returns the value
corresponding to the symbol or Undefined if no such value exists.
"""
def __init__(self, name):
self.name = name
class ValueSymbol(Symbol):
"""Representation of a simple Python symbol with a concrete value.
This includes variables and literals. Since we are reifying undefined symbols
`Undefined` is also a valid value.
"""
def __init__(self, name, value):
super(ValueSymbol, self).__init__(name)
self.value = value
def maybe_compute_value(self):
return self.value
class AttributeAccessSymbol(Symbol):
"""Representation of Python attribute access e.g. `a.b`."""
def __init__(self, parent_symbol, attr_name):
super(AttributeAccessSymbol, self).__init__(
parent_symbol.name + '.' + attr_name)
self.attr_name = attr_name
self.parent_symbol = parent_symbol
def maybe_compute_value(self):
"""Compute the value corresponding to the attribute access or `Undefined`.
This will be `Undefined` if no such value exists either because there is no
such attribute or if the base is itself undefined.
Returns:
value corresponding to the attribute access or `Undefined`
"""
parent_value = self.parent_symbol.maybe_compute_value()
if (is_undefined(parent_value) or
getattr(parent_value, self.attr_name, None) is None):
return Undefined(self.name)
return parent_value.__getattribute__(self.attr_name)
class SubscriptSymbol(Symbol):
"""Representation of Python subscript access e.g. `a[b]`."""
def __init__(self, parent_symbol, index_symbol):
super(SubscriptSymbol, self).__init__(
parent_symbol.name + '[' + index_symbol.name + ']')
self.index_symbol = index_symbol
self.parent_symbol = parent_symbol
def maybe_compute_value(self):
"""Compute the value corresponding to the subscript access or `Undefined`.
This will be `Undefined` if no such value exists either because there is no
element corresponding to the given subscript or if the base itself is
not defined.
Returns:
value corresponding to the subscript access or `Undefined`
"""
parent_value = self.parent_symbol.maybe_compute_value()
index_value = self.index_symbol.maybe_compute_value()
if is_undefined(parent_value) or is_undefined(index_value):
return Undefined(self.name)
try:
return parent_value[index_value]
except (IndexError, KeyError, TypeError):
# Reify the lack of an object for the given index/key
# This allows us to define them later without regret
return Undefined(self.name)

View File

@ -1,230 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for special symbol handling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import special_values
from tensorflow.python.autograph.operators import symbols
from tensorflow.python.platform import test
Undefined = special_values.Undefined
AttributeAccessSymbol = symbols.AttributeAccessSymbol
SubscriptSymbol = symbols.SubscriptSymbol
ValueSymbol = symbols.ValueSymbol
class SymbolsTest(test.TestCase):
def test_value_symbol_returns_value(self):
a = 42
a_symbol = ValueSymbol('a', a)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(a_symbol.name, 'a')
def test_attribute_access_missing_attribute(self):
class Foo(object):
pass
a = Foo()
a_symbol = ValueSymbol('a', a)
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a.b')
def test_attribute_access_undefined_target(self):
a = Undefined('a')
a_symbol = ValueSymbol('a', a)
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a.b')
def test_attribute_access_basic(self):
class Foo(object):
def __init__(self):
self.b = 'this is an attribute'
a = Foo()
a_symbol = ValueSymbol('a', a)
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
def test_item_access_undefined_index(self):
class Foo(object):
def __getitem__(self, key):
return 'this is an item'
a = Foo()
b = Undefined('b')
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
def test_item_access_no_getitem(self):
class Foo(object):
pass
a = Foo()
b = 42
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
def test_item_access_undefined_root(self):
a = Undefined('a')
b = 42
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
def test_item_access_basic(self):
class Foo(object):
def __getitem__(self, key):
return 'this is an item'
a = Foo()
b = 42
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
def test_item_access_after_attribute_access(self):
class Foo(object):
def __getitem__(self, key):
return 'this is an item'
class Bar(object):
def __init__(self):
self.b = Foo()
a = Bar()
c = 42
a_symbol = ValueSymbol('a', a)
c_symbol = ValueSymbol('c', c)
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
a_b_c_symbol = SubscriptSymbol(a_b_symbol, c_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(c_symbol.maybe_compute_value(), c)
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a.b[c])
def test_attribute_access_after_item_access(self):
class Bar(object):
def __init__(self):
self.c = object()
item = Bar()
class Foo(object):
def __getitem__(self, key):
return item
a = Foo()
b = 42
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
a_b_c_symbol = AttributeAccessSymbol(a_b_symbol, 'c')
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a[b].c)
def test_item_access_after_item_access(self):
class Bar(object):
def __getitem__(self, key):
return 'this is an item'
item = Bar()
class Foo(object):
def __getitem__(self, key):
return item
a = Foo()
b = 42
c = 43
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
c_symbol = ValueSymbol('b', c)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
a_b_c_symbol = SubscriptSymbol(a_b_symbol, c_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a[b][c])
def test_attribute_access_after_attribute_access(self):
class Bar(object):
def __init__(self):
self.c = object()
class Foo(object):
def __init__(self):
self.b = Bar()
a = Foo()
a_symbol = ValueSymbol('a', a)
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
a_b_c_symbol = AttributeAccessSymbol(a_b_symbol, 'c')
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a.b.c)
if __name__ == '__main__':
test.main()

View File

@ -19,6 +19,13 @@ from __future__ import division
from __future__ import print_function
def ld(v):
"""Load variable operator."""
if isinstance(v, Undefined):
return v.read()
return v
class Undefined(object):
"""Represents an undefined symbol in Python.
@ -51,6 +58,10 @@ class Undefined(object):
def __init__(self, symbol_name):
self.symbol_name = symbol_name
def read(self):
raise UnboundLocalError("'{}' is used before assignment".format(
self.symbol_name))
def __repr__(self):
return self.symbol_name
@ -66,34 +77,7 @@ class Undefined(object):
return self
def is_undefined(value):
"""Checks whether Autograph has determined that a given value is undefined.
This only works in places where Autograph reifies undefined symbols. Note that
if this function is passed a truly undefined symbol the call-site will raise
NameError.
Args:
value: value to test for undefinedness
Returns:
Boolean, whether the input value is undefined.
"""
return isinstance(value, Undefined)
# TODO(mdan): Refactor as a RetVal object, aggregating the value and do_return.
class UndefinedReturnValue(object):
"""Represents a default return value from a function (None in Python)."""
"""Represents a return value that is undefined."""
pass
def retval(value):
"""Returns the actual value that a return statement should produce."""
if isinstance(value, UndefinedReturnValue):
return None
return value
def is_undefined_return(value):
"""Checks whether `value` is the default return value."""
return isinstance(value, UndefinedReturnValue)

View File

@ -18,28 +18,38 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import special_values
from tensorflow.python.autograph.operators import variables
from tensorflow.python.platform import test
class SpecialValuesTest(test.TestCase):
def test_undefined(self):
undefined_symbol = special_values.Undefined('name')
self.assertEqual(undefined_symbol.symbol_name, 'name')
undefined_symbol = variables.Undefined('name')
undefined_symbol2 = variables.Undefined('name')
undefined_symbol2 = special_values.Undefined('name')
self.assertEqual(undefined_symbol.symbol_name, 'name')
self.assertEqual(undefined_symbol2.symbol_name, 'name')
self.assertNotEqual(undefined_symbol, undefined_symbol2)
self.assertTrue(special_values.is_undefined(undefined_symbol))
self.assertTrue(special_values.is_undefined(undefined_symbol2))
def test_undefined_operations(self):
undefined_symbol = special_values.Undefined('name')
undefined_symbol = variables.Undefined('name')
self.assertIsInstance(undefined_symbol.foo, variables.Undefined)
self.assertIsInstance(undefined_symbol[0], variables.Undefined)
self.assertNotIsInstance(undefined_symbol.__class__, variables.Undefined)
def test_read(self):
self.assertEqual(variables.ld(1), 1)
o = object()
self.assertEqual(variables.ld(o), o)
self.assertIsNone(variables.ld(None))
def test_read_undefined(self):
with self.assertRaisesRegex(UnboundLocalError, 'used before assignment'):
variables.ld(variables.Undefined('a'))
self.assertTrue(special_values.is_undefined(undefined_symbol.foo))
self.assertTrue(special_values.is_undefined(undefined_symbol[0]))
self.assertFalse(special_values.is_undefined(undefined_symbol.__class__))
if __name__ == '__main__':
test.main()

View File

@ -78,6 +78,18 @@ class ReachingDefinitionsAnalyzerTest(
self.assertSameDef(local_body[1].test, local_body[2].value.elts[0])
# Note: the function name is is visible inside the function body. But it's
# a closure variable, not a local.
#
# Example:
#
# >>> def f():
# ... print(f)
# >>> g = f
# >>> f = 'something else'
# >>> g()
# something else
#
self.assertHasDefinedIn(local_body[1], ('a', 'b'))

View File

@ -255,6 +255,9 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
inner_fn_body = fn_body[1].body[1].body
def_of_a_in_foo = inner_fn_body[0].value
# Even though `a` is visible in the inner functio above, the late binding
# makes it impossible to assume that the same value will be visible at
# call time.
self.assertHasDefs(def_of_a_in_foo, 0)
def test_nested_functions_isolation(self):