Ensure that functions created with tfe.defun
create ResourceVariables, not
RefVariables. Creating and using RefVariables inside functions was never fully supported and fails in obscure ways. PiperOrigin-RevId: 208666153
This commit is contained in:
parent
1eb7db417a
commit
4be575a6c2
@ -42,6 +42,7 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.training import distribute
|
from tensorflow.python.training import distribute
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -832,6 +833,8 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
|
|||||||
func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph())
|
func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph())
|
||||||
|
|
||||||
with func_graph.as_default(), AutomaticControlDependencies() as a:
|
with func_graph.as_default(), AutomaticControlDependencies() as a:
|
||||||
|
variable_scope.get_variable_scope().set_use_resource(True)
|
||||||
|
|
||||||
if signature is None:
|
if signature is None:
|
||||||
func_args = _get_defun_inputs_from_args(args)
|
func_args = _get_defun_inputs_from_args(args)
|
||||||
func_kwds = _get_defun_inputs_from_args(kwds)
|
func_kwds = _get_defun_inputs_from_args(kwds)
|
||||||
|
@ -397,6 +397,18 @@ class FunctionTest(test.TestCase):
|
|||||||
compiled = function.defun(f)
|
compiled = function.defun(f)
|
||||||
compiled()
|
compiled()
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testDefunForcesResourceVariables(self):
|
||||||
|
|
||||||
|
def variable_creator():
|
||||||
|
return variables.Variable(0.0).read_value()
|
||||||
|
|
||||||
|
defined = function.defun(variable_creator)
|
||||||
|
defined() # Create the variable.
|
||||||
|
self.assertEqual(len(defined.variables), 1)
|
||||||
|
self.assertIsInstance(
|
||||||
|
defined.variables[0], resource_variable_ops.ResourceVariable)
|
||||||
|
|
||||||
def testDefunDifferentiable(self):
|
def testDefunDifferentiable(self):
|
||||||
v = resource_variable_ops.ResourceVariable(1.0)
|
v = resource_variable_ops.ResourceVariable(1.0)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user