parent
92e7793661
commit
df849b7767
@ -34,6 +34,7 @@ from tensorflow.python.ops import metrics_impl
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import weights_broadcast_ops
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
|
||||
|
||||
@ -651,7 +652,7 @@ def _streaming_confusion_matrix_at_thresholds(
|
||||
label_is_neg = math_ops.logical_not(label_is_pos)
|
||||
|
||||
if weights is not None:
|
||||
broadcast_weights = _broadcast_weights(
|
||||
broadcast_weights = weights_broadcast_ops.broadcast_weights(
|
||||
math_ops.to_float(weights), predictions)
|
||||
weights_tiled = array_ops.tile(array_ops.reshape(
|
||||
broadcast_weights, [1, -1]), [num_thresholds, 1])
|
||||
@ -1924,7 +1925,7 @@ def streaming_covariance(predictions,
|
||||
weighted_predictions = predictions
|
||||
weighted_labels = labels
|
||||
else:
|
||||
weights = _broadcast_weights(weights, labels)
|
||||
weights = weights_broadcast_ops.broadcast_weights(weights, labels)
|
||||
batch_count = math_ops.reduce_sum(weights) # n_B in eqn
|
||||
weighted_predictions = math_ops.multiply(predictions, weights)
|
||||
weighted_labels = math_ops.multiply(labels, weights)
|
||||
@ -2051,7 +2052,7 @@ def streaming_pearson_correlation(predictions,
|
||||
# Broadcast weights here to avoid duplicate broadcasting in each call to
|
||||
# `streaming_covariance`.
|
||||
if weights is not None:
|
||||
weights = _broadcast_weights(weights, labels)
|
||||
weights = weights_broadcast_ops.broadcast_weights(weights, labels)
|
||||
cov, update_cov = streaming_covariance(
|
||||
predictions, labels, weights=weights, name='covariance')
|
||||
var_predictions, update_var_predictions = streaming_covariance(
|
||||
|
Loading…
Reference in New Issue
Block a user