optimization on zeros_like_impl

Instead of calling convert to tensor all the time, we first check it is a tensor, if those exist, no need to call. Otherwise, call.

This can help a lot on variable input.

In benchmark the performance is improved by ~30%

PiperOrigin-RevId: 275479807
Change-Id: I81bbb4e6bc5a7918704b13f96619e53217ce5fea
This commit is contained in:
Yanhua Sun 2019-10-18 08:43:41 -07:00 committed by TensorFlower Gardener
parent 256a92a079
commit d3c1452077

View File

@ -2525,10 +2525,13 @@ def zeros_like_v2(
def zeros_like_impl(tensor, dtype, name, optimize=True):
"""Internal implementation for the v1/v2 zeros_like API calls."""
with ops.name_scope(name, "zeros_like", [tensor]) as name:
if not tensor_util.is_tensor(tensor):
tensor = ops.convert_to_tensor(tensor, name="tensor")
tensor_shape = tensor.shape
tensor_dtype = tensor.dtype
if context.executing_eagerly():
if dtype is not None and dtype != tensor.dtype:
if dtype is not None and dtype != tensor_dtype:
return zeros(
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
return gen_array_ops.zeros_like(tensor, name=name)
@ -2536,13 +2539,13 @@ def zeros_like_impl(tensor, dtype, name, optimize=True):
# For now, variant types must be created via zeros_like; as we need to
# pass the input variant object to the proper zeros callback.
if (optimize and tensor.shape.is_fully_defined() and
tensor.dtype != dtypes.variant):
if (optimize and tensor_shape.is_fully_defined() and
tensor_dtype != dtypes.variant):
# We can produce a zeros tensor independent of the value of 'tensor',
# since the shape is known statically.
return zeros(tensor.shape, dtype=dtype or tensor.dtype, name=name)
return zeros(tensor_shape, dtype=dtype or tensor_dtype, name=name)
if dtype is not None and dtype != tensor.dtype and dtype != dtypes.variant:
if dtype is not None and dtype != tensor_dtype and dtype != dtypes.variant:
return zeros(
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
else: