Move the namer into the generic pyct library. Clean up dead code.

PiperOrigin-RevId: 302928833
Change-Id: I3e065722bbf46c2f13faccd5265d559e55a3121b
This commit is contained in:
Dan Moldovan 2020-03-25 11:12:10 -07:00 committed by TensorFlower Gardener
parent 26a3d1c92d
commit be2c7869f5
7 changed files with 73 additions and 173 deletions

View File

@ -24,7 +24,6 @@ py_library(
"config_lib.py",
"converter.py",
"function_wrappers.py",
"naming.py",
"unsupported_features_checker.py",
],
srcs_version = "PY2AND3",
@ -79,14 +78,3 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
py_test(
name = "naming_test",
srcs = ["naming_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":core",
"//tensorflow/python:client_testlib",
],
)

View File

@ -30,9 +30,9 @@ from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.core import naming
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer

View File

@ -1,129 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Symbol naming utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.utils import misc
class _NamingStyle(enum.Enum):
SNAKE = 1
CAMEL = 2
class Namer(object):
"""Symbol name generator."""
def __init__(self, global_namespace):
self.global_namespace = global_namespace
self.generated_names = set()
def _as_symbol_name(self, fqn, style=_NamingStyle.SNAKE):
"""Returns a symbol name that matches a fully-qualified name.
The returned name is safe to use for Python symbols. Any special characters
present in fqn are replaced according to the style argument.
Examples:
self._as_symbol_name('foo.bar', style=_NamingStyle.CAMEL) == 'FooBar'
self._as_symbol_name('foo.bar', style=_NamingStyle.SNAKE) == 'foo_bar'
See the unit tests for more examples.
Args:
fqn: Union[Text, Tuple[Text]] a fully-qualified symbol name. The qualifier
may include module, class names, attributes, etc.
style: _NamingStyle
Returns:
Text
"""
assert style in _NamingStyle
if isinstance(fqn, tuple):
cn = '.'.join(fqn)
else:
cn = fqn
# Until we clean up the whole FQN mechanism, `fqn` may not be
# canonical, that is, in can appear as ('foo.bar', 'baz')
# This replaces any characters that might remain because of that.
pieces = cn.split('.')
if style == _NamingStyle.CAMEL:
pieces = tuple(misc.capitalize_initial(p) for p in pieces)
return ''.join(pieces)
elif style == _NamingStyle.SNAKE:
return '_'.join(pieces)
def class_name(self, original_fqn):
"""Returns the name of a converted class."""
canonical_name = self._as_symbol_name(
original_fqn, style=_NamingStyle.CAMEL)
new_name_root = 'Tf%s' % canonical_name
new_name = new_name_root
n = 0
while new_name in self.global_namespace:
n += 1
new_name = '%s_%d' % (new_name_root, n)
self.generated_names.add(new_name)
return new_name
def function_name(self, original_fqn):
"""Returns the name of a converted function."""
canonical_name = self._as_symbol_name(
original_fqn, style=_NamingStyle.SNAKE)
new_name_root = 'tf__%s' % canonical_name
new_name = new_name_root
n = 0
while new_name in self.global_namespace:
n += 1
new_name = '%s_%d' % (new_name_root, n)
self.generated_names.add(new_name)
return new_name
def new_symbol(self, name_root, reserved_locals):
"""See control_flow.SymbolNamer.new_symbol."""
# reserved_locals may contain QNs.
all_reserved_locals = set()
for s in reserved_locals:
if isinstance(s, qual_names.QN):
all_reserved_locals.update(s.qn)
elif isinstance(s, str):
all_reserved_locals.add(s)
else:
raise ValueError('Unexpected symbol type "%s"' % type(s))
pieces = name_root.split('_')
if pieces[-1].isdigit():
name_root = '_'.join(pieces[:-1])
n = int(pieces[-1])
else:
n = 0
new_name = name_root
while (new_name in self.global_namespace or
new_name in all_reserved_locals or new_name in self.generated_names):
n += 1
new_name = '%s_%d' % (name_root, n)
self.generated_names.add(new_name)
return new_name

View File

@ -48,12 +48,12 @@ from tensorflow.python.autograph.converters import slices
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.core import naming
from tensorflow.python.autograph.core import unsupported_features_checker
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
@ -572,7 +572,7 @@ def convert_func_to_ast(f, program_ctx, do_rename=True):
if isinstance(node, gast.Lambda):
new_name = namer.new_symbol('tf__lambda', ())
elif do_rename:
new_name = namer.function_name(f.__name__)
new_name = namer.new_symbol('tf__' + f.__name__, ())
else:
new_name = f.__name__

View File

@ -31,6 +31,7 @@ py_library(
"inspect_utils.py",
"loader.py",
"loader_deprecated_py2.py",
"naming.py",
"origin_info.py",
"parser.py",
"pretty_printer.py",
@ -133,6 +134,17 @@ sh_test(
tags = ["no_oss"],
)
py_test(
name = "naming_test",
srcs = ["naming_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":pyct",
"//tensorflow/python:client_testlib",
],
)
py_test(
name = "origin_info_test",
srcs = ["origin_info_test.py"],

View File

@ -0,0 +1,57 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Symbol naming utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.pyct import qual_names
class Namer(object):
"""Symbol name generator."""
def __init__(self, global_namespace):
self.global_namespace = global_namespace
self.generated_names = set()
def new_symbol(self, name_root, reserved_locals):
"""See control_flow.SymbolNamer.new_symbol."""
# reserved_locals may contain QNs.
all_reserved_locals = set()
for s in reserved_locals:
if isinstance(s, qual_names.QN):
all_reserved_locals.update(s.qn)
elif isinstance(s, str):
all_reserved_locals.add(s)
else:
raise ValueError('Unexpected symbol type "%s"' % type(s))
pieces = name_root.split('_')
if pieces[-1].isdigit():
name_root = '_'.join(pieces[:-1])
n = int(pieces[-1])
else:
n = 0
new_name = name_root
while (new_name in self.global_namespace or
new_name in all_reserved_locals or new_name in self.generated_names):
n += 1
new_name = '%s_%d' % (name_root, n)
self.generated_names.add(new_name)
return new_name

View File

@ -18,40 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.core import naming
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.platform import test
class NamerTest(test.TestCase):
def test_function_name_tracks_names(self):
namer = naming.Namer({})
self.assertEqual('tf__foo', namer.function_name('foo'))
self.assertEqual('tf__bar', namer.function_name('bar'))
self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names)
def test_function_name_consistent(self):
namer = naming.Namer({})
self.assertEqual('tf__foo', namer.function_name('foo'))
self.assertEqual('tf__foo', namer.function_name('foo'))
def test_function_name_unsanitized_fqn(self):
namer = naming.Namer({})
self.assertEqual('tf__foo_bar', namer.function_name('foo.bar'))
self.assertEqual('tf__foo_bar_baz', namer.function_name(('foo.bar', 'baz')))
def test_class_name_basic(self):
namer = naming.Namer({})
self.assertEqual('TfFooBar', namer.class_name(('foo', 'Bar')))
def test_class_name_unsanitized_fqn(self):
namer = naming.Namer({})
self.assertEqual('TfFooBarBaz', namer.class_name(('foo.bar', 'Baz')))
def test_function_name_avoids_global_conflicts(self):
namer = naming.Namer({'tf__foo': 1})
self.assertEqual('tf__foo_1', namer.function_name('foo'))
def test_new_symbol_tracks_names(self):
namer = naming.Namer({})
self.assertEqual('temp', namer.new_symbol('temp', set()))