For a multi-device function, if the output producing node is assigned to a remote device, colocate the retval node with its input node.
If a retval node is assigned to a remote device, the function instantiation would fail with an error "outputting tensors on remote devices is not supported.". After this change, a retval node which was assigned to a remote device might be assigned to a local device now. PiperOrigin-RevId: 285090649 Change-Id: Iade0982fee600591e3e0e2dfe12d676efa720729
This commit is contained in:
		
							parent
							
								
									824250a7ed
								
							
						
					
					
						commit
						126c04a9ed
					
				| @ -410,6 +410,18 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets( | ||||
|                   << " src_device: " << *src_device | ||||
|                   << " colo group: " << colocation_group; | ||||
|         } | ||||
|         // If colocation_group is not set and output producing node is assigned
 | ||||
|         // to a remote device, colocate the retval node with its input node.
 | ||||
|         // TODO(yujingzhang): Remove this when we support outputting tensors on
 | ||||
|         // remote devices.
 | ||||
|         const bool remote_src_device = | ||||
|             !src_device->empty() && GetFLR(*src_device) == nullptr; | ||||
|         if (colocation_group.empty() && remote_src_device) { | ||||
|           colocation_group = | ||||
|               absl::StrCat(kColocationGroupPrefix, it->src()->name()); | ||||
|           VLOG(3) << "Considering src: " << src_node->name() | ||||
|                   << " colo group: " << colocation_group; | ||||
|         } | ||||
| 
 | ||||
|         // If resource is produced by a function call node, we can't trust
 | ||||
|         // source node device assignment, because multi-device functions can
 | ||||
|  | ||||
| @ -94,12 +94,8 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): | ||||
|         c = variable_b + 1 | ||||
|       return c, i + variable_b | ||||
| 
 | ||||
|     with self.assertRaises(errors.UnimplementedError) as cm: | ||||
|       remote_output(constant_op.constant([1])) | ||||
| 
 | ||||
|     self.assertIn( | ||||
|         'Currently, outputting tensors on remote devices is not supported.', | ||||
|         cm.exception.message) | ||||
|     self.assertAllEqual( | ||||
|         remote_output(constant_op.constant([1]))[0].numpy(), 2) | ||||
| 
 | ||||
|   def testMultiDeviceFunctionAmbiguousDevice(self): | ||||
| 
 | ||||
| @ -176,6 +172,19 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase): | ||||
|     # Reset the context to avoid polluting other test cases. | ||||
|     context._reset_context() | ||||
| 
 | ||||
|   @test_util.eager_lazy_remote_copy_on_and_off | ||||
|   def testReturnRemoteArgument(self): | ||||
| 
 | ||||
|     @def_function.function | ||||
|     def local_func(i): | ||||
|       return i | ||||
| 
 | ||||
|     with ops.device('/job:worker/replica:0/task:0'): | ||||
|       x = constant_op.constant([2, 1]) | ||||
| 
 | ||||
|     with ops.device('/job:worker/replica:0/task:1'): | ||||
|       self.assertAllEqual(local_func(x), [2, 1]) | ||||
| 
 | ||||
|   @test_util.eager_lazy_remote_copy_on_and_off | ||||
|   def testMultiDeviceFunctionOnLocalDevice(self): | ||||
|     with ops.device('/job:worker/replica:0/task:1'): | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user