Use float32 instead of float64 for confusion matrix computations to make it compatible with tpus.

PiperOrigin-RevId: 337635801
Change-Id: I2f9c282ca010168d19c316e1f0c6c86997567dec
This commit is contained in:
Abdullah Rashwan 2020-10-16 23:46:30 -07:00 committed by TensorFlower Gardener
parent 195b6a667e
commit 7a636941b3

View File

@ -2771,13 +2771,12 @@ class MeanIoU(Metric):
super(MeanIoU, self).__init__(name=name, dtype=dtype)
self.num_classes = num_classes
# Variable to accumulate the predictions in the confusion matrix. Setting
# the type to be `float64` as required by confusion_matrix_ops.
# Variable to accumulate the predictions in the confusion matrix.
self.total_cm = self.add_weight(
'total_confusion_matrix',
shape=(num_classes, num_classes),
initializer=init_ops.zeros_initializer,
dtype=dtypes.float64)
dtype=dtypes.float32)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates the confusion matrix statistics.
@ -2814,7 +2813,7 @@ class MeanIoU(Metric):
y_pred,
self.num_classes,
weights=sample_weight,
dtype=dtypes.float64)
dtype=dtypes.float32)
return self.total_cm.assign_add(current_cm)
def result(self):