Bugfix: Never use env_->device_mgr
The base_rendezvous_mgr handles transferring a tensor using DMAs to non-"host" devices such as GPUs in the SameWorkerRecvDone function. This function would use the worker_env's device_mgr to obtain a pointer to the relevant Device (using the LookupDevice call). In the ClusterSpec-propagation world, using the environment's device_set means that the devices are not renamed, often resulting in devices not being found. This change updates BaseRemoteRendezvous to use the WorkerSession stored when the BaseRemoteRendezvous is initialized. The WorkerSession has a pointer to a DeviceMgr that contains the appropriately renamed devices for the given session the Rendezvous is associated with. Note: because we have a fast-path host-device-only copy, the original bug does not show up when using 2 CPU devices. I have added a test to ensure that transferring between 2 CPU devices works in a ClusterSpec propagation session, but note that this test does not actually reproduce the motivating bug. In the process of writing a test for the original bug, I discovered another latent bug in ClusterSpec propagation where if there were 2 CPU devices (i.e. due to explicit server configuration to have 2 CPU devices), a DCHECK could be triggered. The Master::CreateSession would call `device_set->set_client_device` multiple times (once for each CPU device). PiperOrigin-RevId: 162680163
This commit is contained in:
parent
f199febfa8
commit
d2f8e98650
@ -248,14 +248,15 @@ void BaseRemoteRendezvous::SameWorkerRecvDone(
|
||||
return;
|
||||
}
|
||||
|
||||
WorkerSession* sess = session();
|
||||
Device* src_device;
|
||||
Status s = env_->device_mgr->LookupDevice(parsed.src_device, &src_device);
|
||||
Status s = sess->device_mgr->LookupDevice(parsed.src_device, &src_device);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
Device* dst_device;
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
|
@ -372,7 +372,7 @@ void Master::CreateSession(const CreateSessionRequest* req,
|
||||
DeviceNameUtils::ParsedName name = d->parsed_name();
|
||||
if (name.job == *worker_cache_factory_options.job_name &&
|
||||
name.task == worker_cache_factory_options.task_index &&
|
||||
name.type == "CPU") {
|
||||
name.type == "CPU" && name.id == 0) {
|
||||
device_set->set_client_device(d.get());
|
||||
}
|
||||
}
|
||||
@ -399,7 +399,8 @@ void Master::CreateSession(const CreateSessionRequest* req,
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(device_set->client_device());
|
||||
CHECK(device_set->client_device()) << "No client device found. Missing "
|
||||
<< "CPU:0 device?";
|
||||
|
||||
SessionOptions options;
|
||||
options.config = req->config();
|
||||
|
@ -48,6 +48,10 @@ struct WorkerEnv {
|
||||
|
||||
// device_mgr manages local devices (cpu and gpu). The WorkerService
|
||||
// is the network interface for managed devices.
|
||||
//
|
||||
// Note: Please use the device_mgr associated with your session if appropriate
|
||||
// instead of this one. Using this device_mgr does not support ClusterSpec
|
||||
// propagated sessions.
|
||||
DeviceMgr* device_mgr = nullptr;
|
||||
|
||||
// A set of rendezvous keyed by step ids.
|
||||
|
@ -3074,6 +3074,7 @@ py_test(
|
||||
deps = [
|
||||
":array_ops",
|
||||
":client",
|
||||
":client_testlib",
|
||||
":framework",
|
||||
":framework_for_generated_wrappers",
|
||||
":framework_test_lib",
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.ops import resource_variable_ops # pylint: disable=unuse
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
ops._USE_C_API = True
|
||||
@ -135,6 +136,82 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
||||
dev_stats.device and 'Const' == node_stats.node_name
|
||||
]))
|
||||
|
||||
def testFullDeviceNames(self):
|
||||
server1 = server_lib.Server.create_local_server()
|
||||
server2 = server_lib.Server.create_local_server()
|
||||
cluster_def = cluster_pb2.ClusterDef()
|
||||
job = cluster_def.job.add()
|
||||
job.name = 'renamed_worker'
|
||||
job.tasks[0] = server1.target[len('grpc://'):]
|
||||
job.tasks[1] = server2.target[len('grpc://'):]
|
||||
config = config_pb2.ConfigProto(cluster_def=cluster_def)
|
||||
|
||||
with ops.Graph().as_default() as g, ops.device(
|
||||
'/job:renamed_worker/replica:0/task:1/device:CPU:0'):
|
||||
const = constant_op.constant(17)
|
||||
sess = session.Session(server1.target, config=config, graph=g)
|
||||
run_options = config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
output = sess.run(const, options=run_options, run_metadata=run_metadata)
|
||||
self.assertEqual(17, output)
|
||||
self.assertEqual(1,
|
||||
len([
|
||||
node_stats
|
||||
for dev_stats in run_metadata.step_stats.dev_stats
|
||||
for node_stats in dev_stats.node_stats
|
||||
if '/job:renamed_worker/replica:0/task:1/device:CPU:0'
|
||||
== dev_stats.device and 'Const' == node_stats.node_name
|
||||
]))
|
||||
|
||||
@test_util.disable_c_api # Operation._set_device doesn't work with C API
|
||||
def testMultipleLocalDevices(self):
|
||||
# Note: CPU->CPU transfers have a fast-path in
|
||||
# BaseRemoteRendezvous::SameWorkerRecvDone that means the test doesn't
|
||||
# actually capture the motivating bug unless run on a GPU machine.
|
||||
#
|
||||
# Example error message (before bugfix -- linebreaks added because lint):
|
||||
#
|
||||
# W0718 17:14:41.521534 190121 device_mgr.cc:107] Unknown device:
|
||||
# /job:worker/replica:0/task:0/device:CPU:0 all devices:
|
||||
# /job:local/replica:0/task:0/gpu:0,
|
||||
# /job:local/replica:0/task:0/device:GPU:0,
|
||||
# /job:local/replica:0/task:0/cpu:1, CPU:0, GPU:0,
|
||||
# /job:local/replica:0/task:0/device:CPU:1,
|
||||
# /job:local/replica:0/task:0/device:CPU:0, CPU:1,
|
||||
# /job:local/replica:0/task:0/cpu:0
|
||||
server_config = config_pb2.ConfigProto(device_count={'CPU': 2})
|
||||
server1 = server_lib.Server.create_local_server(config=server_config)
|
||||
server2 = server_lib.Server.create_local_server(config=server_config)
|
||||
cluster_def = cluster_pb2.ClusterDef()
|
||||
job = cluster_def.job.add()
|
||||
job.name = 'worker'
|
||||
job.tasks[0] = server1.target[len('grpc://'):]
|
||||
job.tasks[1] = server2.target[len('grpc://'):]
|
||||
config = config_pb2.ConfigProto(cluster_def=cluster_def)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with ops.device('/job:worker/task:1/cpu:1'):
|
||||
input1 = constant_op.constant(17, dtypes.float32)
|
||||
with ops.device('/job:worker/task:0/cpu:1'):
|
||||
input2 = constant_op.constant(3, dtypes.float32)
|
||||
with ops.device('/job:worker/task:1/cpu:0'):
|
||||
sum1 = input1 + input2
|
||||
|
||||
if test.is_gpu_available():
|
||||
device_str = '/job:worker/task:0/gpu:0'
|
||||
else:
|
||||
device_str = '/job:worker/task:0/cpu:1'
|
||||
with ops.device(device_str):
|
||||
sum2 = input2 + input1
|
||||
|
||||
with ops.device('/job:worker/task:0/cpu:0'):
|
||||
sum3 = sum1 + sum2
|
||||
sess = session.Session(server1.target, config=config, graph=g)
|
||||
output = sess.run(sum3)
|
||||
self.assertEqual(40, output)
|
||||
|
||||
@test_util.disable_c_api # Operation._set_device doesn't work with C API
|
||||
def testLegacyDeviceNames(self):
|
||||
server1 = server_lib.Server.create_local_server()
|
||||
server2 = server_lib.Server.create_local_server()
|
||||
|
Loading…
Reference in New Issue
Block a user