diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index e68aea46ecd..b0979d4c8d7 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -248,14 +248,15 @@ void BaseRemoteRendezvous::SameWorkerRecvDone( return; } + WorkerSession* sess = session(); Device* src_device; - Status s = env_->device_mgr->LookupDevice(parsed.src_device, &src_device); + Status s = sess->device_mgr->LookupDevice(parsed.src_device, &src_device); if (!s.ok()) { done(s); return; } Device* dst_device; - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); if (!s.ok()) { done(s); return; diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index e3f23ef0dd0..4ff2d0f5e3d 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -372,7 +372,7 @@ void Master::CreateSession(const CreateSessionRequest* req, DeviceNameUtils::ParsedName name = d->parsed_name(); if (name.job == *worker_cache_factory_options.job_name && name.task == worker_cache_factory_options.task_index && - name.type == "CPU") { + name.type == "CPU" && name.id == 0) { device_set->set_client_device(d.get()); } } @@ -399,7 +399,8 @@ void Master::CreateSession(const CreateSessionRequest* req, } } - CHECK(device_set->client_device()); + CHECK(device_set->client_device()) << "No client device found. Missing " + << "CPU:0 device?"; SessionOptions options; options.config = req->config(); diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h index f09bea328fd..793d58c8a1c 100644 --- a/tensorflow/core/distributed_runtime/worker_env.h +++ b/tensorflow/core/distributed_runtime/worker_env.h @@ -48,6 +48,10 @@ struct WorkerEnv { // device_mgr manages local devices (cpu and gpu). The WorkerService // is the network interface for managed devices. + // + // Note: Please use the device_mgr associated with your session if appropriate + // instead of this one. Using this device_mgr does not support ClusterSpec + // propagated sessions. DeviceMgr* device_mgr = nullptr; // A set of rendezvous keyed by step ids. diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 922f9baf78d..7882e088d0e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3074,6 +3074,7 @@ py_test( deps = [ ":array_ops", ":client", + ":client_testlib", ":framework", ":framework_for_generated_wrappers", ":framework_test_lib", diff --git a/tensorflow/python/client/session_clusterspec_prop_test.py b/tensorflow/python/client/session_clusterspec_prop_test.py index f40d3f18728..1ce2b7d7c3a 100644 --- a/tensorflow/python/client/session_clusterspec_prop_test.py +++ b/tensorflow/python/client/session_clusterspec_prop_test.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import resource_variable_ops # pylint: disable=unuse from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.platform import test from tensorflow.python.training import server_lib ops._USE_C_API = True @@ -135,6 +136,82 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): dev_stats.device and 'Const' == node_stats.node_name ])) + def testFullDeviceNames(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'renamed_worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.Graph().as_default() as g, ops.device( + '/job:renamed_worker/replica:0/task:1/device:CPU:0'): + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config, graph=g) + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + output = sess.run(const, options=run_options, run_metadata=run_metadata) + self.assertEqual(17, output) + self.assertEqual(1, + len([ + node_stats + for dev_stats in run_metadata.step_stats.dev_stats + for node_stats in dev_stats.node_stats + if '/job:renamed_worker/replica:0/task:1/device:CPU:0' + == dev_stats.device and 'Const' == node_stats.node_name + ])) + + @test_util.disable_c_api # Operation._set_device doesn't work with C API + def testMultipleLocalDevices(self): + # Note: CPU->CPU transfers have a fast-path in + # BaseRemoteRendezvous::SameWorkerRecvDone that means the test doesn't + # actually capture the motivating bug unless run on a GPU machine. + # + # Example error message (before bugfix -- linebreaks added because lint): + # + # W0718 17:14:41.521534 190121 device_mgr.cc:107] Unknown device: + # /job:worker/replica:0/task:0/device:CPU:0 all devices: + # /job:local/replica:0/task:0/gpu:0, + # /job:local/replica:0/task:0/device:GPU:0, + # /job:local/replica:0/task:0/cpu:1, CPU:0, GPU:0, + # /job:local/replica:0/task:0/device:CPU:1, + # /job:local/replica:0/task:0/device:CPU:0, CPU:1, + # /job:local/replica:0/task:0/cpu:0 + server_config = config_pb2.ConfigProto(device_count={'CPU': 2}) + server1 = server_lib.Server.create_local_server(config=server_config) + server2 = server_lib.Server.create_local_server(config=server_config) + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.Graph().as_default() as g: + with ops.device('/job:worker/task:1/cpu:1'): + input1 = constant_op.constant(17, dtypes.float32) + with ops.device('/job:worker/task:0/cpu:1'): + input2 = constant_op.constant(3, dtypes.float32) + with ops.device('/job:worker/task:1/cpu:0'): + sum1 = input1 + input2 + + if test.is_gpu_available(): + device_str = '/job:worker/task:0/gpu:0' + else: + device_str = '/job:worker/task:0/cpu:1' + with ops.device(device_str): + sum2 = input2 + input1 + + with ops.device('/job:worker/task:0/cpu:0'): + sum3 = sum1 + sum2 + sess = session.Session(server1.target, config=config, graph=g) + output = sess.run(sum3) + self.assertEqual(40, output) + + @test_util.disable_c_api # Operation._set_device doesn't work with C API def testLegacyDeviceNames(self): server1 = server_lib.Server.create_local_server() server2 = server_lib.Server.create_local_server()