Add functions to switch between 1.x and 2.x global behavior.
PiperOrigin-RevId: 223595880
This commit is contained in:
parent
f30d3d01e9
commit
61c7cbca28
@ -35,8 +35,9 @@ _tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: di
|
||||
if _tf_api_dir not in __path__:
|
||||
__path__.append(_tf_api_dir)
|
||||
|
||||
# Calls to enable and disable features.
|
||||
enable_eager_execution() # pylint: disable=undefined-variable
|
||||
# Enable TF2 behaviors
|
||||
from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top
|
||||
_compat.enable_v2_behavior()
|
||||
|
||||
# These symbols appear because we import the python package which
|
||||
# in turn imports from tensorflow.core and tensorflow.python. They
|
||||
|
@ -9,7 +9,10 @@ py_library(
|
||||
srcs = ["compat.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = ["//tensorflow/python:util"],
|
||||
deps = [
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
|
@ -23,6 +23,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import datetime
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -132,3 +138,40 @@ def forward_compatibility_horizon(year, month, day):
|
||||
yield
|
||||
finally:
|
||||
_FORWARD_COMPATIBILITY_HORIZON = old_compat_date
|
||||
|
||||
|
||||
@tf_export(v1=["enable_v2_behavior"])
|
||||
def enable_v2_behavior():
|
||||
"""Enables TensorFlow 2.x behaviors.
|
||||
|
||||
This function can be called at the beginning of the program (before `Tensors`,
|
||||
`Graphs` or other structures have been created, and before devices have been
|
||||
initialized. It switches all global behaviors that are different between
|
||||
TensorFlow 1.x and 2.x to behave as intended for 2.x.
|
||||
|
||||
This function is called in the main TensorFlow `__init__.py` file, user should
|
||||
not need to call it, except during complex migrations.
|
||||
"""
|
||||
tf2.enable() # Switches TensorArrayV2 and control flow V2
|
||||
ops.enable_eager_execution()
|
||||
tensor_shape.enable_v2_tensorshape() # Also switched by tf2
|
||||
variable_scope.enable_resource_variables()
|
||||
|
||||
|
||||
@tf_export(v1=["disable_v2_behavior"])
|
||||
def disable_v2_behavior():
|
||||
"""Enables TensorFlow 2.x behaviors.
|
||||
|
||||
This function can be called at the beginning of the program (before `Tensors`,
|
||||
`Graphs` or other structures have been created, and before devices have been
|
||||
initialized. It switches all global behaviors that are different between
|
||||
TensorFlow 1.x and 2.x to behave as intended for 1.x.
|
||||
|
||||
User can call this function to disable 2.x behavior during complex migrations.
|
||||
"""
|
||||
tf2.disable() # Switches TensorArrayV2 and control flow V2
|
||||
ops.disable_eager_execution()
|
||||
tensor_shape.disable_v2_tensorshape() # Also switched by tf2
|
||||
variable_scope.disable_resource_variables()
|
||||
|
||||
|
||||
|
@ -5393,7 +5393,7 @@ def inside_function():
|
||||
return get_default_graph().building_function
|
||||
|
||||
|
||||
@tf_export("enable_eager_execution")
|
||||
@tf_export(v1=["enable_eager_execution"])
|
||||
def enable_eager_execution(config=None,
|
||||
device_policy=None,
|
||||
execution_mode=None):
|
||||
@ -5464,6 +5464,17 @@ def enable_eager_execution(config=None,
|
||||
server_def=None)
|
||||
|
||||
|
||||
@tf_export(v1=["disable_eager_execution"])
|
||||
def disable_eager_execution():
|
||||
"""Disables eager execution.
|
||||
|
||||
This function can only be called before any Graphs, Ops, or Tensors have been
|
||||
created. It can be used at the beginning of the program for complex migration
|
||||
projects from TensorFlow 1.x to 2.x.
|
||||
"""
|
||||
context.default_execution_mode = context.GRAPH_MODE
|
||||
|
||||
|
||||
def enable_eager_execution_internal(config=None,
|
||||
device_policy=None,
|
||||
execution_mode=None,
|
||||
@ -5471,6 +5482,7 @@ def enable_eager_execution_internal(config=None,
|
||||
"""Enables eager execution for the lifetime of this program.
|
||||
|
||||
Most of the doc string for enable_eager_execution is relevant here as well.
|
||||
|
||||
Args:
|
||||
config: See enable_eager_execution doc string
|
||||
device_policy: See enable_eager_execution doc string
|
||||
|
@ -25,6 +25,21 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
|
||||
_force_enable = False
|
||||
|
||||
|
||||
def enable():
|
||||
"""Enables v2 behaviors."""
|
||||
global _force_enable
|
||||
_force_enable = True
|
||||
|
||||
|
||||
def disable():
|
||||
"""Disables v2 behaviors (TF2_BEHAVIOR env variable is still respected)."""
|
||||
global _force_enable
|
||||
_force_enable = False
|
||||
|
||||
|
||||
def enabled():
|
||||
"""Returns True iff TensorFlow 2.0 behavior should be enabled."""
|
||||
return os.getenv("TF2_BEHAVIOR") is not None
|
||||
return _force_enable or os.getenv("TF2_BEHAVIOR") is not None
|
||||
|
@ -1052,10 +1052,18 @@ tf_module {
|
||||
name: "dimension_value"
|
||||
argspec: "args=[\'dimension\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "disable_eager_execution"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "disable_resource_variables"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "disable_v2_behavior"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "disable_v2_tensorshape"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
@ -1096,6 +1104,10 @@ tf_module {
|
||||
name: "enable_resource_variables"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "enable_v2_behavior"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "enable_v2_tensorshape"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -608,10 +608,6 @@ tf_module {
|
||||
name: "einsum"
|
||||
argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "enable_eager_execution"
|
||||
argspec: "args=[\'config\', \'device_policy\', \'execution_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ensure_shape"
|
||||
argspec: "args=[\'x\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -126,7 +126,9 @@ renames = {
|
||||
'tf.digamma': 'tf.math.digamma',
|
||||
'tf.dimension_at_index': 'tf.compat.v1.dimension_at_index',
|
||||
'tf.dimension_value': 'tf.compat.v1.dimension_value',
|
||||
'tf.disable_eager_execution': 'tf.compat.v1.disable_eager_execution',
|
||||
'tf.disable_resource_variables': 'tf.compat.v1.disable_resource_variables',
|
||||
'tf.disable_v2_behavior': 'tf.compat.v1.disable_v2_behavior',
|
||||
'tf.disable_v2_tensorshape': 'tf.compat.v1.disable_v2_tensorshape',
|
||||
'tf.distributions.Bernoulli': 'tf.compat.v1.distributions.Bernoulli',
|
||||
'tf.distributions.Beta': 'tf.compat.v1.distributions.Beta',
|
||||
@ -147,7 +149,9 @@ renames = {
|
||||
'tf.distributions.Uniform': 'tf.compat.v1.distributions.Uniform',
|
||||
'tf.distributions.kl_divergence': 'tf.compat.v1.distributions.kl_divergence',
|
||||
'tf.div': 'tf.compat.v1.div',
|
||||
'tf.enable_eager_execution': 'tf.compat.v1.enable_eager_execution',
|
||||
'tf.enable_resource_variables': 'tf.compat.v1.enable_resource_variables',
|
||||
'tf.enable_v2_behavior': 'tf.compat.v1.enable_v2_behavior',
|
||||
'tf.enable_v2_tensorshape': 'tf.compat.v1.enable_v2_tensorshape',
|
||||
'tf.encode_base64': 'tf.io.encode_base64',
|
||||
'tf.erf': 'tf.math.erf',
|
||||
|
Loading…
Reference in New Issue
Block a user