Better error messages for invalid return values in TensorFlow conditionals and loops.

Have a specific error message when a return statement appears in one branch of a TensorFlow conditional and not the other. Also have a nicer error message when returning a value from a TensorFlow loop.

PiperOrigin-RevId: 247212194
This commit is contained in:
A. Unique TensorFlower 2019-05-08 07:53:44 -07:00 committed by TensorFlower Gardener
parent 0410cff073
commit c2d506cc04
4 changed files with 34 additions and 4 deletions

View File

@ -375,8 +375,10 @@ class ReturnStatementsTransformer(converter.Base):
if self.default_to_null_return:
template = """
do_return_var_name = False
retval_var_name = None
retval_var_name = ag__.UndefinedReturnValue()
body
if ag__.is_undefined_return(retval_var_name):
retval_var_name = None
return retval_var_name
"""
else:

View File

@ -72,4 +72,6 @@ from tensorflow.python.autograph.operators.slices import get_item
from tensorflow.python.autograph.operators.slices import GetItemOpts
from tensorflow.python.autograph.operators.slices import set_item
from tensorflow.python.autograph.operators.special_values import is_undefined
from tensorflow.python.autograph.operators.special_values import is_undefined_return
from tensorflow.python.autograph.operators.special_values import Undefined
from tensorflow.python.autograph.operators.special_values import UndefinedReturnValue

View File

@ -42,6 +42,7 @@ INEFFICIENT_UNROLL_MIN_OPS = 1
def _disallow_undefs_into_loop(*values):
"""Ensures that all values in the state are defined when entering a loop."""
undefined = tuple(filter(special_values.is_undefined, values))
if undefined:
raise ValueError(
@ -49,6 +50,14 @@ def _disallow_undefs_into_loop(*values):
' before the loop: {}'.format(
tuple(s.symbol_name for s in undefined)))
for value in values:
if special_values.is_undefined_return(value):
# Assumption: the loop will only capture the variable which tracks the
# return value if the loop contained a return statement.
# TODO(mdan): This should be checked at the place where return occurs.
raise ValueError(
'Return statements are not supported within a TensorFlow loop.')
def for_stmt(iter_, extra_test, body, init_state):
"""Functional form of a for statement.
@ -435,8 +444,8 @@ def if_stmt(cond, body, orelse, get_state, set_state):
def tf_if_stmt(cond, body, orelse, get_state, set_state):
"""Overload of if_stmt that stages a TF cond."""
body = _wrap_disallow_undefs_in_cond(body, branch_name='if')
orelse = _wrap_disallow_undefs_in_cond(orelse, branch_name='else')
body = _wrap_disallow_undefs_from_cond(body, branch_name='if')
orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else')
body = _isolate_state(body, get_state, set_state)
orelse = _isolate_state(orelse, get_state, set_state)
@ -484,7 +493,7 @@ def _isolate_state(func, get_state, set_state):
return wrapper
def _wrap_disallow_undefs_in_cond(func, branch_name):
def _wrap_disallow_undefs_from_cond(func, branch_name):
"""Wraps conditional branch to disallow returning undefined symbols."""
def wrapper():
@ -503,6 +512,13 @@ def _wrap_disallow_undefs_in_cond(func, branch_name):
' statement.'.format(branch_name,
tuple(s.symbol_name for s in undefined)))
for result in results_tuple:
if special_values.is_undefined_return(result):
raise ValueError(
'A value must also be returned from the {} branch. If a value is '
'returned from one branch of a conditional a value must be '
'returned from all branches.'.format(branch_name))
return results
return wrapper

View File

@ -64,3 +64,13 @@ def is_undefined(value):
Boolean, whether the input value is undefined.
"""
return isinstance(value, Undefined)
class UndefinedReturnValue(object):
"""Represents a default return value from a function (None in Python)."""
pass
def is_undefined_return(value):
"""Checks whether `value` is the default return value."""
return isinstance(value, UndefinedReturnValue)