disable_eager_execution actually disables eager execution

This makes it somewhat less safe to call (as there can be eager tensorso out
there which will mess up graph building) but makes it do what it says on the
label.

PiperOrigin-RevId: 226947417
This commit is contained in:
Alexandre Passos 2018-12-26 13:12:33 -08:00 committed by TensorFlower Gardener
parent 98bbee7afe
commit 83cb1f1c5e
3 changed files with 53 additions and 0 deletions

View File

@ -24,3 +24,14 @@ tf_py_test(
"//tensorflow/python:client_testlib",
],
)
tf_py_test(
name = "disable_v2_behavior_test",
size = "small",
srcs = ["disable_v2_behavior_test.py"],
additional_deps = [
":compat",
"//tensorflow/python:framework",
"//tensorflow/python:client_testlib",
],
)

View File

@ -0,0 +1,39 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for forward and backwards compatibility utilties."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
class DisableV2BehaviorTest(test.TestCase):
def test_basic(self):
t = constant_op.constant([1, 2, 3]) # creates a hidden context
self.assertTrue(isinstance(t, ops.EagerTensor))
compat.disable_v2_behavior()
t = constant_op.constant([1, 2, 3])
self.assertFalse(isinstance(t, ops.EagerTensor))
if __name__ == '__main__':
compat.enable_v2_behavior()
test.main()

View File

@ -5474,6 +5474,9 @@ def disable_eager_execution():
projects from TensorFlow 1.x to 2.x.
"""
context.default_execution_mode = context.GRAPH_MODE
c = context.context_safe()
if c is not None:
c._eager_context.is_eager = False # pylint: disable=protected-access
def enable_eager_execution_internal(config=None,