Issue a more user-friendly error message if a variable's initializer is from inside a control-flow scope, such as tf.cond() or tf.while_loop().
Fixes #8604. PiperOrigin-RevId: 157516279
This commit is contained in:
parent
da2daf068a
commit
2994444bf6
@ -25,6 +25,7 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -256,6 +257,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||
v.initializer.run(feed_dict={v.initial_value: 3.0})
|
||||
self.assertEqual(3.0, v.value().eval())
|
||||
|
||||
def testControlFlowInitialization(self):
|
||||
"""Expects an error if an initializer is in a control-flow scope."""
|
||||
def cond(i, _):
|
||||
return i < 10
|
||||
|
||||
def body(i, _):
|
||||
zero = array_ops.zeros([], dtype=dtypes.int32)
|
||||
v = resource_variable_ops.ResourceVariable(initial_value=zero)
|
||||
return (i + 1, v.read_value())
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
|
||||
control_flow_ops.while_loop(cond, body, [0, 0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -231,6 +231,19 @@ class VariablesTestCase(test.TestCase):
|
||||
sess.run(v0.initializer)
|
||||
sess.run(add)
|
||||
|
||||
def testControlFlowInitialization(self):
|
||||
"""Expects an error if an initializer is in a control-flow scope."""
|
||||
def cond(i, _):
|
||||
return i < 10
|
||||
|
||||
def body(i, _):
|
||||
zero = array_ops.zeros([], dtype=dtypes.int32)
|
||||
v = variables.Variable(initial_value=zero)
|
||||
return (i + 1, v.read_value())
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
|
||||
control_flow_ops.while_loop(cond, body, [0, 0])
|
||||
|
||||
def testUseVariableAsTensor(self):
|
||||
with self.test_session():
|
||||
var_x = variables.Variable(2.0)
|
||||
|
@ -183,6 +183,14 @@ class ResourceVariable(variables.Variable):
|
||||
else:
|
||||
self._initial_value = ops.convert_to_tensor(
|
||||
initial_value, name="initial_value", dtype=dtype)
|
||||
# pylint: disable=protected-access
|
||||
if self._initial_value.op._get_control_flow_context() is not None:
|
||||
raise ValueError(
|
||||
"Initializer for variable %s is from inside a control-flow "
|
||||
"construct, such as a loop or conditional. When creating a "
|
||||
"variable inside a loop or conditional, use a lambda as the "
|
||||
"initializer." % name)
|
||||
# pylint: enable=protected-access
|
||||
self._handle = gen_resource_variable_ops.var_handle_op(
|
||||
shape=self._initial_value.get_shape(),
|
||||
dtype=self._initial_value.dtype.base_dtype,
|
||||
|
@ -282,11 +282,20 @@ class Variable(object):
|
||||
shape,
|
||||
self._initial_value.dtype.base_dtype,
|
||||
name=name)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Or get the initial value from a Tensor or Python object.
|
||||
else:
|
||||
self._initial_value = ops.convert_to_tensor(
|
||||
initial_value, name="initial_value", dtype=dtype)
|
||||
# pylint: disable=protected-access
|
||||
if self._initial_value.op._get_control_flow_context() is not None:
|
||||
raise ValueError(
|
||||
"Initializer for variable %s is from inside a control-flow "
|
||||
"construct, such as a loop or conditional. When creating a "
|
||||
"variable inside a loop or conditional, use a lambda as the "
|
||||
"initializer." % name)
|
||||
# pylint: enable=protected-access
|
||||
shape = (self._initial_value.get_shape()
|
||||
if validate_shape else tensor_shape.unknown_shape())
|
||||
# In this case, the variable op can't be created until after the
|
||||
|
Loading…
Reference in New Issue
Block a user