Merge pull request #27157 from gshashank84:patch-3

PiperOrigin-RevId: 246165383
This commit is contained in:
TensorFlower Gardener 2019-05-01 11:40:20 -07:00
commit 6e95b93346
2 changed files with 6 additions and 2 deletions

View File

@ -2290,6 +2290,10 @@ class MeanIoU(Metric):
Returns:
Update op.
"""
y_true = math_ops.cast(y_true, self._dtype)
y_pred = math_ops.cast(y_pred, self._dtype)
# Flatten the input if its rank > 1.
if y_pred.shape.ndims > 1:
y_pred = array_ops.reshape(y_pred, [-1])

View File

@ -1141,8 +1141,8 @@ class MeanIoUTest(test.TestCase):
self.assertEqual(m_obj2.num_classes, 2)
def test_unweighted(self):
y_pred = constant_op.constant([0, 1, 0, 1], dtype=dtypes.float32)
y_true = constant_op.constant([0, 0, 1, 1])
y_pred = [0, 1, 0, 1]
y_true = [0, 0, 1, 1]
m_obj = metrics.MeanIoU(num_classes=2)
self.evaluate(variables.variables_initializer(m_obj.variables))