Allow ZeroDivisionTest to run eagerly
PiperOrigin-RevId: 316486715 Change-Id: I5a705b27562c57760ed8efeae52c67a91539ee7c
This commit is contained in:
parent
05bec441e3
commit
a208b5cd29
|
@ -20,26 +20,25 @@ from __future__ import print_function
|
|||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ZeroDivisionTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZeros(self):
|
||||
with test_util.use_gpu():
|
||||
for dtype in dtypes.uint8, dtypes.int16, dtypes.int32, dtypes.int64:
|
||||
zero = constant_op.constant(0, dtype=dtype)
|
||||
one = constant_op.constant(1, dtype=dtype)
|
||||
bads = [one // zero]
|
||||
bads = [lambda x, y: x // y]
|
||||
if dtype in (dtypes.int32, dtypes.int64):
|
||||
bads.append(one % zero)
|
||||
bads.append(lambda x, y: x % y)
|
||||
for bad in bads:
|
||||
try:
|
||||
result = self.evaluate(bad)
|
||||
except errors_impl.OpError as e:
|
||||
result = self.evaluate(bad(one, zero))
|
||||
except (errors.OpError, errors.InvalidArgumentError) as e:
|
||||
# Ideally, we'd get a nice exception. In theory, this should only
|
||||
# happen on CPU, but 32 bit integer GPU division is actually on
|
||||
# CPU due to a placer bug.
|
||||
|
|
Loading…
Reference in New Issue