Add a mechanism internally in tf_export to allow grabbing symbols from their canonical names. This will be used internally inside TF for serializing to/from symbols by their exported tf api name.
PiperOrigin-RevId: 313666229 Change-Id: I418d7765f7eb901eda1f3075a44b899ca4530c3e
This commit is contained in:
parent
789ebcf00b
commit
bfd14e0aa3
@ -94,6 +94,12 @@ class InvalidSymbolNameError(Exception):
|
||||
"""Raised when trying to export symbol as an invalid or unallowed name."""
|
||||
pass
|
||||
|
||||
_NAME_TO_SYMBOL_MAPPING = dict()
|
||||
|
||||
|
||||
def get_symbol_from_name(name):
|
||||
return _NAME_TO_SYMBOL_MAPPING.get(name)
|
||||
|
||||
|
||||
def get_canonical_name_for_symbol(
|
||||
symbol, api_name=TENSORFLOW_API_NAME,
|
||||
@ -333,6 +339,11 @@ class api_export(object): # pylint: disable=invalid-name
|
||||
_, undecorated_func = tf_decorator.unwrap(func)
|
||||
self.set_attr(undecorated_func, api_names_attr, self._names)
|
||||
self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
|
||||
|
||||
for name in self._names:
|
||||
_NAME_TO_SYMBOL_MAPPING[name] = func
|
||||
for name_v1 in self._names_v1:
|
||||
_NAME_TO_SYMBOL_MAPPING['compat.v1.%s' % name_v1] = func
|
||||
return func
|
||||
|
||||
def set_attr(self, func, api_names_attr, names):
|
||||
|
@ -82,6 +82,33 @@ class ValidateExportTest(test.TestCase):
|
||||
tf_export.get_v1_names(decorated_function))
|
||||
self.assertEquals(['nameA', 'nameB'],
|
||||
tf_export.get_v2_names(decorated_function))
|
||||
self.assertEqual(tf_export.get_symbol_from_name('nameA'),
|
||||
decorated_function)
|
||||
self.assertEqual(tf_export.get_symbol_from_name('nameB'),
|
||||
decorated_function)
|
||||
self.assertEqual(
|
||||
tf_export.get_symbol_from_name(
|
||||
tf_export.get_canonical_name_for_symbol(decorated_function)),
|
||||
decorated_function)
|
||||
|
||||
def testExportSingleFunctionV1Only(self):
|
||||
export_decorator = tf_export.tf_export(v1=['nameA', 'nameB'])
|
||||
decorated_function = export_decorator(_test_function)
|
||||
self.assertEqual(decorated_function, _test_function)
|
||||
self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1)
|
||||
self.assertAllEqual(['nameA', 'nameB'],
|
||||
tf_export.get_v1_names(decorated_function))
|
||||
self.assertEqual([],
|
||||
tf_export.get_v2_names(decorated_function))
|
||||
self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'),
|
||||
decorated_function)
|
||||
self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'),
|
||||
decorated_function)
|
||||
self.assertEqual(
|
||||
tf_export.get_symbol_from_name(
|
||||
tf_export.get_canonical_name_for_symbol(
|
||||
decorated_function, add_prefix_to_v1_names=True)),
|
||||
decorated_function)
|
||||
|
||||
def testExportMultipleFunctions(self):
|
||||
export_decorator1 = tf_export.tf_export('nameA', 'nameB')
|
||||
@ -92,6 +119,20 @@ class ValidateExportTest(test.TestCase):
|
||||
self.assertEquals(decorated_function2, _test_function2)
|
||||
self.assertEquals(('nameA', 'nameB'), decorated_function1._tf_api_names)
|
||||
self.assertEquals(('nameC', 'nameD'), decorated_function2._tf_api_names)
|
||||
self.assertEqual(tf_export.get_symbol_from_name('nameB'),
|
||||
decorated_function1)
|
||||
self.assertEqual(tf_export.get_symbol_from_name('nameD'),
|
||||
decorated_function2)
|
||||
self.assertEqual(
|
||||
tf_export.get_symbol_from_name(
|
||||
tf_export.get_canonical_name_for_symbol(
|
||||
decorated_function1)),
|
||||
decorated_function1)
|
||||
self.assertEqual(
|
||||
tf_export.get_symbol_from_name(
|
||||
tf_export.get_canonical_name_for_symbol(
|
||||
decorated_function2)),
|
||||
decorated_function2)
|
||||
|
||||
def testExportClasses(self):
|
||||
export_decorator_a = tf_export.tf_export('TestClassA1')
|
||||
|
Loading…
Reference in New Issue
Block a user