Update update_confusion_matrix_variables to alwasy cast to variables_to_update dtype (vs. explicit float32)
This commits updates the function update_confusion_matrix_variables to alwasy cast to dtype based on variables_to_update (previously the values are casted to float32 explicitly and that cuases issues when keras' backend use non-float32). Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
8dadd8304f
commit
ee7bffc10c
|
@ -299,9 +299,19 @@ def update_confusion_matrix_variables(variables_to_update,
|
|||
'`multi_label` is True.')
|
||||
if variables_to_update is None:
|
||||
return
|
||||
y_true = math_ops.cast(y_true, dtype=dtypes.float32)
|
||||
y_pred = math_ops.cast(y_pred, dtype=dtypes.float32)
|
||||
thresholds = ops.convert_to_tensor_v2(thresholds, dtype=dtypes.float32)
|
||||
if not any(
|
||||
key for key in variables_to_update if key in list(ConfusionMatrix)):
|
||||
raise ValueError(
|
||||
'Please provide at least one valid confusion matrix '
|
||||
'variable to update. Valid variable key options are: "{}". '
|
||||
'Received: "{}"'.format(
|
||||
list(ConfusionMatrix), variables_to_update.keys()))
|
||||
|
||||
variable_dtype = list(variables_to_update.values())[0].dtype
|
||||
|
||||
y_true = math_ops.cast(y_true, dtype=variable_dtype)
|
||||
y_pred = math_ops.cast(y_pred, dtype=variable_dtype)
|
||||
thresholds = ops.convert_to_tensor_v2(thresholds, dtype=variable_dtype)
|
||||
num_thresholds = thresholds.shape[0]
|
||||
if multi_label:
|
||||
one_thresh = math_ops.equal(
|
||||
|
@ -314,14 +324,6 @@ def update_confusion_matrix_variables(variables_to_update,
|
|||
sample_weight)
|
||||
one_thresh = math_ops.cast(True, dtype=dtypes.bool)
|
||||
|
||||
if not any(
|
||||
key for key in variables_to_update if key in list(ConfusionMatrix)):
|
||||
raise ValueError(
|
||||
'Please provide at least one valid confusion matrix '
|
||||
'variable to update. Valid variable key options are: "{}". '
|
||||
'Received: "{}"'.format(
|
||||
list(ConfusionMatrix), variables_to_update.keys()))
|
||||
|
||||
invalid_keys = [
|
||||
key for key in variables_to_update if key not in list(ConfusionMatrix)
|
||||
]
|
||||
|
@ -401,7 +403,7 @@ def update_confusion_matrix_variables(variables_to_update,
|
|||
|
||||
if sample_weight is not None:
|
||||
sample_weight = weights_broadcast_ops.broadcast_weights(
|
||||
math_ops.cast(sample_weight, dtype=dtypes.float32), y_pred)
|
||||
math_ops.cast(sample_weight, dtype=variable_dtype), y_pred)
|
||||
weights_tiled = array_ops.tile(
|
||||
array_ops.reshape(sample_weight, thresh_tiles), data_tiles)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue