Fix ResourceVariable placement during checkpointing to correctly colocate the
copy of the variable on the same machine. Addresses Issue #20914. PiperOrigin-RevId: 205317119
This commit is contained in:
parent
3647625e53
commit
8f130ff5b0
@ -126,7 +126,12 @@ class BaseSaverBuilder(object):
|
||||
def f():
|
||||
with ops.device(v.device):
|
||||
x = v.read_value()
|
||||
with ops.device("/device:CPU:0"):
|
||||
# To allow variables placed on non-CPU devices to be checkpointed,
|
||||
# we copy them to CPU on the same machine first.
|
||||
device_spec = pydev.DeviceSpec().parse_from_string(v.device)
|
||||
device_spec.merge_from(
|
||||
pydev.DeviceSpec().parse_from_string("/device:CPU:0"))
|
||||
with ops.device(device_spec.to_string()):
|
||||
return array_ops.identity(x)
|
||||
return f
|
||||
|
||||
|
@ -174,6 +174,24 @@ class SaverTest(test.TestCase):
|
||||
def testResourceBasic(self):
|
||||
self.basicSaveRestore(resource_variable_ops.ResourceVariable)
|
||||
|
||||
def testResourceColocation(self):
|
||||
partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2)
|
||||
with ops_lib.device("/job:ps/device:GPU:0"):
|
||||
v = variable_scope.get_variable("v0",
|
||||
shape=[10, 2],
|
||||
partitioner=partitioner,
|
||||
use_resource=True)
|
||||
saver_module.Saver({"v0": v}).build()
|
||||
save_op = None
|
||||
for op in ops_lib.get_default_graph().get_operations():
|
||||
if op.type == "SaveV2":
|
||||
save_op = op
|
||||
break
|
||||
assert save_op is not None
|
||||
for save_inp in save_op.inputs[3:]:
|
||||
# Input to SaveV2 op is placed on CPU of the same device as the Variable.
|
||||
self.assertEqual("/job:ps/device:CPU:0", save_inp.device)
|
||||
|
||||
def testResourceVariableReadOpsAddedDeterministically(self):
|
||||
graph_defs = []
|
||||
num_graphs = 10
|
||||
|
Loading…
Reference in New Issue
Block a user