parent
92e7793661
commit
df849b7767
@ -34,6 +34,7 @@ from tensorflow.python.ops import metrics_impl
|
|||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import weights_broadcast_ops
|
||||||
from tensorflow.python.util.deprecation import deprecated
|
from tensorflow.python.util.deprecation import deprecated
|
||||||
|
|
||||||
|
|
||||||
@ -651,7 +652,7 @@ def _streaming_confusion_matrix_at_thresholds(
|
|||||||
label_is_neg = math_ops.logical_not(label_is_pos)
|
label_is_neg = math_ops.logical_not(label_is_pos)
|
||||||
|
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
broadcast_weights = _broadcast_weights(
|
broadcast_weights = weights_broadcast_ops.broadcast_weights(
|
||||||
math_ops.to_float(weights), predictions)
|
math_ops.to_float(weights), predictions)
|
||||||
weights_tiled = array_ops.tile(array_ops.reshape(
|
weights_tiled = array_ops.tile(array_ops.reshape(
|
||||||
broadcast_weights, [1, -1]), [num_thresholds, 1])
|
broadcast_weights, [1, -1]), [num_thresholds, 1])
|
||||||
@ -1924,7 +1925,7 @@ def streaming_covariance(predictions,
|
|||||||
weighted_predictions = predictions
|
weighted_predictions = predictions
|
||||||
weighted_labels = labels
|
weighted_labels = labels
|
||||||
else:
|
else:
|
||||||
weights = _broadcast_weights(weights, labels)
|
weights = weights_broadcast_ops.broadcast_weights(weights, labels)
|
||||||
batch_count = math_ops.reduce_sum(weights) # n_B in eqn
|
batch_count = math_ops.reduce_sum(weights) # n_B in eqn
|
||||||
weighted_predictions = math_ops.multiply(predictions, weights)
|
weighted_predictions = math_ops.multiply(predictions, weights)
|
||||||
weighted_labels = math_ops.multiply(labels, weights)
|
weighted_labels = math_ops.multiply(labels, weights)
|
||||||
@ -2051,7 +2052,7 @@ def streaming_pearson_correlation(predictions,
|
|||||||
# Broadcast weights here to avoid duplicate broadcasting in each call to
|
# Broadcast weights here to avoid duplicate broadcasting in each call to
|
||||||
# `streaming_covariance`.
|
# `streaming_covariance`.
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
weights = _broadcast_weights(weights, labels)
|
weights = weights_broadcast_ops.broadcast_weights(weights, labels)
|
||||||
cov, update_cov = streaming_covariance(
|
cov, update_cov = streaming_covariance(
|
||||||
predictions, labels, weights=weights, name='covariance')
|
predictions, labels, weights=weights, name='covariance')
|
||||||
var_predictions, update_var_predictions = streaming_covariance(
|
var_predictions, update_var_predictions = streaming_covariance(
|
||||||
|
Loading…
Reference in New Issue
Block a user