Use float32 instead of float64 for confusion matrix computations to make it compatible with tpus.
PiperOrigin-RevId: 337635801 Change-Id: I2f9c282ca010168d19c316e1f0c6c86997567dec
This commit is contained in:
parent
195b6a667e
commit
7a636941b3
@ -2771,13 +2771,12 @@ class MeanIoU(Metric):
|
||||
super(MeanIoU, self).__init__(name=name, dtype=dtype)
|
||||
self.num_classes = num_classes
|
||||
|
||||
# Variable to accumulate the predictions in the confusion matrix. Setting
|
||||
# the type to be `float64` as required by confusion_matrix_ops.
|
||||
# Variable to accumulate the predictions in the confusion matrix.
|
||||
self.total_cm = self.add_weight(
|
||||
'total_confusion_matrix',
|
||||
shape=(num_classes, num_classes),
|
||||
initializer=init_ops.zeros_initializer,
|
||||
dtype=dtypes.float64)
|
||||
dtype=dtypes.float32)
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
"""Accumulates the confusion matrix statistics.
|
||||
@ -2814,7 +2813,7 @@ class MeanIoU(Metric):
|
||||
y_pred,
|
||||
self.num_classes,
|
||||
weights=sample_weight,
|
||||
dtype=dtypes.float64)
|
||||
dtype=dtypes.float32)
|
||||
return self.total_cm.assign_add(current_cm)
|
||||
|
||||
def result(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user