From d2f8e98650f553009b53d814814fb2d9ead66711 Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Thu, 20 Jul 2017 16:11:17 -0700 Subject: [PATCH] Bugfix: Never use env_->device_mgr The base_rendezvous_mgr handles transferring a tensor using DMAs to non-"host" devices such as GPUs in the SameWorkerRecvDone function. This function would use the worker_env's device_mgr to obtain a pointer to the relevant Device (using the LookupDevice call). In the ClusterSpec-propagation world, using the environment's device_set means that the devices are not renamed, often resulting in devices not being found. This change updates BaseRemoteRendezvous to use the WorkerSession stored when the BaseRemoteRendezvous is initialized. The WorkerSession has a pointer to a DeviceMgr that contains the appropriately renamed devices for the given session the Rendezvous is associated with. Note: because we have a fast-path host-device-only copy, the original bug does not show up when using 2 CPU devices. I have added a test to ensure that transferring between 2 CPU devices works in a ClusterSpec propagation session, but note that this test does not actually reproduce the motivating bug. In the process of writing a test for the original bug, I discovered another latent bug in ClusterSpec propagation where if there were 2 CPU devices (i.e. due to explicit server configuration to have 2 CPU devices), a DCHECK could be triggered. The Master::CreateSession would call `device_set->set_client_device` multiple times (once for each CPU device). PiperOrigin-RevId: 162680163 --- .../base_rendezvous_mgr.cc | 5 +- tensorflow/core/distributed_runtime/master.cc | 5 +- .../core/distributed_runtime/worker_env.h | 4 + tensorflow/python/BUILD | 1 + .../client/session_clusterspec_prop_test.py | 77 +++++++++++++++++++ 5 files changed, 88 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index e68aea46ecd..b0979d4c8d7 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -248,14 +248,15 @@ void BaseRemoteRendezvous::SameWorkerRecvDone( return; } + WorkerSession* sess = session(); Device* src_device; - Status s = env_->device_mgr->LookupDevice(parsed.src_device, &src_device); + Status s = sess->device_mgr->LookupDevice(parsed.src_device, &src_device); if (!s.ok()) { done(s); return; } Device* dst_device; - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); if (!s.ok()) { done(s); return; diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index e3f23ef0dd0..4ff2d0f5e3d 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -372,7 +372,7 @@ void Master::CreateSession(const CreateSessionRequest* req, DeviceNameUtils::ParsedName name = d->parsed_name(); if (name.job == *worker_cache_factory_options.job_name && name.task == worker_cache_factory_options.task_index && - name.type == "CPU") { + name.type == "CPU" && name.id == 0) { device_set->set_client_device(d.get()); } } @@ -399,7 +399,8 @@ void Master::CreateSession(const CreateSessionRequest* req, } } - CHECK(device_set->client_device()); + CHECK(device_set->client_device()) << "No client device found. Missing " + << "CPU:0 device?"; SessionOptions options; options.config = req->config(); diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h index f09bea328fd..793d58c8a1c 100644 --- a/tensorflow/core/distributed_runtime/worker_env.h +++ b/tensorflow/core/distributed_runtime/worker_env.h @@ -48,6 +48,10 @@ struct WorkerEnv { // device_mgr manages local devices (cpu and gpu). The WorkerService // is the network interface for managed devices. + // + // Note: Please use the device_mgr associated with your session if appropriate + // instead of this one. Using this device_mgr does not support ClusterSpec + // propagated sessions. DeviceMgr* device_mgr = nullptr; // A set of rendezvous keyed by step ids. diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 922f9baf78d..7882e088d0e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3074,6 +3074,7 @@ py_test( deps = [ ":array_ops", ":client", + ":client_testlib", ":framework", ":framework_for_generated_wrappers", ":framework_test_lib", diff --git a/tensorflow/python/client/session_clusterspec_prop_test.py b/tensorflow/python/client/session_clusterspec_prop_test.py index f40d3f18728..1ce2b7d7c3a 100644 --- a/tensorflow/python/client/session_clusterspec_prop_test.py +++ b/tensorflow/python/client/session_clusterspec_prop_test.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import resource_variable_ops # pylint: disable=unuse from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.platform import test from tensorflow.python.training import server_lib ops._USE_C_API = True @@ -135,6 +136,82 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): dev_stats.device and 'Const' == node_stats.node_name ])) + def testFullDeviceNames(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'renamed_worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.Graph().as_default() as g, ops.device( + '/job:renamed_worker/replica:0/task:1/device:CPU:0'): + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config, graph=g) + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + output = sess.run(const, options=run_options, run_metadata=run_metadata) + self.assertEqual(17, output) + self.assertEqual(1, + len([ + node_stats + for dev_stats in run_metadata.step_stats.dev_stats + for node_stats in dev_stats.node_stats + if '/job:renamed_worker/replica:0/task:1/device:CPU:0' + == dev_stats.device and 'Const' == node_stats.node_name + ])) + + @test_util.disable_c_api # Operation._set_device doesn't work with C API + def testMultipleLocalDevices(self): + # Note: CPU->CPU transfers have a fast-path in + # BaseRemoteRendezvous::SameWorkerRecvDone that means the test doesn't + # actually capture the motivating bug unless run on a GPU machine. + # + # Example error message (before bugfix -- linebreaks added because lint): + # + # W0718 17:14:41.521534 190121 device_mgr.cc:107] Unknown device: + # /job:worker/replica:0/task:0/device:CPU:0 all devices: + # /job:local/replica:0/task:0/gpu:0, + # /job:local/replica:0/task:0/device:GPU:0, + # /job:local/replica:0/task:0/cpu:1, CPU:0, GPU:0, + # /job:local/replica:0/task:0/device:CPU:1, + # /job:local/replica:0/task:0/device:CPU:0, CPU:1, + # /job:local/replica:0/task:0/cpu:0 + server_config = config_pb2.ConfigProto(device_count={'CPU': 2}) + server1 = server_lib.Server.create_local_server(config=server_config) + server2 = server_lib.Server.create_local_server(config=server_config) + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.Graph().as_default() as g: + with ops.device('/job:worker/task:1/cpu:1'): + input1 = constant_op.constant(17, dtypes.float32) + with ops.device('/job:worker/task:0/cpu:1'): + input2 = constant_op.constant(3, dtypes.float32) + with ops.device('/job:worker/task:1/cpu:0'): + sum1 = input1 + input2 + + if test.is_gpu_available(): + device_str = '/job:worker/task:0/gpu:0' + else: + device_str = '/job:worker/task:0/cpu:1' + with ops.device(device_str): + sum2 = input2 + input1 + + with ops.device('/job:worker/task:0/cpu:0'): + sum3 = sum1 + sum2 + sess = session.Session(server1.target, config=config, graph=g) + output = sess.run(sum3) + self.assertEqual(40, output) + + @test_util.disable_c_api # Operation._set_device doesn't work with C API def testLegacyDeviceNames(self): server1 = server_lib.Server.create_local_server() server2 = server_lib.Server.create_local_server()