test fixes?

This commit is contained in:
ngc92 2020-04-22 14:23:52 +03:00
parent 7f389c2ac3
commit 31077fffee

View File

@ -22,6 +22,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import confusion_matrix from tensorflow.python.ops import confusion_matrix
@ -188,7 +189,7 @@ class ConfusionMatrixTest(test.TestCase):
def testLabelsTooLarge(self): def testLabelsTooLarge(self):
labels = np.asarray([1, 1, 0, 3, 5], dtype=np.int32) labels = np.asarray([1, 1, 0, 3, 5], dtype=np.int32)
predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32) predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32)
with self.assertRaisesOpError("`labels`.*x < y"): with self.assertRaisesWithPredicateMatch(errors_impl.InvalidArgumentError, "`labels`.*x < y"):
self._testConfMatrix( self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None) labels=labels, predictions=predictions, num_classes=3, truth=None)
@ -203,7 +204,7 @@ class ConfusionMatrixTest(test.TestCase):
def testPredictionsTooLarge(self): def testPredictionsTooLarge(self):
labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32) labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32)
predictions = np.asarray([2, 1, 0, 3, 5], dtype=np.int32) predictions = np.asarray([2, 1, 0, 3, 5], dtype=np.int32)
with self.assertRaisesOpError("`predictions`.*x < y"): with self.assertRaisesWithPredicateMatch(errors_impl.InvalidArgumentError, "`predictions`.*x < y"):
self._testConfMatrix( self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None) labels=labels, predictions=predictions, num_classes=3, truth=None)