Automated rollback of commit a7e7582e9a2b698054bf93aa27e53ebbc081d1a6. Revert #31106.
PiperOrigin-RevId: 263621326
This commit is contained in:
parent
e3c100fdce
commit
9cfcec4123
@ -192,26 +192,22 @@ def _SumGrad(op, grad):
|
|||||||
return [array_ops.tile(grad, tile_scaling), None]
|
return [array_ops.tile(grad, tile_scaling), None]
|
||||||
|
|
||||||
input_shape = array_ops.shape(op.inputs[0])
|
input_shape = array_ops.shape(op.inputs[0])
|
||||||
|
# TODO(apassos) remove this once device placement for eager ops makes more
|
||||||
if not op.get_attr("keep_dims"):
|
# sense.
|
||||||
# TODO(apassos) remove this once device placement for eager ops makes more
|
with ops.colocate_with(input_shape):
|
||||||
# sense.
|
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
||||||
with ops.colocate_with(input_shape):
|
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
||||||
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
||||||
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
return [array_ops.tile(grad, tile_scaling), None]
|
||||||
return [array_ops.broadcast_to(grad, input_shape), None]
|
|
||||||
|
|
||||||
|
|
||||||
def _MinOrMaxGrad(op, grad):
|
def _MinOrMaxGrad(op, grad):
|
||||||
"""Gradient for Min or Max. Amazingly it's precisely the same code."""
|
"""Gradient for Min or Max. Amazingly it's precisely the same code."""
|
||||||
input_shape = array_ops.shape(op.inputs[0])
|
input_shape = array_ops.shape(op.inputs[0])
|
||||||
|
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
||||||
y = op.outputs[0]
|
y = op.outputs[0]
|
||||||
if not op.get_attr("keep_dims"):
|
y = array_ops.reshape(y, output_shape_kept_dims)
|
||||||
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
||||||
y = array_ops.reshape(y, output_shape_kept_dims)
|
|
||||||
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
|
||||||
else:
|
|
||||||
output_shape_kept_dims = array_ops.shape(y)
|
|
||||||
|
|
||||||
# Compute the number of selected (maximum or minimum) elements in each
|
# Compute the number of selected (maximum or minimum) elements in each
|
||||||
# reduction dimension. If there are multiple minimum or maximum elements
|
# reduction dimension. If there are multiple minimum or maximum elements
|
||||||
@ -267,11 +263,10 @@ def _ProdGrad(op, grad):
|
|||||||
reduction_indices = array_ops.reshape(op.inputs[1], [-1])
|
reduction_indices = array_ops.reshape(op.inputs[1], [-1])
|
||||||
|
|
||||||
# Expand grad to full input shape
|
# Expand grad to full input shape
|
||||||
if not op.get_attr("keep_dims"):
|
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
||||||
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
||||||
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
||||||
|
grad = array_ops.tile(grad, tile_scaling)
|
||||||
grad = array_ops.broadcast_to(grad, input_shape)
|
|
||||||
|
|
||||||
# Pack all reduced dimensions into a single one, so we can perform the
|
# Pack all reduced dimensions into a single one, so we can perform the
|
||||||
# cumprod ops. If the reduction dims list is empty, it defaults to float32,
|
# cumprod ops. If the reduction dims list is empty, it defaults to float32,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user