Make resource variable shared name consistent with non-resource variables.
Remove colocation constraint from resource variable cached value with the variable itself. Change: 152192203
This commit is contained in:
parent
a42b3fc598
commit
83afdc92d9
@ -195,6 +195,39 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertIsInstance(w.dtype, dtypes.DType)
|
||||
self.assertEqual(v.dtype, w.dtype)
|
||||
|
||||
def testCachingDevice(self):
|
||||
with ops.device("/job:server/task:1"):
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
2.0, caching_device="/job:localhost")
|
||||
self.assertEqual("/job:localhost", v.value().device)
|
||||
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
|
||||
_ = v.value().op.get_attr("_class")
|
||||
|
||||
with ops.colocate_with(v.op):
|
||||
w = resource_variable_ops.ResourceVariable(
|
||||
2.0, caching_device="/job:localhost")
|
||||
self.assertEqual("/job:localhost", w.value().device)
|
||||
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
|
||||
_ = w.value().op.get_attr("_class")
|
||||
|
||||
def testSharedName(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
|
||||
v.initializer.run()
|
||||
|
||||
w = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype,
|
||||
shape=v.get_shape(),
|
||||
shared_name="var1")
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
self.assertEqual(300.0, w_read.eval())
|
||||
|
||||
x = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype,
|
||||
shape=v.get_shape(),
|
||||
shared_name="var1/")
|
||||
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
|
||||
with self.assertRaisesOpError("Resource .*/var1//.* does not exist"):
|
||||
_ = x_read.eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -159,16 +159,15 @@ class ResourceVariable(object):
|
||||
with ops.control_dependencies(None):
|
||||
with ops.name_scope(name, "Variable", [] if init_from_fn else
|
||||
[initial_value]) as name:
|
||||
# pylint: disable=protected-access
|
||||
true_name = ops._name_from_scope_name(name)
|
||||
if init_from_fn:
|
||||
# Use attr_scope and device(None) to simulate the behavior of
|
||||
# colocate_with when the variable we want to colocate with doesn't
|
||||
# yet exist.
|
||||
# pylint: disable=protected-access
|
||||
true_name = ops._name_from_scope_name(name)
|
||||
attr = attr_value_pb2.AttrValue(
|
||||
list=attr_value_pb2.AttrValue.ListValue(
|
||||
s=[compat.as_bytes("loc:@%s" % true_name)]))
|
||||
# pylint: disable=protected-access
|
||||
with ops.get_default_graph()._attr_scope({"_class": attr}):
|
||||
with ops.name_scope("Initializer"), ops.device(None):
|
||||
self._initial_value = ops.convert_to_tensor(
|
||||
@ -176,7 +175,8 @@ class ResourceVariable(object):
|
||||
self._handle = gen_resource_variable_ops.var_handle_op(
|
||||
shape=self._initial_value.get_shape(),
|
||||
dtype=self._initial_value.dtype.base_dtype,
|
||||
shared_name=name, name=name)
|
||||
shared_name=true_name, name=name)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Or get the initial value from a Tensor or Python object.
|
||||
else:
|
||||
@ -185,7 +185,7 @@ class ResourceVariable(object):
|
||||
self._handle = gen_resource_variable_ops.var_handle_op(
|
||||
shape=self._initial_value.get_shape(),
|
||||
dtype=self._initial_value.dtype.base_dtype,
|
||||
shared_name=name, name=name)
|
||||
shared_name=true_name, name=name)
|
||||
|
||||
self._dtype = self._initial_value.dtype.base_dtype
|
||||
|
||||
@ -201,8 +201,16 @@ class ResourceVariable(object):
|
||||
self._handle, dtype=self._dtype)
|
||||
self._graph_element = value
|
||||
if caching_device is not None:
|
||||
with ops.device(caching_device):
|
||||
self._cached_value = array_ops.identity(value)
|
||||
# Variables may be created in a tf.device() or ops.colocate_with()
|
||||
# context. At the same time, users would expect caching device to be
|
||||
# independent of this context, and/or would not expect the current
|
||||
# device context to be merged with the caching device spec.
|
||||
# Therefore we reset the colocation stack before creating the cached
|
||||
# value. Note that resetting the colocation stack will also reset
|
||||
# the device stack.
|
||||
with ops.colocate_with(None, ignore_existing=True):
|
||||
with ops.device(caching_device):
|
||||
self._cached_value = array_ops.identity(value)
|
||||
else:
|
||||
self._cached_value = None
|
||||
ops.add_to_collections(collections, self)
|
||||
|
Loading…
Reference in New Issue
Block a user