Make resource variable shared name consistent with non-resource variables.
Remove colocation constraint from resource variable cached value with the variable itself. Change: 152192203
This commit is contained in:
parent
a42b3fc598
commit
83afdc92d9
@ -195,6 +195,39 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIsInstance(w.dtype, dtypes.DType)
|
self.assertIsInstance(w.dtype, dtypes.DType)
|
||||||
self.assertEqual(v.dtype, w.dtype)
|
self.assertEqual(v.dtype, w.dtype)
|
||||||
|
|
||||||
|
def testCachingDevice(self):
|
||||||
|
with ops.device("/job:server/task:1"):
|
||||||
|
v = resource_variable_ops.ResourceVariable(
|
||||||
|
2.0, caching_device="/job:localhost")
|
||||||
|
self.assertEqual("/job:localhost", v.value().device)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
|
||||||
|
_ = v.value().op.get_attr("_class")
|
||||||
|
|
||||||
|
with ops.colocate_with(v.op):
|
||||||
|
w = resource_variable_ops.ResourceVariable(
|
||||||
|
2.0, caching_device="/job:localhost")
|
||||||
|
self.assertEqual("/job:localhost", w.value().device)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
|
||||||
|
_ = w.value().op.get_attr("_class")
|
||||||
|
|
||||||
|
def testSharedName(self):
|
||||||
|
with self.test_session():
|
||||||
|
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
|
||||||
|
v.initializer.run()
|
||||||
|
|
||||||
|
w = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype,
|
||||||
|
shape=v.get_shape(),
|
||||||
|
shared_name="var1")
|
||||||
|
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||||
|
self.assertEqual(300.0, w_read.eval())
|
||||||
|
|
||||||
|
x = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype,
|
||||||
|
shape=v.get_shape(),
|
||||||
|
shared_name="var1/")
|
||||||
|
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
|
||||||
|
with self.assertRaisesOpError("Resource .*/var1//.* does not exist"):
|
||||||
|
_ = x_read.eval()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -159,16 +159,15 @@ class ResourceVariable(object):
|
|||||||
with ops.control_dependencies(None):
|
with ops.control_dependencies(None):
|
||||||
with ops.name_scope(name, "Variable", [] if init_from_fn else
|
with ops.name_scope(name, "Variable", [] if init_from_fn else
|
||||||
[initial_value]) as name:
|
[initial_value]) as name:
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
true_name = ops._name_from_scope_name(name)
|
||||||
if init_from_fn:
|
if init_from_fn:
|
||||||
# Use attr_scope and device(None) to simulate the behavior of
|
# Use attr_scope and device(None) to simulate the behavior of
|
||||||
# colocate_with when the variable we want to colocate with doesn't
|
# colocate_with when the variable we want to colocate with doesn't
|
||||||
# yet exist.
|
# yet exist.
|
||||||
# pylint: disable=protected-access
|
|
||||||
true_name = ops._name_from_scope_name(name)
|
|
||||||
attr = attr_value_pb2.AttrValue(
|
attr = attr_value_pb2.AttrValue(
|
||||||
list=attr_value_pb2.AttrValue.ListValue(
|
list=attr_value_pb2.AttrValue.ListValue(
|
||||||
s=[compat.as_bytes("loc:@%s" % true_name)]))
|
s=[compat.as_bytes("loc:@%s" % true_name)]))
|
||||||
# pylint: disable=protected-access
|
|
||||||
with ops.get_default_graph()._attr_scope({"_class": attr}):
|
with ops.get_default_graph()._attr_scope({"_class": attr}):
|
||||||
with ops.name_scope("Initializer"), ops.device(None):
|
with ops.name_scope("Initializer"), ops.device(None):
|
||||||
self._initial_value = ops.convert_to_tensor(
|
self._initial_value = ops.convert_to_tensor(
|
||||||
@ -176,7 +175,8 @@ class ResourceVariable(object):
|
|||||||
self._handle = gen_resource_variable_ops.var_handle_op(
|
self._handle = gen_resource_variable_ops.var_handle_op(
|
||||||
shape=self._initial_value.get_shape(),
|
shape=self._initial_value.get_shape(),
|
||||||
dtype=self._initial_value.dtype.base_dtype,
|
dtype=self._initial_value.dtype.base_dtype,
|
||||||
shared_name=name, name=name)
|
shared_name=true_name, name=name)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
# Or get the initial value from a Tensor or Python object.
|
# Or get the initial value from a Tensor or Python object.
|
||||||
else:
|
else:
|
||||||
@ -185,7 +185,7 @@ class ResourceVariable(object):
|
|||||||
self._handle = gen_resource_variable_ops.var_handle_op(
|
self._handle = gen_resource_variable_ops.var_handle_op(
|
||||||
shape=self._initial_value.get_shape(),
|
shape=self._initial_value.get_shape(),
|
||||||
dtype=self._initial_value.dtype.base_dtype,
|
dtype=self._initial_value.dtype.base_dtype,
|
||||||
shared_name=name, name=name)
|
shared_name=true_name, name=name)
|
||||||
|
|
||||||
self._dtype = self._initial_value.dtype.base_dtype
|
self._dtype = self._initial_value.dtype.base_dtype
|
||||||
|
|
||||||
@ -201,8 +201,16 @@ class ResourceVariable(object):
|
|||||||
self._handle, dtype=self._dtype)
|
self._handle, dtype=self._dtype)
|
||||||
self._graph_element = value
|
self._graph_element = value
|
||||||
if caching_device is not None:
|
if caching_device is not None:
|
||||||
with ops.device(caching_device):
|
# Variables may be created in a tf.device() or ops.colocate_with()
|
||||||
self._cached_value = array_ops.identity(value)
|
# context. At the same time, users would expect caching device to be
|
||||||
|
# independent of this context, and/or would not expect the current
|
||||||
|
# device context to be merged with the caching device spec.
|
||||||
|
# Therefore we reset the colocation stack before creating the cached
|
||||||
|
# value. Note that resetting the colocation stack will also reset
|
||||||
|
# the device stack.
|
||||||
|
with ops.colocate_with(None, ignore_existing=True):
|
||||||
|
with ops.device(caching_device):
|
||||||
|
self._cached_value = array_ops.identity(value)
|
||||||
else:
|
else:
|
||||||
self._cached_value = None
|
self._cached_value = None
|
||||||
ops.add_to_collections(collections, self)
|
ops.add_to_collections(collections, self)
|
||||||
|
Loading…
Reference in New Issue
Block a user