Fix ResourceVariable placement during checkpointing to correctly colocate the

copy of the variable on the same machine. Addresses Issue #20914.

PiperOrigin-RevId: 205317119
This commit is contained in:
A. Unique TensorFlower 2018-07-19 15:58:10 -07:00 committed by TensorFlower Gardener
parent 3647625e53
commit 8f130ff5b0
2 changed files with 24 additions and 1 deletions

View File

@ -126,7 +126,12 @@ class BaseSaverBuilder(object):
def f():
with ops.device(v.device):
x = v.read_value()
with ops.device("/device:CPU:0"):
# To allow variables placed on non-CPU devices to be checkpointed,
# we copy them to CPU on the same machine first.
device_spec = pydev.DeviceSpec().parse_from_string(v.device)
device_spec.merge_from(
pydev.DeviceSpec().parse_from_string("/device:CPU:0"))
with ops.device(device_spec.to_string()):
return array_ops.identity(x)
return f

View File

@ -174,6 +174,24 @@ class SaverTest(test.TestCase):
def testResourceBasic(self):
self.basicSaveRestore(resource_variable_ops.ResourceVariable)
def testResourceColocation(self):
partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2)
with ops_lib.device("/job:ps/device:GPU:0"):
v = variable_scope.get_variable("v0",
shape=[10, 2],
partitioner=partitioner,
use_resource=True)
saver_module.Saver({"v0": v}).build()
save_op = None
for op in ops_lib.get_default_graph().get_operations():
if op.type == "SaveV2":
save_op = op
break
assert save_op is not None
for save_inp in save_op.inputs[3:]:
# Input to SaveV2 op is placed on CPU of the same device as the Variable.
self.assertEqual("/job:ps/device:CPU:0", save_inp.device)
def testResourceVariableReadOpsAddedDeterministically(self):
graph_defs = []
num_graphs = 10