Update metrics_op.py (#12586)

Using core method over private
This commit is contained in:
Alan Yee 2017-09-08 10:31:15 -07:00 committed by Yifei Feng
parent 92e7793661
commit df849b7767

View File

@ -34,6 +34,7 @@ from tensorflow.python.ops import metrics_impl
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.util.deprecation import deprecated
@ -651,7 +652,7 @@ def _streaming_confusion_matrix_at_thresholds(
label_is_neg = math_ops.logical_not(label_is_pos)
if weights is not None:
broadcast_weights = _broadcast_weights(
broadcast_weights = weights_broadcast_ops.broadcast_weights(
math_ops.to_float(weights), predictions)
weights_tiled = array_ops.tile(array_ops.reshape(
broadcast_weights, [1, -1]), [num_thresholds, 1])
@ -1924,7 +1925,7 @@ def streaming_covariance(predictions,
weighted_predictions = predictions
weighted_labels = labels
else:
weights = _broadcast_weights(weights, labels)
weights = weights_broadcast_ops.broadcast_weights(weights, labels)
batch_count = math_ops.reduce_sum(weights) # n_B in eqn
weighted_predictions = math_ops.multiply(predictions, weights)
weighted_labels = math_ops.multiply(labels, weights)
@ -2051,7 +2052,7 @@ def streaming_pearson_correlation(predictions,
# Broadcast weights here to avoid duplicate broadcasting in each call to
# `streaming_covariance`.
if weights is not None:
weights = _broadcast_weights(weights, labels)
weights = weights_broadcast_ops.broadcast_weights(weights, labels)
cov, update_cov = streaming_covariance(
predictions, labels, weights=weights, name='covariance')
var_predictions, update_var_predictions = streaming_covariance(