Add extra checks for shape and dtype for control flow conditionals.
PiperOrigin-RevId: 335096700 Change-Id: I742518d56648aa4d99bba374c824587f25c2c220
This commit is contained in:
parent
7ae3aed902
commit
99d1481fc6
@ -116,6 +116,33 @@ def _is_none_or_undef(value):
|
||||
or isinstance(value, variables.Undefined))
|
||||
|
||||
|
||||
def _verify_tf_condition(cond, tag):
|
||||
"""Ensures that the condition can be used in a TF control flow."""
|
||||
extra_hint = 'to check for None, use `is not None`'
|
||||
cond = ops.convert_to_tensor_v2(cond)
|
||||
|
||||
if cond.dtype != dtypes.bool:
|
||||
raise ValueError(
|
||||
'condition of {} expected to be `tf.bool` scalar, got {}'
|
||||
'; to use as boolean Tensor, use `tf.cast`'
|
||||
'; {}'.format(tag, cond, extra_hint))
|
||||
|
||||
if cond.shape is None or cond.shape.ndims is None:
|
||||
# TODO(mdan): Consider a explicit size check, if not too slow.
|
||||
cond = array_ops.reshape(cond, ())
|
||||
|
||||
elif cond.shape.ndims > 0:
|
||||
known_dims = [d for d in cond.shape.as_list() if d is not None]
|
||||
if np.prod(known_dims) > 1:
|
||||
raise ValueError(
|
||||
'condition of {} expected to be `tf.bool` scalar, got {}'
|
||||
'; {}'.format(tag, cond, extra_hint))
|
||||
else:
|
||||
cond = array_ops.reshape(cond, ())
|
||||
|
||||
return cond
|
||||
|
||||
|
||||
def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None):
|
||||
"""Ensures that all values in the state are valid to use in a TF loop.
|
||||
|
||||
@ -1038,7 +1065,7 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
|
||||
loop_vars = loop_vars[1:]
|
||||
|
||||
set_state(loop_vars)
|
||||
return test()
|
||||
return _verify_tf_condition(test(), 'while loop')
|
||||
|
||||
def aug_body(*loop_vars):
|
||||
if require_one_iteration:
|
||||
@ -1141,6 +1168,8 @@ def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts):
|
||||
def _tf_if_stmt(
|
||||
cond, body, orelse, get_state, set_state, symbol_names, nouts):
|
||||
"""Overload of if_stmt that stages a TF cond."""
|
||||
cond = _verify_tf_condition(cond, 'if statement')
|
||||
|
||||
if not nouts:
|
||||
prev_get_state, prev_set_state = get_state, set_state
|
||||
# Control flow V1 wants at least one output.
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.autograph.utils import testing
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -46,6 +47,20 @@ from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _unranked_item(value):
|
||||
rand_rank = random_ops.random_uniform(
|
||||
shape=(), minval=3, maxval=4, dtype=dtypes.int32)
|
||||
rand_shape = array_ops.ones([rand_rank], dtype=dtypes.int32)
|
||||
return array_ops.fill(rand_shape, value)
|
||||
|
||||
|
||||
def _partial_shaped_bools():
|
||||
rand_vect = math_ops.range(
|
||||
random_ops.random_uniform(
|
||||
shape=(), minval=2, maxval=3, dtype=dtypes.int32))
|
||||
return array_ops.expand_dims_v2(rand_vect, 0) < 0
|
||||
|
||||
|
||||
class ForLoopTest(testing.AutoGraphTestCase):
|
||||
|
||||
def test_tensor(self):
|
||||
@ -871,6 +886,60 @@ class WhileLoopTest(testing.AutoGraphTestCase):
|
||||
with self.assertRaisesRegex(ValueError, r"'s'.* shape \(1,\) after"):
|
||||
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
||||
|
||||
def _fixed_while_loop(self, cond_fn):
|
||||
def test_():
|
||||
return cond_fn(s)
|
||||
|
||||
def body():
|
||||
nonlocal s
|
||||
s += 1
|
||||
|
||||
def set_state(loop_vars):
|
||||
nonlocal s
|
||||
s, = loop_vars
|
||||
|
||||
s = constant_op.constant(0)
|
||||
control_flow.while_stmt(
|
||||
test=test_,
|
||||
body=body,
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
return s
|
||||
|
||||
def _assertFixedLoopResult(self, cond, expected):
|
||||
def test_fn():
|
||||
return self._fixed_while_loop(cond)
|
||||
self.assertEqual(test_fn(), expected)
|
||||
|
||||
def test_tensor_legal_cond_scalar(self):
|
||||
self._assertFixedLoopResult(lambda s: constant_op.constant(False), 0)
|
||||
self._assertFixedLoopResult(lambda s: s < 2, 2)
|
||||
|
||||
def test_tensor_legal_cond_single_element_nd(self):
|
||||
self._assertFixedLoopResult(lambda s: constant_op.constant([[False]]), 0)
|
||||
self._assertFixedLoopResult(lambda s: _unranked_item(False), 0)
|
||||
|
||||
def _assertCondCheckFails(self, cond):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'condition of while loop expected to be `tf.bool`'):
|
||||
self._fixed_while_loop(cond)
|
||||
|
||||
def test_tensor_illegal_cond_not_bool(self):
|
||||
self._assertCondCheckFails(lambda s: constant_op.constant(1))
|
||||
self._assertCondCheckFails(lambda s: s)
|
||||
|
||||
def test_tensor_illegal_cond_not_single_element(self):
|
||||
self._assertCondCheckFails(lambda s: constant_op.constant([1, 2, 3]))
|
||||
self._assertCondCheckFails(lambda s: constant_op.constant([True, False]))
|
||||
|
||||
def test_tensor_illegal_cond_not_single_element_dynamic_shape(self):
|
||||
self._fixed_while_loop(lambda s: _partial_shaped_bools())
|
||||
# TODO(mdan): This error is quite bad. Measure the cost of an assertion.
|
||||
self.assertRaisesRuntime(
|
||||
errors_impl.InvalidArgumentError, 'requested shape has 1')
|
||||
|
||||
|
||||
class IfStmtTest(testing.AutoGraphTestCase):
|
||||
|
||||
@ -1065,6 +1134,62 @@ class IfStmtTest(testing.AutoGraphTestCase):
|
||||
TypeError, "'x' has dtype int32.*but.*float32"):
|
||||
self._basic_cond(lambda: 1, lambda: 1.0)
|
||||
|
||||
def _fixed_cond(self, cond_val):
|
||||
def body():
|
||||
nonlocal x
|
||||
x = 1
|
||||
|
||||
def orelse():
|
||||
nonlocal x
|
||||
x = -1
|
||||
|
||||
def set_state(cond_vars):
|
||||
nonlocal x
|
||||
x, = cond_vars
|
||||
|
||||
x = 0
|
||||
control_flow.if_stmt(
|
||||
cond=cond_val,
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
get_state=lambda: (x,),
|
||||
set_state=set_state,
|
||||
symbol_names=('x',),
|
||||
nouts=1)
|
||||
return x
|
||||
|
||||
def _assertFixedCondResult(self, cond, expected):
|
||||
def test_fn():
|
||||
return self._fixed_cond(cond)
|
||||
self.assertEqual(test_fn(), expected)
|
||||
|
||||
def test_tensor_legal_cond_scalar(self):
|
||||
self._assertFixedCondResult(constant_op.constant(True), 1)
|
||||
self._assertFixedCondResult(constant_op.constant(False), -1)
|
||||
|
||||
def test_tensor_legal_cond_single_element_nd(self):
|
||||
self._assertFixedCondResult(constant_op.constant([[True]]), 1)
|
||||
self._assertFixedCondResult(constant_op.constant([[False]]), -1)
|
||||
self._assertFixedCondResult(_unranked_item(True), 1)
|
||||
self._assertFixedCondResult(_unranked_item(False), -1)
|
||||
|
||||
def _assertCondCheckFails(self, cond):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'condition of if statement expected to be `tf.bool`'):
|
||||
self._fixed_cond(cond)
|
||||
|
||||
def test_tensor_illegal_cond_not_bool(self):
|
||||
self._assertCondCheckFails(constant_op.constant(1))
|
||||
|
||||
def test_tensor_illegal_cond_not_single_element(self):
|
||||
self._assertCondCheckFails(constant_op.constant([1, 2, 3]))
|
||||
self._assertCondCheckFails(constant_op.constant([True, False]))
|
||||
|
||||
def test_tensor_illegal_cond_not_single_element_dynamic_shape(self):
|
||||
self._fixed_cond(_partial_shaped_bools())
|
||||
# TODO(mdan): This error is quite bad. Measure the cost of an assertion.
|
||||
self.assertRaisesRuntime(
|
||||
errors_impl.InvalidArgumentError, 'requested shape has 1')
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
|
||||
@ -81,18 +82,29 @@ class AutoGraphTestCase(test.TestCase):
|
||||
@def_function.function(autograph=False) # Testing autograph itself.
|
||||
def fn_wrapper():
|
||||
self.assertions = []
|
||||
self.raises_cm = None
|
||||
self.graph_assertions = []
|
||||
self.trace_log = []
|
||||
fn()
|
||||
targets = [args for _, args in self.assertions]
|
||||
return targets
|
||||
|
||||
tensors = fn_wrapper()
|
||||
try:
|
||||
tensors = fn_wrapper()
|
||||
|
||||
for assertion in self.graph_assertions:
|
||||
assertion(fn_wrapper.get_concrete_function().graph)
|
||||
for assertion in self.graph_assertions:
|
||||
assertion(fn_wrapper.get_concrete_function().graph)
|
||||
|
||||
actuals = self.evaluate(tensors)
|
||||
|
||||
except: # pylint:disable=bare-except
|
||||
if self.raises_cm is not None:
|
||||
# Note: Yes, the Raises and function contexts cross.
|
||||
self.raises_cm.__exit__(*sys.exc_info())
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
actuals = self.evaluate(tensors)
|
||||
for (assertion, _), values in zip(self.assertions, actuals):
|
||||
assertion(*values)
|
||||
|
||||
@ -109,6 +121,7 @@ class AutoGraphTestCase(test.TestCase):
|
||||
super().setUp()
|
||||
self.variables = {}
|
||||
self.trace_log = []
|
||||
self.raises_cm = None
|
||||
op_callbacks.add_op_callback(self._op_callback)
|
||||
|
||||
def tearDown(self):
|
||||
@ -145,3 +158,9 @@ class AutoGraphTestCase(test.TestCase):
|
||||
|
||||
def assertDictEqual(self, *args):
|
||||
self.assertions.append((super().assertDictEqual, list(args)))
|
||||
|
||||
def assertRaisesRuntime(self, *args):
|
||||
if self.raises_cm is not None:
|
||||
raise ValueError('cannot use more than one assertRaisesRuntime in a test')
|
||||
self.raises_cm = self.assertRaisesRegex(*args)
|
||||
self.raises_cm.__enter__()
|
||||
|
Loading…
Reference in New Issue
Block a user