Add back the lock in CollectiveKeys
MWMS is using CollectiveKeys directly when broadcasting variable initial values, so the lock in CollectiveAllReduce is not enough. This change also acquires in all methods in CollectiveKeys, instead of get_group_key(). PiperOrigin-RevId: 323938375 Change-Id: I15ea98ff62952d0c3bd4d33f74067b4bad03d7cb
This commit is contained in:
parent
561b8292a2
commit
cb2cf6f56f
@ -19,6 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections as pycoll
|
||||
import copy
|
||||
import threading
|
||||
|
||||
from tensorflow.python.distribute import all_reduce
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
@ -244,6 +246,8 @@ class CollectiveKeys(object):
|
||||
"Graph key": an integer key that is unique key graph. This is used to support
|
||||
multiple graphs per client session. It must be non-zero and set in the
|
||||
`config` argument of each call to `session.run`.
|
||||
|
||||
This class is thread safe.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -264,6 +268,7 @@ class CollectiveKeys(object):
|
||||
assert op_instance_key_start != variable_instance_key_start
|
||||
self._op_instance_key = op_instance_key_start
|
||||
self._variable_instance_key = variable_instance_key_start
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_group_key(self, devices):
|
||||
"""Returns a group key for the set of devices.
|
||||
@ -282,23 +287,36 @@ class CollectiveKeys(object):
|
||||
# task_type and task_id.
|
||||
names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed])
|
||||
key_id = ','.join(names)
|
||||
if key_id not in self._group_key_table:
|
||||
new_key = self._group_key
|
||||
self._group_key += 1
|
||||
self._group_key_table[key_id] = new_key
|
||||
return self._group_key_table[key_id]
|
||||
with self._lock:
|
||||
if key_id not in self._group_key_table:
|
||||
new_key = self._group_key
|
||||
self._group_key += 1
|
||||
self._group_key_table[key_id] = new_key
|
||||
return self._group_key_table[key_id]
|
||||
|
||||
def get_op_instance_key(self):
|
||||
"""Returns a new instance key for use in defining a collective op."""
|
||||
v = self._op_instance_key
|
||||
self._op_instance_key += 1
|
||||
return v
|
||||
with self._lock:
|
||||
v = self._op_instance_key
|
||||
self._op_instance_key += 1
|
||||
return v
|
||||
|
||||
def get_variable_instance_key(self):
|
||||
"""Returns a new instance key for use in creating a Variable."""
|
||||
v = self._variable_instance_key
|
||||
self._variable_instance_key += 1
|
||||
return v
|
||||
with self._lock:
|
||||
v = self._variable_instance_key
|
||||
self._variable_instance_key += 1
|
||||
return v
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
# distribute_coordinator deep-copies the strategy object, so
|
||||
# CollectiveKeys needs to support deep copy as well.
|
||||
copied = CollectiveKeys()
|
||||
copied._group_key = self._group_key
|
||||
copied._group_key_table = copy.deepcopy(self._group_key_table, memo)
|
||||
copied._op_instance_key = self._op_instance_key
|
||||
copied._variable_instance_key = self._variable_instance_key
|
||||
return copied
|
||||
|
||||
|
||||
def build_collective_reduce(input_tensors,
|
||||
|
Loading…
x
Reference in New Issue
Block a user