diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 3d6303d200b..730f459c1e3 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -453,9 +453,11 @@ class ExponentialMovingAverage(object): # tensors, we rely on the existing device allocation mechanism. with ops.init_scope(): if isinstance(var, variables.Variable): + with ops.device(var.device): + initialized_value = var.initialized_value() avg = slot_creator.create_slot( var, - var.initialized_value(), + initialized_value, self.name, colocate_with_primary=True, copy_xla_sharding=True)