Uniformize the handling of undefined simple and composite names in control flow.

PiperOrigin-RevId: 313461038
Change-Id: Ic70f11291dfa6da52073ec4cacecda883a4d126c
This commit is contained in:
Dan Moldovan 2020-05-27 14:10:51 -07:00 committed by TensorFlower Gardener
parent df1131c395
commit 1c1d4b619a
4 changed files with 67 additions and 31 deletions

View File

@ -72,31 +72,43 @@ class ControlFlowTransformer(converter.Base):
return results
def _create_state_functions(
self, loop_vars, nonlocal_declarations, getter_name, setter_name):
if loop_vars:
template = """
def getter_name():
return state_vars,
def setter_name(vars_):
nonlocal_declarations
state_vars, = vars_
"""
return templates.replace(
template,
nonlocal_declarations=nonlocal_declarations,
getter_name=getter_name,
setter_name=setter_name,
state_vars=tuple(loop_vars))
else:
self, block_vars, nonlocal_declarations, getter_name, setter_name):
if not block_vars:
template = """
def getter_name():
return ()
def setter_name(loop_vars):
def setter_name(block_vars):
pass
"""
return templates.replace(
template, getter_name=getter_name, setter_name=setter_name)
guarded_block_vars = []
for v in block_vars:
if v.is_simple():
guarded_block_vars.append(v)
else:
guarded_block_vars.append(
templates.replace_as_expression(
'ag__.ldu(lambda: var_, name)',
var_=v,
name=gast.Constant(str(v), kind=None)))
template = """
def getter_name():
return guarded_state_vars,
def setter_name(vars_):
nonlocal_declarations
state_vars, = vars_
"""
return templates.replace(
template,
nonlocal_declarations=nonlocal_declarations,
getter_name=getter_name,
guarded_state_vars=guarded_block_vars,
setter_name=setter_name,
state_vars=tuple(block_vars))
def _create_loop_options(self, node):
if not anno.hasanno(node, anno.Basic.DIRECTIVES):
return gast.Dict([], [])

View File

@ -189,9 +189,9 @@ class WhileStatementTest(ControlFlowTestBase):
symbols={'TestClass': TestClass})
with self.converted(
test_fn, control_flow, {'TestClass': TestClass}) as result:
# TODO(b/128519776): Better error message.
with self.assertRaisesRegex(AttributeError, 'subattr'):
result.test_fn(constant_op.constant(0), constant_op.constant(5))
with self.assertRaisesRegex(
ValueError, "'tc.subattr' must be defined before the loop"):
result.test_fn(constant_op.constant(0), 0)
def test_composite_state_slice_initialized_in_loop(self):
@ -209,9 +209,9 @@ class WhileStatementTest(ControlFlowTestBase):
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
{'subkey': 14})
with self.converted(test_fn, control_flow, {}) as result:
# TODO(b/128519776): Better error message.
with self.assertRaisesRegex(KeyError, 'subkey'):
result.test_fn(constant_op.constant(0), constant_op.constant(5))
with self.assertRaisesRegex(
ValueError, r"'d\[k\]' must be defined before the loop"):
result.test_fn(constant_op.constant(0), 0)
def test_composite_state_literal_slice_initialized_in_loop(self):
@ -228,9 +228,9 @@ class WhileStatementTest(ControlFlowTestBase):
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
{'subkey': 14})
with self.converted(test_fn, control_flow, {}) as result:
# TODO(b/128519776): Better error message.
with self.assertRaisesRegex(KeyError, 'subkey'):
result.test_fn(constant_op.constant(0), constant_op.constant(5))
with self.assertRaisesRegex(
ValueError, r"'d\['subkey'\]' must be defined before the loop"):
result.test_fn(constant_op.constant(0), 0)
def test_composite_state_slice_aliased_to_local(self):
@ -245,7 +245,7 @@ class WhileStatementTest(ControlFlowTestBase):
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
{'subkey': 11})
with self.converted(test_fn, control_flow, {}) as result:
# TODO(b/128519776): Better error message.
# TODO(b/136999953): Better error message.
# Note that this error happens at execution time.
with self.assertRaises(errors.InaccessibleTensorError):
graph_fn = def_function.function(result.test_fn, autograph=False)
@ -671,11 +671,9 @@ class ForStatementTest(ControlFlowTestBase):
symbols={'TestClass': TestClass})
with self.converted(
test_fn, control_flow, {'TestClass': TestClass}) as result:
# TODO(b/128519776): Better error message.
with self.assertRaisesRegex(
AttributeError, '\'TestClass\' object has no attribute \'x\''):
result.test_fn(
constant_op.constant(list(range(5))), constant_op.constant(5))
ValueError, "'tc.x' must be defined before the loop"):
result.test_fn(constant_op.constant(list(range(5))), 0)
def test_tuple_unpacking(self):
def test_fn(x_list):

View File

@ -62,5 +62,6 @@ from tensorflow.python.autograph.operators.slices import get_item
from tensorflow.python.autograph.operators.slices import GetItemOpts
from tensorflow.python.autograph.operators.slices import set_item
from tensorflow.python.autograph.operators.variables import ld
from tensorflow.python.autograph.operators.variables import ldu
from tensorflow.python.autograph.operators.variables import Undefined
from tensorflow.python.autograph.operators.variables import UndefinedReturnValue

View File

@ -26,6 +26,31 @@ def ld(v):
return v
def ldu(load_v, name):
"""Load variable operator that returns Undefined when failing to evaluate.
Note: the name ("load or return undefined") is abbreviated to minimize
the amount of clutter in generated code.
This variant of `ld` is useful when loading symbols that may be undefined at
runtime, such as composite symbols, and whether they are defined or not cannot
be determined statically. For example `d['a']` is undefined when `d` is an
empty dict.
Args:
load_v: Lambda that executes the actual read.
name: Human-readable name of the symbol being read.
Returns:
Either the value of the symbol, or Undefined, if the symbol is not fully
defined.
"""
try:
# TODO(mdan): Use locals()/globals() here.
return load_v()
except (KeyError, AttributeError, NameError):
return Undefined(name)
class Undefined(object):
"""Represents an undefined symbol in Python.