Make resource variable shared name consistent with non-resource variables.

Remove colocation constraint from resource variable cached value with the
variable itself.
Change: 152192203
This commit is contained in:
A. Unique TensorFlower 2017-04-04 15:01:39 -08:00 committed by TensorFlower Gardener
parent a42b3fc598
commit 83afdc92d9
2 changed files with 48 additions and 7 deletions

View File

@ -195,6 +195,39 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertIsInstance(w.dtype, dtypes.DType)
self.assertEqual(v.dtype, w.dtype)
def testCachingDevice(self):
with ops.device("/job:server/task:1"):
v = resource_variable_ops.ResourceVariable(
2.0, caching_device="/job:localhost")
self.assertEqual("/job:localhost", v.value().device)
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
_ = v.value().op.get_attr("_class")
with ops.colocate_with(v.op):
w = resource_variable_ops.ResourceVariable(
2.0, caching_device="/job:localhost")
self.assertEqual("/job:localhost", w.value().device)
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
_ = w.value().op.get_attr("_class")
def testSharedName(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
v.initializer.run()
w = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype,
shape=v.get_shape(),
shared_name="var1")
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, w_read.eval())
x = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype,
shape=v.get_shape(),
shared_name="var1/")
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
with self.assertRaisesOpError("Resource .*/var1//.* does not exist"):
_ = x_read.eval()
if __name__ == "__main__":
test.main()

View File

@ -159,16 +159,15 @@ class ResourceVariable(object):
with ops.control_dependencies(None):
with ops.name_scope(name, "Variable", [] if init_from_fn else
[initial_value]) as name:
# pylint: disable=protected-access
true_name = ops._name_from_scope_name(name)
if init_from_fn:
# Use attr_scope and device(None) to simulate the behavior of
# colocate_with when the variable we want to colocate with doesn't
# yet exist.
# pylint: disable=protected-access
true_name = ops._name_from_scope_name(name)
attr = attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(
s=[compat.as_bytes("loc:@%s" % true_name)]))
# pylint: disable=protected-access
with ops.get_default_graph()._attr_scope({"_class": attr}):
with ops.name_scope("Initializer"), ops.device(None):
self._initial_value = ops.convert_to_tensor(
@ -176,7 +175,8 @@ class ResourceVariable(object):
self._handle = gen_resource_variable_ops.var_handle_op(
shape=self._initial_value.get_shape(),
dtype=self._initial_value.dtype.base_dtype,
shared_name=name, name=name)
shared_name=true_name, name=name)
# pylint: enable=protected-access
# Or get the initial value from a Tensor or Python object.
else:
@ -185,7 +185,7 @@ class ResourceVariable(object):
self._handle = gen_resource_variable_ops.var_handle_op(
shape=self._initial_value.get_shape(),
dtype=self._initial_value.dtype.base_dtype,
shared_name=name, name=name)
shared_name=true_name, name=name)
self._dtype = self._initial_value.dtype.base_dtype
@ -201,8 +201,16 @@ class ResourceVariable(object):
self._handle, dtype=self._dtype)
self._graph_element = value
if caching_device is not None:
with ops.device(caching_device):
self._cached_value = array_ops.identity(value)
# Variables may be created in a tf.device() or ops.colocate_with()
# context. At the same time, users would expect caching device to be
# independent of this context, and/or would not expect the current
# device context to be merged with the caching device spec.
# Therefore we reset the colocation stack before creating the cached
# value. Note that resetting the colocation stack will also reset
# the device stack.
with ops.colocate_with(None, ignore_existing=True):
with ops.device(caching_device):
self._cached_value = array_ops.identity(value)
else:
self._cached_value = None
ops.add_to_collections(collections, self)