Filter num_accelerators by task_name and task_id where applicable.
PiperOrigin-RevId: 235622654
This commit is contained in:
parent
268f66fb2f
commit
82cf60d9d7
@ -151,6 +151,11 @@ class ClusterResolver(object):
|
||||
devices = get_accelerator_devices(master, config_proto)
|
||||
mapping = collections.defaultdict(int)
|
||||
for device in devices:
|
||||
if task_type is not None and task_id is not None:
|
||||
job_path = '/job:%s' % task_type
|
||||
task_path = '/task:%s' % task_id
|
||||
if job_path not in device.name or task_path not in device.name:
|
||||
continue
|
||||
mapping[device.device_type] += 1
|
||||
return mapping
|
||||
|
||||
|
@ -56,8 +56,8 @@ class BaseClusterResolverTest(test.TestCase):
|
||||
"/job:worker/task:0/device:GPU:3",
|
||||
]
|
||||
device_list = [
|
||||
session._DeviceAttributes(
|
||||
name, "GPU", 1024, 0) for name in device_names
|
||||
session._DeviceAttributes(name, "GPU", 1024, 0)
|
||||
for name in device_names
|
||||
]
|
||||
mock_eager_list_devices.return_value = device_names
|
||||
mock_list_devices.return_value = device_list
|
||||
@ -80,8 +80,8 @@ class BaseClusterResolverTest(test.TestCase):
|
||||
"/job:worker/task:0/device:GPU:3",
|
||||
]
|
||||
device_list = [
|
||||
session._DeviceAttributes(
|
||||
name, name[26:29], 1024, 0) for name in device_names
|
||||
session._DeviceAttributes(name, name[26:29], 1024, 0)
|
||||
for name in device_names
|
||||
]
|
||||
mock_eager_list_devices.return_value = device_names
|
||||
mock_list_devices.return_value = device_list
|
||||
@ -89,6 +89,35 @@ class BaseClusterResolverTest(test.TestCase):
|
||||
resolver = MockBaseClusterResolver()
|
||||
self.assertEqual(resolver.num_accelerators(), {"TPU": 4, "GPU": 4})
|
||||
|
||||
@mock.patch.object(eager.context, "list_devices")
|
||||
@mock.patch.object(session.BaseSession, "list_devices")
|
||||
def testNumAcceleratorsFilterTasks(self, mock_list_devices,
|
||||
mock_eager_list_devices):
|
||||
device_names = [
|
||||
"/job:worker1/task:0/device:TPU:0",
|
||||
"/job:worker1/task:0/device:TPU:1",
|
||||
"/job:worker1/task:0/device:GPU:0",
|
||||
"/job:worker1/task:0/device:GPU:1",
|
||||
"/job:worker2/task:1/device:TPU:2",
|
||||
"/job:worker2/task:2/device:TPU:3",
|
||||
"/job:worker2/task:3/device:GPU:2",
|
||||
"/job:worker2/task:4/device:GPU:3",
|
||||
]
|
||||
device_list = [
|
||||
session._DeviceAttributes(name, name[27:30], 1024, 0)
|
||||
for name in device_names
|
||||
]
|
||||
mock_eager_list_devices.return_value = device_names
|
||||
mock_list_devices.return_value = device_list
|
||||
|
||||
resolver = MockBaseClusterResolver()
|
||||
self.assertEqual(resolver.num_accelerators(task_type="worker1", task_id=0),
|
||||
{"TPU": 2, "GPU": 2})
|
||||
self.assertEqual(resolver.num_accelerators(task_type="worker2", task_id=3),
|
||||
{"GPU": 1})
|
||||
self.assertEqual(resolver.num_accelerators(task_type="worker2", task_id=4),
|
||||
{"GPU": 1})
|
||||
|
||||
|
||||
class UnionClusterResolverTest(test.TestCase):
|
||||
# TODO(frankchn): Transform to parameterized test after it is included in the
|
||||
|
Loading…
x
Reference in New Issue
Block a user