parent
d85eade584
commit
386da9758d
tensorflow
@ -800,8 +800,8 @@ genrule(
|
||||
}),
|
||||
outs = ["__init__.py"],
|
||||
cmd = select({
|
||||
"api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)",
|
||||
"//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)",
|
||||
"api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS) && sed -i'.original' 's:from . import:from . _api.v2 import:g' $(OUTS)",
|
||||
"//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS) && sed -i'.original' 's:from . import:from ._api.v1 import:g' $(OUTS)",
|
||||
}),
|
||||
)
|
||||
|
||||
|
@ -56,10 +56,10 @@ elif _tf_api_dir not in __path__:
|
||||
__path__.append(_tf_api_dir)
|
||||
|
||||
# Hook external TensorFlow modules.
|
||||
|
||||
# Import compat before trying to import summary from tensorboard, so that
|
||||
# reexport_tf_summary can get compat from sys.modules
|
||||
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
|
||||
# reexport_tf_summary can get compat from sys.modules. Only needed if using
|
||||
# lazy loading.
|
||||
_current_module.compat.v2 # pylint: disable=pointless-statement
|
||||
try:
|
||||
from tensorboard.summary._tf import summary
|
||||
_current_module.__path__ = (
|
||||
@ -78,7 +78,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v2 import keras
|
||||
from .python.keras.api._v2 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
|
@ -60,6 +60,10 @@ elif _tf_api_dir not in __path__:
|
||||
__path__.append(_tf_api_dir)
|
||||
|
||||
# Hook external TensorFlow modules.
|
||||
# Import compat before trying to import summary from tensorboard, so that
|
||||
# reexport_tf_summary can get compat from sys.modules. Only needed if using
|
||||
# lazy loading.
|
||||
_current_module.compat.v2 # pylint: disable=pointless-statement
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
_current_module.__path__ = (
|
||||
@ -69,7 +73,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v1 import keras
|
||||
from .python.keras.api._v1 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
@ -152,5 +156,4 @@ try:
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
|
||||
# pylint: enable=undefined-variable
|
||||
|
@ -195,8 +195,7 @@ class _ModuleInitCodeBuilder(object):
|
||||
dest_module_name=parent_module,
|
||||
dest_name=module_split[submodule_index])
|
||||
else:
|
||||
if submodule_index > 0:
|
||||
import_from += '.' + '.'.join(module_split[:submodule_index])
|
||||
import_from = '.'
|
||||
self.add_import(
|
||||
symbol=None,
|
||||
source_module_name=import_from,
|
||||
@ -429,6 +428,18 @@ def get_api_init_text(packages,
|
||||
api_name, compat_api_version,
|
||||
_COMPAT_MODULE_TEMPLATE % compat_api_version)
|
||||
|
||||
# Include compat.vN-1 under compat.vN.
|
||||
# For e.g. import compat.v1 under compat.v2.compat
|
||||
for version in compat_api_versions:
|
||||
if version - 1 in compat_api_versions:
|
||||
prev_version = 'v%d' % (version - 1)
|
||||
module_code_builder.add_import(
|
||||
symbol=None,
|
||||
source_module_name='%s.compat' % output_package,
|
||||
source_name=prev_version,
|
||||
dest_module_name='compat.v%d.compat' % version,
|
||||
dest_name=prev_version)
|
||||
|
||||
return module_code_builder.build()
|
||||
|
||||
|
||||
|
@ -23,6 +23,7 @@ import pkgutil
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -50,6 +51,34 @@ class ModuleTest(test.TestCase):
|
||||
def testName(self):
|
||||
self.assertEqual('tensorflow', tf.__name__)
|
||||
|
||||
def testBuiltInName(self):
|
||||
# range is a built-in name in Python. Just checking that
|
||||
# tf.range works fine.
|
||||
if tf2.enabled():
|
||||
self.assertEqual(
|
||||
'tf.Tensor([1 2 3 4 5 6 7 8 9], shape=(9,), dtype=int32)',
|
||||
str(tf.range(1, 10)))
|
||||
else:
|
||||
self.assertEqual(
|
||||
'Tensor("range:0", shape=(9,), dtype=int32)',
|
||||
str(tf.range(1, 10)))
|
||||
|
||||
def testCompatV2HasCompatV1(self):
|
||||
# pylint: disable=pointless-statement
|
||||
tf.compat.v2.compat.v1.keras
|
||||
# pylint: enable=pointless-statement
|
||||
|
||||
def testSummaryMerged(self):
|
||||
# pylint: disable=pointless-statement
|
||||
tf.summary.image
|
||||
# If we use v2 API, check for create_file_writer,
|
||||
# otherwise check for FileWriter.
|
||||
if '._api.v2' in tf.bitwise.__name__:
|
||||
tf.summary.create_file_writer
|
||||
else:
|
||||
tf.summary.FileWriter
|
||||
# pylint: enable=pointless-statement
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user