fixed the breakage. Re-submit plumbing of the enable_v2_behavior() call to C++
PiperOrigin-RevId: 350478865 Change-Id: Ib5baae7e822f646bbcd796f6760c9f7239f021af
This commit is contained in:
parent
5c1bc1cc40
commit
c4f8aa9853
@ -981,6 +981,17 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "enable_tf2_utils",
|
||||||
|
srcs = ["enable_tf2_utils.cc"],
|
||||||
|
hdrs = ["enable_tf2_utils.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core/util:env_var",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "profile_utils_cpu_utils",
|
name = "profile_utils_cpu_utils",
|
||||||
actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils",
|
actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils",
|
||||||
@ -992,6 +1003,12 @@ filegroup(
|
|||||||
compatible_with = get_compatible_with_portable(),
|
compatible_with = get_compatible_with_portable(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "enable_tf2_hdr",
|
||||||
|
srcs = ["enable_tf2_utils.h"],
|
||||||
|
compatible_with = get_compatible_with_portable(),
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_tests(
|
tf_cc_tests(
|
||||||
name = "low_level_library_tests",
|
name = "low_level_library_tests",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -1047,6 +1064,20 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "enable_tf2_utils_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"enable_tf2_utils_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":enable_tf2_utils",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/util:env_var",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_tests(
|
tf_cc_tests(
|
||||||
name = "stacktrace_handler_test",
|
name = "stacktrace_handler_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
49
tensorflow/core/platform/enable_tf2_utils.cc
Normal file
49
tensorflow/core/platform/enable_tf2_utils.cc
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/enable_tf2_utils.h"
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
enum Enablement : uint8 { kFalse = 0, kTrue = 1, undefined = 2 };
|
||||||
|
|
||||||
|
// If this flag is set, we will use it as a signal to decide on whether to
|
||||||
|
// use the MLIR based TF-XLA bridge.
|
||||||
|
static std::atomic<Enablement> tf2_enabled{undefined};
|
||||||
|
|
||||||
|
// Determine whether or not the user has explicitly asked for tf2 execution.
|
||||||
|
// Will be used to determine whether to use the MLIR based bridge.
|
||||||
|
void set_tf2_execution(bool enabled) {
|
||||||
|
tf2_enabled = (enabled) ? Enablement::kTrue : Enablement::kFalse;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool tf2_execution_enabled() {
|
||||||
|
if (tf2_enabled == Enablement::undefined) {
|
||||||
|
static bool tf2_behavior_env_enabled = [] {
|
||||||
|
string tf2_env;
|
||||||
|
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
|
||||||
|
return tf2_env != "0";
|
||||||
|
}();
|
||||||
|
tf2_enabled =
|
||||||
|
(tf2_behavior_env_enabled) ? Enablement::kTrue : Enablement::kFalse;
|
||||||
|
}
|
||||||
|
return tf2_enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
31
tensorflow/core/platform/enable_tf2_utils.h
Normal file
31
tensorflow/core/platform/enable_tf2_utils.h
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TF_CORE_PLATFORM_TF2_UTILS_H_
|
||||||
|
#define TF_CORE_PLATFORM_TF2_UTILS_H_
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Sets the tf2 execution state. This can be used to indicate whether the user
|
||||||
|
// has explicitly asked for tf2 execution.
|
||||||
|
void set_tf2_execution(bool enabled);
|
||||||
|
|
||||||
|
// Returns true or false depending on whether the user flag for tf2 execution
|
||||||
|
// has been set. The default is false.
|
||||||
|
bool tf2_execution_enabled();
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TF_CORE_PLATFORM_TF2_UTILS_H_
|
35
tensorflow/core/platform/enable_tf2_utils_test.cc
Normal file
35
tensorflow/core/platform/enable_tf2_utils_test.cc
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Testing TF2 enablement.
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/enable_tf2_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TEST(TF2EnabledTest, enabled_behavior) {
|
||||||
|
string tf2_env;
|
||||||
|
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
|
||||||
|
bool expected = (tf2_env != "0");
|
||||||
|
EXPECT_EQ(tensorflow::tf2_execution_enabled(), expected);
|
||||||
|
tensorflow::set_tf2_execution(true);
|
||||||
|
EXPECT_TRUE(tensorflow::tf2_execution_enabled());
|
||||||
|
tensorflow::set_tf2_execution(false);
|
||||||
|
EXPECT_FALSE(tensorflow::tf2_execution_enabled());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -5266,7 +5266,10 @@ pywrap_tensorflow_macro(
|
|||||||
tf_additional_plugin_deps() +
|
tf_additional_plugin_deps() +
|
||||||
tf_additional_profiler_deps()) + if_xla_available([
|
tf_additional_profiler_deps()) + if_xla_available([
|
||||||
"//tensorflow/compiler/aot:tfcompile_lib",
|
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||||
]) + if_static(extra_deps = ["//tensorflow/core/platform:tensor_float_32_utils"]),
|
]) + if_static(extra_deps = [
|
||||||
|
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||||
|
"//tensorflow/core/platform:enable_tf2_utils",
|
||||||
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# ** Targets for Windows build (start) **
|
# ** Targets for Windows build (start) **
|
||||||
@ -6768,6 +6771,10 @@ py_library(
|
|||||||
name = "tf2",
|
name = "tf2",
|
||||||
srcs = ["tf2.py"],
|
srcs = ["tf2.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python/platform:_pywrap_tf2",
|
||||||
|
"//tensorflow/python/util:tf_export",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
@ -6779,6 +6786,8 @@ py_test(
|
|||||||
":client_testlib",
|
":client_testlib",
|
||||||
":framework_combinations",
|
":framework_combinations",
|
||||||
":tf2",
|
":tf2",
|
||||||
|
"//tensorflow/python/compat:v2_compat",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.compat import v2_compat
|
from tensorflow.python.compat import v2_compat
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.platform import _pywrap_tf2
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -29,9 +30,13 @@ class DisableV2BehaviorTest(test.TestCase):
|
|||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
t = constant_op.constant([1, 2, 3]) # creates a hidden context
|
t = constant_op.constant([1, 2, 3]) # creates a hidden context
|
||||||
self.assertTrue(isinstance(t, ops.EagerTensor))
|
self.assertTrue(isinstance(t, ops.EagerTensor))
|
||||||
|
t = _pywrap_tf2.is_enabled()
|
||||||
|
self.assertTrue(t)
|
||||||
v2_compat.disable_v2_behavior()
|
v2_compat.disable_v2_behavior()
|
||||||
t = constant_op.constant([1, 2, 3])
|
t = constant_op.constant([1, 2, 3])
|
||||||
self.assertFalse(isinstance(t, ops.EagerTensor))
|
self.assertFalse(isinstance(t, ops.EagerTensor))
|
||||||
|
t = _pywrap_tf2.is_enabled()
|
||||||
|
self.assertFalse(t)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -2047,6 +2047,10 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
self._tempdir = None
|
self._tempdir = None
|
||||||
self._cached_session = None
|
self._cached_session = None
|
||||||
self._test_start_time = None
|
self._test_start_time = None
|
||||||
|
# This flag provides the ability to control whether the graph mode gets
|
||||||
|
# initialized for TF1 or not. Initializing for TF1, which is what was
|
||||||
|
# happening earlier, was preventing enablement of 'eager mode' in the test.
|
||||||
|
self._set_default_seed = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TensorFlowTestCase, self).setUp()
|
super(TensorFlowTestCase, self).setUp()
|
||||||
@ -2061,7 +2065,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
# cleared first.
|
# cleared first.
|
||||||
ops._default_graph_stack.reset() # pylint: disable=protected-access
|
ops._default_graph_stack.reset() # pylint: disable=protected-access
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
|
if self._set_default_seed:
|
||||||
|
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||||
# Reset summary writer in case another test used set_as_default() with their
|
# Reset summary writer in case another test used set_as_default() with their
|
||||||
# summary writer.
|
# summary writer.
|
||||||
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
||||||
|
@ -18,68 +18,73 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
|
from tensorflow.python.compat import v2_compat
|
||||||
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
|
from tensorflow.python.platform import _pywrap_tf2
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
def set_environ():
|
|
||||||
os.environ['TF2_BEHAVIOR'] = '1'
|
|
||||||
|
|
||||||
|
|
||||||
def unset_environ():
|
|
||||||
os.environ['TF2_BEHAVIOR'] = '0'
|
|
||||||
|
|
||||||
|
|
||||||
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
|
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def __init__(self, methodName):
|
||||||
super(EnablingTF2Behavior, self).setUp()
|
super().__init__(methodName)
|
||||||
tf2._force_enable = None
|
self._set_default_seed = False
|
||||||
if 'TF2_BEHAVIOR' in os.environ:
|
|
||||||
del os.environ['TF2_BEHAVIOR']
|
|
||||||
|
|
||||||
actions = [tf2.enable, tf2.disable, set_environ, unset_environ]
|
@combinations.generate(test_base.v1_only_combinations())
|
||||||
|
def test_tf1_enable_tf2_behaviour(self):
|
||||||
|
self.assertFalse(tf2.enabled())
|
||||||
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
@combinations.generate(
|
v2_compat.enable_v2_behavior()
|
||||||
combinations.combine(
|
self.assertTrue(tf2.enabled())
|
||||||
action_0=actions, action_1=actions,
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
action_2=actions, action_3=actions))
|
|
||||||
def test_scenarios(self, action_0, action_1, action_2, action_3):
|
|
||||||
|
|
||||||
def state(action, enabled, disabled):
|
v2_compat.disable_v2_behavior()
|
||||||
"""Returns bool tuple (tf2_enabled, force_enabled, force_disabled)."""
|
self.assertFalse(tf2.enabled())
|
||||||
if action is tf2.enable:
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
return True, True, False
|
|
||||||
elif action is tf2.disable:
|
|
||||||
return False, False, True
|
|
||||||
elif action is set_environ:
|
|
||||||
return not disabled, enabled, disabled
|
|
||||||
elif action is unset_environ:
|
|
||||||
return enabled, enabled, disabled
|
|
||||||
else:
|
|
||||||
raise ValueError('Unexpected action {}. {} are supported'.format(
|
|
||||||
action, EnablingTF2Behavior.actions))
|
|
||||||
|
|
||||||
action_0()
|
@combinations.generate(test_base.v1_only_combinations())
|
||||||
expected, enabled, disabled = state(action_0, False, False)
|
def test_tf1_disable_tf2_behaviour(self):
|
||||||
self.assertEqual(tf2.enabled(), expected)
|
self.assertFalse(tf2.enabled())
|
||||||
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
action_1()
|
v2_compat.disable_v2_behavior()
|
||||||
expected, enabled, disabled = state(action_1, enabled, disabled)
|
self.assertFalse(tf2.enabled())
|
||||||
self.assertEqual(tf2.enabled(), expected)
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
action_2()
|
v2_compat.enable_v2_behavior()
|
||||||
expected, enabled, disabled = state(action_2, enabled, disabled)
|
self.assertTrue(tf2.enabled())
|
||||||
self.assertEqual(tf2.enabled(), expected)
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
action_3()
|
@combinations.generate(test_base.v2_only_combinations())
|
||||||
expected, enabled, disabled = state(action_3, enabled, disabled)
|
def test_tf2_enable_tf2_behaviour(self):
|
||||||
self.assertEqual(tf2.enabled(), expected)
|
self.assertTrue(tf2.enabled())
|
||||||
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
v2_compat.enable_v2_behavior()
|
||||||
|
self.assertTrue(tf2.enabled())
|
||||||
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
v2_compat.disable_v2_behavior()
|
||||||
|
self.assertFalse(tf2.enabled())
|
||||||
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.v2_only_combinations())
|
||||||
|
def test_tf2_disable_tf2_behaviour(self):
|
||||||
|
self.assertTrue(tf2.enabled())
|
||||||
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
v2_compat.disable_v2_behavior()
|
||||||
|
self.assertFalse(tf2.enabled())
|
||||||
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
v2_compat.enable_v2_behavior()
|
||||||
|
self.assertTrue(tf2.enabled())
|
||||||
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -7,6 +7,9 @@ load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additiona
|
|||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||||
|
|
||||||
|
# buildifier: disable=same-origin-load
|
||||||
|
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||||
|
|
||||||
@ -230,3 +233,16 @@ tf_py_test(
|
|||||||
":platform",
|
":platform",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pybind_extension(
|
||||||
|
name = "_pywrap_tf2",
|
||||||
|
srcs = ["enable_tf2.cc"],
|
||||||
|
hdrs = ["//tensorflow/core/platform:enable_tf2_hdr"],
|
||||||
|
module_name = "_pywrap_tf2",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:enable_tf2_utils",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
22
tensorflow/python/platform/enable_tf2.cc
Normal file
22
tensorflow/python/platform/enable_tf2.cc
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "tensorflow/core/platform/enable_tf2_utils.h"
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_tf2, m) {
|
||||||
|
m.def("enable", &tensorflow::set_tf2_execution);
|
||||||
|
m.def("is_enabled", &tensorflow::tf2_execution_enabled);
|
||||||
|
}
|
@ -22,29 +22,21 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
from tensorflow.python.platform import _pywrap_tf2
|
||||||
|
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
_force_enable = None
|
|
||||||
|
|
||||||
|
|
||||||
def enable():
|
def enable():
|
||||||
# Enables v2 behaviors.
|
# Enables v2 behaviors.
|
||||||
global _force_enable
|
_pywrap_tf2.enable(True)
|
||||||
_force_enable = True
|
|
||||||
|
|
||||||
|
|
||||||
def disable():
|
def disable():
|
||||||
# Disables v2 behaviors.
|
# Disables v2 behaviors.
|
||||||
global _force_enable
|
_pywrap_tf2.enable(False)
|
||||||
_force_enable = False
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("__internal__.tf2.enabled", v1=[])
|
@tf_export("__internal__.tf2.enabled", v1=[])
|
||||||
def enabled():
|
def enabled():
|
||||||
# Returns True iff TensorFlow 2.0 behavior should be enabled.
|
# Returns True iff TensorFlow 2.0 behavior should be enabled.
|
||||||
if _force_enable is None:
|
return _pywrap_tf2.is_enabled()
|
||||||
return os.getenv("TF2_BEHAVIOR", "0") != "0"
|
|
||||||
return _force_enable
|
|
||||||
|
Loading…
Reference in New Issue
Block a user