Fix bug with pfor reductions in dynamic-length loops.

Since `PForConfig._maybe_iters` is used to build TensorShapes, it must be static or None, not a Tensor.

PiperOrigin-RevId: 356276299
Change-Id: I42f62d9e8bb055c0946d8c64df0e114aec2351c8
This commit is contained in:
Dave Moore 2021-02-08 08:56:40 -08:00 committed by TensorFlower Gardener
parent b97a34c5f9
commit 750119433d
2 changed files with 17 additions and 0 deletions

View File

@ -29,6 +29,7 @@ from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
@ -315,6 +316,20 @@ class ReductionTest(PForTestCase):
"parallel_iterations currently unsupported"):
pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2)
def test_var_loop_len(self):
if context.executing_eagerly():
self.skipTest("Variable length not possible under eager execution.")
x = random_ops.random_uniform([8, 3])
def loop_fn(i, pfor_config):
return pfor_config.reduce_sum(array_ops.gather(x, i))
num_iters = array_ops.placeholder(dtypes.int32)
pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
with self.cached_session() as sess:
sess.run(pfor, feed_dict={num_iters: 8})
@test_util.run_all_in_graph_and_eager_modes
class BitwiseTest(PForTestCase):

View File

@ -1118,6 +1118,8 @@ class PForConfig(object):
def _set_iters(self, iters):
"""Set number of pfor iterations."""
if isinstance(iters, ops.Tensor):
iters = tensor_util.constant_value(iters)
self._maybe_iters = iters
def reduce(self, fn, *args):