Use array_ops.where_v2 in _SelectGradV2
The usages of array_ops.where_v2 and array_ops.where are the same as else and then are None in both cases. Use array_ops.where_v2 for consistency. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
2c27f64b94
commit
518a2bad14
@ -1321,7 +1321,7 @@ def _SelectGradV2(op, grad):
|
||||
# Reduce away broadcasted leading dims.
|
||||
gx = math_ops.reduce_sum(gx, axis=math_ops.range(rankdiff_x))
|
||||
# Reduce but keep x's 1-valued dims which were broadcast.
|
||||
axis = array_ops.where(gx_shape[rankdiff_x:] > x_shape)
|
||||
axis = array_ops.where_v2(gx_shape[rankdiff_x:] > x_shape)
|
||||
# tf.where returns 2D so squeeze.
|
||||
axis = array_ops.squeeze(axis)
|
||||
gx = math_ops.reduce_sum(gx, keepdims=True, axis=axis)
|
||||
@ -1333,7 +1333,7 @@ def _SelectGradV2(op, grad):
|
||||
# Reduce away broadcasted leading dims.
|
||||
gy = math_ops.reduce_sum(gy, axis=math_ops.range(rankdiff_y))
|
||||
# Reduce but keep y's 1-valued dims which were broadcast.
|
||||
axis = array_ops.where(gy_shape[rankdiff_y:] > y_shape)
|
||||
axis = array_ops.where_v2(gy_shape[rankdiff_y:] > y_shape)
|
||||
# tf.where returns 2D so squeeze.
|
||||
axis = array_ops.squeeze(axis)
|
||||
gy = math_ops.reduce_sum(gy, keepdims=True, axis=axis)
|
||||
|
Loading…
Reference in New Issue
Block a user