More robustly check for undefined symbols before attempting to use them. This check is required because undefined symbols are initialized with a special placeholder before entering control flow. This placeholder can lead to confusing error messages if left unchecked. The change introduces two more general operators: "variable load" and "return".

PiperOrigin-RevId: 311411422
Change-Id: Ic8abda74c1f68c1d4de491949d309d60099b91b4
This commit is contained in:
Dan Moldovan 2020-05-13 15:00:41 -07:00 committed by TensorFlower Gardener
parent 18c0da1024
commit 4eeb6d742e
23 changed files with 351 additions and 494 deletions

View File

@ -33,6 +33,7 @@ py_library(
"logical_expressions.py", "logical_expressions.py",
"return_statements.py", "return_statements.py",
"slices.py", "slices.py",
"variables.py",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
@ -213,3 +214,16 @@ py_test(
"//tensorflow/python/autograph/pyct", "//tensorflow/python/autograph/pyct",
], ],
) )
py_test(
name = "variables_test",
srcs = ["variables_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":converters",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/core:test_lib",
"//tensorflow/python/autograph/pyct",
],
)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.autograph.converters import asserts from tensorflow.python.autograph.converters import asserts
from tensorflow.python.autograph.converters import functions from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
@ -36,7 +37,8 @@ class AssertsTest(converter_testing.TestCase):
return a return a
with ops.Graph().as_default(): with ops.Graph().as_default():
with self.converted(test_fn, (functions, asserts), {}) as result: with self.converted(
test_fn, (functions, asserts, return_statements), {}) as result:
op = result.test_fn(constant_op.constant(False)) op = result.test_fn(constant_op.constant(False))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):

View File

@ -38,15 +38,6 @@ class _Function(object):
class FunctionTransformer(converter.Base): class FunctionTransformer(converter.Base):
"""Wraps function bodies around autograph-specific boilerplate.""" """Wraps function bodies around autograph-specific boilerplate."""
def visit_Return(self, node):
if node.value is None:
return node
node = self.generic_visit(node)
return templates.replace(
'return function_context_name.mark_return_value(value)',
function_context_name=self.state[_Function].context_name,
value=node.value)
def _function_scope_options(self, fn_scope): def _function_scope_options(self, fn_scope):
"""Returns the options with which to create function scopes.""" """Returns the options with which to create function scopes."""
# Top-level function receive the options that were directly requested. # Top-level function receive the options that were directly requested.

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.autograph.converters import functions from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
@ -74,7 +75,7 @@ class FunctionTransformer(converter_testing.TestCase):
l += 1 l += 1
return l, inner_fn(l) return l, inner_fn(l)
with self.converted(test_fn, functions, {}, with self.converted(test_fn, (functions, return_statements), {},
(ops.name_scope,)) as result: (ops.name_scope,)) as result:
first, second = result.test_fn(constant_op.constant(1)) first, second = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name) self.assertIn('test_fn/', first.op.name)
@ -119,6 +120,7 @@ class FunctionTransformer(converter_testing.TestCase):
ns = {'TestClass': TestClass} ns = {'TestClass': TestClass}
node, ctx = self.prepare(TestClass, ns) node, ctx = self.prepare(TestClass, ns)
node = functions.transform(node, ctx) node = functions.transform(node, ctx)
node = return_statements.transform(node, ctx)
with self.compiled(node, {}, (ops.name_scope,)) as result: with self.compiled(node, {}, (ops.name_scope,)) as result:
first, second = result.TestClass().test_fn(constant_op.constant(1)) first, second = result.TestClass().test_fn(constant_op.constant(1))

View File

@ -220,9 +220,9 @@ class ReturnStatementsTransformer(converter.Base):
retval = val retval = val
""" """
def __init__(self, ctx, default_to_null_return): def __init__(self, ctx, allow_missing_return):
super(ReturnStatementsTransformer, self).__init__(ctx) super(ReturnStatementsTransformer, self).__init__(ctx)
self.default_to_null_return = default_to_null_return self.allow_missing_return = allow_missing_return
def visit_Return(self, node): def visit_Return(self, node):
for block in reversed(self.state[_Block].stack): for block in reversed(self.state[_Block].stack):
@ -339,41 +339,42 @@ class ReturnStatementsTransformer(converter.Base):
return node return node
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
self.state[_Function].enter() with self.state[_Function] as fn:
self.state[_Block].enter() with self.state[_Block] as block:
self.state[_Block].is_function = True block.is_function = True
scope = anno.getanno(node, NodeAnno.BODY_SCOPE) scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
do_return_var_name = self.ctx.namer.new_symbol( do_return_var_name = self.ctx.namer.new_symbol('do_return',
'do_return', scope.referenced) scope.referenced)
retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced) retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced)
self.state[_Function].do_return_var_name = do_return_var_name fn.do_return_var_name = do_return_var_name
self.state[_Function].retval_var_name = retval_var_name fn.retval_var_name = retval_var_name
converted_body = self._visit_statement_block(node, node.body) node.body = self._visit_statement_block(node, node.body)
# Avoid placing statements before any eventual docstring. if block.return_used:
# TODO(mdan): Should a docstring even be included in the output?
docstring = None
if converted_body:
if (isinstance(converted_body[0], gast.Expr) and
isinstance(converted_body[0].value, gast.Constant)):
docstring = converted_body[0]
converted_body = converted_body[1:]
if self.state[_Block].return_used: if self.allow_missing_return:
# The function whould have a single `with` node that wraps the
# entire body. If the function had a docstring, the body has two
# nodes, with the `with` as the second node.
wrapper_node = node.body[-1]
assert isinstance(wrapper_node, gast.With), (
'This transformer requires the functions converter.')
if self.default_to_null_return:
# TODO(mdan): Remove the (do_return_var_name,) below.
# Currently, that line ensures the variable is both defined and alive
# throughout the function.
template = """ template = """
do_return_var_name = False do_return_var_name = False
retval_var_name = ag__.UndefinedReturnValue() retval_var_name = ag__.UndefinedReturnValue()
body body
(do_return_var_name,) return function_context.ret(retval_var_name, do_return_var_name)
return ag__.retval(retval_var_name)
""" """
wrapper_node.body = templates.replace(
template,
body=wrapper_node.body,
do_return_var_name=do_return_var_name,
function_context=anno.getanno(node, 'function_context_name'),
retval_var_name=retval_var_name)
else: else:
template = """ template = """
body body
@ -381,33 +382,25 @@ class ReturnStatementsTransformer(converter.Base):
""" """
node.body = templates.replace( node.body = templates.replace(
template, template,
body=converted_body, body=node.body,
do_return_var_name=do_return_var_name, do_return_var_name=do_return_var_name,
retval_var_name=retval_var_name) retval_var_name=retval_var_name)
if docstring:
node.body.insert(0, docstring)
self.state[_Block].exit()
self.state[_Function].exit()
return node return node
def transform(node, ctx, default_to_null_return=True): def transform(node, ctx, default_to_null_return=True):
"""Ensure a function has only a single return.""" """Ensure a function has only a single return, at the end."""
# Note: Technically, these two could be merged into a single walk, but
# keeping them separate helps with readability.
node = qual_names.resolve(node) node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None) node = activity.resolve(node, ctx, None)
# Note: Technically, these two could be merged into a single walk, but
# keeping them separate helps with readability.
node = ConditionalReturnRewriter(ctx).visit(node) node = ConditionalReturnRewriter(ctx).visit(node)
node = qual_names.resolve(node) node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None) node = activity.resolve(node, ctx, None)
transformer = ReturnStatementsTransformer( transformer = ReturnStatementsTransformer(
ctx, default_to_null_return=default_to_null_return) ctx, allow_missing_return=default_to_null_return)
node = transformer.visit(node) node = transformer.visit(node)
return node return node

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.converters import return_statements from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -28,7 +29,7 @@ class SingleReturnTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs): def assertTransformedEquivalent(self, test_fn, *inputs):
ns = {'ops': ops} ns = {'ops': ops}
with self.converted(test_fn, return_statements, ns) as result: with self.converted(test_fn, (functions, return_statements), ns) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_straightline(self): def test_straightline(self):

View File

@ -0,0 +1,76 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Overloads all variable read operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import templates
class VariableAccessTransformer(converter.Base):
"""Rewrites basic symbol reads.
This transformer rewrites variable reads with a "read" operator which allows
tracking activity.
Example:
For a basic statement:
a = b + c
This is translated to:
a = ld(b) + ld(c)
Augmented assignment operations also introduce a `ld` operator:
a += b
The assignment target also receives an operator to properly represent the
read:
a = ld(a)
a += ld(b)
"""
def visit_Name(self, node):
# Only the loads which existed in the original code are overloaded.
if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS):
return node
if isinstance(node.ctx, gast.Load):
node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
return node
def visit_AugAssign(self, node):
if isinstance(node.target, gast.Name):
template = """
var_ = ag__.ld(var_)
original
"""
node = templates.replace(template, var_=node.target, original=node)
else:
node = self.generic_visit(node)
return node
def transform(node, ctx):
return VariableAccessTransformer(ctx).visit(node)

View File

@ -0,0 +1,116 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for variables module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python.autograph.converters import variables
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
class VariablesTest(converter_testing.TestCase):
@contextlib.contextmanager
def apply_add_one_conversion(self, fn):
"""Generates code which adds 1 to all variable reads."""
with self.converted(fn, variables, {}) as result:
result.ag__.__dict__['ld'] = lambda x: x + 1
yield result
def test_read(self):
def test_fn(l):
return l
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(1), 2)
def test_aug_assign(self):
def test_fn(l):
l *= 10
return l
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads
def test_attribute(self):
class TestClass(object):
def __init__(self):
self.v = 1
def __add__(self, other):
self.v += other
return self
def test_fn(l):
return l.v
tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(tc), 2)
def test_subscript(self):
class TestClass(object):
def __init__(self):
self.v = 1
def __add__(self, other):
self.v += other
return self
def __getitem__(self, _):
return self.v
def test_fn(l):
return l[0]
tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(tc), 2)
def test_call(self):
class TestClass(object):
def __init__(self):
self.v = 1
def __add__(self, other):
self.v += other
return self
def __call__(self):
return self.v
def test_fn(l):
return l()
tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(tc), 2)
if __name__ == '__main__':
test.main()

View File

@ -30,6 +30,7 @@ py_library(
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
deps = [ deps = [
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python/autograph/operators",
"//tensorflow/python/autograph/pyct", "//tensorflow/python/autograph/pyct",
"//tensorflow/python/autograph/pyct/static_analysis", "//tensorflow/python/autograph/pyct/static_analysis",
"//tensorflow/python/autograph/utils", "//tensorflow/python/autograph/utils",

View File

@ -20,12 +20,16 @@ from __future__ import print_function
from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.operators import variables
from tensorflow.python.framework import auto_control_deps from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.util import nest from tensorflow.python.util import nest
# TODO(mdan): Move this into operators - it represents a function definition.
class FunctionScope(object): class FunctionScope(object):
"""Context manager that wraps the body of a converted function. """Context manager that wraps the body of a converted function.
@ -84,8 +88,13 @@ class FunctionScope(object):
if self.use_auto_deps: if self.use_auto_deps:
self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb) self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb)
def mark_return_value(self, value): def ret(self, value, did_return):
"""Marks a value as returned from the function guarded by the scope.""" """Marks a value as returned from the function guarded by the scope."""
del did_return
if isinstance(value, variables.UndefinedReturnValue):
return None
if self.use_auto_deps: if self.use_auto_deps:
self._return_value_marked = True self._return_value_marked = True
if value is None: if value is None:

View File

@ -46,7 +46,7 @@ class FunctionWrappersTest(test.TestCase):
converter.ConversionOptions( converter.ConversionOptions(
optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope: optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope:
v.assign(2) v.assign(2)
op = scope.mark_return_value(constant_op.constant(1)) op = scope.ret(constant_op.constant(1), True)
self.evaluate(op) self.evaluate(op)
self.assertEqual(self.evaluate(v.read_value()), 2) self.assertEqual(self.evaluate(v.read_value()), 2)

View File

@ -36,6 +36,7 @@ from tensorflow.python.autograph.converters import lists
from tensorflow.python.autograph.converters import logical_expressions from tensorflow.python.autograph.converters import logical_expressions
from tensorflow.python.autograph.converters import return_statements from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.converters import slices from tensorflow.python.autograph.converters import slices
from tensorflow.python.autograph.converters import variables
from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers from tensorflow.python.autograph.core import function_wrappers
@ -92,6 +93,7 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
node = control_flow.transform(node, ctx) node = control_flow.transform(node, ctx)
node = conditional_expressions.transform(node, ctx) node = conditional_expressions.transform(node, ctx)
node = logical_expressions.transform(node, ctx) node = logical_expressions.transform(node, ctx)
node = variables.transform(node, ctx)
return node return node

View File

@ -29,8 +29,7 @@ py_library(
"logical.py", "logical.py",
"py_builtins.py", "py_builtins.py",
"slices.py", "slices.py",
"special_values.py", "variables.py",
"symbols.py",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
@ -148,19 +147,8 @@ py_test(
) )
py_test( py_test(
name = "special_values_test", name = "variables_test",
srcs = ["special_values_test.py"], srcs = ["variables_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":operators",
"//tensorflow/python:client_testlib",
],
)
py_test(
name = "symbols_test",
srcs = ["symbols_test.py"],
python_version = "PY3", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [

View File

@ -60,8 +60,6 @@ from tensorflow.python.autograph.operators.py_builtins import range_
from tensorflow.python.autograph.operators.slices import get_item from tensorflow.python.autograph.operators.slices import get_item
from tensorflow.python.autograph.operators.slices import GetItemOpts from tensorflow.python.autograph.operators.slices import GetItemOpts
from tensorflow.python.autograph.operators.slices import set_item from tensorflow.python.autograph.operators.slices import set_item
from tensorflow.python.autograph.operators.special_values import is_undefined from tensorflow.python.autograph.operators.variables import ld
from tensorflow.python.autograph.operators.special_values import is_undefined_return from tensorflow.python.autograph.operators.variables import Undefined
from tensorflow.python.autograph.operators.special_values import retval from tensorflow.python.autograph.operators.variables import UndefinedReturnValue
from tensorflow.python.autograph.operators.special_values import Undefined
from tensorflow.python.autograph.operators.special_values import UndefinedReturnValue

View File

@ -65,7 +65,7 @@ import traceback
import numpy as np import numpy as np
from tensorflow.python.autograph.operators import py_builtins from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.autograph.operators import special_values from tensorflow.python.autograph.operators import variables
from tensorflow.python.autograph.utils import ag_logging from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.autograph.utils import compat_util from tensorflow.python.autograph.utils import compat_util
from tensorflow.python.autograph.utils import misc from tensorflow.python.autograph.utils import misc
@ -103,13 +103,13 @@ def _verify_loop_init_vars(values, symbol_names):
for name, value in zip(symbol_names, values): for name, value in zip(symbol_names, values):
if value is None: if value is None:
raise ValueError('"{}" may not be None before the loop.'.format(name)) raise ValueError('"{}" may not be None before the loop.'.format(name))
if special_values.is_undefined_return(value): if isinstance(value, variables.UndefinedReturnValue):
# Assumption: the loop will only capture the variable which tracks the # Assumption: the loop will only capture the variable which tracks the
# return value if the loop contained a return statement. # return value if the loop contained a return statement.
# TODO(mdan): This should be checked at the place where return occurs. # TODO(mdan): This should be checked at the place where return occurs.
raise ValueError( raise ValueError(
'return statements are not supported within a TensorFlow loop.') 'return statements are not supported within a TensorFlow loop.')
if special_values.is_undefined(value): if isinstance(value, variables.Undefined):
raise ValueError('"{}" must be defined before the loop.'.format(name)) raise ValueError('"{}" must be defined before the loop.'.format(name))
@ -495,8 +495,7 @@ def _tf_range_for_stmt(
iterate = compat_util.BasicRef(start) iterate = compat_util.BasicRef(start)
def _value_or(name, var, default): def _value_or(name, var, default):
if (name == opts['iterate_names'] if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
and isinstance(var, special_values.Undefined)):
return default return default
return var return var
@ -1019,7 +1018,15 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
results_tuple = results results_tuple = results
else: else:
results_tuple = results, results_tuple = results,
undefined = tuple(filter(special_values.is_undefined, results_tuple))
for result in results_tuple:
if isinstance(result, variables.UndefinedReturnValue):
raise ValueError(
'A value must also be returned from the {} branch. If a value is '
'returned from one branch of a conditional a value must be '
'returned from all branches.'.format(branch_name))
undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
if undefined: if undefined:
raise ValueError( raise ValueError(
'The following symbols must also be initialized in the {} branch: {}.' 'The following symbols must also be initialized in the {} branch: {}.'
@ -1027,13 +1034,6 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
' statement.'.format(branch_name, ' statement.'.format(branch_name,
tuple(s.symbol_name for s in undefined))) tuple(s.symbol_name for s in undefined)))
for result in results_tuple:
if special_values.is_undefined_return(result):
raise ValueError(
'A value must also be returned from the {} branch. If a value is '
'returned from one branch of a conditional a value must be '
'returned from all branches.'.format(branch_name))
return results return results
return wrapper return wrapper

View File

@ -66,7 +66,7 @@ import functools
import numpy as np import numpy as np
from tensorflow.python.autograph.operators import py_builtins from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.autograph.operators import special_values from tensorflow.python.autograph.operators import variables
from tensorflow.python.autograph.utils import ag_logging from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.autograph.utils import misc from tensorflow.python.autograph.utils import misc
from tensorflow.python.autograph.utils import tensors from tensorflow.python.autograph.utils import tensors
@ -103,13 +103,13 @@ INEFFICIENT_UNROLL_MIN_OPS = 1
def _disallow_undefs_into_loop(*values): def _disallow_undefs_into_loop(*values):
"""Ensures that all values in the state are defined when entering a loop.""" """Ensures that all values in the state are defined when entering a loop."""
undefined = tuple(filter(special_values.is_undefined, values)) undefined = [v for v in values if isinstance(v, variables.Undefined)]
if undefined: if undefined:
raise ValueError( raise ValueError(
'{} must be defined before the loop.'.format( '{} must be defined before the loop.'.format(
','.join(s.symbol_name for s in undefined))) ','.join(s.symbol_name for s in undefined)))
for value in values: for value in values:
if special_values.is_undefined_return(value): if isinstance(value, variables.UndefinedReturnValue):
# Assumption: the loop will only capture the variable which tracks the # Assumption: the loop will only capture the variable which tracks the
# return value if the loop contained a return statement. # return value if the loop contained a return statement.
# TODO(mdan): This should be checked at the place where return occurs. # TODO(mdan): This should be checked at the place where return occurs.
@ -1129,7 +1129,7 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
results_tuple = results results_tuple = results
else: else:
results_tuple = results, results_tuple = results,
undefined = tuple(filter(special_values.is_undefined, results_tuple)) undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
if undefined: if undefined:
raise ValueError( raise ValueError(
'The following symbols must also be initialized in the {} branch: {}.' 'The following symbols must also be initialized in the {} branch: {}.'
@ -1138,7 +1138,7 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
tuple(s.symbol_name for s in undefined))) tuple(s.symbol_name for s in undefined)))
for result in results_tuple: for result in results_tuple:
if special_values.is_undefined_return(result): if isinstance(result, variables.UndefinedReturnValue):
raise ValueError( raise ValueError(
'A value must also be returned from the {} branch. If a value is ' 'A value must also be returned from the {} branch. If a value is '
'returned from one branch of a conditional a value must be ' 'returned from one branch of a conditional a value must be '

View File

@ -29,7 +29,7 @@ import numpy as np
import six import six
from tensorflow.python.autograph.operators import control_flow from tensorflow.python.autograph.operators import control_flow
from tensorflow.python.autograph.operators import special_values from tensorflow.python.autograph.operators import variables as variable_operators
from tensorflow.python.autograph.utils import ag_logging from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
@ -546,7 +546,7 @@ class ForLoopTest(test.TestCase):
with self.assertRaisesRegex(ValueError, '"s" may not be None'): with self.assertRaisesRegex(ValueError, '"s" may not be None'):
self._basic_loop(None, lambda i, s: s) self._basic_loop(None, lambda i, s: s)
with self.assertRaisesRegex(ValueError, '"s" must be defined'): with self.assertRaisesRegex(ValueError, '"s" must be defined'):
self._basic_loop(special_values.Undefined(''), lambda i, s: s) self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
def test_tensor_none_output(self): def test_tensor_none_output(self):
with self.assertRaisesRegex(ValueError, '"s" is None at the end'): with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
@ -785,7 +785,7 @@ class WhileLoopTest(test.TestCase):
with self.assertRaisesRegex(ValueError, '"s" may not be None'): with self.assertRaisesRegex(ValueError, '"s" may not be None'):
self._basic_loop(None, lambda i, s: s) self._basic_loop(None, lambda i, s: s)
with self.assertRaisesRegex(ValueError, '"s" must be defined'): with self.assertRaisesRegex(ValueError, '"s" must be defined'):
self._basic_loop(special_values.Undefined(''), lambda i, s: s) self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
def test_tensor_none_output(self): def test_tensor_none_output(self):
with self.assertRaisesRegex(ValueError, '"s" is None at the end'): with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
@ -887,10 +887,10 @@ class IfStmtTest(test.TestCase):
def test_tensor_undefined_output(self): def test_tensor_undefined_output(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "must also be initialized in the if.*'s'"): ValueError, "must also be initialized in the if.*'s'"):
self._basic_cond(lambda: special_values.Undefined('s'), lambda: 1) self._basic_cond(lambda: variable_operators.Undefined('s'), lambda: 1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "must also be initialized in the else.*'s'"): ValueError, "must also be initialized in the else.*'s'"):
self._basic_cond(lambda: 1, lambda: special_values.Undefined('s')) self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s'))
def test_tensor_dtype_change(self): def test_tensor_dtype_change(self):
with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'): with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'):

View File

@ -1,115 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstract representation of composite symbols that can be used in staging code.
This provides a way to checkpoint the values of symbols that may be undefined
entering staged control flow. This checkpointing is necessary to prevent some
unintended side-effects. For example checkpointing prevents side-effects in one
branch of a conditional from leaking into another.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import special_values
is_undefined = special_values.is_undefined
Undefined = special_values.Undefined
class Symbol(object):
"""Representation of a simple or composite Python symbol.
Subclasses should implement `maybe_compute_value(self)` that returns the value
corresponding to the symbol or Undefined if no such value exists.
"""
def __init__(self, name):
self.name = name
class ValueSymbol(Symbol):
"""Representation of a simple Python symbol with a concrete value.
This includes variables and literals. Since we are reifying undefined symbols
`Undefined` is also a valid value.
"""
def __init__(self, name, value):
super(ValueSymbol, self).__init__(name)
self.value = value
def maybe_compute_value(self):
return self.value
class AttributeAccessSymbol(Symbol):
"""Representation of Python attribute access e.g. `a.b`."""
def __init__(self, parent_symbol, attr_name):
super(AttributeAccessSymbol, self).__init__(
parent_symbol.name + '.' + attr_name)
self.attr_name = attr_name
self.parent_symbol = parent_symbol
def maybe_compute_value(self):
"""Compute the value corresponding to the attribute access or `Undefined`.
This will be `Undefined` if no such value exists either because there is no
such attribute or if the base is itself undefined.
Returns:
value corresponding to the attribute access or `Undefined`
"""
parent_value = self.parent_symbol.maybe_compute_value()
if (is_undefined(parent_value) or
getattr(parent_value, self.attr_name, None) is None):
return Undefined(self.name)
return parent_value.__getattribute__(self.attr_name)
class SubscriptSymbol(Symbol):
"""Representation of Python subscript access e.g. `a[b]`."""
def __init__(self, parent_symbol, index_symbol):
super(SubscriptSymbol, self).__init__(
parent_symbol.name + '[' + index_symbol.name + ']')
self.index_symbol = index_symbol
self.parent_symbol = parent_symbol
def maybe_compute_value(self):
"""Compute the value corresponding to the subscript access or `Undefined`.
This will be `Undefined` if no such value exists either because there is no
element corresponding to the given subscript or if the base itself is
not defined.
Returns:
value corresponding to the subscript access or `Undefined`
"""
parent_value = self.parent_symbol.maybe_compute_value()
index_value = self.index_symbol.maybe_compute_value()
if is_undefined(parent_value) or is_undefined(index_value):
return Undefined(self.name)
try:
return parent_value[index_value]
except (IndexError, KeyError, TypeError):
# Reify the lack of an object for the given index/key
# This allows us to define them later without regret
return Undefined(self.name)

View File

@ -1,230 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for special symbol handling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import special_values
from tensorflow.python.autograph.operators import symbols
from tensorflow.python.platform import test
Undefined = special_values.Undefined
AttributeAccessSymbol = symbols.AttributeAccessSymbol
SubscriptSymbol = symbols.SubscriptSymbol
ValueSymbol = symbols.ValueSymbol
class SymbolsTest(test.TestCase):
def test_value_symbol_returns_value(self):
a = 42
a_symbol = ValueSymbol('a', a)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(a_symbol.name, 'a')
def test_attribute_access_missing_attribute(self):
class Foo(object):
pass
a = Foo()
a_symbol = ValueSymbol('a', a)
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a.b')
def test_attribute_access_undefined_target(self):
a = Undefined('a')
a_symbol = ValueSymbol('a', a)
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a.b')
def test_attribute_access_basic(self):
class Foo(object):
def __init__(self):
self.b = 'this is an attribute'
a = Foo()
a_symbol = ValueSymbol('a', a)
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
def test_item_access_undefined_index(self):
class Foo(object):
def __getitem__(self, key):
return 'this is an item'
a = Foo()
b = Undefined('b')
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
def test_item_access_no_getitem(self):
class Foo(object):
pass
a = Foo()
b = 42
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
def test_item_access_undefined_root(self):
a = Undefined('a')
b = 42
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
def test_item_access_basic(self):
class Foo(object):
def __getitem__(self, key):
return 'this is an item'
a = Foo()
b = 42
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
def test_item_access_after_attribute_access(self):
class Foo(object):
def __getitem__(self, key):
return 'this is an item'
class Bar(object):
def __init__(self):
self.b = Foo()
a = Bar()
c = 42
a_symbol = ValueSymbol('a', a)
c_symbol = ValueSymbol('c', c)
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
a_b_c_symbol = SubscriptSymbol(a_b_symbol, c_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(c_symbol.maybe_compute_value(), c)
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a.b[c])
def test_attribute_access_after_item_access(self):
class Bar(object):
def __init__(self):
self.c = object()
item = Bar()
class Foo(object):
def __getitem__(self, key):
return item
a = Foo()
b = 42
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
a_b_c_symbol = AttributeAccessSymbol(a_b_symbol, 'c')
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a[b].c)
def test_item_access_after_item_access(self):
class Bar(object):
def __getitem__(self, key):
return 'this is an item'
item = Bar()
class Foo(object):
def __getitem__(self, key):
return item
a = Foo()
b = 42
c = 43
a_symbol = ValueSymbol('a', a)
b_symbol = ValueSymbol('b', b)
c_symbol = ValueSymbol('b', c)
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
a_b_c_symbol = SubscriptSymbol(a_b_symbol, c_symbol)
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(b_symbol.maybe_compute_value(), b)
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a[b][c])
def test_attribute_access_after_attribute_access(self):
class Bar(object):
def __init__(self):
self.c = object()
class Foo(object):
def __init__(self):
self.b = Bar()
a = Foo()
a_symbol = ValueSymbol('a', a)
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
a_b_c_symbol = AttributeAccessSymbol(a_b_symbol, 'c')
self.assertEqual(a_symbol.maybe_compute_value(), a)
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a.b.c)
if __name__ == '__main__':
test.main()

View File

@ -19,6 +19,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
def ld(v):
"""Load variable operator."""
if isinstance(v, Undefined):
return v.read()
return v
class Undefined(object): class Undefined(object):
"""Represents an undefined symbol in Python. """Represents an undefined symbol in Python.
@ -51,6 +58,10 @@ class Undefined(object):
def __init__(self, symbol_name): def __init__(self, symbol_name):
self.symbol_name = symbol_name self.symbol_name = symbol_name
def read(self):
raise UnboundLocalError("'{}' is used before assignment".format(
self.symbol_name))
def __repr__(self): def __repr__(self):
return self.symbol_name return self.symbol_name
@ -66,34 +77,7 @@ class Undefined(object):
return self return self
def is_undefined(value):
"""Checks whether Autograph has determined that a given value is undefined.
This only works in places where Autograph reifies undefined symbols. Note that
if this function is passed a truly undefined symbol the call-site will raise
NameError.
Args:
value: value to test for undefinedness
Returns:
Boolean, whether the input value is undefined.
"""
return isinstance(value, Undefined)
# TODO(mdan): Refactor as a RetVal object, aggregating the value and do_return. # TODO(mdan): Refactor as a RetVal object, aggregating the value and do_return.
class UndefinedReturnValue(object): class UndefinedReturnValue(object):
"""Represents a default return value from a function (None in Python).""" """Represents a return value that is undefined."""
pass pass
def retval(value):
"""Returns the actual value that a return statement should produce."""
if isinstance(value, UndefinedReturnValue):
return None
return value
def is_undefined_return(value):
"""Checks whether `value` is the default return value."""
return isinstance(value, UndefinedReturnValue)

View File

@ -18,28 +18,38 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.autograph.operators import special_values from tensorflow.python.autograph.operators import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
class SpecialValuesTest(test.TestCase): class SpecialValuesTest(test.TestCase):
def test_undefined(self): def test_undefined(self):
undefined_symbol = special_values.Undefined('name') undefined_symbol = variables.Undefined('name')
self.assertEqual(undefined_symbol.symbol_name, 'name') undefined_symbol2 = variables.Undefined('name')
undefined_symbol2 = special_values.Undefined('name') self.assertEqual(undefined_symbol.symbol_name, 'name')
self.assertEqual(undefined_symbol2.symbol_name, 'name')
self.assertNotEqual(undefined_symbol, undefined_symbol2) self.assertNotEqual(undefined_symbol, undefined_symbol2)
self.assertTrue(special_values.is_undefined(undefined_symbol))
self.assertTrue(special_values.is_undefined(undefined_symbol2))
def test_undefined_operations(self): def test_undefined_operations(self):
undefined_symbol = special_values.Undefined('name') undefined_symbol = variables.Undefined('name')
self.assertIsInstance(undefined_symbol.foo, variables.Undefined)
self.assertIsInstance(undefined_symbol[0], variables.Undefined)
self.assertNotIsInstance(undefined_symbol.__class__, variables.Undefined)
def test_read(self):
self.assertEqual(variables.ld(1), 1)
o = object()
self.assertEqual(variables.ld(o), o)
self.assertIsNone(variables.ld(None))
def test_read_undefined(self):
with self.assertRaisesRegex(UnboundLocalError, 'used before assignment'):
variables.ld(variables.Undefined('a'))
self.assertTrue(special_values.is_undefined(undefined_symbol.foo))
self.assertTrue(special_values.is_undefined(undefined_symbol[0]))
self.assertFalse(special_values.is_undefined(undefined_symbol.__class__))
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -78,6 +78,18 @@ class ReachingDefinitionsAnalyzerTest(
self.assertSameDef(local_body[1].test, local_body[2].value.elts[0]) self.assertSameDef(local_body[1].test, local_body[2].value.elts[0])
# Note: the function name is is visible inside the function body. But it's
# a closure variable, not a local.
#
# Example:
#
# >>> def f():
# ... print(f)
# >>> g = f
# >>> f = 'something else'
# >>> g()
# something else
#
self.assertHasDefinedIn(local_body[1], ('a', 'b')) self.assertHasDefinedIn(local_body[1], ('a', 'b'))

View File

@ -255,6 +255,9 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
inner_fn_body = fn_body[1].body[1].body inner_fn_body = fn_body[1].body[1].body
def_of_a_in_foo = inner_fn_body[0].value def_of_a_in_foo = inner_fn_body[0].value
# Even though `a` is visible in the inner functio above, the late binding
# makes it impossible to assume that the same value will be visible at
# call time.
self.assertHasDefs(def_of_a_in_foo, 0) self.assertHasDefs(def_of_a_in_foo, 0)
def test_nested_functions_isolation(self): def test_nested_functions_isolation(self):