Silence warnings about expanding indexed slices into dense tensors where it is intended.

PiperOrigin-RevId: 240179107
This commit is contained in:
A. Unique TensorFlower 2019-03-25 11:12:04 -07:00 committed by TensorFlower Gardener
parent 7e1b975e85
commit 787c2fd411

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@ -406,7 +408,11 @@ def _GatherGrad(op, grad):
indices = op.inputs[1]
size = array_ops.expand_dims(array_ops.size(indices), 0)
values_shape = array_ops.concat([size, params_shape[1:]], 0)
values = array_ops.reshape(grad, values_shape)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Converting sparse IndexedSlices to a dense Tensor.*")
values = array_ops.reshape(grad, values_shape)
indices = array_ops.reshape(indices, size)
return [ops.IndexedSlices(values, indices, params_shape), None]
@ -437,7 +443,11 @@ def _GatherV2Grad(op, grad):
else:
params_tail_shape = params_shape[1:]
values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
values = array_ops.reshape(grad, values_shape)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Converting sparse IndexedSlices to a dense Tensor.*")
values = array_ops.reshape(grad, values_shape)
indices = array_ops.reshape(indices, indices_size)
return [ops.IndexedSlices(values, indices, params_shape), None, None]
@ -451,7 +461,11 @@ def _GatherV2Grad(op, grad):
outer_dims + 1 + inner_dims)
values_shape = array_ops.concat([outer_shape, indices_size, inner_shape], 0)
values = array_ops.reshape(grad, values_shape)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Converting sparse IndexedSlices to a dense Tensor.*")
values = array_ops.reshape(grad, values_shape)
indices = array_ops.reshape(indices, indices_size)
# We need to sum up every slice `values[..., i, ....]` corresponding to
@ -519,7 +533,11 @@ ops.NotDifferentiable("StopGradient")
@ops.RegisterGradient("Reshape")
def _ReshapeGrad(op, grad):
return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None]
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Converting sparse IndexedSlices to a dense Tensor.*")
return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None]
ops.NotDifferentiable("InvertPermutation")
@ -527,7 +545,11 @@ ops.NotDifferentiable("InvertPermutation")
def _ReshapeToInput(op, grad):
"""Reshapes the gradient to the shape of the original input."""
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]))
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Converting sparse IndexedSlices to a dense Tensor.*")
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]))
@ops.RegisterGradient("ExpandDims")
@ -783,10 +805,14 @@ def _ExtractImagePatchesGrad(op, grad):
(1, 0),
(input_indices_num - 1, output_indices_num))
grad_expanded = array_ops.transpose(
array_ops.reshape(
grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
(1, 2, 3, 4, 0, 5))
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Converting sparse IndexedSlices to a dense Tensor.*")
grad_expanded = array_ops.transpose(
array_ops.reshape(
grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
(1, 2, 3, 4, 0, 5))
grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
@ -844,10 +870,14 @@ def _ExtractVolumePatchesGrad(op, grad):
sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
(input_indices_num - 1, output_indices_num))
grad_expanded = array_ops.transpose(
array_ops.reshape(grad, (batch_size, planes_out, rows_out, cols_out,
ksize_p, ksize_r, ksize_c, channels)),
(1, 2, 3, 4, 5, 6, 0, 7))
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Converting sparse IndexedSlices to a dense Tensor.*")
grad_expanded = array_ops.transpose(
array_ops.reshape(grad, (batch_size, planes_out, rows_out, cols_out,
ksize_p, ksize_r, ksize_c, channels)),
(1, 2, 3, 4, 5, 6, 0, 7))
grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)