Filter num_accelerators by task_name and task_id where applicable.

PiperOrigin-RevId: 235622654
This commit is contained in:
Frank Chen 2019-02-25 16:20:53 -08:00 committed by TensorFlower Gardener
parent 268f66fb2f
commit 82cf60d9d7
2 changed files with 38 additions and 4 deletions

View File

@ -151,6 +151,11 @@ class ClusterResolver(object):
devices = get_accelerator_devices(master, config_proto)
mapping = collections.defaultdict(int)
for device in devices:
if task_type is not None and task_id is not None:
job_path = '/job:%s' % task_type
task_path = '/task:%s' % task_id
if job_path not in device.name or task_path not in device.name:
continue
mapping[device.device_type] += 1
return mapping

View File

@ -56,8 +56,8 @@ class BaseClusterResolverTest(test.TestCase):
"/job:worker/task:0/device:GPU:3",
]
device_list = [
session._DeviceAttributes(
name, "GPU", 1024, 0) for name in device_names
session._DeviceAttributes(name, "GPU", 1024, 0)
for name in device_names
]
mock_eager_list_devices.return_value = device_names
mock_list_devices.return_value = device_list
@ -80,8 +80,8 @@ class BaseClusterResolverTest(test.TestCase):
"/job:worker/task:0/device:GPU:3",
]
device_list = [
session._DeviceAttributes(
name, name[26:29], 1024, 0) for name in device_names
session._DeviceAttributes(name, name[26:29], 1024, 0)
for name in device_names
]
mock_eager_list_devices.return_value = device_names
mock_list_devices.return_value = device_list
@ -89,6 +89,35 @@ class BaseClusterResolverTest(test.TestCase):
resolver = MockBaseClusterResolver()
self.assertEqual(resolver.num_accelerators(), {"TPU": 4, "GPU": 4})
@mock.patch.object(eager.context, "list_devices")
@mock.patch.object(session.BaseSession, "list_devices")
def testNumAcceleratorsFilterTasks(self, mock_list_devices,
mock_eager_list_devices):
device_names = [
"/job:worker1/task:0/device:TPU:0",
"/job:worker1/task:0/device:TPU:1",
"/job:worker1/task:0/device:GPU:0",
"/job:worker1/task:0/device:GPU:1",
"/job:worker2/task:1/device:TPU:2",
"/job:worker2/task:2/device:TPU:3",
"/job:worker2/task:3/device:GPU:2",
"/job:worker2/task:4/device:GPU:3",
]
device_list = [
session._DeviceAttributes(name, name[27:30], 1024, 0)
for name in device_names
]
mock_eager_list_devices.return_value = device_names
mock_list_devices.return_value = device_list
resolver = MockBaseClusterResolver()
self.assertEqual(resolver.num_accelerators(task_type="worker1", task_id=0),
{"TPU": 2, "GPU": 2})
self.assertEqual(resolver.num_accelerators(task_type="worker2", task_id=3),
{"GPU": 1})
self.assertEqual(resolver.num_accelerators(task_type="worker2", task_id=4),
{"GPU": 1})
class UnionClusterResolverTest(test.TestCase):
# TODO(frankchn): Transform to parameterized test after it is included in the