Fix a unexpected behavior in Parameter Server. Let's say we have two worker, one ps and a function which will read a variable on ps. We first execute the function on worker:0, a send kernel will be created on ps to send the variable to work:0. Then when we execute the same function on worker:1, since the partition function has the same name, same partition function will be executed on ps. This partition function will send variable to worker:0 instead of worker:1.

PiperOrigin-RevId: 253679384
This commit is contained in:
Xiao Yu 2019-06-17 15:58:06 -07:00 committed by TensorFlower Gardener
parent 3f13e7a71b
commit 8507511e5d
2 changed files with 26 additions and 3 deletions

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"
#include "tensorflow/core/util/ptr_util.h"
@ -737,7 +738,10 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
};
int i = 0;
FunctionNameGenerator name_generator(&data->lib_def_, function_name);
// Generate a random function_name to avoid one function reuse the partition
// function instantiated by another function.
FunctionNameGenerator name_generator(
&data->lib_def_, absl::StrCat(function_name, "_", random::New64()));
for (const auto& pair : subgraphs) {
i += 1;
const string& target = pair.first;

View File

@ -96,8 +96,9 @@ class MultiWorkersTest(test.TestCase):
def setUp(self):
super(MultiWorkersTest, self).setUp()
workers, _ = test_util.create_local_cluster(2, 0)
remote.connect_to_remote_host([workers[0].target, workers[1].target])
workers, _ = test_util.create_local_cluster(3, 0)
remote.connect_to_remote_host(
[workers[0].target, workers[1].target, workers[2].target])
def testMultiDeviceFunctionOnRemoteDevice(self):
with ops.device('/job:worker/replica:0/task:1'):
@ -113,6 +114,24 @@ class MultiWorkersTest(test.TestCase):
with ops.device('/job:worker/replica:0/task:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
def testSimpleParameterServer(self):
with ops.device('/job:worker/task:2/device:CPU:0'):
v1 = variables.Variable(initial_value=0)
v2 = variables.Variable(initial_value=10)
@def_function.function
def worker_fn():
v1.assign_add(1)
v2.assign_sub(2)
return v1.read_value() + v2.read_value()
with ops.device('/job:worker/task:0/device:CPU:0'):
self.assertAllEqual(worker_fn(), 9)
with ops.device('/job:worker/task:1/device:CPU:0'):
self.assertAllEqual(worker_fn(), 8)
if __name__ == '__main__':
test.main()