Merge pull request #32773 from guillaumekln:fix-vocab-table-tracking
PiperOrigin-RevId: 273318926
This commit is contained in:
commit
0a2af00d23
@ -42,7 +42,10 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util as trackable
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class BaseLookupTableTest(test.TestCase):
|
||||
@ -899,6 +902,19 @@ class StaticVocabularyTableTest(BaseLookupTableTest):
|
||||
self.assertAllEqual([3, 1, 3], self.evaluate(out2))
|
||||
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size()))
|
||||
|
||||
def testStaticVocabularyTableAssetTracking(self):
|
||||
vocab_file = self._createVocabFile("vocab.txt")
|
||||
vocab_size = 3
|
||||
oov_buckets = 1
|
||||
table = self.getVocabularyTable()(lookup_ops.TextFileIdTableInitializer(
|
||||
vocab_file, vocab_size=vocab_size), oov_buckets)
|
||||
object_graph_view = graph_view.ObjectGraphView(table)
|
||||
objects = object_graph_view.list_objects()
|
||||
assets = list(filter(lambda obj: isinstance(obj, tracking.Asset), objects))
|
||||
self.assertLen(assets, 1)
|
||||
self.assertEqual(
|
||||
self.evaluate(assets[0].asset_path), compat.as_bytes(vocab_file))
|
||||
|
||||
def testSparseTensor(self):
|
||||
vocab_file = self._createVocabFile("feat_to_id_7.txt")
|
||||
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
||||
|
@ -1118,6 +1118,8 @@ class StaticVocabularyTable(LookupInterface):
|
||||
if initializer.value_dtype != dtypes.int64:
|
||||
raise TypeError("Invalid value dtype, expected %s but got %s." %
|
||||
(dtypes.int64, initializer.value_dtype))
|
||||
if isinstance(initializer, trackable_base.Trackable):
|
||||
self._initializer = self._track_trackable(initializer, "_initializer")
|
||||
self._table = HashTable(initializer, default_value=-1)
|
||||
name = name or self._table.name
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user