Merge pull request #39443 from tensorflow/mm-cherrypick-keras-fix
Cherrypick keras fix
This commit is contained in:
commit
a5f7769777
@ -875,7 +875,7 @@ gen_api_init_files(
|
||||
output_files = TENSORFLOW_API_INIT_FILES_V1,
|
||||
output_package = "tensorflow._api.v1",
|
||||
root_file_name = "v1.py",
|
||||
root_init_template = "api_template_v1.__init__.py",
|
||||
root_init_template = "$(location api_template_v1.__init__.py)",
|
||||
)
|
||||
|
||||
gen_api_init_files(
|
||||
@ -898,7 +898,7 @@ gen_api_init_files(
|
||||
output_files = TENSORFLOW_API_INIT_FILES_V2,
|
||||
output_package = "tensorflow._api.v2",
|
||||
root_file_name = "v2.py",
|
||||
root_init_template = "api_template.__init__.py",
|
||||
root_init_template = "$(location api_template.__init__.py)",
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -89,6 +89,7 @@ except ImportError:
|
||||
# Enable TF2 behaviors
|
||||
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
||||
_compat.enable_v2_behavior()
|
||||
_major_api_version = 2
|
||||
|
||||
|
||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||
|
@ -104,6 +104,8 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
|
||||
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||
setattr(_current_module, "flags", flags)
|
||||
|
||||
_major_api_version = 1
|
||||
|
||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||
# running under pip.
|
||||
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
|
||||
|
@ -15,6 +15,7 @@ py_library(
|
||||
"//tensorflow/python:control_flow_v2_toggles",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:monitoring",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.data.experimental.ops import random_ops
|
||||
from tensorflow.python.data.experimental.ops import readers as exp_readers
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
@ -32,6 +33,11 @@ from tensorflow.python.ops import variable_scope
|
||||
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# Metrics to track the status of v2_behavior
|
||||
_v2_behavior_usage_gauge = monitoring.BoolGauge(
|
||||
"/tensorflow/version/v2_behavior",
|
||||
"whether v2_behavior is enabled or disabled", "status")
|
||||
|
||||
|
||||
@tf_export(v1=["enable_v2_behavior"])
|
||||
def enable_v2_behavior():
|
||||
@ -45,6 +51,7 @@ def enable_v2_behavior():
|
||||
This function is called in the main TensorFlow `__init__.py` file, user should
|
||||
not need to call it, except during complex migrations.
|
||||
"""
|
||||
_v2_behavior_usage_gauge.get_cell("enable").set(True)
|
||||
# TF2 behavior is enabled if either 1) enable_v2_behavior() is called or
|
||||
# 2) the TF2_BEHAVIOR=1 environment variable is set. In the latter case,
|
||||
# the modules below independently check if tf2.enabled().
|
||||
@ -82,6 +89,7 @@ def disable_v2_behavior():
|
||||
|
||||
User can call this function to disable 2.x behavior during complex migrations.
|
||||
"""
|
||||
_v2_behavior_usage_gauge.get_cell("disable").set(True)
|
||||
tf2.disable()
|
||||
ops.disable_eager_execution()
|
||||
tensor_shape.disable_v2_tensorshape() # Also switched by tf2
|
||||
|
@ -43,6 +43,6 @@ def disable():
|
||||
def enabled():
|
||||
"""Returns True iff TensorFlow 2.0 behavior should be enabled."""
|
||||
if _force_enable is None:
|
||||
return os.getenv("TF2_BEHAVIOR", "0") != "0"
|
||||
return os.getenv("TF2_BEHAVIOR", "1") == "1"
|
||||
else:
|
||||
return _force_enable
|
||||
|
@ -84,10 +84,10 @@ def gen_api_init_files(
|
||||
"""
|
||||
root_init_template_flag = ""
|
||||
if root_init_template:
|
||||
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
|
||||
root_init_template_flag = "--root_init_template=" + root_init_template
|
||||
|
||||
primary_package = packages[0]
|
||||
api_gen_binary_target = ("create_" + primary_package + "_api_%d_%s") % (api_version, name)
|
||||
api_gen_binary_target = ("create_" + primary_package + "_api_%s") % name
|
||||
native.py_binary(
|
||||
name = api_gen_binary_target,
|
||||
srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
|
||||
|
@ -367,7 +367,9 @@ class ApiCompatibilityTest(test.TestCase):
|
||||
api_version=api_version)
|
||||
|
||||
def testAPIBackwardsCompatibility(self):
|
||||
api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1
|
||||
api_version = 1
|
||||
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
|
||||
api_version = 2
|
||||
golden_file_pattern = os.path.join(
|
||||
resource_loader.get_root_dir_with_all_resources(),
|
||||
_KeyToFilePath('*', api_version))
|
||||
|
@ -24,6 +24,7 @@ import pkgutil
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -73,12 +74,21 @@ class ModuleTest(test.TestCase):
|
||||
tf.summary.image
|
||||
# If we use v2 API, check for create_file_writer,
|
||||
# otherwise check for FileWriter.
|
||||
if '._api.v2' in tf.bitwise.__name__:
|
||||
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
|
||||
tf.summary.create_file_writer
|
||||
else:
|
||||
tf.summary.FileWriter
|
||||
# pylint: enable=pointless-statement
|
||||
|
||||
def testInternalKerasImport(self):
|
||||
normalization_parent = layers.BatchNormalization.__module__.split('.')[-1]
|
||||
if tf._major_api_version == 2:
|
||||
self.assertEqual('normalization_v2', normalization_parent)
|
||||
self.assertTrue(layers.BatchNormalization._USE_V2_BEHAVIOR)
|
||||
else:
|
||||
self.assertEqual('normalization', normalization_parent)
|
||||
self.assertFalse(layers.BatchNormalization._USE_V2_BEHAVIOR)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user