From 5df01bc5af4bfaa955eb7ed330cfaaebe96bb109 Mon Sep 17 00:00:00 2001 From: Anna R Date: Sat, 21 Dec 2019 11:02:47 -0800 Subject: [PATCH] Add _major_api_version to top level __init__.py file to tell when we import tensorflow version 1 or version 2 api. + Minor change to the way root_init_template flag is passed in (now it should be a location path instead of file name). PiperOrigin-RevId: 286729526 Change-Id: I55ebaa0cfe0fe3db3f4d1e699082b1f7b11df4da --- tensorflow/BUILD | 4 ++-- tensorflow/api_template.__init__.py | 1 + tensorflow/api_template_v1.__init__.py | 2 ++ tensorflow/python/tools/api/generator/api_gen.bzl | 4 ++-- tensorflow/tools/api/tests/api_compatibility_test.py | 4 +++- tensorflow/tools/api/tests/module_test.py | 2 +- 6 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 2ccb9854622..7e87e58d37e 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -875,7 +875,7 @@ gen_api_init_files( output_files = TENSORFLOW_API_INIT_FILES_V1, output_package = "tensorflow._api.v1", root_file_name = "v1.py", - root_init_template = "api_template_v1.__init__.py", + root_init_template = "$(location api_template_v1.__init__.py)", ) gen_api_init_files( @@ -898,7 +898,7 @@ gen_api_init_files( output_files = TENSORFLOW_API_INIT_FILES_V2, output_package = "tensorflow._api.v2", root_file_name = "v2.py", - root_init_template = "api_template.__init__.py", + root_init_template = "$(location api_template.__init__.py)", ) py_library( diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index c515cc76b9a..2a53f973f32 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -89,6 +89,7 @@ except ImportError: # Enable TF2 behaviors from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top _compat.enable_v2_behavior() +_major_api_version = 2 # Load all plugin libraries from site-packages/tensorflow-plugins if we are diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 2b2899c3fe0..b6dba2d35da 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -104,6 +104,8 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at- _current_module.app.flags = flags # pylint: disable=undefined-variable setattr(_current_module, "flags", flags) +_major_api_version = 1 + # Load all plugin libraries from site-packages/tensorflow-plugins if we are # running under pip. # TODO(gunan): Enable setting an environment variable to define arbitrary plugin diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl index b567a229177..6595960c341 100644 --- a/tensorflow/python/tools/api/generator/api_gen.bzl +++ b/tensorflow/python/tools/api/generator/api_gen.bzl @@ -84,10 +84,10 @@ def gen_api_init_files( """ root_init_template_flag = "" if root_init_template: - root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" + root_init_template_flag = "--root_init_template=" + root_init_template primary_package = packages[0] - api_gen_binary_target = ("create_" + primary_package + "_api_%d_%s") % (api_version, name) + api_gen_binary_target = ("create_" + primary_package + "_api_%s") % name native.py_binary( name = api_gen_binary_target, srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"], diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 383dbb4ab1f..321fc381290 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -367,7 +367,9 @@ class ApiCompatibilityTest(test.TestCase): api_version=api_version) def testAPIBackwardsCompatibility(self): - api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1 + api_version = 1 + if hasattr(tf, '_major_api_version') and tf._major_api_version == 2: + api_version = 2 golden_file_pattern = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*', api_version)) diff --git a/tensorflow/tools/api/tests/module_test.py b/tensorflow/tools/api/tests/module_test.py index bf4e462391d..692ce94d006 100644 --- a/tensorflow/tools/api/tests/module_test.py +++ b/tensorflow/tools/api/tests/module_test.py @@ -74,7 +74,7 @@ class ModuleTest(test.TestCase): tf.summary.image # If we use v2 API, check for create_file_writer, # otherwise check for FileWriter. - if '._api.v2' in tf.bitwise.__name__: + if hasattr(tf, '_major_api_version') and tf._major_api_version == 2: tf.summary.create_file_writer else: tf.summary.FileWriter