Adding tf_export bindings for the tf.lookup API. Adds tf.lookup.StaticHashTable, tf.lookup.VocabularyTable, tf.lookup.experimental.DenseHashTable, tf.lookup.TextFileInitializer and tf.lookup.KeyValueTensorInitializer

PiperOrigin-RevId: 235990023
This commit is contained in:
Rohan Jain 2019-02-27 14:14:23 -08:00 committed by TensorFlower Gardener
parent 104098d46b
commit 6adc4f81d0
21 changed files with 839 additions and 385 deletions

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops.lookup_ops import IdTableWithHashBuckets
from tensorflow.python.ops.lookup_ops import index_table_from_file from tensorflow.python.ops.lookup_ops import index_table_from_file
from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file
from tensorflow.python.ops.lookup_ops import InitializableLookupTableBase from tensorflow.python.ops.lookup_ops import InitializableLookupTableBase
from tensorflow.python.ops.lookup_ops import InitializableLookupTableBaseV1
from tensorflow.python.ops.lookup_ops import KeyValueTensorInitializer from tensorflow.python.ops.lookup_ops import KeyValueTensorInitializer
from tensorflow.python.ops.lookup_ops import LookupInterface from tensorflow.python.ops.lookup_ops import LookupInterface
from tensorflow.python.ops.lookup_ops import StrongHashSpec from tensorflow.python.ops.lookup_ops import StrongHashSpec
@ -284,7 +285,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
return table.lookup(tensor) return table.lookup(tensor)
class HashTable(InitializableLookupTableBase): class HashTable(InitializableLookupTableBaseV1):
"""A generic hash table implementation. """A generic hash table implementation.
Example usage: Example usage:

View File

@ -22,6 +22,7 @@ import tempfile
import numpy as np import numpy as np
import six import six
from tensorflow.python import tf2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import counter from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
@ -41,16 +42,34 @@ from tensorflow.python.training import server_lib
from tensorflow.python.training.tracking import util as trackable from tensorflow.python.training.tracking import util as trackable
class StaticHashTableTest(test.TestCase): class BaseLookupTableTest(test.TestCase):
def getHashTable(self):
if tf2.enabled():
return lookup_ops.StaticHashTable
else:
return lookup_ops.StaticHashTableV1
def getVocabularyTable(self):
if tf2.enabled():
return lookup_ops.StaticVocabularyTable
else:
return lookup_ops.StaticVocabularyTableV1
def initialize_table(self, table):
if not tf2.enabled():
self.evaluate(table.initializer)
class StaticHashTableTest(BaseLookupTableTest):
def testStaticHashTable(self): def testStaticHashTable(self):
with self.cached_session():
default_val = -1 default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"]) keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
self.evaluate(table.initializer) self.initialize_table(table)
self.assertAllEqual(3, self.evaluate(table.size())) self.assertAllEqual(3, self.evaluate(table.size()))
@ -68,13 +87,12 @@ class StaticHashTableTest(test.TestCase):
self.assertItemsEqual([0, 1, 2], self.evaluate(exported_values_tensor)) self.assertItemsEqual([0, 1, 2], self.evaluate(exported_values_tensor))
def testStaticHashTableFindHighRank(self): def testStaticHashTableFindHighRank(self):
with self.cached_session():
default_val = -1 default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"]) keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
self.evaluate(table.initializer) self.initialize_table(table)
self.assertAllEqual(3, self.evaluate(table.size())) self.assertAllEqual(3, self.evaluate(table.size()))
@ -86,14 +104,13 @@ class StaticHashTableTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result) self.assertAllEqual([[0, 1], [-1, -1]], result)
def testStaticHashTableInitWithPythonArrays(self): def testStaticHashTableInitWithPythonArrays(self):
with self.cached_session():
default_val = -1 default_val = -1
keys = ["brain", "salad", "surgery"] keys = ["brain", "salad", "surgery"]
values = [0, 1, 2] values = [0, 1, 2]
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer( lookup_ops.KeyValueTensorInitializer(
keys, values, value_dtype=dtypes.int64), default_val) keys, values, value_dtype=dtypes.int64), default_val)
self.evaluate(table.initializer) self.initialize_table(table)
self.assertAllEqual(3, self.evaluate(table.size())) self.assertAllEqual(3, self.evaluate(table.size()))
@ -104,13 +121,12 @@ class StaticHashTableTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result) self.assertAllEqual([0, 1, -1], result)
def testStaticHashTableInitWithNumPyArrays(self): def testStaticHashTableInitWithNumPyArrays(self):
with self.cached_session():
default_val = -1 default_val = -1
keys = np.array(["brain", "salad", "surgery"], dtype=np.str) keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
values = np.array([0, 1, 2], dtype=np.int64) values = np.array([0, 1, 2], dtype=np.int64)
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
self.evaluate(table.initializer) self.initialize_table(table)
self.assertAllEqual(3, self.evaluate(table.size())) self.assertAllEqual(3, self.evaluate(table.size()))
@ -121,22 +137,20 @@ class StaticHashTableTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result) self.assertAllEqual([0, 1, -1], result)
def testMultipleStaticHashTables(self): def testMultipleStaticHashTables(self):
with self.cached_session():
default_val = -1 default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"]) keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64)
table1 = lookup_ops.StaticHashTable( table1 = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table2 = lookup_ops.StaticHashTable( table2 = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table3 = lookup_ops.StaticHashTable( table3 = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
self.evaluate(table1.initializer) self.initialize_table(table1)
self.evaluate(table2.initializer) self.initialize_table(table2)
self.evaluate(table3.initializer) self.initialize_table(table3)
self.assertAllEqual(3, self.evaluate(table1.size())) self.assertAllEqual(3, self.evaluate(table1.size()))
self.assertAllEqual(3, self.evaluate(table2.size())) self.assertAllEqual(3, self.evaluate(table2.size()))
self.assertAllEqual(3, self.evaluate(table3.size())) self.assertAllEqual(3, self.evaluate(table3.size()))
@ -152,13 +166,12 @@ class StaticHashTableTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3) self.assertAllEqual([0, 1, -1], out3)
def testStaticHashTableWithTensorDefault(self): def testStaticHashTableWithTensorDefault(self):
with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64) default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"]) keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
self.evaluate(table.initializer) self.initialize_table(table)
input_string = constant_op.constant(["brain", "salad", "tank"]) input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string) output = table.lookup(input_string)
@ -167,13 +180,12 @@ class StaticHashTableTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result) self.assertAllEqual([0, 1, -1], result)
def testStaticHashTableWithSparseTensorInput(self): def testStaticHashTableWithSparseTensorInput(self):
with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64) default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"]) keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
self.evaluate(table.initializer) self.initialize_table(table)
sp_indices = [[0, 0], [0, 1], [1, 0]] sp_indices = [[0, 0], [0, 1], [1, 0]]
sp_shape = [2, 2] sp_shape = [2, 2]
@ -190,13 +202,12 @@ class StaticHashTableTest(test.TestCase):
self.assertAllEqual(sp_shape, out_shape) self.assertAllEqual(sp_shape, out_shape)
def testSignatureMismatch(self): def testSignatureMismatch(self):
with self.cached_session():
default_val = -1 default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"]) keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
self.evaluate(table.initializer) self.initialize_table(table)
# Ref types do not produce a lookup signature mismatch. # Ref types do not produce a lookup signature mismatch.
input_string_ref = variables.Variable("brain") input_string_ref = variables.Variable("brain")
@ -208,22 +219,21 @@ class StaticHashTableTest(test.TestCase):
table.lookup(input_string) table.lookup(input_string)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
lookup_ops.StaticHashTable( self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), "UNK") lookup_ops.KeyValueTensorInitializer(keys, values), "UNK")
def testDTypes(self): def testDTypes(self):
with self.cached_session():
default_val = -1 default_val = -1
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
lookup_ops.StaticHashTable( self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(["a"], [1], [dtypes.string], lookup_ops.KeyValueTensorInitializer(["a"], [1], [dtypes.string],
dtypes.int64), default_val) dtypes.int64), default_val)
@test_util.run_deprecated_v1 @test_util.run_v1_only
def testNotInitialized(self): def testNotInitialized(self):
with self.cached_session(): with self.cached_session():
default_val = -1 default_val = -1
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(["a"], [1], lookup_ops.KeyValueTensorInitializer(["a"], [1],
value_dtype=dtypes.int64), value_dtype=dtypes.int64),
default_val) default_val)
@ -234,31 +244,32 @@ class StaticHashTableTest(test.TestCase):
with self.assertRaisesOpError("Table not initialized"): with self.assertRaisesOpError("Table not initialized"):
self.evaluate(output) self.evaluate(output)
@test_util.run_deprecated_v1 @test_util.run_v1_only
def testInitializeTwice(self): def testInitializeTwice(self):
with self.cached_session(): with self.cached_session():
default_val = -1 default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"]) keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
self.evaluate(table.initializer) self.initialize_table(table)
with self.assertRaisesOpError("Table already initialized"): with self.assertRaisesOpError("Table already initialized"):
self.evaluate(table.initializer) self.initialize_table(table)
@test_util.run_deprecated_v1
def testInitializationWithInvalidDimensions(self): def testInitializationWithInvalidDimensions(self):
with self.cached_session():
default_val = -1 default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"]) keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
with self.assertRaises(ValueError): raised_error = ValueError
lookup_ops.StaticHashTable( if context.executing_eagerly():
raised_error = errors_impl.InvalidArgumentError
with self.assertRaises(raised_error):
self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
@test_util.run_deprecated_v1 @test_util.run_v1_only
def testMultipleSessions(self): def testMultipleSessions(self):
# Start a server # Start a server
server = server_lib.Server({"local0": ["localhost:0"]}, server = server_lib.Server({"local0": ["localhost:0"]},
@ -271,14 +282,14 @@ class StaticHashTableTest(test.TestCase):
default_val = -1 default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"]) keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), lookup_ops.KeyValueTensorInitializer(keys, values),
default_val, default_val,
name="t1") name="t1")
# Init the table in the first session. # Init the table in the first session.
with session1: with session1:
self.evaluate(table.initializer) self.initialize_table(table)
self.assertAllEqual(3, self.evaluate(table.size())) self.assertAllEqual(3, self.evaluate(table.size()))
# Init the table in the second session and verify that we do not get a # Init the table in the second session and verify that we do not get a
@ -288,13 +299,12 @@ class StaticHashTableTest(test.TestCase):
self.assertAllEqual(3, self.evaluate(table.size())) self.assertAllEqual(3, self.evaluate(table.size()))
def testStaticHashTableInt32String(self): def testStaticHashTableInt32String(self):
with self.cached_session():
default_val = "n/a" default_val = "n/a"
keys = constant_op.constant([0, 1, 2], dtypes.int32) keys = constant_op.constant([0, 1, 2], dtypes.int32)
values = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.StaticHashTable( table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val) lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
self.evaluate(table.initializer) self.initialize_table(table)
input_tensor = constant_op.constant([0, 1, -1]) input_tensor = constant_op.constant([0, 1, -1])
output = table.lookup(input_tensor) output = table.lookup(input_tensor)
@ -569,53 +579,45 @@ class IndexTableFromFile(test.TestCase):
self.assertIsNotNone(table.resource_handle) self.assertIsNotNone(table.resource_handle)
class KeyValueTensorInitializerTest(test.TestCase): class KeyValueTensorInitializerTest(BaseLookupTableTest):
def test_string(self): def test_string(self):
with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer( init = lookup_ops.KeyValueTensorInitializer(
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
self.assertEqual("", init._shared_name) table = self.getHashTable()(init, default_value=-1)
table = lookup_ops.StaticHashTable(init, default_value=-1) self.initialize_table(table)
table.initializer.run()
def test_multiple_tables(self): def test_multiple_tables(self):
with ops.Graph().as_default(), self.cached_session():
with ops.name_scope("table_scope"): with ops.name_scope("table_scope"):
init1 = lookup_ops.KeyValueTensorInitializer( init1 = lookup_ops.KeyValueTensorInitializer(
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string,
dtypes.int64) dtypes.int64)
self.assertEqual("", init1._shared_name) table1 = self.getHashTable()(init1, default_value=-1)
table1 = lookup_ops.StaticHashTable(init1, default_value=-1) if not context.executing_eagerly():
self.assertEqual("hash_table", table1.name) self.assertEqual("hash_table", table1.name)
self.assertEqual("table_scope/hash_table", self.assertEqual("table_scope/hash_table",
table1.resource_handle.op.name) table1.resource_handle.op.name)
init2 = lookup_ops.KeyValueTensorInitializer( init2 = lookup_ops.KeyValueTensorInitializer(
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string,
dtypes.int64) dtypes.int64)
self.assertEqual("", init2._shared_name) table2 = self.getHashTable()(init2, default_value=-1)
table2 = lookup_ops.StaticHashTable(init2, default_value=-1) if not context.executing_eagerly():
self.assertEqual("hash_table_1", table2.name) self.assertEqual("hash_table_1", table2.name)
self.assertEqual("table_scope/hash_table_1", self.assertEqual("table_scope/hash_table_1",
table2.resource_handle.op.name) table2.resource_handle.op.name)
def test_int64(self): def test_int64(self):
with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
dtypes.int64, dtypes.int64) dtypes.int64, dtypes.int64)
self.assertEqual("", init._shared_name) table = self.getHashTable()(init, default_value=-1)
table = lookup_ops.StaticHashTable(init, default_value=-1) self.initialize_table(table)
table.initializer.run()
def test_int32(self): def test_int32(self):
with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
dtypes.int32, dtypes.int64) dtypes.int32, dtypes.int64)
self.assertEqual("", init._shared_name) with self.assertRaises(errors_impl.OpError):
table = lookup_ops.StaticHashTable(init, default_value=-1) table = self.getHashTable()(init, default_value=-1)
with self.assertRaisesRegexp(errors_impl.OpError, self.initialize_table(table)
"No OpKernel was registered"):
table.initializer.run()
class IndexTableFromTensor(test.TestCase): class IndexTableFromTensor(test.TestCase):
@ -866,7 +868,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
self.evaluate(features)) self.evaluate(features))
class InitializeTableFromFileOpTest(test.TestCase): class InitializeTableFromFileOpTest(BaseLookupTableTest):
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
vocabulary_file = os.path.join(self.get_temp_dir(), basename) vocabulary_file = os.path.join(self.get_temp_dir(), basename)
@ -874,7 +876,6 @@ class InitializeTableFromFileOpTest(test.TestCase):
f.write("\n".join(values) + "\n") f.write("\n".join(values) + "\n")
return vocabulary_file return vocabulary_file
@test_util.run_in_graph_and_eager_modes
def testInitializeStringTable(self): def testInitializeStringTable(self):
vocabulary_file = self._createVocabFile("one_column_1.txt") vocabulary_file = self._createVocabFile("one_column_1.txt")
default_value = -1 default_value = -1
@ -882,8 +883,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
self.assertTrue("one_column_1.txt_-2_-1" in init._shared_name) self.assertTrue("one_column_1.txt_-2_-1" in init._shared_name)
table = lookup_ops.StaticHashTable(init, default_value) table = self.getHashTable()(init, default_value)
self.evaluate(table.initializer) self.initialize_table(table)
output = table.lookup(constant_op.constant(["brain", "salad", "tank"])) output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
@ -900,8 +901,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocabulary_file, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE, vocabulary_file, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
self.assertTrue("one_column_int64.txt_-2_-1" in init._shared_name) self.assertTrue("one_column_int64.txt_-2_-1" in init._shared_name)
table = lookup_ops.StaticHashTable(init, default_value) table = self.getHashTable()(init, default_value)
self.evaluate(table.initializer) self.initialize_table(table)
output = table.lookup( output = table.lookup(
constant_op.constant((42, 1, 11), dtype=dtypes.int64)) constant_op.constant((42, 1, 11), dtype=dtypes.int64))
@ -919,8 +920,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
init = lookup_ops.TextFileInitializer( init = lookup_ops.TextFileInitializer(
vocabulary_file, dtypes.int64, key_index, dtypes.string, value_index) vocabulary_file, dtypes.int64, key_index, dtypes.string, value_index)
self.assertTrue("one_column_2.txt_-1_-2" in init._shared_name) self.assertTrue("one_column_2.txt_-1_-2" in init._shared_name)
table = lookup_ops.StaticHashTable(init, default_value) table = self.getHashTable()(init, default_value)
self.evaluate(table.initializer) self.initialize_table(table)
input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64) input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
output = table.lookup(input_values) output = table.lookup(input_values)
@ -941,8 +942,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
init = lookup_ops.TextFileInitializer( init = lookup_ops.TextFileInitializer(
vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index) vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
self.assertTrue("three_columns.txt_1_2" in init._shared_name) self.assertTrue("three_columns.txt_1_2" in init._shared_name)
table = lookup_ops.StaticHashTable(init, default_value) table = self.getHashTable()(init, default_value)
self.evaluate(table.initializer) self.initialize_table(table)
input_string = constant_op.constant(["brain", "salad", "surgery"]) input_string = constant_op.constant(["brain", "salad", "surgery"])
output = table.lookup(input_string) output = table.lookup(input_string)
@ -963,8 +964,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index) vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
self.assertTrue("three_columns.txt_2_1" in init._shared_name) self.assertTrue("three_columns.txt_2_1" in init._shared_name)
with self.assertRaisesOpError("is not a valid"): with self.assertRaisesOpError("is not a valid"):
table = lookup_ops.StaticHashTable(init, default_value) table = self.getHashTable()(init, default_value)
self.evaluate(table.initializer) self.initialize_table(table)
def testInvalidDataType(self): def testInvalidDataType(self):
vocabulary_file = self._createVocabFile("one_column_3.txt") vocabulary_file = self._createVocabFile("one_column_3.txt")
@ -979,7 +980,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
key_index, dtypes.string, key_index, dtypes.string,
value_index) value_index)
self.assertTrue("one_column_3.txt_-2_-1" in init._shared_name) self.assertTrue("one_column_3.txt_-2_-1" in init._shared_name)
lookup_ops.StaticHashTable(init, default_value) self.getHashTable()(init, default_value)
def testInvalidIndex(self): def testInvalidIndex(self):
vocabulary_file = self._createVocabFile("one_column_4.txt") vocabulary_file = self._createVocabFile("one_column_4.txt")
@ -992,8 +993,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
self.assertTrue("one_column_4.txt_1_-1" in init._shared_name) self.assertTrue("one_column_4.txt_1_-1" in init._shared_name)
with self.assertRaisesOpError("Invalid number of columns"): with self.assertRaisesOpError("Invalid number of columns"):
table = lookup_ops.StaticHashTable(init, default_value) table = self.getHashTable()(init, default_value)
self.evaluate(table.initializer) self.initialize_table(table)
def testInitializeSameTableWithMultipleNodes(self): def testInitializeSameTableWithMultipleNodes(self):
vocabulary_file = self._createVocabFile("one_column_5.txt") vocabulary_file = self._createVocabFile("one_column_5.txt")
@ -1004,17 +1005,17 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
self.assertTrue("one_column_5.txt_-2_-1" in init1._shared_name) self.assertTrue("one_column_5.txt_-2_-1" in init1._shared_name)
table1 = lookup_ops.StaticHashTable(init1, default_value) table1 = self.getHashTable()(init1, default_value)
init2 = lookup_ops.TextFileInitializer( init2 = lookup_ops.TextFileInitializer(
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
self.assertTrue("one_column_5.txt_-2_-1" in init2._shared_name) self.assertTrue("one_column_5.txt_-2_-1" in init2._shared_name)
table2 = lookup_ops.StaticHashTable(init2, default_value) table2 = self.getHashTable()(init2, default_value)
init3 = lookup_ops.TextFileInitializer( init3 = lookup_ops.TextFileInitializer(
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
self.assertTrue("one_column_5.txt_-2_-1" in init3._shared_name) self.assertTrue("one_column_5.txt_-2_-1" in init3._shared_name)
table3 = lookup_ops.StaticHashTable(init3, default_value) table3 = self.getHashTable()(init3, default_value)
self.evaluate(lookup_ops.tables_initializer()) self.evaluate(lookup_ops.tables_initializer())
@ -1033,7 +1034,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with self.cached_session(): with self.cached_session():
default_value = -1 default_value = -1
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
lookup_ops.StaticHashTable( self.getHashTable()(
lookup_ops.TextFileInitializer( lookup_ops.TextFileInitializer(
"", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, "", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
@ -1052,10 +1053,10 @@ class InitializeTableFromFileOpTest(test.TestCase):
lookup_ops.TextFileIndex.LINE_NUMBER, lookup_ops.TextFileIndex.LINE_NUMBER,
vocab_size=vocab_size) vocab_size=vocab_size)
self.assertTrue("one_column6.txt_3_-2_-1" in init1._shared_name) self.assertTrue("one_column6.txt_3_-2_-1" in init1._shared_name)
table1 = lookup_ops.StaticHashTable(init1, default_value) table1 = self.getHashTable()(init1, default_value)
# Initialize from file. # Initialize from file.
self.evaluate(table1.initializer) self.initialize_table(table1)
self.assertEqual(vocab_size, self.evaluate(table1.size())) self.assertEqual(vocab_size, self.evaluate(table1.size()))
vocabulary_file2 = self._createVocabFile("one_column7.txt") vocabulary_file2 = self._createVocabFile("one_column7.txt")
@ -1069,8 +1070,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocab_size=vocab_size) vocab_size=vocab_size)
self.assertTrue("one_column7.txt_5_-2_-1" in init2._shared_name) self.assertTrue("one_column7.txt_5_-2_-1" in init2._shared_name)
with self.assertRaisesOpError("Invalid vocab_size"): with self.assertRaisesOpError("Invalid vocab_size"):
table2 = lookup_ops.StaticHashTable(init2, default_value) table2 = self.getHashTable()(init2, default_value)
self.evaluate(table2.initializer) self.initialize_table(table2)
vocab_size = 1 vocab_size = 1
vocabulary_file3 = self._createVocabFile("one_column3.txt") vocabulary_file3 = self._createVocabFile("one_column3.txt")
@ -1082,10 +1083,10 @@ class InitializeTableFromFileOpTest(test.TestCase):
lookup_ops.TextFileIndex.LINE_NUMBER, lookup_ops.TextFileIndex.LINE_NUMBER,
vocab_size=vocab_size) vocab_size=vocab_size)
self.assertTrue("one_column3.txt_1_-2_-1" in init3._shared_name) self.assertTrue("one_column3.txt_1_-2_-1" in init3._shared_name)
table3 = lookup_ops.StaticHashTable(init3, default_value) table3 = self.getHashTable()(init3, default_value)
# Smaller vocab size reads only vocab_size records. # Smaller vocab size reads only vocab_size records.
self.evaluate(table3.initializer) self.initialize_table(table3)
self.assertEqual(vocab_size, self.evaluate(table3.size())) self.assertEqual(vocab_size, self.evaluate(table3.size()))
@test_util.run_v1_only("placeholder usage") @test_util.run_v1_only("placeholder usage")
@ -1098,7 +1099,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
"old_file.txt", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, "old_file.txt", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
self.assertTrue("old_file.txt_-2_-1" in init._shared_name) self.assertTrue("old_file.txt_-2_-1" in init._shared_name)
table = lookup_ops.StaticHashTable(init, default_value) table = self.getHashTable()(init, default_value)
# Initialize with non existing file (old_file.txt) should fail. # Initialize with non existing file (old_file.txt) should fail.
# TODO(yleon): Update message, which might change per FileSystem. # TODO(yleon): Update message, which might change per FileSystem.
@ -1124,7 +1125,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
# Invalid data type # Invalid data type
other_type = constant_op.constant(1) other_type = constant_op.constant(1)
with self.assertRaises(Exception) as cm: with self.assertRaises(Exception) as cm:
lookup_ops.StaticHashTable( self.getHashTable()(
lookup_ops.TextFileInitializer( lookup_ops.TextFileInitializer(
other_type, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, other_type, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
@ -1135,7 +1136,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
filenames = constant_op.constant([vocabulary_file, vocabulary_file]) filenames = constant_op.constant([vocabulary_file, vocabulary_file])
if not context.executing_eagerly(): if not context.executing_eagerly():
with self.assertRaises(Exception) as cm: with self.assertRaises(Exception) as cm:
lookup_ops.StaticHashTable( self.getHashTable()(
lookup_ops.TextFileInitializer( lookup_ops.TextFileInitializer(
filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
@ -1143,7 +1144,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
self.assertTrue(isinstance(cm.exception, (ValueError, TypeError))) self.assertTrue(isinstance(cm.exception, (ValueError, TypeError)))
else: else:
with self.assertRaises(errors_impl.InvalidArgumentError): with self.assertRaises(errors_impl.InvalidArgumentError):
lookup_ops.StaticHashTable( self.getHashTable()(
lookup_ops.TextFileInitializer( lookup_ops.TextFileInitializer(
filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
@ -1157,9 +1158,9 @@ class InitializeTableFromFileOpTest(test.TestCase):
init = lookup_ops.TextFileStringTableInitializer( init = lookup_ops.TextFileStringTableInitializer(
vocab_file, vocab_size=vocab_size) vocab_file, vocab_size=vocab_size)
self.assertTrue("feat_to_id_1.txt_3_-1_-2", init._shared_name) self.assertTrue("feat_to_id_1.txt_3_-1_-2", init._shared_name)
table = lookup_ops.StaticHashTable(init, default_value) table = self.getHashTable()(init, default_value)
self.evaluate(table.initializer) self.initialize_table(table)
input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64) input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
@ -1176,8 +1177,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
init = lookup_ops.TextFileIdTableInitializer( init = lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size) vocab_file, vocab_size=vocab_size)
self.assertTrue("feat_to_id_2.txt_3_-1_-2", init._shared_name) self.assertTrue("feat_to_id_2.txt_3_-1_-2", init._shared_name)
table = lookup_ops.StaticHashTable(init, default_value) table = self.getHashTable()(init, default_value)
self.evaluate(table.initializer) self.initialize_table(table)
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
@ -1194,8 +1195,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
init = lookup_ops.TextFileIdTableInitializer( init = lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64) vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64)
self.assertTrue("feat_to_id_3.txt_3_-1_-2", init._shared_name) self.assertTrue("feat_to_id_3.txt_3_-1_-2", init._shared_name)
table = lookup_ops.StaticHashTable(init, default_value) table = self.getHashTable()(init, default_value)
self.evaluate(table.initializer) self.initialize_table(table)
out = table.lookup( out = table.lookup(
constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)) constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64))
@ -2885,8 +2886,7 @@ class DenseHashTableBenchmark(MutableHashTableBenchmark):
deleted_key=-2) deleted_key=-2)
@test_util.run_all_in_graph_and_eager_modes class StaticVocabularyTableTest(BaseLookupTableTest):
class StaticVocabularyTableTest(test.TestCase):
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
vocabulary_file = os.path.join(self.get_temp_dir(), basename) vocabulary_file = os.path.join(self.get_temp_dir(), basename)
@ -2896,14 +2896,13 @@ class StaticVocabularyTableTest(test.TestCase):
def testStringStaticVocabularyTable(self): def testStringStaticVocabularyTable(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt") vocab_file = self._createVocabFile("feat_to_id_1.txt")
with self.cached_session():
vocab_size = 3 vocab_size = 3
oov_buckets = 1 oov_buckets = 1
table = lookup_ops.StaticVocabularyTable( table = self.getVocabularyTable()(
lookup_ops.TextFileIdTableInitializer( lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size), oov_buckets) vocab_file, vocab_size=vocab_size), oov_buckets)
self.evaluate(table.initializer) self.initialize_table(table)
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
@ -2913,16 +2912,15 @@ class StaticVocabularyTableTest(test.TestCase):
def testInt32StaticVocabularyTable(self): def testInt32StaticVocabularyTable(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000")) vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
with self.cached_session():
vocab_size = 3 vocab_size = 3
oov_buckets = 1 oov_buckets = 1
table = lookup_ops.StaticVocabularyTable( table = self.getVocabularyTable()(
lookup_ops.TextFileIdTableInitializer( lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
oov_buckets, oov_buckets,
lookup_key_dtype=dtypes.int32) lookup_key_dtype=dtypes.int32)
self.evaluate(table.initializer) self.initialize_table(table)
values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32) values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32)
@ -2932,15 +2930,14 @@ class StaticVocabularyTableTest(test.TestCase):
def testInt64StaticVocabularyTable(self): def testInt64StaticVocabularyTable(self):
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000")) vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
with self.cached_session():
vocab_size = 3 vocab_size = 3
oov_buckets = 1 oov_buckets = 1
table = lookup_ops.StaticVocabularyTable( table = self.getVocabularyTable()(
lookup_ops.TextFileIdTableInitializer( lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
oov_buckets) oov_buckets)
self.evaluate(table.initializer) self.initialize_table(table)
values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64) values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)
@ -2949,13 +2946,12 @@ class StaticVocabularyTableTest(test.TestCase):
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size())) self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
def testStringStaticVocabularyTableNoInitializer(self): def testStringStaticVocabularyTableNoInitializer(self):
with self.cached_session():
oov_buckets = 5 oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns # Set a table that only uses hash buckets, for each input value returns
# an id calculated by fingerprint("input") mod oov_buckets. # an id calculated by fingerprint("input") mod oov_buckets.
table = lookup_ops.StaticVocabularyTable(None, oov_buckets) table = self.getVocabularyTable()(None, oov_buckets)
self.evaluate(table.initializer) self.initialize_table(table)
values = constant_op.constant(("brain", "salad", "surgery")) values = constant_op.constant(("brain", "salad", "surgery"))
@ -2971,16 +2967,15 @@ class StaticVocabularyTableTest(test.TestCase):
def testStaticVocabularyTableWithMultipleInitializers(self): def testStaticVocabularyTableWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt") vocab_file = self._createVocabFile("feat_to_id_4.txt")
with self.cached_session():
vocab_size = 3 vocab_size = 3
oov_buckets = 3 oov_buckets = 3
init = lookup_ops.TextFileIdTableInitializer( init = lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size) vocab_file, vocab_size=vocab_size)
table1 = lookup_ops.StaticVocabularyTable( table1 = self.getVocabularyTable()(
init, oov_buckets, name="table1") init, oov_buckets, name="table1")
table2 = lookup_ops.StaticVocabularyTable( table2 = self.getVocabularyTable()(
init, oov_buckets, name="table2") init, oov_buckets, name="table2")
self.evaluate(lookup_ops.tables_initializer()) self.evaluate(lookup_ops.tables_initializer())
@ -3002,11 +2997,11 @@ class StaticVocabularyTableTest(test.TestCase):
with self.cached_session(): with self.cached_session():
vocab_size = 3 vocab_size = 3
oov_buckets = 1 oov_buckets = 1
table1 = lookup_ops.StaticVocabularyTable( table1 = self.getVocabularyTable()(
lookup_ops.TextFileIdTableInitializer( lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size), oov_buckets) vocab_file, vocab_size=vocab_size), oov_buckets)
self.evaluate(table1.initializer) self.initialize_table(table1)
input_string_1 = constant_op.constant( input_string_1 = constant_op.constant(
["brain", "salad", "surgery", "UNK"]) ["brain", "salad", "surgery", "UNK"])
@ -3022,7 +3017,7 @@ class StaticVocabularyTableTest(test.TestCase):
# Underlying lookup table already initialized in previous session. # Underlying lookup table already initialized in previous session.
# No need to initialize table2 # No need to initialize table2
table2 = lookup_ops.StaticVocabularyTable( table2 = self.getVocabularyTable()(
lookup_ops.TextFileIdTableInitializer( lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size), oov_buckets) vocab_file, vocab_size=vocab_size), oov_buckets)
@ -3037,22 +3032,21 @@ class StaticVocabularyTableTest(test.TestCase):
vocab_file = self._createVocabFile("feat_to_id_7.txt") vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4] input_shape = [4, 4]
with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor( sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64), constant_op.constant(input_indices, dtypes.int64),
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"], constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
dtypes.string), dtypes.string),
constant_op.constant(input_shape, dtypes.int64)) constant_op.constant(input_shape, dtypes.int64))
table = lookup_ops.StaticVocabularyTable( table = self.getVocabularyTable()(
lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3), 1) lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3), 1)
self.evaluate(table.initializer) self.initialize_table(table)
sp_ids = table.lookup(sp_features) sp_ids = table.lookup(sp_features)
self.assertAllEqual([5], sp_ids.values._shape_as_list()) self.assertAllEqual([5], sp_ids.values._shape_as_list())
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run( sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
self.assertAllEqual(input_indices, sp_ids_ind) self.assertAllEqual(input_indices, sp_ids_ind)
@ -3062,24 +3056,23 @@ class StaticVocabularyTableTest(test.TestCase):
def testInt32SparseTensor(self): def testInt32SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4] input_shape = [4, 4]
with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor( sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64), constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32), constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
constant_op.constant(input_shape, dtypes.int64)) constant_op.constant(input_shape, dtypes.int64))
table = lookup_ops.StaticVocabularyTable( table = self.getVocabularyTable()(
lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
dtypes.int64, dtypes.int64), dtypes.int64, dtypes.int64),
1, 1,
lookup_key_dtype=dtypes.int32) lookup_key_dtype=dtypes.int32)
self.evaluate(table.initializer) self.initialize_table(table)
sp_ids = table.lookup(sp_features) sp_ids = table.lookup(sp_features)
self.assertAllEqual([5], sp_ids.values._shape_as_list()) self.assertAllEqual([5], sp_ids.values._shape_as_list())
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run( sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
self.assertAllEqual(input_indices, sp_ids_ind) self.assertAllEqual(input_indices, sp_ids_ind)
@ -3089,22 +3082,21 @@ class StaticVocabularyTableTest(test.TestCase):
def testInt64SparseTensor(self): def testInt64SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4] input_shape = [4, 4]
with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor( sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64), constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64), constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
constant_op.constant(input_shape, dtypes.int64)) constant_op.constant(input_shape, dtypes.int64))
table = lookup_ops.StaticVocabularyTable( table = self.getVocabularyTable()(
lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
dtypes.int64, dtypes.int64), 1) dtypes.int64, dtypes.int64), 1)
self.evaluate(table.initializer) self.initialize_table(table)
sp_ids = table.lookup(sp_features) sp_ids = table.lookup(sp_features)
self.assertAllEqual([5], sp_ids.values._shape_as_list()) self.assertAllEqual([5], sp_ids.values._shape_as_list())
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run( sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
self.assertAllEqual(input_indices, sp_ids_ind) self.assertAllEqual(input_indices, sp_ids_ind)
@ -3112,8 +3104,7 @@ class StaticVocabularyTableTest(test.TestCase):
self.assertAllEqual(input_shape, sp_ids_shape) self.assertAllEqual(input_shape, sp_ids_shape)
def testStaticVocabularyTableNoInnerTable(self): def testStaticVocabularyTableNoInnerTable(self):
with self.cached_session(): table = self.getVocabularyTable()(None, num_oov_buckets=1)
table = lookup_ops.StaticVocabularyTable(None, num_oov_buckets=1)
self.assertIsNone(table.resource_handle) self.assertIsNone(table.resource_handle)

View File

@ -174,15 +174,6 @@ class InitializableLookupTableBase(LookupInterface):
def _initialize(self): def _initialize(self):
return self._initializer.initialize(self) return self._initializer.initialize(self)
@property
def initializer(self):
return self._init_op
@property
@deprecated("2018-12-15", "Use `initializer` instead.")
def init(self):
return self.initializer
@property @property
def default_value(self): def default_value(self):
"""The default value of the table.""" """The default value of the table."""
@ -237,6 +228,14 @@ class InitializableLookupTableBase(LookupInterface):
return values return values
class InitializableLookupTableBaseV1(InitializableLookupTableBase):
@property
def initializer(self):
return self._init_op
@tf_export("lookup.StaticHashTable", v1=[])
class StaticHashTable(InitializableLookupTableBase): class StaticHashTable(InitializableLookupTableBase):
"""A generic hash table implementation. """A generic hash table implementation.
@ -311,8 +310,20 @@ class StaticHashTable(InitializableLookupTableBase):
return exported_keys, exported_values return exported_keys, exported_values
@tf_export(v1=["lookup.StaticHashTable"])
class StaticHashTableV1(StaticHashTable):
@property
def initializer(self):
return self._init_op
# For backwards compatibility. This will be removed in TF 2.0. # For backwards compatibility. This will be removed in TF 2.0.
HashTable = StaticHashTable class HashTable(StaticHashTableV1):
@property
def init(self):
return self.initializer
class TableInitializerBase(trackable_base.Trackable): class TableInitializerBase(trackable_base.Trackable):
@ -354,6 +365,7 @@ class TableInitializerBase(trackable_base.Trackable):
return shared_name return shared_name
@tf_export("lookup.KeyValueTensorInitializer")
class KeyValueTensorInitializer(TableInitializerBase): class KeyValueTensorInitializer(TableInitializerBase):
"""Table initializers given `keys` and `values` tensors.""" """Table initializers given `keys` and `values` tensors."""
@ -412,6 +424,7 @@ class TextFileIndex(object):
LINE_NUMBER = -1 LINE_NUMBER = -1
@tf_export("lookup.TextFileInitializer")
class TextFileInitializer(TableInitializerBase): class TextFileInitializer(TableInitializerBase):
"""Table initializers from a text file. """Table initializers from a text file.
@ -951,6 +964,7 @@ class IdTableWithHashBuckets(LookupInterface):
return ids return ids
@tf_export("lookup.StaticVocabularyTable", v1=[])
class StaticVocabularyTable(LookupInterface): class StaticVocabularyTable(LookupInterface):
"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets. """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
@ -1063,18 +1077,6 @@ class StaticVocabularyTable(LookupInterface):
with ops.name_scope(None, "init"): with ops.name_scope(None, "init"):
return control_flow_ops.no_op() return control_flow_ops.no_op()
@property
def initializer(self):
if self._table is not None:
return self._table._init_op # pylint: disable=protected-access
with ops.name_scope(None, "init"):
return control_flow_ops.no_op()
@property
@deprecated("2018-12-15", "Use `initializer` instead.")
def init(self):
return self.initializer
@property @property
def resource_handle(self): def resource_handle(self):
if self._table is not None: if self._table is not None:
@ -1136,6 +1138,17 @@ class StaticVocabularyTable(LookupInterface):
return ids return ids
@tf_export(v1=["lookup.StaticVocabularyTable"])
class StaticVocabularyTableV1(StaticVocabularyTable):
@property
def initializer(self):
if self._table is not None:
return self._table._init_op # pylint: disable=protected-access
with ops.name_scope(None, "init"):
return control_flow_ops.no_op()
def index_table_from_file(vocabulary_file=None, def index_table_from_file(vocabulary_file=None,
num_oov_buckets=0, num_oov_buckets=0,
vocab_size=None, vocab_size=None,
@ -1244,7 +1257,7 @@ def index_table_from_file(vocabulary_file=None,
value_column_index=value_column_index, value_column_index=value_column_index,
delimiter=delimiter) delimiter=delimiter)
table = StaticHashTable(init, default_value) table = StaticHashTableV1(init, default_value)
if num_oov_buckets: if num_oov_buckets:
table = IdTableWithHashBuckets( table = IdTableWithHashBuckets(
table, table,
@ -1341,7 +1354,7 @@ def index_table_from_tensor(vocabulary_list,
table_keys.dtype.base_dtype, table_keys.dtype.base_dtype,
dtypes.int64, dtypes.int64,
name="table_init") name="table_init")
table = StaticHashTable(init, default_value) table = StaticHashTableV1(init, default_value)
if num_oov_buckets: if num_oov_buckets:
table = IdTableWithHashBuckets( table = IdTableWithHashBuckets(
table, table,
@ -1438,7 +1451,7 @@ def index_to_string_table_from_file(vocabulary_file,
delimiter=delimiter) delimiter=delimiter)
# TODO(yleon): Use a more effienct structure. # TODO(yleon): Use a more effienct structure.
return StaticHashTable(init, default_value) return StaticHashTableV1(init, default_value)
def index_to_string_table_from_tensor(vocabulary_list, def index_to_string_table_from_tensor(vocabulary_list,
@ -1499,7 +1512,7 @@ def index_to_string_table_from_tensor(vocabulary_list,
init = KeyValueTensorInitializer( init = KeyValueTensorInitializer(
keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init") keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init")
# TODO(yleon): Use a more effienct structure. # TODO(yleon): Use a more effienct structure.
return StaticHashTable(init, default_value) return StaticHashTableV1(init, default_value)
class MutableHashTable(LookupInterface): class MutableHashTable(LookupInterface):
@ -1733,6 +1746,7 @@ class MutableHashTable(LookupInterface):
self.op.resource_handle, restored_tensors[0], restored_tensors[1]) self.op.resource_handle, restored_tensors[0], restored_tensors[1])
@tf_export("lookup.experimental.DenseHashTable")
class DenseHashTable(LookupInterface): class DenseHashTable(LookupInterface):
"""A generic mutable hash table implementation using tensors as backing store. """A generic mutable hash table implementation using tensors as backing store.

View File

@ -30,6 +30,8 @@ TENSORFLOW_API_INIT_FILES = [
"lite/constants/__init__.py", "lite/constants/__init__.py",
"lite/experimental/__init__.py", "lite/experimental/__init__.py",
"lite/experimental/nn/__init__.py", "lite/experimental/nn/__init__.py",
"lookup/__init__.py",
"lookup/experimental/__init__.py",
"math/__init__.py", "math/__init__.py",
"nest/__init__.py", "nest/__init__.py",
"nn/__init__.py", "nn/__init__.py",

View File

@ -37,6 +37,8 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"lite/experimental/__init__.py", "lite/experimental/__init__.py",
"lite/experimental/nn/__init__.py", "lite/experimental/nn/__init__.py",
"logging/__init__.py", "logging/__init__.py",
"lookup/__init__.py",
"lookup/experimental/__init__.py",
"losses/__init__.py", "losses/__init__.py",
"manip/__init__.py", "manip/__init__.py",
"math/__init__.py", "math/__init__.py",

View File

@ -0,0 +1,23 @@
path: "tensorflow.lookup.KeyValueTensorInitializer"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.KeyValueTensorInitializer\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'keys\', \'values\', \'key_dtype\', \'value_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "initialize"
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,50 @@
path: "tensorflow.lookup.StaticHashTable"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticHashTableV1\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticHashTable\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.InitializableLookupTableBase\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "default_value"
mtype: "<type \'property\'>"
}
member {
name: "initializer"
mtype: "<type \'property\'>"
}
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "resource_handle"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'initializer\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "export"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "lookup"
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "size"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -0,0 +1,41 @@
path: "tensorflow.lookup.StaticVocabularyTable"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticVocabularyTableV1\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticVocabularyTable\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "initializer"
mtype: "<type \'property\'>"
}
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "resource_handle"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'initializer\', \'num_oov_buckets\', \'lookup_key_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "lookup"
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "size"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -0,0 +1,23 @@
path: "tensorflow.lookup.TextFileInitializer"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TextFileInitializer\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
}
member_method {
name: "initialize"
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,56 @@
path: "tensorflow.lookup.experimental.DenseHashTable"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DenseHashTable\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "resource_handle"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'key_dtype\', \'value_dtype\', \'default_value\', \'empty_key\', \'deleted_key\', \'initial_num_buckets\', \'name\', \'checkpoint\'], varargs=None, keywords=None, defaults=[\'None\', \'MutableDenseHashTable\', \'True\'], "
}
member_method {
name: "erase"
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "export"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "insert"
argspec: "args=[\'self\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "insert_or_assign"
argspec: "args=[\'self\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "lookup"
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "remove"
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "size"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.lookup.experimental"
tf_module {
member {
name: "DenseHashTable"
mtype: "<type \'type\'>"
}
}

View File

@ -0,0 +1,23 @@
path: "tensorflow.lookup"
tf_module {
member {
name: "KeyValueTensorInitializer"
mtype: "<type \'type\'>"
}
member {
name: "StaticHashTable"
mtype: "<type \'type\'>"
}
member {
name: "StaticVocabularyTable"
mtype: "<type \'type\'>"
}
member {
name: "TextFileInitializer"
mtype: "<type \'type\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
}
}

View File

@ -464,6 +464,10 @@ tf_module {
name: "logging" name: "logging"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"
} }
member {
name: "lookup"
mtype: "<type \'module\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"

View File

@ -0,0 +1,23 @@
path: "tensorflow.lookup.KeyValueTensorInitializer"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.KeyValueTensorInitializer\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'keys\', \'values\', \'key_dtype\', \'value_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "initialize"
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,45 @@
path: "tensorflow.lookup.StaticHashTable"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticHashTable\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.InitializableLookupTableBase\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "default_value"
mtype: "<type \'property\'>"
}
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "resource_handle"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'initializer\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "export"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "lookup"
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "size"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -0,0 +1,36 @@
path: "tensorflow.lookup.StaticVocabularyTable"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticVocabularyTable\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "resource_handle"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'initializer\', \'num_oov_buckets\', \'lookup_key_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "lookup"
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "size"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -0,0 +1,23 @@
path: "tensorflow.lookup.TextFileInitializer"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TextFileInitializer\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
}
member_method {
name: "initialize"
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,56 @@
path: "tensorflow.lookup.experimental.DenseHashTable"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DenseHashTable\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "resource_handle"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'key_dtype\', \'value_dtype\', \'default_value\', \'empty_key\', \'deleted_key\', \'initial_num_buckets\', \'name\', \'checkpoint\'], varargs=None, keywords=None, defaults=[\'None\', \'MutableDenseHashTable\', \'True\'], "
}
member_method {
name: "erase"
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "export"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "insert"
argspec: "args=[\'self\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "insert_or_assign"
argspec: "args=[\'self\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "lookup"
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "remove"
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "size"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.lookup.experimental"
tf_module {
member {
name: "DenseHashTable"
mtype: "<type \'type\'>"
}
}

View File

@ -0,0 +1,23 @@
path: "tensorflow.lookup"
tf_module {
member {
name: "KeyValueTensorInitializer"
mtype: "<type \'type\'>"
}
member {
name: "StaticHashTable"
mtype: "<type \'type\'>"
}
member {
name: "StaticVocabularyTable"
mtype: "<type \'type\'>"
}
member {
name: "TextFileInitializer"
mtype: "<type \'type\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
}
}

View File

@ -212,6 +212,10 @@ tf_module {
name: "lite" name: "lite"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"
} }
member {
name: "lookup"
mtype: "<type \'module\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"