fixed the breakage. Re-submit plumbing of the enable_v2_behavior() call to C++
PiperOrigin-RevId: 350478865 Change-Id: Ib5baae7e822f646bbcd796f6760c9f7239f021af
This commit is contained in:
parent
5c1bc1cc40
commit
c4f8aa9853
@ -981,6 +981,17 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "enable_tf2_utils",
|
||||
srcs = ["enable_tf2_utils.cc"],
|
||||
hdrs = ["enable_tf2_utils.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core/util:env_var",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "profile_utils_cpu_utils",
|
||||
actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils",
|
||||
@ -992,6 +1003,12 @@ filegroup(
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "enable_tf2_hdr",
|
||||
srcs = ["enable_tf2_utils.h"],
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
name = "low_level_library_tests",
|
||||
size = "small",
|
||||
@ -1047,6 +1064,20 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "enable_tf2_utils_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"enable_tf2_utils_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":enable_tf2_utils",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/util:env_var",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
name = "stacktrace_handler_test",
|
||||
size = "small",
|
||||
|
49
tensorflow/core/platform/enable_tf2_utils.cc
Normal file
49
tensorflow/core/platform/enable_tf2_utils.cc
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/enable_tf2_utils.h"
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
enum Enablement : uint8 { kFalse = 0, kTrue = 1, undefined = 2 };
|
||||
|
||||
// If this flag is set, we will use it as a signal to decide on whether to
|
||||
// use the MLIR based TF-XLA bridge.
|
||||
static std::atomic<Enablement> tf2_enabled{undefined};
|
||||
|
||||
// Determine whether or not the user has explicitly asked for tf2 execution.
|
||||
// Will be used to determine whether to use the MLIR based bridge.
|
||||
void set_tf2_execution(bool enabled) {
|
||||
tf2_enabled = (enabled) ? Enablement::kTrue : Enablement::kFalse;
|
||||
}
|
||||
|
||||
bool tf2_execution_enabled() {
|
||||
if (tf2_enabled == Enablement::undefined) {
|
||||
static bool tf2_behavior_env_enabled = [] {
|
||||
string tf2_env;
|
||||
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
|
||||
return tf2_env != "0";
|
||||
}();
|
||||
tf2_enabled =
|
||||
(tf2_behavior_env_enabled) ? Enablement::kTrue : Enablement::kFalse;
|
||||
}
|
||||
return tf2_enabled;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
31
tensorflow/core/platform/enable_tf2_utils.h
Normal file
31
tensorflow/core/platform/enable_tf2_utils.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TF_CORE_PLATFORM_TF2_UTILS_H_
|
||||
#define TF_CORE_PLATFORM_TF2_UTILS_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Sets the tf2 execution state. This can be used to indicate whether the user
|
||||
// has explicitly asked for tf2 execution.
|
||||
void set_tf2_execution(bool enabled);
|
||||
|
||||
// Returns true or false depending on whether the user flag for tf2 execution
|
||||
// has been set. The default is false.
|
||||
bool tf2_execution_enabled();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TF_CORE_PLATFORM_TF2_UTILS_H_
|
35
tensorflow/core/platform/enable_tf2_utils_test.cc
Normal file
35
tensorflow/core/platform/enable_tf2_utils_test.cc
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Testing TF2 enablement.
|
||||
|
||||
#include "tensorflow/core/platform/enable_tf2_utils.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(TF2EnabledTest, enabled_behavior) {
|
||||
string tf2_env;
|
||||
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
|
||||
bool expected = (tf2_env != "0");
|
||||
EXPECT_EQ(tensorflow::tf2_execution_enabled(), expected);
|
||||
tensorflow::set_tf2_execution(true);
|
||||
EXPECT_TRUE(tensorflow::tf2_execution_enabled());
|
||||
tensorflow::set_tf2_execution(false);
|
||||
EXPECT_FALSE(tensorflow::tf2_execution_enabled());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -5266,7 +5266,10 @@ pywrap_tensorflow_macro(
|
||||
tf_additional_plugin_deps() +
|
||||
tf_additional_profiler_deps()) + if_xla_available([
|
||||
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||
]) + if_static(extra_deps = ["//tensorflow/core/platform:tensor_float_32_utils"]),
|
||||
]) + if_static(extra_deps = [
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"//tensorflow/core/platform:enable_tf2_utils",
|
||||
]),
|
||||
)
|
||||
|
||||
# ** Targets for Windows build (start) **
|
||||
@ -6768,6 +6771,10 @@ py_library(
|
||||
name = "tf2",
|
||||
srcs = ["tf2.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python/platform:_pywrap_tf2",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -6779,6 +6786,8 @@ py_test(
|
||||
":client_testlib",
|
||||
":framework_combinations",
|
||||
":tf2",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import _pywrap_tf2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -29,9 +30,13 @@ class DisableV2BehaviorTest(test.TestCase):
|
||||
def test_basic(self):
|
||||
t = constant_op.constant([1, 2, 3]) # creates a hidden context
|
||||
self.assertTrue(isinstance(t, ops.EagerTensor))
|
||||
t = _pywrap_tf2.is_enabled()
|
||||
self.assertTrue(t)
|
||||
v2_compat.disable_v2_behavior()
|
||||
t = constant_op.constant([1, 2, 3])
|
||||
self.assertFalse(isinstance(t, ops.EagerTensor))
|
||||
t = _pywrap_tf2.is_enabled()
|
||||
self.assertFalse(t)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -2047,6 +2047,10 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
self._tempdir = None
|
||||
self._cached_session = None
|
||||
self._test_start_time = None
|
||||
# This flag provides the ability to control whether the graph mode gets
|
||||
# initialized for TF1 or not. Initializing for TF1, which is what was
|
||||
# happening earlier, was preventing enablement of 'eager mode' in the test.
|
||||
self._set_default_seed = True
|
||||
|
||||
def setUp(self):
|
||||
super(TensorFlowTestCase, self).setUp()
|
||||
@ -2061,7 +2065,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
# cleared first.
|
||||
ops._default_graph_stack.reset() # pylint: disable=protected-access
|
||||
ops.reset_default_graph()
|
||||
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||
if self._set_default_seed:
|
||||
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||
# Reset summary writer in case another test used set_as_default() with their
|
||||
# summary writer.
|
||||
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
||||
|
@ -18,68 +18,73 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import _pywrap_tf2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def set_environ():
|
||||
os.environ['TF2_BEHAVIOR'] = '1'
|
||||
|
||||
|
||||
def unset_environ():
|
||||
os.environ['TF2_BEHAVIOR'] = '0'
|
||||
|
||||
|
||||
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(EnablingTF2Behavior, self).setUp()
|
||||
tf2._force_enable = None
|
||||
if 'TF2_BEHAVIOR' in os.environ:
|
||||
del os.environ['TF2_BEHAVIOR']
|
||||
def __init__(self, methodName):
|
||||
super().__init__(methodName)
|
||||
self._set_default_seed = False
|
||||
|
||||
actions = [tf2.enable, tf2.disable, set_environ, unset_environ]
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def test_tf1_enable_tf2_behaviour(self):
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
action_0=actions, action_1=actions,
|
||||
action_2=actions, action_3=actions))
|
||||
def test_scenarios(self, action_0, action_1, action_2, action_3):
|
||||
v2_compat.enable_v2_behavior()
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
def state(action, enabled, disabled):
|
||||
"""Returns bool tuple (tf2_enabled, force_enabled, force_disabled)."""
|
||||
if action is tf2.enable:
|
||||
return True, True, False
|
||||
elif action is tf2.disable:
|
||||
return False, False, True
|
||||
elif action is set_environ:
|
||||
return not disabled, enabled, disabled
|
||||
elif action is unset_environ:
|
||||
return enabled, enabled, disabled
|
||||
else:
|
||||
raise ValueError('Unexpected action {}. {} are supported'.format(
|
||||
action, EnablingTF2Behavior.actions))
|
||||
v2_compat.disable_v2_behavior()
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
action_0()
|
||||
expected, enabled, disabled = state(action_0, False, False)
|
||||
self.assertEqual(tf2.enabled(), expected)
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def test_tf1_disable_tf2_behaviour(self):
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
action_1()
|
||||
expected, enabled, disabled = state(action_1, enabled, disabled)
|
||||
self.assertEqual(tf2.enabled(), expected)
|
||||
v2_compat.disable_v2_behavior()
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
action_2()
|
||||
expected, enabled, disabled = state(action_2, enabled, disabled)
|
||||
self.assertEqual(tf2.enabled(), expected)
|
||||
v2_compat.enable_v2_behavior()
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
action_3()
|
||||
expected, enabled, disabled = state(action_3, enabled, disabled)
|
||||
self.assertEqual(tf2.enabled(), expected)
|
||||
@combinations.generate(test_base.v2_only_combinations())
|
||||
def test_tf2_enable_tf2_behaviour(self):
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
v2_compat.enable_v2_behavior()
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
v2_compat.disable_v2_behavior()
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
@combinations.generate(test_base.v2_only_combinations())
|
||||
def test_tf2_disable_tf2_behaviour(self):
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
v2_compat.disable_v2_behavior()
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
v2_compat.enable_v2_behavior()
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -7,6 +7,9 @@ load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additiona
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
@ -230,3 +233,16 @@ tf_py_test(
|
||||
":platform",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pywrap_tf2",
|
||||
srcs = ["enable_tf2.cc"],
|
||||
hdrs = ["//tensorflow/core/platform:enable_tf2_hdr"],
|
||||
module_name = "_pywrap_tf2",
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:enable_tf2_utils",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
22
tensorflow/python/platform/enable_tf2.cc
Normal file
22
tensorflow/python/platform/enable_tf2.cc
Normal file
@ -0,0 +1,22 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/core/platform/enable_tf2_utils.h"
|
||||
|
||||
PYBIND11_MODULE(_pywrap_tf2, m) {
|
||||
m.def("enable", &tensorflow::set_tf2_execution);
|
||||
m.def("is_enabled", &tensorflow::tf2_execution_enabled);
|
||||
}
|
@ -22,29 +22,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.platform import _pywrap_tf2
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
_force_enable = None
|
||||
|
||||
|
||||
def enable():
|
||||
# Enables v2 behaviors.
|
||||
global _force_enable
|
||||
_force_enable = True
|
||||
_pywrap_tf2.enable(True)
|
||||
|
||||
|
||||
def disable():
|
||||
# Disables v2 behaviors.
|
||||
global _force_enable
|
||||
_force_enable = False
|
||||
_pywrap_tf2.enable(False)
|
||||
|
||||
|
||||
@tf_export("__internal__.tf2.enabled", v1=[])
|
||||
def enabled():
|
||||
# Returns True iff TensorFlow 2.0 behavior should be enabled.
|
||||
if _force_enable is None:
|
||||
return os.getenv("TF2_BEHAVIOR", "0") != "0"
|
||||
return _force_enable
|
||||
return _pywrap_tf2.is_enabled()
|
||||
|
Loading…
Reference in New Issue
Block a user