Silence warnings about expanding indexed slices into dense tensors where it is intended.
PiperOrigin-RevId: 240179107
This commit is contained in:
parent
7e1b975e85
commit
787c2fd411
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -406,6 +408,10 @@ def _GatherGrad(op, grad):
|
|||||||
indices = op.inputs[1]
|
indices = op.inputs[1]
|
||||||
size = array_ops.expand_dims(array_ops.size(indices), 0)
|
size = array_ops.expand_dims(array_ops.size(indices), 0)
|
||||||
values_shape = array_ops.concat([size, params_shape[1:]], 0)
|
values_shape = array_ops.concat([size, params_shape[1:]], 0)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
values = array_ops.reshape(grad, values_shape)
|
values = array_ops.reshape(grad, values_shape)
|
||||||
indices = array_ops.reshape(indices, size)
|
indices = array_ops.reshape(indices, size)
|
||||||
return [ops.IndexedSlices(values, indices, params_shape), None]
|
return [ops.IndexedSlices(values, indices, params_shape), None]
|
||||||
@ -437,6 +443,10 @@ def _GatherV2Grad(op, grad):
|
|||||||
else:
|
else:
|
||||||
params_tail_shape = params_shape[1:]
|
params_tail_shape = params_shape[1:]
|
||||||
values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
|
values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
values = array_ops.reshape(grad, values_shape)
|
values = array_ops.reshape(grad, values_shape)
|
||||||
indices = array_ops.reshape(indices, indices_size)
|
indices = array_ops.reshape(indices, indices_size)
|
||||||
return [ops.IndexedSlices(values, indices, params_shape), None, None]
|
return [ops.IndexedSlices(values, indices, params_shape), None, None]
|
||||||
@ -451,6 +461,10 @@ def _GatherV2Grad(op, grad):
|
|||||||
outer_dims + 1 + inner_dims)
|
outer_dims + 1 + inner_dims)
|
||||||
|
|
||||||
values_shape = array_ops.concat([outer_shape, indices_size, inner_shape], 0)
|
values_shape = array_ops.concat([outer_shape, indices_size, inner_shape], 0)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
values = array_ops.reshape(grad, values_shape)
|
values = array_ops.reshape(grad, values_shape)
|
||||||
indices = array_ops.reshape(indices, indices_size)
|
indices = array_ops.reshape(indices, indices_size)
|
||||||
|
|
||||||
@ -519,6 +533,10 @@ ops.NotDifferentiable("StopGradient")
|
|||||||
|
|
||||||
@ops.RegisterGradient("Reshape")
|
@ops.RegisterGradient("Reshape")
|
||||||
def _ReshapeGrad(op, grad):
|
def _ReshapeGrad(op, grad):
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None]
|
return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None]
|
||||||
|
|
||||||
|
|
||||||
@ -527,6 +545,10 @@ ops.NotDifferentiable("InvertPermutation")
|
|||||||
|
|
||||||
def _ReshapeToInput(op, grad):
|
def _ReshapeToInput(op, grad):
|
||||||
"""Reshapes the gradient to the shape of the original input."""
|
"""Reshapes the gradient to the shape of the original input."""
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]))
|
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]))
|
||||||
|
|
||||||
|
|
||||||
@ -783,6 +805,10 @@ def _ExtractImagePatchesGrad(op, grad):
|
|||||||
(1, 0),
|
(1, 0),
|
||||||
(input_indices_num - 1, output_indices_num))
|
(input_indices_num - 1, output_indices_num))
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
grad_expanded = array_ops.transpose(
|
grad_expanded = array_ops.transpose(
|
||||||
array_ops.reshape(
|
array_ops.reshape(
|
||||||
grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
|
grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
|
||||||
@ -844,6 +870,10 @@ def _ExtractVolumePatchesGrad(op, grad):
|
|||||||
sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
|
sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
|
||||||
(input_indices_num - 1, output_indices_num))
|
(input_indices_num - 1, output_indices_num))
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Converting sparse IndexedSlices to a dense Tensor.*")
|
||||||
grad_expanded = array_ops.transpose(
|
grad_expanded = array_ops.transpose(
|
||||||
array_ops.reshape(grad, (batch_size, planes_out, rows_out, cols_out,
|
array_ops.reshape(grad, (batch_size, planes_out, rows_out, cols_out,
|
||||||
ksize_p, ksize_r, ksize_c, channels)),
|
ksize_p, ksize_r, ksize_c, channels)),
|
||||||
|
Loading…
Reference in New Issue
Block a user