Fix bug with pfor reductions in dynamic-length loops.
Since `PForConfig._maybe_iters` is used to build TensorShapes, it must be static or None, not a Tensor. PiperOrigin-RevId: 356276299 Change-Id: I42f62d9e8bb055c0946d8c64df0e114aec2351c8
This commit is contained in:
		
							parent
							
								
									b97a34c5f9
								
							
						
					
					
						commit
						750119433d
					
				| @ -29,6 +29,7 @@ from tensorflow.core.example import example_pb2 | |||||||
| from tensorflow.core.example import feature_pb2 | from tensorflow.core.example import feature_pb2 | ||||||
| from tensorflow.python.client import session | from tensorflow.python.client import session | ||||||
| from tensorflow.python.eager import backprop | from tensorflow.python.eager import backprop | ||||||
|  | from tensorflow.python.eager import context | ||||||
| from tensorflow.python.eager import def_function | from tensorflow.python.eager import def_function | ||||||
| from tensorflow.python.framework import composite_tensor | from tensorflow.python.framework import composite_tensor | ||||||
| from tensorflow.python.framework import config | from tensorflow.python.framework import config | ||||||
| @ -315,6 +316,20 @@ class ReductionTest(PForTestCase): | |||||||
|                                 "parallel_iterations currently unsupported"): |                                 "parallel_iterations currently unsupported"): | ||||||
|       pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2) |       pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2) | ||||||
| 
 | 
 | ||||||
|  |   def test_var_loop_len(self): | ||||||
|  |     if context.executing_eagerly(): | ||||||
|  |       self.skipTest("Variable length not possible under eager execution.") | ||||||
|  | 
 | ||||||
|  |     x = random_ops.random_uniform([8, 3]) | ||||||
|  | 
 | ||||||
|  |     def loop_fn(i, pfor_config): | ||||||
|  |       return pfor_config.reduce_sum(array_ops.gather(x, i)) | ||||||
|  | 
 | ||||||
|  |     num_iters = array_ops.placeholder(dtypes.int32) | ||||||
|  |     pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) | ||||||
|  |     with self.cached_session() as sess: | ||||||
|  |       sess.run(pfor, feed_dict={num_iters: 8}) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @test_util.run_all_in_graph_and_eager_modes | @test_util.run_all_in_graph_and_eager_modes | ||||||
| class BitwiseTest(PForTestCase): | class BitwiseTest(PForTestCase): | ||||||
|  | |||||||
| @ -1118,6 +1118,8 @@ class PForConfig(object): | |||||||
| 
 | 
 | ||||||
|   def _set_iters(self, iters): |   def _set_iters(self, iters): | ||||||
|     """Set number of pfor iterations.""" |     """Set number of pfor iterations.""" | ||||||
|  |     if isinstance(iters, ops.Tensor): | ||||||
|  |       iters = tensor_util.constant_value(iters) | ||||||
|     self._maybe_iters = iters |     self._maybe_iters = iters | ||||||
| 
 | 
 | ||||||
|   def reduce(self, fn, *args): |   def reduce(self, fn, *args): | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user