Update tf_export.py to take constant name as an argument instead of a constant.

PiperOrigin-RevId: 168280613
This commit is contained in:
Anna R 2017-09-11 13:59:13 -07:00 committed by TensorFlower Gardener
parent c3d19e40ae
commit c0348bb552
3 changed files with 19 additions and 9 deletions

View File

@ -342,6 +342,17 @@ py_test(
],
)
py_test(
name = "tf_export_test",
srcs = ["util/tf_export_test.py"],
srcs_version = "PY2AND3",
deps = [
":client_testlib",
":platform",
":util",
],
)
py_test(
name = "deprecation_test",
srcs = ["util/deprecation_test.py"],

View File

@ -34,7 +34,7 @@ tf_export('foo', 'bar.foo')(foo)
Exporting a constant
```python
foo = 1
tf_export("consts.foo").export_constant(__name__, foo)
tf_export("consts.foo").export_constant(__name__, 'foo')
```
"""
from __future__ import absolute_import
@ -102,7 +102,7 @@ class tf_export(object): # pylint: disable=invalid-name
# pylint: enable=protected-access
return func
def export_constant(self, module_name, value):
def export_constant(self, module_name, name):
"""Store export information for constants/string literals.
Export information is stored in the module where constants/string literals
@ -112,17 +112,17 @@ class tf_export(object): # pylint: disable=invalid-name
```python
foo = 1
bar = 2
tf_export("consts.foo").export_constant(__name__, foo)
tf_export("consts.bar").export_constant(__name__, bar)
tf_export("consts.foo").export_constant(__name__, 'foo')
tf_export("consts.bar").export_constant(__name__, 'bar')
```
Args:
module_name: (string) Name of the module to store constant at.
value: Value of the constant.
name: (string) Current constant name.
"""
module = sys.modules[module_name]
if not hasattr(module, '_tf_api_constants'):
module._tf_api_constants = [] # pylint: disable=protected-access
# pylint: disable=protected-access
module._tf_api_constants.append((self._names, value))
module._tf_api_constants.append((self._names, name))

View File

@ -97,10 +97,9 @@ class ValidateExportTest(test.TestCase):
def testExportSingleConstant(self):
module1 = self._CreateMockModule('module1')
test_constant = 123
export_decorator = tf_export.tf_export('NAME_A', 'NAME_B')
export_decorator.export_constant('module1', test_constant)
self.assertEquals([(('NAME_A', 'NAME_B'), 123)],
export_decorator.export_constant('module1', 'test_constant')
self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')],
module1._tf_api_constants)
def testExportMultipleConstants(self):