Uniformize the handling of undefined simple and composite names in control flow.

PiperOrigin-RevId: 313461038
Change-Id: Ic70f11291dfa6da52073ec4cacecda883a4d126c
This commit is contained in:
Dan Moldovan 2020-05-27 14:10:51 -07:00 committed by TensorFlower Gardener
parent df1131c395
commit 1c1d4b619a
4 changed files with 67 additions and 31 deletions
tensorflow/python/autograph

View File

@ -72,11 +72,31 @@ class ControlFlowTransformer(converter.Base):
return results return results
def _create_state_functions( def _create_state_functions(
self, loop_vars, nonlocal_declarations, getter_name, setter_name): self, block_vars, nonlocal_declarations, getter_name, setter_name):
if loop_vars: if not block_vars:
template = """ template = """
def getter_name(): def getter_name():
return state_vars, return ()
def setter_name(block_vars):
pass
"""
return templates.replace(
template, getter_name=getter_name, setter_name=setter_name)
guarded_block_vars = []
for v in block_vars:
if v.is_simple():
guarded_block_vars.append(v)
else:
guarded_block_vars.append(
templates.replace_as_expression(
'ag__.ldu(lambda: var_, name)',
var_=v,
name=gast.Constant(str(v), kind=None)))
template = """
def getter_name():
return guarded_state_vars,
def setter_name(vars_): def setter_name(vars_):
nonlocal_declarations nonlocal_declarations
state_vars, = vars_ state_vars, = vars_
@ -85,17 +105,9 @@ class ControlFlowTransformer(converter.Base):
template, template,
nonlocal_declarations=nonlocal_declarations, nonlocal_declarations=nonlocal_declarations,
getter_name=getter_name, getter_name=getter_name,
guarded_state_vars=guarded_block_vars,
setter_name=setter_name, setter_name=setter_name,
state_vars=tuple(loop_vars)) state_vars=tuple(block_vars))
else:
template = """
def getter_name():
return ()
def setter_name(loop_vars):
pass
"""
return templates.replace(
template, getter_name=getter_name, setter_name=setter_name)
def _create_loop_options(self, node): def _create_loop_options(self, node):
if not anno.hasanno(node, anno.Basic.DIRECTIVES): if not anno.hasanno(node, anno.Basic.DIRECTIVES):

View File

@ -189,9 +189,9 @@ class WhileStatementTest(ControlFlowTestBase):
symbols={'TestClass': TestClass}) symbols={'TestClass': TestClass})
with self.converted( with self.converted(
test_fn, control_flow, {'TestClass': TestClass}) as result: test_fn, control_flow, {'TestClass': TestClass}) as result:
# TODO(b/128519776): Better error message. with self.assertRaisesRegex(
with self.assertRaisesRegex(AttributeError, 'subattr'): ValueError, "'tc.subattr' must be defined before the loop"):
result.test_fn(constant_op.constant(0), constant_op.constant(5)) result.test_fn(constant_op.constant(0), 0)
def test_composite_state_slice_initialized_in_loop(self): def test_composite_state_slice_initialized_in_loop(self):
@ -209,9 +209,9 @@ class WhileStatementTest(ControlFlowTestBase):
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
{'subkey': 14}) {'subkey': 14})
with self.converted(test_fn, control_flow, {}) as result: with self.converted(test_fn, control_flow, {}) as result:
# TODO(b/128519776): Better error message. with self.assertRaisesRegex(
with self.assertRaisesRegex(KeyError, 'subkey'): ValueError, r"'d\[k\]' must be defined before the loop"):
result.test_fn(constant_op.constant(0), constant_op.constant(5)) result.test_fn(constant_op.constant(0), 0)
def test_composite_state_literal_slice_initialized_in_loop(self): def test_composite_state_literal_slice_initialized_in_loop(self):
@ -228,9 +228,9 @@ class WhileStatementTest(ControlFlowTestBase):
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
{'subkey': 14}) {'subkey': 14})
with self.converted(test_fn, control_flow, {}) as result: with self.converted(test_fn, control_flow, {}) as result:
# TODO(b/128519776): Better error message. with self.assertRaisesRegex(
with self.assertRaisesRegex(KeyError, 'subkey'): ValueError, r"'d\['subkey'\]' must be defined before the loop"):
result.test_fn(constant_op.constant(0), constant_op.constant(5)) result.test_fn(constant_op.constant(0), 0)
def test_composite_state_slice_aliased_to_local(self): def test_composite_state_slice_aliased_to_local(self):
@ -245,7 +245,7 @@ class WhileStatementTest(ControlFlowTestBase):
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
{'subkey': 11}) {'subkey': 11})
with self.converted(test_fn, control_flow, {}) as result: with self.converted(test_fn, control_flow, {}) as result:
# TODO(b/128519776): Better error message. # TODO(b/136999953): Better error message.
# Note that this error happens at execution time. # Note that this error happens at execution time.
with self.assertRaises(errors.InaccessibleTensorError): with self.assertRaises(errors.InaccessibleTensorError):
graph_fn = def_function.function(result.test_fn, autograph=False) graph_fn = def_function.function(result.test_fn, autograph=False)
@ -671,11 +671,9 @@ class ForStatementTest(ControlFlowTestBase):
symbols={'TestClass': TestClass}) symbols={'TestClass': TestClass})
with self.converted( with self.converted(
test_fn, control_flow, {'TestClass': TestClass}) as result: test_fn, control_flow, {'TestClass': TestClass}) as result:
# TODO(b/128519776): Better error message.
with self.assertRaisesRegex( with self.assertRaisesRegex(
AttributeError, '\'TestClass\' object has no attribute \'x\''): ValueError, "'tc.x' must be defined before the loop"):
result.test_fn( result.test_fn(constant_op.constant(list(range(5))), 0)
constant_op.constant(list(range(5))), constant_op.constant(5))
def test_tuple_unpacking(self): def test_tuple_unpacking(self):
def test_fn(x_list): def test_fn(x_list):

View File

@ -62,5 +62,6 @@ from tensorflow.python.autograph.operators.slices import get_item
from tensorflow.python.autograph.operators.slices import GetItemOpts from tensorflow.python.autograph.operators.slices import GetItemOpts
from tensorflow.python.autograph.operators.slices import set_item from tensorflow.python.autograph.operators.slices import set_item
from tensorflow.python.autograph.operators.variables import ld from tensorflow.python.autograph.operators.variables import ld
from tensorflow.python.autograph.operators.variables import ldu
from tensorflow.python.autograph.operators.variables import Undefined from tensorflow.python.autograph.operators.variables import Undefined
from tensorflow.python.autograph.operators.variables import UndefinedReturnValue from tensorflow.python.autograph.operators.variables import UndefinedReturnValue

View File

@ -26,6 +26,31 @@ def ld(v):
return v return v
def ldu(load_v, name):
"""Load variable operator that returns Undefined when failing to evaluate.
Note: the name ("load or return undefined") is abbreviated to minimize
the amount of clutter in generated code.
This variant of `ld` is useful when loading symbols that may be undefined at
runtime, such as composite symbols, and whether they are defined or not cannot
be determined statically. For example `d['a']` is undefined when `d` is an
empty dict.
Args:
load_v: Lambda that executes the actual read.
name: Human-readable name of the symbol being read.
Returns:
Either the value of the symbol, or Undefined, if the symbol is not fully
defined.
"""
try:
# TODO(mdan): Use locals()/globals() here.
return load_v()
except (KeyError, AttributeError, NameError):
return Undefined(name)
class Undefined(object): class Undefined(object):
"""Represents an undefined symbol in Python. """Represents an undefined symbol in Python.