Issue a more user-friendly error message if a variable's initializer is from inside a control-flow scope, such as tf.cond() or tf.while_loop().

Fixes #8604.

PiperOrigin-RevId: 157516279
This commit is contained in:
Peter Hawkins 2017-05-30 14:55:51 -07:00 committed by TensorFlower Gardener
parent da2daf068a
commit 2994444bf6
4 changed files with 44 additions and 0 deletions

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@ -256,6 +257,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v.initializer.run(feed_dict={v.initial_value: 3.0})
self.assertEqual(3.0, v.value().eval())
def testControlFlowInitialization(self):
"""Expects an error if an initializer is in a control-flow scope."""
def cond(i, _):
return i < 10
def body(i, _):
zero = array_ops.zeros([], dtype=dtypes.int32)
v = resource_variable_ops.ResourceVariable(initial_value=zero)
return (i + 1, v.read_value())
with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
control_flow_ops.while_loop(cond, body, [0, 0])
if __name__ == "__main__":
test.main()

View File

@ -231,6 +231,19 @@ class VariablesTestCase(test.TestCase):
sess.run(v0.initializer)
sess.run(add)
def testControlFlowInitialization(self):
"""Expects an error if an initializer is in a control-flow scope."""
def cond(i, _):
return i < 10
def body(i, _):
zero = array_ops.zeros([], dtype=dtypes.int32)
v = variables.Variable(initial_value=zero)
return (i + 1, v.read_value())
with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
control_flow_ops.while_loop(cond, body, [0, 0])
def testUseVariableAsTensor(self):
with self.test_session():
var_x = variables.Variable(2.0)

View File

@ -183,6 +183,14 @@ class ResourceVariable(variables.Variable):
else:
self._initial_value = ops.convert_to_tensor(
initial_value, name="initial_value", dtype=dtype)
# pylint: disable=protected-access
if self._initial_value.op._get_control_flow_context() is not None:
raise ValueError(
"Initializer for variable %s is from inside a control-flow "
"construct, such as a loop or conditional. When creating a "
"variable inside a loop or conditional, use a lambda as the "
"initializer." % name)
# pylint: enable=protected-access
self._handle = gen_resource_variable_ops.var_handle_op(
shape=self._initial_value.get_shape(),
dtype=self._initial_value.dtype.base_dtype,

View File

@ -282,11 +282,20 @@ class Variable(object):
shape,
self._initial_value.dtype.base_dtype,
name=name)
# pylint: enable=protected-access
# Or get the initial value from a Tensor or Python object.
else:
self._initial_value = ops.convert_to_tensor(
initial_value, name="initial_value", dtype=dtype)
# pylint: disable=protected-access
if self._initial_value.op._get_control_flow_context() is not None:
raise ValueError(
"Initializer for variable %s is from inside a control-flow "
"construct, such as a loop or conditional. When creating a "
"variable inside a loop or conditional, use a lambda as the "
"initializer." % name)
# pylint: enable=protected-access
shape = (self._initial_value.get_shape()
if validate_shape else tensor_shape.unknown_shape())
# In this case, the variable op can't be created until after the