fixed the breakage. Re-submit plumbing of the enable_v2_behavior() call to C++

PiperOrigin-RevId: 350478865
Change-Id: Ib5baae7e822f646bbcd796f6760c9f7239f021af
This commit is contained in:
Pankaj Kanwar 2021-01-06 20:30:39 -08:00 committed by TensorFlower Gardener
parent 5c1bc1cc40
commit c4f8aa9853
11 changed files with 260 additions and 60 deletions

View File

@ -981,6 +981,17 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "enable_tf2_utils",
srcs = ["enable_tf2_utils.cc"],
hdrs = ["enable_tf2_utils.h"],
copts = tf_copts(),
deps = [
"//tensorflow/core/util:env_var",
],
alwayslink = 1,
)
alias(
name = "profile_utils_cpu_utils",
actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils",
@ -992,6 +1003,12 @@ filegroup(
compatible_with = get_compatible_with_portable(),
)
filegroup(
name = "enable_tf2_hdr",
srcs = ["enable_tf2_utils.h"],
compatible_with = get_compatible_with_portable(),
)
tf_cc_tests(
name = "low_level_library_tests",
size = "small",
@ -1047,6 +1064,20 @@ tf_cc_test(
],
)
tf_cc_test(
name = "enable_tf2_utils_test",
size = "small",
srcs = [
"enable_tf2_utils_test.cc",
],
deps = [
":enable_tf2_utils",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/util:env_var",
],
)
tf_cc_tests(
name = "stacktrace_handler_test",
size = "small",

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/enable_tf2_utils.h"
#include <atomic>
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
enum Enablement : uint8 { kFalse = 0, kTrue = 1, undefined = 2 };
// If this flag is set, we will use it as a signal to decide on whether to
// use the MLIR based TF-XLA bridge.
static std::atomic<Enablement> tf2_enabled{undefined};
// Determine whether or not the user has explicitly asked for tf2 execution.
// Will be used to determine whether to use the MLIR based bridge.
void set_tf2_execution(bool enabled) {
tf2_enabled = (enabled) ? Enablement::kTrue : Enablement::kFalse;
}
bool tf2_execution_enabled() {
if (tf2_enabled == Enablement::undefined) {
static bool tf2_behavior_env_enabled = [] {
string tf2_env;
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
return tf2_env != "0";
}();
tf2_enabled =
(tf2_behavior_env_enabled) ? Enablement::kTrue : Enablement::kFalse;
}
return tf2_enabled;
}
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TF_CORE_PLATFORM_TF2_UTILS_H_
#define TF_CORE_PLATFORM_TF2_UTILS_H_
namespace tensorflow {
// Sets the tf2 execution state. This can be used to indicate whether the user
// has explicitly asked for tf2 execution.
void set_tf2_execution(bool enabled);
// Returns true or false depending on whether the user flag for tf2 execution
// has been set. The default is false.
bool tf2_execution_enabled();
} // namespace tensorflow
#endif // TF_CORE_PLATFORM_TF2_UTILS_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Testing TF2 enablement.
#include "tensorflow/core/platform/enable_tf2_utils.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
TEST(TF2EnabledTest, enabled_behavior) {
string tf2_env;
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
bool expected = (tf2_env != "0");
EXPECT_EQ(tensorflow::tf2_execution_enabled(), expected);
tensorflow::set_tf2_execution(true);
EXPECT_TRUE(tensorflow::tf2_execution_enabled());
tensorflow::set_tf2_execution(false);
EXPECT_FALSE(tensorflow::tf2_execution_enabled());
}
} // namespace tensorflow

View File

@ -5266,7 +5266,10 @@ pywrap_tensorflow_macro(
tf_additional_plugin_deps() +
tf_additional_profiler_deps()) + if_xla_available([
"//tensorflow/compiler/aot:tfcompile_lib",
]) + if_static(extra_deps = ["//tensorflow/core/platform:tensor_float_32_utils"]),
]) + if_static(extra_deps = [
"//tensorflow/core/platform:tensor_float_32_utils",
"//tensorflow/core/platform:enable_tf2_utils",
]),
)
# ** Targets for Windows build (start) **
@ -6768,6 +6771,10 @@ py_library(
name = "tf2",
srcs = ["tf2.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/platform:_pywrap_tf2",
"//tensorflow/python/util:tf_export",
],
)
py_test(
@ -6779,6 +6786,8 @@ py_test(
":client_testlib",
":framework_combinations",
":tf2",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/data/kernel_tests:test_base",
],
)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.compat import v2_compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import _pywrap_tf2
from tensorflow.python.platform import test
@ -29,9 +30,13 @@ class DisableV2BehaviorTest(test.TestCase):
def test_basic(self):
t = constant_op.constant([1, 2, 3]) # creates a hidden context
self.assertTrue(isinstance(t, ops.EagerTensor))
t = _pywrap_tf2.is_enabled()
self.assertTrue(t)
v2_compat.disable_v2_behavior()
t = constant_op.constant([1, 2, 3])
self.assertFalse(isinstance(t, ops.EagerTensor))
t = _pywrap_tf2.is_enabled()
self.assertFalse(t)
if __name__ == '__main__':

View File

@ -2047,6 +2047,10 @@ class TensorFlowTestCase(googletest.TestCase):
self._tempdir = None
self._cached_session = None
self._test_start_time = None
# This flag provides the ability to control whether the graph mode gets
# initialized for TF1 or not. Initializing for TF1, which is what was
# happening earlier, was preventing enablement of 'eager mode' in the test.
self._set_default_seed = True
def setUp(self):
super(TensorFlowTestCase, self).setUp()
@ -2061,7 +2065,8 @@ class TensorFlowTestCase(googletest.TestCase):
# cleared first.
ops._default_graph_stack.reset() # pylint: disable=protected-access
ops.reset_default_graph()
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
if self._set_default_seed:
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
# Reset summary writer in case another test used set_as_default() with their
# summary writer.
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access

View File

@ -18,68 +18,73 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
from tensorflow.python import tf2
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import combinations
from tensorflow.python.platform import _pywrap_tf2
from tensorflow.python.platform import test
def set_environ():
os.environ['TF2_BEHAVIOR'] = '1'
def unset_environ():
os.environ['TF2_BEHAVIOR'] = '0'
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
def setUp(self):
super(EnablingTF2Behavior, self).setUp()
tf2._force_enable = None
if 'TF2_BEHAVIOR' in os.environ:
del os.environ['TF2_BEHAVIOR']
def __init__(self, methodName):
super().__init__(methodName)
self._set_default_seed = False
actions = [tf2.enable, tf2.disable, set_environ, unset_environ]
@combinations.generate(test_base.v1_only_combinations())
def test_tf1_enable_tf2_behaviour(self):
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
@combinations.generate(
combinations.combine(
action_0=actions, action_1=actions,
action_2=actions, action_3=actions))
def test_scenarios(self, action_0, action_1, action_2, action_3):
v2_compat.enable_v2_behavior()
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
def state(action, enabled, disabled):
"""Returns bool tuple (tf2_enabled, force_enabled, force_disabled)."""
if action is tf2.enable:
return True, True, False
elif action is tf2.disable:
return False, False, True
elif action is set_environ:
return not disabled, enabled, disabled
elif action is unset_environ:
return enabled, enabled, disabled
else:
raise ValueError('Unexpected action {}. {} are supported'.format(
action, EnablingTF2Behavior.actions))
v2_compat.disable_v2_behavior()
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
action_0()
expected, enabled, disabled = state(action_0, False, False)
self.assertEqual(tf2.enabled(), expected)
@combinations.generate(test_base.v1_only_combinations())
def test_tf1_disable_tf2_behaviour(self):
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
action_1()
expected, enabled, disabled = state(action_1, enabled, disabled)
self.assertEqual(tf2.enabled(), expected)
v2_compat.disable_v2_behavior()
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
action_2()
expected, enabled, disabled = state(action_2, enabled, disabled)
self.assertEqual(tf2.enabled(), expected)
v2_compat.enable_v2_behavior()
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
action_3()
expected, enabled, disabled = state(action_3, enabled, disabled)
self.assertEqual(tf2.enabled(), expected)
@combinations.generate(test_base.v2_only_combinations())
def test_tf2_enable_tf2_behaviour(self):
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
v2_compat.enable_v2_behavior()
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
v2_compat.disable_v2_behavior()
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
@combinations.generate(test_base.v2_only_combinations())
def test_tf2_disable_tf2_behaviour(self):
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
v2_compat.disable_v2_behavior()
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
v2_compat.enable_v2_behavior()
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
if __name__ == '__main__':

View File

@ -7,6 +7,9 @@ load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additiona
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "pybind_extension")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_py_test")
@ -230,3 +233,16 @@ tf_py_test(
":platform",
],
)
pybind_extension(
name = "_pywrap_tf2",
srcs = ["enable_tf2.cc"],
hdrs = ["//tensorflow/core/platform:enable_tf2_hdr"],
module_name = "_pywrap_tf2",
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:enable_tf2_utils",
"@pybind11",
],
)

View File

@ -0,0 +1,22 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "pybind11/pybind11.h"
#include "tensorflow/core/platform/enable_tf2_utils.h"
PYBIND11_MODULE(_pywrap_tf2, m) {
m.def("enable", &tensorflow::set_tf2_execution);
m.def("is_enabled", &tensorflow::tf2_execution_enabled);
}

View File

@ -22,29 +22,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.python.platform import _pywrap_tf2
from tensorflow.python.util.tf_export import tf_export
_force_enable = None
def enable():
# Enables v2 behaviors.
global _force_enable
_force_enable = True
_pywrap_tf2.enable(True)
def disable():
# Disables v2 behaviors.
global _force_enable
_force_enable = False
_pywrap_tf2.enable(False)
@tf_export("__internal__.tf2.enabled", v1=[])
def enabled():
# Returns True iff TensorFlow 2.0 behavior should be enabled.
if _force_enable is None:
return os.getenv("TF2_BEHAVIOR", "0") != "0"
return _force_enable
return _pywrap_tf2.is_enabled()