From 8f130ff5b021efb94946ed9deb1341890763fd3f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Jul 2018 15:58:10 -0700 Subject: [PATCH] Fix ResourceVariable placement during checkpointing to correctly colocate the copy of the variable on the same machine. Addresses Issue #20914. PiperOrigin-RevId: 205317119 --- tensorflow/python/training/saver.py | 7 ++++++- tensorflow/python/training/saver_test.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 1ee975fbe48..11510d9928e 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -126,7 +126,12 @@ class BaseSaverBuilder(object): def f(): with ops.device(v.device): x = v.read_value() - with ops.device("/device:CPU:0"): + # To allow variables placed on non-CPU devices to be checkpointed, + # we copy them to CPU on the same machine first. + device_spec = pydev.DeviceSpec().parse_from_string(v.device) + device_spec.merge_from( + pydev.DeviceSpec().parse_from_string("/device:CPU:0")) + with ops.device(device_spec.to_string()): return array_ops.identity(x) return f diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index ae9c244aaf3..ecce8ae6bde 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -174,6 +174,24 @@ class SaverTest(test.TestCase): def testResourceBasic(self): self.basicSaveRestore(resource_variable_ops.ResourceVariable) + def testResourceColocation(self): + partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2) + with ops_lib.device("/job:ps/device:GPU:0"): + v = variable_scope.get_variable("v0", + shape=[10, 2], + partitioner=partitioner, + use_resource=True) + saver_module.Saver({"v0": v}).build() + save_op = None + for op in ops_lib.get_default_graph().get_operations(): + if op.type == "SaveV2": + save_op = op + break + assert save_op is not None + for save_inp in save_op.inputs[3:]: + # Input to SaveV2 op is placed on CPU of the same device as the Variable. + self.assertEqual("/job:ps/device:CPU:0", save_inp.device) + def testResourceVariableReadOpsAddedDeterministically(self): graph_defs = [] num_graphs = 10