Getting initialized value from the same device as the original variable,
when creating EMA variables. PiperOrigin-RevId: 356380269 Change-Id: I09d330dc4f321396a082abd8fda02ebc5199f5b0
This commit is contained in:
parent
e920f7a7ea
commit
b757b8a8db
@ -453,9 +453,11 @@ class ExponentialMovingAverage(object):
|
|||||||
# tensors, we rely on the existing device allocation mechanism.
|
# tensors, we rely on the existing device allocation mechanism.
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
if isinstance(var, variables.Variable):
|
if isinstance(var, variables.Variable):
|
||||||
|
with ops.device(var.device):
|
||||||
|
initialized_value = var.initialized_value()
|
||||||
avg = slot_creator.create_slot(
|
avg = slot_creator.create_slot(
|
||||||
var,
|
var,
|
||||||
var.initialized_value(),
|
initialized_value,
|
||||||
self.name,
|
self.name,
|
||||||
colocate_with_primary=True,
|
colocate_with_primary=True,
|
||||||
copy_xla_sharding=True)
|
copy_xla_sharding=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user