Add back the lock in CollectiveKeys

MWMS is using CollectiveKeys directly when broadcasting variable initial values,
so the lock in CollectiveAllReduce is not enough.

This change also acquires in all methods in CollectiveKeys, instead of get_group_key().

PiperOrigin-RevId: 323938375
Change-Id: I15ea98ff62952d0c3bd4d33f74067b4bad03d7cb
This commit is contained in:
Ran Chen 2020-07-29 22:23:50 -07:00 committed by TensorFlower Gardener
parent 561b8292a2
commit cb2cf6f56f

View File

@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
import collections as pycoll
import copy
import threading
from tensorflow.python.distribute import all_reduce
from tensorflow.python.distribute import values as value_lib
@ -244,6 +246,8 @@ class CollectiveKeys(object):
"Graph key": an integer key that is unique key graph. This is used to support
multiple graphs per client session. It must be non-zero and set in the
`config` argument of each call to `session.run`.
This class is thread safe.
"""
def __init__(self,
@ -264,6 +268,7 @@ class CollectiveKeys(object):
assert op_instance_key_start != variable_instance_key_start
self._op_instance_key = op_instance_key_start
self._variable_instance_key = variable_instance_key_start
self._lock = threading.Lock()
def get_group_key(self, devices):
"""Returns a group key for the set of devices.
@ -282,23 +287,36 @@ class CollectiveKeys(object):
# task_type and task_id.
names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed])
key_id = ','.join(names)
if key_id not in self._group_key_table:
new_key = self._group_key
self._group_key += 1
self._group_key_table[key_id] = new_key
return self._group_key_table[key_id]
with self._lock:
if key_id not in self._group_key_table:
new_key = self._group_key
self._group_key += 1
self._group_key_table[key_id] = new_key
return self._group_key_table[key_id]
def get_op_instance_key(self):
"""Returns a new instance key for use in defining a collective op."""
v = self._op_instance_key
self._op_instance_key += 1
return v
with self._lock:
v = self._op_instance_key
self._op_instance_key += 1
return v
def get_variable_instance_key(self):
"""Returns a new instance key for use in creating a Variable."""
v = self._variable_instance_key
self._variable_instance_key += 1
return v
with self._lock:
v = self._variable_instance_key
self._variable_instance_key += 1
return v
def __deepcopy__(self, memo):
# distribute_coordinator deep-copies the strategy object, so
# CollectiveKeys needs to support deep copy as well.
copied = CollectiveKeys()
copied._group_key = self._group_key
copied._group_key_table = copy.deepcopy(self._group_key_table, memo)
copied._op_instance_key = self._op_instance_key
copied._variable_instance_key = self._variable_instance_key
return copied
def build_collective_reduce(input_tensors,