Ensure that functions created with tfe.defun
create ResourceVariables, not
RefVariables. Creating and using RefVariables inside functions was never fully supported and fails in obscure ways. PiperOrigin-RevId: 208666153
This commit is contained in:
parent
1eb7db417a
commit
4be575a6c2
@ -42,6 +42,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import distribute
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
@ -832,6 +833,8 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
|
||||
func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph())
|
||||
|
||||
with func_graph.as_default(), AutomaticControlDependencies() as a:
|
||||
variable_scope.get_variable_scope().set_use_resource(True)
|
||||
|
||||
if signature is None:
|
||||
func_args = _get_defun_inputs_from_args(args)
|
||||
func_kwds = _get_defun_inputs_from_args(kwds)
|
||||
|
@ -397,6 +397,18 @@ class FunctionTest(test.TestCase):
|
||||
compiled = function.defun(f)
|
||||
compiled()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testDefunForcesResourceVariables(self):
|
||||
|
||||
def variable_creator():
|
||||
return variables.Variable(0.0).read_value()
|
||||
|
||||
defined = function.defun(variable_creator)
|
||||
defined() # Create the variable.
|
||||
self.assertEqual(len(defined.variables), 1)
|
||||
self.assertIsInstance(
|
||||
defined.variables[0], resource_variable_ops.ResourceVariable)
|
||||
|
||||
def testDefunDifferentiable(self):
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user