Support template variable reuse inside Defun
s.
PiperOrigin-RevId: 247156776
This commit is contained in:
parent
db354356d5
commit
fbc2e8bc3d
@ -352,6 +352,21 @@ class _DefinedFunction(object):
|
||||
if self._definition is not None or self._c_func is not None:
|
||||
return
|
||||
|
||||
# Copy variable collections (by reference) from the parent graph such that
|
||||
# name based variable sharing (e.g. via tf.make_template) works between the
|
||||
# func graph and parent graph.
|
||||
variable_keys = []
|
||||
variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access
|
||||
variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access
|
||||
|
||||
collections_ref = {}
|
||||
parent_collections_ref = ops.get_default_graph()._collections # pylint: disable=protected-access
|
||||
for key in variable_keys:
|
||||
if key not in parent_collections_ref:
|
||||
parent_collections_ref[key] = collections_ref[key] = []
|
||||
else:
|
||||
collections_ref[key] = parent_collections_ref[key]
|
||||
|
||||
temp_graph = func_graph_from_py_func(
|
||||
self._func,
|
||||
self._arg_names,
|
||||
@ -359,6 +374,7 @@ class _DefinedFunction(object):
|
||||
self._func_name,
|
||||
self._capture_by_value,
|
||||
self._caller_device,
|
||||
collections_ref=collections_ref,
|
||||
whitelisted_stateful_ops=self._whitelisted_stateful_ops,
|
||||
capture_resource_var_by_value=self._capture_resource_var_by_value)
|
||||
|
||||
|
@ -47,6 +47,7 @@ from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import template
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -1753,13 +1754,64 @@ class VariableHoistingTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasic(self):
|
||||
self._testSimpleModel(True)
|
||||
self._testSimpleModel(False)
|
||||
self._testSimpleModel(True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasicResource(self):
|
||||
self._testSimpleModel(True, use_resource=True)
|
||||
self._testSimpleModel(False, use_resource=True)
|
||||
self._testSimpleModel(True, use_resource=True)
|
||||
|
||||
|
||||
class TemplateTest(test.TestCase):
|
||||
|
||||
@test_util.run_v1_only("make_template not supported in TF2")
|
||||
def testBasic(self):
|
||||
self.assertTemplateVariableSharing(use_resource=True, defun_first=False)
|
||||
|
||||
@test_util.run_v1_only("make_template not supported in TF2")
|
||||
def testBasicRef(self):
|
||||
self.assertTemplateVariableSharing(use_resource=False, defun_first=False)
|
||||
|
||||
@test_util.run_v1_only("make_template not supported in TF2")
|
||||
def testBasicDefunFirst(self):
|
||||
self.assertTemplateVariableSharing(use_resource=True, defun_first=True)
|
||||
|
||||
@test_util.run_v1_only("make_template not supported in TF2")
|
||||
def testBasicRefDefunFirst(self):
|
||||
self.assertTemplateVariableSharing(use_resource=False, defun_first=True)
|
||||
|
||||
def assertTemplateVariableSharing(self, use_resource, defun_first):
|
||||
parameters = []
|
||||
|
||||
def MakeModel(x):
|
||||
w = variable_scope.get_variable(
|
||||
"w", (64, 64),
|
||||
initializer=init_ops.random_uniform_initializer(seed=312),
|
||||
use_resource=use_resource)
|
||||
b = variable_scope.get_variable(
|
||||
"b", (64),
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
use_resource=use_resource)
|
||||
parameters.extend((w, b))
|
||||
return math_ops.sigmoid(math_ops.matmul(x, w) + b)
|
||||
|
||||
model = template.make_template("f", MakeModel, create_scope_now_=True)
|
||||
|
||||
@function.Defun()
|
||||
def ModelDefun(x):
|
||||
return model(x)
|
||||
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
if defun_first:
|
||||
ModelDefun(x)
|
||||
model(x)
|
||||
else:
|
||||
model(x)
|
||||
ModelDefun(x)
|
||||
w1, b1, w2, b2 = parameters # pylint: disable=unbalanced-tuple-unpacking
|
||||
self.assertIs(w1, w2)
|
||||
self.assertIs(b1, b2)
|
||||
|
||||
|
||||
class DevicePlacementTest(test.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user