Import submodules using relative imports. That is, "from .b import c" instead of "from a.b import c" in TF python API. This helps with autocomplete in PyCharms.
Specifically, it makes autocomplete work for `tf.image.` (with `import tensorflow as tf`) or `image.` (with `from tensorflow import image`). PiperOrigin-RevId: 264901051
This commit is contained in:
parent
37ccc893db
commit
e32d1900e5
tensorflow
@ -811,8 +811,8 @@ genrule(
|
||||
}),
|
||||
outs = ["__init__.py"],
|
||||
cmd = select({
|
||||
"api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)",
|
||||
"//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)",
|
||||
"api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS) && sed -i 's:from . import:from . _api.v2 import:g' $(OUTS)",
|
||||
"//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS) && sed -i 's:from . import:from ._api.v1 import:g' $(OUTS)",
|
||||
}),
|
||||
)
|
||||
|
||||
|
@ -78,7 +78,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v2 import keras
|
||||
from .python.keras.api._v2 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
|
@ -69,7 +69,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tensorflow.python.keras.api._v1 import keras
|
||||
from .python.keras.api._v1 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
|
@ -195,8 +195,7 @@ class _ModuleInitCodeBuilder(object):
|
||||
dest_module_name=parent_module,
|
||||
dest_name=module_split[submodule_index])
|
||||
else:
|
||||
if submodule_index > 0:
|
||||
import_from += '.' + '.'.join(module_split[:submodule_index])
|
||||
import_from = '.'
|
||||
self.add_import(
|
||||
symbol=None,
|
||||
source_module_name=import_from,
|
||||
|
@ -23,6 +23,7 @@ import pkgutil
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -50,6 +51,18 @@ class ModuleTest(test.TestCase):
|
||||
def testName(self):
|
||||
self.assertEqual('tensorflow', tf.__name__)
|
||||
|
||||
def testBuiltInName(self):
|
||||
# range is a built-in name in Python. Just checking that
|
||||
# tf.range works fine.
|
||||
if tf2.enabled():
|
||||
self.assertEqual(
|
||||
'tf.Tensor([1 2 3 4 5 6 7 8 9], shape=(9,), dtype=int32)',
|
||||
str(tf.range(1, 10)))
|
||||
else:
|
||||
self.assertEqual(
|
||||
'Tensor("range:0", shape=(9,), dtype=int32)',
|
||||
str(tf.range(1, 10)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user