Use array_ops.where_v2 in _SelectGradV2

The usages of array_ops.where_v2 and array_ops.where
are the same as else and then are None in both cases.
Use array_ops.where_v2 for consistency.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2019-05-14 05:06:03 +00:00
parent 2c27f64b94
commit 518a2bad14

View File

@ -1321,7 +1321,7 @@ def _SelectGradV2(op, grad):
# Reduce away broadcasted leading dims.
gx = math_ops.reduce_sum(gx, axis=math_ops.range(rankdiff_x))
# Reduce but keep x's 1-valued dims which were broadcast.
axis = array_ops.where(gx_shape[rankdiff_x:] > x_shape)
axis = array_ops.where_v2(gx_shape[rankdiff_x:] > x_shape)
# tf.where returns 2D so squeeze.
axis = array_ops.squeeze(axis)
gx = math_ops.reduce_sum(gx, keepdims=True, axis=axis)
@ -1333,7 +1333,7 @@ def _SelectGradV2(op, grad):
# Reduce away broadcasted leading dims.
gy = math_ops.reduce_sum(gy, axis=math_ops.range(rankdiff_y))
# Reduce but keep y's 1-valued dims which were broadcast.
axis = array_ops.where(gy_shape[rankdiff_y:] > y_shape)
axis = array_ops.where_v2(gy_shape[rankdiff_y:] > y_shape)
# tf.where returns 2D so squeeze.
axis = array_ops.squeeze(axis)
gy = math_ops.reduce_sum(gy, keepdims=True, axis=axis)