Merge pull request #31673 from saxenasaurabh/cherrypicks_VWQL1
Make `maybe_set_static_shape` a no-op when `shape` is a python constant.
This commit is contained in:
commit
4756cfbbec
@ -964,10 +964,40 @@ def shape_tensor(shape): # pylint: disable=invalid-name
|
||||
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
|
||||
|
||||
|
||||
# DO NOT USE: For testing only.
|
||||
_ENABLE_MAYBE_SET_STATIC_SHAPE = True
|
||||
|
||||
|
||||
def maybe_set_static_shape(tensor, shape): # pylint: disable=invalid-name
|
||||
if (not context.executing_eagerly() and
|
||||
"""Sets the shape of `tensor` to the `shape`'s constant value, if inferrable.
|
||||
|
||||
This is a temporary workaround to fix shape inference across functional op
|
||||
boundaries. E.g.
|
||||
|
||||
```python
|
||||
shape = tf.constant([3])
|
||||
@tf.function
|
||||
def f():
|
||||
u = tf.random_uniform(shape)
|
||||
return u
|
||||
```
|
||||
|
||||
If we were to rely solely on C++ shape inference, the shape of `u` inside
|
||||
`f` would be unknown because C++ shape inference is not aware of the outer
|
||||
graph and all it sees is a Placeholder node when backtracing the captured
|
||||
tensor for `shape`. `maybe_set_static_shape` computes the static shape value
|
||||
of `shape` by traversing the `FuncGraph` boundaries and sets the correct
|
||||
shape.
|
||||
|
||||
A longer term solution would be to fix C++ shape inference.
|
||||
|
||||
Args:
|
||||
tensor: A tensor.
|
||||
shape: A shape tensor.
|
||||
"""
|
||||
if (_ENABLE_MAYBE_SET_STATIC_SHAPE and not context.executing_eagerly() and
|
||||
ops.get_default_graph().building_function and
|
||||
not tensor.shape.is_fully_defined()):
|
||||
not tensor.shape.is_fully_defined() and is_tensor(shape)):
|
||||
shape = shape_tensor(shape)
|
||||
const_shape = constant_value_as_shape(shape)
|
||||
tensor.set_shape(const_shape)
|
||||
|
@ -18,11 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -1080,6 +1082,52 @@ class ConstantValueAsShapeTest(test.TestCase):
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
|
||||
|
||||
class MaybeSetStaticShapeTest(test.TestCase):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disableSetStaticShape(self):
|
||||
flag_old = tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE
|
||||
tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = False
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = flag_old
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMaybeSetStaticShape(self):
|
||||
shape = constant_op.constant([2, 5], dtype=dtypes.int32)
|
||||
|
||||
def reshape():
|
||||
v = array_ops.zeros([10])
|
||||
return array_ops.reshape(v, shape)
|
||||
|
||||
with self.disableSetStaticShape():
|
||||
graph_without_shape_propagation = func_graph.func_graph_from_py_func(
|
||||
"without_shape_propagation", reshape, [], {})
|
||||
graph_with_shape_propagation = func_graph.func_graph_from_py_func(
|
||||
"with_shape_propagation", reshape, [], {})
|
||||
self.assertCountEqual(
|
||||
[op.type for op in graph_without_shape_propagation.get_operations()],
|
||||
[op.type for op in graph_with_shape_propagation.get_operations()])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMaybeSetStaticShapeScalarShape(self):
|
||||
|
||||
def reshape():
|
||||
v = array_ops.placeholder(dtypes.float32)
|
||||
t = array_ops.reshape(v, [-1])
|
||||
return t
|
||||
|
||||
with self.disableSetStaticShape():
|
||||
graph_without_shape_propagation = func_graph.func_graph_from_py_func(
|
||||
"without_shape_propagation", reshape, [], {})
|
||||
graph_with_shape_propagation = func_graph.func_graph_from_py_func(
|
||||
"with_shape_propagation", reshape, [], {})
|
||||
self.assertCountEqual(
|
||||
[op.type for op in graph_without_shape_propagation.get_operations()],
|
||||
[op.type for op in graph_with_shape_propagation.get_operations()])
|
||||
|
||||
|
||||
class ShapeTensorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
|
Loading…
x
Reference in New Issue
Block a user