Ensure that functions created with tfe.defun create ResourceVariables, not

RefVariables.

Creating and using RefVariables inside functions was never fully supported
and fails in obscure ways.

PiperOrigin-RevId: 208666153
This commit is contained in:
Akshay Agrawal 2018-08-14 10:10:27 -07:00 committed by TensorFlower Gardener
parent 1eb7db417a
commit 4be575a6c2
2 changed files with 15 additions and 0 deletions

View File

@ -42,6 +42,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import distribute from tensorflow.python.training import distribute
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -832,6 +833,8 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph()) func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph())
with func_graph.as_default(), AutomaticControlDependencies() as a: with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
if signature is None: if signature is None:
func_args = _get_defun_inputs_from_args(args) func_args = _get_defun_inputs_from_args(args)
func_kwds = _get_defun_inputs_from_args(kwds) func_kwds = _get_defun_inputs_from_args(kwds)

View File

@ -397,6 +397,18 @@ class FunctionTest(test.TestCase):
compiled = function.defun(f) compiled = function.defun(f)
compiled() compiled()
@test_util.run_in_graph_and_eager_modes
def testDefunForcesResourceVariables(self):
def variable_creator():
return variables.Variable(0.0).read_value()
defined = function.defun(variable_creator)
defined() # Create the variable.
self.assertEqual(len(defined.variables), 1)
self.assertIsInstance(
defined.variables[0], resource_variable_ops.ResourceVariable)
def testDefunDifferentiable(self): def testDefunDifferentiable(self):
v = resource_variable_ops.ResourceVariable(1.0) v = resource_variable_ops.ResourceVariable(1.0)