diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD index 9f2ce8c676e..f9a57b9c05b 100644 --- a/tensorflow/python/compat/BUILD +++ b/tensorflow/python/compat/BUILD @@ -24,3 +24,14 @@ tf_py_test( "//tensorflow/python:client_testlib", ], ) + +tf_py_test( + name = "disable_v2_behavior_test", + size = "small", + srcs = ["disable_v2_behavior_test.py"], + additional_deps = [ + ":compat", + "//tensorflow/python:framework", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/python/compat/disable_v2_behavior_test.py b/tensorflow/python/compat/disable_v2_behavior_test.py new file mode 100644 index 00000000000..221691b4830 --- /dev/null +++ b/tensorflow/python/compat/disable_v2_behavior_test.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for forward and backwards compatibility utilties.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.compat import compat +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class DisableV2BehaviorTest(test.TestCase): + + def test_basic(self): + t = constant_op.constant([1, 2, 3]) # creates a hidden context + self.assertTrue(isinstance(t, ops.EagerTensor)) + compat.disable_v2_behavior() + t = constant_op.constant([1, 2, 3]) + self.assertFalse(isinstance(t, ops.EagerTensor)) + + +if __name__ == '__main__': + compat.enable_v2_behavior() + test.main() diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e7a9af48662..5dc8e418a29 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -5474,6 +5474,9 @@ def disable_eager_execution(): projects from TensorFlow 1.x to 2.x. """ context.default_execution_mode = context.GRAPH_MODE + c = context.context_safe() + if c is not None: + c._eager_context.is_eager = False # pylint: disable=protected-access def enable_eager_execution_internal(config=None,