Automated g4 rollback of changelist 167604306
PiperOrigin-RevId: 167639833
This commit is contained in:
parent
eaaa0b9385
commit
0ab137cd8b
@ -1496,15 +1496,6 @@ class StreamingAUCTest(test.TestCase):
|
||||
for _ in range(10):
|
||||
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
|
||||
|
||||
def testPredictionsOutOfRange(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant(
|
||||
[1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32)
|
||||
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
|
||||
_, update_op = metrics.streaming_auc(predictions, labels)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertRaises(errors_impl.InvalidArgumentError, update_op.eval)
|
||||
|
||||
def testAllCorrect(self):
|
||||
self.allCorrectAsExpected('ROC')
|
||||
|
||||
|
@ -463,16 +463,10 @@ def _confusion_matrix_at_thresholds(
|
||||
if include not in all_includes:
|
||||
raise ValueError('Invaild key: %s.' % include)
|
||||
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_greater_equal(
|
||||
predictions, 0.0, message='predictions must be in [0, 1]'),
|
||||
check_ops.assert_less_equal(
|
||||
predictions, 1.0, message='predictions must be in [0, 1]')
|
||||
]):
|
||||
predictions, labels, weights = _remove_squeezable_dimensions(
|
||||
predictions=math_ops.to_float(predictions),
|
||||
labels=math_ops.cast(labels, dtype=dtypes.bool),
|
||||
weights=weights)
|
||||
predictions, labels, weights = _remove_squeezable_dimensions(
|
||||
predictions=math_ops.to_float(predictions),
|
||||
labels=math_ops.cast(labels, dtype=dtypes.bool),
|
||||
weights=weights)
|
||||
|
||||
num_thresholds = len(thresholds)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user