Better error messages for invalid return values in TensorFlow conditionals and loops.
Have a specific error message when a return statement appears in one branch of a TensorFlow conditional and not the other. Also have a nicer error message when returning a value from a TensorFlow loop. PiperOrigin-RevId: 247212194
This commit is contained in:
parent
0410cff073
commit
c2d506cc04
@ -375,8 +375,10 @@ class ReturnStatementsTransformer(converter.Base):
|
|||||||
if self.default_to_null_return:
|
if self.default_to_null_return:
|
||||||
template = """
|
template = """
|
||||||
do_return_var_name = False
|
do_return_var_name = False
|
||||||
retval_var_name = None
|
retval_var_name = ag__.UndefinedReturnValue()
|
||||||
body
|
body
|
||||||
|
if ag__.is_undefined_return(retval_var_name):
|
||||||
|
retval_var_name = None
|
||||||
return retval_var_name
|
return retval_var_name
|
||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
|
@ -72,4 +72,6 @@ from tensorflow.python.autograph.operators.slices import get_item
|
|||||||
from tensorflow.python.autograph.operators.slices import GetItemOpts
|
from tensorflow.python.autograph.operators.slices import GetItemOpts
|
||||||
from tensorflow.python.autograph.operators.slices import set_item
|
from tensorflow.python.autograph.operators.slices import set_item
|
||||||
from tensorflow.python.autograph.operators.special_values import is_undefined
|
from tensorflow.python.autograph.operators.special_values import is_undefined
|
||||||
|
from tensorflow.python.autograph.operators.special_values import is_undefined_return
|
||||||
from tensorflow.python.autograph.operators.special_values import Undefined
|
from tensorflow.python.autograph.operators.special_values import Undefined
|
||||||
|
from tensorflow.python.autograph.operators.special_values import UndefinedReturnValue
|
||||||
|
@ -42,6 +42,7 @@ INEFFICIENT_UNROLL_MIN_OPS = 1
|
|||||||
|
|
||||||
|
|
||||||
def _disallow_undefs_into_loop(*values):
|
def _disallow_undefs_into_loop(*values):
|
||||||
|
"""Ensures that all values in the state are defined when entering a loop."""
|
||||||
undefined = tuple(filter(special_values.is_undefined, values))
|
undefined = tuple(filter(special_values.is_undefined, values))
|
||||||
if undefined:
|
if undefined:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -49,6 +50,14 @@ def _disallow_undefs_into_loop(*values):
|
|||||||
' before the loop: {}'.format(
|
' before the loop: {}'.format(
|
||||||
tuple(s.symbol_name for s in undefined)))
|
tuple(s.symbol_name for s in undefined)))
|
||||||
|
|
||||||
|
for value in values:
|
||||||
|
if special_values.is_undefined_return(value):
|
||||||
|
# Assumption: the loop will only capture the variable which tracks the
|
||||||
|
# return value if the loop contained a return statement.
|
||||||
|
# TODO(mdan): This should be checked at the place where return occurs.
|
||||||
|
raise ValueError(
|
||||||
|
'Return statements are not supported within a TensorFlow loop.')
|
||||||
|
|
||||||
|
|
||||||
def for_stmt(iter_, extra_test, body, init_state):
|
def for_stmt(iter_, extra_test, body, init_state):
|
||||||
"""Functional form of a for statement.
|
"""Functional form of a for statement.
|
||||||
@ -435,8 +444,8 @@ def if_stmt(cond, body, orelse, get_state, set_state):
|
|||||||
|
|
||||||
def tf_if_stmt(cond, body, orelse, get_state, set_state):
|
def tf_if_stmt(cond, body, orelse, get_state, set_state):
|
||||||
"""Overload of if_stmt that stages a TF cond."""
|
"""Overload of if_stmt that stages a TF cond."""
|
||||||
body = _wrap_disallow_undefs_in_cond(body, branch_name='if')
|
body = _wrap_disallow_undefs_from_cond(body, branch_name='if')
|
||||||
orelse = _wrap_disallow_undefs_in_cond(orelse, branch_name='else')
|
orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else')
|
||||||
body = _isolate_state(body, get_state, set_state)
|
body = _isolate_state(body, get_state, set_state)
|
||||||
orelse = _isolate_state(orelse, get_state, set_state)
|
orelse = _isolate_state(orelse, get_state, set_state)
|
||||||
|
|
||||||
@ -484,7 +493,7 @@ def _isolate_state(func, get_state, set_state):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def _wrap_disallow_undefs_in_cond(func, branch_name):
|
def _wrap_disallow_undefs_from_cond(func, branch_name):
|
||||||
"""Wraps conditional branch to disallow returning undefined symbols."""
|
"""Wraps conditional branch to disallow returning undefined symbols."""
|
||||||
|
|
||||||
def wrapper():
|
def wrapper():
|
||||||
@ -503,6 +512,13 @@ def _wrap_disallow_undefs_in_cond(func, branch_name):
|
|||||||
' statement.'.format(branch_name,
|
' statement.'.format(branch_name,
|
||||||
tuple(s.symbol_name for s in undefined)))
|
tuple(s.symbol_name for s in undefined)))
|
||||||
|
|
||||||
|
for result in results_tuple:
|
||||||
|
if special_values.is_undefined_return(result):
|
||||||
|
raise ValueError(
|
||||||
|
'A value must also be returned from the {} branch. If a value is '
|
||||||
|
'returned from one branch of a conditional a value must be '
|
||||||
|
'returned from all branches.'.format(branch_name))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -64,3 +64,13 @@ def is_undefined(value):
|
|||||||
Boolean, whether the input value is undefined.
|
Boolean, whether the input value is undefined.
|
||||||
"""
|
"""
|
||||||
return isinstance(value, Undefined)
|
return isinstance(value, Undefined)
|
||||||
|
|
||||||
|
|
||||||
|
class UndefinedReturnValue(object):
|
||||||
|
"""Represents a default return value from a function (None in Python)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def is_undefined_return(value):
|
||||||
|
"""Checks whether `value` is the default return value."""
|
||||||
|
return isinstance(value, UndefinedReturnValue)
|
||||||
|
Loading…
Reference in New Issue
Block a user