Reduce TensorShape.__init__ overhead by 50%.
TensorShape.__init__ is on the hotpath because a TensorShape is created the first time EagerTensor.shape is called. The TensorShape is created from EagerTensor._shape_tuple, which is a tuple of ints. This change optimizes the code for this common path. PiperOrigin-RevId: 316922384 Change-Id: I063ea393450123ea4150972e5c73647f03a29cf5
This commit is contained in:
parent
f8fd28e4a0
commit
2a8bbb92b7
|
@ -465,8 +465,8 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
for _ in range(10):
|
||||
yield [20]
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, r"Failed to convert '\[\[1\]\]' to a shape"):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
r"Dimension value must be integer or None"):
|
||||
dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=(dtypes.int64), output_shapes=[[1]])
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ from tensorflow.python.eager import test
|
|||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
@ -1441,6 +1442,19 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||
|
||||
self._run(fn, 10000)
|
||||
|
||||
def benchmark_tf_tensor_shape_creation_overhead(self):
|
||||
# A `TensorShape` is created the first time `EagerTensor.shape` is
|
||||
# called, which puts `TensorShape.__init__` on the hotpath. The
|
||||
# `TensorShape` is created from `EagerTensor._shape_tuple`.
|
||||
|
||||
x = array_ops.ones((1, 1))
|
||||
shape_tuple = x._shape_tuple()
|
||||
|
||||
def fn():
|
||||
tensor_shape.TensorShape(shape_tuple)
|
||||
|
||||
self._run(fn, 100000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
@ -184,10 +184,14 @@ class Dimension(object):
|
|||
|
||||
def __init__(self, value):
|
||||
"""Creates a new Dimension with the given value."""
|
||||
if value is None:
|
||||
if isinstance(value, int): # Most common case.
|
||||
if value < 0:
|
||||
raise ValueError("Dimension %d must be >= 0" % value)
|
||||
self._value = value
|
||||
elif value is None:
|
||||
self._value = None
|
||||
elif isinstance(value, Dimension):
|
||||
self._value = value
|
||||
self._value = value._value
|
||||
else:
|
||||
try:
|
||||
# int(...) compensates for the int/long dichotomy on Python 2.X.
|
||||
|
@ -748,7 +752,9 @@ class TensorShape(object):
|
|||
Raises:
|
||||
TypeError: If dims cannot be converted to a list of dimensions.
|
||||
"""
|
||||
if dims is None:
|
||||
if isinstance(dims, (tuple, list)): # Most common case.
|
||||
self._dims = [Dimension(d) for d in dims]
|
||||
elif dims is None:
|
||||
self._dims = None
|
||||
elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
|
||||
if dims.unknown_rank:
|
||||
|
|
Loading…
Reference in New Issue