Getting initialized value from the same device as the original variable,
when creating EMA variables. PiperOrigin-RevId: 356380269 Change-Id: I09d330dc4f321396a082abd8fda02ebc5199f5b0
This commit is contained in:
parent
e920f7a7ea
commit
b757b8a8db
@ -453,9 +453,11 @@ class ExponentialMovingAverage(object):
|
||||
# tensors, we rely on the existing device allocation mechanism.
|
||||
with ops.init_scope():
|
||||
if isinstance(var, variables.Variable):
|
||||
with ops.device(var.device):
|
||||
initialized_value = var.initialized_value()
|
||||
avg = slot_creator.create_slot(
|
||||
var,
|
||||
var.initialized_value(),
|
||||
initialized_value,
|
||||
self.name,
|
||||
colocate_with_primary=True,
|
||||
copy_xla_sharding=True)
|
||||
|
Loading…
Reference in New Issue
Block a user