Fix a unexpected behavior in Parameter Server. Let's say we have two worker, one ps and a function which will read a variable on ps. We first execute the function on worker:0, a send kernel will be created on ps to send the variable to work:0. Then when we execute the same function on worker:1, since the partition function has the same name, same partition function will be executed on ps. This partition function will send variable to worker:0 instead of worker:1.
PiperOrigin-RevId: 253679384
This commit is contained in:
parent
3f13e7a71b
commit
8507511e5d
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_partition.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
@ -737,7 +738,10 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
};
|
||||
|
||||
int i = 0;
|
||||
FunctionNameGenerator name_generator(&data->lib_def_, function_name);
|
||||
// Generate a random function_name to avoid one function reuse the partition
|
||||
// function instantiated by another function.
|
||||
FunctionNameGenerator name_generator(
|
||||
&data->lib_def_, absl::StrCat(function_name, "_", random::New64()));
|
||||
for (const auto& pair : subgraphs) {
|
||||
i += 1;
|
||||
const string& target = pair.first;
|
||||
|
@ -96,8 +96,9 @@ class MultiWorkersTest(test.TestCase):
|
||||
def setUp(self):
|
||||
super(MultiWorkersTest, self).setUp()
|
||||
|
||||
workers, _ = test_util.create_local_cluster(2, 0)
|
||||
remote.connect_to_remote_host([workers[0].target, workers[1].target])
|
||||
workers, _ = test_util.create_local_cluster(3, 0)
|
||||
remote.connect_to_remote_host(
|
||||
[workers[0].target, workers[1].target, workers[2].target])
|
||||
|
||||
def testMultiDeviceFunctionOnRemoteDevice(self):
|
||||
with ops.device('/job:worker/replica:0/task:1'):
|
||||
@ -113,6 +114,24 @@ class MultiWorkersTest(test.TestCase):
|
||||
with ops.device('/job:worker/replica:0/task:0'):
|
||||
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
||||
|
||||
def testSimpleParameterServer(self):
|
||||
|
||||
with ops.device('/job:worker/task:2/device:CPU:0'):
|
||||
v1 = variables.Variable(initial_value=0)
|
||||
v2 = variables.Variable(initial_value=10)
|
||||
|
||||
@def_function.function
|
||||
def worker_fn():
|
||||
v1.assign_add(1)
|
||||
v2.assign_sub(2)
|
||||
return v1.read_value() + v2.read_value()
|
||||
|
||||
with ops.device('/job:worker/task:0/device:CPU:0'):
|
||||
self.assertAllEqual(worker_fn(), 9)
|
||||
|
||||
with ops.device('/job:worker/task:1/device:CPU:0'):
|
||||
self.assertAllEqual(worker_fn(), 8)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user