Adding tf_export bindings for the tf.lookup API. Adds tf.lookup.StaticHashTable, tf.lookup.VocabularyTable, tf.lookup.experimental.DenseHashTable, tf.lookup.TextFileInitializer and tf.lookup.KeyValueTensorInitializer
PiperOrigin-RevId: 235990023
This commit is contained in:
parent
104098d46b
commit
6adc4f81d0
@ -30,6 +30,7 @@ from tensorflow.python.ops.lookup_ops import IdTableWithHashBuckets
|
|||||||
from tensorflow.python.ops.lookup_ops import index_table_from_file
|
from tensorflow.python.ops.lookup_ops import index_table_from_file
|
||||||
from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file
|
from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file
|
||||||
from tensorflow.python.ops.lookup_ops import InitializableLookupTableBase
|
from tensorflow.python.ops.lookup_ops import InitializableLookupTableBase
|
||||||
|
from tensorflow.python.ops.lookup_ops import InitializableLookupTableBaseV1
|
||||||
from tensorflow.python.ops.lookup_ops import KeyValueTensorInitializer
|
from tensorflow.python.ops.lookup_ops import KeyValueTensorInitializer
|
||||||
from tensorflow.python.ops.lookup_ops import LookupInterface
|
from tensorflow.python.ops.lookup_ops import LookupInterface
|
||||||
from tensorflow.python.ops.lookup_ops import StrongHashSpec
|
from tensorflow.python.ops.lookup_ops import StrongHashSpec
|
||||||
@ -284,7 +285,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
|
|||||||
return table.lookup(tensor)
|
return table.lookup(tensor)
|
||||||
|
|
||||||
|
|
||||||
class HashTable(InitializableLookupTableBase):
|
class HashTable(InitializableLookupTableBaseV1):
|
||||||
"""A generic hash table implementation.
|
"""A generic hash table implementation.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
@ -22,6 +22,7 @@ import tempfile
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.data.experimental.ops import counter
|
from tensorflow.python.data.experimental.ops import counter
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
@ -41,16 +42,34 @@ from tensorflow.python.training import server_lib
|
|||||||
from tensorflow.python.training.tracking import util as trackable
|
from tensorflow.python.training.tracking import util as trackable
|
||||||
|
|
||||||
|
|
||||||
class StaticHashTableTest(test.TestCase):
|
class BaseLookupTableTest(test.TestCase):
|
||||||
|
|
||||||
|
def getHashTable(self):
|
||||||
|
if tf2.enabled():
|
||||||
|
return lookup_ops.StaticHashTable
|
||||||
|
else:
|
||||||
|
return lookup_ops.StaticHashTableV1
|
||||||
|
|
||||||
|
def getVocabularyTable(self):
|
||||||
|
if tf2.enabled():
|
||||||
|
return lookup_ops.StaticVocabularyTable
|
||||||
|
else:
|
||||||
|
return lookup_ops.StaticVocabularyTableV1
|
||||||
|
|
||||||
|
def initialize_table(self, table):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.evaluate(table.initializer)
|
||||||
|
|
||||||
|
|
||||||
|
class StaticHashTableTest(BaseLookupTableTest):
|
||||||
|
|
||||||
def testStaticHashTable(self):
|
def testStaticHashTable(self):
|
||||||
with self.cached_session():
|
|
||||||
default_val = -1
|
default_val = -1
|
||||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
self.assertAllEqual(3, self.evaluate(table.size()))
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||||
|
|
||||||
@ -68,13 +87,12 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
self.assertItemsEqual([0, 1, 2], self.evaluate(exported_values_tensor))
|
self.assertItemsEqual([0, 1, 2], self.evaluate(exported_values_tensor))
|
||||||
|
|
||||||
def testStaticHashTableFindHighRank(self):
|
def testStaticHashTableFindHighRank(self):
|
||||||
with self.cached_session():
|
|
||||||
default_val = -1
|
default_val = -1
|
||||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
self.assertAllEqual(3, self.evaluate(table.size()))
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||||
|
|
||||||
@ -86,14 +104,13 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
self.assertAllEqual([[0, 1], [-1, -1]], result)
|
self.assertAllEqual([[0, 1], [-1, -1]], result)
|
||||||
|
|
||||||
def testStaticHashTableInitWithPythonArrays(self):
|
def testStaticHashTableInitWithPythonArrays(self):
|
||||||
with self.cached_session():
|
|
||||||
default_val = -1
|
default_val = -1
|
||||||
keys = ["brain", "salad", "surgery"]
|
keys = ["brain", "salad", "surgery"]
|
||||||
values = [0, 1, 2]
|
values = [0, 1, 2]
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(
|
lookup_ops.KeyValueTensorInitializer(
|
||||||
keys, values, value_dtype=dtypes.int64), default_val)
|
keys, values, value_dtype=dtypes.int64), default_val)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
self.assertAllEqual(3, self.evaluate(table.size()))
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||||
|
|
||||||
@ -104,13 +121,12 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
self.assertAllEqual([0, 1, -1], result)
|
self.assertAllEqual([0, 1, -1], result)
|
||||||
|
|
||||||
def testStaticHashTableInitWithNumPyArrays(self):
|
def testStaticHashTableInitWithNumPyArrays(self):
|
||||||
with self.cached_session():
|
|
||||||
default_val = -1
|
default_val = -1
|
||||||
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
|
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
|
||||||
values = np.array([0, 1, 2], dtype=np.int64)
|
values = np.array([0, 1, 2], dtype=np.int64)
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
self.assertAllEqual(3, self.evaluate(table.size()))
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||||
|
|
||||||
@ -121,22 +137,20 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
self.assertAllEqual([0, 1, -1], result)
|
self.assertAllEqual([0, 1, -1], result)
|
||||||
|
|
||||||
def testMultipleStaticHashTables(self):
|
def testMultipleStaticHashTables(self):
|
||||||
with self.cached_session():
|
|
||||||
|
|
||||||
default_val = -1
|
default_val = -1
|
||||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||||
|
|
||||||
table1 = lookup_ops.StaticHashTable(
|
table1 = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
table2 = lookup_ops.StaticHashTable(
|
table2 = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
table3 = lookup_ops.StaticHashTable(
|
table3 = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
|
|
||||||
self.evaluate(table1.initializer)
|
self.initialize_table(table1)
|
||||||
self.evaluate(table2.initializer)
|
self.initialize_table(table2)
|
||||||
self.evaluate(table3.initializer)
|
self.initialize_table(table3)
|
||||||
self.assertAllEqual(3, self.evaluate(table1.size()))
|
self.assertAllEqual(3, self.evaluate(table1.size()))
|
||||||
self.assertAllEqual(3, self.evaluate(table2.size()))
|
self.assertAllEqual(3, self.evaluate(table2.size()))
|
||||||
self.assertAllEqual(3, self.evaluate(table3.size()))
|
self.assertAllEqual(3, self.evaluate(table3.size()))
|
||||||
@ -152,13 +166,12 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
self.assertAllEqual([0, 1, -1], out3)
|
self.assertAllEqual([0, 1, -1], out3)
|
||||||
|
|
||||||
def testStaticHashTableWithTensorDefault(self):
|
def testStaticHashTableWithTensorDefault(self):
|
||||||
with self.cached_session():
|
|
||||||
default_val = constant_op.constant(-1, dtypes.int64)
|
default_val = constant_op.constant(-1, dtypes.int64)
|
||||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
input_string = constant_op.constant(["brain", "salad", "tank"])
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
||||||
output = table.lookup(input_string)
|
output = table.lookup(input_string)
|
||||||
@ -167,13 +180,12 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
self.assertAllEqual([0, 1, -1], result)
|
self.assertAllEqual([0, 1, -1], result)
|
||||||
|
|
||||||
def testStaticHashTableWithSparseTensorInput(self):
|
def testStaticHashTableWithSparseTensorInput(self):
|
||||||
with self.cached_session():
|
|
||||||
default_val = constant_op.constant(-1, dtypes.int64)
|
default_val = constant_op.constant(-1, dtypes.int64)
|
||||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
sp_indices = [[0, 0], [0, 1], [1, 0]]
|
sp_indices = [[0, 0], [0, 1], [1, 0]]
|
||||||
sp_shape = [2, 2]
|
sp_shape = [2, 2]
|
||||||
@ -190,13 +202,12 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
self.assertAllEqual(sp_shape, out_shape)
|
self.assertAllEqual(sp_shape, out_shape)
|
||||||
|
|
||||||
def testSignatureMismatch(self):
|
def testSignatureMismatch(self):
|
||||||
with self.cached_session():
|
|
||||||
default_val = -1
|
default_val = -1
|
||||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
# Ref types do not produce a lookup signature mismatch.
|
# Ref types do not produce a lookup signature mismatch.
|
||||||
input_string_ref = variables.Variable("brain")
|
input_string_ref = variables.Variable("brain")
|
||||||
@ -208,22 +219,21 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
table.lookup(input_string)
|
table.lookup(input_string)
|
||||||
|
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
lookup_ops.StaticHashTable(
|
self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), "UNK")
|
lookup_ops.KeyValueTensorInitializer(keys, values), "UNK")
|
||||||
|
|
||||||
def testDTypes(self):
|
def testDTypes(self):
|
||||||
with self.cached_session():
|
|
||||||
default_val = -1
|
default_val = -1
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
lookup_ops.StaticHashTable(
|
self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(["a"], [1], [dtypes.string],
|
lookup_ops.KeyValueTensorInitializer(["a"], [1], [dtypes.string],
|
||||||
dtypes.int64), default_val)
|
dtypes.int64), default_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_v1_only
|
||||||
def testNotInitialized(self):
|
def testNotInitialized(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
default_val = -1
|
default_val = -1
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(["a"], [1],
|
lookup_ops.KeyValueTensorInitializer(["a"], [1],
|
||||||
value_dtype=dtypes.int64),
|
value_dtype=dtypes.int64),
|
||||||
default_val)
|
default_val)
|
||||||
@ -234,31 +244,32 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
with self.assertRaisesOpError("Table not initialized"):
|
with self.assertRaisesOpError("Table not initialized"):
|
||||||
self.evaluate(output)
|
self.evaluate(output)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_v1_only
|
||||||
def testInitializeTwice(self):
|
def testInitializeTwice(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
default_val = -1
|
default_val = -1
|
||||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
with self.assertRaisesOpError("Table already initialized"):
|
with self.assertRaisesOpError("Table already initialized"):
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testInitializationWithInvalidDimensions(self):
|
def testInitializationWithInvalidDimensions(self):
|
||||||
with self.cached_session():
|
|
||||||
default_val = -1
|
default_val = -1
|
||||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
|
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
raised_error = ValueError
|
||||||
lookup_ops.StaticHashTable(
|
if context.executing_eagerly():
|
||||||
|
raised_error = errors_impl.InvalidArgumentError
|
||||||
|
with self.assertRaises(raised_error):
|
||||||
|
self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_v1_only
|
||||||
def testMultipleSessions(self):
|
def testMultipleSessions(self):
|
||||||
# Start a server
|
# Start a server
|
||||||
server = server_lib.Server({"local0": ["localhost:0"]},
|
server = server_lib.Server({"local0": ["localhost:0"]},
|
||||||
@ -271,14 +282,14 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
default_val = -1
|
default_val = -1
|
||||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values),
|
lookup_ops.KeyValueTensorInitializer(keys, values),
|
||||||
default_val,
|
default_val,
|
||||||
name="t1")
|
name="t1")
|
||||||
|
|
||||||
# Init the table in the first session.
|
# Init the table in the first session.
|
||||||
with session1:
|
with session1:
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
self.assertAllEqual(3, self.evaluate(table.size()))
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||||
|
|
||||||
# Init the table in the second session and verify that we do not get a
|
# Init the table in the second session and verify that we do not get a
|
||||||
@ -288,13 +299,12 @@ class StaticHashTableTest(test.TestCase):
|
|||||||
self.assertAllEqual(3, self.evaluate(table.size()))
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||||
|
|
||||||
def testStaticHashTableInt32String(self):
|
def testStaticHashTableInt32String(self):
|
||||||
with self.cached_session():
|
|
||||||
default_val = "n/a"
|
default_val = "n/a"
|
||||||
keys = constant_op.constant([0, 1, 2], dtypes.int32)
|
keys = constant_op.constant([0, 1, 2], dtypes.int32)
|
||||||
values = constant_op.constant(["brain", "salad", "surgery"])
|
values = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
table = lookup_ops.StaticHashTable(
|
table = self.getHashTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
input_tensor = constant_op.constant([0, 1, -1])
|
input_tensor = constant_op.constant([0, 1, -1])
|
||||||
output = table.lookup(input_tensor)
|
output = table.lookup(input_tensor)
|
||||||
@ -569,53 +579,45 @@ class IndexTableFromFile(test.TestCase):
|
|||||||
self.assertIsNotNone(table.resource_handle)
|
self.assertIsNotNone(table.resource_handle)
|
||||||
|
|
||||||
|
|
||||||
class KeyValueTensorInitializerTest(test.TestCase):
|
class KeyValueTensorInitializerTest(BaseLookupTableTest):
|
||||||
|
|
||||||
def test_string(self):
|
def test_string(self):
|
||||||
with ops.Graph().as_default(), self.cached_session():
|
|
||||||
init = lookup_ops.KeyValueTensorInitializer(
|
init = lookup_ops.KeyValueTensorInitializer(
|
||||||
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
|
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
|
||||||
self.assertEqual("", init._shared_name)
|
table = self.getHashTable()(init, default_value=-1)
|
||||||
table = lookup_ops.StaticHashTable(init, default_value=-1)
|
self.initialize_table(table)
|
||||||
table.initializer.run()
|
|
||||||
|
|
||||||
def test_multiple_tables(self):
|
def test_multiple_tables(self):
|
||||||
with ops.Graph().as_default(), self.cached_session():
|
|
||||||
with ops.name_scope("table_scope"):
|
with ops.name_scope("table_scope"):
|
||||||
init1 = lookup_ops.KeyValueTensorInitializer(
|
init1 = lookup_ops.KeyValueTensorInitializer(
|
||||||
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string,
|
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string,
|
||||||
dtypes.int64)
|
dtypes.int64)
|
||||||
self.assertEqual("", init1._shared_name)
|
table1 = self.getHashTable()(init1, default_value=-1)
|
||||||
table1 = lookup_ops.StaticHashTable(init1, default_value=-1)
|
if not context.executing_eagerly():
|
||||||
self.assertEqual("hash_table", table1.name)
|
self.assertEqual("hash_table", table1.name)
|
||||||
self.assertEqual("table_scope/hash_table",
|
self.assertEqual("table_scope/hash_table",
|
||||||
table1.resource_handle.op.name)
|
table1.resource_handle.op.name)
|
||||||
init2 = lookup_ops.KeyValueTensorInitializer(
|
init2 = lookup_ops.KeyValueTensorInitializer(
|
||||||
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string,
|
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string,
|
||||||
dtypes.int64)
|
dtypes.int64)
|
||||||
self.assertEqual("", init2._shared_name)
|
table2 = self.getHashTable()(init2, default_value=-1)
|
||||||
table2 = lookup_ops.StaticHashTable(init2, default_value=-1)
|
if not context.executing_eagerly():
|
||||||
self.assertEqual("hash_table_1", table2.name)
|
self.assertEqual("hash_table_1", table2.name)
|
||||||
self.assertEqual("table_scope/hash_table_1",
|
self.assertEqual("table_scope/hash_table_1",
|
||||||
table2.resource_handle.op.name)
|
table2.resource_handle.op.name)
|
||||||
|
|
||||||
def test_int64(self):
|
def test_int64(self):
|
||||||
with ops.Graph().as_default(), self.cached_session():
|
|
||||||
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
||||||
dtypes.int64, dtypes.int64)
|
dtypes.int64, dtypes.int64)
|
||||||
self.assertEqual("", init._shared_name)
|
table = self.getHashTable()(init, default_value=-1)
|
||||||
table = lookup_ops.StaticHashTable(init, default_value=-1)
|
self.initialize_table(table)
|
||||||
table.initializer.run()
|
|
||||||
|
|
||||||
def test_int32(self):
|
def test_int32(self):
|
||||||
with ops.Graph().as_default(), self.cached_session():
|
|
||||||
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
||||||
dtypes.int32, dtypes.int64)
|
dtypes.int32, dtypes.int64)
|
||||||
self.assertEqual("", init._shared_name)
|
with self.assertRaises(errors_impl.OpError):
|
||||||
table = lookup_ops.StaticHashTable(init, default_value=-1)
|
table = self.getHashTable()(init, default_value=-1)
|
||||||
with self.assertRaisesRegexp(errors_impl.OpError,
|
self.initialize_table(table)
|
||||||
"No OpKernel was registered"):
|
|
||||||
table.initializer.run()
|
|
||||||
|
|
||||||
|
|
||||||
class IndexTableFromTensor(test.TestCase):
|
class IndexTableFromTensor(test.TestCase):
|
||||||
@ -866,7 +868,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
|
|||||||
self.evaluate(features))
|
self.evaluate(features))
|
||||||
|
|
||||||
|
|
||||||
class InitializeTableFromFileOpTest(test.TestCase):
|
class InitializeTableFromFileOpTest(BaseLookupTableTest):
|
||||||
|
|
||||||
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
||||||
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
||||||
@ -874,7 +876,6 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
f.write("\n".join(values) + "\n")
|
f.write("\n".join(values) + "\n")
|
||||||
return vocabulary_file
|
return vocabulary_file
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testInitializeStringTable(self):
|
def testInitializeStringTable(self):
|
||||||
vocabulary_file = self._createVocabFile("one_column_1.txt")
|
vocabulary_file = self._createVocabFile("one_column_1.txt")
|
||||||
default_value = -1
|
default_value = -1
|
||||||
@ -882,8 +883,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
||||||
self.assertTrue("one_column_1.txt_-2_-1" in init._shared_name)
|
self.assertTrue("one_column_1.txt_-2_-1" in init._shared_name)
|
||||||
table = lookup_ops.StaticHashTable(init, default_value)
|
table = self.getHashTable()(init, default_value)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
|
output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
|
||||||
|
|
||||||
@ -900,8 +901,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
vocabulary_file, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE,
|
vocabulary_file, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
||||||
self.assertTrue("one_column_int64.txt_-2_-1" in init._shared_name)
|
self.assertTrue("one_column_int64.txt_-2_-1" in init._shared_name)
|
||||||
table = lookup_ops.StaticHashTable(init, default_value)
|
table = self.getHashTable()(init, default_value)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
output = table.lookup(
|
output = table.lookup(
|
||||||
constant_op.constant((42, 1, 11), dtype=dtypes.int64))
|
constant_op.constant((42, 1, 11), dtype=dtypes.int64))
|
||||||
@ -919,8 +920,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
init = lookup_ops.TextFileInitializer(
|
init = lookup_ops.TextFileInitializer(
|
||||||
vocabulary_file, dtypes.int64, key_index, dtypes.string, value_index)
|
vocabulary_file, dtypes.int64, key_index, dtypes.string, value_index)
|
||||||
self.assertTrue("one_column_2.txt_-1_-2" in init._shared_name)
|
self.assertTrue("one_column_2.txt_-1_-2" in init._shared_name)
|
||||||
table = lookup_ops.StaticHashTable(init, default_value)
|
table = self.getHashTable()(init, default_value)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
||||||
output = table.lookup(input_values)
|
output = table.lookup(input_values)
|
||||||
@ -941,8 +942,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
init = lookup_ops.TextFileInitializer(
|
init = lookup_ops.TextFileInitializer(
|
||||||
vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
|
vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
|
||||||
self.assertTrue("three_columns.txt_1_2" in init._shared_name)
|
self.assertTrue("three_columns.txt_1_2" in init._shared_name)
|
||||||
table = lookup_ops.StaticHashTable(init, default_value)
|
table = self.getHashTable()(init, default_value)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
input_string = constant_op.constant(["brain", "salad", "surgery"])
|
input_string = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
output = table.lookup(input_string)
|
output = table.lookup(input_string)
|
||||||
@ -963,8 +964,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
|
vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
|
||||||
self.assertTrue("three_columns.txt_2_1" in init._shared_name)
|
self.assertTrue("three_columns.txt_2_1" in init._shared_name)
|
||||||
with self.assertRaisesOpError("is not a valid"):
|
with self.assertRaisesOpError("is not a valid"):
|
||||||
table = lookup_ops.StaticHashTable(init, default_value)
|
table = self.getHashTable()(init, default_value)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
def testInvalidDataType(self):
|
def testInvalidDataType(self):
|
||||||
vocabulary_file = self._createVocabFile("one_column_3.txt")
|
vocabulary_file = self._createVocabFile("one_column_3.txt")
|
||||||
@ -979,7 +980,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
key_index, dtypes.string,
|
key_index, dtypes.string,
|
||||||
value_index)
|
value_index)
|
||||||
self.assertTrue("one_column_3.txt_-2_-1" in init._shared_name)
|
self.assertTrue("one_column_3.txt_-2_-1" in init._shared_name)
|
||||||
lookup_ops.StaticHashTable(init, default_value)
|
self.getHashTable()(init, default_value)
|
||||||
|
|
||||||
def testInvalidIndex(self):
|
def testInvalidIndex(self):
|
||||||
vocabulary_file = self._createVocabFile("one_column_4.txt")
|
vocabulary_file = self._createVocabFile("one_column_4.txt")
|
||||||
@ -992,8 +993,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
self.assertTrue("one_column_4.txt_1_-1" in init._shared_name)
|
self.assertTrue("one_column_4.txt_1_-1" in init._shared_name)
|
||||||
|
|
||||||
with self.assertRaisesOpError("Invalid number of columns"):
|
with self.assertRaisesOpError("Invalid number of columns"):
|
||||||
table = lookup_ops.StaticHashTable(init, default_value)
|
table = self.getHashTable()(init, default_value)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
def testInitializeSameTableWithMultipleNodes(self):
|
def testInitializeSameTableWithMultipleNodes(self):
|
||||||
vocabulary_file = self._createVocabFile("one_column_5.txt")
|
vocabulary_file = self._createVocabFile("one_column_5.txt")
|
||||||
@ -1004,17 +1005,17 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
||||||
self.assertTrue("one_column_5.txt_-2_-1" in init1._shared_name)
|
self.assertTrue("one_column_5.txt_-2_-1" in init1._shared_name)
|
||||||
table1 = lookup_ops.StaticHashTable(init1, default_value)
|
table1 = self.getHashTable()(init1, default_value)
|
||||||
init2 = lookup_ops.TextFileInitializer(
|
init2 = lookup_ops.TextFileInitializer(
|
||||||
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
||||||
self.assertTrue("one_column_5.txt_-2_-1" in init2._shared_name)
|
self.assertTrue("one_column_5.txt_-2_-1" in init2._shared_name)
|
||||||
table2 = lookup_ops.StaticHashTable(init2, default_value)
|
table2 = self.getHashTable()(init2, default_value)
|
||||||
init3 = lookup_ops.TextFileInitializer(
|
init3 = lookup_ops.TextFileInitializer(
|
||||||
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
||||||
self.assertTrue("one_column_5.txt_-2_-1" in init3._shared_name)
|
self.assertTrue("one_column_5.txt_-2_-1" in init3._shared_name)
|
||||||
table3 = lookup_ops.StaticHashTable(init3, default_value)
|
table3 = self.getHashTable()(init3, default_value)
|
||||||
|
|
||||||
self.evaluate(lookup_ops.tables_initializer())
|
self.evaluate(lookup_ops.tables_initializer())
|
||||||
|
|
||||||
@ -1033,7 +1034,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
default_value = -1
|
default_value = -1
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
lookup_ops.StaticHashTable(
|
self.getHashTable()(
|
||||||
lookup_ops.TextFileInitializer(
|
lookup_ops.TextFileInitializer(
|
||||||
"", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
"", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
|
||||||
@ -1052,10 +1053,10 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
lookup_ops.TextFileIndex.LINE_NUMBER,
|
lookup_ops.TextFileIndex.LINE_NUMBER,
|
||||||
vocab_size=vocab_size)
|
vocab_size=vocab_size)
|
||||||
self.assertTrue("one_column6.txt_3_-2_-1" in init1._shared_name)
|
self.assertTrue("one_column6.txt_3_-2_-1" in init1._shared_name)
|
||||||
table1 = lookup_ops.StaticHashTable(init1, default_value)
|
table1 = self.getHashTable()(init1, default_value)
|
||||||
|
|
||||||
# Initialize from file.
|
# Initialize from file.
|
||||||
self.evaluate(table1.initializer)
|
self.initialize_table(table1)
|
||||||
self.assertEqual(vocab_size, self.evaluate(table1.size()))
|
self.assertEqual(vocab_size, self.evaluate(table1.size()))
|
||||||
|
|
||||||
vocabulary_file2 = self._createVocabFile("one_column7.txt")
|
vocabulary_file2 = self._createVocabFile("one_column7.txt")
|
||||||
@ -1069,8 +1070,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
vocab_size=vocab_size)
|
vocab_size=vocab_size)
|
||||||
self.assertTrue("one_column7.txt_5_-2_-1" in init2._shared_name)
|
self.assertTrue("one_column7.txt_5_-2_-1" in init2._shared_name)
|
||||||
with self.assertRaisesOpError("Invalid vocab_size"):
|
with self.assertRaisesOpError("Invalid vocab_size"):
|
||||||
table2 = lookup_ops.StaticHashTable(init2, default_value)
|
table2 = self.getHashTable()(init2, default_value)
|
||||||
self.evaluate(table2.initializer)
|
self.initialize_table(table2)
|
||||||
|
|
||||||
vocab_size = 1
|
vocab_size = 1
|
||||||
vocabulary_file3 = self._createVocabFile("one_column3.txt")
|
vocabulary_file3 = self._createVocabFile("one_column3.txt")
|
||||||
@ -1082,10 +1083,10 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
lookup_ops.TextFileIndex.LINE_NUMBER,
|
lookup_ops.TextFileIndex.LINE_NUMBER,
|
||||||
vocab_size=vocab_size)
|
vocab_size=vocab_size)
|
||||||
self.assertTrue("one_column3.txt_1_-2_-1" in init3._shared_name)
|
self.assertTrue("one_column3.txt_1_-2_-1" in init3._shared_name)
|
||||||
table3 = lookup_ops.StaticHashTable(init3, default_value)
|
table3 = self.getHashTable()(init3, default_value)
|
||||||
|
|
||||||
# Smaller vocab size reads only vocab_size records.
|
# Smaller vocab size reads only vocab_size records.
|
||||||
self.evaluate(table3.initializer)
|
self.initialize_table(table3)
|
||||||
self.assertEqual(vocab_size, self.evaluate(table3.size()))
|
self.assertEqual(vocab_size, self.evaluate(table3.size()))
|
||||||
|
|
||||||
@test_util.run_v1_only("placeholder usage")
|
@test_util.run_v1_only("placeholder usage")
|
||||||
@ -1098,7 +1099,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
"old_file.txt", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
"old_file.txt", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
||||||
self.assertTrue("old_file.txt_-2_-1" in init._shared_name)
|
self.assertTrue("old_file.txt_-2_-1" in init._shared_name)
|
||||||
table = lookup_ops.StaticHashTable(init, default_value)
|
table = self.getHashTable()(init, default_value)
|
||||||
|
|
||||||
# Initialize with non existing file (old_file.txt) should fail.
|
# Initialize with non existing file (old_file.txt) should fail.
|
||||||
# TODO(yleon): Update message, which might change per FileSystem.
|
# TODO(yleon): Update message, which might change per FileSystem.
|
||||||
@ -1124,7 +1125,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
# Invalid data type
|
# Invalid data type
|
||||||
other_type = constant_op.constant(1)
|
other_type = constant_op.constant(1)
|
||||||
with self.assertRaises(Exception) as cm:
|
with self.assertRaises(Exception) as cm:
|
||||||
lookup_ops.StaticHashTable(
|
self.getHashTable()(
|
||||||
lookup_ops.TextFileInitializer(
|
lookup_ops.TextFileInitializer(
|
||||||
other_type, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
other_type, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
|
||||||
@ -1135,7 +1136,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
filenames = constant_op.constant([vocabulary_file, vocabulary_file])
|
filenames = constant_op.constant([vocabulary_file, vocabulary_file])
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
with self.assertRaises(Exception) as cm:
|
with self.assertRaises(Exception) as cm:
|
||||||
lookup_ops.StaticHashTable(
|
self.getHashTable()(
|
||||||
lookup_ops.TextFileInitializer(
|
lookup_ops.TextFileInitializer(
|
||||||
filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
|
||||||
@ -1143,7 +1144,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
self.assertTrue(isinstance(cm.exception, (ValueError, TypeError)))
|
self.assertTrue(isinstance(cm.exception, (ValueError, TypeError)))
|
||||||
else:
|
else:
|
||||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
lookup_ops.StaticHashTable(
|
self.getHashTable()(
|
||||||
lookup_ops.TextFileInitializer(
|
lookup_ops.TextFileInitializer(
|
||||||
filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
|
||||||
@ -1157,9 +1158,9 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
init = lookup_ops.TextFileStringTableInitializer(
|
init = lookup_ops.TextFileStringTableInitializer(
|
||||||
vocab_file, vocab_size=vocab_size)
|
vocab_file, vocab_size=vocab_size)
|
||||||
self.assertTrue("feat_to_id_1.txt_3_-1_-2", init._shared_name)
|
self.assertTrue("feat_to_id_1.txt_3_-1_-2", init._shared_name)
|
||||||
table = lookup_ops.StaticHashTable(init, default_value)
|
table = self.getHashTable()(init, default_value)
|
||||||
|
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
||||||
|
|
||||||
@ -1176,8 +1177,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
init = lookup_ops.TextFileIdTableInitializer(
|
init = lookup_ops.TextFileIdTableInitializer(
|
||||||
vocab_file, vocab_size=vocab_size)
|
vocab_file, vocab_size=vocab_size)
|
||||||
self.assertTrue("feat_to_id_2.txt_3_-1_-2", init._shared_name)
|
self.assertTrue("feat_to_id_2.txt_3_-1_-2", init._shared_name)
|
||||||
table = lookup_ops.StaticHashTable(init, default_value)
|
table = self.getHashTable()(init, default_value)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
|
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
|
||||||
|
|
||||||
@ -1194,8 +1195,8 @@ class InitializeTableFromFileOpTest(test.TestCase):
|
|||||||
init = lookup_ops.TextFileIdTableInitializer(
|
init = lookup_ops.TextFileIdTableInitializer(
|
||||||
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64)
|
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64)
|
||||||
self.assertTrue("feat_to_id_3.txt_3_-1_-2", init._shared_name)
|
self.assertTrue("feat_to_id_3.txt_3_-1_-2", init._shared_name)
|
||||||
table = lookup_ops.StaticHashTable(init, default_value)
|
table = self.getHashTable()(init, default_value)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
out = table.lookup(
|
out = table.lookup(
|
||||||
constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64))
|
constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64))
|
||||||
@ -2885,8 +2886,7 @@ class DenseHashTableBenchmark(MutableHashTableBenchmark):
|
|||||||
deleted_key=-2)
|
deleted_key=-2)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
class StaticVocabularyTableTest(BaseLookupTableTest):
|
||||||
class StaticVocabularyTableTest(test.TestCase):
|
|
||||||
|
|
||||||
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
||||||
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
||||||
@ -2896,14 +2896,13 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
|
|
||||||
def testStringStaticVocabularyTable(self):
|
def testStringStaticVocabularyTable(self):
|
||||||
vocab_file = self._createVocabFile("feat_to_id_1.txt")
|
vocab_file = self._createVocabFile("feat_to_id_1.txt")
|
||||||
with self.cached_session():
|
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
oov_buckets = 1
|
oov_buckets = 1
|
||||||
table = lookup_ops.StaticVocabularyTable(
|
table = self.getVocabularyTable()(
|
||||||
lookup_ops.TextFileIdTableInitializer(
|
lookup_ops.TextFileIdTableInitializer(
|
||||||
vocab_file, vocab_size=vocab_size), oov_buckets)
|
vocab_file, vocab_size=vocab_size), oov_buckets)
|
||||||
|
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
|
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
|
||||||
|
|
||||||
@ -2913,16 +2912,15 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
|
|
||||||
def testInt32StaticVocabularyTable(self):
|
def testInt32StaticVocabularyTable(self):
|
||||||
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
|
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
|
||||||
with self.cached_session():
|
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
oov_buckets = 1
|
oov_buckets = 1
|
||||||
table = lookup_ops.StaticVocabularyTable(
|
table = self.getVocabularyTable()(
|
||||||
lookup_ops.TextFileIdTableInitializer(
|
lookup_ops.TextFileIdTableInitializer(
|
||||||
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
|
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
|
||||||
oov_buckets,
|
oov_buckets,
|
||||||
lookup_key_dtype=dtypes.int32)
|
lookup_key_dtype=dtypes.int32)
|
||||||
|
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32)
|
values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32)
|
||||||
|
|
||||||
@ -2932,15 +2930,14 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
|
|
||||||
def testInt64StaticVocabularyTable(self):
|
def testInt64StaticVocabularyTable(self):
|
||||||
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
|
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
|
||||||
with self.cached_session():
|
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
oov_buckets = 1
|
oov_buckets = 1
|
||||||
table = lookup_ops.StaticVocabularyTable(
|
table = self.getVocabularyTable()(
|
||||||
lookup_ops.TextFileIdTableInitializer(
|
lookup_ops.TextFileIdTableInitializer(
|
||||||
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
|
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
|
||||||
oov_buckets)
|
oov_buckets)
|
||||||
|
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)
|
values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)
|
||||||
|
|
||||||
@ -2949,13 +2946,12 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
|
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
|
||||||
|
|
||||||
def testStringStaticVocabularyTableNoInitializer(self):
|
def testStringStaticVocabularyTableNoInitializer(self):
|
||||||
with self.cached_session():
|
|
||||||
oov_buckets = 5
|
oov_buckets = 5
|
||||||
|
|
||||||
# Set a table that only uses hash buckets, for each input value returns
|
# Set a table that only uses hash buckets, for each input value returns
|
||||||
# an id calculated by fingerprint("input") mod oov_buckets.
|
# an id calculated by fingerprint("input") mod oov_buckets.
|
||||||
table = lookup_ops.StaticVocabularyTable(None, oov_buckets)
|
table = self.getVocabularyTable()(None, oov_buckets)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
values = constant_op.constant(("brain", "salad", "surgery"))
|
values = constant_op.constant(("brain", "salad", "surgery"))
|
||||||
|
|
||||||
@ -2971,16 +2967,15 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
|
|
||||||
def testStaticVocabularyTableWithMultipleInitializers(self):
|
def testStaticVocabularyTableWithMultipleInitializers(self):
|
||||||
vocab_file = self._createVocabFile("feat_to_id_4.txt")
|
vocab_file = self._createVocabFile("feat_to_id_4.txt")
|
||||||
with self.cached_session():
|
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
oov_buckets = 3
|
oov_buckets = 3
|
||||||
|
|
||||||
init = lookup_ops.TextFileIdTableInitializer(
|
init = lookup_ops.TextFileIdTableInitializer(
|
||||||
vocab_file, vocab_size=vocab_size)
|
vocab_file, vocab_size=vocab_size)
|
||||||
table1 = lookup_ops.StaticVocabularyTable(
|
table1 = self.getVocabularyTable()(
|
||||||
init, oov_buckets, name="table1")
|
init, oov_buckets, name="table1")
|
||||||
|
|
||||||
table2 = lookup_ops.StaticVocabularyTable(
|
table2 = self.getVocabularyTable()(
|
||||||
init, oov_buckets, name="table2")
|
init, oov_buckets, name="table2")
|
||||||
|
|
||||||
self.evaluate(lookup_ops.tables_initializer())
|
self.evaluate(lookup_ops.tables_initializer())
|
||||||
@ -3002,11 +2997,11 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
oov_buckets = 1
|
oov_buckets = 1
|
||||||
table1 = lookup_ops.StaticVocabularyTable(
|
table1 = self.getVocabularyTable()(
|
||||||
lookup_ops.TextFileIdTableInitializer(
|
lookup_ops.TextFileIdTableInitializer(
|
||||||
vocab_file, vocab_size=vocab_size), oov_buckets)
|
vocab_file, vocab_size=vocab_size), oov_buckets)
|
||||||
|
|
||||||
self.evaluate(table1.initializer)
|
self.initialize_table(table1)
|
||||||
|
|
||||||
input_string_1 = constant_op.constant(
|
input_string_1 = constant_op.constant(
|
||||||
["brain", "salad", "surgery", "UNK"])
|
["brain", "salad", "surgery", "UNK"])
|
||||||
@ -3022,7 +3017,7 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
|
|
||||||
# Underlying lookup table already initialized in previous session.
|
# Underlying lookup table already initialized in previous session.
|
||||||
# No need to initialize table2
|
# No need to initialize table2
|
||||||
table2 = lookup_ops.StaticVocabularyTable(
|
table2 = self.getVocabularyTable()(
|
||||||
lookup_ops.TextFileIdTableInitializer(
|
lookup_ops.TextFileIdTableInitializer(
|
||||||
vocab_file, vocab_size=vocab_size), oov_buckets)
|
vocab_file, vocab_size=vocab_size), oov_buckets)
|
||||||
|
|
||||||
@ -3037,22 +3032,21 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
vocab_file = self._createVocabFile("feat_to_id_7.txt")
|
vocab_file = self._createVocabFile("feat_to_id_7.txt")
|
||||||
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
||||||
input_shape = [4, 4]
|
input_shape = [4, 4]
|
||||||
with self.cached_session() as sess:
|
|
||||||
sp_features = sparse_tensor.SparseTensor(
|
sp_features = sparse_tensor.SparseTensor(
|
||||||
constant_op.constant(input_indices, dtypes.int64),
|
constant_op.constant(input_indices, dtypes.int64),
|
||||||
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
|
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
|
||||||
dtypes.string),
|
dtypes.string),
|
||||||
constant_op.constant(input_shape, dtypes.int64))
|
constant_op.constant(input_shape, dtypes.int64))
|
||||||
|
|
||||||
table = lookup_ops.StaticVocabularyTable(
|
table = self.getVocabularyTable()(
|
||||||
lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3), 1)
|
lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3), 1)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
sp_ids = table.lookup(sp_features)
|
sp_ids = table.lookup(sp_features)
|
||||||
|
|
||||||
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
||||||
|
|
||||||
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
|
sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
|
||||||
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
||||||
|
|
||||||
self.assertAllEqual(input_indices, sp_ids_ind)
|
self.assertAllEqual(input_indices, sp_ids_ind)
|
||||||
@ -3062,24 +3056,23 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
def testInt32SparseTensor(self):
|
def testInt32SparseTensor(self):
|
||||||
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
||||||
input_shape = [4, 4]
|
input_shape = [4, 4]
|
||||||
with self.cached_session() as sess:
|
|
||||||
sp_features = sparse_tensor.SparseTensor(
|
sp_features = sparse_tensor.SparseTensor(
|
||||||
constant_op.constant(input_indices, dtypes.int64),
|
constant_op.constant(input_indices, dtypes.int64),
|
||||||
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
|
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
|
||||||
constant_op.constant(input_shape, dtypes.int64))
|
constant_op.constant(input_shape, dtypes.int64))
|
||||||
|
|
||||||
table = lookup_ops.StaticVocabularyTable(
|
table = self.getVocabularyTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
||||||
dtypes.int64, dtypes.int64),
|
dtypes.int64, dtypes.int64),
|
||||||
1,
|
1,
|
||||||
lookup_key_dtype=dtypes.int32)
|
lookup_key_dtype=dtypes.int32)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
sp_ids = table.lookup(sp_features)
|
sp_ids = table.lookup(sp_features)
|
||||||
|
|
||||||
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
||||||
|
|
||||||
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
|
sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
|
||||||
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
||||||
|
|
||||||
self.assertAllEqual(input_indices, sp_ids_ind)
|
self.assertAllEqual(input_indices, sp_ids_ind)
|
||||||
@ -3089,22 +3082,21 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
def testInt64SparseTensor(self):
|
def testInt64SparseTensor(self):
|
||||||
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
||||||
input_shape = [4, 4]
|
input_shape = [4, 4]
|
||||||
with self.cached_session() as sess:
|
|
||||||
sp_features = sparse_tensor.SparseTensor(
|
sp_features = sparse_tensor.SparseTensor(
|
||||||
constant_op.constant(input_indices, dtypes.int64),
|
constant_op.constant(input_indices, dtypes.int64),
|
||||||
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
|
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
|
||||||
constant_op.constant(input_shape, dtypes.int64))
|
constant_op.constant(input_shape, dtypes.int64))
|
||||||
|
|
||||||
table = lookup_ops.StaticVocabularyTable(
|
table = self.getVocabularyTable()(
|
||||||
lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
||||||
dtypes.int64, dtypes.int64), 1)
|
dtypes.int64, dtypes.int64), 1)
|
||||||
self.evaluate(table.initializer)
|
self.initialize_table(table)
|
||||||
|
|
||||||
sp_ids = table.lookup(sp_features)
|
sp_ids = table.lookup(sp_features)
|
||||||
|
|
||||||
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
||||||
|
|
||||||
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
|
sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
|
||||||
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
||||||
|
|
||||||
self.assertAllEqual(input_indices, sp_ids_ind)
|
self.assertAllEqual(input_indices, sp_ids_ind)
|
||||||
@ -3112,8 +3104,7 @@ class StaticVocabularyTableTest(test.TestCase):
|
|||||||
self.assertAllEqual(input_shape, sp_ids_shape)
|
self.assertAllEqual(input_shape, sp_ids_shape)
|
||||||
|
|
||||||
def testStaticVocabularyTableNoInnerTable(self):
|
def testStaticVocabularyTableNoInnerTable(self):
|
||||||
with self.cached_session():
|
table = self.getVocabularyTable()(None, num_oov_buckets=1)
|
||||||
table = lookup_ops.StaticVocabularyTable(None, num_oov_buckets=1)
|
|
||||||
self.assertIsNone(table.resource_handle)
|
self.assertIsNone(table.resource_handle)
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,15 +174,6 @@ class InitializableLookupTableBase(LookupInterface):
|
|||||||
def _initialize(self):
|
def _initialize(self):
|
||||||
return self._initializer.initialize(self)
|
return self._initializer.initialize(self)
|
||||||
|
|
||||||
@property
|
|
||||||
def initializer(self):
|
|
||||||
return self._init_op
|
|
||||||
|
|
||||||
@property
|
|
||||||
@deprecated("2018-12-15", "Use `initializer` instead.")
|
|
||||||
def init(self):
|
|
||||||
return self.initializer
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_value(self):
|
def default_value(self):
|
||||||
"""The default value of the table."""
|
"""The default value of the table."""
|
||||||
@ -237,6 +228,14 @@ class InitializableLookupTableBase(LookupInterface):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class InitializableLookupTableBaseV1(InitializableLookupTableBase):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def initializer(self):
|
||||||
|
return self._init_op
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("lookup.StaticHashTable", v1=[])
|
||||||
class StaticHashTable(InitializableLookupTableBase):
|
class StaticHashTable(InitializableLookupTableBase):
|
||||||
"""A generic hash table implementation.
|
"""A generic hash table implementation.
|
||||||
|
|
||||||
@ -311,8 +310,20 @@ class StaticHashTable(InitializableLookupTableBase):
|
|||||||
return exported_keys, exported_values
|
return exported_keys, exported_values
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export(v1=["lookup.StaticHashTable"])
|
||||||
|
class StaticHashTableV1(StaticHashTable):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def initializer(self):
|
||||||
|
return self._init_op
|
||||||
|
|
||||||
|
|
||||||
# For backwards compatibility. This will be removed in TF 2.0.
|
# For backwards compatibility. This will be removed in TF 2.0.
|
||||||
HashTable = StaticHashTable
|
class HashTable(StaticHashTableV1):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init(self):
|
||||||
|
return self.initializer
|
||||||
|
|
||||||
|
|
||||||
class TableInitializerBase(trackable_base.Trackable):
|
class TableInitializerBase(trackable_base.Trackable):
|
||||||
@ -354,6 +365,7 @@ class TableInitializerBase(trackable_base.Trackable):
|
|||||||
return shared_name
|
return shared_name
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("lookup.KeyValueTensorInitializer")
|
||||||
class KeyValueTensorInitializer(TableInitializerBase):
|
class KeyValueTensorInitializer(TableInitializerBase):
|
||||||
"""Table initializers given `keys` and `values` tensors."""
|
"""Table initializers given `keys` and `values` tensors."""
|
||||||
|
|
||||||
@ -412,6 +424,7 @@ class TextFileIndex(object):
|
|||||||
LINE_NUMBER = -1
|
LINE_NUMBER = -1
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("lookup.TextFileInitializer")
|
||||||
class TextFileInitializer(TableInitializerBase):
|
class TextFileInitializer(TableInitializerBase):
|
||||||
"""Table initializers from a text file.
|
"""Table initializers from a text file.
|
||||||
|
|
||||||
@ -951,6 +964,7 @@ class IdTableWithHashBuckets(LookupInterface):
|
|||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("lookup.StaticVocabularyTable", v1=[])
|
||||||
class StaticVocabularyTable(LookupInterface):
|
class StaticVocabularyTable(LookupInterface):
|
||||||
"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
|
"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
|
||||||
|
|
||||||
@ -1063,18 +1077,6 @@ class StaticVocabularyTable(LookupInterface):
|
|||||||
with ops.name_scope(None, "init"):
|
with ops.name_scope(None, "init"):
|
||||||
return control_flow_ops.no_op()
|
return control_flow_ops.no_op()
|
||||||
|
|
||||||
@property
|
|
||||||
def initializer(self):
|
|
||||||
if self._table is not None:
|
|
||||||
return self._table._init_op # pylint: disable=protected-access
|
|
||||||
with ops.name_scope(None, "init"):
|
|
||||||
return control_flow_ops.no_op()
|
|
||||||
|
|
||||||
@property
|
|
||||||
@deprecated("2018-12-15", "Use `initializer` instead.")
|
|
||||||
def init(self):
|
|
||||||
return self.initializer
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resource_handle(self):
|
def resource_handle(self):
|
||||||
if self._table is not None:
|
if self._table is not None:
|
||||||
@ -1136,6 +1138,17 @@ class StaticVocabularyTable(LookupInterface):
|
|||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export(v1=["lookup.StaticVocabularyTable"])
|
||||||
|
class StaticVocabularyTableV1(StaticVocabularyTable):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def initializer(self):
|
||||||
|
if self._table is not None:
|
||||||
|
return self._table._init_op # pylint: disable=protected-access
|
||||||
|
with ops.name_scope(None, "init"):
|
||||||
|
return control_flow_ops.no_op()
|
||||||
|
|
||||||
|
|
||||||
def index_table_from_file(vocabulary_file=None,
|
def index_table_from_file(vocabulary_file=None,
|
||||||
num_oov_buckets=0,
|
num_oov_buckets=0,
|
||||||
vocab_size=None,
|
vocab_size=None,
|
||||||
@ -1244,7 +1257,7 @@ def index_table_from_file(vocabulary_file=None,
|
|||||||
value_column_index=value_column_index,
|
value_column_index=value_column_index,
|
||||||
delimiter=delimiter)
|
delimiter=delimiter)
|
||||||
|
|
||||||
table = StaticHashTable(init, default_value)
|
table = StaticHashTableV1(init, default_value)
|
||||||
if num_oov_buckets:
|
if num_oov_buckets:
|
||||||
table = IdTableWithHashBuckets(
|
table = IdTableWithHashBuckets(
|
||||||
table,
|
table,
|
||||||
@ -1341,7 +1354,7 @@ def index_table_from_tensor(vocabulary_list,
|
|||||||
table_keys.dtype.base_dtype,
|
table_keys.dtype.base_dtype,
|
||||||
dtypes.int64,
|
dtypes.int64,
|
||||||
name="table_init")
|
name="table_init")
|
||||||
table = StaticHashTable(init, default_value)
|
table = StaticHashTableV1(init, default_value)
|
||||||
if num_oov_buckets:
|
if num_oov_buckets:
|
||||||
table = IdTableWithHashBuckets(
|
table = IdTableWithHashBuckets(
|
||||||
table,
|
table,
|
||||||
@ -1438,7 +1451,7 @@ def index_to_string_table_from_file(vocabulary_file,
|
|||||||
delimiter=delimiter)
|
delimiter=delimiter)
|
||||||
|
|
||||||
# TODO(yleon): Use a more effienct structure.
|
# TODO(yleon): Use a more effienct structure.
|
||||||
return StaticHashTable(init, default_value)
|
return StaticHashTableV1(init, default_value)
|
||||||
|
|
||||||
|
|
||||||
def index_to_string_table_from_tensor(vocabulary_list,
|
def index_to_string_table_from_tensor(vocabulary_list,
|
||||||
@ -1499,7 +1512,7 @@ def index_to_string_table_from_tensor(vocabulary_list,
|
|||||||
init = KeyValueTensorInitializer(
|
init = KeyValueTensorInitializer(
|
||||||
keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init")
|
keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init")
|
||||||
# TODO(yleon): Use a more effienct structure.
|
# TODO(yleon): Use a more effienct structure.
|
||||||
return StaticHashTable(init, default_value)
|
return StaticHashTableV1(init, default_value)
|
||||||
|
|
||||||
|
|
||||||
class MutableHashTable(LookupInterface):
|
class MutableHashTable(LookupInterface):
|
||||||
@ -1733,6 +1746,7 @@ class MutableHashTable(LookupInterface):
|
|||||||
self.op.resource_handle, restored_tensors[0], restored_tensors[1])
|
self.op.resource_handle, restored_tensors[0], restored_tensors[1])
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("lookup.experimental.DenseHashTable")
|
||||||
class DenseHashTable(LookupInterface):
|
class DenseHashTable(LookupInterface):
|
||||||
"""A generic mutable hash table implementation using tensors as backing store.
|
"""A generic mutable hash table implementation using tensors as backing store.
|
||||||
|
|
||||||
|
@ -30,6 +30,8 @@ TENSORFLOW_API_INIT_FILES = [
|
|||||||
"lite/constants/__init__.py",
|
"lite/constants/__init__.py",
|
||||||
"lite/experimental/__init__.py",
|
"lite/experimental/__init__.py",
|
||||||
"lite/experimental/nn/__init__.py",
|
"lite/experimental/nn/__init__.py",
|
||||||
|
"lookup/__init__.py",
|
||||||
|
"lookup/experimental/__init__.py",
|
||||||
"math/__init__.py",
|
"math/__init__.py",
|
||||||
"nest/__init__.py",
|
"nest/__init__.py",
|
||||||
"nn/__init__.py",
|
"nn/__init__.py",
|
||||||
|
@ -37,6 +37,8 @@ TENSORFLOW_API_INIT_FILES_V1 = [
|
|||||||
"lite/experimental/__init__.py",
|
"lite/experimental/__init__.py",
|
||||||
"lite/experimental/nn/__init__.py",
|
"lite/experimental/nn/__init__.py",
|
||||||
"logging/__init__.py",
|
"logging/__init__.py",
|
||||||
|
"lookup/__init__.py",
|
||||||
|
"lookup/experimental/__init__.py",
|
||||||
"losses/__init__.py",
|
"losses/__init__.py",
|
||||||
"manip/__init__.py",
|
"manip/__init__.py",
|
||||||
"math/__init__.py",
|
"math/__init__.py",
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
path: "tensorflow.lookup.KeyValueTensorInitializer"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.KeyValueTensorInitializer\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "key_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "value_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'values\', \'key_dtype\', \'value_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "initialize"
|
||||||
|
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,50 @@
|
|||||||
|
path: "tensorflow.lookup.StaticHashTable"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticHashTableV1\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticHashTable\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.InitializableLookupTableBase\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "default_value"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "initializer"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "key_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "resource_handle"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "value_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initializer\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "export"
|
||||||
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "lookup"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "size"
|
||||||
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,41 @@
|
|||||||
|
path: "tensorflow.lookup.StaticVocabularyTable"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticVocabularyTableV1\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticVocabularyTable\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "initializer"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "key_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "resource_handle"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "value_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initializer\', \'num_oov_buckets\', \'lookup_key_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "lookup"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "size"
|
||||||
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
path: "tensorflow.lookup.TextFileInitializer"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TextFileInitializer\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "key_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "value_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "initialize"
|
||||||
|
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,56 @@
|
|||||||
|
path: "tensorflow.lookup.experimental.DenseHashTable"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DenseHashTable\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "key_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "resource_handle"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "value_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'key_dtype\', \'value_dtype\', \'default_value\', \'empty_key\', \'deleted_key\', \'initial_num_buckets\', \'name\', \'checkpoint\'], varargs=None, keywords=None, defaults=[\'None\', \'MutableDenseHashTable\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "erase"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "export"
|
||||||
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "insert"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "insert_or_assign"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "lookup"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "remove"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "size"
|
||||||
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,7 @@
|
|||||||
|
path: "tensorflow.lookup.experimental"
|
||||||
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "DenseHashTable"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
}
|
23
tensorflow/tools/api/golden/v1/tensorflow.lookup.pbtxt
Normal file
23
tensorflow/tools/api/golden/v1/tensorflow.lookup.pbtxt
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
path: "tensorflow.lookup"
|
||||||
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "KeyValueTensorInitializer"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "StaticHashTable"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "StaticVocabularyTable"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "TextFileInitializer"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "experimental"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
|
}
|
@ -464,6 +464,10 @@ tf_module {
|
|||||||
name: "logging"
|
name: "logging"
|
||||||
mtype: "<type \'module\'>"
|
mtype: "<type \'module\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "lookup"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "losses"
|
name: "losses"
|
||||||
mtype: "<type \'module\'>"
|
mtype: "<type \'module\'>"
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
path: "tensorflow.lookup.KeyValueTensorInitializer"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.KeyValueTensorInitializer\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "key_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "value_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'values\', \'key_dtype\', \'value_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "initialize"
|
||||||
|
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,45 @@
|
|||||||
|
path: "tensorflow.lookup.StaticHashTable"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticHashTable\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.InitializableLookupTableBase\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "default_value"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "key_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "resource_handle"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "value_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initializer\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "export"
|
||||||
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "lookup"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "size"
|
||||||
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,36 @@
|
|||||||
|
path: "tensorflow.lookup.StaticVocabularyTable"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticVocabularyTable\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "key_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "resource_handle"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "value_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initializer\', \'num_oov_buckets\', \'lookup_key_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "lookup"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "size"
|
||||||
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
path: "tensorflow.lookup.TextFileInitializer"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TextFileInitializer\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "key_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "value_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "initialize"
|
||||||
|
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,56 @@
|
|||||||
|
path: "tensorflow.lookup.experimental.DenseHashTable"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DenseHashTable\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "key_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "resource_handle"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "value_dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'key_dtype\', \'value_dtype\', \'default_value\', \'empty_key\', \'deleted_key\', \'initial_num_buckets\', \'name\', \'checkpoint\'], varargs=None, keywords=None, defaults=[\'None\', \'MutableDenseHashTable\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "erase"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "export"
|
||||||
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "insert"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "insert_or_assign"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "lookup"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "remove"
|
||||||
|
argspec: "args=[\'self\', \'keys\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "size"
|
||||||
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,7 @@
|
|||||||
|
path: "tensorflow.lookup.experimental"
|
||||||
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "DenseHashTable"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
}
|
23
tensorflow/tools/api/golden/v2/tensorflow.lookup.pbtxt
Normal file
23
tensorflow/tools/api/golden/v2/tensorflow.lookup.pbtxt
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
path: "tensorflow.lookup"
|
||||||
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "KeyValueTensorInitializer"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "StaticHashTable"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "StaticVocabularyTable"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "TextFileInitializer"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "experimental"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
|
}
|
@ -212,6 +212,10 @@ tf_module {
|
|||||||
name: "lite"
|
name: "lite"
|
||||||
mtype: "<type \'module\'>"
|
mtype: "<type \'module\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "lookup"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "losses"
|
name: "losses"
|
||||||
mtype: "<type \'module\'>"
|
mtype: "<type \'module\'>"
|
||||||
|
Loading…
Reference in New Issue
Block a user