diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index ffac7d1eb36..d287ea2fcd4 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -352,6 +352,21 @@ class _DefinedFunction(object): if self._definition is not None or self._c_func is not None: return + # Copy variable collections (by reference) from the parent graph such that + # name based variable sharing (e.g. via tf.make_template) works between the + # func graph and parent graph. + variable_keys = [] + variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access + variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access + + collections_ref = {} + parent_collections_ref = ops.get_default_graph()._collections # pylint: disable=protected-access + for key in variable_keys: + if key not in parent_collections_ref: + parent_collections_ref[key] = collections_ref[key] = [] + else: + collections_ref[key] = parent_collections_ref[key] + temp_graph = func_graph_from_py_func( self._func, self._arg_names, @@ -359,6 +374,7 @@ class _DefinedFunction(object): self._func_name, self._capture_by_value, self._caller_device, + collections_ref=collections_ref, whitelisted_stateful_ops=self._whitelisted_stateful_ops, capture_resource_var_by_value=self._capture_resource_var_by_value) diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 21c6565ca97..57f50b888f5 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -47,6 +47,7 @@ from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -1753,13 +1754,64 @@ class VariableHoistingTest(test.TestCase): @test_util.run_deprecated_v1 def testBasic(self): - self._testSimpleModel(True) self._testSimpleModel(False) + self._testSimpleModel(True) @test_util.run_deprecated_v1 def testBasicResource(self): - self._testSimpleModel(True, use_resource=True) self._testSimpleModel(False, use_resource=True) + self._testSimpleModel(True, use_resource=True) + + +class TemplateTest(test.TestCase): + + @test_util.run_v1_only("make_template not supported in TF2") + def testBasic(self): + self.assertTemplateVariableSharing(use_resource=True, defun_first=False) + + @test_util.run_v1_only("make_template not supported in TF2") + def testBasicRef(self): + self.assertTemplateVariableSharing(use_resource=False, defun_first=False) + + @test_util.run_v1_only("make_template not supported in TF2") + def testBasicDefunFirst(self): + self.assertTemplateVariableSharing(use_resource=True, defun_first=True) + + @test_util.run_v1_only("make_template not supported in TF2") + def testBasicRefDefunFirst(self): + self.assertTemplateVariableSharing(use_resource=False, defun_first=True) + + def assertTemplateVariableSharing(self, use_resource, defun_first): + parameters = [] + + def MakeModel(x): + w = variable_scope.get_variable( + "w", (64, 64), + initializer=init_ops.random_uniform_initializer(seed=312), + use_resource=use_resource) + b = variable_scope.get_variable( + "b", (64), + initializer=init_ops.zeros_initializer(), + use_resource=use_resource) + parameters.extend((w, b)) + return math_ops.sigmoid(math_ops.matmul(x, w) + b) + + model = template.make_template("f", MakeModel, create_scope_now_=True) + + @function.Defun() + def ModelDefun(x): + return model(x) + + x = array_ops.placeholder(dtypes.float32) + if defun_first: + ModelDefun(x) + model(x) + else: + model(x) + ModelDefun(x) + w1, b1, w2, b2 = parameters # pylint: disable=unbalanced-tuple-unpacking + self.assertIs(w1, w2) + self.assertIs(b1, b2) class DevicePlacementTest(test.TestCase):