Merge pull request #31673 from saxenasaurabh/cherrypicks_VWQL1

Make `maybe_set_static_shape` a no-op when `shape` is a python constant.
This commit is contained in:
Goldie Gadde 2019-08-16 14:19:11 -07:00 committed by GitHub
commit 4756cfbbec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 2 deletions

View File

@ -964,10 +964,40 @@ def shape_tensor(shape): # pylint: disable=invalid-name
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
# DO NOT USE: For testing only.
_ENABLE_MAYBE_SET_STATIC_SHAPE = True
def maybe_set_static_shape(tensor, shape): # pylint: disable=invalid-name
if (not context.executing_eagerly() and
"""Sets the shape of `tensor` to the `shape`'s constant value, if inferrable.
This is a temporary workaround to fix shape inference across functional op
boundaries. E.g.
```python
shape = tf.constant([3])
@tf.function
def f():
u = tf.random_uniform(shape)
return u
```
If we were to rely solely on C++ shape inference, the shape of `u` inside
`f` would be unknown because C++ shape inference is not aware of the outer
graph and all it sees is a Placeholder node when backtracing the captured
tensor for `shape`. `maybe_set_static_shape` computes the static shape value
of `shape` by traversing the `FuncGraph` boundaries and sets the correct
shape.
A longer term solution would be to fix C++ shape inference.
Args:
tensor: A tensor.
shape: A shape tensor.
"""
if (_ENABLE_MAYBE_SET_STATIC_SHAPE and not context.executing_eagerly() and
ops.get_default_graph().building_function and
not tensor.shape.is_fully_defined()):
not tensor.shape.is_fully_defined() and is_tensor(shape)):
shape = shape_tensor(shape)
const_shape = constant_value_as_shape(shape)
tensor.set_shape(const_shape)

View File

@ -18,11 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import sys
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
@ -1080,6 +1082,52 @@ class ConstantValueAsShapeTest(test.TestCase):
c_val = tensor_util.constant_value_as_shape(tf_val)
class MaybeSetStaticShapeTest(test.TestCase):
@contextlib.contextmanager
def disableSetStaticShape(self):
flag_old = tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE
tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = False
try:
yield
finally:
tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = flag_old
@test_util.run_deprecated_v1
def testMaybeSetStaticShape(self):
shape = constant_op.constant([2, 5], dtype=dtypes.int32)
def reshape():
v = array_ops.zeros([10])
return array_ops.reshape(v, shape)
with self.disableSetStaticShape():
graph_without_shape_propagation = func_graph.func_graph_from_py_func(
"without_shape_propagation", reshape, [], {})
graph_with_shape_propagation = func_graph.func_graph_from_py_func(
"with_shape_propagation", reshape, [], {})
self.assertCountEqual(
[op.type for op in graph_without_shape_propagation.get_operations()],
[op.type for op in graph_with_shape_propagation.get_operations()])
@test_util.run_deprecated_v1
def testMaybeSetStaticShapeScalarShape(self):
def reshape():
v = array_ops.placeholder(dtypes.float32)
t = array_ops.reshape(v, [-1])
return t
with self.disableSetStaticShape():
graph_without_shape_propagation = func_graph.func_graph_from_py_func(
"without_shape_propagation", reshape, [], {})
graph_with_shape_propagation = func_graph.func_graph_from_py_func(
"with_shape_propagation", reshape, [], {})
self.assertCountEqual(
[op.type for op in graph_without_shape_propagation.get_operations()],
[op.type for op in graph_with_shape_propagation.get_operations()])
class ShapeTensorTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes