diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index 673781e47dd..b54770cbd28 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -72,31 +72,43 @@ class ControlFlowTransformer(converter.Base): return results def _create_state_functions( - self, loop_vars, nonlocal_declarations, getter_name, setter_name): - if loop_vars: - template = """ - def getter_name(): - return state_vars, - def setter_name(vars_): - nonlocal_declarations - state_vars, = vars_ - """ - return templates.replace( - template, - nonlocal_declarations=nonlocal_declarations, - getter_name=getter_name, - setter_name=setter_name, - state_vars=tuple(loop_vars)) - else: + self, block_vars, nonlocal_declarations, getter_name, setter_name): + if not block_vars: template = """ def getter_name(): return () - def setter_name(loop_vars): + def setter_name(block_vars): pass """ return templates.replace( template, getter_name=getter_name, setter_name=setter_name) + guarded_block_vars = [] + for v in block_vars: + if v.is_simple(): + guarded_block_vars.append(v) + else: + guarded_block_vars.append( + templates.replace_as_expression( + 'ag__.ldu(lambda: var_, name)', + var_=v, + name=gast.Constant(str(v), kind=None))) + + template = """ + def getter_name(): + return guarded_state_vars, + def setter_name(vars_): + nonlocal_declarations + state_vars, = vars_ + """ + return templates.replace( + template, + nonlocal_declarations=nonlocal_declarations, + getter_name=getter_name, + guarded_state_vars=guarded_block_vars, + setter_name=setter_name, + state_vars=tuple(block_vars)) + def _create_loop_options(self, node): if not anno.hasanno(node, anno.Basic.DIRECTIVES): return gast.Dict([], []) diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py index 935e2cec4b8..f0681128698 100644 --- a/tensorflow/python/autograph/converters/control_flow_test.py +++ b/tensorflow/python/autograph/converters/control_flow_test.py @@ -189,9 +189,9 @@ class WhileStatementTest(ControlFlowTestBase): symbols={'TestClass': TestClass}) with self.converted( test_fn, control_flow, {'TestClass': TestClass}) as result: - # TODO(b/128519776): Better error message. - with self.assertRaisesRegex(AttributeError, 'subattr'): - result.test_fn(constant_op.constant(0), constant_op.constant(5)) + with self.assertRaisesRegex( + ValueError, "'tc.subattr' must be defined before the loop"): + result.test_fn(constant_op.constant(0), 0) def test_composite_state_slice_initialized_in_loop(self): @@ -209,9 +209,9 @@ class WhileStatementTest(ControlFlowTestBase): self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), {'subkey': 14}) with self.converted(test_fn, control_flow, {}) as result: - # TODO(b/128519776): Better error message. - with self.assertRaisesRegex(KeyError, 'subkey'): - result.test_fn(constant_op.constant(0), constant_op.constant(5)) + with self.assertRaisesRegex( + ValueError, r"'d\[k\]' must be defined before the loop"): + result.test_fn(constant_op.constant(0), 0) def test_composite_state_literal_slice_initialized_in_loop(self): @@ -228,9 +228,9 @@ class WhileStatementTest(ControlFlowTestBase): self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), {'subkey': 14}) with self.converted(test_fn, control_flow, {}) as result: - # TODO(b/128519776): Better error message. - with self.assertRaisesRegex(KeyError, 'subkey'): - result.test_fn(constant_op.constant(0), constant_op.constant(5)) + with self.assertRaisesRegex( + ValueError, r"'d\['subkey'\]' must be defined before the loop"): + result.test_fn(constant_op.constant(0), 0) def test_composite_state_slice_aliased_to_local(self): @@ -245,7 +245,7 @@ class WhileStatementTest(ControlFlowTestBase): self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), {'subkey': 11}) with self.converted(test_fn, control_flow, {}) as result: - # TODO(b/128519776): Better error message. + # TODO(b/136999953): Better error message. # Note that this error happens at execution time. with self.assertRaises(errors.InaccessibleTensorError): graph_fn = def_function.function(result.test_fn, autograph=False) @@ -671,11 +671,9 @@ class ForStatementTest(ControlFlowTestBase): symbols={'TestClass': TestClass}) with self.converted( test_fn, control_flow, {'TestClass': TestClass}) as result: - # TODO(b/128519776): Better error message. with self.assertRaisesRegex( - AttributeError, '\'TestClass\' object has no attribute \'x\''): - result.test_fn( - constant_op.constant(list(range(5))), constant_op.constant(5)) + ValueError, "'tc.x' must be defined before the loop"): + result.test_fn(constant_op.constant(list(range(5))), 0) def test_tuple_unpacking(self): def test_fn(x_list): diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py index 8ac4e1d8bb3..a42dcf326c3 100644 --- a/tensorflow/python/autograph/operators/__init__.py +++ b/tensorflow/python/autograph/operators/__init__.py @@ -62,5 +62,6 @@ from tensorflow.python.autograph.operators.slices import get_item from tensorflow.python.autograph.operators.slices import GetItemOpts from tensorflow.python.autograph.operators.slices import set_item from tensorflow.python.autograph.operators.variables import ld +from tensorflow.python.autograph.operators.variables import ldu from tensorflow.python.autograph.operators.variables import Undefined from tensorflow.python.autograph.operators.variables import UndefinedReturnValue diff --git a/tensorflow/python/autograph/operators/variables.py b/tensorflow/python/autograph/operators/variables.py index 150f64e1758..c3bedc3fecf 100644 --- a/tensorflow/python/autograph/operators/variables.py +++ b/tensorflow/python/autograph/operators/variables.py @@ -26,6 +26,31 @@ def ld(v): return v +def ldu(load_v, name): + """Load variable operator that returns Undefined when failing to evaluate. + + Note: the name ("load or return undefined") is abbreviated to minimize + the amount of clutter in generated code. + + This variant of `ld` is useful when loading symbols that may be undefined at + runtime, such as composite symbols, and whether they are defined or not cannot + be determined statically. For example `d['a']` is undefined when `d` is an + empty dict. + + Args: + load_v: Lambda that executes the actual read. + name: Human-readable name of the symbol being read. + Returns: + Either the value of the symbol, or Undefined, if the symbol is not fully + defined. + """ + try: + # TODO(mdan): Use locals()/globals() here. + return load_v() + except (KeyError, AttributeError, NameError): + return Undefined(name) + + class Undefined(object): """Represents an undefined symbol in Python.