Use float32 instead of float64 for confusion matrix computations to make it compatible with tpus.
PiperOrigin-RevId: 337636936 Change-Id: I06295e65d089d35bff638cd7502f6975a15dfc40
This commit is contained in:
parent
7a636941b3
commit
7529bc18e8
@ -2771,12 +2771,13 @@ class MeanIoU(Metric):
|
|||||||
super(MeanIoU, self).__init__(name=name, dtype=dtype)
|
super(MeanIoU, self).__init__(name=name, dtype=dtype)
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
|
||||||
# Variable to accumulate the predictions in the confusion matrix.
|
# Variable to accumulate the predictions in the confusion matrix. Setting
|
||||||
|
# the type to be `float64` as required by confusion_matrix_ops.
|
||||||
self.total_cm = self.add_weight(
|
self.total_cm = self.add_weight(
|
||||||
'total_confusion_matrix',
|
'total_confusion_matrix',
|
||||||
shape=(num_classes, num_classes),
|
shape=(num_classes, num_classes),
|
||||||
initializer=init_ops.zeros_initializer,
|
initializer=init_ops.zeros_initializer,
|
||||||
dtype=dtypes.float32)
|
dtype=dtypes.float64)
|
||||||
|
|
||||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||||
"""Accumulates the confusion matrix statistics.
|
"""Accumulates the confusion matrix statistics.
|
||||||
@ -2813,7 +2814,7 @@ class MeanIoU(Metric):
|
|||||||
y_pred,
|
y_pred,
|
||||||
self.num_classes,
|
self.num_classes,
|
||||||
weights=sample_weight,
|
weights=sample_weight,
|
||||||
dtype=dtypes.float32)
|
dtype=dtypes.float64)
|
||||||
return self.total_cm.assign_add(current_cm)
|
return self.total_cm.assign_add(current_cm)
|
||||||
|
|
||||||
def result(self):
|
def result(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user