Minimize calls to tesor_util.constant_value in array_grad._StridedSliceGrad.
PiperOrigin-RevId: 293716338 Change-Id: Id05c9afa21f80543ef783d0cfbc33027caecdf05
This commit is contained in:
parent
6a7b7a211f
commit
265e1be025
@ -273,14 +273,14 @@ def _StridedSliceGrad(op, grad):
|
|||||||
# be the same.
|
# be the same.
|
||||||
x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
|
x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
|
||||||
|
|
||||||
if tensor_util.constant_value(x) is not None:
|
x_static = tensor_util.constant_value(x)
|
||||||
x = tensor_util.constant_value(x)
|
x = x_static if x_static is not None else x
|
||||||
if tensor_util.constant_value(begin) is not None:
|
begin_static = tensor_util.constant_value(begin)
|
||||||
begin = tensor_util.constant_value(begin)
|
begin = begin_static if begin_static is not None else begin
|
||||||
if tensor_util.constant_value(end) is not None:
|
end_static = tensor_util.constant_value(end)
|
||||||
end = tensor_util.constant_value(end)
|
end = end_static if end_static is not None else end
|
||||||
if tensor_util.constant_value(strides) is not None:
|
strides_static = tensor_util.constant_value(strides)
|
||||||
strides = tensor_util.constant_value(strides)
|
strides = strides_static if strides_static is not None else strides
|
||||||
|
|
||||||
return array_ops.strided_slice_grad(
|
return array_ops.strided_slice_grad(
|
||||||
x,
|
x,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user