Merge pull request #31673 from saxenasaurabh/cherrypicks_VWQL1
Make `maybe_set_static_shape` a no-op when `shape` is a python constant.
This commit is contained in:
commit
4756cfbbec
@ -964,10 +964,40 @@ def shape_tensor(shape): # pylint: disable=invalid-name
|
|||||||
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
|
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
|
||||||
|
|
||||||
|
|
||||||
|
# DO NOT USE: For testing only.
|
||||||
|
_ENABLE_MAYBE_SET_STATIC_SHAPE = True
|
||||||
|
|
||||||
|
|
||||||
def maybe_set_static_shape(tensor, shape): # pylint: disable=invalid-name
|
def maybe_set_static_shape(tensor, shape): # pylint: disable=invalid-name
|
||||||
if (not context.executing_eagerly() and
|
"""Sets the shape of `tensor` to the `shape`'s constant value, if inferrable.
|
||||||
|
|
||||||
|
This is a temporary workaround to fix shape inference across functional op
|
||||||
|
boundaries. E.g.
|
||||||
|
|
||||||
|
```python
|
||||||
|
shape = tf.constant([3])
|
||||||
|
@tf.function
|
||||||
|
def f():
|
||||||
|
u = tf.random_uniform(shape)
|
||||||
|
return u
|
||||||
|
```
|
||||||
|
|
||||||
|
If we were to rely solely on C++ shape inference, the shape of `u` inside
|
||||||
|
`f` would be unknown because C++ shape inference is not aware of the outer
|
||||||
|
graph and all it sees is a Placeholder node when backtracing the captured
|
||||||
|
tensor for `shape`. `maybe_set_static_shape` computes the static shape value
|
||||||
|
of `shape` by traversing the `FuncGraph` boundaries and sets the correct
|
||||||
|
shape.
|
||||||
|
|
||||||
|
A longer term solution would be to fix C++ shape inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: A tensor.
|
||||||
|
shape: A shape tensor.
|
||||||
|
"""
|
||||||
|
if (_ENABLE_MAYBE_SET_STATIC_SHAPE and not context.executing_eagerly() and
|
||||||
ops.get_default_graph().building_function and
|
ops.get_default_graph().building_function and
|
||||||
not tensor.shape.is_fully_defined()):
|
not tensor.shape.is_fully_defined() and is_tensor(shape)):
|
||||||
shape = shape_tensor(shape)
|
shape = shape_tensor(shape)
|
||||||
const_shape = constant_value_as_shape(shape)
|
const_shape = constant_value_as_shape(shape)
|
||||||
tensor.set_shape(const_shape)
|
tensor.set_shape(const_shape)
|
||||||
|
@ -18,11 +18,13 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import func_graph
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -1080,6 +1082,52 @@ class ConstantValueAsShapeTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
|
||||||
|
|
||||||
|
class MaybeSetStaticShapeTest(test.TestCase):
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def disableSetStaticShape(self):
|
||||||
|
flag_old = tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE
|
||||||
|
tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = False
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = flag_old
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testMaybeSetStaticShape(self):
|
||||||
|
shape = constant_op.constant([2, 5], dtype=dtypes.int32)
|
||||||
|
|
||||||
|
def reshape():
|
||||||
|
v = array_ops.zeros([10])
|
||||||
|
return array_ops.reshape(v, shape)
|
||||||
|
|
||||||
|
with self.disableSetStaticShape():
|
||||||
|
graph_without_shape_propagation = func_graph.func_graph_from_py_func(
|
||||||
|
"without_shape_propagation", reshape, [], {})
|
||||||
|
graph_with_shape_propagation = func_graph.func_graph_from_py_func(
|
||||||
|
"with_shape_propagation", reshape, [], {})
|
||||||
|
self.assertCountEqual(
|
||||||
|
[op.type for op in graph_without_shape_propagation.get_operations()],
|
||||||
|
[op.type for op in graph_with_shape_propagation.get_operations()])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testMaybeSetStaticShapeScalarShape(self):
|
||||||
|
|
||||||
|
def reshape():
|
||||||
|
v = array_ops.placeholder(dtypes.float32)
|
||||||
|
t = array_ops.reshape(v, [-1])
|
||||||
|
return t
|
||||||
|
|
||||||
|
with self.disableSetStaticShape():
|
||||||
|
graph_without_shape_propagation = func_graph.func_graph_from_py_func(
|
||||||
|
"without_shape_propagation", reshape, [], {})
|
||||||
|
graph_with_shape_propagation = func_graph.func_graph_from_py_func(
|
||||||
|
"with_shape_propagation", reshape, [], {})
|
||||||
|
self.assertCountEqual(
|
||||||
|
[op.type for op in graph_without_shape_propagation.get_operations()],
|
||||||
|
[op.type for op in graph_with_shape_propagation.get_operations()])
|
||||||
|
|
||||||
|
|
||||||
class ShapeTensorTest(test_util.TensorFlowTestCase):
|
class ShapeTensorTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
Loading…
x
Reference in New Issue
Block a user