Fix a unexpected behavior in Parameter Server. Let's say we have two worker, one ps and a function which will read a variable on ps. We first execute the function on worker:0, a send kernel will be created on ps to send the variable to work:0. Then when we execute the same function on worker:1, since the partition function has the same name, same partition function will be executed on ps. This partition function will send variable to worker:0 instead of worker:1.
PiperOrigin-RevId: 253679384
This commit is contained in:
parent
3f13e7a71b
commit
8507511e5d
@ -37,6 +37,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/graph_partition.h"
|
#include "tensorflow/core/graph/graph_partition.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
#include "tensorflow/core/util/dump_graph.h"
|
#include "tensorflow/core/util/dump_graph.h"
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
@ -737,7 +738,10 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
};
|
};
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
FunctionNameGenerator name_generator(&data->lib_def_, function_name);
|
// Generate a random function_name to avoid one function reuse the partition
|
||||||
|
// function instantiated by another function.
|
||||||
|
FunctionNameGenerator name_generator(
|
||||||
|
&data->lib_def_, absl::StrCat(function_name, "_", random::New64()));
|
||||||
for (const auto& pair : subgraphs) {
|
for (const auto& pair : subgraphs) {
|
||||||
i += 1;
|
i += 1;
|
||||||
const string& target = pair.first;
|
const string& target = pair.first;
|
||||||
|
@ -96,8 +96,9 @@ class MultiWorkersTest(test.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(MultiWorkersTest, self).setUp()
|
super(MultiWorkersTest, self).setUp()
|
||||||
|
|
||||||
workers, _ = test_util.create_local_cluster(2, 0)
|
workers, _ = test_util.create_local_cluster(3, 0)
|
||||||
remote.connect_to_remote_host([workers[0].target, workers[1].target])
|
remote.connect_to_remote_host(
|
||||||
|
[workers[0].target, workers[1].target, workers[2].target])
|
||||||
|
|
||||||
def testMultiDeviceFunctionOnRemoteDevice(self):
|
def testMultiDeviceFunctionOnRemoteDevice(self):
|
||||||
with ops.device('/job:worker/replica:0/task:1'):
|
with ops.device('/job:worker/replica:0/task:1'):
|
||||||
@ -113,6 +114,24 @@ class MultiWorkersTest(test.TestCase):
|
|||||||
with ops.device('/job:worker/replica:0/task:0'):
|
with ops.device('/job:worker/replica:0/task:0'):
|
||||||
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
||||||
|
|
||||||
|
def testSimpleParameterServer(self):
|
||||||
|
|
||||||
|
with ops.device('/job:worker/task:2/device:CPU:0'):
|
||||||
|
v1 = variables.Variable(initial_value=0)
|
||||||
|
v2 = variables.Variable(initial_value=10)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def worker_fn():
|
||||||
|
v1.assign_add(1)
|
||||||
|
v2.assign_sub(2)
|
||||||
|
return v1.read_value() + v2.read_value()
|
||||||
|
|
||||||
|
with ops.device('/job:worker/task:0/device:CPU:0'):
|
||||||
|
self.assertAllEqual(worker_fn(), 9)
|
||||||
|
|
||||||
|
with ops.device('/job:worker/task:1/device:CPU:0'):
|
||||||
|
self.assertAllEqual(worker_fn(), 8)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user