Do not change the default graph in variable_scope when building a function.

When building a TensorFlow function, we need precise control over the default
graph.  This change ensures that, when a function is being built,
variable_scope preserves the default graph.

PiperOrigin-RevId: 175983226
This commit is contained in:
Akshay Agrawal 2017-11-16 10:31:07 -08:00 committed by TensorFlower Gardener
parent aa4162ac9f
commit de8453ff5d
3 changed files with 27 additions and 2 deletions

View File

@ -241,6 +241,7 @@ py_test(
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:function",
"//tensorflow/python/eager:test",
],
)

View File

@ -20,6 +20,7 @@ import gc
from tensorflow.contrib.eager.python import network
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
@ -87,6 +88,23 @@ class NetworkTest(test.TestCase):
result = net(constant_op.constant([[2.0]]))
self.assertEqual(34.0, self.evaluate(result))
# TODO(akshayka): This test should be changed once an API for compiling
# `call` into a defun is implemented.
def testReplacingNetworkCallWithDefun(self):
net = MyNetwork(name="abcd")
x = constant_op.constant([[2.0]])
net(x) # Force variables to be created.
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
net.call = function.defun(net.call)
result = net(x) # Build and execute the TensorFlow function
self.assertEqual(34.0, self.evaluate(result))
# Force the creation of another TensorFlow function by changing input shape
y = constant_op.constant([[1.0], [2.0]])
result = net(y)
self.assertAllEqual([[17.0], [34.0]], self.evaluate(result))
# TODO(allenl): This test creates garbage in some Python versions
@test_util.run_in_graph_and_eager_modes()
def testNetworkSaveRestoreAlreadyBuilt(self):

View File

@ -1828,7 +1828,13 @@ class variable_scope(object): # pylint: disable=invalid-name
self._current_name_scope = None
def __enter__(self):
if self._in_graph_mode:
# If the default graph is building a function, then we should not replace it
# with the cached graph.
if ops.get_default_graph().building_function:
self._building_function = True
else:
self._building_function = False
if self._in_graph_mode and not self._building_function:
self._graph_context_manager = self._graph.as_default()
self._graph_context_manager.__enter__()
if self._cached_pure_variable_scope is not None:
@ -1907,7 +1913,7 @@ class variable_scope(object): # pylint: disable=invalid-name
type_arg, value_arg, traceback_arg)
if self._current_name_scope:
self._current_name_scope.__exit__(type_arg, value_arg, traceback_arg)
if self._in_graph_mode:
if self._in_graph_mode and not self._building_function:
self._graph_context_manager.__exit__(type_arg, value_arg, traceback_arg)