nn_impl.py cleanup: used keepdims instead of deprecated keep_dims.
PiperOrigin-RevId: 177972555
This commit is contained in:
parent
f88cd91955
commit
c72bb97541
@ -341,7 +341,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
|
||||
with ops.name_scope(name, "l2_normalize", [x]) as name:
|
||||
axis = deprecated_argument_lookup("axis", axis, "dim", dim)
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keep_dims=True)
|
||||
square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
|
||||
x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
|
||||
return math_ops.multiply(x, x_inv_norm, name=name)
|
||||
|
||||
@ -593,8 +593,8 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
|
||||
else: # no shift.
|
||||
m_ss = x
|
||||
v_ss = math_ops.square(x)
|
||||
m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
|
||||
v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
|
||||
m_ss = math_ops.reduce_sum(m_ss, axes, keepdims=keep_dims, name="mean_ss")
|
||||
v_ss = math_ops.reduce_sum(v_ss, axes, keepdims=keep_dims, name="var_ss")
|
||||
return counts, m_ss, v_ss, shift
|
||||
|
||||
|
||||
@ -664,12 +664,12 @@ def moments(x, axes,
|
||||
# on 32-bit floats before converting the mean and variance back to fp16
|
||||
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
|
||||
# Compute true mean while keeping the dims for proper broadcasting.
|
||||
mean = math_ops.reduce_mean(y, axes, keep_dims=True, name="mean")
|
||||
mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
|
||||
# sample variance, not unbiased variance
|
||||
variance = math_ops.reduce_mean(
|
||||
math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
|
||||
axes,
|
||||
keep_dims=True,
|
||||
keepdims=True,
|
||||
name="variance")
|
||||
if not keep_dims:
|
||||
mean = array_ops.squeeze(mean, axes)
|
||||
@ -714,7 +714,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
|
||||
# Note that we use keep_dims=True for our reductions regardless of the arg;
|
||||
# this is so that the results remain broadcast-compatible with the inputs.
|
||||
weighted_input_sum = math_ops.reduce_sum(
|
||||
frequency_weights * x, axes, name="weighted_input_sum", keep_dims=True)
|
||||
frequency_weights * x, axes, name="weighted_input_sum", keepdims=True)
|
||||
|
||||
# The shape of the weights isn't necessarily the same as x's
|
||||
# shape, just broadcast-compatible with it -- so this expression
|
||||
@ -725,7 +725,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
|
||||
broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
|
||||
|
||||
sum_of_weights = math_ops.reduce_sum(
|
||||
broadcasted_weights, axes, name="sum_of_weights", keep_dims=True)
|
||||
broadcasted_weights, axes, name="sum_of_weights", keepdims=True)
|
||||
|
||||
divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum")
|
||||
|
||||
@ -736,7 +736,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
|
||||
frequency_weights * math_ops.squared_difference(x, weighted_mean),
|
||||
axes,
|
||||
name="weighted_distsq",
|
||||
keep_dims=True)
|
||||
keepdims=True)
|
||||
|
||||
weighted_variance = math_ops.multiply(weighted_distsq, divisor)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user