Minimize calls to tesor_util.constant_value in array_grad._StridedSliceGrad.

PiperOrigin-RevId: 293716338
Change-Id: Id05c9afa21f80543ef783d0cfbc33027caecdf05
This commit is contained in:
A. Unique TensorFlower 2020-02-06 17:38:52 -08:00 committed by TensorFlower Gardener
parent 6a7b7a211f
commit 265e1be025

View File

@ -273,14 +273,14 @@ def _StridedSliceGrad(op, grad):
# be the same.
x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
if tensor_util.constant_value(x) is not None:
x = tensor_util.constant_value(x)
if tensor_util.constant_value(begin) is not None:
begin = tensor_util.constant_value(begin)
if tensor_util.constant_value(end) is not None:
end = tensor_util.constant_value(end)
if tensor_util.constant_value(strides) is not None:
strides = tensor_util.constant_value(strides)
x_static = tensor_util.constant_value(x)
x = x_static if x_static is not None else x
begin_static = tensor_util.constant_value(begin)
begin = begin_static if begin_static is not None else begin
end_static = tensor_util.constant_value(end)
end = end_static if end_static is not None else end
strides_static = tensor_util.constant_value(strides)
strides = strides_static if strides_static is not None else strides
return array_ops.strided_slice_grad(
x,