Add _major_api_version to top level __init__.py file to tell when we import

tensorflow version 1 or version 2 api. + Minor change to the way root_init_template flag is passed in (now it should be a location path instead of file name).

PiperOrigin-RevId: 286729526
Change-Id: I55ebaa0cfe0fe3db3f4d1e699082b1f7b11df4da
This commit is contained in:
Anna R 2019-12-21 11:02:47 -08:00 committed by TensorFlower Gardener
parent b37c967b37
commit 8f7f1e22b4
6 changed files with 11 additions and 6 deletions

View File

@ -860,7 +860,7 @@ gen_api_init_files(
output_files = TENSORFLOW_API_INIT_FILES_V1,
output_package = "tensorflow._api.v1",
root_file_name = "v1.py",
root_init_template = "api_template_v1.__init__.py",
root_init_template = "$(location api_template_v1.__init__.py)",
)
gen_api_init_files(
@ -883,7 +883,7 @@ gen_api_init_files(
output_files = TENSORFLOW_API_INIT_FILES_V2,
output_package = "tensorflow._api.v2",
root_file_name = "v2.py",
root_init_template = "api_template.__init__.py",
root_init_template = "$(location api_template.__init__.py)",
)
py_library(

View File

@ -89,6 +89,7 @@ except ImportError:
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
_compat.enable_v2_behavior()
_major_api_version = 2
# Load all plugin libraries from site-packages/tensorflow-plugins if we are

View File

@ -104,6 +104,8 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
_current_module.app.flags = flags # pylint: disable=undefined-variable
setattr(_current_module, "flags", flags)
_major_api_version = 1
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin

View File

@ -84,10 +84,10 @@ def gen_api_init_files(
"""
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
root_init_template_flag = "--root_init_template=" + root_init_template
primary_package = packages[0]
api_gen_binary_target = ("create_" + primary_package + "_api_%d_%s") % (api_version, name)
api_gen_binary_target = ("create_" + primary_package + "_api_%s") % name
native.py_binary(
name = api_gen_binary_target,
srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],

View File

@ -367,7 +367,9 @@ class ApiCompatibilityTest(test.TestCase):
api_version=api_version)
def testAPIBackwardsCompatibility(self):
api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1
api_version = 1
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))

View File

@ -73,7 +73,7 @@ class ModuleTest(test.TestCase):
tf.summary.image
# If we use v2 API, check for create_file_writer,
# otherwise check for FileWriter.
if '._api.v2' in tf.bitwise.__name__:
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
tf.summary.create_file_writer
else:
tf.summary.FileWriter