Automated rollback of commit a7e7582e9a2b698054bf93aa27e53ebbc081d1a6. Revert #31106.
PiperOrigin-RevId: 263621326
This commit is contained in:
parent
e3c100fdce
commit
9cfcec4123
@ -192,26 +192,22 @@ def _SumGrad(op, grad):
|
||||
return [array_ops.tile(grad, tile_scaling), None]
|
||||
|
||||
input_shape = array_ops.shape(op.inputs[0])
|
||||
|
||||
if not op.get_attr("keep_dims"):
|
||||
# TODO(apassos) remove this once device placement for eager ops makes more
|
||||
# sense.
|
||||
with ops.colocate_with(input_shape):
|
||||
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
||||
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
||||
return [array_ops.broadcast_to(grad, input_shape), None]
|
||||
# TODO(apassos) remove this once device placement for eager ops makes more
|
||||
# sense.
|
||||
with ops.colocate_with(input_shape):
|
||||
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
||||
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
||||
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
||||
return [array_ops.tile(grad, tile_scaling), None]
|
||||
|
||||
|
||||
def _MinOrMaxGrad(op, grad):
|
||||
"""Gradient for Min or Max. Amazingly it's precisely the same code."""
|
||||
input_shape = array_ops.shape(op.inputs[0])
|
||||
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
||||
y = op.outputs[0]
|
||||
if not op.get_attr("keep_dims"):
|
||||
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
||||
y = array_ops.reshape(y, output_shape_kept_dims)
|
||||
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
||||
else:
|
||||
output_shape_kept_dims = array_ops.shape(y)
|
||||
y = array_ops.reshape(y, output_shape_kept_dims)
|
||||
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
||||
|
||||
# Compute the number of selected (maximum or minimum) elements in each
|
||||
# reduction dimension. If there are multiple minimum or maximum elements
|
||||
@ -267,11 +263,10 @@ def _ProdGrad(op, grad):
|
||||
reduction_indices = array_ops.reshape(op.inputs[1], [-1])
|
||||
|
||||
# Expand grad to full input shape
|
||||
if not op.get_attr("keep_dims"):
|
||||
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
||||
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
||||
|
||||
grad = array_ops.broadcast_to(grad, input_shape)
|
||||
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
||||
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
||||
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
||||
grad = array_ops.tile(grad, tile_scaling)
|
||||
|
||||
# Pack all reduced dimensions into a single one, so we can perform the
|
||||
# cumprod ops. If the reduction dims list is empty, it defaults to float32,
|
||||
|
Loading…
x
Reference in New Issue
Block a user