Fix bug with pfor reductions in dynamic-length loops.
Since `PForConfig._maybe_iters` is used to build TensorShapes, it must be static or None, not a Tensor. PiperOrigin-RevId: 356276299 Change-Id: I42f62d9e8bb055c0946d8c64df0e114aec2351c8
This commit is contained in:
parent
b97a34c5f9
commit
750119433d
@ -29,6 +29,7 @@ from tensorflow.core.example import example_pb2
|
|||||||
from tensorflow.core.example import feature_pb2
|
from tensorflow.core.example import feature_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
@ -315,6 +316,20 @@ class ReductionTest(PForTestCase):
|
|||||||
"parallel_iterations currently unsupported"):
|
"parallel_iterations currently unsupported"):
|
||||||
pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2)
|
pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2)
|
||||||
|
|
||||||
|
def test_var_loop_len(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.skipTest("Variable length not possible under eager execution.")
|
||||||
|
|
||||||
|
x = random_ops.random_uniform([8, 3])
|
||||||
|
|
||||||
|
def loop_fn(i, pfor_config):
|
||||||
|
return pfor_config.reduce_sum(array_ops.gather(x, i))
|
||||||
|
|
||||||
|
num_iters = array_ops.placeholder(dtypes.int32)
|
||||||
|
pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
sess.run(pfor, feed_dict={num_iters: 8})
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class BitwiseTest(PForTestCase):
|
class BitwiseTest(PForTestCase):
|
||||||
|
|||||||
@ -1118,6 +1118,8 @@ class PForConfig(object):
|
|||||||
|
|
||||||
def _set_iters(self, iters):
|
def _set_iters(self, iters):
|
||||||
"""Set number of pfor iterations."""
|
"""Set number of pfor iterations."""
|
||||||
|
if isinstance(iters, ops.Tensor):
|
||||||
|
iters = tensor_util.constant_value(iters)
|
||||||
self._maybe_iters = iters
|
self._maybe_iters = iters
|
||||||
|
|
||||||
def reduce(self, fn, *args):
|
def reduce(self, fn, *args):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user