Update update_confusion_matrix_variables to alwasy cast to variables_to_update dtype (vs. explicit float32)
This commits updates the function update_confusion_matrix_variables to alwasy cast to dtype based on variables_to_update (previously the values are casted to float32 explicitly and that cuases issues when keras' backend use non-float32). Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
8dadd8304f
commit
ee7bffc10c
|
@ -299,9 +299,19 @@ def update_confusion_matrix_variables(variables_to_update,
|
||||||
'`multi_label` is True.')
|
'`multi_label` is True.')
|
||||||
if variables_to_update is None:
|
if variables_to_update is None:
|
||||||
return
|
return
|
||||||
y_true = math_ops.cast(y_true, dtype=dtypes.float32)
|
if not any(
|
||||||
y_pred = math_ops.cast(y_pred, dtype=dtypes.float32)
|
key for key in variables_to_update if key in list(ConfusionMatrix)):
|
||||||
thresholds = ops.convert_to_tensor_v2(thresholds, dtype=dtypes.float32)
|
raise ValueError(
|
||||||
|
'Please provide at least one valid confusion matrix '
|
||||||
|
'variable to update. Valid variable key options are: "{}". '
|
||||||
|
'Received: "{}"'.format(
|
||||||
|
list(ConfusionMatrix), variables_to_update.keys()))
|
||||||
|
|
||||||
|
variable_dtype = list(variables_to_update.values())[0].dtype
|
||||||
|
|
||||||
|
y_true = math_ops.cast(y_true, dtype=variable_dtype)
|
||||||
|
y_pred = math_ops.cast(y_pred, dtype=variable_dtype)
|
||||||
|
thresholds = ops.convert_to_tensor_v2(thresholds, dtype=variable_dtype)
|
||||||
num_thresholds = thresholds.shape[0]
|
num_thresholds = thresholds.shape[0]
|
||||||
if multi_label:
|
if multi_label:
|
||||||
one_thresh = math_ops.equal(
|
one_thresh = math_ops.equal(
|
||||||
|
@ -314,14 +324,6 @@ def update_confusion_matrix_variables(variables_to_update,
|
||||||
sample_weight)
|
sample_weight)
|
||||||
one_thresh = math_ops.cast(True, dtype=dtypes.bool)
|
one_thresh = math_ops.cast(True, dtype=dtypes.bool)
|
||||||
|
|
||||||
if not any(
|
|
||||||
key for key in variables_to_update if key in list(ConfusionMatrix)):
|
|
||||||
raise ValueError(
|
|
||||||
'Please provide at least one valid confusion matrix '
|
|
||||||
'variable to update. Valid variable key options are: "{}". '
|
|
||||||
'Received: "{}"'.format(
|
|
||||||
list(ConfusionMatrix), variables_to_update.keys()))
|
|
||||||
|
|
||||||
invalid_keys = [
|
invalid_keys = [
|
||||||
key for key in variables_to_update if key not in list(ConfusionMatrix)
|
key for key in variables_to_update if key not in list(ConfusionMatrix)
|
||||||
]
|
]
|
||||||
|
@ -401,7 +403,7 @@ def update_confusion_matrix_variables(variables_to_update,
|
||||||
|
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
sample_weight = weights_broadcast_ops.broadcast_weights(
|
sample_weight = weights_broadcast_ops.broadcast_weights(
|
||||||
math_ops.cast(sample_weight, dtype=dtypes.float32), y_pred)
|
math_ops.cast(sample_weight, dtype=variable_dtype), y_pred)
|
||||||
weights_tiled = array_ops.tile(
|
weights_tiled = array_ops.tile(
|
||||||
array_ops.reshape(sample_weight, thresh_tiles), data_tiles)
|
array_ops.reshape(sample_weight, thresh_tiles), data_tiles)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue