Merge pull request #39443 from tensorflow/mm-cherrypick-keras-fix

Cherrypick keras fix
This commit is contained in:
Mihai Maruseac 2020-05-12 14:40:13 +00:00 committed by GitHub
commit a5f7769777
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 31 additions and 7 deletions

View File

@ -875,7 +875,7 @@ gen_api_init_files(
output_files = TENSORFLOW_API_INIT_FILES_V1,
output_package = "tensorflow._api.v1",
root_file_name = "v1.py",
root_init_template = "api_template_v1.__init__.py",
root_init_template = "$(location api_template_v1.__init__.py)",
)
gen_api_init_files(
@ -898,7 +898,7 @@ gen_api_init_files(
output_files = TENSORFLOW_API_INIT_FILES_V2,
output_package = "tensorflow._api.v2",
root_file_name = "v2.py",
root_init_template = "api_template.__init__.py",
root_init_template = "$(location api_template.__init__.py)",
)
py_library(

View File

@ -89,6 +89,7 @@ except ImportError:
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
_compat.enable_v2_behavior()
_major_api_version = 2
# Load all plugin libraries from site-packages/tensorflow-plugins if we are

View File

@ -104,6 +104,8 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
_current_module.app.flags = flags # pylint: disable=undefined-variable
setattr(_current_module, "flags", flags)
_major_api_version = 1
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin

View File

@ -15,6 +15,7 @@ py_library(
"//tensorflow/python:control_flow_v2_toggles",
"//tensorflow/python:tf2",
"//tensorflow/python:util",
"//tensorflow/python/eager:monitoring",
],
)

View File

@ -25,6 +25,7 @@ from tensorflow.python.data.experimental.ops import random_ops
from tensorflow.python.data.experimental.ops import readers as exp_readers
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import control_flow_v2_toggles
@ -32,6 +33,11 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.util.tf_export import tf_export
# Metrics to track the status of v2_behavior
_v2_behavior_usage_gauge = monitoring.BoolGauge(
"/tensorflow/version/v2_behavior",
"whether v2_behavior is enabled or disabled", "status")
@tf_export(v1=["enable_v2_behavior"])
def enable_v2_behavior():
@ -45,6 +51,7 @@ def enable_v2_behavior():
This function is called in the main TensorFlow `__init__.py` file, user should
not need to call it, except during complex migrations.
"""
_v2_behavior_usage_gauge.get_cell("enable").set(True)
# TF2 behavior is enabled if either 1) enable_v2_behavior() is called or
# 2) the TF2_BEHAVIOR=1 environment variable is set. In the latter case,
# the modules below independently check if tf2.enabled().
@ -82,6 +89,7 @@ def disable_v2_behavior():
User can call this function to disable 2.x behavior during complex migrations.
"""
_v2_behavior_usage_gauge.get_cell("disable").set(True)
tf2.disable()
ops.disable_eager_execution()
tensor_shape.disable_v2_tensorshape() # Also switched by tf2

View File

@ -43,6 +43,6 @@ def disable():
def enabled():
"""Returns True iff TensorFlow 2.0 behavior should be enabled."""
if _force_enable is None:
return os.getenv("TF2_BEHAVIOR", "0") != "0"
return os.getenv("TF2_BEHAVIOR", "1") == "1"
else:
return _force_enable

View File

@ -84,10 +84,10 @@ def gen_api_init_files(
"""
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
root_init_template_flag = "--root_init_template=" + root_init_template
primary_package = packages[0]
api_gen_binary_target = ("create_" + primary_package + "_api_%d_%s") % (api_version, name)
api_gen_binary_target = ("create_" + primary_package + "_api_%s") % name
native.py_binary(
name = api_gen_binary_target,
srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],

View File

@ -367,7 +367,9 @@ class ApiCompatibilityTest(test.TestCase):
api_version=api_version)
def testAPIBackwardsCompatibility(self):
api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1
api_version = 1
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))

View File

@ -24,6 +24,7 @@ import pkgutil
import tensorflow as tf
from tensorflow.python import tf2
from tensorflow.python.keras import layers
from tensorflow.python.platform import test
@ -73,12 +74,21 @@ class ModuleTest(test.TestCase):
tf.summary.image
# If we use v2 API, check for create_file_writer,
# otherwise check for FileWriter.
if '._api.v2' in tf.bitwise.__name__:
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
tf.summary.create_file_writer
else:
tf.summary.FileWriter
# pylint: enable=pointless-statement
def testInternalKerasImport(self):
normalization_parent = layers.BatchNormalization.__module__.split('.')[-1]
if tf._major_api_version == 2:
self.assertEqual('normalization_v2', normalization_parent)
self.assertTrue(layers.BatchNormalization._USE_V2_BEHAVIOR)
else:
self.assertEqual('normalization', normalization_parent)
self.assertFalse(layers.BatchNormalization._USE_V2_BEHAVIOR)
if __name__ == '__main__':
test.main()