Better error messages for invalid return values in TensorFlow conditionals and loops.
Have a specific error message when a return statement appears in one branch of a TensorFlow conditional and not the other. Also have a nicer error message when returning a value from a TensorFlow loop. PiperOrigin-RevId: 247212194
This commit is contained in:
parent
0410cff073
commit
c2d506cc04
@ -375,8 +375,10 @@ class ReturnStatementsTransformer(converter.Base):
|
||||
if self.default_to_null_return:
|
||||
template = """
|
||||
do_return_var_name = False
|
||||
retval_var_name = None
|
||||
retval_var_name = ag__.UndefinedReturnValue()
|
||||
body
|
||||
if ag__.is_undefined_return(retval_var_name):
|
||||
retval_var_name = None
|
||||
return retval_var_name
|
||||
"""
|
||||
else:
|
||||
|
@ -72,4 +72,6 @@ from tensorflow.python.autograph.operators.slices import get_item
|
||||
from tensorflow.python.autograph.operators.slices import GetItemOpts
|
||||
from tensorflow.python.autograph.operators.slices import set_item
|
||||
from tensorflow.python.autograph.operators.special_values import is_undefined
|
||||
from tensorflow.python.autograph.operators.special_values import is_undefined_return
|
||||
from tensorflow.python.autograph.operators.special_values import Undefined
|
||||
from tensorflow.python.autograph.operators.special_values import UndefinedReturnValue
|
||||
|
@ -42,6 +42,7 @@ INEFFICIENT_UNROLL_MIN_OPS = 1
|
||||
|
||||
|
||||
def _disallow_undefs_into_loop(*values):
|
||||
"""Ensures that all values in the state are defined when entering a loop."""
|
||||
undefined = tuple(filter(special_values.is_undefined, values))
|
||||
if undefined:
|
||||
raise ValueError(
|
||||
@ -49,6 +50,14 @@ def _disallow_undefs_into_loop(*values):
|
||||
' before the loop: {}'.format(
|
||||
tuple(s.symbol_name for s in undefined)))
|
||||
|
||||
for value in values:
|
||||
if special_values.is_undefined_return(value):
|
||||
# Assumption: the loop will only capture the variable which tracks the
|
||||
# return value if the loop contained a return statement.
|
||||
# TODO(mdan): This should be checked at the place where return occurs.
|
||||
raise ValueError(
|
||||
'Return statements are not supported within a TensorFlow loop.')
|
||||
|
||||
|
||||
def for_stmt(iter_, extra_test, body, init_state):
|
||||
"""Functional form of a for statement.
|
||||
@ -435,8 +444,8 @@ def if_stmt(cond, body, orelse, get_state, set_state):
|
||||
|
||||
def tf_if_stmt(cond, body, orelse, get_state, set_state):
|
||||
"""Overload of if_stmt that stages a TF cond."""
|
||||
body = _wrap_disallow_undefs_in_cond(body, branch_name='if')
|
||||
orelse = _wrap_disallow_undefs_in_cond(orelse, branch_name='else')
|
||||
body = _wrap_disallow_undefs_from_cond(body, branch_name='if')
|
||||
orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else')
|
||||
body = _isolate_state(body, get_state, set_state)
|
||||
orelse = _isolate_state(orelse, get_state, set_state)
|
||||
|
||||
@ -484,7 +493,7 @@ def _isolate_state(func, get_state, set_state):
|
||||
return wrapper
|
||||
|
||||
|
||||
def _wrap_disallow_undefs_in_cond(func, branch_name):
|
||||
def _wrap_disallow_undefs_from_cond(func, branch_name):
|
||||
"""Wraps conditional branch to disallow returning undefined symbols."""
|
||||
|
||||
def wrapper():
|
||||
@ -503,6 +512,13 @@ def _wrap_disallow_undefs_in_cond(func, branch_name):
|
||||
' statement.'.format(branch_name,
|
||||
tuple(s.symbol_name for s in undefined)))
|
||||
|
||||
for result in results_tuple:
|
||||
if special_values.is_undefined_return(result):
|
||||
raise ValueError(
|
||||
'A value must also be returned from the {} branch. If a value is '
|
||||
'returned from one branch of a conditional a value must be '
|
||||
'returned from all branches.'.format(branch_name))
|
||||
|
||||
return results
|
||||
|
||||
return wrapper
|
||||
|
@ -64,3 +64,13 @@ def is_undefined(value):
|
||||
Boolean, whether the input value is undefined.
|
||||
"""
|
||||
return isinstance(value, Undefined)
|
||||
|
||||
|
||||
class UndefinedReturnValue(object):
|
||||
"""Represents a default return value from a function (None in Python)."""
|
||||
pass
|
||||
|
||||
|
||||
def is_undefined_return(value):
|
||||
"""Checks whether `value` is the default return value."""
|
||||
return isinstance(value, UndefinedReturnValue)
|
||||
|
Loading…
Reference in New Issue
Block a user