diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index 1d5c2c8f452..dc6dc4071bd 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import collections as pycoll +import copy +import threading from tensorflow.python.distribute import all_reduce from tensorflow.python.distribute import values as value_lib @@ -244,6 +246,8 @@ class CollectiveKeys(object): "Graph key": an integer key that is unique key graph. This is used to support multiple graphs per client session. It must be non-zero and set in the `config` argument of each call to `session.run`. + + This class is thread safe. """ def __init__(self, @@ -264,6 +268,7 @@ class CollectiveKeys(object): assert op_instance_key_start != variable_instance_key_start self._op_instance_key = op_instance_key_start self._variable_instance_key = variable_instance_key_start + self._lock = threading.Lock() def get_group_key(self, devices): """Returns a group key for the set of devices. @@ -282,23 +287,36 @@ class CollectiveKeys(object): # task_type and task_id. names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed]) key_id = ','.join(names) - if key_id not in self._group_key_table: - new_key = self._group_key - self._group_key += 1 - self._group_key_table[key_id] = new_key - return self._group_key_table[key_id] + with self._lock: + if key_id not in self._group_key_table: + new_key = self._group_key + self._group_key += 1 + self._group_key_table[key_id] = new_key + return self._group_key_table[key_id] def get_op_instance_key(self): """Returns a new instance key for use in defining a collective op.""" - v = self._op_instance_key - self._op_instance_key += 1 - return v + with self._lock: + v = self._op_instance_key + self._op_instance_key += 1 + return v def get_variable_instance_key(self): """Returns a new instance key for use in creating a Variable.""" - v = self._variable_instance_key - self._variable_instance_key += 1 - return v + with self._lock: + v = self._variable_instance_key + self._variable_instance_key += 1 + return v + + def __deepcopy__(self, memo): + # distribute_coordinator deep-copies the strategy object, so + # CollectiveKeys needs to support deep copy as well. + copied = CollectiveKeys() + copied._group_key = self._group_key + copied._group_key_table = copy.deepcopy(self._group_key_table, memo) + copied._op_instance_key = self._op_instance_key + copied._variable_instance_key = self._variable_instance_key + return copied def build_collective_reduce(input_tensors,