Add _major_api_version to top level __init__.py file to tell when we import
tensorflow version 1 or version 2 api. + Minor change to the way root_init_template flag is passed in (now it should be a location path instead of file name). PiperOrigin-RevId: 286729526 Change-Id: I55ebaa0cfe0fe3db3f4d1e699082b1f7b11df4da
This commit is contained in:
parent
b37c967b37
commit
8f7f1e22b4
@ -860,7 +860,7 @@ gen_api_init_files(
|
||||
output_files = TENSORFLOW_API_INIT_FILES_V1,
|
||||
output_package = "tensorflow._api.v1",
|
||||
root_file_name = "v1.py",
|
||||
root_init_template = "api_template_v1.__init__.py",
|
||||
root_init_template = "$(location api_template_v1.__init__.py)",
|
||||
)
|
||||
|
||||
gen_api_init_files(
|
||||
@ -883,7 +883,7 @@ gen_api_init_files(
|
||||
output_files = TENSORFLOW_API_INIT_FILES_V2,
|
||||
output_package = "tensorflow._api.v2",
|
||||
root_file_name = "v2.py",
|
||||
root_init_template = "api_template.__init__.py",
|
||||
root_init_template = "$(location api_template.__init__.py)",
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -89,6 +89,7 @@ except ImportError:
|
||||
# Enable TF2 behaviors
|
||||
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
||||
_compat.enable_v2_behavior()
|
||||
_major_api_version = 2
|
||||
|
||||
|
||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||
|
@ -104,6 +104,8 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
|
||||
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||
setattr(_current_module, "flags", flags)
|
||||
|
||||
_major_api_version = 1
|
||||
|
||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||
# running under pip.
|
||||
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
|
||||
|
@ -84,10 +84,10 @@ def gen_api_init_files(
|
||||
"""
|
||||
root_init_template_flag = ""
|
||||
if root_init_template:
|
||||
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
|
||||
root_init_template_flag = "--root_init_template=" + root_init_template
|
||||
|
||||
primary_package = packages[0]
|
||||
api_gen_binary_target = ("create_" + primary_package + "_api_%d_%s") % (api_version, name)
|
||||
api_gen_binary_target = ("create_" + primary_package + "_api_%s") % name
|
||||
native.py_binary(
|
||||
name = api_gen_binary_target,
|
||||
srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
|
||||
|
@ -367,7 +367,9 @@ class ApiCompatibilityTest(test.TestCase):
|
||||
api_version=api_version)
|
||||
|
||||
def testAPIBackwardsCompatibility(self):
|
||||
api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1
|
||||
api_version = 1
|
||||
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
|
||||
api_version = 2
|
||||
golden_file_pattern = os.path.join(
|
||||
resource_loader.get_root_dir_with_all_resources(),
|
||||
_KeyToFilePath('*', api_version))
|
||||
|
@ -73,7 +73,7 @@ class ModuleTest(test.TestCase):
|
||||
tf.summary.image
|
||||
# If we use v2 API, check for create_file_writer,
|
||||
# otherwise check for FileWriter.
|
||||
if '._api.v2' in tf.bitwise.__name__:
|
||||
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
|
||||
tf.summary.create_file_writer
|
||||
else:
|
||||
tf.summary.FileWriter
|
||||
|
Loading…
Reference in New Issue
Block a user