diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 1a7ca90d15d..01bf1c04a99 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_partition.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/util/ptr_util.h" @@ -737,7 +738,10 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( }; int i = 0; - FunctionNameGenerator name_generator(&data->lib_def_, function_name); + // Generate a random function_name to avoid one function reuse the partition + // function instantiated by another function. + FunctionNameGenerator name_generator( + &data->lib_def_, absl::StrCat(function_name, "_", random::New64())); for (const auto& pair : subgraphs) { i += 1; const string& target = pair.first; diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index 4cb54479b62..9522b3eae49 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -96,8 +96,9 @@ class MultiWorkersTest(test.TestCase): def setUp(self): super(MultiWorkersTest, self).setUp() - workers, _ = test_util.create_local_cluster(2, 0) - remote.connect_to_remote_host([workers[0].target, workers[1].target]) + workers, _ = test_util.create_local_cluster(3, 0) + remote.connect_to_remote_host( + [workers[0].target, workers[1].target, workers[2].target]) def testMultiDeviceFunctionOnRemoteDevice(self): with ops.device('/job:worker/replica:0/task:1'): @@ -113,6 +114,24 @@ class MultiWorkersTest(test.TestCase): with ops.device('/job:worker/replica:0/task:0'): self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) + def testSimpleParameterServer(self): + + with ops.device('/job:worker/task:2/device:CPU:0'): + v1 = variables.Variable(initial_value=0) + v2 = variables.Variable(initial_value=10) + + @def_function.function + def worker_fn(): + v1.assign_add(1) + v2.assign_sub(2) + return v1.read_value() + v2.read_value() + + with ops.device('/job:worker/task:0/device:CPU:0'): + self.assertAllEqual(worker_fn(), 9) + + with ops.device('/job:worker/task:1/device:CPU:0'): + self.assertAllEqual(worker_fn(), 8) + if __name__ == '__main__': test.main()