diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index e988499292e..216ec8b62d2 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -840,7 +840,7 @@ class Strategy(object): if dim is not None: # By returning a python value in the static shape case, we can # maybe get a fast path for reducing the denominator. - return numer, dim + return numer, array_ops.constant(dim, dtype=dtypes.int64) elif axis < 0: axis = axis + array_ops.rank(v) if v.shape.rank == 1: