diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index 9058564ed1f..bc1bb5409a7 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -29,6 +29,7 @@ from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.client import session from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import config @@ -315,6 +316,20 @@ class ReductionTest(PForTestCase): "parallel_iterations currently unsupported"): pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2) + def test_var_loop_len(self): + if context.executing_eagerly(): + self.skipTest("Variable length not possible under eager execution.") + + x = random_ops.random_uniform([8, 3]) + + def loop_fn(i, pfor_config): + return pfor_config.reduce_sum(array_ops.gather(x, i)) + + num_iters = array_ops.placeholder(dtypes.int32) + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + with self.cached_session() as sess: + sess.run(pfor, feed_dict={num_iters: 8}) + @test_util.run_all_in_graph_and_eager_modes class BitwiseTest(PForTestCase): diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index f23747d2bfb..cbbaf4d56ee 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -1118,6 +1118,8 @@ class PForConfig(object): def _set_iters(self, iters): """Set number of pfor iterations.""" + if isinstance(iters, ops.Tensor): + iters = tensor_util.constant_value(iters) self._maybe_iters = iters def reduce(self, fn, *args):