Add extra checks for shape and dtype for control flow conditionals.

PiperOrigin-RevId: 335096700
Change-Id: I742518d56648aa4d99bba374c824587f25c2c220
This commit is contained in:
Dan Moldovan 2020-10-02 14:04:41 -07:00 committed by TensorFlower Gardener
parent 7ae3aed902
commit 99d1481fc6
3 changed files with 178 additions and 5 deletions

View File

@ -116,6 +116,33 @@ def _is_none_or_undef(value):
or isinstance(value, variables.Undefined))
def _verify_tf_condition(cond, tag):
"""Ensures that the condition can be used in a TF control flow."""
extra_hint = 'to check for None, use `is not None`'
cond = ops.convert_to_tensor_v2(cond)
if cond.dtype != dtypes.bool:
raise ValueError(
'condition of {} expected to be `tf.bool` scalar, got {}'
'; to use as boolean Tensor, use `tf.cast`'
'; {}'.format(tag, cond, extra_hint))
if cond.shape is None or cond.shape.ndims is None:
# TODO(mdan): Consider a explicit size check, if not too slow.
cond = array_ops.reshape(cond, ())
elif cond.shape.ndims > 0:
known_dims = [d for d in cond.shape.as_list() if d is not None]
if np.prod(known_dims) > 1:
raise ValueError(
'condition of {} expected to be `tf.bool` scalar, got {}'
'; {}'.format(tag, cond, extra_hint))
else:
cond = array_ops.reshape(cond, ())
return cond
def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None):
"""Ensures that all values in the state are valid to use in a TF loop.
@ -1038,7 +1065,7 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
loop_vars = loop_vars[1:]
set_state(loop_vars)
return test()
return _verify_tf_condition(test(), 'while loop')
def aug_body(*loop_vars):
if require_one_iteration:
@ -1141,6 +1168,8 @@ def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts):
def _tf_if_stmt(
cond, body, orelse, get_state, set_state, symbol_names, nouts):
"""Overload of if_stmt that stages a TF cond."""
cond = _verify_tf_condition(cond, 'if statement')
if not nouts:
prev_get_state, prev_set_state = get_state, set_state
# Control flow V1 wants at least one output.

View File

@ -35,6 +35,7 @@ from tensorflow.python.autograph.utils import testing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@ -46,6 +47,20 @@ from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
def _unranked_item(value):
rand_rank = random_ops.random_uniform(
shape=(), minval=3, maxval=4, dtype=dtypes.int32)
rand_shape = array_ops.ones([rand_rank], dtype=dtypes.int32)
return array_ops.fill(rand_shape, value)
def _partial_shaped_bools():
rand_vect = math_ops.range(
random_ops.random_uniform(
shape=(), minval=2, maxval=3, dtype=dtypes.int32))
return array_ops.expand_dims_v2(rand_vect, 0) < 0
class ForLoopTest(testing.AutoGraphTestCase):
def test_tensor(self):
@ -871,6 +886,60 @@ class WhileLoopTest(testing.AutoGraphTestCase):
with self.assertRaisesRegex(ValueError, r"'s'.* shape \(1,\) after"):
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
def _fixed_while_loop(self, cond_fn):
def test_():
return cond_fn(s)
def body():
nonlocal s
s += 1
def set_state(loop_vars):
nonlocal s
s, = loop_vars
s = constant_op.constant(0)
control_flow.while_stmt(
test=test_,
body=body,
get_state=lambda: (s,),
set_state=set_state,
symbol_names=('s',),
opts={})
return s
def _assertFixedLoopResult(self, cond, expected):
def test_fn():
return self._fixed_while_loop(cond)
self.assertEqual(test_fn(), expected)
def test_tensor_legal_cond_scalar(self):
self._assertFixedLoopResult(lambda s: constant_op.constant(False), 0)
self._assertFixedLoopResult(lambda s: s < 2, 2)
def test_tensor_legal_cond_single_element_nd(self):
self._assertFixedLoopResult(lambda s: constant_op.constant([[False]]), 0)
self._assertFixedLoopResult(lambda s: _unranked_item(False), 0)
def _assertCondCheckFails(self, cond):
with self.assertRaisesRegex(
ValueError, 'condition of while loop expected to be `tf.bool`'):
self._fixed_while_loop(cond)
def test_tensor_illegal_cond_not_bool(self):
self._assertCondCheckFails(lambda s: constant_op.constant(1))
self._assertCondCheckFails(lambda s: s)
def test_tensor_illegal_cond_not_single_element(self):
self._assertCondCheckFails(lambda s: constant_op.constant([1, 2, 3]))
self._assertCondCheckFails(lambda s: constant_op.constant([True, False]))
def test_tensor_illegal_cond_not_single_element_dynamic_shape(self):
self._fixed_while_loop(lambda s: _partial_shaped_bools())
# TODO(mdan): This error is quite bad. Measure the cost of an assertion.
self.assertRaisesRuntime(
errors_impl.InvalidArgumentError, 'requested shape has 1')
class IfStmtTest(testing.AutoGraphTestCase):
@ -1065,6 +1134,62 @@ class IfStmtTest(testing.AutoGraphTestCase):
TypeError, "'x' has dtype int32.*but.*float32"):
self._basic_cond(lambda: 1, lambda: 1.0)
def _fixed_cond(self, cond_val):
def body():
nonlocal x
x = 1
def orelse():
nonlocal x
x = -1
def set_state(cond_vars):
nonlocal x
x, = cond_vars
x = 0
control_flow.if_stmt(
cond=cond_val,
body=body,
orelse=orelse,
get_state=lambda: (x,),
set_state=set_state,
symbol_names=('x',),
nouts=1)
return x
def _assertFixedCondResult(self, cond, expected):
def test_fn():
return self._fixed_cond(cond)
self.assertEqual(test_fn(), expected)
def test_tensor_legal_cond_scalar(self):
self._assertFixedCondResult(constant_op.constant(True), 1)
self._assertFixedCondResult(constant_op.constant(False), -1)
def test_tensor_legal_cond_single_element_nd(self):
self._assertFixedCondResult(constant_op.constant([[True]]), 1)
self._assertFixedCondResult(constant_op.constant([[False]]), -1)
self._assertFixedCondResult(_unranked_item(True), 1)
self._assertFixedCondResult(_unranked_item(False), -1)
def _assertCondCheckFails(self, cond):
with self.assertRaisesRegex(
ValueError, 'condition of if statement expected to be `tf.bool`'):
self._fixed_cond(cond)
def test_tensor_illegal_cond_not_bool(self):
self._assertCondCheckFails(constant_op.constant(1))
def test_tensor_illegal_cond_not_single_element(self):
self._assertCondCheckFails(constant_op.constant([1, 2, 3]))
self._assertCondCheckFails(constant_op.constant([True, False]))
def test_tensor_illegal_cond_not_single_element_dynamic_shape(self):
self._fixed_cond(_partial_shaped_bools())
# TODO(mdan): This error is quite bad. Measure the cost of an assertion.
self.assertRaisesRuntime(
errors_impl.InvalidArgumentError, 'requested shape has 1')
if __name__ == '__main__':
test.main()

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import re
import sys
import types
import unittest
@ -81,18 +82,29 @@ class AutoGraphTestCase(test.TestCase):
@def_function.function(autograph=False) # Testing autograph itself.
def fn_wrapper():
self.assertions = []
self.raises_cm = None
self.graph_assertions = []
self.trace_log = []
fn()
targets = [args for _, args in self.assertions]
return targets
tensors = fn_wrapper()
try:
tensors = fn_wrapper()
for assertion in self.graph_assertions:
assertion(fn_wrapper.get_concrete_function().graph)
for assertion in self.graph_assertions:
assertion(fn_wrapper.get_concrete_function().graph)
actuals = self.evaluate(tensors)
except: # pylint:disable=bare-except
if self.raises_cm is not None:
# Note: Yes, the Raises and function contexts cross.
self.raises_cm.__exit__(*sys.exc_info())
return
else:
raise
actuals = self.evaluate(tensors)
for (assertion, _), values in zip(self.assertions, actuals):
assertion(*values)
@ -109,6 +121,7 @@ class AutoGraphTestCase(test.TestCase):
super().setUp()
self.variables = {}
self.trace_log = []
self.raises_cm = None
op_callbacks.add_op_callback(self._op_callback)
def tearDown(self):
@ -145,3 +158,9 @@ class AutoGraphTestCase(test.TestCase):
def assertDictEqual(self, *args):
self.assertions.append((super().assertDictEqual, list(args)))
def assertRaisesRuntime(self, *args):
if self.raises_cm is not None:
raise ValueError('cannot use more than one assertRaisesRuntime in a test')
self.raises_cm = self.assertRaisesRegex(*args)
self.raises_cm.__enter__()