Support template variable reuse inside Defuns.

PiperOrigin-RevId: 247156776
This commit is contained in:
Tom Hennigan 2019-05-07 23:24:05 -07:00 committed by TensorFlower Gardener
parent db354356d5
commit fbc2e8bc3d
2 changed files with 70 additions and 2 deletions

View File

@ -352,6 +352,21 @@ class _DefinedFunction(object):
if self._definition is not None or self._c_func is not None:
return
# Copy variable collections (by reference) from the parent graph such that
# name based variable sharing (e.g. via tf.make_template) works between the
# func graph and parent graph.
variable_keys = []
variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access
variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access
collections_ref = {}
parent_collections_ref = ops.get_default_graph()._collections # pylint: disable=protected-access
for key in variable_keys:
if key not in parent_collections_ref:
parent_collections_ref[key] = collections_ref[key] = []
else:
collections_ref[key] = parent_collections_ref[key]
temp_graph = func_graph_from_py_func(
self._func,
self._arg_names,
@ -359,6 +374,7 @@ class _DefinedFunction(object):
self._func_name,
self._capture_by_value,
self._caller_device,
collections_ref=collections_ref,
whitelisted_stateful_ops=self._whitelisted_stateful_ops,
capture_resource_var_by_value=self._capture_resource_var_by_value)

View File

@ -47,6 +47,7 @@ from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -1753,13 +1754,64 @@ class VariableHoistingTest(test.TestCase):
@test_util.run_deprecated_v1
def testBasic(self):
self._testSimpleModel(True)
self._testSimpleModel(False)
self._testSimpleModel(True)
@test_util.run_deprecated_v1
def testBasicResource(self):
self._testSimpleModel(True, use_resource=True)
self._testSimpleModel(False, use_resource=True)
self._testSimpleModel(True, use_resource=True)
class TemplateTest(test.TestCase):
@test_util.run_v1_only("make_template not supported in TF2")
def testBasic(self):
self.assertTemplateVariableSharing(use_resource=True, defun_first=False)
@test_util.run_v1_only("make_template not supported in TF2")
def testBasicRef(self):
self.assertTemplateVariableSharing(use_resource=False, defun_first=False)
@test_util.run_v1_only("make_template not supported in TF2")
def testBasicDefunFirst(self):
self.assertTemplateVariableSharing(use_resource=True, defun_first=True)
@test_util.run_v1_only("make_template not supported in TF2")
def testBasicRefDefunFirst(self):
self.assertTemplateVariableSharing(use_resource=False, defun_first=True)
def assertTemplateVariableSharing(self, use_resource, defun_first):
parameters = []
def MakeModel(x):
w = variable_scope.get_variable(
"w", (64, 64),
initializer=init_ops.random_uniform_initializer(seed=312),
use_resource=use_resource)
b = variable_scope.get_variable(
"b", (64),
initializer=init_ops.zeros_initializer(),
use_resource=use_resource)
parameters.extend((w, b))
return math_ops.sigmoid(math_ops.matmul(x, w) + b)
model = template.make_template("f", MakeModel, create_scope_now_=True)
@function.Defun()
def ModelDefun(x):
return model(x)
x = array_ops.placeholder(dtypes.float32)
if defun_first:
ModelDefun(x)
model(x)
else:
model(x)
ModelDefun(x)
w1, b1, w2, b2 = parameters # pylint: disable=unbalanced-tuple-unpacking
self.assertIs(w1, w2)
self.assertIs(b1, b2)
class DevicePlacementTest(test.TestCase):