Merge pull request #47416 from lgeiger:zeros-like
PiperOrigin-RevId: 359774166 Change-Id: I2d142ebc30e75a8aba13600a34b2c9428079e4d6
This commit is contained in:
commit
7a1b65b613
@ -1178,7 +1178,7 @@ def _SigmoidGradGrad(op, grad):
|
|||||||
def _SignGrad(op, _):
|
def _SignGrad(op, _):
|
||||||
"""Returns 0."""
|
"""Returns 0."""
|
||||||
x = op.inputs[0]
|
x = op.inputs[0]
|
||||||
return array_ops.zeros(array_ops.shape(x), dtype=x.dtype)
|
return array_ops.zeros_like(x)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Sin")
|
@ops.RegisterGradient("Sin")
|
||||||
@ -1560,11 +1560,9 @@ def _MaximumMinimumGrad(op, grad, selector_op):
|
|||||||
# No gradient skipping, so do the full gradient computation
|
# No gradient skipping, so do the full gradient computation
|
||||||
pass
|
pass
|
||||||
x = op.inputs[0]
|
x = op.inputs[0]
|
||||||
gdtype = grad.dtype
|
|
||||||
sx = array_ops.shape(x)
|
sx = array_ops.shape(x)
|
||||||
sy = array_ops.shape(y)
|
sy = array_ops.shape(y)
|
||||||
gradshape = array_ops.shape(grad)
|
zeros = array_ops.zeros_like(grad)
|
||||||
zeros = array_ops.zeros(gradshape, gdtype)
|
|
||||||
xmask = selector_op(x, y)
|
xmask = selector_op(x, y)
|
||||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||||
if skip_input_indices is not None and 0 in skip_input_indices:
|
if skip_input_indices is not None and 0 in skip_input_indices:
|
||||||
|
@ -251,10 +251,8 @@ def _MaxPool3DGrad(op, grad):
|
|||||||
|
|
||||||
@ops.RegisterGradient("MaxPool3DGrad")
|
@ops.RegisterGradient("MaxPool3DGrad")
|
||||||
def _MaxPool3DGradGrad(op, grad):
|
def _MaxPool3DGradGrad(op, grad):
|
||||||
return (array_ops.zeros(
|
return (array_ops.zeros_like(op.inputs[0]),
|
||||||
shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
|
array_ops.zeros_like(op.inputs[1]),
|
||||||
array_ops.zeros(
|
|
||||||
shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
|
|
||||||
gen_nn_ops.max_pool3d_grad_grad(
|
gen_nn_ops.max_pool3d_grad_grad(
|
||||||
op.inputs[0],
|
op.inputs[0],
|
||||||
op.inputs[1],
|
op.inputs[1],
|
||||||
@ -267,10 +265,8 @@ def _MaxPool3DGradGrad(op, grad):
|
|||||||
|
|
||||||
@ops.RegisterGradient("MaxPool3DGradGrad")
|
@ops.RegisterGradient("MaxPool3DGradGrad")
|
||||||
def _MaxPool3DGradGradGrad(op, grad):
|
def _MaxPool3DGradGradGrad(op, grad):
|
||||||
return (array_ops.zeros(
|
return (array_ops.zeros_like(op.inputs[0]),
|
||||||
shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
|
array_ops.zeros_like(op.inputs[1]),
|
||||||
array_ops.zeros(
|
|
||||||
shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
|
|
||||||
gen_nn_ops.max_pool3d_grad(
|
gen_nn_ops.max_pool3d_grad(
|
||||||
op.inputs[0],
|
op.inputs[0],
|
||||||
op.inputs[1],
|
op.inputs[1],
|
||||||
@ -441,8 +437,7 @@ def _Relu6Grad(op, grad):
|
|||||||
@ops.RegisterGradient("Relu6Grad")
|
@ops.RegisterGradient("Relu6Grad")
|
||||||
def _Relu6GradGrad(op, grad):
|
def _Relu6GradGrad(op, grad):
|
||||||
x = op.inputs[1]
|
x = op.inputs[1]
|
||||||
return (gen_nn_ops.relu6_grad(grad, x),
|
return (gen_nn_ops.relu6_grad(grad, x), array_ops.zeros_like(x))
|
||||||
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("LeakyRelu")
|
@ops.RegisterGradient("LeakyRelu")
|
||||||
@ -456,8 +451,8 @@ def _LeakyReluGrad(op, grad):
|
|||||||
def _LeakyReluGradGrad(op, grad):
|
def _LeakyReluGradGrad(op, grad):
|
||||||
x = op.inputs[1]
|
x = op.inputs[1]
|
||||||
alpha = op.get_attr("alpha")
|
alpha = op.get_attr("alpha")
|
||||||
return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha),
|
return (gen_nn_ops.leaky_relu_grad(grad, x,
|
||||||
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
|
alpha=alpha), array_ops.zeros_like(x))
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Elu")
|
@ops.RegisterGradient("Elu")
|
||||||
@ -496,8 +491,7 @@ def _SoftsignGrad(op, grad):
|
|||||||
@ops.RegisterGradient("ReluGrad")
|
@ops.RegisterGradient("ReluGrad")
|
||||||
def _ReluGradGrad(op, grad):
|
def _ReluGradGrad(op, grad):
|
||||||
x = op.inputs[1]
|
x = op.inputs[1]
|
||||||
return (gen_nn_ops.relu_grad(grad, x),
|
return (gen_nn_ops.relu_grad(grad, x), array_ops.zeros_like(x))
|
||||||
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
|
|
||||||
|
|
||||||
|
|
||||||
def _BroadcastMul(vec, mat):
|
def _BroadcastMul(vec, mat):
|
||||||
@ -721,10 +715,8 @@ def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
|
|||||||
|
|
||||||
@ops.RegisterGradient("MaxPoolGrad")
|
@ops.RegisterGradient("MaxPoolGrad")
|
||||||
def _MaxPoolGradGrad(op, grad):
|
def _MaxPoolGradGrad(op, grad):
|
||||||
return (array_ops.zeros(
|
return (array_ops.zeros_like(op.inputs[0]),
|
||||||
shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
|
array_ops.zeros_like(op.inputs[1]),
|
||||||
array_ops.zeros(
|
|
||||||
shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
|
|
||||||
gen_nn_ops.max_pool_grad_grad(
|
gen_nn_ops.max_pool_grad_grad(
|
||||||
op.inputs[0],
|
op.inputs[0],
|
||||||
op.inputs[1],
|
op.inputs[1],
|
||||||
@ -739,10 +731,8 @@ def _MaxPoolGradGrad(op, grad):
|
|||||||
def _MaxPoolGradGradV2(op, grad):
|
def _MaxPoolGradGradV2(op, grad):
|
||||||
ksize = op.inputs[3]
|
ksize = op.inputs[3]
|
||||||
strides = op.inputs[4]
|
strides = op.inputs[4]
|
||||||
return (array_ops.zeros(
|
return (array_ops.zeros_like(op.inputs[0]),
|
||||||
shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
|
array_ops.zeros_like(op.inputs[1]),
|
||||||
array_ops.zeros(
|
|
||||||
shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
|
|
||||||
gen_nn_ops.max_pool_grad_grad_v2(
|
gen_nn_ops.max_pool_grad_grad_v2(
|
||||||
op.inputs[0],
|
op.inputs[0],
|
||||||
op.inputs[1],
|
op.inputs[1],
|
||||||
@ -755,10 +745,8 @@ def _MaxPoolGradGradV2(op, grad):
|
|||||||
|
|
||||||
@ops.RegisterGradient("MaxPoolGradGrad")
|
@ops.RegisterGradient("MaxPoolGradGrad")
|
||||||
def _MaxPoolGradGradGrad(op, grad):
|
def _MaxPoolGradGradGrad(op, grad):
|
||||||
return (array_ops.zeros(
|
return (array_ops.zeros_like(op.inputs[0]),
|
||||||
shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
|
array_ops.zeros_like(op.inputs[1]),
|
||||||
array_ops.zeros(
|
|
||||||
shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
|
|
||||||
gen_nn_ops.max_pool_grad(
|
gen_nn_ops.max_pool_grad(
|
||||||
op.inputs[0],
|
op.inputs[0],
|
||||||
op.inputs[1],
|
op.inputs[1],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user