Update update_confusion_matrix_variables to alwasy cast to variables_to_update dtype (vs. explicit float32)

This commits updates the function update_confusion_matrix_variables
to alwasy cast to dtype based on  variables_to_update (previously
the values are casted to float32 explicitly and that cuases issues
when keras' backend use non-float32).

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-05-26 20:32:18 +00:00
parent 8dadd8304f
commit ee7bffc10c
1 changed files with 14 additions and 12 deletions

View File

@ -299,9 +299,19 @@ def update_confusion_matrix_variables(variables_to_update,
'`multi_label` is True.') '`multi_label` is True.')
if variables_to_update is None: if variables_to_update is None:
return return
y_true = math_ops.cast(y_true, dtype=dtypes.float32) if not any(
y_pred = math_ops.cast(y_pred, dtype=dtypes.float32) key for key in variables_to_update if key in list(ConfusionMatrix)):
thresholds = ops.convert_to_tensor_v2(thresholds, dtype=dtypes.float32) raise ValueError(
'Please provide at least one valid confusion matrix '
'variable to update. Valid variable key options are: "{}". '
'Received: "{}"'.format(
list(ConfusionMatrix), variables_to_update.keys()))
variable_dtype = list(variables_to_update.values())[0].dtype
y_true = math_ops.cast(y_true, dtype=variable_dtype)
y_pred = math_ops.cast(y_pred, dtype=variable_dtype)
thresholds = ops.convert_to_tensor_v2(thresholds, dtype=variable_dtype)
num_thresholds = thresholds.shape[0] num_thresholds = thresholds.shape[0]
if multi_label: if multi_label:
one_thresh = math_ops.equal( one_thresh = math_ops.equal(
@ -314,14 +324,6 @@ def update_confusion_matrix_variables(variables_to_update,
sample_weight) sample_weight)
one_thresh = math_ops.cast(True, dtype=dtypes.bool) one_thresh = math_ops.cast(True, dtype=dtypes.bool)
if not any(
key for key in variables_to_update if key in list(ConfusionMatrix)):
raise ValueError(
'Please provide at least one valid confusion matrix '
'variable to update. Valid variable key options are: "{}". '
'Received: "{}"'.format(
list(ConfusionMatrix), variables_to_update.keys()))
invalid_keys = [ invalid_keys = [
key for key in variables_to_update if key not in list(ConfusionMatrix) key for key in variables_to_update if key not in list(ConfusionMatrix)
] ]
@ -401,7 +403,7 @@ def update_confusion_matrix_variables(variables_to_update,
if sample_weight is not None: if sample_weight is not None:
sample_weight = weights_broadcast_ops.broadcast_weights( sample_weight = weights_broadcast_ops.broadcast_weights(
math_ops.cast(sample_weight, dtype=dtypes.float32), y_pred) math_ops.cast(sample_weight, dtype=variable_dtype), y_pred)
weights_tiled = array_ops.tile( weights_tiled = array_ops.tile(
array_ops.reshape(sample_weight, thresh_tiles), data_tiles) array_ops.reshape(sample_weight, thresh_tiles), data_tiles)
else: else: