diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index febdc2ae556..8813dad4952 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -35,6 +35,9 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nccl_ops +OP_INSTANCE_KEY_START_NUMBER = 100 + + def aggregate_gradients_using_nccl(replica_grads): """Aggregate gradients using nccl allreduce.""" agg_all_g_and_v = [] @@ -253,7 +256,7 @@ class CollectiveKeys(object): def __init__(self, group_key_start=1, - op_instance_key_start=100, + op_instance_key_start=OP_INSTANCE_KEY_START_NUMBER, variable_instance_key_start=1000000): """Initializes the object.