PR #40778: Slightly improve performance of TensorShape methods in the common case
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/40778 `TensorShape` is used a lot within the codebase and can end up on the hot path during eager execution. This PR simplifies some method calls to slightly improve performance for the case of fully defined shapes. Most changes are related to the switch to a "Easier to ask for forgiveness than permission" (EAFP) pattern for accessing `TensorShape._dims` which is quite common in Python, but I would be happy to revert that if you prefer to still explicitely check for existance before access. This PR improves performance of `.merge_elements()` by 50% and `.num_elements()` by 25%. `.rank`, `as_list` and `__add__` are 5-10% faster on my machine which probably doesn't really matter in practice. Copybara import of the project: -- 8d93940cbee2d32ddafac4ba2d277f90cb60a09b by Lukas Geiger <lukas.geiger94@gmail.com>: Slightly improve performance of TensorShape PiperOrigin-RevId: 334589733 Change-Id: I96c1329b8ae5d9a5698bb2aa131977874f674286
This commit is contained in:
parent
48e66b2056
commit
ed63e8a52c
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import operator
|
||||
import six
|
||||
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
@ -916,10 +918,7 @@ class TensorShape(object):
|
||||
def num_elements(self):
|
||||
"""Returns the total number of elements, or none for incomplete shapes."""
|
||||
if self.is_fully_defined():
|
||||
size = 1
|
||||
for dim in self._dims:
|
||||
size *= dim.value
|
||||
return size
|
||||
return functools.reduce(operator.mul, self.as_list(), 1)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -942,19 +941,20 @@ class TensorShape(object):
|
||||
other = as_shape(other)
|
||||
if self._dims is None:
|
||||
return other
|
||||
if other.dims is None:
|
||||
return self
|
||||
else:
|
||||
try:
|
||||
self.assert_same_rank(other)
|
||||
new_dims = []
|
||||
for i, dim in enumerate(self._dims):
|
||||
new_dims.append(dim.merge_with(other[i]))
|
||||
new_dims = [
|
||||
dim.merge_with(other_dim)
|
||||
for dim, other_dim in zip(self._dims, other.dims)
|
||||
]
|
||||
return TensorShape(new_dims)
|
||||
except ValueError:
|
||||
raise ValueError("Shapes %s and %s are not compatible" % (self, other))
|
||||
|
||||
def __add__(self, other):
|
||||
if not isinstance(other, TensorShape):
|
||||
other = TensorShape(other)
|
||||
return self.concatenate(other)
|
||||
|
||||
def __radd__(self, other):
|
||||
@ -1157,10 +1157,10 @@ class TensorShape(object):
|
||||
if self._dims is None or other.dims is None or self.rank != other.rank:
|
||||
return unknown_shape()
|
||||
|
||||
dims = [(Dimension(None))] * self.rank
|
||||
for i, (d1, d2) in enumerate(zip(self._dims, other.dims)):
|
||||
if d1 is not None and d2 is not None and d1 == d2:
|
||||
dims[i] = d1
|
||||
dims = [
|
||||
d1 if d1 is not None and d2 is not None and d1 == d2 else None
|
||||
for d1, d2 in zip(self._dims, other.dims)
|
||||
]
|
||||
return TensorShape(dims)
|
||||
|
||||
def is_fully_defined(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user