nn_impl.py cleanup: used keepdims instead of deprecated keep_dims.
PiperOrigin-RevId: 177972555
This commit is contained in:
parent
f88cd91955
commit
c72bb97541
@ -341,7 +341,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
|
|||||||
with ops.name_scope(name, "l2_normalize", [x]) as name:
|
with ops.name_scope(name, "l2_normalize", [x]) as name:
|
||||||
axis = deprecated_argument_lookup("axis", axis, "dim", dim)
|
axis = deprecated_argument_lookup("axis", axis, "dim", dim)
|
||||||
x = ops.convert_to_tensor(x, name="x")
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keep_dims=True)
|
square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
|
||||||
x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
|
x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
|
||||||
return math_ops.multiply(x, x_inv_norm, name=name)
|
return math_ops.multiply(x, x_inv_norm, name=name)
|
||||||
|
|
||||||
@ -593,8 +593,8 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
|
|||||||
else: # no shift.
|
else: # no shift.
|
||||||
m_ss = x
|
m_ss = x
|
||||||
v_ss = math_ops.square(x)
|
v_ss = math_ops.square(x)
|
||||||
m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
|
m_ss = math_ops.reduce_sum(m_ss, axes, keepdims=keep_dims, name="mean_ss")
|
||||||
v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
|
v_ss = math_ops.reduce_sum(v_ss, axes, keepdims=keep_dims, name="var_ss")
|
||||||
return counts, m_ss, v_ss, shift
|
return counts, m_ss, v_ss, shift
|
||||||
|
|
||||||
|
|
||||||
@ -664,12 +664,12 @@ def moments(x, axes,
|
|||||||
# on 32-bit floats before converting the mean and variance back to fp16
|
# on 32-bit floats before converting the mean and variance back to fp16
|
||||||
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
|
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
|
||||||
# Compute true mean while keeping the dims for proper broadcasting.
|
# Compute true mean while keeping the dims for proper broadcasting.
|
||||||
mean = math_ops.reduce_mean(y, axes, keep_dims=True, name="mean")
|
mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
|
||||||
# sample variance, not unbiased variance
|
# sample variance, not unbiased variance
|
||||||
variance = math_ops.reduce_mean(
|
variance = math_ops.reduce_mean(
|
||||||
math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
|
math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
|
||||||
axes,
|
axes,
|
||||||
keep_dims=True,
|
keepdims=True,
|
||||||
name="variance")
|
name="variance")
|
||||||
if not keep_dims:
|
if not keep_dims:
|
||||||
mean = array_ops.squeeze(mean, axes)
|
mean = array_ops.squeeze(mean, axes)
|
||||||
@ -714,7 +714,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
|
|||||||
# Note that we use keep_dims=True for our reductions regardless of the arg;
|
# Note that we use keep_dims=True for our reductions regardless of the arg;
|
||||||
# this is so that the results remain broadcast-compatible with the inputs.
|
# this is so that the results remain broadcast-compatible with the inputs.
|
||||||
weighted_input_sum = math_ops.reduce_sum(
|
weighted_input_sum = math_ops.reduce_sum(
|
||||||
frequency_weights * x, axes, name="weighted_input_sum", keep_dims=True)
|
frequency_weights * x, axes, name="weighted_input_sum", keepdims=True)
|
||||||
|
|
||||||
# The shape of the weights isn't necessarily the same as x's
|
# The shape of the weights isn't necessarily the same as x's
|
||||||
# shape, just broadcast-compatible with it -- so this expression
|
# shape, just broadcast-compatible with it -- so this expression
|
||||||
@ -725,7 +725,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
|
|||||||
broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
|
broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
|
||||||
|
|
||||||
sum_of_weights = math_ops.reduce_sum(
|
sum_of_weights = math_ops.reduce_sum(
|
||||||
broadcasted_weights, axes, name="sum_of_weights", keep_dims=True)
|
broadcasted_weights, axes, name="sum_of_weights", keepdims=True)
|
||||||
|
|
||||||
divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum")
|
divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum")
|
||||||
|
|
||||||
@ -736,7 +736,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
|
|||||||
frequency_weights * math_ops.squared_difference(x, weighted_mean),
|
frequency_weights * math_ops.squared_difference(x, weighted_mean),
|
||||||
axes,
|
axes,
|
||||||
name="weighted_distsq",
|
name="weighted_distsq",
|
||||||
keep_dims=True)
|
keepdims=True)
|
||||||
|
|
||||||
weighted_variance = math_ops.multiply(weighted_distsq, divisor)
|
weighted_variance = math_ops.multiply(weighted_distsq, divisor)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user