From 83afdc92d9bdafeb65c3f84263ad09e337ff29e7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Apr 2017 15:01:39 -0800 Subject: [PATCH] Make resource variable shared name consistent with non-resource variables. Remove colocation constraint from resource variable cached value with the variable itself. Change: 152192203 --- .../resource_variable_ops_test.py | 33 +++++++++++++++++++ .../python/ops/resource_variable_ops.py | 22 +++++++++---- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 0b81dcb8afe..2fba15801cb 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -195,6 +195,39 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertIsInstance(w.dtype, dtypes.DType) self.assertEqual(v.dtype, w.dtype) + def testCachingDevice(self): + with ops.device("/job:server/task:1"): + v = resource_variable_ops.ResourceVariable( + 2.0, caching_device="/job:localhost") + self.assertEqual("/job:localhost", v.value().device) + with self.assertRaisesRegexp(ValueError, "No attr named '_class'"): + _ = v.value().op.get_attr("_class") + + with ops.colocate_with(v.op): + w = resource_variable_ops.ResourceVariable( + 2.0, caching_device="/job:localhost") + self.assertEqual("/job:localhost", w.value().device) + with self.assertRaisesRegexp(ValueError, "No attr named '_class'"): + _ = w.value().op.get_attr("_class") + + def testSharedName(self): + with self.test_session(): + v = resource_variable_ops.ResourceVariable(300.0, name="var1") + v.initializer.run() + + w = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype, + shape=v.get_shape(), + shared_name="var1") + w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) + self.assertEqual(300.0, w_read.eval()) + + x = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype, + shape=v.get_shape(), + shared_name="var1/") + x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) + with self.assertRaisesOpError("Resource .*/var1//.* does not exist"): + _ = x_read.eval() + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 12b39babe0e..86e0cae27ac 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -159,16 +159,15 @@ class ResourceVariable(object): with ops.control_dependencies(None): with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: + # pylint: disable=protected-access + true_name = ops._name_from_scope_name(name) if init_from_fn: # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. - # pylint: disable=protected-access - true_name = ops._name_from_scope_name(name) attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % true_name)])) - # pylint: disable=protected-access with ops.get_default_graph()._attr_scope({"_class": attr}): with ops.name_scope("Initializer"), ops.device(None): self._initial_value = ops.convert_to_tensor( @@ -176,7 +175,8 @@ class ResourceVariable(object): self._handle = gen_resource_variable_ops.var_handle_op( shape=self._initial_value.get_shape(), dtype=self._initial_value.dtype.base_dtype, - shared_name=name, name=name) + shared_name=true_name, name=name) + # pylint: enable=protected-access # Or get the initial value from a Tensor or Python object. else: @@ -185,7 +185,7 @@ class ResourceVariable(object): self._handle = gen_resource_variable_ops.var_handle_op( shape=self._initial_value.get_shape(), dtype=self._initial_value.dtype.base_dtype, - shared_name=name, name=name) + shared_name=true_name, name=name) self._dtype = self._initial_value.dtype.base_dtype @@ -201,8 +201,16 @@ class ResourceVariable(object): self._handle, dtype=self._dtype) self._graph_element = value if caching_device is not None: - with ops.device(caching_device): - self._cached_value = array_ops.identity(value) + # Variables may be created in a tf.device() or ops.colocate_with() + # context. At the same time, users would expect caching device to be + # independent of this context, and/or would not expect the current + # device context to be merged with the caching device spec. + # Therefore we reset the colocation stack before creating the cached + # value. Note that resetting the colocation stack will also reset + # the device stack. + with ops.colocate_with(None, ignore_existing=True): + with ops.device(caching_device): + self._cached_value = array_ops.identity(value) else: self._cached_value = None ops.add_to_collections(collections, self)