Merge pull request #27157 from gshashank84:patch-3
PiperOrigin-RevId: 246165383
This commit is contained in:
commit
6e95b93346
@ -2290,6 +2290,10 @@ class MeanIoU(Metric):
|
||||
Returns:
|
||||
Update op.
|
||||
"""
|
||||
|
||||
y_true = math_ops.cast(y_true, self._dtype)
|
||||
y_pred = math_ops.cast(y_pred, self._dtype)
|
||||
|
||||
# Flatten the input if its rank > 1.
|
||||
if y_pred.shape.ndims > 1:
|
||||
y_pred = array_ops.reshape(y_pred, [-1])
|
||||
|
@ -1141,8 +1141,8 @@ class MeanIoUTest(test.TestCase):
|
||||
self.assertEqual(m_obj2.num_classes, 2)
|
||||
|
||||
def test_unweighted(self):
|
||||
y_pred = constant_op.constant([0, 1, 0, 1], dtype=dtypes.float32)
|
||||
y_true = constant_op.constant([0, 0, 1, 1])
|
||||
y_pred = [0, 1, 0, 1]
|
||||
y_true = [0, 0, 1, 1]
|
||||
|
||||
m_obj = metrics.MeanIoU(num_classes=2)
|
||||
self.evaluate(variables.variables_initializer(m_obj.variables))
|
||||
|
Loading…
x
Reference in New Issue
Block a user