Uniformize the handling of undefined simple and composite names in control flow.
PiperOrigin-RevId: 313461038 Change-Id: Ic70f11291dfa6da52073ec4cacecda883a4d126c
This commit is contained in:
parent
df1131c395
commit
1c1d4b619a
@ -72,31 +72,43 @@ class ControlFlowTransformer(converter.Base):
|
||||
return results
|
||||
|
||||
def _create_state_functions(
|
||||
self, loop_vars, nonlocal_declarations, getter_name, setter_name):
|
||||
if loop_vars:
|
||||
template = """
|
||||
def getter_name():
|
||||
return state_vars,
|
||||
def setter_name(vars_):
|
||||
nonlocal_declarations
|
||||
state_vars, = vars_
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
nonlocal_declarations=nonlocal_declarations,
|
||||
getter_name=getter_name,
|
||||
setter_name=setter_name,
|
||||
state_vars=tuple(loop_vars))
|
||||
else:
|
||||
self, block_vars, nonlocal_declarations, getter_name, setter_name):
|
||||
if not block_vars:
|
||||
template = """
|
||||
def getter_name():
|
||||
return ()
|
||||
def setter_name(loop_vars):
|
||||
def setter_name(block_vars):
|
||||
pass
|
||||
"""
|
||||
return templates.replace(
|
||||
template, getter_name=getter_name, setter_name=setter_name)
|
||||
|
||||
guarded_block_vars = []
|
||||
for v in block_vars:
|
||||
if v.is_simple():
|
||||
guarded_block_vars.append(v)
|
||||
else:
|
||||
guarded_block_vars.append(
|
||||
templates.replace_as_expression(
|
||||
'ag__.ldu(lambda: var_, name)',
|
||||
var_=v,
|
||||
name=gast.Constant(str(v), kind=None)))
|
||||
|
||||
template = """
|
||||
def getter_name():
|
||||
return guarded_state_vars,
|
||||
def setter_name(vars_):
|
||||
nonlocal_declarations
|
||||
state_vars, = vars_
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
nonlocal_declarations=nonlocal_declarations,
|
||||
getter_name=getter_name,
|
||||
guarded_state_vars=guarded_block_vars,
|
||||
setter_name=setter_name,
|
||||
state_vars=tuple(block_vars))
|
||||
|
||||
def _create_loop_options(self, node):
|
||||
if not anno.hasanno(node, anno.Basic.DIRECTIVES):
|
||||
return gast.Dict([], [])
|
||||
|
@ -189,9 +189,9 @@ class WhileStatementTest(ControlFlowTestBase):
|
||||
symbols={'TestClass': TestClass})
|
||||
with self.converted(
|
||||
test_fn, control_flow, {'TestClass': TestClass}) as result:
|
||||
# TODO(b/128519776): Better error message.
|
||||
with self.assertRaisesRegex(AttributeError, 'subattr'):
|
||||
result.test_fn(constant_op.constant(0), constant_op.constant(5))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "'tc.subattr' must be defined before the loop"):
|
||||
result.test_fn(constant_op.constant(0), 0)
|
||||
|
||||
def test_composite_state_slice_initialized_in_loop(self):
|
||||
|
||||
@ -209,9 +209,9 @@ class WhileStatementTest(ControlFlowTestBase):
|
||||
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
|
||||
{'subkey': 14})
|
||||
with self.converted(test_fn, control_flow, {}) as result:
|
||||
# TODO(b/128519776): Better error message.
|
||||
with self.assertRaisesRegex(KeyError, 'subkey'):
|
||||
result.test_fn(constant_op.constant(0), constant_op.constant(5))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"'d\[k\]' must be defined before the loop"):
|
||||
result.test_fn(constant_op.constant(0), 0)
|
||||
|
||||
def test_composite_state_literal_slice_initialized_in_loop(self):
|
||||
|
||||
@ -228,9 +228,9 @@ class WhileStatementTest(ControlFlowTestBase):
|
||||
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
|
||||
{'subkey': 14})
|
||||
with self.converted(test_fn, control_flow, {}) as result:
|
||||
# TODO(b/128519776): Better error message.
|
||||
with self.assertRaisesRegex(KeyError, 'subkey'):
|
||||
result.test_fn(constant_op.constant(0), constant_op.constant(5))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"'d\['subkey'\]' must be defined before the loop"):
|
||||
result.test_fn(constant_op.constant(0), 0)
|
||||
|
||||
def test_composite_state_slice_aliased_to_local(self):
|
||||
|
||||
@ -245,7 +245,7 @@ class WhileStatementTest(ControlFlowTestBase):
|
||||
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
|
||||
{'subkey': 11})
|
||||
with self.converted(test_fn, control_flow, {}) as result:
|
||||
# TODO(b/128519776): Better error message.
|
||||
# TODO(b/136999953): Better error message.
|
||||
# Note that this error happens at execution time.
|
||||
with self.assertRaises(errors.InaccessibleTensorError):
|
||||
graph_fn = def_function.function(result.test_fn, autograph=False)
|
||||
@ -671,11 +671,9 @@ class ForStatementTest(ControlFlowTestBase):
|
||||
symbols={'TestClass': TestClass})
|
||||
with self.converted(
|
||||
test_fn, control_flow, {'TestClass': TestClass}) as result:
|
||||
# TODO(b/128519776): Better error message.
|
||||
with self.assertRaisesRegex(
|
||||
AttributeError, '\'TestClass\' object has no attribute \'x\''):
|
||||
result.test_fn(
|
||||
constant_op.constant(list(range(5))), constant_op.constant(5))
|
||||
ValueError, "'tc.x' must be defined before the loop"):
|
||||
result.test_fn(constant_op.constant(list(range(5))), 0)
|
||||
|
||||
def test_tuple_unpacking(self):
|
||||
def test_fn(x_list):
|
||||
|
@ -62,5 +62,6 @@ from tensorflow.python.autograph.operators.slices import get_item
|
||||
from tensorflow.python.autograph.operators.slices import GetItemOpts
|
||||
from tensorflow.python.autograph.operators.slices import set_item
|
||||
from tensorflow.python.autograph.operators.variables import ld
|
||||
from tensorflow.python.autograph.operators.variables import ldu
|
||||
from tensorflow.python.autograph.operators.variables import Undefined
|
||||
from tensorflow.python.autograph.operators.variables import UndefinedReturnValue
|
||||
|
@ -26,6 +26,31 @@ def ld(v):
|
||||
return v
|
||||
|
||||
|
||||
def ldu(load_v, name):
|
||||
"""Load variable operator that returns Undefined when failing to evaluate.
|
||||
|
||||
Note: the name ("load or return undefined") is abbreviated to minimize
|
||||
the amount of clutter in generated code.
|
||||
|
||||
This variant of `ld` is useful when loading symbols that may be undefined at
|
||||
runtime, such as composite symbols, and whether they are defined or not cannot
|
||||
be determined statically. For example `d['a']` is undefined when `d` is an
|
||||
empty dict.
|
||||
|
||||
Args:
|
||||
load_v: Lambda that executes the actual read.
|
||||
name: Human-readable name of the symbol being read.
|
||||
Returns:
|
||||
Either the value of the symbol, or Undefined, if the symbol is not fully
|
||||
defined.
|
||||
"""
|
||||
try:
|
||||
# TODO(mdan): Use locals()/globals() here.
|
||||
return load_v()
|
||||
except (KeyError, AttributeError, NameError):
|
||||
return Undefined(name)
|
||||
|
||||
|
||||
class Undefined(object):
|
||||
"""Represents an undefined symbol in Python.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user