optimization on zeros_like_impl
Instead of calling convert to tensor all the time, we first check it is a tensor, if those exist, no need to call. Otherwise, call. This can help a lot on variable input. In benchmark the performance is improved by ~30% PiperOrigin-RevId: 275479807 Change-Id: I81bbb4e6bc5a7918704b13f96619e53217ce5fea
This commit is contained in:
parent
256a92a079
commit
d3c1452077
@ -2525,10 +2525,13 @@ def zeros_like_v2(
|
||||
def zeros_like_impl(tensor, dtype, name, optimize=True):
|
||||
"""Internal implementation for the v1/v2 zeros_like API calls."""
|
||||
with ops.name_scope(name, "zeros_like", [tensor]) as name:
|
||||
if not tensor_util.is_tensor(tensor):
|
||||
tensor = ops.convert_to_tensor(tensor, name="tensor")
|
||||
tensor_shape = tensor.shape
|
||||
tensor_dtype = tensor.dtype
|
||||
|
||||
if context.executing_eagerly():
|
||||
if dtype is not None and dtype != tensor.dtype:
|
||||
if dtype is not None and dtype != tensor_dtype:
|
||||
return zeros(
|
||||
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
|
||||
return gen_array_ops.zeros_like(tensor, name=name)
|
||||
@ -2536,13 +2539,13 @@ def zeros_like_impl(tensor, dtype, name, optimize=True):
|
||||
# For now, variant types must be created via zeros_like; as we need to
|
||||
# pass the input variant object to the proper zeros callback.
|
||||
|
||||
if (optimize and tensor.shape.is_fully_defined() and
|
||||
tensor.dtype != dtypes.variant):
|
||||
if (optimize and tensor_shape.is_fully_defined() and
|
||||
tensor_dtype != dtypes.variant):
|
||||
# We can produce a zeros tensor independent of the value of 'tensor',
|
||||
# since the shape is known statically.
|
||||
return zeros(tensor.shape, dtype=dtype or tensor.dtype, name=name)
|
||||
return zeros(tensor_shape, dtype=dtype or tensor_dtype, name=name)
|
||||
|
||||
if dtype is not None and dtype != tensor.dtype and dtype != dtypes.variant:
|
||||
if dtype is not None and dtype != tensor_dtype and dtype != dtypes.variant:
|
||||
return zeros(
|
||||
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user