diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 9c1d5a38707..ec780a7c0a1 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -33,6 +33,7 @@ py_library( "logical_expressions.py", "return_statements.py", "slices.py", + "variables.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -213,3 +214,16 @@ py_test( "//tensorflow/python/autograph/pyct", ], ) + +py_test( + name = "variables_test", + srcs = ["variables_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":converters", + "//tensorflow/python:client_testlib", + "//tensorflow/python/autograph/core:test_lib", + "//tensorflow/python/autograph/pyct", + ], +) diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py index fd31cd15a0e..dc435cbc90e 100644 --- a/tensorflow/python/autograph/converters/asserts_test.py +++ b/tensorflow/python/autograph/converters/asserts_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.autograph.converters import asserts from tensorflow.python.autograph.converters import functions +from tensorflow.python.autograph.converters import return_statements from tensorflow.python.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl @@ -36,7 +37,8 @@ class AssertsTest(converter_testing.TestCase): return a with ops.Graph().as_default(): - with self.converted(test_fn, (functions, asserts), {}) as result: + with self.converted( + test_fn, (functions, asserts, return_statements), {}) as result: op = result.test_fn(constant_op.constant(False)) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'): diff --git a/tensorflow/python/autograph/converters/functions.py b/tensorflow/python/autograph/converters/functions.py index fc33dafb63d..26ead131f9b 100644 --- a/tensorflow/python/autograph/converters/functions.py +++ b/tensorflow/python/autograph/converters/functions.py @@ -38,15 +38,6 @@ class _Function(object): class FunctionTransformer(converter.Base): """Wraps function bodies around autograph-specific boilerplate.""" - def visit_Return(self, node): - if node.value is None: - return node - node = self.generic_visit(node) - return templates.replace( - 'return function_context_name.mark_return_value(value)', - function_context_name=self.state[_Function].context_name, - value=node.value) - def _function_scope_options(self, fn_scope): """Returns the options with which to create function scopes.""" # Top-level function receive the options that were directly requested. diff --git a/tensorflow/python/autograph/converters/functions_test.py b/tensorflow/python/autograph/converters/functions_test.py index aad455e67d7..2a51ef71ebf 100644 --- a/tensorflow/python/autograph/converters/functions_test.py +++ b/tensorflow/python/autograph/converters/functions_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.autograph.converters import functions +from tensorflow.python.autograph.converters import return_statements from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter_testing @@ -74,7 +75,7 @@ class FunctionTransformer(converter_testing.TestCase): l += 1 return l, inner_fn(l) - with self.converted(test_fn, functions, {}, + with self.converted(test_fn, (functions, return_statements), {}, (ops.name_scope,)) as result: first, second = result.test_fn(constant_op.constant(1)) self.assertIn('test_fn/', first.op.name) @@ -119,6 +120,7 @@ class FunctionTransformer(converter_testing.TestCase): ns = {'TestClass': TestClass} node, ctx = self.prepare(TestClass, ns) node = functions.transform(node, ctx) + node = return_statements.transform(node, ctx) with self.compiled(node, {}, (ops.name_scope,)) as result: first, second = result.TestClass().test_fn(constant_op.constant(1)) diff --git a/tensorflow/python/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py index 39bac60fb91..e4062e42db7 100644 --- a/tensorflow/python/autograph/converters/return_statements.py +++ b/tensorflow/python/autograph/converters/return_statements.py @@ -220,9 +220,9 @@ class ReturnStatementsTransformer(converter.Base): retval = val """ - def __init__(self, ctx, default_to_null_return): + def __init__(self, ctx, allow_missing_return): super(ReturnStatementsTransformer, self).__init__(ctx) - self.default_to_null_return = default_to_null_return + self.allow_missing_return = allow_missing_return def visit_Return(self, node): for block in reversed(self.state[_Block].stack): @@ -339,75 +339,68 @@ class ReturnStatementsTransformer(converter.Base): return node def visit_FunctionDef(self, node): - self.state[_Function].enter() - self.state[_Block].enter() - self.state[_Block].is_function = True + with self.state[_Function] as fn: + with self.state[_Block] as block: + block.is_function = True - scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - do_return_var_name = self.ctx.namer.new_symbol( - 'do_return', scope.referenced) - retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced) - self.state[_Function].do_return_var_name = do_return_var_name - self.state[_Function].retval_var_name = retval_var_name + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + do_return_var_name = self.ctx.namer.new_symbol('do_return', + scope.referenced) + retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced) + fn.do_return_var_name = do_return_var_name + fn.retval_var_name = retval_var_name - converted_body = self._visit_statement_block(node, node.body) + node.body = self._visit_statement_block(node, node.body) - # Avoid placing statements before any eventual docstring. - # TODO(mdan): Should a docstring even be included in the output? - docstring = None - if converted_body: - if (isinstance(converted_body[0], gast.Expr) and - isinstance(converted_body[0].value, gast.Constant)): - docstring = converted_body[0] - converted_body = converted_body[1:] + if block.return_used: - if self.state[_Block].return_used: + if self.allow_missing_return: + # The function whould have a single `with` node that wraps the + # entire body. If the function had a docstring, the body has two + # nodes, with the `with` as the second node. + wrapper_node = node.body[-1] + assert isinstance(wrapper_node, gast.With), ( + 'This transformer requires the functions converter.') - if self.default_to_null_return: - # TODO(mdan): Remove the (do_return_var_name,) below. - # Currently, that line ensures the variable is both defined and alive - # throughout the function. - template = """ - do_return_var_name = False - retval_var_name = ag__.UndefinedReturnValue() - body - (do_return_var_name,) - return ag__.retval(retval_var_name) - """ - else: - template = """ - body - return retval_var_name - """ - node.body = templates.replace( - template, - body=converted_body, - do_return_var_name=do_return_var_name, - retval_var_name=retval_var_name) + template = """ + do_return_var_name = False + retval_var_name = ag__.UndefinedReturnValue() + body + return function_context.ret(retval_var_name, do_return_var_name) + """ - if docstring: - node.body.insert(0, docstring) + wrapper_node.body = templates.replace( + template, + body=wrapper_node.body, + do_return_var_name=do_return_var_name, + function_context=anno.getanno(node, 'function_context_name'), + retval_var_name=retval_var_name) + else: + template = """ + body + return retval_var_name + """ + node.body = templates.replace( + template, + body=node.body, + do_return_var_name=do_return_var_name, + retval_var_name=retval_var_name) - self.state[_Block].exit() - self.state[_Function].exit() return node def transform(node, ctx, default_to_null_return=True): - """Ensure a function has only a single return.""" - # Note: Technically, these two could be merged into a single walk, but - # keeping them separate helps with readability. - + """Ensure a function has only a single return, at the end.""" node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) + # Note: Technically, these two could be merged into a single walk, but + # keeping them separate helps with readability. node = ConditionalReturnRewriter(ctx).visit(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) - transformer = ReturnStatementsTransformer( - ctx, default_to_null_return=default_to_null_return) + ctx, allow_missing_return=default_to_null_return) node = transformer.visit(node) - return node diff --git a/tensorflow/python/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py index df687927638..3f1e6a0bd97 100644 --- a/tensorflow/python/autograph/converters/return_statements_test.py +++ b/tensorflow/python/autograph/converters/return_statements_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.autograph.converters import functions from tensorflow.python.autograph.converters import return_statements from tensorflow.python.autograph.core import converter_testing from tensorflow.python.framework import ops @@ -28,7 +29,7 @@ class SingleReturnTest(converter_testing.TestCase): def assertTransformedEquivalent(self, test_fn, *inputs): ns = {'ops': ops} - with self.converted(test_fn, return_statements, ns) as result: + with self.converted(test_fn, (functions, return_statements), ns) as result: self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) def test_straightline(self): diff --git a/tensorflow/python/autograph/converters/variables.py b/tensorflow/python/autograph/converters/variables.py new file mode 100644 index 00000000000..3028a65a69b --- /dev/null +++ b/tensorflow/python/autograph/converters/variables.py @@ -0,0 +1,76 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Overloads all variable read operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import templates + + +class VariableAccessTransformer(converter.Base): + """Rewrites basic symbol reads. + + This transformer rewrites variable reads with a "read" operator which allows + tracking activity. + + Example: + + For a basic statement: + + a = b + c + + This is translated to: + + a = ld(b) + ld(c) + + Augmented assignment operations also introduce a `ld` operator: + + a += b + + The assignment target also receives an operator to properly represent the + read: + + a = ld(a) + a += ld(b) + """ + + def visit_Name(self, node): + # Only the loads which existed in the original code are overloaded. + if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS): + return node + if isinstance(node.ctx, gast.Load): + node = templates.replace_as_expression('ag__.ld(var_)', var_=node) + return node + + def visit_AugAssign(self, node): + if isinstance(node.target, gast.Name): + template = """ + var_ = ag__.ld(var_) + original + """ + node = templates.replace(template, var_=node.target, original=node) + else: + node = self.generic_visit(node) + return node + + +def transform(node, ctx): + return VariableAccessTransformer(ctx).visit(node) diff --git a/tensorflow/python/autograph/converters/variables_test.py b/tensorflow/python/autograph/converters/variables_test.py new file mode 100644 index 00000000000..556dafbaa8a --- /dev/null +++ b/tensorflow/python/autograph/converters/variables_test.py @@ -0,0 +1,116 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for variables module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.python.autograph.converters import variables +from tensorflow.python.autograph.core import converter_testing +from tensorflow.python.platform import test + + +class VariablesTest(converter_testing.TestCase): + + @contextlib.contextmanager + def apply_add_one_conversion(self, fn): + """Generates code which adds 1 to all variable reads.""" + with self.converted(fn, variables, {}) as result: + result.ag__.__dict__['ld'] = lambda x: x + 1 + yield result + + def test_read(self): + + def test_fn(l): + return l + + with self.apply_add_one_conversion(test_fn) as result: + self.assertEqual(result.test_fn(1), 2) + + def test_aug_assign(self): + + def test_fn(l): + l *= 10 + return l + + with self.apply_add_one_conversion(test_fn) as result: + self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads + + def test_attribute(self): + + class TestClass(object): + + def __init__(self): + self.v = 1 + + def __add__(self, other): + self.v += other + return self + + def test_fn(l): + return l.v + + tc = TestClass() + with self.apply_add_one_conversion(test_fn) as result: + self.assertEqual(result.test_fn(tc), 2) + + def test_subscript(self): + + class TestClass(object): + + def __init__(self): + self.v = 1 + + def __add__(self, other): + self.v += other + return self + + def __getitem__(self, _): + return self.v + + def test_fn(l): + return l[0] + + tc = TestClass() + with self.apply_add_one_conversion(test_fn) as result: + self.assertEqual(result.test_fn(tc), 2) + + def test_call(self): + + class TestClass(object): + + def __init__(self): + self.v = 1 + + def __add__(self, other): + self.v += other + return self + + def __call__(self): + return self.v + + def test_fn(l): + return l() + + tc = TestClass() + with self.apply_add_one_conversion(test_fn) as result: + self.assertEqual(result.test_fn(tc), 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD index 655dc118a37..4a5c50dac55 100644 --- a/tensorflow/python/autograph/core/BUILD +++ b/tensorflow/python/autograph/core/BUILD @@ -30,6 +30,7 @@ py_library( visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/python:framework_ops", + "//tensorflow/python/autograph/operators", "//tensorflow/python/autograph/pyct", "//tensorflow/python/autograph/pyct/static_analysis", "//tensorflow/python/autograph/utils", diff --git a/tensorflow/python/autograph/core/function_wrappers.py b/tensorflow/python/autograph/core/function_wrappers.py index cc0e7b98de5..d425f8b679d 100644 --- a/tensorflow/python/autograph/core/function_wrappers.py +++ b/tensorflow/python/autograph/core/function_wrappers.py @@ -20,12 +20,16 @@ from __future__ import print_function from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.operators import variables from tensorflow.python.framework import auto_control_deps from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.util import nest +# TODO(mdan): Move this into operators - it represents a function definition. + + class FunctionScope(object): """Context manager that wraps the body of a converted function. @@ -84,8 +88,13 @@ class FunctionScope(object): if self.use_auto_deps: self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb) - def mark_return_value(self, value): + def ret(self, value, did_return): """Marks a value as returned from the function guarded by the scope.""" + del did_return + + if isinstance(value, variables.UndefinedReturnValue): + return None + if self.use_auto_deps: self._return_value_marked = True if value is None: diff --git a/tensorflow/python/autograph/core/function_wrappers_test.py b/tensorflow/python/autograph/core/function_wrappers_test.py index 917a5358633..344ba495570 100644 --- a/tensorflow/python/autograph/core/function_wrappers_test.py +++ b/tensorflow/python/autograph/core/function_wrappers_test.py @@ -46,7 +46,7 @@ class FunctionWrappersTest(test.TestCase): converter.ConversionOptions( optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope: v.assign(2) - op = scope.mark_return_value(constant_op.constant(1)) + op = scope.ret(constant_op.constant(1), True) self.evaluate(op) self.assertEqual(self.evaluate(v.read_value()), 2) diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 7a7efe3d43a..eeea0aef896 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -36,6 +36,7 @@ from tensorflow.python.autograph.converters import lists from tensorflow.python.autograph.converters import logical_expressions from tensorflow.python.autograph.converters import return_statements from tensorflow.python.autograph.converters import slices +from tensorflow.python.autograph.converters import variables from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import function_wrappers @@ -92,6 +93,7 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler): node = control_flow.transform(node, ctx) node = conditional_expressions.transform(node, ctx) node = logical_expressions.transform(node, ctx) + node = variables.transform(node, ctx) return node diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index 6db9e4f8e3b..3851c7b44ba 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -29,8 +29,7 @@ py_library( "logical.py", "py_builtins.py", "slices.py", - "special_values.py", - "symbols.py", + "variables.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -148,19 +147,8 @@ py_test( ) py_test( - name = "special_values_test", - srcs = ["special_values_test.py"], - python_version = "PY3", - srcs_version = "PY2AND3", - deps = [ - ":operators", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "symbols_test", - srcs = ["symbols_test.py"], + name = "variables_test", + srcs = ["variables_test.py"], python_version = "PY3", srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py index 495b6070aae..f7f9078107c 100644 --- a/tensorflow/python/autograph/operators/__init__.py +++ b/tensorflow/python/autograph/operators/__init__.py @@ -60,8 +60,6 @@ from tensorflow.python.autograph.operators.py_builtins import range_ from tensorflow.python.autograph.operators.slices import get_item from tensorflow.python.autograph.operators.slices import GetItemOpts from tensorflow.python.autograph.operators.slices import set_item -from tensorflow.python.autograph.operators.special_values import is_undefined -from tensorflow.python.autograph.operators.special_values import is_undefined_return -from tensorflow.python.autograph.operators.special_values import retval -from tensorflow.python.autograph.operators.special_values import Undefined -from tensorflow.python.autograph.operators.special_values import UndefinedReturnValue +from tensorflow.python.autograph.operators.variables import ld +from tensorflow.python.autograph.operators.variables import Undefined +from tensorflow.python.autograph.operators.variables import UndefinedReturnValue diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index 48b7971ec16..592281b0ce2 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -65,7 +65,7 @@ import traceback import numpy as np from tensorflow.python.autograph.operators import py_builtins -from tensorflow.python.autograph.operators import special_values +from tensorflow.python.autograph.operators import variables from tensorflow.python.autograph.utils import ag_logging from tensorflow.python.autograph.utils import compat_util from tensorflow.python.autograph.utils import misc @@ -103,13 +103,13 @@ def _verify_loop_init_vars(values, symbol_names): for name, value in zip(symbol_names, values): if value is None: raise ValueError('"{}" may not be None before the loop.'.format(name)) - if special_values.is_undefined_return(value): + if isinstance(value, variables.UndefinedReturnValue): # Assumption: the loop will only capture the variable which tracks the # return value if the loop contained a return statement. # TODO(mdan): This should be checked at the place where return occurs. raise ValueError( 'return statements are not supported within a TensorFlow loop.') - if special_values.is_undefined(value): + if isinstance(value, variables.Undefined): raise ValueError('"{}" must be defined before the loop.'.format(name)) @@ -495,8 +495,7 @@ def _tf_range_for_stmt( iterate = compat_util.BasicRef(start) def _value_or(name, var, default): - if (name == opts['iterate_names'] - and isinstance(var, special_values.Undefined)): + if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)): return default return var @@ -1019,7 +1018,15 @@ def _wrap_disallow_undefs_from_cond(func, branch_name): results_tuple = results else: results_tuple = results, - undefined = tuple(filter(special_values.is_undefined, results_tuple)) + + for result in results_tuple: + if isinstance(result, variables.UndefinedReturnValue): + raise ValueError( + 'A value must also be returned from the {} branch. If a value is ' + 'returned from one branch of a conditional a value must be ' + 'returned from all branches.'.format(branch_name)) + + undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)] if undefined: raise ValueError( 'The following symbols must also be initialized in the {} branch: {}.' @@ -1027,13 +1034,6 @@ def _wrap_disallow_undefs_from_cond(func, branch_name): ' statement.'.format(branch_name, tuple(s.symbol_name for s in undefined))) - for result in results_tuple: - if special_values.is_undefined_return(result): - raise ValueError( - 'A value must also be returned from the {} branch. If a value is ' - 'returned from one branch of a conditional a value must be ' - 'returned from all branches.'.format(branch_name)) - return results return wrapper diff --git a/tensorflow/python/autograph/operators/control_flow_deprecated_py2.py b/tensorflow/python/autograph/operators/control_flow_deprecated_py2.py index e01a2f206c8..5a900fb19ed 100644 --- a/tensorflow/python/autograph/operators/control_flow_deprecated_py2.py +++ b/tensorflow/python/autograph/operators/control_flow_deprecated_py2.py @@ -66,7 +66,7 @@ import functools import numpy as np from tensorflow.python.autograph.operators import py_builtins -from tensorflow.python.autograph.operators import special_values +from tensorflow.python.autograph.operators import variables from tensorflow.python.autograph.utils import ag_logging from tensorflow.python.autograph.utils import misc from tensorflow.python.autograph.utils import tensors @@ -103,13 +103,13 @@ INEFFICIENT_UNROLL_MIN_OPS = 1 def _disallow_undefs_into_loop(*values): """Ensures that all values in the state are defined when entering a loop.""" - undefined = tuple(filter(special_values.is_undefined, values)) + undefined = [v for v in values if isinstance(v, variables.Undefined)] if undefined: raise ValueError( '{} must be defined before the loop.'.format( ','.join(s.symbol_name for s in undefined))) for value in values: - if special_values.is_undefined_return(value): + if isinstance(value, variables.UndefinedReturnValue): # Assumption: the loop will only capture the variable which tracks the # return value if the loop contained a return statement. # TODO(mdan): This should be checked at the place where return occurs. @@ -1129,7 +1129,7 @@ def _wrap_disallow_undefs_from_cond(func, branch_name): results_tuple = results else: results_tuple = results, - undefined = tuple(filter(special_values.is_undefined, results_tuple)) + undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)] if undefined: raise ValueError( 'The following symbols must also be initialized in the {} branch: {}.' @@ -1138,7 +1138,7 @@ def _wrap_disallow_undefs_from_cond(func, branch_name): tuple(s.symbol_name for s in undefined))) for result in results_tuple: - if special_values.is_undefined_return(result): + if isinstance(result, variables.UndefinedReturnValue): raise ValueError( 'A value must also be returned from the {} branch. If a value is ' 'returned from one branch of a conditional a value must be ' diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py index 5f0a9d09bf3..1c4407904b2 100644 --- a/tensorflow/python/autograph/operators/control_flow_test.py +++ b/tensorflow/python/autograph/operators/control_flow_test.py @@ -29,7 +29,7 @@ import numpy as np import six from tensorflow.python.autograph.operators import control_flow -from tensorflow.python.autograph.operators import special_values +from tensorflow.python.autograph.operators import variables as variable_operators from tensorflow.python.autograph.utils import ag_logging from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import def_function @@ -546,7 +546,7 @@ class ForLoopTest(test.TestCase): with self.assertRaisesRegex(ValueError, '"s" may not be None'): self._basic_loop(None, lambda i, s: s) with self.assertRaisesRegex(ValueError, '"s" must be defined'): - self._basic_loop(special_values.Undefined(''), lambda i, s: s) + self._basic_loop(variable_operators.Undefined(''), lambda i, s: s) def test_tensor_none_output(self): with self.assertRaisesRegex(ValueError, '"s" is None at the end'): @@ -785,7 +785,7 @@ class WhileLoopTest(test.TestCase): with self.assertRaisesRegex(ValueError, '"s" may not be None'): self._basic_loop(None, lambda i, s: s) with self.assertRaisesRegex(ValueError, '"s" must be defined'): - self._basic_loop(special_values.Undefined(''), lambda i, s: s) + self._basic_loop(variable_operators.Undefined(''), lambda i, s: s) def test_tensor_none_output(self): with self.assertRaisesRegex(ValueError, '"s" is None at the end'): @@ -887,10 +887,10 @@ class IfStmtTest(test.TestCase): def test_tensor_undefined_output(self): with self.assertRaisesRegex( ValueError, "must also be initialized in the if.*'s'"): - self._basic_cond(lambda: special_values.Undefined('s'), lambda: 1) + self._basic_cond(lambda: variable_operators.Undefined('s'), lambda: 1) with self.assertRaisesRegex( ValueError, "must also be initialized in the else.*'s'"): - self._basic_cond(lambda: 1, lambda: special_values.Undefined('s')) + self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s')) def test_tensor_dtype_change(self): with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'): diff --git a/tensorflow/python/autograph/operators/symbols.py b/tensorflow/python/autograph/operators/symbols.py deleted file mode 100644 index 0dd7e0a5956..00000000000 --- a/tensorflow/python/autograph/operators/symbols.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Abstract representation of composite symbols that can be used in staging code. - -This provides a way to checkpoint the values of symbols that may be undefined -entering staged control flow. This checkpointing is necessary to prevent some -unintended side-effects. For example checkpointing prevents side-effects in one -branch of a conditional from leaking into another. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.autograph.operators import special_values - - -is_undefined = special_values.is_undefined -Undefined = special_values.Undefined - - -class Symbol(object): - """Representation of a simple or composite Python symbol. - - Subclasses should implement `maybe_compute_value(self)` that returns the value - corresponding to the symbol or Undefined if no such value exists. - """ - - def __init__(self, name): - self.name = name - - -class ValueSymbol(Symbol): - """Representation of a simple Python symbol with a concrete value. - - This includes variables and literals. Since we are reifying undefined symbols - `Undefined` is also a valid value. - """ - - def __init__(self, name, value): - super(ValueSymbol, self).__init__(name) - self.value = value - - def maybe_compute_value(self): - return self.value - - -class AttributeAccessSymbol(Symbol): - """Representation of Python attribute access e.g. `a.b`.""" - - def __init__(self, parent_symbol, attr_name): - super(AttributeAccessSymbol, self).__init__( - parent_symbol.name + '.' + attr_name) - self.attr_name = attr_name - self.parent_symbol = parent_symbol - - def maybe_compute_value(self): - """Compute the value corresponding to the attribute access or `Undefined`. - - This will be `Undefined` if no such value exists either because there is no - such attribute or if the base is itself undefined. - - Returns: - value corresponding to the attribute access or `Undefined` - """ - parent_value = self.parent_symbol.maybe_compute_value() - if (is_undefined(parent_value) or - getattr(parent_value, self.attr_name, None) is None): - return Undefined(self.name) - - return parent_value.__getattribute__(self.attr_name) - - -class SubscriptSymbol(Symbol): - """Representation of Python subscript access e.g. `a[b]`.""" - - def __init__(self, parent_symbol, index_symbol): - super(SubscriptSymbol, self).__init__( - parent_symbol.name + '[' + index_symbol.name + ']') - self.index_symbol = index_symbol - self.parent_symbol = parent_symbol - - def maybe_compute_value(self): - """Compute the value corresponding to the subscript access or `Undefined`. - - This will be `Undefined` if no such value exists either because there is no - element corresponding to the given subscript or if the base itself is - not defined. - - Returns: - value corresponding to the subscript access or `Undefined` - """ - parent_value = self.parent_symbol.maybe_compute_value() - index_value = self.index_symbol.maybe_compute_value() - if is_undefined(parent_value) or is_undefined(index_value): - return Undefined(self.name) - - try: - return parent_value[index_value] - except (IndexError, KeyError, TypeError): - # Reify the lack of an object for the given index/key - # This allows us to define them later without regret - return Undefined(self.name) diff --git a/tensorflow/python/autograph/operators/symbols_test.py b/tensorflow/python/autograph/operators/symbols_test.py deleted file mode 100644 index 3acb16273bd..00000000000 --- a/tensorflow/python/autograph/operators/symbols_test.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for special symbol handling.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.autograph.operators import special_values -from tensorflow.python.autograph.operators import symbols -from tensorflow.python.platform import test - -Undefined = special_values.Undefined -AttributeAccessSymbol = symbols.AttributeAccessSymbol -SubscriptSymbol = symbols.SubscriptSymbol -ValueSymbol = symbols.ValueSymbol - - -class SymbolsTest(test.TestCase): - - def test_value_symbol_returns_value(self): - a = 42 - a_symbol = ValueSymbol('a', a) - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertEqual(a_symbol.name, 'a') - - def test_attribute_access_missing_attribute(self): - class Foo(object): - pass - a = Foo() - - a_symbol = ValueSymbol('a', a) - a_b_symbol = AttributeAccessSymbol(a_symbol, 'b') - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined) - self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a.b') - - def test_attribute_access_undefined_target(self): - a = Undefined('a') - a_symbol = ValueSymbol('a', a) - a_b_symbol = AttributeAccessSymbol(a_symbol, 'b') - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined) - self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a.b') - - def test_attribute_access_basic(self): - class Foo(object): - - def __init__(self): - self.b = 'this is an attribute' - - a = Foo() - a_symbol = ValueSymbol('a', a) - a_b_symbol = AttributeAccessSymbol(a_symbol, 'b') - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertEqual(a_b_symbol.maybe_compute_value(), a.b) - - def test_item_access_undefined_index(self): - class Foo(object): - - def __getitem__(self, key): - return 'this is an item' - - a = Foo() - b = Undefined('b') - a_symbol = ValueSymbol('a', a) - b_symbol = ValueSymbol('b', b) - a_b_symbol = SubscriptSymbol(a_symbol, b_symbol) - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertEqual(b_symbol.maybe_compute_value(), b) - self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined) - self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]') - - def test_item_access_no_getitem(self): - class Foo(object): - pass - - a = Foo() - b = 42 - a_symbol = ValueSymbol('a', a) - b_symbol = ValueSymbol('b', b) - a_b_symbol = SubscriptSymbol(a_symbol, b_symbol) - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertEqual(b_symbol.maybe_compute_value(), b) - self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined) - self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]') - - def test_item_access_undefined_root(self): - a = Undefined('a') - b = 42 - a_symbol = ValueSymbol('a', a) - b_symbol = ValueSymbol('b', b) - a_b_symbol = SubscriptSymbol(a_symbol, b_symbol) - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertEqual(b_symbol.maybe_compute_value(), b) - self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined) - self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]') - - def test_item_access_basic(self): - class Foo(object): - - def __getitem__(self, key): - return 'this is an item' - - a = Foo() - b = 42 - a_symbol = ValueSymbol('a', a) - b_symbol = ValueSymbol('b', b) - a_b_symbol = SubscriptSymbol(a_symbol, b_symbol) - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertEqual(b_symbol.maybe_compute_value(), b) - self.assertEqual(a_b_symbol.maybe_compute_value(), a[b]) - - def test_item_access_after_attribute_access(self): - class Foo(object): - - def __getitem__(self, key): - return 'this is an item' - - class Bar(object): - - def __init__(self): - self.b = Foo() - - a = Bar() - c = 42 - a_symbol = ValueSymbol('a', a) - c_symbol = ValueSymbol('c', c) - a_b_symbol = AttributeAccessSymbol(a_symbol, 'b') - a_b_c_symbol = SubscriptSymbol(a_b_symbol, c_symbol) - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertEqual(c_symbol.maybe_compute_value(), c) - self.assertEqual(a_b_symbol.maybe_compute_value(), a.b) - self.assertEqual(a_b_c_symbol.maybe_compute_value(), a.b[c]) - - def test_attribute_access_after_item_access(self): - class Bar(object): - - def __init__(self): - self.c = object() - - item = Bar() - - class Foo(object): - - def __getitem__(self, key): - return item - - a = Foo() - b = 42 - a_symbol = ValueSymbol('a', a) - b_symbol = ValueSymbol('b', b) - a_b_symbol = SubscriptSymbol(a_symbol, b_symbol) - a_b_c_symbol = AttributeAccessSymbol(a_b_symbol, 'c') - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertEqual(b_symbol.maybe_compute_value(), b) - self.assertEqual(a_b_symbol.maybe_compute_value(), a[b]) - self.assertEqual(a_b_c_symbol.maybe_compute_value(), a[b].c) - - def test_item_access_after_item_access(self): - class Bar(object): - - def __getitem__(self, key): - return 'this is an item' - - item = Bar() - - class Foo(object): - - def __getitem__(self, key): - return item - - a = Foo() - b = 42 - c = 43 - a_symbol = ValueSymbol('a', a) - b_symbol = ValueSymbol('b', b) - c_symbol = ValueSymbol('b', c) - a_b_symbol = SubscriptSymbol(a_symbol, b_symbol) - a_b_c_symbol = SubscriptSymbol(a_b_symbol, c_symbol) - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertEqual(b_symbol.maybe_compute_value(), b) - self.assertEqual(a_b_symbol.maybe_compute_value(), a[b]) - self.assertEqual(a_b_c_symbol.maybe_compute_value(), a[b][c]) - - def test_attribute_access_after_attribute_access(self): - class Bar(object): - - def __init__(self): - self.c = object() - - class Foo(object): - - def __init__(self): - self.b = Bar() - - a = Foo() - a_symbol = ValueSymbol('a', a) - a_b_symbol = AttributeAccessSymbol(a_symbol, 'b') - a_b_c_symbol = AttributeAccessSymbol(a_b_symbol, 'c') - - self.assertEqual(a_symbol.maybe_compute_value(), a) - self.assertEqual(a_b_symbol.maybe_compute_value(), a.b) - self.assertEqual(a_b_c_symbol.maybe_compute_value(), a.b.c) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/autograph/operators/special_values.py b/tensorflow/python/autograph/operators/variables.py similarity index 72% rename from tensorflow/python/autograph/operators/special_values.py rename to tensorflow/python/autograph/operators/variables.py index c172cce23f1..150f64e1758 100644 --- a/tensorflow/python/autograph/operators/special_values.py +++ b/tensorflow/python/autograph/operators/variables.py @@ -19,6 +19,13 @@ from __future__ import division from __future__ import print_function +def ld(v): + """Load variable operator.""" + if isinstance(v, Undefined): + return v.read() + return v + + class Undefined(object): """Represents an undefined symbol in Python. @@ -51,6 +58,10 @@ class Undefined(object): def __init__(self, symbol_name): self.symbol_name = symbol_name + def read(self): + raise UnboundLocalError("'{}' is used before assignment".format( + self.symbol_name)) + def __repr__(self): return self.symbol_name @@ -66,34 +77,7 @@ class Undefined(object): return self -def is_undefined(value): - """Checks whether Autograph has determined that a given value is undefined. - - This only works in places where Autograph reifies undefined symbols. Note that - if this function is passed a truly undefined symbol the call-site will raise - NameError. - - Args: - value: value to test for undefinedness - Returns: - Boolean, whether the input value is undefined. - """ - return isinstance(value, Undefined) - - # TODO(mdan): Refactor as a RetVal object, aggregating the value and do_return. class UndefinedReturnValue(object): - """Represents a default return value from a function (None in Python).""" + """Represents a return value that is undefined.""" pass - - -def retval(value): - """Returns the actual value that a return statement should produce.""" - if isinstance(value, UndefinedReturnValue): - return None - return value - - -def is_undefined_return(value): - """Checks whether `value` is the default return value.""" - return isinstance(value, UndefinedReturnValue) diff --git a/tensorflow/python/autograph/operators/special_values_test.py b/tensorflow/python/autograph/operators/variables_test.py similarity index 58% rename from tensorflow/python/autograph/operators/special_values_test.py rename to tensorflow/python/autograph/operators/variables_test.py index 1742cc4277d..168e6172232 100644 --- a/tensorflow/python/autograph/operators/special_values_test.py +++ b/tensorflow/python/autograph/operators/variables_test.py @@ -18,28 +18,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.autograph.operators import special_values +from tensorflow.python.autograph.operators import variables from tensorflow.python.platform import test class SpecialValuesTest(test.TestCase): def test_undefined(self): - undefined_symbol = special_values.Undefined('name') - self.assertEqual(undefined_symbol.symbol_name, 'name') + undefined_symbol = variables.Undefined('name') + undefined_symbol2 = variables.Undefined('name') - undefined_symbol2 = special_values.Undefined('name') + self.assertEqual(undefined_symbol.symbol_name, 'name') + self.assertEqual(undefined_symbol2.symbol_name, 'name') self.assertNotEqual(undefined_symbol, undefined_symbol2) - self.assertTrue(special_values.is_undefined(undefined_symbol)) - self.assertTrue(special_values.is_undefined(undefined_symbol2)) - def test_undefined_operations(self): - undefined_symbol = special_values.Undefined('name') + undefined_symbol = variables.Undefined('name') + + self.assertIsInstance(undefined_symbol.foo, variables.Undefined) + self.assertIsInstance(undefined_symbol[0], variables.Undefined) + self.assertNotIsInstance(undefined_symbol.__class__, variables.Undefined) + + def test_read(self): + self.assertEqual(variables.ld(1), 1) + o = object() + self.assertEqual(variables.ld(o), o) + + self.assertIsNone(variables.ld(None)) + + def test_read_undefined(self): + with self.assertRaisesRegex(UnboundLocalError, 'used before assignment'): + variables.ld(variables.Undefined('a')) - self.assertTrue(special_values.is_undefined(undefined_symbol.foo)) - self.assertTrue(special_values.is_undefined(undefined_symbol[0])) - self.assertFalse(special_values.is_undefined(undefined_symbol.__class__)) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_py3_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_py3_test.py index 7333ec0c872..ba27280f729 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_py3_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_py3_test.py @@ -78,6 +78,18 @@ class ReachingDefinitionsAnalyzerTest( self.assertSameDef(local_body[1].test, local_body[2].value.elts[0]) + # Note: the function name is is visible inside the function body. But it's + # a closure variable, not a local. + # + # Example: + # + # >>> def f(): + # ... print(f) + # >>> g = f + # >>> f = 'something else' + # >>> g() + # something else + # self.assertHasDefinedIn(local_body[1], ('a', 'b')) diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py index c4e7cbd4d17..64b00fcbeba 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -255,6 +255,9 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase): inner_fn_body = fn_body[1].body[1].body def_of_a_in_foo = inner_fn_body[0].value + # Even though `a` is visible in the inner functio above, the late binding + # makes it impossible to assume that the same value will be visible at + # call time. self.assertHasDefs(def_of_a_in_foo, 0) def test_nested_functions_isolation(self):