More robustly check for undefined symbols before attempting to use them. This check is required because undefined symbols are initialized with a special placeholder before entering control flow. This placeholder can lead to confusing error messages if left unchecked. The change introduces two more general operators: "variable load" and "return".
PiperOrigin-RevId: 311411422 Change-Id: Ic8abda74c1f68c1d4de491949d309d60099b91b4
This commit is contained in:
parent
18c0da1024
commit
4eeb6d742e
@ -33,6 +33,7 @@ py_library(
|
|||||||
"logical_expressions.py",
|
"logical_expressions.py",
|
||||||
"return_statements.py",
|
"return_statements.py",
|
||||||
"slices.py",
|
"slices.py",
|
||||||
|
"variables.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
@ -213,3 +214,16 @@ py_test(
|
|||||||
"//tensorflow/python/autograph/pyct",
|
"//tensorflow/python/autograph/pyct",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "variables_test",
|
||||||
|
srcs = ["variables_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":converters",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/autograph/core:test_lib",
|
||||||
|
"//tensorflow/python/autograph/pyct",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.autograph.converters import asserts
|
from tensorflow.python.autograph.converters import asserts
|
||||||
from tensorflow.python.autograph.converters import functions
|
from tensorflow.python.autograph.converters import functions
|
||||||
|
from tensorflow.python.autograph.converters import return_statements
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
@ -36,7 +37,8 @@ class AssertsTest(converter_testing.TestCase):
|
|||||||
return a
|
return a
|
||||||
|
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
with self.converted(test_fn, (functions, asserts), {}) as result:
|
with self.converted(
|
||||||
|
test_fn, (functions, asserts, return_statements), {}) as result:
|
||||||
op = result.test_fn(constant_op.constant(False))
|
op = result.test_fn(constant_op.constant(False))
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
|
||||||
|
@ -38,15 +38,6 @@ class _Function(object):
|
|||||||
class FunctionTransformer(converter.Base):
|
class FunctionTransformer(converter.Base):
|
||||||
"""Wraps function bodies around autograph-specific boilerplate."""
|
"""Wraps function bodies around autograph-specific boilerplate."""
|
||||||
|
|
||||||
def visit_Return(self, node):
|
|
||||||
if node.value is None:
|
|
||||||
return node
|
|
||||||
node = self.generic_visit(node)
|
|
||||||
return templates.replace(
|
|
||||||
'return function_context_name.mark_return_value(value)',
|
|
||||||
function_context_name=self.state[_Function].context_name,
|
|
||||||
value=node.value)
|
|
||||||
|
|
||||||
def _function_scope_options(self, fn_scope):
|
def _function_scope_options(self, fn_scope):
|
||||||
"""Returns the options with which to create function scopes."""
|
"""Returns the options with which to create function scopes."""
|
||||||
# Top-level function receive the options that were directly requested.
|
# Top-level function receive the options that were directly requested.
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.autograph.converters import functions
|
from tensorflow.python.autograph.converters import functions
|
||||||
|
from tensorflow.python.autograph.converters import return_statements
|
||||||
from tensorflow.python.autograph.core import ag_ctx
|
from tensorflow.python.autograph.core import ag_ctx
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
@ -74,7 +75,7 @@ class FunctionTransformer(converter_testing.TestCase):
|
|||||||
l += 1
|
l += 1
|
||||||
return l, inner_fn(l)
|
return l, inner_fn(l)
|
||||||
|
|
||||||
with self.converted(test_fn, functions, {},
|
with self.converted(test_fn, (functions, return_statements), {},
|
||||||
(ops.name_scope,)) as result:
|
(ops.name_scope,)) as result:
|
||||||
first, second = result.test_fn(constant_op.constant(1))
|
first, second = result.test_fn(constant_op.constant(1))
|
||||||
self.assertIn('test_fn/', first.op.name)
|
self.assertIn('test_fn/', first.op.name)
|
||||||
@ -119,6 +120,7 @@ class FunctionTransformer(converter_testing.TestCase):
|
|||||||
ns = {'TestClass': TestClass}
|
ns = {'TestClass': TestClass}
|
||||||
node, ctx = self.prepare(TestClass, ns)
|
node, ctx = self.prepare(TestClass, ns)
|
||||||
node = functions.transform(node, ctx)
|
node = functions.transform(node, ctx)
|
||||||
|
node = return_statements.transform(node, ctx)
|
||||||
|
|
||||||
with self.compiled(node, {}, (ops.name_scope,)) as result:
|
with self.compiled(node, {}, (ops.name_scope,)) as result:
|
||||||
first, second = result.TestClass().test_fn(constant_op.constant(1))
|
first, second = result.TestClass().test_fn(constant_op.constant(1))
|
||||||
|
@ -220,9 +220,9 @@ class ReturnStatementsTransformer(converter.Base):
|
|||||||
retval = val
|
retval = val
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ctx, default_to_null_return):
|
def __init__(self, ctx, allow_missing_return):
|
||||||
super(ReturnStatementsTransformer, self).__init__(ctx)
|
super(ReturnStatementsTransformer, self).__init__(ctx)
|
||||||
self.default_to_null_return = default_to_null_return
|
self.allow_missing_return = allow_missing_return
|
||||||
|
|
||||||
def visit_Return(self, node):
|
def visit_Return(self, node):
|
||||||
for block in reversed(self.state[_Block].stack):
|
for block in reversed(self.state[_Block].stack):
|
||||||
@ -339,41 +339,42 @@ class ReturnStatementsTransformer(converter.Base):
|
|||||||
return node
|
return node
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
self.state[_Function].enter()
|
with self.state[_Function] as fn:
|
||||||
self.state[_Block].enter()
|
with self.state[_Block] as block:
|
||||||
self.state[_Block].is_function = True
|
block.is_function = True
|
||||||
|
|
||||||
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
||||||
do_return_var_name = self.ctx.namer.new_symbol(
|
do_return_var_name = self.ctx.namer.new_symbol('do_return',
|
||||||
'do_return', scope.referenced)
|
scope.referenced)
|
||||||
retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced)
|
retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced)
|
||||||
self.state[_Function].do_return_var_name = do_return_var_name
|
fn.do_return_var_name = do_return_var_name
|
||||||
self.state[_Function].retval_var_name = retval_var_name
|
fn.retval_var_name = retval_var_name
|
||||||
|
|
||||||
converted_body = self._visit_statement_block(node, node.body)
|
node.body = self._visit_statement_block(node, node.body)
|
||||||
|
|
||||||
# Avoid placing statements before any eventual docstring.
|
if block.return_used:
|
||||||
# TODO(mdan): Should a docstring even be included in the output?
|
|
||||||
docstring = None
|
|
||||||
if converted_body:
|
|
||||||
if (isinstance(converted_body[0], gast.Expr) and
|
|
||||||
isinstance(converted_body[0].value, gast.Constant)):
|
|
||||||
docstring = converted_body[0]
|
|
||||||
converted_body = converted_body[1:]
|
|
||||||
|
|
||||||
if self.state[_Block].return_used:
|
if self.allow_missing_return:
|
||||||
|
# The function whould have a single `with` node that wraps the
|
||||||
|
# entire body. If the function had a docstring, the body has two
|
||||||
|
# nodes, with the `with` as the second node.
|
||||||
|
wrapper_node = node.body[-1]
|
||||||
|
assert isinstance(wrapper_node, gast.With), (
|
||||||
|
'This transformer requires the functions converter.')
|
||||||
|
|
||||||
if self.default_to_null_return:
|
|
||||||
# TODO(mdan): Remove the (do_return_var_name,) below.
|
|
||||||
# Currently, that line ensures the variable is both defined and alive
|
|
||||||
# throughout the function.
|
|
||||||
template = """
|
template = """
|
||||||
do_return_var_name = False
|
do_return_var_name = False
|
||||||
retval_var_name = ag__.UndefinedReturnValue()
|
retval_var_name = ag__.UndefinedReturnValue()
|
||||||
body
|
body
|
||||||
(do_return_var_name,)
|
return function_context.ret(retval_var_name, do_return_var_name)
|
||||||
return ag__.retval(retval_var_name)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
wrapper_node.body = templates.replace(
|
||||||
|
template,
|
||||||
|
body=wrapper_node.body,
|
||||||
|
do_return_var_name=do_return_var_name,
|
||||||
|
function_context=anno.getanno(node, 'function_context_name'),
|
||||||
|
retval_var_name=retval_var_name)
|
||||||
else:
|
else:
|
||||||
template = """
|
template = """
|
||||||
body
|
body
|
||||||
@ -381,33 +382,25 @@ class ReturnStatementsTransformer(converter.Base):
|
|||||||
"""
|
"""
|
||||||
node.body = templates.replace(
|
node.body = templates.replace(
|
||||||
template,
|
template,
|
||||||
body=converted_body,
|
body=node.body,
|
||||||
do_return_var_name=do_return_var_name,
|
do_return_var_name=do_return_var_name,
|
||||||
retval_var_name=retval_var_name)
|
retval_var_name=retval_var_name)
|
||||||
|
|
||||||
if docstring:
|
|
||||||
node.body.insert(0, docstring)
|
|
||||||
|
|
||||||
self.state[_Block].exit()
|
|
||||||
self.state[_Function].exit()
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
def transform(node, ctx, default_to_null_return=True):
|
def transform(node, ctx, default_to_null_return=True):
|
||||||
"""Ensure a function has only a single return."""
|
"""Ensure a function has only a single return, at the end."""
|
||||||
# Note: Technically, these two could be merged into a single walk, but
|
|
||||||
# keeping them separate helps with readability.
|
|
||||||
|
|
||||||
node = qual_names.resolve(node)
|
node = qual_names.resolve(node)
|
||||||
node = activity.resolve(node, ctx, None)
|
node = activity.resolve(node, ctx, None)
|
||||||
|
|
||||||
|
# Note: Technically, these two could be merged into a single walk, but
|
||||||
|
# keeping them separate helps with readability.
|
||||||
node = ConditionalReturnRewriter(ctx).visit(node)
|
node = ConditionalReturnRewriter(ctx).visit(node)
|
||||||
|
|
||||||
node = qual_names.resolve(node)
|
node = qual_names.resolve(node)
|
||||||
node = activity.resolve(node, ctx, None)
|
node = activity.resolve(node, ctx, None)
|
||||||
|
|
||||||
transformer = ReturnStatementsTransformer(
|
transformer = ReturnStatementsTransformer(
|
||||||
ctx, default_to_null_return=default_to_null_return)
|
ctx, allow_missing_return=default_to_null_return)
|
||||||
node = transformer.visit(node)
|
node = transformer.visit(node)
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.converters import functions
|
||||||
from tensorflow.python.autograph.converters import return_statements
|
from tensorflow.python.autograph.converters import return_statements
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -28,7 +29,7 @@ class SingleReturnTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
def assertTransformedEquivalent(self, test_fn, *inputs):
|
||||||
ns = {'ops': ops}
|
ns = {'ops': ops}
|
||||||
with self.converted(test_fn, return_statements, ns) as result:
|
with self.converted(test_fn, (functions, return_statements), ns) as result:
|
||||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
||||||
|
|
||||||
def test_straightline(self):
|
def test_straightline(self):
|
||||||
|
76
tensorflow/python/autograph/converters/variables.py
Normal file
76
tensorflow/python/autograph/converters/variables.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Overloads all variable read operations."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import gast
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.core import converter
|
||||||
|
from tensorflow.python.autograph.pyct import anno
|
||||||
|
from tensorflow.python.autograph.pyct import templates
|
||||||
|
|
||||||
|
|
||||||
|
class VariableAccessTransformer(converter.Base):
|
||||||
|
"""Rewrites basic symbol reads.
|
||||||
|
|
||||||
|
This transformer rewrites variable reads with a "read" operator which allows
|
||||||
|
tracking activity.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
For a basic statement:
|
||||||
|
|
||||||
|
a = b + c
|
||||||
|
|
||||||
|
This is translated to:
|
||||||
|
|
||||||
|
a = ld(b) + ld(c)
|
||||||
|
|
||||||
|
Augmented assignment operations also introduce a `ld` operator:
|
||||||
|
|
||||||
|
a += b
|
||||||
|
|
||||||
|
The assignment target also receives an operator to properly represent the
|
||||||
|
read:
|
||||||
|
|
||||||
|
a = ld(a)
|
||||||
|
a += ld(b)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def visit_Name(self, node):
|
||||||
|
# Only the loads which existed in the original code are overloaded.
|
||||||
|
if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS):
|
||||||
|
return node
|
||||||
|
if isinstance(node.ctx, gast.Load):
|
||||||
|
node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_AugAssign(self, node):
|
||||||
|
if isinstance(node.target, gast.Name):
|
||||||
|
template = """
|
||||||
|
var_ = ag__.ld(var_)
|
||||||
|
original
|
||||||
|
"""
|
||||||
|
node = templates.replace(template, var_=node.target, original=node)
|
||||||
|
else:
|
||||||
|
node = self.generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def transform(node, ctx):
|
||||||
|
return VariableAccessTransformer(ctx).visit(node)
|
116
tensorflow/python/autograph/converters/variables_test.py
Normal file
116
tensorflow/python/autograph/converters/variables_test.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for variables module."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.converters import variables
|
||||||
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class VariablesTest(converter_testing.TestCase):
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def apply_add_one_conversion(self, fn):
|
||||||
|
"""Generates code which adds 1 to all variable reads."""
|
||||||
|
with self.converted(fn, variables, {}) as result:
|
||||||
|
result.ag__.__dict__['ld'] = lambda x: x + 1
|
||||||
|
yield result
|
||||||
|
|
||||||
|
def test_read(self):
|
||||||
|
|
||||||
|
def test_fn(l):
|
||||||
|
return l
|
||||||
|
|
||||||
|
with self.apply_add_one_conversion(test_fn) as result:
|
||||||
|
self.assertEqual(result.test_fn(1), 2)
|
||||||
|
|
||||||
|
def test_aug_assign(self):
|
||||||
|
|
||||||
|
def test_fn(l):
|
||||||
|
l *= 10
|
||||||
|
return l
|
||||||
|
|
||||||
|
with self.apply_add_one_conversion(test_fn) as result:
|
||||||
|
self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads
|
||||||
|
|
||||||
|
def test_attribute(self):
|
||||||
|
|
||||||
|
class TestClass(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.v = 1
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
self.v += other
|
||||||
|
return self
|
||||||
|
|
||||||
|
def test_fn(l):
|
||||||
|
return l.v
|
||||||
|
|
||||||
|
tc = TestClass()
|
||||||
|
with self.apply_add_one_conversion(test_fn) as result:
|
||||||
|
self.assertEqual(result.test_fn(tc), 2)
|
||||||
|
|
||||||
|
def test_subscript(self):
|
||||||
|
|
||||||
|
class TestClass(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.v = 1
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
self.v += other
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __getitem__(self, _):
|
||||||
|
return self.v
|
||||||
|
|
||||||
|
def test_fn(l):
|
||||||
|
return l[0]
|
||||||
|
|
||||||
|
tc = TestClass()
|
||||||
|
with self.apply_add_one_conversion(test_fn) as result:
|
||||||
|
self.assertEqual(result.test_fn(tc), 2)
|
||||||
|
|
||||||
|
def test_call(self):
|
||||||
|
|
||||||
|
class TestClass(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.v = 1
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
self.v += other
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return self.v
|
||||||
|
|
||||||
|
def test_fn(l):
|
||||||
|
return l()
|
||||||
|
|
||||||
|
tc = TestClass()
|
||||||
|
with self.apply_add_one_conversion(test_fn) as result:
|
||||||
|
self.assertEqual(result.test_fn(tc), 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -30,6 +30,7 @@ py_library(
|
|||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python/autograph/operators",
|
||||||
"//tensorflow/python/autograph/pyct",
|
"//tensorflow/python/autograph/pyct",
|
||||||
"//tensorflow/python/autograph/pyct/static_analysis",
|
"//tensorflow/python/autograph/pyct/static_analysis",
|
||||||
"//tensorflow/python/autograph/utils",
|
"//tensorflow/python/autograph/utils",
|
||||||
|
@ -20,12 +20,16 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.autograph.core import ag_ctx
|
from tensorflow.python.autograph.core import ag_ctx
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
|
from tensorflow.python.autograph.operators import variables
|
||||||
from tensorflow.python.framework import auto_control_deps
|
from tensorflow.python.framework import auto_control_deps
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(mdan): Move this into operators - it represents a function definition.
|
||||||
|
|
||||||
|
|
||||||
class FunctionScope(object):
|
class FunctionScope(object):
|
||||||
"""Context manager that wraps the body of a converted function.
|
"""Context manager that wraps the body of a converted function.
|
||||||
|
|
||||||
@ -84,8 +88,13 @@ class FunctionScope(object):
|
|||||||
if self.use_auto_deps:
|
if self.use_auto_deps:
|
||||||
self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb)
|
self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
def mark_return_value(self, value):
|
def ret(self, value, did_return):
|
||||||
"""Marks a value as returned from the function guarded by the scope."""
|
"""Marks a value as returned from the function guarded by the scope."""
|
||||||
|
del did_return
|
||||||
|
|
||||||
|
if isinstance(value, variables.UndefinedReturnValue):
|
||||||
|
return None
|
||||||
|
|
||||||
if self.use_auto_deps:
|
if self.use_auto_deps:
|
||||||
self._return_value_marked = True
|
self._return_value_marked = True
|
||||||
if value is None:
|
if value is None:
|
||||||
|
@ -46,7 +46,7 @@ class FunctionWrappersTest(test.TestCase):
|
|||||||
converter.ConversionOptions(
|
converter.ConversionOptions(
|
||||||
optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope:
|
optional_features=converter.Feature.AUTO_CONTROL_DEPS)) as scope:
|
||||||
v.assign(2)
|
v.assign(2)
|
||||||
op = scope.mark_return_value(constant_op.constant(1))
|
op = scope.ret(constant_op.constant(1), True)
|
||||||
self.evaluate(op)
|
self.evaluate(op)
|
||||||
self.assertEqual(self.evaluate(v.read_value()), 2)
|
self.assertEqual(self.evaluate(v.read_value()), 2)
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ from tensorflow.python.autograph.converters import lists
|
|||||||
from tensorflow.python.autograph.converters import logical_expressions
|
from tensorflow.python.autograph.converters import logical_expressions
|
||||||
from tensorflow.python.autograph.converters import return_statements
|
from tensorflow.python.autograph.converters import return_statements
|
||||||
from tensorflow.python.autograph.converters import slices
|
from tensorflow.python.autograph.converters import slices
|
||||||
|
from tensorflow.python.autograph.converters import variables
|
||||||
from tensorflow.python.autograph.core import config
|
from tensorflow.python.autograph.core import config
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
from tensorflow.python.autograph.core import function_wrappers
|
from tensorflow.python.autograph.core import function_wrappers
|
||||||
@ -92,6 +93,7 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
|
|||||||
node = control_flow.transform(node, ctx)
|
node = control_flow.transform(node, ctx)
|
||||||
node = conditional_expressions.transform(node, ctx)
|
node = conditional_expressions.transform(node, ctx)
|
||||||
node = logical_expressions.transform(node, ctx)
|
node = logical_expressions.transform(node, ctx)
|
||||||
|
node = variables.transform(node, ctx)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,8 +29,7 @@ py_library(
|
|||||||
"logical.py",
|
"logical.py",
|
||||||
"py_builtins.py",
|
"py_builtins.py",
|
||||||
"slices.py",
|
"slices.py",
|
||||||
"special_values.py",
|
"variables.py",
|
||||||
"symbols.py",
|
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
@ -148,19 +147,8 @@ py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "special_values_test",
|
name = "variables_test",
|
||||||
srcs = ["special_values_test.py"],
|
srcs = ["variables_test.py"],
|
||||||
python_version = "PY3",
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":operators",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "symbols_test",
|
|
||||||
srcs = ["symbols_test.py"],
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
@ -60,8 +60,6 @@ from tensorflow.python.autograph.operators.py_builtins import range_
|
|||||||
from tensorflow.python.autograph.operators.slices import get_item
|
from tensorflow.python.autograph.operators.slices import get_item
|
||||||
from tensorflow.python.autograph.operators.slices import GetItemOpts
|
from tensorflow.python.autograph.operators.slices import GetItemOpts
|
||||||
from tensorflow.python.autograph.operators.slices import set_item
|
from tensorflow.python.autograph.operators.slices import set_item
|
||||||
from tensorflow.python.autograph.operators.special_values import is_undefined
|
from tensorflow.python.autograph.operators.variables import ld
|
||||||
from tensorflow.python.autograph.operators.special_values import is_undefined_return
|
from tensorflow.python.autograph.operators.variables import Undefined
|
||||||
from tensorflow.python.autograph.operators.special_values import retval
|
from tensorflow.python.autograph.operators.variables import UndefinedReturnValue
|
||||||
from tensorflow.python.autograph.operators.special_values import Undefined
|
|
||||||
from tensorflow.python.autograph.operators.special_values import UndefinedReturnValue
|
|
||||||
|
@ -65,7 +65,7 @@ import traceback
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.autograph.operators import py_builtins
|
from tensorflow.python.autograph.operators import py_builtins
|
||||||
from tensorflow.python.autograph.operators import special_values
|
from tensorflow.python.autograph.operators import variables
|
||||||
from tensorflow.python.autograph.utils import ag_logging
|
from tensorflow.python.autograph.utils import ag_logging
|
||||||
from tensorflow.python.autograph.utils import compat_util
|
from tensorflow.python.autograph.utils import compat_util
|
||||||
from tensorflow.python.autograph.utils import misc
|
from tensorflow.python.autograph.utils import misc
|
||||||
@ -103,13 +103,13 @@ def _verify_loop_init_vars(values, symbol_names):
|
|||||||
for name, value in zip(symbol_names, values):
|
for name, value in zip(symbol_names, values):
|
||||||
if value is None:
|
if value is None:
|
||||||
raise ValueError('"{}" may not be None before the loop.'.format(name))
|
raise ValueError('"{}" may not be None before the loop.'.format(name))
|
||||||
if special_values.is_undefined_return(value):
|
if isinstance(value, variables.UndefinedReturnValue):
|
||||||
# Assumption: the loop will only capture the variable which tracks the
|
# Assumption: the loop will only capture the variable which tracks the
|
||||||
# return value if the loop contained a return statement.
|
# return value if the loop contained a return statement.
|
||||||
# TODO(mdan): This should be checked at the place where return occurs.
|
# TODO(mdan): This should be checked at the place where return occurs.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'return statements are not supported within a TensorFlow loop.')
|
'return statements are not supported within a TensorFlow loop.')
|
||||||
if special_values.is_undefined(value):
|
if isinstance(value, variables.Undefined):
|
||||||
raise ValueError('"{}" must be defined before the loop.'.format(name))
|
raise ValueError('"{}" must be defined before the loop.'.format(name))
|
||||||
|
|
||||||
|
|
||||||
@ -495,8 +495,7 @@ def _tf_range_for_stmt(
|
|||||||
iterate = compat_util.BasicRef(start)
|
iterate = compat_util.BasicRef(start)
|
||||||
|
|
||||||
def _value_or(name, var, default):
|
def _value_or(name, var, default):
|
||||||
if (name == opts['iterate_names']
|
if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
|
||||||
and isinstance(var, special_values.Undefined)):
|
|
||||||
return default
|
return default
|
||||||
return var
|
return var
|
||||||
|
|
||||||
@ -1019,7 +1018,15 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
|
|||||||
results_tuple = results
|
results_tuple = results
|
||||||
else:
|
else:
|
||||||
results_tuple = results,
|
results_tuple = results,
|
||||||
undefined = tuple(filter(special_values.is_undefined, results_tuple))
|
|
||||||
|
for result in results_tuple:
|
||||||
|
if isinstance(result, variables.UndefinedReturnValue):
|
||||||
|
raise ValueError(
|
||||||
|
'A value must also be returned from the {} branch. If a value is '
|
||||||
|
'returned from one branch of a conditional a value must be '
|
||||||
|
'returned from all branches.'.format(branch_name))
|
||||||
|
|
||||||
|
undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
|
||||||
if undefined:
|
if undefined:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The following symbols must also be initialized in the {} branch: {}.'
|
'The following symbols must also be initialized in the {} branch: {}.'
|
||||||
@ -1027,13 +1034,6 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
|
|||||||
' statement.'.format(branch_name,
|
' statement.'.format(branch_name,
|
||||||
tuple(s.symbol_name for s in undefined)))
|
tuple(s.symbol_name for s in undefined)))
|
||||||
|
|
||||||
for result in results_tuple:
|
|
||||||
if special_values.is_undefined_return(result):
|
|
||||||
raise ValueError(
|
|
||||||
'A value must also be returned from the {} branch. If a value is '
|
|
||||||
'returned from one branch of a conditional a value must be '
|
|
||||||
'returned from all branches.'.format(branch_name))
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -66,7 +66,7 @@ import functools
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.autograph.operators import py_builtins
|
from tensorflow.python.autograph.operators import py_builtins
|
||||||
from tensorflow.python.autograph.operators import special_values
|
from tensorflow.python.autograph.operators import variables
|
||||||
from tensorflow.python.autograph.utils import ag_logging
|
from tensorflow.python.autograph.utils import ag_logging
|
||||||
from tensorflow.python.autograph.utils import misc
|
from tensorflow.python.autograph.utils import misc
|
||||||
from tensorflow.python.autograph.utils import tensors
|
from tensorflow.python.autograph.utils import tensors
|
||||||
@ -103,13 +103,13 @@ INEFFICIENT_UNROLL_MIN_OPS = 1
|
|||||||
|
|
||||||
def _disallow_undefs_into_loop(*values):
|
def _disallow_undefs_into_loop(*values):
|
||||||
"""Ensures that all values in the state are defined when entering a loop."""
|
"""Ensures that all values in the state are defined when entering a loop."""
|
||||||
undefined = tuple(filter(special_values.is_undefined, values))
|
undefined = [v for v in values if isinstance(v, variables.Undefined)]
|
||||||
if undefined:
|
if undefined:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'{} must be defined before the loop.'.format(
|
'{} must be defined before the loop.'.format(
|
||||||
','.join(s.symbol_name for s in undefined)))
|
','.join(s.symbol_name for s in undefined)))
|
||||||
for value in values:
|
for value in values:
|
||||||
if special_values.is_undefined_return(value):
|
if isinstance(value, variables.UndefinedReturnValue):
|
||||||
# Assumption: the loop will only capture the variable which tracks the
|
# Assumption: the loop will only capture the variable which tracks the
|
||||||
# return value if the loop contained a return statement.
|
# return value if the loop contained a return statement.
|
||||||
# TODO(mdan): This should be checked at the place where return occurs.
|
# TODO(mdan): This should be checked at the place where return occurs.
|
||||||
@ -1129,7 +1129,7 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
|
|||||||
results_tuple = results
|
results_tuple = results
|
||||||
else:
|
else:
|
||||||
results_tuple = results,
|
results_tuple = results,
|
||||||
undefined = tuple(filter(special_values.is_undefined, results_tuple))
|
undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
|
||||||
if undefined:
|
if undefined:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The following symbols must also be initialized in the {} branch: {}.'
|
'The following symbols must also be initialized in the {} branch: {}.'
|
||||||
@ -1138,7 +1138,7 @@ def _wrap_disallow_undefs_from_cond(func, branch_name):
|
|||||||
tuple(s.symbol_name for s in undefined)))
|
tuple(s.symbol_name for s in undefined)))
|
||||||
|
|
||||||
for result in results_tuple:
|
for result in results_tuple:
|
||||||
if special_values.is_undefined_return(result):
|
if isinstance(result, variables.UndefinedReturnValue):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'A value must also be returned from the {} branch. If a value is '
|
'A value must also be returned from the {} branch. If a value is '
|
||||||
'returned from one branch of a conditional a value must be '
|
'returned from one branch of a conditional a value must be '
|
||||||
|
@ -29,7 +29,7 @@ import numpy as np
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python.autograph.operators import control_flow
|
from tensorflow.python.autograph.operators import control_flow
|
||||||
from tensorflow.python.autograph.operators import special_values
|
from tensorflow.python.autograph.operators import variables as variable_operators
|
||||||
from tensorflow.python.autograph.utils import ag_logging
|
from tensorflow.python.autograph.utils import ag_logging
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -546,7 +546,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, '"s" may not be None'):
|
with self.assertRaisesRegex(ValueError, '"s" may not be None'):
|
||||||
self._basic_loop(None, lambda i, s: s)
|
self._basic_loop(None, lambda i, s: s)
|
||||||
with self.assertRaisesRegex(ValueError, '"s" must be defined'):
|
with self.assertRaisesRegex(ValueError, '"s" must be defined'):
|
||||||
self._basic_loop(special_values.Undefined(''), lambda i, s: s)
|
self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
|
||||||
|
|
||||||
def test_tensor_none_output(self):
|
def test_tensor_none_output(self):
|
||||||
with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
|
with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
|
||||||
@ -785,7 +785,7 @@ class WhileLoopTest(test.TestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, '"s" may not be None'):
|
with self.assertRaisesRegex(ValueError, '"s" may not be None'):
|
||||||
self._basic_loop(None, lambda i, s: s)
|
self._basic_loop(None, lambda i, s: s)
|
||||||
with self.assertRaisesRegex(ValueError, '"s" must be defined'):
|
with self.assertRaisesRegex(ValueError, '"s" must be defined'):
|
||||||
self._basic_loop(special_values.Undefined(''), lambda i, s: s)
|
self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
|
||||||
|
|
||||||
def test_tensor_none_output(self):
|
def test_tensor_none_output(self):
|
||||||
with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
|
with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
|
||||||
@ -887,10 +887,10 @@ class IfStmtTest(test.TestCase):
|
|||||||
def test_tensor_undefined_output(self):
|
def test_tensor_undefined_output(self):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "must also be initialized in the if.*'s'"):
|
ValueError, "must also be initialized in the if.*'s'"):
|
||||||
self._basic_cond(lambda: special_values.Undefined('s'), lambda: 1)
|
self._basic_cond(lambda: variable_operators.Undefined('s'), lambda: 1)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "must also be initialized in the else.*'s'"):
|
ValueError, "must also be initialized in the else.*'s'"):
|
||||||
self._basic_cond(lambda: 1, lambda: special_values.Undefined('s'))
|
self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s'))
|
||||||
|
|
||||||
def test_tensor_dtype_change(self):
|
def test_tensor_dtype_change(self):
|
||||||
with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'):
|
with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'):
|
||||||
|
@ -1,115 +0,0 @@
|
|||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Abstract representation of composite symbols that can be used in staging code.
|
|
||||||
|
|
||||||
This provides a way to checkpoint the values of symbols that may be undefined
|
|
||||||
entering staged control flow. This checkpointing is necessary to prevent some
|
|
||||||
unintended side-effects. For example checkpointing prevents side-effects in one
|
|
||||||
branch of a conditional from leaking into another.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.python.autograph.operators import special_values
|
|
||||||
|
|
||||||
|
|
||||||
is_undefined = special_values.is_undefined
|
|
||||||
Undefined = special_values.Undefined
|
|
||||||
|
|
||||||
|
|
||||||
class Symbol(object):
|
|
||||||
"""Representation of a simple or composite Python symbol.
|
|
||||||
|
|
||||||
Subclasses should implement `maybe_compute_value(self)` that returns the value
|
|
||||||
corresponding to the symbol or Undefined if no such value exists.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, name):
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
|
|
||||||
class ValueSymbol(Symbol):
|
|
||||||
"""Representation of a simple Python symbol with a concrete value.
|
|
||||||
|
|
||||||
This includes variables and literals. Since we are reifying undefined symbols
|
|
||||||
`Undefined` is also a valid value.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, name, value):
|
|
||||||
super(ValueSymbol, self).__init__(name)
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
def maybe_compute_value(self):
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
|
|
||||||
class AttributeAccessSymbol(Symbol):
|
|
||||||
"""Representation of Python attribute access e.g. `a.b`."""
|
|
||||||
|
|
||||||
def __init__(self, parent_symbol, attr_name):
|
|
||||||
super(AttributeAccessSymbol, self).__init__(
|
|
||||||
parent_symbol.name + '.' + attr_name)
|
|
||||||
self.attr_name = attr_name
|
|
||||||
self.parent_symbol = parent_symbol
|
|
||||||
|
|
||||||
def maybe_compute_value(self):
|
|
||||||
"""Compute the value corresponding to the attribute access or `Undefined`.
|
|
||||||
|
|
||||||
This will be `Undefined` if no such value exists either because there is no
|
|
||||||
such attribute or if the base is itself undefined.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
value corresponding to the attribute access or `Undefined`
|
|
||||||
"""
|
|
||||||
parent_value = self.parent_symbol.maybe_compute_value()
|
|
||||||
if (is_undefined(parent_value) or
|
|
||||||
getattr(parent_value, self.attr_name, None) is None):
|
|
||||||
return Undefined(self.name)
|
|
||||||
|
|
||||||
return parent_value.__getattribute__(self.attr_name)
|
|
||||||
|
|
||||||
|
|
||||||
class SubscriptSymbol(Symbol):
|
|
||||||
"""Representation of Python subscript access e.g. `a[b]`."""
|
|
||||||
|
|
||||||
def __init__(self, parent_symbol, index_symbol):
|
|
||||||
super(SubscriptSymbol, self).__init__(
|
|
||||||
parent_symbol.name + '[' + index_symbol.name + ']')
|
|
||||||
self.index_symbol = index_symbol
|
|
||||||
self.parent_symbol = parent_symbol
|
|
||||||
|
|
||||||
def maybe_compute_value(self):
|
|
||||||
"""Compute the value corresponding to the subscript access or `Undefined`.
|
|
||||||
|
|
||||||
This will be `Undefined` if no such value exists either because there is no
|
|
||||||
element corresponding to the given subscript or if the base itself is
|
|
||||||
not defined.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
value corresponding to the subscript access or `Undefined`
|
|
||||||
"""
|
|
||||||
parent_value = self.parent_symbol.maybe_compute_value()
|
|
||||||
index_value = self.index_symbol.maybe_compute_value()
|
|
||||||
if is_undefined(parent_value) or is_undefined(index_value):
|
|
||||||
return Undefined(self.name)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return parent_value[index_value]
|
|
||||||
except (IndexError, KeyError, TypeError):
|
|
||||||
# Reify the lack of an object for the given index/key
|
|
||||||
# This allows us to define them later without regret
|
|
||||||
return Undefined(self.name)
|
|
@ -1,230 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Tests for special symbol handling."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.python.autograph.operators import special_values
|
|
||||||
from tensorflow.python.autograph.operators import symbols
|
|
||||||
from tensorflow.python.platform import test
|
|
||||||
|
|
||||||
Undefined = special_values.Undefined
|
|
||||||
AttributeAccessSymbol = symbols.AttributeAccessSymbol
|
|
||||||
SubscriptSymbol = symbols.SubscriptSymbol
|
|
||||||
ValueSymbol = symbols.ValueSymbol
|
|
||||||
|
|
||||||
|
|
||||||
class SymbolsTest(test.TestCase):
|
|
||||||
|
|
||||||
def test_value_symbol_returns_value(self):
|
|
||||||
a = 42
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertEqual(a_symbol.name, 'a')
|
|
||||||
|
|
||||||
def test_attribute_access_missing_attribute(self):
|
|
||||||
class Foo(object):
|
|
||||||
pass
|
|
||||||
a = Foo()
|
|
||||||
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a.b')
|
|
||||||
|
|
||||||
def test_attribute_access_undefined_target(self):
|
|
||||||
a = Undefined('a')
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a.b')
|
|
||||||
|
|
||||||
def test_attribute_access_basic(self):
|
|
||||||
class Foo(object):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.b = 'this is an attribute'
|
|
||||||
|
|
||||||
a = Foo()
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
|
|
||||||
|
|
||||||
def test_item_access_undefined_index(self):
|
|
||||||
class Foo(object):
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return 'this is an item'
|
|
||||||
|
|
||||||
a = Foo()
|
|
||||||
b = Undefined('b')
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
b_symbol = ValueSymbol('b', b)
|
|
||||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
|
||||||
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
|
|
||||||
|
|
||||||
def test_item_access_no_getitem(self):
|
|
||||||
class Foo(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
a = Foo()
|
|
||||||
b = 42
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
b_symbol = ValueSymbol('b', b)
|
|
||||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
|
||||||
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
|
|
||||||
|
|
||||||
def test_item_access_undefined_root(self):
|
|
||||||
a = Undefined('a')
|
|
||||||
b = 42
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
b_symbol = ValueSymbol('b', b)
|
|
||||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
|
||||||
self.assertIsInstance(a_b_symbol.maybe_compute_value(), Undefined)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value().symbol_name, 'a[b]')
|
|
||||||
|
|
||||||
def test_item_access_basic(self):
|
|
||||||
class Foo(object):
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return 'this is an item'
|
|
||||||
|
|
||||||
a = Foo()
|
|
||||||
b = 42
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
b_symbol = ValueSymbol('b', b)
|
|
||||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
|
|
||||||
|
|
||||||
def test_item_access_after_attribute_access(self):
|
|
||||||
class Foo(object):
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return 'this is an item'
|
|
||||||
|
|
||||||
class Bar(object):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.b = Foo()
|
|
||||||
|
|
||||||
a = Bar()
|
|
||||||
c = 42
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
c_symbol = ValueSymbol('c', c)
|
|
||||||
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
|
|
||||||
a_b_c_symbol = SubscriptSymbol(a_b_symbol, c_symbol)
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertEqual(c_symbol.maybe_compute_value(), c)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
|
|
||||||
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a.b[c])
|
|
||||||
|
|
||||||
def test_attribute_access_after_item_access(self):
|
|
||||||
class Bar(object):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.c = object()
|
|
||||||
|
|
||||||
item = Bar()
|
|
||||||
|
|
||||||
class Foo(object):
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return item
|
|
||||||
|
|
||||||
a = Foo()
|
|
||||||
b = 42
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
b_symbol = ValueSymbol('b', b)
|
|
||||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
|
||||||
a_b_c_symbol = AttributeAccessSymbol(a_b_symbol, 'c')
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
|
|
||||||
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a[b].c)
|
|
||||||
|
|
||||||
def test_item_access_after_item_access(self):
|
|
||||||
class Bar(object):
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return 'this is an item'
|
|
||||||
|
|
||||||
item = Bar()
|
|
||||||
|
|
||||||
class Foo(object):
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return item
|
|
||||||
|
|
||||||
a = Foo()
|
|
||||||
b = 42
|
|
||||||
c = 43
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
b_symbol = ValueSymbol('b', b)
|
|
||||||
c_symbol = ValueSymbol('b', c)
|
|
||||||
a_b_symbol = SubscriptSymbol(a_symbol, b_symbol)
|
|
||||||
a_b_c_symbol = SubscriptSymbol(a_b_symbol, c_symbol)
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertEqual(b_symbol.maybe_compute_value(), b)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a[b])
|
|
||||||
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a[b][c])
|
|
||||||
|
|
||||||
def test_attribute_access_after_attribute_access(self):
|
|
||||||
class Bar(object):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.c = object()
|
|
||||||
|
|
||||||
class Foo(object):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.b = Bar()
|
|
||||||
|
|
||||||
a = Foo()
|
|
||||||
a_symbol = ValueSymbol('a', a)
|
|
||||||
a_b_symbol = AttributeAccessSymbol(a_symbol, 'b')
|
|
||||||
a_b_c_symbol = AttributeAccessSymbol(a_b_symbol, 'c')
|
|
||||||
|
|
||||||
self.assertEqual(a_symbol.maybe_compute_value(), a)
|
|
||||||
self.assertEqual(a_b_symbol.maybe_compute_value(), a.b)
|
|
||||||
self.assertEqual(a_b_c_symbol.maybe_compute_value(), a.b.c)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test.main()
|
|
@ -19,6 +19,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
def ld(v):
|
||||||
|
"""Load variable operator."""
|
||||||
|
if isinstance(v, Undefined):
|
||||||
|
return v.read()
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class Undefined(object):
|
class Undefined(object):
|
||||||
"""Represents an undefined symbol in Python.
|
"""Represents an undefined symbol in Python.
|
||||||
|
|
||||||
@ -51,6 +58,10 @@ class Undefined(object):
|
|||||||
def __init__(self, symbol_name):
|
def __init__(self, symbol_name):
|
||||||
self.symbol_name = symbol_name
|
self.symbol_name = symbol_name
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
raise UnboundLocalError("'{}' is used before assignment".format(
|
||||||
|
self.symbol_name))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.symbol_name
|
return self.symbol_name
|
||||||
|
|
||||||
@ -66,34 +77,7 @@ class Undefined(object):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def is_undefined(value):
|
|
||||||
"""Checks whether Autograph has determined that a given value is undefined.
|
|
||||||
|
|
||||||
This only works in places where Autograph reifies undefined symbols. Note that
|
|
||||||
if this function is passed a truly undefined symbol the call-site will raise
|
|
||||||
NameError.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: value to test for undefinedness
|
|
||||||
Returns:
|
|
||||||
Boolean, whether the input value is undefined.
|
|
||||||
"""
|
|
||||||
return isinstance(value, Undefined)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(mdan): Refactor as a RetVal object, aggregating the value and do_return.
|
# TODO(mdan): Refactor as a RetVal object, aggregating the value and do_return.
|
||||||
class UndefinedReturnValue(object):
|
class UndefinedReturnValue(object):
|
||||||
"""Represents a default return value from a function (None in Python)."""
|
"""Represents a return value that is undefined."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def retval(value):
|
|
||||||
"""Returns the actual value that a return statement should produce."""
|
|
||||||
if isinstance(value, UndefinedReturnValue):
|
|
||||||
return None
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def is_undefined_return(value):
|
|
||||||
"""Checks whether `value` is the default return value."""
|
|
||||||
return isinstance(value, UndefinedReturnValue)
|
|
@ -18,28 +18,38 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.autograph.operators import special_values
|
from tensorflow.python.autograph.operators import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class SpecialValuesTest(test.TestCase):
|
class SpecialValuesTest(test.TestCase):
|
||||||
|
|
||||||
def test_undefined(self):
|
def test_undefined(self):
|
||||||
undefined_symbol = special_values.Undefined('name')
|
undefined_symbol = variables.Undefined('name')
|
||||||
self.assertEqual(undefined_symbol.symbol_name, 'name')
|
undefined_symbol2 = variables.Undefined('name')
|
||||||
|
|
||||||
undefined_symbol2 = special_values.Undefined('name')
|
self.assertEqual(undefined_symbol.symbol_name, 'name')
|
||||||
|
self.assertEqual(undefined_symbol2.symbol_name, 'name')
|
||||||
self.assertNotEqual(undefined_symbol, undefined_symbol2)
|
self.assertNotEqual(undefined_symbol, undefined_symbol2)
|
||||||
|
|
||||||
self.assertTrue(special_values.is_undefined(undefined_symbol))
|
|
||||||
self.assertTrue(special_values.is_undefined(undefined_symbol2))
|
|
||||||
|
|
||||||
def test_undefined_operations(self):
|
def test_undefined_operations(self):
|
||||||
undefined_symbol = special_values.Undefined('name')
|
undefined_symbol = variables.Undefined('name')
|
||||||
|
|
||||||
|
self.assertIsInstance(undefined_symbol.foo, variables.Undefined)
|
||||||
|
self.assertIsInstance(undefined_symbol[0], variables.Undefined)
|
||||||
|
self.assertNotIsInstance(undefined_symbol.__class__, variables.Undefined)
|
||||||
|
|
||||||
|
def test_read(self):
|
||||||
|
self.assertEqual(variables.ld(1), 1)
|
||||||
|
o = object()
|
||||||
|
self.assertEqual(variables.ld(o), o)
|
||||||
|
|
||||||
|
self.assertIsNone(variables.ld(None))
|
||||||
|
|
||||||
|
def test_read_undefined(self):
|
||||||
|
with self.assertRaisesRegex(UnboundLocalError, 'used before assignment'):
|
||||||
|
variables.ld(variables.Undefined('a'))
|
||||||
|
|
||||||
self.assertTrue(special_values.is_undefined(undefined_symbol.foo))
|
|
||||||
self.assertTrue(special_values.is_undefined(undefined_symbol[0]))
|
|
||||||
self.assertFalse(special_values.is_undefined(undefined_symbol.__class__))
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
@ -78,6 +78,18 @@ class ReachingDefinitionsAnalyzerTest(
|
|||||||
|
|
||||||
self.assertSameDef(local_body[1].test, local_body[2].value.elts[0])
|
self.assertSameDef(local_body[1].test, local_body[2].value.elts[0])
|
||||||
|
|
||||||
|
# Note: the function name is is visible inside the function body. But it's
|
||||||
|
# a closure variable, not a local.
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
#
|
||||||
|
# >>> def f():
|
||||||
|
# ... print(f)
|
||||||
|
# >>> g = f
|
||||||
|
# >>> f = 'something else'
|
||||||
|
# >>> g()
|
||||||
|
# something else
|
||||||
|
#
|
||||||
self.assertHasDefinedIn(local_body[1], ('a', 'b'))
|
self.assertHasDefinedIn(local_body[1], ('a', 'b'))
|
||||||
|
|
||||||
|
|
||||||
|
@ -255,6 +255,9 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
|
|||||||
|
|
||||||
inner_fn_body = fn_body[1].body[1].body
|
inner_fn_body = fn_body[1].body[1].body
|
||||||
def_of_a_in_foo = inner_fn_body[0].value
|
def_of_a_in_foo = inner_fn_body[0].value
|
||||||
|
# Even though `a` is visible in the inner functio above, the late binding
|
||||||
|
# makes it impossible to assume that the same value will be visible at
|
||||||
|
# call time.
|
||||||
self.assertHasDefs(def_of_a_in_foo, 0)
|
self.assertHasDefs(def_of_a_in_foo, 0)
|
||||||
|
|
||||||
def test_nested_functions_isolation(self):
|
def test_nested_functions_isolation(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user