Automated rollback of commit 1d1f7dfcbd

PiperOrigin-RevId: 266040561
This commit is contained in:
Anna R 2019-08-28 18:32:21 -07:00 committed by TensorFlower Gardener
parent d85eade584
commit 386da9758d
5 changed files with 53 additions and 10 deletions

View File

@ -800,8 +800,8 @@ genrule(
}),
outs = ["__init__.py"],
cmd = select({
"api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)",
"//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)",
"api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS) && sed -i'.original' 's:from . import:from . _api.v2 import:g' $(OUTS)",
"//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS) && sed -i'.original' 's:from . import:from ._api.v1 import:g' $(OUTS)",
}),
)

View File

@ -56,10 +56,10 @@ elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Hook external TensorFlow modules.
# Import compat before trying to import summary from tensorboard, so that
# reexport_tf_summary can get compat from sys.modules
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
# reexport_tf_summary can get compat from sys.modules. Only needed if using
# lazy loading.
_current_module.compat.v2 # pylint: disable=pointless-statement
try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
@ -78,7 +78,7 @@ except ImportError:
pass
try:
from tensorflow.python.keras.api._v2 import keras
from .python.keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)

View File

@ -60,6 +60,10 @@ elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Hook external TensorFlow modules.
# Import compat before trying to import summary from tensorboard, so that
# reexport_tf_summary can get compat from sys.modules. Only needed if using
# lazy loading.
_current_module.compat.v2 # pylint: disable=pointless-statement
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
@ -69,7 +73,7 @@ except ImportError:
pass
try:
from tensorflow.python.keras.api._v1 import keras
from .python.keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
@ -152,5 +156,4 @@ try:
except NameError:
pass
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
# pylint: enable=undefined-variable

View File

@ -195,8 +195,7 @@ class _ModuleInitCodeBuilder(object):
dest_module_name=parent_module,
dest_name=module_split[submodule_index])
else:
if submodule_index > 0:
import_from += '.' + '.'.join(module_split[:submodule_index])
import_from = '.'
self.add_import(
symbol=None,
source_module_name=import_from,
@ -429,6 +428,18 @@ def get_api_init_text(packages,
api_name, compat_api_version,
_COMPAT_MODULE_TEMPLATE % compat_api_version)
# Include compat.vN-1 under compat.vN.
# For e.g. import compat.v1 under compat.v2.compat
for version in compat_api_versions:
if version - 1 in compat_api_versions:
prev_version = 'v%d' % (version - 1)
module_code_builder.add_import(
symbol=None,
source_module_name='%s.compat' % output_package,
source_name=prev_version,
dest_module_name='compat.v%d.compat' % version,
dest_name=prev_version)
return module_code_builder.build()

View File

@ -23,6 +23,7 @@ import pkgutil
import tensorflow as tf
from tensorflow.python import tf2
from tensorflow.python.platform import test
@ -50,6 +51,34 @@ class ModuleTest(test.TestCase):
def testName(self):
self.assertEqual('tensorflow', tf.__name__)
def testBuiltInName(self):
# range is a built-in name in Python. Just checking that
# tf.range works fine.
if tf2.enabled():
self.assertEqual(
'tf.Tensor([1 2 3 4 5 6 7 8 9], shape=(9,), dtype=int32)',
str(tf.range(1, 10)))
else:
self.assertEqual(
'Tensor("range:0", shape=(9,), dtype=int32)',
str(tf.range(1, 10)))
def testCompatV2HasCompatV1(self):
# pylint: disable=pointless-statement
tf.compat.v2.compat.v1.keras
# pylint: enable=pointless-statement
def testSummaryMerged(self):
# pylint: disable=pointless-statement
tf.summary.image
# If we use v2 API, check for create_file_writer,
# otherwise check for FileWriter.
if '._api.v2' in tf.bitwise.__name__:
tf.summary.create_file_writer
else:
tf.summary.FileWriter
# pylint: enable=pointless-statement
if __name__ == '__main__':
test.main()