Add a mechanism internally in tf_export to allow grabbing symbols from their canonical names. This will be used internally inside TF for serializing to/from symbols by their exported tf api name.

PiperOrigin-RevId: 313666229
Change-Id: I418d7765f7eb901eda1f3075a44b899ca4530c3e
This commit is contained in:
Tomer Kaftan 2020-05-28 15:00:55 -07:00 committed by TensorFlower Gardener
parent 789ebcf00b
commit bfd14e0aa3
2 changed files with 52 additions and 0 deletions

View File

@ -94,6 +94,12 @@ class InvalidSymbolNameError(Exception):
"""Raised when trying to export symbol as an invalid or unallowed name."""
pass
_NAME_TO_SYMBOL_MAPPING = dict()
def get_symbol_from_name(name):
return _NAME_TO_SYMBOL_MAPPING.get(name)
def get_canonical_name_for_symbol(
symbol, api_name=TENSORFLOW_API_NAME,
@ -333,6 +339,11 @@ class api_export(object): # pylint: disable=invalid-name
_, undecorated_func = tf_decorator.unwrap(func)
self.set_attr(undecorated_func, api_names_attr, self._names)
self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
for name in self._names:
_NAME_TO_SYMBOL_MAPPING[name] = func
for name_v1 in self._names_v1:
_NAME_TO_SYMBOL_MAPPING['compat.v1.%s' % name_v1] = func
return func
def set_attr(self, func, api_names_attr, names):

View File

@ -82,6 +82,33 @@ class ValidateExportTest(test.TestCase):
tf_export.get_v1_names(decorated_function))
self.assertEquals(['nameA', 'nameB'],
tf_export.get_v2_names(decorated_function))
self.assertEqual(tf_export.get_symbol_from_name('nameA'),
decorated_function)
self.assertEqual(tf_export.get_symbol_from_name('nameB'),
decorated_function)
self.assertEqual(
tf_export.get_symbol_from_name(
tf_export.get_canonical_name_for_symbol(decorated_function)),
decorated_function)
def testExportSingleFunctionV1Only(self):
export_decorator = tf_export.tf_export(v1=['nameA', 'nameB'])
decorated_function = export_decorator(_test_function)
self.assertEqual(decorated_function, _test_function)
self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1)
self.assertAllEqual(['nameA', 'nameB'],
tf_export.get_v1_names(decorated_function))
self.assertEqual([],
tf_export.get_v2_names(decorated_function))
self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'),
decorated_function)
self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'),
decorated_function)
self.assertEqual(
tf_export.get_symbol_from_name(
tf_export.get_canonical_name_for_symbol(
decorated_function, add_prefix_to_v1_names=True)),
decorated_function)
def testExportMultipleFunctions(self):
export_decorator1 = tf_export.tf_export('nameA', 'nameB')
@ -92,6 +119,20 @@ class ValidateExportTest(test.TestCase):
self.assertEquals(decorated_function2, _test_function2)
self.assertEquals(('nameA', 'nameB'), decorated_function1._tf_api_names)
self.assertEquals(('nameC', 'nameD'), decorated_function2._tf_api_names)
self.assertEqual(tf_export.get_symbol_from_name('nameB'),
decorated_function1)
self.assertEqual(tf_export.get_symbol_from_name('nameD'),
decorated_function2)
self.assertEqual(
tf_export.get_symbol_from_name(
tf_export.get_canonical_name_for_symbol(
decorated_function1)),
decorated_function1)
self.assertEqual(
tf_export.get_symbol_from_name(
tf_export.get_canonical_name_for_symbol(
decorated_function2)),
decorated_function2)
def testExportClasses(self):
export_decorator_a = tf_export.tf_export('TestClassA1')