diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 3c6bb4fb829..c5821ce8fcb 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -1178,7 +1178,7 @@ def _SigmoidGradGrad(op, grad): def _SignGrad(op, _): """Returns 0.""" x = op.inputs[0] - return array_ops.zeros(array_ops.shape(x), dtype=x.dtype) + return array_ops.zeros_like(x) @ops.RegisterGradient("Sin") @@ -1560,11 +1560,9 @@ def _MaximumMinimumGrad(op, grad, selector_op): # No gradient skipping, so do the full gradient computation pass x = op.inputs[0] - gdtype = grad.dtype sx = array_ops.shape(x) sy = array_ops.shape(y) - gradshape = array_ops.shape(grad) - zeros = array_ops.zeros(gradshape, gdtype) + zeros = array_ops.zeros_like(grad) xmask = selector_op(x, y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) if skip_input_indices is not None and 0 in skip_input_indices: diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 70b137a57a8..089297171bf 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -251,10 +251,8 @@ def _MaxPool3DGrad(op, grad): @ops.RegisterGradient("MaxPool3DGrad") def _MaxPool3DGradGrad(op, grad): - return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), - array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + return (array_ops.zeros_like(op.inputs[0]), + array_ops.zeros_like(op.inputs[1]), gen_nn_ops.max_pool3d_grad_grad( op.inputs[0], op.inputs[1], @@ -267,10 +265,8 @@ def _MaxPool3DGradGrad(op, grad): @ops.RegisterGradient("MaxPool3DGradGrad") def _MaxPool3DGradGradGrad(op, grad): - return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), - array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + return (array_ops.zeros_like(op.inputs[0]), + array_ops.zeros_like(op.inputs[1]), gen_nn_ops.max_pool3d_grad( op.inputs[0], op.inputs[1], @@ -441,8 +437,7 @@ def _Relu6Grad(op, grad): @ops.RegisterGradient("Relu6Grad") def _Relu6GradGrad(op, grad): x = op.inputs[1] - return (gen_nn_ops.relu6_grad(grad, x), - array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + return (gen_nn_ops.relu6_grad(grad, x), array_ops.zeros_like(x)) @ops.RegisterGradient("LeakyRelu") @@ -456,8 +451,8 @@ def _LeakyReluGrad(op, grad): def _LeakyReluGradGrad(op, grad): x = op.inputs[1] alpha = op.get_attr("alpha") - return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha), - array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + return (gen_nn_ops.leaky_relu_grad(grad, x, + alpha=alpha), array_ops.zeros_like(x)) @ops.RegisterGradient("Elu") @@ -496,8 +491,7 @@ def _SoftsignGrad(op, grad): @ops.RegisterGradient("ReluGrad") def _ReluGradGrad(op, grad): x = op.inputs[1] - return (gen_nn_ops.relu_grad(grad, x), - array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + return (gen_nn_ops.relu_grad(grad, x), array_ops.zeros_like(x)) def _BroadcastMul(vec, mat): @@ -721,10 +715,8 @@ def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): @ops.RegisterGradient("MaxPoolGrad") def _MaxPoolGradGrad(op, grad): - return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), - array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + return (array_ops.zeros_like(op.inputs[0]), + array_ops.zeros_like(op.inputs[1]), gen_nn_ops.max_pool_grad_grad( op.inputs[0], op.inputs[1], @@ -739,10 +731,8 @@ def _MaxPoolGradGrad(op, grad): def _MaxPoolGradGradV2(op, grad): ksize = op.inputs[3] strides = op.inputs[4] - return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), - array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + return (array_ops.zeros_like(op.inputs[0]), + array_ops.zeros_like(op.inputs[1]), gen_nn_ops.max_pool_grad_grad_v2( op.inputs[0], op.inputs[1], @@ -755,10 +745,8 @@ def _MaxPoolGradGradV2(op, grad): @ops.RegisterGradient("MaxPoolGradGrad") def _MaxPoolGradGradGrad(op, grad): - return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), - array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + return (array_ops.zeros_like(op.inputs[0]), + array_ops.zeros_like(op.inputs[1]), gen_nn_ops.max_pool_grad( op.inputs[0], op.inputs[1],