Merge pull request #32773 from guillaumekln:fix-vocab-table-tracking

PiperOrigin-RevId: 273318926
This commit is contained in:
TensorFlower Gardener 2019-10-07 11:34:02 -07:00
commit 0a2af00d23
2 changed files with 18 additions and 0 deletions

View File

@ -42,7 +42,10 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util as trackable
from tensorflow.python.util import compat
class BaseLookupTableTest(test.TestCase):
@ -899,6 +902,19 @@ class StaticVocabularyTableTest(BaseLookupTableTest):
self.assertAllEqual([3, 1, 3], self.evaluate(out2))
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size()))
def testStaticVocabularyTableAssetTracking(self):
vocab_file = self._createVocabFile("vocab.txt")
vocab_size = 3
oov_buckets = 1
table = self.getVocabularyTable()(lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size), oov_buckets)
object_graph_view = graph_view.ObjectGraphView(table)
objects = object_graph_view.list_objects()
assets = list(filter(lambda obj: isinstance(obj, tracking.Asset), objects))
self.assertLen(assets, 1)
self.assertEqual(
self.evaluate(assets[0].asset_path), compat.as_bytes(vocab_file))
def testSparseTensor(self):
vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]

View File

@ -1118,6 +1118,8 @@ class StaticVocabularyTable(LookupInterface):
if initializer.value_dtype != dtypes.int64:
raise TypeError("Invalid value dtype, expected %s but got %s." %
(dtypes.int64, initializer.value_dtype))
if isinstance(initializer, trackable_base.Trackable):
self._initializer = self._track_trackable(initializer, "_initializer")
self._table = HashTable(initializer, default_value=-1)
name = name or self._table.name
else: