Silence warnings about expanding indexed slices into dense tensors where it is intended.
PiperOrigin-RevId: 240179107
This commit is contained in:
parent
7e1b975e85
commit
787c2fd411
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -406,7 +408,11 @@ def _GatherGrad(op, grad):
|
|||||||
indices = op.inputs[1]
|
indices = op.inputs[1]
|
||||||
size = array_ops.expand_dims(array_ops.size(indices), 0)
|
size = array_ops.expand_dims(array_ops.size(indices), 0)
|
||||||
values_shape = array_ops.concat([size, params_shape[1:]], 0)
|
values_shape = array_ops.concat([size, params_shape[1:]], 0)
|
||||||
values = array_ops.reshape(grad, values_shape)
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
|
values = array_ops.reshape(grad, values_shape)
|
||||||
indices = array_ops.reshape(indices, size)
|
indices = array_ops.reshape(indices, size)
|
||||||
return [ops.IndexedSlices(values, indices, params_shape), None]
|
return [ops.IndexedSlices(values, indices, params_shape), None]
|
||||||
|
|
||||||
@ -437,7 +443,11 @@ def _GatherV2Grad(op, grad):
|
|||||||
else:
|
else:
|
||||||
params_tail_shape = params_shape[1:]
|
params_tail_shape = params_shape[1:]
|
||||||
values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
|
values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
|
||||||
values = array_ops.reshape(grad, values_shape)
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
|
values = array_ops.reshape(grad, values_shape)
|
||||||
indices = array_ops.reshape(indices, indices_size)
|
indices = array_ops.reshape(indices, indices_size)
|
||||||
return [ops.IndexedSlices(values, indices, params_shape), None, None]
|
return [ops.IndexedSlices(values, indices, params_shape), None, None]
|
||||||
|
|
||||||
@ -451,7 +461,11 @@ def _GatherV2Grad(op, grad):
|
|||||||
outer_dims + 1 + inner_dims)
|
outer_dims + 1 + inner_dims)
|
||||||
|
|
||||||
values_shape = array_ops.concat([outer_shape, indices_size, inner_shape], 0)
|
values_shape = array_ops.concat([outer_shape, indices_size, inner_shape], 0)
|
||||||
values = array_ops.reshape(grad, values_shape)
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
|
values = array_ops.reshape(grad, values_shape)
|
||||||
indices = array_ops.reshape(indices, indices_size)
|
indices = array_ops.reshape(indices, indices_size)
|
||||||
|
|
||||||
# We need to sum up every slice `values[..., i, ....]` corresponding to
|
# We need to sum up every slice `values[..., i, ....]` corresponding to
|
||||||
@ -519,7 +533,11 @@ ops.NotDifferentiable("StopGradient")
|
|||||||
|
|
||||||
@ops.RegisterGradient("Reshape")
|
@ops.RegisterGradient("Reshape")
|
||||||
def _ReshapeGrad(op, grad):
|
def _ReshapeGrad(op, grad):
|
||||||
return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None]
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
|
return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None]
|
||||||
|
|
||||||
|
|
||||||
ops.NotDifferentiable("InvertPermutation")
|
ops.NotDifferentiable("InvertPermutation")
|
||||||
@ -527,7 +545,11 @@ ops.NotDifferentiable("InvertPermutation")
|
|||||||
|
|
||||||
def _ReshapeToInput(op, grad):
|
def _ReshapeToInput(op, grad):
|
||||||
"""Reshapes the gradient to the shape of the original input."""
|
"""Reshapes the gradient to the shape of the original input."""
|
||||||
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]))
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
|
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]))
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("ExpandDims")
|
@ops.RegisterGradient("ExpandDims")
|
||||||
@ -783,10 +805,14 @@ def _ExtractImagePatchesGrad(op, grad):
|
|||||||
(1, 0),
|
(1, 0),
|
||||||
(input_indices_num - 1, output_indices_num))
|
(input_indices_num - 1, output_indices_num))
|
||||||
|
|
||||||
grad_expanded = array_ops.transpose(
|
with warnings.catch_warnings():
|
||||||
array_ops.reshape(
|
warnings.filterwarnings(
|
||||||
grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
|
"ignore",
|
||||||
(1, 2, 3, 4, 0, 5))
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
|
grad_expanded = array_ops.transpose(
|
||||||
|
array_ops.reshape(
|
||||||
|
grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
|
||||||
|
(1, 2, 3, 4, 0, 5))
|
||||||
grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
|
grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
|
||||||
|
|
||||||
jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
|
jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
|
||||||
@ -844,10 +870,14 @@ def _ExtractVolumePatchesGrad(op, grad):
|
|||||||
sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
|
sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
|
||||||
(input_indices_num - 1, output_indices_num))
|
(input_indices_num - 1, output_indices_num))
|
||||||
|
|
||||||
grad_expanded = array_ops.transpose(
|
with warnings.catch_warnings():
|
||||||
array_ops.reshape(grad, (batch_size, planes_out, rows_out, cols_out,
|
warnings.filterwarnings(
|
||||||
ksize_p, ksize_r, ksize_c, channels)),
|
"ignore",
|
||||||
(1, 2, 3, 4, 5, 6, 0, 7))
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
|
grad_expanded = array_ops.transpose(
|
||||||
|
array_ops.reshape(grad, (batch_size, planes_out, rows_out, cols_out,
|
||||||
|
ksize_p, ksize_r, ksize_c, channels)),
|
||||||
|
(1, 2, 3, 4, 5, 6, 0, 7))
|
||||||
grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
|
grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
|
||||||
|
|
||||||
jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
|
jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
|
||||||
|
Loading…
Reference in New Issue
Block a user