Issue a more user-friendly error message if a variable's initializer is from inside a control-flow scope, such as tf.cond() or tf.while_loop().
Fixes #8604. PiperOrigin-RevId: 157516279
This commit is contained in:
parent
da2daf068a
commit
2994444bf6
@ -25,6 +25,7 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
@ -256,6 +257,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
|||||||
v.initializer.run(feed_dict={v.initial_value: 3.0})
|
v.initializer.run(feed_dict={v.initial_value: 3.0})
|
||||||
self.assertEqual(3.0, v.value().eval())
|
self.assertEqual(3.0, v.value().eval())
|
||||||
|
|
||||||
|
def testControlFlowInitialization(self):
|
||||||
|
"""Expects an error if an initializer is in a control-flow scope."""
|
||||||
|
def cond(i, _):
|
||||||
|
return i < 10
|
||||||
|
|
||||||
|
def body(i, _):
|
||||||
|
zero = array_ops.zeros([], dtype=dtypes.int32)
|
||||||
|
v = resource_variable_ops.ResourceVariable(initial_value=zero)
|
||||||
|
return (i + 1, v.read_value())
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
|
||||||
|
control_flow_ops.while_loop(cond, body, [0, 0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -231,6 +231,19 @@ class VariablesTestCase(test.TestCase):
|
|||||||
sess.run(v0.initializer)
|
sess.run(v0.initializer)
|
||||||
sess.run(add)
|
sess.run(add)
|
||||||
|
|
||||||
|
def testControlFlowInitialization(self):
|
||||||
|
"""Expects an error if an initializer is in a control-flow scope."""
|
||||||
|
def cond(i, _):
|
||||||
|
return i < 10
|
||||||
|
|
||||||
|
def body(i, _):
|
||||||
|
zero = array_ops.zeros([], dtype=dtypes.int32)
|
||||||
|
v = variables.Variable(initial_value=zero)
|
||||||
|
return (i + 1, v.read_value())
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
|
||||||
|
control_flow_ops.while_loop(cond, body, [0, 0])
|
||||||
|
|
||||||
def testUseVariableAsTensor(self):
|
def testUseVariableAsTensor(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
var_x = variables.Variable(2.0)
|
var_x = variables.Variable(2.0)
|
||||||
|
@ -183,6 +183,14 @@ class ResourceVariable(variables.Variable):
|
|||||||
else:
|
else:
|
||||||
self._initial_value = ops.convert_to_tensor(
|
self._initial_value = ops.convert_to_tensor(
|
||||||
initial_value, name="initial_value", dtype=dtype)
|
initial_value, name="initial_value", dtype=dtype)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if self._initial_value.op._get_control_flow_context() is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Initializer for variable %s is from inside a control-flow "
|
||||||
|
"construct, such as a loop or conditional. When creating a "
|
||||||
|
"variable inside a loop or conditional, use a lambda as the "
|
||||||
|
"initializer." % name)
|
||||||
|
# pylint: enable=protected-access
|
||||||
self._handle = gen_resource_variable_ops.var_handle_op(
|
self._handle = gen_resource_variable_ops.var_handle_op(
|
||||||
shape=self._initial_value.get_shape(),
|
shape=self._initial_value.get_shape(),
|
||||||
dtype=self._initial_value.dtype.base_dtype,
|
dtype=self._initial_value.dtype.base_dtype,
|
||||||
|
@ -282,11 +282,20 @@ class Variable(object):
|
|||||||
shape,
|
shape,
|
||||||
self._initial_value.dtype.base_dtype,
|
self._initial_value.dtype.base_dtype,
|
||||||
name=name)
|
name=name)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
# Or get the initial value from a Tensor or Python object.
|
# Or get the initial value from a Tensor or Python object.
|
||||||
else:
|
else:
|
||||||
self._initial_value = ops.convert_to_tensor(
|
self._initial_value = ops.convert_to_tensor(
|
||||||
initial_value, name="initial_value", dtype=dtype)
|
initial_value, name="initial_value", dtype=dtype)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if self._initial_value.op._get_control_flow_context() is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Initializer for variable %s is from inside a control-flow "
|
||||||
|
"construct, such as a loop or conditional. When creating a "
|
||||||
|
"variable inside a loop or conditional, use a lambda as the "
|
||||||
|
"initializer." % name)
|
||||||
|
# pylint: enable=protected-access
|
||||||
shape = (self._initial_value.get_shape()
|
shape = (self._initial_value.get_shape()
|
||||||
if validate_shape else tensor_shape.unknown_shape())
|
if validate_shape else tensor_shape.unknown_shape())
|
||||||
# In this case, the variable op can't be created until after the
|
# In this case, the variable op can't be created until after the
|
||||||
|
Loading…
Reference in New Issue
Block a user