Update tf_export.py to take constant name as an argument instead of a constant.

PiperOrigin-RevId: 168280613
This commit is contained in:
Anna R 2017-09-11 13:59:13 -07:00 committed by TensorFlower Gardener
parent c3d19e40ae
commit c0348bb552
3 changed files with 19 additions and 9 deletions

View File

@ -342,6 +342,17 @@ py_test(
], ],
) )
py_test(
name = "tf_export_test",
srcs = ["util/tf_export_test.py"],
srcs_version = "PY2AND3",
deps = [
":client_testlib",
":platform",
":util",
],
)
py_test( py_test(
name = "deprecation_test", name = "deprecation_test",
srcs = ["util/deprecation_test.py"], srcs = ["util/deprecation_test.py"],

View File

@ -34,7 +34,7 @@ tf_export('foo', 'bar.foo')(foo)
Exporting a constant Exporting a constant
```python ```python
foo = 1 foo = 1
tf_export("consts.foo").export_constant(__name__, foo) tf_export("consts.foo").export_constant(__name__, 'foo')
``` ```
""" """
from __future__ import absolute_import from __future__ import absolute_import
@ -102,7 +102,7 @@ class tf_export(object): # pylint: disable=invalid-name
# pylint: enable=protected-access # pylint: enable=protected-access
return func return func
def export_constant(self, module_name, value): def export_constant(self, module_name, name):
"""Store export information for constants/string literals. """Store export information for constants/string literals.
Export information is stored in the module where constants/string literals Export information is stored in the module where constants/string literals
@ -112,17 +112,17 @@ class tf_export(object): # pylint: disable=invalid-name
```python ```python
foo = 1 foo = 1
bar = 2 bar = 2
tf_export("consts.foo").export_constant(__name__, foo) tf_export("consts.foo").export_constant(__name__, 'foo')
tf_export("consts.bar").export_constant(__name__, bar) tf_export("consts.bar").export_constant(__name__, 'bar')
``` ```
Args: Args:
module_name: (string) Name of the module to store constant at. module_name: (string) Name of the module to store constant at.
value: Value of the constant. name: (string) Current constant name.
""" """
module = sys.modules[module_name] module = sys.modules[module_name]
if not hasattr(module, '_tf_api_constants'): if not hasattr(module, '_tf_api_constants'):
module._tf_api_constants = [] # pylint: disable=protected-access module._tf_api_constants = [] # pylint: disable=protected-access
# pylint: disable=protected-access # pylint: disable=protected-access
module._tf_api_constants.append((self._names, value)) module._tf_api_constants.append((self._names, name))

View File

@ -97,10 +97,9 @@ class ValidateExportTest(test.TestCase):
def testExportSingleConstant(self): def testExportSingleConstant(self):
module1 = self._CreateMockModule('module1') module1 = self._CreateMockModule('module1')
test_constant = 123
export_decorator = tf_export.tf_export('NAME_A', 'NAME_B') export_decorator = tf_export.tf_export('NAME_A', 'NAME_B')
export_decorator.export_constant('module1', test_constant) export_decorator.export_constant('module1', 'test_constant')
self.assertEquals([(('NAME_A', 'NAME_B'), 123)], self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')],
module1._tf_api_constants) module1._tf_api_constants)
def testExportMultipleConstants(self): def testExportMultipleConstants(self):