diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py index f5c6255c346..ba9359d9234 100644 --- a/tensorflow/python/kernel_tests/partitioned_variables_test.py +++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py @@ -25,12 +25,15 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent class PartitionerCreatorsTest(test.TestCase): @@ -543,32 +546,6 @@ class PartitionedVariablesTestCase(test.TestCase): partitioned_variables.create_partitioned_variables( [10, 43], [1, 50], rnd.initialized_value()) - def testControlDepsNone(self): - with self.test_session() as session: - c = constant_op.constant(1.0) - with ops.control_dependencies([c]): - # d get the control dependency. - d = constant_op.constant(2.0) - # Partitioned variables do not. - var_x = variable_scope.get_variable( - "x", - shape=[2], - initializer=init_ops.ones_initializer(), - partitioner=partitioned_variables.variable_axis_size_partitioner(4)) - - ops_before_read = session.graph.get_operations() - var_x.as_tensor() # Caches the ops for subsequent reads. - reading_ops = [ - op for op in session.graph.get_operations() - if op not in ops_before_read - ] - - self.assertEqual([c.op], d.op.control_inputs) - # Tests that no control dependencies are added to reading a partitioned - # variable which is similar to reading a variable. - for op in reading_ops: - self.assertEqual([], op.control_inputs) - def testConcat(self): with self.test_session() as session: var_x = variable_scope.get_variable( @@ -594,6 +571,57 @@ class PartitionedVariablesTestCase(test.TestCase): variables.global_variables_initializer().run() self.assertAllClose(value.eval(), var_x.as_tensor().eval()) + def testVariableCreationInALoop(self): + """Tests the variable created inside a loop can be used outside the loop.""" + with self.test_session(): + with variable_scope.variable_scope("ascope") as scope: + def Body(i, _): + var_x = variable_scope.get_variable( + "x", + shape=[2], + initializer=init_ops.ones_initializer(), + partitioner=partitioned_variables.variable_axis_size_partitioner( + 4)) + return (i + 1, var_x.as_tensor()) + + cond = lambda i, _: i < 2 + _, x = control_flow_ops.while_loop( + cond, Body, (0, constant_op.constant([7, 8], dtypes.float32))) + variables.global_variables_initializer().run() + self.assertAllClose([1.0, 1.0], x.eval()) + + scope.reuse_variables() + var_x = variable_scope.get_variable( + "x", + shape=[2], + initializer=init_ops.ones_initializer(), + partitioner=partitioned_variables.variable_axis_size_partitioner(4)) + + self.assertAllClose([1.0, 1.0], var_x.as_tensor().eval()) + + def testReadInWhileLoop(self): + """Tests the value is current (not cached) when read within a loop.""" + with self.test_session(): + var_x = variable_scope.get_variable( + "x", + shape=[2], + initializer=init_ops.ones_initializer(), + partitioner=partitioned_variables.variable_axis_size_partitioner(4)) + + def Body(i, _): + # Use a SGD step to update the variable's value. + loss = math_ops.reduce_sum(var_x) + optimizer = gradient_descent.GradientDescentOptimizer(1.0) + minimize = optimizer.minimize(loss * 0.7) + with ops.control_dependencies([minimize]): + return (i + 1, var_x.as_tensor()) + + cond = lambda i, _: i < 2 + _, x = control_flow_ops.while_loop( + cond, Body, (0, constant_op.constant([7, 8], dtypes.float32))) + variables.global_variables_initializer().run() + self.assertAllClose([-0.4, -0.4], x.eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 464c1167d93..402ab2dd9d7 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -1917,15 +1917,10 @@ class PartitionedVariable(object): def as_tensor(self): """Returns the overall concatenated value as a `Tensor`. - The returned tensor will not inherit the control dependencies from the scope - where the value is used, which is similar to getting the value of - `Variable`. - Returns: `Tensor` containing the concatenated value. """ - with ops.control_dependencies(None): - return self._concat() + return self._concat() @staticmethod def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):