replica_device_setter should work for resource variables
Change: 152289915
This commit is contained in:
parent
971b11dcca
commit
765006ea74
tensorflow/python/training
@ -198,7 +198,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
|
||||
if ps_ops is None:
|
||||
# TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be
|
||||
# placed in the parameter server.
|
||||
ps_ops = ["Variable", "VariableV2"]
|
||||
ps_ops = ["Variable", "VariableV2", "VarHandleOp"]
|
||||
|
||||
if not merge_devices:
|
||||
logging.warning(
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import device_setter
|
||||
@ -46,6 +47,12 @@ class DeviceSetterTest(test.TestCase):
|
||||
self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
|
||||
self.assertDeviceEqual("/job:worker/cpu:0", a.device)
|
||||
|
||||
def testResource(self):
|
||||
with ops.device(
|
||||
device_setter.replica_device_setter(cluster=self._cluster_spec)):
|
||||
v = resource_variable_ops.ResourceVariable([1, 2])
|
||||
self.assertDeviceEqual("/job:ps/task:0", v.device)
|
||||
|
||||
def testPS2TasksWithClusterSpecClass(self):
|
||||
with ops.device(
|
||||
device_setter.replica_device_setter(cluster=self._cluster_spec)):
|
||||
|
Loading…
Reference in New Issue
Block a user