Remove control dependency escape from PartitionedVariable.concat().

Wrapping the as_tensor() method of PartitionedVariable is incorrect. It means, for example, that concatenation happens outside the context of any enclosing loop, so loop iterations would see stale values. In particular this is the wrong behavior for models wrapped in a in-graph training loop.

PiperOrigin-RevId: 208663593
This commit is contained in:
Peter Hawkins 2018-08-14 09:56:20 -07:00 committed by TensorFlower Gardener
parent 01267485b6
commit 1eb7db417a
2 changed files with 55 additions and 32 deletions

View File

@ -25,12 +25,15 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
class PartitionerCreatorsTest(test.TestCase):
@ -543,32 +546,6 @@ class PartitionedVariablesTestCase(test.TestCase):
partitioned_variables.create_partitioned_variables(
[10, 43], [1, 50], rnd.initialized_value())
def testControlDepsNone(self):
with self.test_session() as session:
c = constant_op.constant(1.0)
with ops.control_dependencies([c]):
# d get the control dependency.
d = constant_op.constant(2.0)
# Partitioned variables do not.
var_x = variable_scope.get_variable(
"x",
shape=[2],
initializer=init_ops.ones_initializer(),
partitioner=partitioned_variables.variable_axis_size_partitioner(4))
ops_before_read = session.graph.get_operations()
var_x.as_tensor() # Caches the ops for subsequent reads.
reading_ops = [
op for op in session.graph.get_operations()
if op not in ops_before_read
]
self.assertEqual([c.op], d.op.control_inputs)
# Tests that no control dependencies are added to reading a partitioned
# variable which is similar to reading a variable.
for op in reading_ops:
self.assertEqual([], op.control_inputs)
def testConcat(self):
with self.test_session() as session:
var_x = variable_scope.get_variable(
@ -594,6 +571,57 @@ class PartitionedVariablesTestCase(test.TestCase):
variables.global_variables_initializer().run()
self.assertAllClose(value.eval(), var_x.as_tensor().eval())
def testVariableCreationInALoop(self):
"""Tests the variable created inside a loop can be used outside the loop."""
with self.test_session():
with variable_scope.variable_scope("ascope") as scope:
def Body(i, _):
var_x = variable_scope.get_variable(
"x",
shape=[2],
initializer=init_ops.ones_initializer(),
partitioner=partitioned_variables.variable_axis_size_partitioner(
4))
return (i + 1, var_x.as_tensor())
cond = lambda i, _: i < 2
_, x = control_flow_ops.while_loop(
cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
variables.global_variables_initializer().run()
self.assertAllClose([1.0, 1.0], x.eval())
scope.reuse_variables()
var_x = variable_scope.get_variable(
"x",
shape=[2],
initializer=init_ops.ones_initializer(),
partitioner=partitioned_variables.variable_axis_size_partitioner(4))
self.assertAllClose([1.0, 1.0], var_x.as_tensor().eval())
def testReadInWhileLoop(self):
"""Tests the value is current (not cached) when read within a loop."""
with self.test_session():
var_x = variable_scope.get_variable(
"x",
shape=[2],
initializer=init_ops.ones_initializer(),
partitioner=partitioned_variables.variable_axis_size_partitioner(4))
def Body(i, _):
# Use a SGD step to update the variable's value.
loss = math_ops.reduce_sum(var_x)
optimizer = gradient_descent.GradientDescentOptimizer(1.0)
minimize = optimizer.minimize(loss * 0.7)
with ops.control_dependencies([minimize]):
return (i + 1, var_x.as_tensor())
cond = lambda i, _: i < 2
_, x = control_flow_ops.while_loop(
cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
variables.global_variables_initializer().run()
self.assertAllClose([-0.4, -0.4], x.eval())
if __name__ == "__main__":
test.main()

View File

@ -1917,14 +1917,9 @@ class PartitionedVariable(object):
def as_tensor(self):
"""Returns the overall concatenated value as a `Tensor`.
The returned tensor will not inherit the control dependencies from the scope
where the value is used, which is similar to getting the value of
`Variable`.
Returns:
`Tensor` containing the concatenated value.
"""
with ops.control_dependencies(None):
return self._concat()
@staticmethod