diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 2d249be1314..888cadd5344 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -42,7 +42,10 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import saver from tensorflow.python.training import server_lib +from tensorflow.python.training.tracking import graph_view +from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import util as trackable +from tensorflow.python.util import compat class BaseLookupTableTest(test.TestCase): @@ -899,6 +902,19 @@ class StaticVocabularyTableTest(BaseLookupTableTest): self.assertAllEqual([3, 1, 3], self.evaluate(out2)) self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size())) + def testStaticVocabularyTableAssetTracking(self): + vocab_file = self._createVocabFile("vocab.txt") + vocab_size = 3 + oov_buckets = 1 + table = self.getVocabularyTable()(lookup_ops.TextFileIdTableInitializer( + vocab_file, vocab_size=vocab_size), oov_buckets) + object_graph_view = graph_view.ObjectGraphView(table) + objects = object_graph_view.list_objects() + assets = list(filter(lambda obj: isinstance(obj, tracking.Asset), objects)) + self.assertLen(assets, 1) + self.assertEqual( + self.evaluate(assets[0].asset_path), compat.as_bytes(vocab_file)) + def testSparseTensor(self): vocab_file = self._createVocabFile("feat_to_id_7.txt") input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 0aabb5c7ecb..43fcb61042e 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -1118,6 +1118,8 @@ class StaticVocabularyTable(LookupInterface): if initializer.value_dtype != dtypes.int64: raise TypeError("Invalid value dtype, expected %s but got %s." % (dtypes.int64, initializer.value_dtype)) + if isinstance(initializer, trackable_base.Trackable): + self._initializer = self._track_trackable(initializer, "_initializer") self._table = HashTable(initializer, default_value=-1) name = name or self._table.name else: