Move the namer into the generic pyct library. Clean up dead code.
PiperOrigin-RevId: 302928833 Change-Id: I3e065722bbf46c2f13faccd5265d559e55a3121b
This commit is contained in:
parent
26a3d1c92d
commit
be2c7869f5
@ -24,7 +24,6 @@ py_library(
|
||||
"config_lib.py",
|
||||
"converter.py",
|
||||
"function_wrappers.py",
|
||||
"naming.py",
|
||||
"unsupported_features_checker.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
@ -79,14 +78,3 @@ py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "naming_test",
|
||||
srcs = ["naming_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":core",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
@ -30,9 +30,9 @@ from tensorflow.python.autograph import utils
|
||||
from tensorflow.python.autograph.core import config
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import function_wrappers
|
||||
from tensorflow.python.autograph.core import naming
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.autograph.pyct import loader
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.autograph.pyct import origin_info
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import pretty_printer
|
||||
|
@ -1,129 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Symbol naming utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import enum
|
||||
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.utils import misc
|
||||
|
||||
|
||||
class _NamingStyle(enum.Enum):
|
||||
SNAKE = 1
|
||||
CAMEL = 2
|
||||
|
||||
|
||||
class Namer(object):
|
||||
"""Symbol name generator."""
|
||||
|
||||
def __init__(self, global_namespace):
|
||||
self.global_namespace = global_namespace
|
||||
self.generated_names = set()
|
||||
|
||||
def _as_symbol_name(self, fqn, style=_NamingStyle.SNAKE):
|
||||
"""Returns a symbol name that matches a fully-qualified name.
|
||||
|
||||
The returned name is safe to use for Python symbols. Any special characters
|
||||
present in fqn are replaced according to the style argument.
|
||||
|
||||
Examples:
|
||||
|
||||
self._as_symbol_name('foo.bar', style=_NamingStyle.CAMEL) == 'FooBar'
|
||||
self._as_symbol_name('foo.bar', style=_NamingStyle.SNAKE) == 'foo_bar'
|
||||
|
||||
See the unit tests for more examples.
|
||||
|
||||
Args:
|
||||
fqn: Union[Text, Tuple[Text]] a fully-qualified symbol name. The qualifier
|
||||
may include module, class names, attributes, etc.
|
||||
style: _NamingStyle
|
||||
Returns:
|
||||
Text
|
||||
"""
|
||||
assert style in _NamingStyle
|
||||
|
||||
if isinstance(fqn, tuple):
|
||||
cn = '.'.join(fqn)
|
||||
else:
|
||||
cn = fqn
|
||||
|
||||
# Until we clean up the whole FQN mechanism, `fqn` may not be
|
||||
# canonical, that is, in can appear as ('foo.bar', 'baz')
|
||||
# This replaces any characters that might remain because of that.
|
||||
pieces = cn.split('.')
|
||||
|
||||
if style == _NamingStyle.CAMEL:
|
||||
pieces = tuple(misc.capitalize_initial(p) for p in pieces)
|
||||
return ''.join(pieces)
|
||||
elif style == _NamingStyle.SNAKE:
|
||||
return '_'.join(pieces)
|
||||
|
||||
def class_name(self, original_fqn):
|
||||
"""Returns the name of a converted class."""
|
||||
canonical_name = self._as_symbol_name(
|
||||
original_fqn, style=_NamingStyle.CAMEL)
|
||||
new_name_root = 'Tf%s' % canonical_name
|
||||
new_name = new_name_root
|
||||
n = 0
|
||||
while new_name in self.global_namespace:
|
||||
n += 1
|
||||
new_name = '%s_%d' % (new_name_root, n)
|
||||
self.generated_names.add(new_name)
|
||||
return new_name
|
||||
|
||||
def function_name(self, original_fqn):
|
||||
"""Returns the name of a converted function."""
|
||||
canonical_name = self._as_symbol_name(
|
||||
original_fqn, style=_NamingStyle.SNAKE)
|
||||
new_name_root = 'tf__%s' % canonical_name
|
||||
new_name = new_name_root
|
||||
n = 0
|
||||
while new_name in self.global_namespace:
|
||||
n += 1
|
||||
new_name = '%s_%d' % (new_name_root, n)
|
||||
self.generated_names.add(new_name)
|
||||
return new_name
|
||||
|
||||
def new_symbol(self, name_root, reserved_locals):
|
||||
"""See control_flow.SymbolNamer.new_symbol."""
|
||||
# reserved_locals may contain QNs.
|
||||
all_reserved_locals = set()
|
||||
for s in reserved_locals:
|
||||
if isinstance(s, qual_names.QN):
|
||||
all_reserved_locals.update(s.qn)
|
||||
elif isinstance(s, str):
|
||||
all_reserved_locals.add(s)
|
||||
else:
|
||||
raise ValueError('Unexpected symbol type "%s"' % type(s))
|
||||
|
||||
pieces = name_root.split('_')
|
||||
if pieces[-1].isdigit():
|
||||
name_root = '_'.join(pieces[:-1])
|
||||
n = int(pieces[-1])
|
||||
else:
|
||||
n = 0
|
||||
new_name = name_root
|
||||
|
||||
while (new_name in self.global_namespace or
|
||||
new_name in all_reserved_locals or new_name in self.generated_names):
|
||||
n += 1
|
||||
new_name = '%s_%d' % (name_root, n)
|
||||
|
||||
self.generated_names.add(new_name)
|
||||
return new_name
|
@ -48,12 +48,12 @@ from tensorflow.python.autograph.converters import slices
|
||||
from tensorflow.python.autograph.core import config
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import function_wrappers
|
||||
from tensorflow.python.autograph.core import naming
|
||||
from tensorflow.python.autograph.core import unsupported_features_checker
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.pyct import loader
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.autograph.pyct import origin_info
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import pretty_printer
|
||||
@ -572,7 +572,7 @@ def convert_func_to_ast(f, program_ctx, do_rename=True):
|
||||
if isinstance(node, gast.Lambda):
|
||||
new_name = namer.new_symbol('tf__lambda', ())
|
||||
elif do_rename:
|
||||
new_name = namer.function_name(f.__name__)
|
||||
new_name = namer.new_symbol('tf__' + f.__name__, ())
|
||||
else:
|
||||
new_name = f.__name__
|
||||
|
||||
|
@ -31,6 +31,7 @@ py_library(
|
||||
"inspect_utils.py",
|
||||
"loader.py",
|
||||
"loader_deprecated_py2.py",
|
||||
"naming.py",
|
||||
"origin_info.py",
|
||||
"parser.py",
|
||||
"pretty_printer.py",
|
||||
@ -133,6 +134,17 @@ sh_test(
|
||||
tags = ["no_oss"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "naming_test",
|
||||
srcs = ["naming_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pyct",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "origin_info_test",
|
||||
srcs = ["origin_info_test.py"],
|
||||
|
57
tensorflow/python/autograph/pyct/naming.py
Normal file
57
tensorflow/python/autograph/pyct/naming.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Symbol naming utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
|
||||
|
||||
class Namer(object):
|
||||
"""Symbol name generator."""
|
||||
|
||||
def __init__(self, global_namespace):
|
||||
self.global_namespace = global_namespace
|
||||
self.generated_names = set()
|
||||
|
||||
def new_symbol(self, name_root, reserved_locals):
|
||||
"""See control_flow.SymbolNamer.new_symbol."""
|
||||
# reserved_locals may contain QNs.
|
||||
all_reserved_locals = set()
|
||||
for s in reserved_locals:
|
||||
if isinstance(s, qual_names.QN):
|
||||
all_reserved_locals.update(s.qn)
|
||||
elif isinstance(s, str):
|
||||
all_reserved_locals.add(s)
|
||||
else:
|
||||
raise ValueError('Unexpected symbol type "%s"' % type(s))
|
||||
|
||||
pieces = name_root.split('_')
|
||||
if pieces[-1].isdigit():
|
||||
name_root = '_'.join(pieces[:-1])
|
||||
n = int(pieces[-1])
|
||||
else:
|
||||
n = 0
|
||||
new_name = name_root
|
||||
|
||||
while (new_name in self.global_namespace or
|
||||
new_name in all_reserved_locals or new_name in self.generated_names):
|
||||
n += 1
|
||||
new_name = '%s_%d' % (name_root, n)
|
||||
|
||||
self.generated_names.add(new_name)
|
||||
return new_name
|
@ -18,40 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.core import naming
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class NamerTest(test.TestCase):
|
||||
|
||||
def test_function_name_tracks_names(self):
|
||||
namer = naming.Namer({})
|
||||
self.assertEqual('tf__foo', namer.function_name('foo'))
|
||||
self.assertEqual('tf__bar', namer.function_name('bar'))
|
||||
self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names)
|
||||
|
||||
def test_function_name_consistent(self):
|
||||
namer = naming.Namer({})
|
||||
self.assertEqual('tf__foo', namer.function_name('foo'))
|
||||
self.assertEqual('tf__foo', namer.function_name('foo'))
|
||||
|
||||
def test_function_name_unsanitized_fqn(self):
|
||||
namer = naming.Namer({})
|
||||
self.assertEqual('tf__foo_bar', namer.function_name('foo.bar'))
|
||||
self.assertEqual('tf__foo_bar_baz', namer.function_name(('foo.bar', 'baz')))
|
||||
|
||||
def test_class_name_basic(self):
|
||||
namer = naming.Namer({})
|
||||
self.assertEqual('TfFooBar', namer.class_name(('foo', 'Bar')))
|
||||
|
||||
def test_class_name_unsanitized_fqn(self):
|
||||
namer = naming.Namer({})
|
||||
self.assertEqual('TfFooBarBaz', namer.class_name(('foo.bar', 'Baz')))
|
||||
|
||||
def test_function_name_avoids_global_conflicts(self):
|
||||
namer = naming.Namer({'tf__foo': 1})
|
||||
self.assertEqual('tf__foo_1', namer.function_name('foo'))
|
||||
|
||||
def test_new_symbol_tracks_names(self):
|
||||
namer = naming.Namer({})
|
||||
self.assertEqual('temp', namer.new_symbol('temp', set()))
|
Loading…
Reference in New Issue
Block a user