From 1c1d4b619a3ee0a45d26edacdee591648e911314 Mon Sep 17 00:00:00 2001
From: Dan Moldovan <mdan@google.com>
Date: Wed, 27 May 2020 14:10:51 -0700
Subject: [PATCH] Uniformize the handling of undefined simple and composite
 names in control flow.

PiperOrigin-RevId: 313461038
Change-Id: Ic70f11291dfa6da52073ec4cacecda883a4d126c
---
 .../autograph/converters/control_flow.py      | 46 ++++++++++++-------
 .../autograph/converters/control_flow_test.py | 26 +++++------
 .../python/autograph/operators/__init__.py    |  1 +
 .../python/autograph/operators/variables.py   | 25 ++++++++++
 4 files changed, 67 insertions(+), 31 deletions(-)

diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
index 673781e47dd..b54770cbd28 100644
--- a/tensorflow/python/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -72,31 +72,43 @@ class ControlFlowTransformer(converter.Base):
     return results
 
   def _create_state_functions(
-      self, loop_vars, nonlocal_declarations, getter_name, setter_name):
-    if loop_vars:
-      template = """
-        def getter_name():
-          return state_vars,
-        def setter_name(vars_):
-          nonlocal_declarations
-          state_vars, = vars_
-      """
-      return templates.replace(
-          template,
-          nonlocal_declarations=nonlocal_declarations,
-          getter_name=getter_name,
-          setter_name=setter_name,
-          state_vars=tuple(loop_vars))
-    else:
+      self, block_vars, nonlocal_declarations, getter_name, setter_name):
+    if not block_vars:
       template = """
         def getter_name():
           return ()
-        def setter_name(loop_vars):
+        def setter_name(block_vars):
           pass
       """
       return templates.replace(
           template, getter_name=getter_name, setter_name=setter_name)
 
+    guarded_block_vars = []
+    for v in block_vars:
+      if v.is_simple():
+        guarded_block_vars.append(v)
+      else:
+        guarded_block_vars.append(
+            templates.replace_as_expression(
+                'ag__.ldu(lambda: var_, name)',
+                var_=v,
+                name=gast.Constant(str(v), kind=None)))
+
+    template = """
+      def getter_name():
+        return guarded_state_vars,
+      def setter_name(vars_):
+        nonlocal_declarations
+        state_vars, = vars_
+    """
+    return templates.replace(
+        template,
+        nonlocal_declarations=nonlocal_declarations,
+        getter_name=getter_name,
+        guarded_state_vars=guarded_block_vars,
+        setter_name=setter_name,
+        state_vars=tuple(block_vars))
+
   def _create_loop_options(self, node):
     if not anno.hasanno(node, anno.Basic.DIRECTIVES):
       return gast.Dict([], [])
diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py
index 935e2cec4b8..f0681128698 100644
--- a/tensorflow/python/autograph/converters/control_flow_test.py
+++ b/tensorflow/python/autograph/converters/control_flow_test.py
@@ -189,9 +189,9 @@ class WhileStatementTest(ControlFlowTestBase):
         symbols={'TestClass': TestClass})
     with self.converted(
         test_fn, control_flow, {'TestClass': TestClass}) as result:
-      # TODO(b/128519776): Better error message.
-      with self.assertRaisesRegex(AttributeError, 'subattr'):
-        result.test_fn(constant_op.constant(0), constant_op.constant(5))
+      with self.assertRaisesRegex(
+          ValueError, "'tc.subattr' must be defined before the loop"):
+        result.test_fn(constant_op.constant(0), 0)
 
   def test_composite_state_slice_initialized_in_loop(self):
 
@@ -209,9 +209,9 @@ class WhileStatementTest(ControlFlowTestBase):
     self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
                                  {'subkey': 14})
     with self.converted(test_fn, control_flow, {}) as result:
-      # TODO(b/128519776): Better error message.
-      with self.assertRaisesRegex(KeyError, 'subkey'):
-        result.test_fn(constant_op.constant(0), constant_op.constant(5))
+      with self.assertRaisesRegex(
+          ValueError, r"'d\[k\]' must be defined before the loop"):
+        result.test_fn(constant_op.constant(0), 0)
 
   def test_composite_state_literal_slice_initialized_in_loop(self):
 
@@ -228,9 +228,9 @@ class WhileStatementTest(ControlFlowTestBase):
     self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
                                  {'subkey': 14})
     with self.converted(test_fn, control_flow, {}) as result:
-      # TODO(b/128519776): Better error message.
-      with self.assertRaisesRegex(KeyError, 'subkey'):
-        result.test_fn(constant_op.constant(0), constant_op.constant(5))
+      with self.assertRaisesRegex(
+          ValueError, r"'d\['subkey'\]' must be defined before the loop"):
+        result.test_fn(constant_op.constant(0), 0)
 
   def test_composite_state_slice_aliased_to_local(self):
 
@@ -245,7 +245,7 @@ class WhileStatementTest(ControlFlowTestBase):
     self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
                                  {'subkey': 11})
     with self.converted(test_fn, control_flow, {}) as result:
-      # TODO(b/128519776): Better error message.
+      # TODO(b/136999953): Better error message.
       # Note that this error happens at execution time.
       with self.assertRaises(errors.InaccessibleTensorError):
         graph_fn = def_function.function(result.test_fn, autograph=False)
@@ -671,11 +671,9 @@ class ForStatementTest(ControlFlowTestBase):
         symbols={'TestClass': TestClass})
     with self.converted(
         test_fn, control_flow, {'TestClass': TestClass}) as result:
-      # TODO(b/128519776): Better error message.
       with self.assertRaisesRegex(
-          AttributeError, '\'TestClass\' object has no attribute \'x\''):
-        result.test_fn(
-            constant_op.constant(list(range(5))), constant_op.constant(5))
+          ValueError, "'tc.x' must be defined before the loop"):
+        result.test_fn(constant_op.constant(list(range(5))), 0)
 
   def test_tuple_unpacking(self):
     def test_fn(x_list):
diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py
index 8ac4e1d8bb3..a42dcf326c3 100644
--- a/tensorflow/python/autograph/operators/__init__.py
+++ b/tensorflow/python/autograph/operators/__init__.py
@@ -62,5 +62,6 @@ from tensorflow.python.autograph.operators.slices import get_item
 from tensorflow.python.autograph.operators.slices import GetItemOpts
 from tensorflow.python.autograph.operators.slices import set_item
 from tensorflow.python.autograph.operators.variables import ld
+from tensorflow.python.autograph.operators.variables import ldu
 from tensorflow.python.autograph.operators.variables import Undefined
 from tensorflow.python.autograph.operators.variables import UndefinedReturnValue
diff --git a/tensorflow/python/autograph/operators/variables.py b/tensorflow/python/autograph/operators/variables.py
index 150f64e1758..c3bedc3fecf 100644
--- a/tensorflow/python/autograph/operators/variables.py
+++ b/tensorflow/python/autograph/operators/variables.py
@@ -26,6 +26,31 @@ def ld(v):
   return v
 
 
+def ldu(load_v, name):
+  """Load variable operator that returns Undefined when failing to evaluate.
+
+  Note: the name ("load or return undefined") is abbreviated to minimize
+  the amount of clutter in generated code.
+
+  This variant of `ld` is useful when loading symbols that may be undefined at
+  runtime, such as composite symbols, and whether they are defined or not cannot
+  be determined statically. For example `d['a']` is undefined when `d` is an
+  empty dict.
+
+  Args:
+    load_v: Lambda that executes the actual read.
+    name: Human-readable name of the symbol being read.
+  Returns:
+    Either the value of the symbol, or Undefined, if the symbol is not fully
+    defined.
+  """
+  try:
+    # TODO(mdan): Use locals()/globals() here.
+    return load_v()
+  except (KeyError, AttributeError, NameError):
+    return Undefined(name)
+
+
 class Undefined(object):
   """Represents an undefined symbol in Python.