Do not change the default graph in variable_scope when building a function.
When building a TensorFlow function, we need precise control over the default graph. This change ensures that, when a function is being built, variable_scope preserves the default graph. PiperOrigin-RevId: 175983226
This commit is contained in:
parent
aa4162ac9f
commit
de8453ff5d
@ -241,6 +241,7 @@ py_test(
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
||||
@ -20,6 +20,7 @@ import gc
|
||||
|
||||
from tensorflow.contrib.eager.python import network
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -87,6 +88,23 @@ class NetworkTest(test.TestCase):
|
||||
result = net(constant_op.constant([[2.0]]))
|
||||
self.assertEqual(34.0, self.evaluate(result))
|
||||
|
||||
# TODO(akshayka): This test should be changed once an API for compiling
|
||||
# `call` into a defun is implemented.
|
||||
def testReplacingNetworkCallWithDefun(self):
|
||||
net = MyNetwork(name="abcd")
|
||||
x = constant_op.constant([[2.0]])
|
||||
net(x) # Force variables to be created.
|
||||
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
|
||||
|
||||
net.call = function.defun(net.call)
|
||||
result = net(x) # Build and execute the TensorFlow function
|
||||
self.assertEqual(34.0, self.evaluate(result))
|
||||
|
||||
# Force the creation of another TensorFlow function by changing input shape
|
||||
y = constant_op.constant([[1.0], [2.0]])
|
||||
result = net(y)
|
||||
self.assertAllEqual([[17.0], [34.0]], self.evaluate(result))
|
||||
|
||||
# TODO(allenl): This test creates garbage in some Python versions
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNetworkSaveRestoreAlreadyBuilt(self):
|
||||
|
||||
@ -1828,7 +1828,13 @@ class variable_scope(object): # pylint: disable=invalid-name
|
||||
self._current_name_scope = None
|
||||
|
||||
def __enter__(self):
|
||||
if self._in_graph_mode:
|
||||
# If the default graph is building a function, then we should not replace it
|
||||
# with the cached graph.
|
||||
if ops.get_default_graph().building_function:
|
||||
self._building_function = True
|
||||
else:
|
||||
self._building_function = False
|
||||
if self._in_graph_mode and not self._building_function:
|
||||
self._graph_context_manager = self._graph.as_default()
|
||||
self._graph_context_manager.__enter__()
|
||||
if self._cached_pure_variable_scope is not None:
|
||||
@ -1907,7 +1913,7 @@ class variable_scope(object): # pylint: disable=invalid-name
|
||||
type_arg, value_arg, traceback_arg)
|
||||
if self._current_name_scope:
|
||||
self._current_name_scope.__exit__(type_arg, value_arg, traceback_arg)
|
||||
if self._in_graph_mode:
|
||||
if self._in_graph_mode and not self._building_function:
|
||||
self._graph_context_manager.__exit__(type_arg, value_arg, traceback_arg)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user