diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD index d81723cf04c..655dc118a37 100644 --- a/tensorflow/python/autograph/core/BUILD +++ b/tensorflow/python/autograph/core/BUILD @@ -24,7 +24,6 @@ py_library( "config_lib.py", "converter.py", "function_wrappers.py", - "naming.py", "unsupported_features_checker.py", ], srcs_version = "PY2AND3", @@ -79,14 +78,3 @@ py_test( "//tensorflow/python:client_testlib", ], ) - -py_test( - name = "naming_test", - srcs = ["naming_test.py"], - python_version = "PY3", - srcs_version = "PY2AND3", - deps = [ - ":core", - "//tensorflow/python:client_testlib", - ], -) diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 4ea1187f8ed..4b170159b8b 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -30,9 +30,9 @@ from tensorflow.python.autograph import utils from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import function_wrappers -from tensorflow.python.autograph.core import naming from tensorflow.python.autograph.lang import special_functions from tensorflow.python.autograph.pyct import loader +from tensorflow.python.autograph.pyct import naming from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import pretty_printer diff --git a/tensorflow/python/autograph/core/naming.py b/tensorflow/python/autograph/core/naming.py deleted file mode 100644 index 67a565a9270..00000000000 --- a/tensorflow/python/autograph/core/naming.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Symbol naming utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import enum - -from tensorflow.python.autograph.pyct import qual_names -from tensorflow.python.autograph.utils import misc - - -class _NamingStyle(enum.Enum): - SNAKE = 1 - CAMEL = 2 - - -class Namer(object): - """Symbol name generator.""" - - def __init__(self, global_namespace): - self.global_namespace = global_namespace - self.generated_names = set() - - def _as_symbol_name(self, fqn, style=_NamingStyle.SNAKE): - """Returns a symbol name that matches a fully-qualified name. - - The returned name is safe to use for Python symbols. Any special characters - present in fqn are replaced according to the style argument. - - Examples: - - self._as_symbol_name('foo.bar', style=_NamingStyle.CAMEL) == 'FooBar' - self._as_symbol_name('foo.bar', style=_NamingStyle.SNAKE) == 'foo_bar' - - See the unit tests for more examples. - - Args: - fqn: Union[Text, Tuple[Text]] a fully-qualified symbol name. The qualifier - may include module, class names, attributes, etc. - style: _NamingStyle - Returns: - Text - """ - assert style in _NamingStyle - - if isinstance(fqn, tuple): - cn = '.'.join(fqn) - else: - cn = fqn - - # Until we clean up the whole FQN mechanism, `fqn` may not be - # canonical, that is, in can appear as ('foo.bar', 'baz') - # This replaces any characters that might remain because of that. - pieces = cn.split('.') - - if style == _NamingStyle.CAMEL: - pieces = tuple(misc.capitalize_initial(p) for p in pieces) - return ''.join(pieces) - elif style == _NamingStyle.SNAKE: - return '_'.join(pieces) - - def class_name(self, original_fqn): - """Returns the name of a converted class.""" - canonical_name = self._as_symbol_name( - original_fqn, style=_NamingStyle.CAMEL) - new_name_root = 'Tf%s' % canonical_name - new_name = new_name_root - n = 0 - while new_name in self.global_namespace: - n += 1 - new_name = '%s_%d' % (new_name_root, n) - self.generated_names.add(new_name) - return new_name - - def function_name(self, original_fqn): - """Returns the name of a converted function.""" - canonical_name = self._as_symbol_name( - original_fqn, style=_NamingStyle.SNAKE) - new_name_root = 'tf__%s' % canonical_name - new_name = new_name_root - n = 0 - while new_name in self.global_namespace: - n += 1 - new_name = '%s_%d' % (new_name_root, n) - self.generated_names.add(new_name) - return new_name - - def new_symbol(self, name_root, reserved_locals): - """See control_flow.SymbolNamer.new_symbol.""" - # reserved_locals may contain QNs. - all_reserved_locals = set() - for s in reserved_locals: - if isinstance(s, qual_names.QN): - all_reserved_locals.update(s.qn) - elif isinstance(s, str): - all_reserved_locals.add(s) - else: - raise ValueError('Unexpected symbol type "%s"' % type(s)) - - pieces = name_root.split('_') - if pieces[-1].isdigit(): - name_root = '_'.join(pieces[:-1]) - n = int(pieces[-1]) - else: - n = 0 - new_name = name_root - - while (new_name in self.global_namespace or - new_name in all_reserved_locals or new_name in self.generated_names): - n += 1 - new_name = '%s_%d' % (name_root, n) - - self.generated_names.add(new_name) - return new_name diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index e14c8e2bfcf..7134c2c0b69 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -48,12 +48,12 @@ from tensorflow.python.autograph.converters import slices from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import function_wrappers -from tensorflow.python.autograph.core import naming from tensorflow.python.autograph.core import unsupported_features_checker from tensorflow.python.autograph.lang import special_functions from tensorflow.python.autograph.pyct import ast_util from tensorflow.python.autograph.pyct import inspect_utils from tensorflow.python.autograph.pyct import loader +from tensorflow.python.autograph.pyct import naming from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import pretty_printer @@ -572,7 +572,7 @@ def convert_func_to_ast(f, program_ctx, do_rename=True): if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) elif do_rename: - new_name = namer.function_name(f.__name__) + new_name = namer.new_symbol('tf__' + f.__name__, ()) else: new_name = f.__name__ diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index 5311392263c..7881b17f88b 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -31,6 +31,7 @@ py_library( "inspect_utils.py", "loader.py", "loader_deprecated_py2.py", + "naming.py", "origin_info.py", "parser.py", "pretty_printer.py", @@ -133,6 +134,17 @@ sh_test( tags = ["no_oss"], ) +py_test( + name = "naming_test", + srcs = ["naming_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "origin_info_test", srcs = ["origin_info_test.py"], diff --git a/tensorflow/python/autograph/pyct/naming.py b/tensorflow/python/autograph/pyct/naming.py new file mode 100644 index 00000000000..c7d239bd7e6 --- /dev/null +++ b/tensorflow/python/autograph/pyct/naming.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Symbol naming utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.pyct import qual_names + + +class Namer(object): + """Symbol name generator.""" + + def __init__(self, global_namespace): + self.global_namespace = global_namespace + self.generated_names = set() + + def new_symbol(self, name_root, reserved_locals): + """See control_flow.SymbolNamer.new_symbol.""" + # reserved_locals may contain QNs. + all_reserved_locals = set() + for s in reserved_locals: + if isinstance(s, qual_names.QN): + all_reserved_locals.update(s.qn) + elif isinstance(s, str): + all_reserved_locals.add(s) + else: + raise ValueError('Unexpected symbol type "%s"' % type(s)) + + pieces = name_root.split('_') + if pieces[-1].isdigit(): + name_root = '_'.join(pieces[:-1]) + n = int(pieces[-1]) + else: + n = 0 + new_name = name_root + + while (new_name in self.global_namespace or + new_name in all_reserved_locals or new_name in self.generated_names): + n += 1 + new_name = '%s_%d' % (name_root, n) + + self.generated_names.add(new_name) + return new_name diff --git a/tensorflow/python/autograph/core/naming_test.py b/tensorflow/python/autograph/pyct/naming_test.py similarity index 60% rename from tensorflow/python/autograph/core/naming_test.py rename to tensorflow/python/autograph/pyct/naming_test.py index 49526ed77f3..61fe22068e4 100644 --- a/tensorflow/python/autograph/core/naming_test.py +++ b/tensorflow/python/autograph/pyct/naming_test.py @@ -18,40 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.autograph.core import naming +from tensorflow.python.autograph.pyct import naming from tensorflow.python.platform import test class NamerTest(test.TestCase): - def test_function_name_tracks_names(self): - namer = naming.Namer({}) - self.assertEqual('tf__foo', namer.function_name('foo')) - self.assertEqual('tf__bar', namer.function_name('bar')) - self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names) - - def test_function_name_consistent(self): - namer = naming.Namer({}) - self.assertEqual('tf__foo', namer.function_name('foo')) - self.assertEqual('tf__foo', namer.function_name('foo')) - - def test_function_name_unsanitized_fqn(self): - namer = naming.Namer({}) - self.assertEqual('tf__foo_bar', namer.function_name('foo.bar')) - self.assertEqual('tf__foo_bar_baz', namer.function_name(('foo.bar', 'baz'))) - - def test_class_name_basic(self): - namer = naming.Namer({}) - self.assertEqual('TfFooBar', namer.class_name(('foo', 'Bar'))) - - def test_class_name_unsanitized_fqn(self): - namer = naming.Namer({}) - self.assertEqual('TfFooBarBaz', namer.class_name(('foo.bar', 'Baz'))) - - def test_function_name_avoids_global_conflicts(self): - namer = naming.Namer({'tf__foo': 1}) - self.assertEqual('tf__foo_1', namer.function_name('foo')) - def test_new_symbol_tracks_names(self): namer = naming.Namer({}) self.assertEqual('temp', namer.new_symbol('temp', set()))