Use float32 instead of float64 for confusion matrix computations to make it compatible with tpus.
PiperOrigin-RevId: 337636936 Change-Id: I06295e65d089d35bff638cd7502f6975a15dfc40
This commit is contained in:
parent
7a636941b3
commit
7529bc18e8
@ -2771,12 +2771,13 @@ class MeanIoU(Metric):
|
||||
super(MeanIoU, self).__init__(name=name, dtype=dtype)
|
||||
self.num_classes = num_classes
|
||||
|
||||
# Variable to accumulate the predictions in the confusion matrix.
|
||||
# Variable to accumulate the predictions in the confusion matrix. Setting
|
||||
# the type to be `float64` as required by confusion_matrix_ops.
|
||||
self.total_cm = self.add_weight(
|
||||
'total_confusion_matrix',
|
||||
shape=(num_classes, num_classes),
|
||||
initializer=init_ops.zeros_initializer,
|
||||
dtype=dtypes.float32)
|
||||
dtype=dtypes.float64)
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
"""Accumulates the confusion matrix statistics.
|
||||
@ -2813,7 +2814,7 @@ class MeanIoU(Metric):
|
||||
y_pred,
|
||||
self.num_classes,
|
||||
weights=sample_weight,
|
||||
dtype=dtypes.float32)
|
||||
dtype=dtypes.float64)
|
||||
return self.total_cm.assign_add(current_cm)
|
||||
|
||||
def result(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user