Getting initialized value from the same device as the original variable,

when creating EMA variables.

PiperOrigin-RevId: 356380269
Change-Id: I09d330dc4f321396a082abd8fda02ebc5199f5b0
This commit is contained in:
A. Unique TensorFlower 2021-02-08 16:31:19 -08:00 committed by TensorFlower Gardener
parent e920f7a7ea
commit b757b8a8db

View File

@ -453,9 +453,11 @@ class ExponentialMovingAverage(object):
# tensors, we rely on the existing device allocation mechanism.
with ops.init_scope():
if isinstance(var, variables.Variable):
with ops.device(var.device):
initialized_value = var.initialized_value()
avg = slot_creator.create_slot(
var,
var.initialized_value(),
initialized_value,
self.name,
colocate_with_primary=True,
copy_xla_sharding=True)