From c2d506cc04119a397a4118b37b316f0251fae872 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 May 2019 07:53:44 -0700 Subject: [PATCH] Better error messages for invalid return values in TensorFlow conditionals and loops. Have a specific error message when a return statement appears in one branch of a TensorFlow conditional and not the other. Also have a nicer error message when returning a value from a TensorFlow loop. PiperOrigin-RevId: 247212194 --- .../autograph/converters/return_statements.py | 4 +++- .../python/autograph/operators/__init__.py | 2 ++ .../autograph/operators/control_flow.py | 22 ++++++++++++++++--- .../autograph/operators/special_values.py | 10 +++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py index 3173e676e5d..a53206c867d 100644 --- a/tensorflow/python/autograph/converters/return_statements.py +++ b/tensorflow/python/autograph/converters/return_statements.py @@ -375,8 +375,10 @@ class ReturnStatementsTransformer(converter.Base): if self.default_to_null_return: template = """ do_return_var_name = False - retval_var_name = None + retval_var_name = ag__.UndefinedReturnValue() body + if ag__.is_undefined_return(retval_var_name): + retval_var_name = None return retval_var_name """ else: diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py index 5b3f45de056..bbc684eaf2b 100644 --- a/tensorflow/python/autograph/operators/__init__.py +++ b/tensorflow/python/autograph/operators/__init__.py @@ -72,4 +72,6 @@ from tensorflow.python.autograph.operators.slices import get_item from tensorflow.python.autograph.operators.slices import GetItemOpts from tensorflow.python.autograph.operators.slices import set_item from tensorflow.python.autograph.operators.special_values import is_undefined +from tensorflow.python.autograph.operators.special_values import is_undefined_return from tensorflow.python.autograph.operators.special_values import Undefined +from tensorflow.python.autograph.operators.special_values import UndefinedReturnValue diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index d1428bb524a..5575b4c1911 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -42,6 +42,7 @@ INEFFICIENT_UNROLL_MIN_OPS = 1 def _disallow_undefs_into_loop(*values): + """Ensures that all values in the state are defined when entering a loop.""" undefined = tuple(filter(special_values.is_undefined, values)) if undefined: raise ValueError( @@ -49,6 +50,14 @@ def _disallow_undefs_into_loop(*values): ' before the loop: {}'.format( tuple(s.symbol_name for s in undefined))) + for value in values: + if special_values.is_undefined_return(value): + # Assumption: the loop will only capture the variable which tracks the + # return value if the loop contained a return statement. + # TODO(mdan): This should be checked at the place where return occurs. + raise ValueError( + 'Return statements are not supported within a TensorFlow loop.') + def for_stmt(iter_, extra_test, body, init_state): """Functional form of a for statement. @@ -435,8 +444,8 @@ def if_stmt(cond, body, orelse, get_state, set_state): def tf_if_stmt(cond, body, orelse, get_state, set_state): """Overload of if_stmt that stages a TF cond.""" - body = _wrap_disallow_undefs_in_cond(body, branch_name='if') - orelse = _wrap_disallow_undefs_in_cond(orelse, branch_name='else') + body = _wrap_disallow_undefs_from_cond(body, branch_name='if') + orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else') body = _isolate_state(body, get_state, set_state) orelse = _isolate_state(orelse, get_state, set_state) @@ -484,7 +493,7 @@ def _isolate_state(func, get_state, set_state): return wrapper -def _wrap_disallow_undefs_in_cond(func, branch_name): +def _wrap_disallow_undefs_from_cond(func, branch_name): """Wraps conditional branch to disallow returning undefined symbols.""" def wrapper(): @@ -503,6 +512,13 @@ def _wrap_disallow_undefs_in_cond(func, branch_name): ' statement.'.format(branch_name, tuple(s.symbol_name for s in undefined))) + for result in results_tuple: + if special_values.is_undefined_return(result): + raise ValueError( + 'A value must also be returned from the {} branch. If a value is ' + 'returned from one branch of a conditional a value must be ' + 'returned from all branches.'.format(branch_name)) + return results return wrapper diff --git a/tensorflow/python/autograph/operators/special_values.py b/tensorflow/python/autograph/operators/special_values.py index 13d846fc7cf..a41f516e550 100644 --- a/tensorflow/python/autograph/operators/special_values.py +++ b/tensorflow/python/autograph/operators/special_values.py @@ -64,3 +64,13 @@ def is_undefined(value): Boolean, whether the input value is undefined. """ return isinstance(value, Undefined) + + +class UndefinedReturnValue(object): + """Represents a default return value from a function (None in Python).""" + pass + + +def is_undefined_return(value): + """Checks whether `value` is the default return value.""" + return isinstance(value, UndefinedReturnValue)