From c4f8aa9853a0ea91bb41d4a10340d08d57b9d664 Mon Sep 17 00:00:00 2001 From: Pankaj Kanwar Date: Wed, 6 Jan 2021 20:30:39 -0800 Subject: [PATCH] fixed the breakage. Re-submit plumbing of the enable_v2_behavior() call to C++ PiperOrigin-RevId: 350478865 Change-Id: Ib5baae7e822f646bbcd796f6760c9f7239f021af --- tensorflow/core/platform/BUILD | 31 ++++++ tensorflow/core/platform/enable_tf2_utils.cc | 49 ++++++++++ tensorflow/core/platform/enable_tf2_utils.h | 31 ++++++ .../core/platform/enable_tf2_utils_test.cc | 35 +++++++ tensorflow/python/BUILD | 11 ++- .../python/compat/disable_v2_behavior_test.py | 5 + tensorflow/python/framework/test_util.py | 7 +- tensorflow/python/framework/tf2_test.py | 97 ++++++++++--------- tensorflow/python/platform/BUILD | 16 +++ tensorflow/python/platform/enable_tf2.cc | 22 +++++ tensorflow/python/tf2.py | 16 +-- 11 files changed, 260 insertions(+), 60 deletions(-) create mode 100644 tensorflow/core/platform/enable_tf2_utils.cc create mode 100644 tensorflow/core/platform/enable_tf2_utils.h create mode 100644 tensorflow/core/platform/enable_tf2_utils_test.cc create mode 100644 tensorflow/python/platform/enable_tf2.cc diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index a680ea69bda..2ef10347d1e 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -981,6 +981,17 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "enable_tf2_utils", + srcs = ["enable_tf2_utils.cc"], + hdrs = ["enable_tf2_utils.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core/util:env_var", + ], + alwayslink = 1, +) + alias( name = "profile_utils_cpu_utils", actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils", @@ -992,6 +1003,12 @@ filegroup( compatible_with = get_compatible_with_portable(), ) +filegroup( + name = "enable_tf2_hdr", + srcs = ["enable_tf2_utils.h"], + compatible_with = get_compatible_with_portable(), +) + tf_cc_tests( name = "low_level_library_tests", size = "small", @@ -1047,6 +1064,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "enable_tf2_utils_test", + size = "small", + srcs = [ + "enable_tf2_utils_test.cc", + ], + deps = [ + ":enable_tf2_utils", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/util:env_var", + ], +) + tf_cc_tests( name = "stacktrace_handler_test", size = "small", diff --git a/tensorflow/core/platform/enable_tf2_utils.cc b/tensorflow/core/platform/enable_tf2_utils.cc new file mode 100644 index 00000000000..ddcee7ae0c8 --- /dev/null +++ b/tensorflow/core/platform/enable_tf2_utils.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/enable_tf2_utils.h" + +#include + +#include "tensorflow/core/util/env_var.h" + +namespace tensorflow { + +enum Enablement : uint8 { kFalse = 0, kTrue = 1, undefined = 2 }; + +// If this flag is set, we will use it as a signal to decide on whether to +// use the MLIR based TF-XLA bridge. +static std::atomic tf2_enabled{undefined}; + +// Determine whether or not the user has explicitly asked for tf2 execution. +// Will be used to determine whether to use the MLIR based bridge. +void set_tf2_execution(bool enabled) { + tf2_enabled = (enabled) ? Enablement::kTrue : Enablement::kFalse; +} + +bool tf2_execution_enabled() { + if (tf2_enabled == Enablement::undefined) { + static bool tf2_behavior_env_enabled = [] { + string tf2_env; + TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env)); + return tf2_env != "0"; + }(); + tf2_enabled = + (tf2_behavior_env_enabled) ? Enablement::kTrue : Enablement::kFalse; + } + return tf2_enabled; +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/enable_tf2_utils.h b/tensorflow/core/platform/enable_tf2_utils.h new file mode 100644 index 00000000000..49c611d7a24 --- /dev/null +++ b/tensorflow/core/platform/enable_tf2_utils.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TF_CORE_PLATFORM_TF2_UTILS_H_ +#define TF_CORE_PLATFORM_TF2_UTILS_H_ + +namespace tensorflow { + +// Sets the tf2 execution state. This can be used to indicate whether the user +// has explicitly asked for tf2 execution. +void set_tf2_execution(bool enabled); + +// Returns true or false depending on whether the user flag for tf2 execution +// has been set. The default is false. +bool tf2_execution_enabled(); + +} // namespace tensorflow + +#endif // TF_CORE_PLATFORM_TF2_UTILS_H_ diff --git a/tensorflow/core/platform/enable_tf2_utils_test.cc b/tensorflow/core/platform/enable_tf2_utils_test.cc new file mode 100644 index 00000000000..f50df9518fa --- /dev/null +++ b/tensorflow/core/platform/enable_tf2_utils_test.cc @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Testing TF2 enablement. + +#include "tensorflow/core/platform/enable_tf2_utils.h" + +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/env_var.h" + +namespace tensorflow { + +TEST(TF2EnabledTest, enabled_behavior) { + string tf2_env; + TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env)); + bool expected = (tf2_env != "0"); + EXPECT_EQ(tensorflow::tf2_execution_enabled(), expected); + tensorflow::set_tf2_execution(true); + EXPECT_TRUE(tensorflow::tf2_execution_enabled()); + tensorflow::set_tf2_execution(false); + EXPECT_FALSE(tensorflow::tf2_execution_enabled()); +} + +} // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 96e01db5069..0c971cc1026 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5266,7 +5266,10 @@ pywrap_tensorflow_macro( tf_additional_plugin_deps() + tf_additional_profiler_deps()) + if_xla_available([ "//tensorflow/compiler/aot:tfcompile_lib", - ]) + if_static(extra_deps = ["//tensorflow/core/platform:tensor_float_32_utils"]), + ]) + if_static(extra_deps = [ + "//tensorflow/core/platform:tensor_float_32_utils", + "//tensorflow/core/platform:enable_tf2_utils", + ]), ) # ** Targets for Windows build (start) ** @@ -6768,6 +6771,10 @@ py_library( name = "tf2", srcs = ["tf2.py"], srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/platform:_pywrap_tf2", + "//tensorflow/python/util:tf_export", + ], ) py_test( @@ -6779,6 +6786,8 @@ py_test( ":client_testlib", ":framework_combinations", ":tf2", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/data/kernel_tests:test_base", ], ) diff --git a/tensorflow/python/compat/disable_v2_behavior_test.py b/tensorflow/python/compat/disable_v2_behavior_test.py index 4b955d3f46d..5f45621e988 100644 --- a/tensorflow/python/compat/disable_v2_behavior_test.py +++ b/tensorflow/python/compat/disable_v2_behavior_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.compat import v2_compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.platform import _pywrap_tf2 from tensorflow.python.platform import test @@ -29,9 +30,13 @@ class DisableV2BehaviorTest(test.TestCase): def test_basic(self): t = constant_op.constant([1, 2, 3]) # creates a hidden context self.assertTrue(isinstance(t, ops.EagerTensor)) + t = _pywrap_tf2.is_enabled() + self.assertTrue(t) v2_compat.disable_v2_behavior() t = constant_op.constant([1, 2, 3]) self.assertFalse(isinstance(t, ops.EagerTensor)) + t = _pywrap_tf2.is_enabled() + self.assertFalse(t) if __name__ == '__main__': diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 1781a450360..d9e9ca088d1 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -2047,6 +2047,10 @@ class TensorFlowTestCase(googletest.TestCase): self._tempdir = None self._cached_session = None self._test_start_time = None + # This flag provides the ability to control whether the graph mode gets + # initialized for TF1 or not. Initializing for TF1, which is what was + # happening earlier, was preventing enablement of 'eager mode' in the test. + self._set_default_seed = True def setUp(self): super(TensorFlowTestCase, self).setUp() @@ -2061,7 +2065,8 @@ class TensorFlowTestCase(googletest.TestCase): # cleared first. ops._default_graph_stack.reset() # pylint: disable=protected-access ops.reset_default_graph() - random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED) + if self._set_default_seed: + random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED) # Reset summary writer in case another test used set_as_default() with their # summary writer. summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access diff --git a/tensorflow/python/framework/tf2_test.py b/tensorflow/python/framework/tf2_test.py index fe9dac4fe4b..1ded46d9ece 100644 --- a/tensorflow/python/framework/tf2_test.py +++ b/tensorflow/python/framework/tf2_test.py @@ -18,68 +18,73 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from absl.testing import parameterized from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import combinations +from tensorflow.python.platform import _pywrap_tf2 from tensorflow.python.platform import test -def set_environ(): - os.environ['TF2_BEHAVIOR'] = '1' - - -def unset_environ(): - os.environ['TF2_BEHAVIOR'] = '0' - - class EnablingTF2Behavior(test.TestCase, parameterized.TestCase): - def setUp(self): - super(EnablingTF2Behavior, self).setUp() - tf2._force_enable = None - if 'TF2_BEHAVIOR' in os.environ: - del os.environ['TF2_BEHAVIOR'] + def __init__(self, methodName): + super().__init__(methodName) + self._set_default_seed = False - actions = [tf2.enable, tf2.disable, set_environ, unset_environ] + @combinations.generate(test_base.v1_only_combinations()) + def test_tf1_enable_tf2_behaviour(self): + self.assertFalse(tf2.enabled()) + self.assertFalse(_pywrap_tf2.is_enabled()) - @combinations.generate( - combinations.combine( - action_0=actions, action_1=actions, - action_2=actions, action_3=actions)) - def test_scenarios(self, action_0, action_1, action_2, action_3): + v2_compat.enable_v2_behavior() + self.assertTrue(tf2.enabled()) + self.assertTrue(_pywrap_tf2.is_enabled()) - def state(action, enabled, disabled): - """Returns bool tuple (tf2_enabled, force_enabled, force_disabled).""" - if action is tf2.enable: - return True, True, False - elif action is tf2.disable: - return False, False, True - elif action is set_environ: - return not disabled, enabled, disabled - elif action is unset_environ: - return enabled, enabled, disabled - else: - raise ValueError('Unexpected action {}. {} are supported'.format( - action, EnablingTF2Behavior.actions)) + v2_compat.disable_v2_behavior() + self.assertFalse(tf2.enabled()) + self.assertFalse(_pywrap_tf2.is_enabled()) - action_0() - expected, enabled, disabled = state(action_0, False, False) - self.assertEqual(tf2.enabled(), expected) + @combinations.generate(test_base.v1_only_combinations()) + def test_tf1_disable_tf2_behaviour(self): + self.assertFalse(tf2.enabled()) + self.assertFalse(_pywrap_tf2.is_enabled()) - action_1() - expected, enabled, disabled = state(action_1, enabled, disabled) - self.assertEqual(tf2.enabled(), expected) + v2_compat.disable_v2_behavior() + self.assertFalse(tf2.enabled()) + self.assertFalse(_pywrap_tf2.is_enabled()) - action_2() - expected, enabled, disabled = state(action_2, enabled, disabled) - self.assertEqual(tf2.enabled(), expected) + v2_compat.enable_v2_behavior() + self.assertTrue(tf2.enabled()) + self.assertTrue(_pywrap_tf2.is_enabled()) - action_3() - expected, enabled, disabled = state(action_3, enabled, disabled) - self.assertEqual(tf2.enabled(), expected) + @combinations.generate(test_base.v2_only_combinations()) + def test_tf2_enable_tf2_behaviour(self): + self.assertTrue(tf2.enabled()) + self.assertTrue(_pywrap_tf2.is_enabled()) + + v2_compat.enable_v2_behavior() + self.assertTrue(tf2.enabled()) + self.assertTrue(_pywrap_tf2.is_enabled()) + + v2_compat.disable_v2_behavior() + self.assertFalse(tf2.enabled()) + self.assertFalse(_pywrap_tf2.is_enabled()) + + @combinations.generate(test_base.v2_only_combinations()) + def test_tf2_disable_tf2_behaviour(self): + self.assertTrue(tf2.enabled()) + self.assertTrue(_pywrap_tf2.is_enabled()) + + v2_compat.disable_v2_behavior() + self.assertFalse(tf2.enabled()) + self.assertFalse(_pywrap_tf2.is_enabled()) + + v2_compat.enable_v2_behavior() + self.assertTrue(tf2.enabled()) + self.assertTrue(_pywrap_tf2.is_enabled()) if __name__ == '__main__': diff --git a/tensorflow/python/platform/BUILD b/tensorflow/python/platform/BUILD index 024fcbc5c76..85484f591a0 100644 --- a/tensorflow/python/platform/BUILD +++ b/tensorflow/python/platform/BUILD @@ -7,6 +7,9 @@ load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additiona # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "pybind_extension") + # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_py_test") @@ -230,3 +233,16 @@ tf_py_test( ":platform", ], ) + +pybind_extension( + name = "_pywrap_tf2", + srcs = ["enable_tf2.cc"], + hdrs = ["//tensorflow/core/platform:enable_tf2_hdr"], + module_name = "_pywrap_tf2", + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:enable_tf2_utils", + "@pybind11", + ], +) diff --git a/tensorflow/python/platform/enable_tf2.cc b/tensorflow/python/platform/enable_tf2.cc new file mode 100644 index 00000000000..4411dfb2614 --- /dev/null +++ b/tensorflow/python/platform/enable_tf2.cc @@ -0,0 +1,22 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "tensorflow/core/platform/enable_tf2_utils.h" + +PYBIND11_MODULE(_pywrap_tf2, m) { + m.def("enable", &tensorflow::set_tf2_execution); + m.def("is_enabled", &tensorflow::tf2_execution_enabled); +} diff --git a/tensorflow/python/tf2.py b/tensorflow/python/tf2.py index 11d7b6f1aab..88d8a472aeb 100644 --- a/tensorflow/python/tf2.py +++ b/tensorflow/python/tf2.py @@ -22,29 +22,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - +from tensorflow.python.platform import _pywrap_tf2 from tensorflow.python.util.tf_export import tf_export -_force_enable = None - - def enable(): # Enables v2 behaviors. - global _force_enable - _force_enable = True + _pywrap_tf2.enable(True) def disable(): # Disables v2 behaviors. - global _force_enable - _force_enable = False + _pywrap_tf2.enable(False) @tf_export("__internal__.tf2.enabled", v1=[]) def enabled(): # Returns True iff TensorFlow 2.0 behavior should be enabled. - if _force_enable is None: - return os.getenv("TF2_BEHAVIOR", "0") != "0" - return _force_enable + return _pywrap_tf2.is_enabled()