From 4be575a6c2fcefeacb62179ad87abfa4419b302a Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Tue, 14 Aug 2018 10:10:27 -0700 Subject: [PATCH] Ensure that functions created with `tfe.defun` create ResourceVariables, not RefVariables. Creating and using RefVariables inside functions was never fully supported and fails in obscure ways. PiperOrigin-RevId: 208666153 --- tensorflow/python/eager/function.py | 3 +++ tensorflow/python/eager/function_test.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index f87d88040f1..e7db5c15225 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -42,6 +42,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.training import distribute from tensorflow.python.util import compat from tensorflow.python.util import nest @@ -832,6 +833,8 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds, func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph()) with func_graph.as_default(), AutomaticControlDependencies() as a: + variable_scope.get_variable_scope().set_use_resource(True) + if signature is None: func_args = _get_defun_inputs_from_args(args) func_kwds = _get_defun_inputs_from_args(kwds) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 0488dc97521..7f28fc15e54 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -397,6 +397,18 @@ class FunctionTest(test.TestCase): compiled = function.defun(f) compiled() + @test_util.run_in_graph_and_eager_modes + def testDefunForcesResourceVariables(self): + + def variable_creator(): + return variables.Variable(0.0).read_value() + + defined = function.defun(variable_creator) + defined() # Create the variable. + self.assertEqual(len(defined.variables), 1) + self.assertIsInstance( + defined.variables[0], resource_variable_ops.ResourceVariable) + def testDefunDifferentiable(self): v = resource_variable_ops.ResourceVariable(1.0)