From c0348bb552f094788155a1d3e2b65672ffc4ec8a Mon Sep 17 00:00:00 2001 From: Anna R Date: Mon, 11 Sep 2017 13:59:13 -0700 Subject: [PATCH] Update tf_export.py to take constant name as an argument instead of a constant. PiperOrigin-RevId: 168280613 --- tensorflow/python/BUILD | 11 +++++++++++ tensorflow/python/util/tf_export.py | 12 ++++++------ tensorflow/python/util/tf_export_test.py | 5 ++--- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 0d8253bcdca..5753f4d0178 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -342,6 +342,17 @@ py_test( ], ) +py_test( + name = "tf_export_test", + srcs = ["util/tf_export_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":platform", + ":util", + ], +) + py_test( name = "deprecation_test", srcs = ["util/deprecation_test.py"], diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index 47396997d34..a30b8b13363 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -34,7 +34,7 @@ tf_export('foo', 'bar.foo')(foo) Exporting a constant ```python foo = 1 -tf_export("consts.foo").export_constant(__name__, foo) +tf_export("consts.foo").export_constant(__name__, 'foo') ``` """ from __future__ import absolute_import @@ -102,7 +102,7 @@ class tf_export(object): # pylint: disable=invalid-name # pylint: enable=protected-access return func - def export_constant(self, module_name, value): + def export_constant(self, module_name, name): """Store export information for constants/string literals. Export information is stored in the module where constants/string literals @@ -112,17 +112,17 @@ class tf_export(object): # pylint: disable=invalid-name ```python foo = 1 bar = 2 - tf_export("consts.foo").export_constant(__name__, foo) - tf_export("consts.bar").export_constant(__name__, bar) + tf_export("consts.foo").export_constant(__name__, 'foo') + tf_export("consts.bar").export_constant(__name__, 'bar') ``` Args: module_name: (string) Name of the module to store constant at. - value: Value of the constant. + name: (string) Current constant name. """ module = sys.modules[module_name] if not hasattr(module, '_tf_api_constants'): module._tf_api_constants = [] # pylint: disable=protected-access # pylint: disable=protected-access - module._tf_api_constants.append((self._names, value)) + module._tf_api_constants.append((self._names, name)) diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py index 3b7636c34e4..b9e26ecb333 100644 --- a/tensorflow/python/util/tf_export_test.py +++ b/tensorflow/python/util/tf_export_test.py @@ -97,10 +97,9 @@ class ValidateExportTest(test.TestCase): def testExportSingleConstant(self): module1 = self._CreateMockModule('module1') - test_constant = 123 export_decorator = tf_export.tf_export('NAME_A', 'NAME_B') - export_decorator.export_constant('module1', test_constant) - self.assertEquals([(('NAME_A', 'NAME_B'), 123)], + export_decorator.export_constant('module1', 'test_constant') + self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')], module1._tf_api_constants) def testExportMultipleConstants(self):