Minimize calls to tesor_util.constant_value in array_grad._StridedSliceGrad.
PiperOrigin-RevId: 293716338 Change-Id: Id05c9afa21f80543ef783d0cfbc33027caecdf05
This commit is contained in:
parent
6a7b7a211f
commit
265e1be025
@ -273,14 +273,14 @@ def _StridedSliceGrad(op, grad):
|
||||
# be the same.
|
||||
x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
|
||||
|
||||
if tensor_util.constant_value(x) is not None:
|
||||
x = tensor_util.constant_value(x)
|
||||
if tensor_util.constant_value(begin) is not None:
|
||||
begin = tensor_util.constant_value(begin)
|
||||
if tensor_util.constant_value(end) is not None:
|
||||
end = tensor_util.constant_value(end)
|
||||
if tensor_util.constant_value(strides) is not None:
|
||||
strides = tensor_util.constant_value(strides)
|
||||
x_static = tensor_util.constant_value(x)
|
||||
x = x_static if x_static is not None else x
|
||||
begin_static = tensor_util.constant_value(begin)
|
||||
begin = begin_static if begin_static is not None else begin
|
||||
end_static = tensor_util.constant_value(end)
|
||||
end = end_static if end_static is not None else end
|
||||
strides_static = tensor_util.constant_value(strides)
|
||||
strides = strides_static if strides_static is not None else strides
|
||||
|
||||
return array_ops.strided_slice_grad(
|
||||
x,
|
||||
|
Loading…
Reference in New Issue
Block a user