Use float32 instead of float64 for confusion matrix computations to make it compatible with tpus.

PiperOrigin-RevId: 337636936
Change-Id: I06295e65d089d35bff638cd7502f6975a15dfc40
This commit is contained in:
A. Unique TensorFlower 2020-10-17 00:07:42 -07:00 committed by TensorFlower Gardener
parent 7a636941b3
commit 7529bc18e8

View File

@ -2771,12 +2771,13 @@ class MeanIoU(Metric):
super(MeanIoU, self).__init__(name=name, dtype=dtype)
self.num_classes = num_classes
# Variable to accumulate the predictions in the confusion matrix.
# Variable to accumulate the predictions in the confusion matrix. Setting
# the type to be `float64` as required by confusion_matrix_ops.
self.total_cm = self.add_weight(
'total_confusion_matrix',
shape=(num_classes, num_classes),
initializer=init_ops.zeros_initializer,
dtype=dtypes.float32)
dtype=dtypes.float64)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates the confusion matrix statistics.
@ -2813,7 +2814,7 @@ class MeanIoU(Metric):
y_pred,
self.num_classes,
weights=sample_weight,
dtype=dtypes.float32)
dtype=dtypes.float64)
return self.total_cm.assign_add(current_cm)
def result(self):