diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 20508f37eb7..de29cc53c1f 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools +import operator import six from tensorflow.core.framework import tensor_shape_pb2 @@ -916,10 +918,7 @@ class TensorShape(object): def num_elements(self): """Returns the total number of elements, or none for incomplete shapes.""" if self.is_fully_defined(): - size = 1 - for dim in self._dims: - size *= dim.value - return size + return functools.reduce(operator.mul, self.as_list(), 1) else: return None @@ -942,19 +941,20 @@ class TensorShape(object): other = as_shape(other) if self._dims is None: return other + if other.dims is None: + return self else: try: self.assert_same_rank(other) - new_dims = [] - for i, dim in enumerate(self._dims): - new_dims.append(dim.merge_with(other[i])) + new_dims = [ + dim.merge_with(other_dim) + for dim, other_dim in zip(self._dims, other.dims) + ] return TensorShape(new_dims) except ValueError: raise ValueError("Shapes %s and %s are not compatible" % (self, other)) def __add__(self, other): - if not isinstance(other, TensorShape): - other = TensorShape(other) return self.concatenate(other) def __radd__(self, other): @@ -1157,10 +1157,10 @@ class TensorShape(object): if self._dims is None or other.dims is None or self.rank != other.rank: return unknown_shape() - dims = [(Dimension(None))] * self.rank - for i, (d1, d2) in enumerate(zip(self._dims, other.dims)): - if d1 is not None and d2 is not None and d1 == d2: - dims[i] = d1 + dims = [ + d1 if d1 is not None and d2 is not None and d1 == d2 else None + for d1, d2 in zip(self._dims, other.dims) + ] return TensorShape(dims) def is_fully_defined(self):