For a multi-device function, if the output producing node is assigned to a remote device, colocate the retval node with its input node.
If a retval node is assigned to a remote device, the function instantiation would fail with an error "outputting tensors on remote devices is not supported.". After this change, a retval node which was assigned to a remote device might be assigned to a local device now. PiperOrigin-RevId: 285090649 Change-Id: Iade0982fee600591e3e0e2dfe12d676efa720729
This commit is contained in:
		
							parent
							
								
									824250a7ed
								
							
						
					
					
						commit
						126c04a9ed
					
				| @ -410,6 +410,18 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets( | |||||||
|                   << " src_device: " << *src_device |                   << " src_device: " << *src_device | ||||||
|                   << " colo group: " << colocation_group; |                   << " colo group: " << colocation_group; | ||||||
|         } |         } | ||||||
|  |         // If colocation_group is not set and output producing node is assigned
 | ||||||
|  |         // to a remote device, colocate the retval node with its input node.
 | ||||||
|  |         // TODO(yujingzhang): Remove this when we support outputting tensors on
 | ||||||
|  |         // remote devices.
 | ||||||
|  |         const bool remote_src_device = | ||||||
|  |             !src_device->empty() && GetFLR(*src_device) == nullptr; | ||||||
|  |         if (colocation_group.empty() && remote_src_device) { | ||||||
|  |           colocation_group = | ||||||
|  |               absl::StrCat(kColocationGroupPrefix, it->src()->name()); | ||||||
|  |           VLOG(3) << "Considering src: " << src_node->name() | ||||||
|  |                   << " colo group: " << colocation_group; | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|         // If resource is produced by a function call node, we can't trust
 |         // If resource is produced by a function call node, we can't trust
 | ||||||
|         // source node device assignment, because multi-device functions can
 |         // source node device assignment, because multi-device functions can
 | ||||||
|  | |||||||
| @ -94,12 +94,8 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): | |||||||
|         c = variable_b + 1 |         c = variable_b + 1 | ||||||
|       return c, i + variable_b |       return c, i + variable_b | ||||||
| 
 | 
 | ||||||
|     with self.assertRaises(errors.UnimplementedError) as cm: |     self.assertAllEqual( | ||||||
|       remote_output(constant_op.constant([1])) |         remote_output(constant_op.constant([1]))[0].numpy(), 2) | ||||||
| 
 |  | ||||||
|     self.assertIn( |  | ||||||
|         'Currently, outputting tensors on remote devices is not supported.', |  | ||||||
|         cm.exception.message) |  | ||||||
| 
 | 
 | ||||||
|   def testMultiDeviceFunctionAmbiguousDevice(self): |   def testMultiDeviceFunctionAmbiguousDevice(self): | ||||||
| 
 | 
 | ||||||
| @ -176,6 +172,19 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase): | |||||||
|     # Reset the context to avoid polluting other test cases. |     # Reset the context to avoid polluting other test cases. | ||||||
|     context._reset_context() |     context._reset_context() | ||||||
| 
 | 
 | ||||||
|  |   @test_util.eager_lazy_remote_copy_on_and_off | ||||||
|  |   def testReturnRemoteArgument(self): | ||||||
|  | 
 | ||||||
|  |     @def_function.function | ||||||
|  |     def local_func(i): | ||||||
|  |       return i | ||||||
|  | 
 | ||||||
|  |     with ops.device('/job:worker/replica:0/task:0'): | ||||||
|  |       x = constant_op.constant([2, 1]) | ||||||
|  | 
 | ||||||
|  |     with ops.device('/job:worker/replica:0/task:1'): | ||||||
|  |       self.assertAllEqual(local_func(x), [2, 1]) | ||||||
|  | 
 | ||||||
|   @test_util.eager_lazy_remote_copy_on_and_off |   @test_util.eager_lazy_remote_copy_on_and_off | ||||||
|   def testMultiDeviceFunctionOnLocalDevice(self): |   def testMultiDeviceFunctionOnLocalDevice(self): | ||||||
|     with ops.device('/job:worker/replica:0/task:1'): |     with ops.device('/job:worker/replica:0/task:1'): | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user