replica_device_setter should work for resource variables

Change: 152289915
This commit is contained in:
Alexandre Passos 2017-04-05 11:23:31 -08:00 committed by TensorFlower Gardener
parent 971b11dcca
commit 765006ea74
2 changed files with 8 additions and 1 deletions
tensorflow/python/training

View File

@ -198,7 +198,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
if ps_ops is None:
# TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be
# placed in the parameter server.
ps_ops = ["Variable", "VariableV2"]
ps_ops = ["Variable", "VariableV2", "VarHandleOp"]
if not merge_devices:
logging.warning(

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_setter
@ -46,6 +47,12 @@ class DeviceSetterTest(test.TestCase):
self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
self.assertDeviceEqual("/job:worker/cpu:0", a.device)
def testResource(self):
with ops.device(
device_setter.replica_device_setter(cluster=self._cluster_spec)):
v = resource_variable_ops.ResourceVariable([1, 2])
self.assertDeviceEqual("/job:ps/task:0", v.device)
def testPS2TasksWithClusterSpecClass(self):
with ops.device(
device_setter.replica_device_setter(cluster=self._cluster_spec)):