3456 lines
134 KiB
Python
3456 lines
134 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for lookup ops."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
import numpy as np
|
|
import six
|
|
|
|
from tensorflow.python import tf2
|
|
from tensorflow.python.client import session
|
|
from tensorflow.python.data.experimental.ops import counter
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.data.ops import readers as reader_ops
|
|
from tensorflow.python.eager import backprop
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import function
|
|
from tensorflow.python.eager import wrap_function
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors_impl
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import lookup_ops
|
|
from tensorflow.python.ops import map_fn
|
|
from tensorflow.python.ops import string_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.training import saver
|
|
from tensorflow.python.training import server_lib
|
|
from tensorflow.python.training.tracking import graph_view
|
|
from tensorflow.python.training.tracking import tracking
|
|
from tensorflow.python.training.tracking import util as trackable
|
|
from tensorflow.python.util import compat
|
|
|
|
|
|
class BaseLookupTableTest(test.TestCase):
|
|
|
|
def getHashTable(self):
|
|
if tf2.enabled():
|
|
return lookup_ops.StaticHashTable
|
|
else:
|
|
return lookup_ops.StaticHashTableV1
|
|
|
|
def getVocabularyTable(self):
|
|
if tf2.enabled():
|
|
return lookup_ops.StaticVocabularyTable
|
|
else:
|
|
return lookup_ops.StaticVocabularyTableV1
|
|
|
|
def initialize_table(self, table):
|
|
if not tf2.enabled():
|
|
self.evaluate(table.initializer)
|
|
|
|
|
|
class StaticHashTableTest(BaseLookupTableTest):
|
|
|
|
def testStaticHashTable(self):
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
self.initialize_table(table)
|
|
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([3], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
exported_keys_tensor, exported_values_tensor = table.export()
|
|
|
|
self.assertItemsEqual([b"brain", b"salad", b"surgery"],
|
|
self.evaluate(exported_keys_tensor))
|
|
self.assertItemsEqual([0, 1, 2], self.evaluate(exported_values_tensor))
|
|
|
|
def testStaticHashTableFindHighRank(self):
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
self.initialize_table(table)
|
|
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([["brain", "salad"],
|
|
["tank", "tarkus"]])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([[0, 1], [-1, -1]], result)
|
|
|
|
def testStaticHashTableInitWithPythonArrays(self):
|
|
default_val = -1
|
|
keys = ["brain", "salad", "surgery"]
|
|
values = [0, 1, 2]
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(
|
|
keys, values, value_dtype=dtypes.int64), default_val)
|
|
self.initialize_table(table)
|
|
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
def testStaticHashTableInitWithNumPyArrays(self):
|
|
default_val = -1
|
|
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
|
|
values = np.array([0, 1, 2], dtype=np.int64)
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
self.initialize_table(table)
|
|
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
def testMultipleStaticHashTables(self):
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
|
|
table1 = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
table2 = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
table3 = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
|
|
self.initialize_table(table1)
|
|
self.initialize_table(table2)
|
|
self.initialize_table(table3)
|
|
self.assertAllEqual(3, self.evaluate(table1.size()))
|
|
self.assertAllEqual(3, self.evaluate(table2.size()))
|
|
self.assertAllEqual(3, self.evaluate(table3.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output1 = table1.lookup(input_string)
|
|
output2 = table2.lookup(input_string)
|
|
output3 = table3.lookup(input_string)
|
|
|
|
out1, out2, out3 = self.evaluate([output1, output2, output3])
|
|
self.assertAllEqual([0, 1, -1], out1)
|
|
self.assertAllEqual([0, 1, -1], out2)
|
|
self.assertAllEqual([0, 1, -1], out3)
|
|
|
|
def testStaticHashTableWithTensorDefault(self):
|
|
default_val = constant_op.constant(-1, dtypes.int64)
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
self.initialize_table(table)
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
def testStaticHashTableWithSparseTensorInput(self):
|
|
default_val = constant_op.constant(-1, dtypes.int64)
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
self.initialize_table(table)
|
|
|
|
sp_indices = [[0, 0], [0, 1], [1, 0]]
|
|
sp_shape = [2, 2]
|
|
input_tensor = sparse_tensor.SparseTensor(
|
|
constant_op.constant(sp_indices, dtypes.int64),
|
|
constant_op.constant(["brain", "salad", "tank"]),
|
|
constant_op.constant(sp_shape, dtypes.int64))
|
|
output = table.lookup(input_tensor)
|
|
|
|
out_indices, out_values, out_shape = self.evaluate(output)
|
|
|
|
self.assertAllEqual([0, 1, -1], out_values)
|
|
self.assertAllEqual(sp_indices, out_indices)
|
|
self.assertAllEqual(sp_shape, out_shape)
|
|
|
|
def testSignatureMismatch(self):
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
self.initialize_table(table)
|
|
|
|
# Ref types do not produce a lookup signature mismatch.
|
|
input_string_ref = variables.Variable("brain")
|
|
self.evaluate(input_string_ref.initializer)
|
|
self.assertEqual(0, self.evaluate(table.lookup(input_string_ref)))
|
|
|
|
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
|
|
with self.assertRaises(TypeError):
|
|
table.lookup(input_string)
|
|
|
|
with self.assertRaises(TypeError):
|
|
self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), "UNK")
|
|
|
|
def testDTypes(self):
|
|
default_val = -1
|
|
with self.assertRaises(TypeError):
|
|
self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(["a"], [1], [dtypes.string],
|
|
dtypes.int64), default_val)
|
|
|
|
@test_util.run_v1_only("(Cached) Sessions not available in TF2.0")
|
|
def testNotInitialized(self):
|
|
with self.cached_session():
|
|
default_val = -1
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(["a"], [1],
|
|
value_dtype=dtypes.int64),
|
|
default_val)
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "surgery"])
|
|
output = table.lookup(input_string)
|
|
|
|
with self.assertRaisesOpError("Table not initialized"):
|
|
self.evaluate(output)
|
|
|
|
@test_util.run_v1_only("(Cached) Sessions not available in TF2.0")
|
|
def testInitializeTwice(self):
|
|
with self.cached_session():
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
self.initialize_table(table)
|
|
# Make sure that initializing twice doesn't throw any errors.
|
|
self.initialize_table(table)
|
|
|
|
def testInitializationWithInvalidDimensions(self):
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
|
|
|
|
raised_error = ValueError
|
|
if context.executing_eagerly():
|
|
raised_error = errors_impl.InvalidArgumentError
|
|
with self.assertRaises(raised_error):
|
|
self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
|
|
@test_util.run_v1_only("Sessions not available in TF2.0")
|
|
def testMultipleSessions(self):
|
|
# Start a server
|
|
server = server_lib.Server({"local0": ["localhost:0"]},
|
|
protocol="grpc",
|
|
start=True)
|
|
# Create two sessions sharing the same state
|
|
session1 = session.Session(server.target)
|
|
session2 = session.Session(server.target)
|
|
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values),
|
|
default_val,
|
|
name="t1")
|
|
|
|
# Init the table in the first session.
|
|
with session1:
|
|
self.initialize_table(table)
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
# Init the table in the second session and verify that we do not get a
|
|
# "Table already initialized" error.
|
|
with session2:
|
|
table.initializer.run()
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
@test_util.run_v2_only
|
|
def testImportedHashTable(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
t = lookup_ops.StaticHashTable(
|
|
lookup_ops.KeyValueTensorInitializer(["a"], [1]),
|
|
2)
|
|
init_op = t._init_op
|
|
op = t.lookup(ops.convert_to_tensor(["a"]))
|
|
meta_graph = saver.export_meta_graph()
|
|
|
|
def f():
|
|
saver.import_meta_graph(meta_graph)
|
|
return ops.get_default_graph().get_tensor_by_name(op.name)
|
|
|
|
wrapped = wrap_function.wrap_function(f, [])
|
|
pruned_init_fn = wrapped.prune(
|
|
(), [wrapped.graph.get_operation_by_name(init_op.name)])
|
|
self.evaluate(pruned_init_fn())
|
|
self.assertAllEqual([1], wrapped())
|
|
|
|
def testStaticHashTableInt32String(self):
|
|
default_val = "n/a"
|
|
keys = constant_op.constant([0, 1, 2], dtypes.int32)
|
|
values = constant_op.constant(["brain", "salad", "surgery"])
|
|
table = self.getHashTable()(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
|
self.initialize_table(table)
|
|
|
|
input_tensor = constant_op.constant([0, 1, -1])
|
|
output = table.lookup(input_tensor)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([b"brain", b"salad", b"n/a"], result)
|
|
|
|
def testTableUseInFunction(self):
|
|
if not context.executing_eagerly():
|
|
self.skipTest("Only Eager mode test.")
|
|
keys = constant_op.constant([0, 1, 2], dtypes.int32)
|
|
values = constant_op.constant(["brain", "salad", "surgery"])
|
|
table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
|
|
keys, values), "n/a")
|
|
|
|
@function.defun()
|
|
def lookup_table_func(k):
|
|
return table.lookup(k)
|
|
|
|
result = lookup_table_func(constant_op.constant([0, 1, -1]))
|
|
self.assertAllEqual([b"brain", b"salad", b"n/a"], result)
|
|
result = lookup_table_func(constant_op.constant([2, -1, 1]))
|
|
self.assertAllEqual([b"surgery", b"n/a", b"salad"], result)
|
|
|
|
def testTableCreatedInFunction(self):
|
|
if not context.executing_eagerly():
|
|
self.skipTest("Only Eager mode test.")
|
|
keys = constant_op.constant([0, 1, 2], dtypes.int32)
|
|
values = constant_op.constant(["brain", "salad", "surgery"])
|
|
|
|
@function.defun()
|
|
def lookup_table_func(k):
|
|
table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
|
|
keys, values), "n/a")
|
|
return table.lookup(k)
|
|
|
|
result = lookup_table_func(constant_op.constant([0, 1, -1]))
|
|
self.assertAllEqual([b"brain", b"salad", b"n/a"], result)
|
|
result = lookup_table_func(constant_op.constant([2, -1, 1]))
|
|
self.assertAllEqual([b"surgery", b"n/a", b"salad"], result)
|
|
|
|
def testTwoTablesInControlFlow(self):
|
|
keys = constant_op.constant([1, 2, 3], dtypes.int32)
|
|
values = constant_op.constant([5, 10, 15], dtypes.int32)
|
|
|
|
def table_func1(x):
|
|
table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
|
|
keys, values), -1)
|
|
return table.lookup(x)
|
|
|
|
elems = np.array([2, 4, 1], dtype=np.int32)
|
|
result1 = map_fn.map_fn(table_func1, elems, dtype=dtypes.int32)
|
|
|
|
def table_func2(x):
|
|
table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
|
|
keys, values), -1)
|
|
return table.lookup(x)
|
|
|
|
elems = np.array([2, 4, 1], dtype=np.int32)
|
|
result2 = map_fn.map_fn(table_func2, elems, dtype=dtypes.int32)
|
|
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
|
|
self.assertAllEqual([10, -1, 5], self.evaluate(result1))
|
|
self.assertAllEqual([10, -1, 5], self.evaluate(result2))
|
|
|
|
@test_util.enable_control_flow_v2
|
|
def testLookupTableInWhileV2(self):
|
|
lookup = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
|
|
constant_op.constant([2, 5], dtype=dtypes.int64),
|
|
constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1)
|
|
|
|
beta = variables.Variable(1.0, trainable=True)
|
|
|
|
@def_function.function
|
|
def get_loss(unused_beta):
|
|
return map_fn.map_fn(
|
|
lookup.lookup,
|
|
constant_op.constant([2, 3], dtype=dtypes.int64),
|
|
dtype=dtypes.float32)
|
|
|
|
with backprop.GradientTape() as tape:
|
|
loss = get_loss(beta)
|
|
|
|
self.assertIsNone(tape.gradient(loss, beta))
|
|
|
|
@test_util.enable_control_flow_v2
|
|
def testLookupTableInCondV2(self):
|
|
lookup = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
|
|
constant_op.constant([2, 5], dtype=dtypes.int64),
|
|
constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1)
|
|
|
|
beta = variables.Variable(1.0, trainable=True)
|
|
|
|
@def_function.function
|
|
def get_loss(beta):
|
|
|
|
def true_fn():
|
|
return lookup.lookup(constant_op.constant(2, dtype=dtypes.int64))
|
|
|
|
def false_fn():
|
|
return constant_op.constant(0, dtype=dtypes.float32)
|
|
|
|
return beta * control_flow_ops.cond(
|
|
constant_op.constant(True), true_fn=true_fn, false_fn=false_fn)
|
|
|
|
with backprop.GradientTape() as tape:
|
|
loss = get_loss(beta)
|
|
grad = tape.gradient(loss, beta)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual(grad, -10.)
|
|
|
|
|
|
class KeyValueTensorInitializerTest(BaseLookupTableTest):
|
|
|
|
def test_string(self):
|
|
init = lookup_ops.KeyValueTensorInitializer(
|
|
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
|
|
table = self.getHashTable()(init, default_value=-1)
|
|
self.initialize_table(table)
|
|
|
|
def test_multiple_tables(self):
|
|
with ops.name_scope("table_scope"):
|
|
init1 = lookup_ops.KeyValueTensorInitializer(
|
|
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
|
|
table1 = self.getHashTable()(init1, default_value=-1)
|
|
if not context.executing_eagerly():
|
|
self.assertEqual("hash_table", table1.name)
|
|
self.assertEqual("table_scope/hash_table",
|
|
table1.resource_handle.op.name)
|
|
init2 = lookup_ops.KeyValueTensorInitializer(
|
|
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
|
|
table2 = self.getHashTable()(init2, default_value=-1)
|
|
if not context.executing_eagerly():
|
|
self.assertEqual("hash_table_1", table2.name)
|
|
self.assertEqual("table_scope/hash_table_1",
|
|
table2.resource_handle.op.name)
|
|
|
|
def test_int64(self):
|
|
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
|
dtypes.int64, dtypes.int64)
|
|
table = self.getHashTable()(init, default_value=-1)
|
|
self.initialize_table(table)
|
|
|
|
def test_int32(self):
|
|
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
|
dtypes.int32, dtypes.int64)
|
|
with self.assertRaises(errors_impl.OpError):
|
|
table = self.getHashTable()(init, default_value=-1)
|
|
self.initialize_table(table)
|
|
|
|
|
|
class DatasetInitializerTest(BaseLookupTableTest):
|
|
|
|
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
|
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
|
with open(vocabulary_file, "w") as f:
|
|
f.write("\n".join(values) + "\n")
|
|
return vocabulary_file
|
|
|
|
def test_basic(self):
|
|
keys = dataset_ops.Dataset.range(100)
|
|
values = dataset_ops.Dataset.range(100).map(
|
|
lambda x: string_ops.as_string(x * 2))
|
|
ds = dataset_ops.Dataset.zip((keys, values))
|
|
init = lookup_ops.DatasetInitializer(ds)
|
|
table = self.getHashTable()(init, default_value="")
|
|
self.initialize_table(table)
|
|
|
|
output = table.lookup(constant_op.constant([0, 2, 5], dtypes.int64))
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual(["0", "4", "10"], result)
|
|
|
|
def test_basic_bad_shape(self):
|
|
keys = dataset_ops.Dataset.range(100)
|
|
values = dataset_ops.Dataset.range(100).map(
|
|
lambda x: string_ops.as_string(x * 2))
|
|
values = values.batch(4)
|
|
ds = dataset_ops.Dataset.zip((keys, values))
|
|
with self.assertRaises(ValueError):
|
|
lookup_ops.DatasetInitializer(ds)
|
|
|
|
def test_from_file(self):
|
|
vocabulary_file = self._createVocabFile("test.txt", ("one", "two", "three"))
|
|
ds = reader_ops.TextLineDataset(vocabulary_file)
|
|
ds = ds.enumerate(start=1)
|
|
init = lookup_ops.DatasetInitializer(ds)
|
|
table = self.getHashTable()(init, default_value="")
|
|
self.initialize_table(table)
|
|
|
|
output = table.lookup(constant_op.constant([2, 3, 4], dtypes.int64))
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual(["two", "three", ""], result)
|
|
|
|
def test_from_multiple_files(self):
|
|
vocabulary_file1 = self._createVocabFile("test1.txt",
|
|
("one", "two", "three"))
|
|
vocabulary_file2 = self._createVocabFile("test2.txt",
|
|
("four", "five", "six"))
|
|
ds = reader_ops.TextLineDataset([vocabulary_file1, vocabulary_file2])
|
|
ds = ds.enumerate(start=1)
|
|
init = lookup_ops.DatasetInitializer(ds)
|
|
table = self.getHashTable()(init, default_value="")
|
|
self.initialize_table(table)
|
|
|
|
output = table.lookup(constant_op.constant([2, 3, 4], dtypes.int64))
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual(["two", "three", "four"], result)
|
|
|
|
def test_map_variable(self):
|
|
ds = dataset_ops.Dataset.range(100)
|
|
captured_var = variables.Variable(0)
|
|
|
|
def func(_):
|
|
return captured_var.assign_add(1)
|
|
|
|
ds = ds.map(func)
|
|
ds = ds.enumerate(start=1)
|
|
init = lookup_ops.DatasetInitializer(ds)
|
|
table = self.getHashTable()(init, default_value=-1)
|
|
self.evaluate(captured_var.initializer)
|
|
self.initialize_table(table)
|
|
|
|
output = table.lookup(constant_op.constant([1, 2, 101], dtypes.int64))
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([1, 2, -1], result)
|
|
|
|
|
|
class InitializeTableFromFileOpTest(BaseLookupTableTest):
|
|
|
|
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
|
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
|
with open(vocabulary_file, "w") as f:
|
|
f.write("\n".join(values) + "\n")
|
|
return vocabulary_file
|
|
|
|
def testInitializeStringTable(self):
|
|
vocabulary_file = self._createVocabFile("one_column_1.txt")
|
|
default_value = -1
|
|
init = lookup_ops.TextFileInitializer(
|
|
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
|
self.assertIn("one_column_1.txt_-2_-1", init._shared_name)
|
|
table = self.getHashTable()(init, default_value)
|
|
self.initialize_table(table)
|
|
|
|
output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
def testInitializeInt64Table(self):
|
|
vocabulary_file = self._createVocabFile(
|
|
"one_column_int64.txt", values=("42", "1", "-1000"))
|
|
|
|
with self.cached_session():
|
|
default_value = -1
|
|
init = lookup_ops.TextFileInitializer(
|
|
vocabulary_file, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
|
self.assertIn("one_column_int64.txt_-2_-1", init._shared_name)
|
|
table = self.getHashTable()(init, default_value)
|
|
self.initialize_table(table)
|
|
|
|
output = table.lookup(
|
|
constant_op.constant((42, 1, 11), dtype=dtypes.int64))
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
def testInitializeIndexTable(self):
|
|
vocabulary_file = self._createVocabFile("one_column_2.txt")
|
|
|
|
with self.cached_session():
|
|
default_value = "UNK"
|
|
key_index = lookup_ops.TextFileIndex.LINE_NUMBER
|
|
value_index = lookup_ops.TextFileIndex.WHOLE_LINE
|
|
init = lookup_ops.TextFileInitializer(
|
|
vocabulary_file, dtypes.int64, key_index, dtypes.string, value_index)
|
|
self.assertIn("one_column_2.txt_-1_-2", init._shared_name)
|
|
table = self.getHashTable()(init, default_value)
|
|
self.initialize_table(table)
|
|
|
|
input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
|
output = table.lookup(input_values)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"], result)
|
|
|
|
def testMultiColumn(self):
|
|
vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt")
|
|
with open(vocabulary_file, "w") as f:
|
|
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
|
|
|
|
with self.cached_session():
|
|
default_value = -1
|
|
key_index = 1
|
|
value_index = 2
|
|
|
|
init = lookup_ops.TextFileInitializer(
|
|
vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
|
|
self.assertIn("three_columns.txt_1_2", init._shared_name)
|
|
table = self.getHashTable()(init, default_value)
|
|
self.initialize_table(table)
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "surgery"])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([1, 5, 6], result)
|
|
|
|
def testInvalidDataTypeInMultiColumn(self):
|
|
vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt")
|
|
with open(vocabulary_file, "w") as f:
|
|
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
|
|
|
|
with self.cached_session():
|
|
default_value = -1
|
|
key_index = 2
|
|
value_index = 1
|
|
init = lookup_ops.TextFileInitializer(
|
|
vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
|
|
self.assertIn("three_columns.txt_2_1", init._shared_name)
|
|
with self.assertRaisesOpError("is not a valid"):
|
|
table = self.getHashTable()(init, default_value)
|
|
self.initialize_table(table)
|
|
|
|
def testInvalidDataType(self):
|
|
vocabulary_file = self._createVocabFile("one_column_3.txt")
|
|
|
|
with self.cached_session():
|
|
default_value = "UNK"
|
|
key_index = lookup_ops.TextFileIndex.WHOLE_LINE
|
|
value_index = lookup_ops.TextFileIndex.LINE_NUMBER
|
|
|
|
with self.assertRaises(ValueError):
|
|
init = lookup_ops.TextFileInitializer(vocabulary_file, dtypes.int64,
|
|
key_index, dtypes.string,
|
|
value_index)
|
|
self.assertIn("one_column_3.txt_-2_-1", init._shared_name)
|
|
self.getHashTable()(init, default_value)
|
|
|
|
def testInvalidIndex(self):
|
|
vocabulary_file = self._createVocabFile("one_column_4.txt")
|
|
with self.cached_session():
|
|
default_value = -1
|
|
key_index = 1 # second column of the line
|
|
value_index = lookup_ops.TextFileIndex.LINE_NUMBER
|
|
init = lookup_ops.TextFileInitializer(
|
|
vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
|
|
self.assertIn("one_column_4.txt_1_-1", init._shared_name)
|
|
|
|
with self.assertRaisesOpError("Invalid number of columns"):
|
|
table = self.getHashTable()(init, default_value)
|
|
self.initialize_table(table)
|
|
|
|
def testInitializeSameTableWithMultipleNodes(self):
|
|
vocabulary_file = self._createVocabFile("one_column_5.txt")
|
|
|
|
with self.cached_session():
|
|
default_value = -1
|
|
init1 = lookup_ops.TextFileInitializer(
|
|
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
|
self.assertIn("one_column_5.txt_-2_-1", init1._shared_name)
|
|
table1 = self.getHashTable()(init1, default_value)
|
|
init2 = lookup_ops.TextFileInitializer(
|
|
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
|
self.assertIn("one_column_5.txt_-2_-1", init2._shared_name)
|
|
table2 = self.getHashTable()(init2, default_value)
|
|
init3 = lookup_ops.TextFileInitializer(
|
|
vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
|
self.assertIn("one_column_5.txt_-2_-1", init3._shared_name)
|
|
table3 = self.getHashTable()(init3, default_value)
|
|
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
|
|
output1 = table1.lookup(input_string)
|
|
output2 = table2.lookup(input_string)
|
|
output3 = table3.lookup(input_string)
|
|
|
|
out1, out2, out3 = self.evaluate([output1, output2, output3])
|
|
self.assertAllEqual([0, 1, -1], out1)
|
|
self.assertAllEqual([0, 1, -1], out2)
|
|
self.assertAllEqual([0, 1, -1], out3)
|
|
|
|
def testInitializeTableWithNoFilename(self):
|
|
with self.cached_session():
|
|
default_value = -1
|
|
with self.assertRaises(ValueError):
|
|
self.getHashTable()(lookup_ops.TextFileInitializer(
|
|
"", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), default_value)
|
|
|
|
def testInitializeWithVocabSize(self):
|
|
with self.cached_session():
|
|
default_value = -1
|
|
vocab_size = 3
|
|
vocabulary_file1 = self._createVocabFile("one_column6.txt")
|
|
init1 = lookup_ops.TextFileInitializer(
|
|
vocabulary_file1,
|
|
dtypes.string,
|
|
lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64,
|
|
lookup_ops.TextFileIndex.LINE_NUMBER,
|
|
vocab_size=vocab_size)
|
|
self.assertIn("one_column6.txt_3_-2_-1", init1._shared_name)
|
|
table1 = self.getHashTable()(init1, default_value)
|
|
|
|
# Initialize from file.
|
|
self.initialize_table(table1)
|
|
self.assertEqual(vocab_size, self.evaluate(table1.size()))
|
|
|
|
vocabulary_file2 = self._createVocabFile("one_column7.txt")
|
|
vocab_size = 5
|
|
init2 = lookup_ops.TextFileInitializer(
|
|
vocabulary_file2,
|
|
dtypes.string,
|
|
lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64,
|
|
lookup_ops.TextFileIndex.LINE_NUMBER,
|
|
vocab_size=vocab_size)
|
|
self.assertIn("one_column7.txt_5_-2_-1", init2._shared_name)
|
|
with self.assertRaisesOpError("Invalid vocab_size"):
|
|
table2 = self.getHashTable()(init2, default_value)
|
|
self.initialize_table(table2)
|
|
|
|
vocab_size = 1
|
|
vocabulary_file3 = self._createVocabFile("one_column3.txt")
|
|
init3 = lookup_ops.TextFileInitializer(
|
|
vocabulary_file3,
|
|
dtypes.string,
|
|
lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64,
|
|
lookup_ops.TextFileIndex.LINE_NUMBER,
|
|
vocab_size=vocab_size)
|
|
self.assertIn("one_column3.txt_1_-2_-1", init3._shared_name)
|
|
table3 = self.getHashTable()(init3, default_value)
|
|
|
|
# Smaller vocab size reads only vocab_size records.
|
|
self.initialize_table(table3)
|
|
self.assertEqual(vocab_size, self.evaluate(table3.size()))
|
|
|
|
@test_util.run_v1_only("placeholder usage")
|
|
def testFeedVocabularyName(self):
|
|
vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
|
|
|
|
with self.cached_session():
|
|
default_value = -1
|
|
init = lookup_ops.TextFileInitializer(
|
|
"old_file.txt", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
|
|
self.assertIn("old_file.txt_-2_-1", init._shared_name)
|
|
table = self.getHashTable()(init, default_value)
|
|
|
|
# Initialize with non existing file (old_file.txt) should fail.
|
|
# TODO(yleon): Update message, which might change per FileSystem.
|
|
with self.assertRaisesOpError("old_file.txt"):
|
|
table.initializer.run()
|
|
|
|
# Initialize the model feeding the vocabulary file.
|
|
filenames = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
|
table.initializer.run(feed_dict={filenames[0]: vocabulary_file})
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
def testInvalidFilenames(self):
|
|
vocabulary_file = self._createVocabFile("filename_shape.txt")
|
|
|
|
with self.cached_session():
|
|
default_value = -1
|
|
|
|
# Invalid data type
|
|
other_type = constant_op.constant(1)
|
|
with self.assertRaises(Exception) as cm:
|
|
self.getHashTable()(lookup_ops.TextFileInitializer(
|
|
other_type, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), default_value)
|
|
self.assertIsInstance(cm.exception, (ValueError, TypeError))
|
|
|
|
# Non-scalar filename
|
|
filenames = constant_op.constant([vocabulary_file, vocabulary_file])
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(Exception) as cm:
|
|
self.getHashTable()(lookup_ops.TextFileInitializer(
|
|
filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
|
|
default_value)
|
|
self.assertIsInstance(cm.exception, (ValueError, TypeError))
|
|
else:
|
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
|
self.getHashTable()(lookup_ops.TextFileInitializer(
|
|
filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
|
|
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
|
|
default_value)
|
|
|
|
def testIdToStringTable(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_1.txt")
|
|
with self.cached_session():
|
|
default_value = "UNK"
|
|
vocab_size = 3
|
|
init = lookup_ops.TextFileStringTableInitializer(
|
|
vocab_file, vocab_size=vocab_size)
|
|
self.assertTrue("feat_to_id_1.txt_3_-1_-2", init._shared_name)
|
|
table = self.getHashTable()(init, default_value)
|
|
|
|
self.initialize_table(table)
|
|
|
|
input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
|
|
|
out = table.lookup(input_values)
|
|
self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"],
|
|
self.evaluate(out))
|
|
self.assertEqual(vocab_size, self.evaluate(table.size()))
|
|
|
|
def testStringToIdTable(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_2.txt")
|
|
with self.cached_session():
|
|
default_value = -1
|
|
vocab_size = 3
|
|
init = lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size)
|
|
self.assertTrue("feat_to_id_2.txt_3_-1_-2", init._shared_name)
|
|
table = self.getHashTable()(init, default_value)
|
|
self.initialize_table(table)
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
|
|
|
|
out = table.lookup(input_string)
|
|
self.assertAllEqual([0, 1, 2, -1], self.evaluate(out))
|
|
self.assertEqual(vocab_size, self.evaluate(table.size()))
|
|
|
|
def testInt64ToIdTable(self):
|
|
vocab_file = self._createVocabFile(
|
|
"feat_to_id_3.txt", values=("42", "1", "-1000"))
|
|
with self.cached_session():
|
|
default_value = -1
|
|
vocab_size = 3
|
|
init = lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64)
|
|
self.assertTrue("feat_to_id_3.txt_3_-1_-2", init._shared_name)
|
|
table = self.getHashTable()(init, default_value)
|
|
self.initialize_table(table)
|
|
|
|
out = table.lookup(
|
|
constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64))
|
|
self.assertAllEqual((0, 1, 2, -1), self.evaluate(out))
|
|
self.assertEqual(vocab_size, self.evaluate(table.size()))
|
|
|
|
|
|
class StaticVocabularyTableTest(BaseLookupTableTest):
|
|
|
|
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
|
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
|
with open(vocabulary_file, "w") as f:
|
|
f.write("\n".join(values) + "\n")
|
|
return vocabulary_file
|
|
|
|
def testStringStaticVocabularyTable(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_1.txt")
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
table = self.getVocabularyTable()(lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), oov_buckets)
|
|
|
|
self.initialize_table(table)
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
|
|
|
|
out = table.lookup(input_string)
|
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
|
|
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
|
|
|
|
def testInt32StaticVocabularyTable(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
table = self.getVocabularyTable()(
|
|
lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
|
|
oov_buckets,
|
|
lookup_key_dtype=dtypes.int32)
|
|
|
|
self.initialize_table(table)
|
|
|
|
values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32)
|
|
|
|
out = table.lookup(values)
|
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
|
|
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
|
|
|
|
def testInt64StaticVocabularyTable(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
table = self.getVocabularyTable()(lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), oov_buckets)
|
|
|
|
self.initialize_table(table)
|
|
|
|
values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)
|
|
|
|
out = table.lookup(values)
|
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
|
|
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
|
|
|
|
def testStringStaticVocabularyTableNoInitializer(self):
|
|
oov_buckets = 5
|
|
|
|
# Set a table that only uses hash buckets, for each input value returns
|
|
# an id calculated by fingerprint("input") mod oov_buckets.
|
|
table = self.getVocabularyTable()(None, oov_buckets)
|
|
self.initialize_table(table)
|
|
|
|
values = constant_op.constant(("brain", "salad", "surgery"))
|
|
|
|
out = table.lookup(values)
|
|
self.assertAllEqual(
|
|
[
|
|
3, # fingerprint("brain") mod 5.
|
|
1, # fingerprint("salad") mod 5.
|
|
4 # fingerprint("surgery") mod 5
|
|
],
|
|
self.evaluate(out))
|
|
self.assertEqual(oov_buckets, self.evaluate(table.size()))
|
|
|
|
def testStaticVocabularyTableWithMultipleInitializers(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_4.txt")
|
|
vocab_size = 3
|
|
oov_buckets = 3
|
|
|
|
init = lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size)
|
|
table1 = self.getVocabularyTable()(init, oov_buckets, name="table1")
|
|
|
|
table2 = self.getVocabularyTable()(init, oov_buckets, name="table2")
|
|
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
|
|
input_string = constant_op.constant(
|
|
["fruit", "brain", "salad", "surgery", "UNK"])
|
|
|
|
out1 = table1.lookup(input_string)
|
|
out2 = table2.lookup(input_string)
|
|
|
|
out1, out2 = self.evaluate([out1, out2])
|
|
self.assertAllEqual([5, 0, 1, 2, 5], out1)
|
|
self.assertAllEqual([5, 0, 1, 2, 5], out2)
|
|
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size()))
|
|
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size()))
|
|
|
|
def testStaticVocabularyTableInitializationAcrossSessions(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_5.txt")
|
|
with self.cached_session():
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
table1 = self.getVocabularyTable()(lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), oov_buckets)
|
|
|
|
self.initialize_table(table1)
|
|
|
|
input_string_1 = constant_op.constant(
|
|
["brain", "salad", "surgery", "UNK"])
|
|
|
|
out1 = table1.lookup(input_string_1)
|
|
|
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(out1))
|
|
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size()))
|
|
|
|
with self.cached_session():
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
|
|
# Underlying lookup table already initialized in previous session.
|
|
# No need to initialize table2
|
|
table2 = self.getVocabularyTable()(lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), oov_buckets)
|
|
|
|
input_string_2 = constant_op.constant(["fruit", "salad", "UNK"])
|
|
|
|
out2 = table2.lookup(input_string_2)
|
|
|
|
self.assertAllEqual([3, 1, 3], self.evaluate(out2))
|
|
self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size()))
|
|
|
|
def testStaticVocabularyTableAssetTracking(self):
|
|
vocab_file = self._createVocabFile("vocab.txt")
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
table = self.getVocabularyTable()(lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), oov_buckets)
|
|
object_graph_view = graph_view.ObjectGraphView(table)
|
|
objects = object_graph_view.list_objects()
|
|
assets = list(filter(lambda obj: isinstance(obj, tracking.Asset), objects))
|
|
self.assertLen(assets, 1)
|
|
self.assertEqual(
|
|
self.evaluate(assets[0].asset_path), compat.as_bytes(vocab_file))
|
|
|
|
def testSparseTensor(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_7.txt")
|
|
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
|
input_shape = [4, 4]
|
|
sp_features = sparse_tensor.SparseTensor(
|
|
constant_op.constant(input_indices, dtypes.int64),
|
|
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
|
|
dtypes.string),
|
|
constant_op.constant(input_shape, dtypes.int64))
|
|
|
|
table = self.getVocabularyTable()(lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=3), 1)
|
|
self.initialize_table(table)
|
|
|
|
sp_ids = table.lookup(sp_features)
|
|
|
|
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
|
|
|
sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
|
|
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
|
|
|
self.assertAllEqual(input_indices, sp_ids_ind)
|
|
self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
|
|
self.assertAllEqual(input_shape, sp_ids_shape)
|
|
|
|
def testInt32SparseTensor(self):
|
|
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
|
input_shape = [4, 4]
|
|
sp_features = sparse_tensor.SparseTensor(
|
|
constant_op.constant(input_indices, dtypes.int64),
|
|
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
|
|
constant_op.constant(input_shape, dtypes.int64))
|
|
|
|
table = self.getVocabularyTable()(
|
|
lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
|
|
dtypes.int64, dtypes.int64),
|
|
1,
|
|
lookup_key_dtype=dtypes.int32)
|
|
self.initialize_table(table)
|
|
|
|
sp_ids = table.lookup(sp_features)
|
|
|
|
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
|
|
|
sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
|
|
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
|
|
|
self.assertAllEqual(input_indices, sp_ids_ind)
|
|
self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
|
|
self.assertAllEqual(input_shape, sp_ids_shape)
|
|
|
|
def testInt64SparseTensor(self):
|
|
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
|
input_shape = [4, 4]
|
|
sp_features = sparse_tensor.SparseTensor(
|
|
constant_op.constant(input_indices, dtypes.int64),
|
|
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
|
|
constant_op.constant(input_shape, dtypes.int64))
|
|
|
|
table = self.getVocabularyTable()(lookup_ops.KeyValueTensorInitializer(
|
|
(42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), 1)
|
|
self.initialize_table(table)
|
|
|
|
sp_ids = table.lookup(sp_features)
|
|
|
|
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
|
|
|
sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
|
|
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
|
|
|
self.assertAllEqual(input_indices, sp_ids_ind)
|
|
self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
|
|
self.assertAllEqual(input_shape, sp_ids_shape)
|
|
|
|
def testStaticVocabularyTableNoInnerTable(self):
|
|
table = self.getVocabularyTable()(None, num_oov_buckets=1)
|
|
self.assertIsNone(table.resource_handle)
|
|
|
|
|
|
class DenseHashTableOpTest(test.TestCase):
|
|
|
|
def testBasic(self):
|
|
with self.cached_session():
|
|
|
|
keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
|
|
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=0,
|
|
deleted_key=-1)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
|
|
remove_string = constant_op.constant([12, 15], dtypes.int64)
|
|
self.evaluate(table.remove(remove_string))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([3], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, -1, -1], result)
|
|
|
|
def testBasicBool(self):
|
|
with self.cached_session():
|
|
|
|
keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
|
|
values = constant_op.constant([True, True, True, True], dtypes.bool)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.bool,
|
|
default_value=False,
|
|
empty_key=0,
|
|
deleted_key=-1)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
|
|
remove_string = constant_op.constant([11, 15], dtypes.int64)
|
|
self.evaluate(table.remove(remove_string))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([3], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([False, True, False], result)
|
|
|
|
def testSameEmptyAndDeletedKey(self):
|
|
with self.cached_session():
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"Empty and deleted keys"):
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=42,
|
|
deleted_key=42)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
@test_util.run_v1_only("uses placeholders")
|
|
def testLookupUnknownShape(self):
|
|
with self.cached_session():
|
|
keys = constant_op.constant([11, 12, 13], dtypes.int64)
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=0,
|
|
deleted_key=-1)
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
placeholder_keys = array_ops.placeholder(dtypes.int64)
|
|
output = table.lookup(placeholder_keys)
|
|
self.assertAllEqual(None, output.get_shape())
|
|
result = output.eval({placeholder_keys: [11, 12, 15]})
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
def testMapStringToFloat(self):
|
|
with self.cached_session():
|
|
|
|
keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string)
|
|
values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32)
|
|
default_value = constant_op.constant(-1.5, dtypes.float32)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.string,
|
|
dtypes.float32,
|
|
default_value=default_value,
|
|
empty_key="",
|
|
deleted_key="$")
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
|
|
remove_string = constant_op.constant(["b", "e"])
|
|
self.evaluate(table.remove(remove_string))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([4], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllClose([0, -1.5, 3.3, -1.5], result)
|
|
|
|
def testMapInt64ToFloat(self):
|
|
for float_dtype in [dtypes.float32, dtypes.float64]:
|
|
with self.cached_session():
|
|
|
|
keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
|
|
values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype)
|
|
default_value = constant_op.constant(-1.5, float_dtype)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
float_dtype,
|
|
default_value=default_value,
|
|
empty_key=0,
|
|
deleted_key=-1)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
|
|
remove_string = constant_op.constant([12, 15], dtypes.int64)
|
|
self.evaluate(table.remove(remove_string))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([4], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllClose([0, -1.5, 3.3, -1.5], result)
|
|
|
|
def testVectorValues(self):
|
|
with self.cached_session():
|
|
keys = constant_op.constant([11, 12, 13], dtypes.int64)
|
|
values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]],
|
|
dtypes.int64)
|
|
default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=0,
|
|
deleted_key=-1,
|
|
initial_num_buckets=4)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
self.assertAllEqual(4, len(self.evaluate(table.export()[0])))
|
|
|
|
self.evaluate(
|
|
table.insert(
|
|
constant_op.constant([14], dtypes.int64),
|
|
constant_op.constant([[2, 3, 4, 5]], dtypes.int64)))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
self.assertAllEqual(8, len(self.evaluate(table.export()[0])))
|
|
|
|
remove_string = constant_op.constant([12, 16], dtypes.int64)
|
|
self.evaluate(table.remove(remove_string))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
self.assertAllEqual(8, len(self.evaluate(table.export()[0])))
|
|
|
|
input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([4, 4],
|
|
output.shape,
|
|
msg="Saw shape: %s" % output.shape)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual(
|
|
[[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]],
|
|
result)
|
|
|
|
def testVectorKeys(self):
|
|
with self.cached_session():
|
|
keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
|
|
values = constant_op.constant([10, 11, 12], dtypes.int64)
|
|
empty_key = constant_op.constant([0, 3], dtypes.int64)
|
|
deleted_key = constant_op.constant([-1, -1], dtypes.int64)
|
|
default_value = constant_op.constant(-1, dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
initial_num_buckets=8)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
self.evaluate(
|
|
table.insert(
|
|
constant_op.constant([[0, 0]], dtypes.int64),
|
|
constant_op.constant([13], dtypes.int64)))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
self.assertAllEqual(8, len(self.evaluate(table.export()[0])))
|
|
|
|
remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64)
|
|
self.evaluate(table.remove(remove_string))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
self.assertAllEqual(8, len(self.evaluate(table.export()[0])))
|
|
|
|
input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]],
|
|
dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([4], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([10, -1, 12, -1], result)
|
|
|
|
def testResize(self):
|
|
with self.cached_session():
|
|
keys = constant_op.constant([11, 12, 13], dtypes.int64)
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=0,
|
|
deleted_key=-1,
|
|
initial_num_buckets=4)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
self.assertAllEqual(4, len(self.evaluate(table.export()[0])))
|
|
|
|
keys2 = constant_op.constant([12, 99], dtypes.int64)
|
|
self.evaluate(table.remove(keys2))
|
|
self.assertAllEqual(2, self.evaluate(table.size()))
|
|
self.assertAllEqual(4, len(self.evaluate(table.export()[0])))
|
|
|
|
keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64)
|
|
values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64)
|
|
|
|
self.evaluate(table.insert(keys3, values3))
|
|
self.assertAllEqual(6, self.evaluate(table.size()))
|
|
self.assertAllEqual(16, len(self.evaluate(table.export()[0])))
|
|
|
|
keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18],
|
|
dtypes.int64)
|
|
output = table.lookup(keys4)
|
|
self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], self.evaluate(output))
|
|
|
|
def testExport(self):
|
|
with self.cached_session():
|
|
|
|
keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
|
|
values = constant_op.constant([1, 2, 3, 4], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=100,
|
|
deleted_key=200,
|
|
initial_num_buckets=8)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
|
|
keys2 = constant_op.constant([12, 15], dtypes.int64)
|
|
self.evaluate(table.remove(keys2))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
exported_keys, exported_values = table.export()
|
|
|
|
np_keys = self.evaluate(exported_keys)
|
|
np_values = self.evaluate(exported_values)
|
|
|
|
self.assertAllEqual(8, len(np_keys))
|
|
self.assertAllEqual(8, len(np_values))
|
|
|
|
# pair up keys and values, drop extra added dimension
|
|
pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0]
|
|
# sort by key
|
|
pairs = pairs[pairs[:, 0].argsort()]
|
|
self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0],
|
|
[100, 0], [100, 0], [200, 2]], pairs)
|
|
|
|
@test_util.run_v1_only("Saver V1 only")
|
|
def testSaveRestore(self):
|
|
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
default_value = -1
|
|
empty_key = 0
|
|
deleted_key = -1
|
|
keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
|
|
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
name="t1",
|
|
checkpoint=True,
|
|
initial_num_buckets=32)
|
|
|
|
save = saver.Saver()
|
|
|
|
self.assertAllEqual(0, table.size().eval())
|
|
table.insert(keys, values).run()
|
|
self.assertAllEqual(4, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
keys2 = constant_op.constant([12, 15], dtypes.int64)
|
|
table.remove(keys2).run()
|
|
self.assertAllEqual(3, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
val = save.save(sess, save_path)
|
|
self.assertIsInstance(val, six.string_types)
|
|
self.assertEqual(save_path, val)
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
name="t1",
|
|
checkpoint=True,
|
|
initial_num_buckets=64)
|
|
table.insert(
|
|
constant_op.constant([11, 14], dtypes.int64),
|
|
constant_op.constant([12, 24], dtypes.int64)).run()
|
|
self.assertAllEqual(2, table.size().eval())
|
|
self.assertAllEqual(64, len(table.export()[0].eval()))
|
|
|
|
save = saver.Saver()
|
|
|
|
# Restore the saved values in the parameter nodes.
|
|
save.restore(sess, save_path)
|
|
|
|
self.assertAllEqual(3, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([-1, 0, -1, 2, 3], output.eval())
|
|
|
|
@test_util.run_v1_only("Saver V1 only")
|
|
def testSaveRestoreOnlyTable(self):
|
|
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
default_value = -1
|
|
empty_key = 0
|
|
deleted_key = -1
|
|
keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
|
|
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
name="t1",
|
|
checkpoint=True,
|
|
initial_num_buckets=32)
|
|
|
|
save = saver.Saver([table])
|
|
|
|
self.assertAllEqual(0, table.size().eval())
|
|
table.insert(keys, values).run()
|
|
self.assertAllEqual(4, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
keys2 = constant_op.constant([12, 15], dtypes.int64)
|
|
table.remove(keys2).run()
|
|
self.assertAllEqual(3, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
val = save.save(sess, save_path)
|
|
self.assertIsInstance(val, six.string_types)
|
|
self.assertEqual(save_path, val)
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
name="t1",
|
|
checkpoint=True,
|
|
initial_num_buckets=64)
|
|
table.insert(
|
|
constant_op.constant([11, 14], dtypes.int64),
|
|
constant_op.constant([12, 24], dtypes.int64)).run()
|
|
self.assertAllEqual(2, table.size().eval())
|
|
self.assertAllEqual(64, len(table.export()[0].eval()))
|
|
|
|
save = saver.Saver([table])
|
|
|
|
# Restore the saved values in the parameter nodes.
|
|
save.restore(sess, save_path)
|
|
|
|
self.assertAllEqual(3, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([-1, 0, -1, 2, 3], output.eval())
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testObjectSaveRestore(self):
|
|
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
|
save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
|
|
|
default_value = -1
|
|
empty_key = 0
|
|
deleted_key = -1
|
|
keys = constant_op.constant([11, 12, 13], dtypes.int64)
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
save_table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
name="t1",
|
|
checkpoint=True,
|
|
initial_num_buckets=32)
|
|
|
|
save_checkpoint = trackable.Checkpoint(table=save_table)
|
|
|
|
self.assertAllEqual(0, self.evaluate(save_table.size()))
|
|
self.evaluate(save_table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(save_table.size()))
|
|
self.assertAllEqual(32, len(self.evaluate(save_table.export()[0])))
|
|
|
|
save_path = save_checkpoint.save(save_prefix)
|
|
del save_table, save_checkpoint
|
|
|
|
load_table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
name="t1",
|
|
checkpoint=True,
|
|
initial_num_buckets=64)
|
|
self.evaluate(
|
|
load_table.insert(
|
|
constant_op.constant([11, 14], dtypes.int64),
|
|
constant_op.constant([12, 24], dtypes.int64)))
|
|
self.assertAllEqual(2, self.evaluate(load_table.size()))
|
|
self.assertAllEqual(64, len(self.evaluate(load_table.export()[0])))
|
|
|
|
restore_checkpoint = trackable.Checkpoint(table=load_table)
|
|
|
|
# Restore the saved values in the parameter nodes.
|
|
restore_checkpoint.restore(save_path).run_restore_ops()
|
|
|
|
self.assertAllEqual(3, self.evaluate(load_table.size()))
|
|
self.assertAllEqual(32, len(self.evaluate(load_table.export()[0])))
|
|
|
|
input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
|
|
output = load_table.lookup(input_string)
|
|
self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
|
|
|
|
@test_util.run_v1_only("Saver V1 only")
|
|
def testVectorSaveRestore(self):
|
|
save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")
|
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
empty_key = constant_op.constant([11, 13], dtypes.int64)
|
|
deleted_key = constant_op.constant([-2, -3], dtypes.int64)
|
|
default_value = constant_op.constant([-1, -2], dtypes.int64)
|
|
keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]],
|
|
dtypes.int64)
|
|
values = constant_op.constant([[0, 1], [2, 3], [2, 4], [4, 5]],
|
|
dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
name="t1",
|
|
checkpoint=True,
|
|
initial_num_buckets=32)
|
|
|
|
save = saver.Saver()
|
|
|
|
self.assertAllEqual(0, table.size().eval())
|
|
table.insert(keys, values).run()
|
|
self.assertAllEqual(4, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
keys2 = constant_op.constant([[12, 13], [16, 17]], dtypes.int64)
|
|
table.remove(keys2).run()
|
|
self.assertAllEqual(3, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
val = save.save(sess, save_path)
|
|
self.assertIsInstance(val, six.string_types)
|
|
self.assertEqual(save_path, val)
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
empty_key = constant_op.constant([11, 13], dtypes.int64)
|
|
deleted_key = constant_op.constant([-2, -3], dtypes.int64)
|
|
default_value = constant_op.constant([-1, -2], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
name="t1",
|
|
checkpoint=True,
|
|
initial_num_buckets=64)
|
|
table.insert(
|
|
constant_op.constant([[11, 12], [13, 15]], dtypes.int64),
|
|
constant_op.constant([[21, 22], [23, 24]], dtypes.int64)).run()
|
|
self.assertAllEqual(2, table.size().eval())
|
|
self.assertAllEqual(64, len(table.export()[0].eval()))
|
|
|
|
save = saver.Saver()
|
|
|
|
# Restore the saved values in the parameter nodes.
|
|
save.restore(sess, save_path)
|
|
|
|
self.assertAllEqual(3, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
input_string = constant_op.constant(
|
|
[[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]],
|
|
output.eval())
|
|
|
|
@test_util.run_v1_only("Saver V1 only")
|
|
def testVectorScalarSaveRestore(self):
|
|
save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore")
|
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
empty_key = constant_op.constant([11, 13], dtypes.int64)
|
|
deleted_key = constant_op.constant([-1, -1], dtypes.int64)
|
|
default_value = constant_op.constant(-1, dtypes.int64)
|
|
keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]],
|
|
dtypes.int64)
|
|
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
name="t2",
|
|
checkpoint=True,
|
|
initial_num_buckets=32)
|
|
|
|
save = saver.Saver()
|
|
|
|
self.assertAllEqual(0, table.size().eval())
|
|
table.insert(keys, values).run()
|
|
self.assertAllEqual(4, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
keys2 = constant_op.constant([[12, 13], [15, 16]], dtypes.int64)
|
|
table.remove(keys2).run()
|
|
self.assertAllEqual(3, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
val = save.save(sess, save_path)
|
|
self.assertIsInstance(val, six.string_types)
|
|
self.assertEqual(save_path, val)
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
empty_key = constant_op.constant([11, 13], dtypes.int64)
|
|
deleted_key = constant_op.constant([-1, -1], dtypes.int64)
|
|
default_value = constant_op.constant(-1, dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=default_value,
|
|
empty_key=empty_key,
|
|
deleted_key=deleted_key,
|
|
name="t2",
|
|
checkpoint=True,
|
|
initial_num_buckets=64)
|
|
table.insert(
|
|
constant_op.constant([[11, 12], [13, 15]], dtypes.int64),
|
|
constant_op.constant([3, 4], dtypes.int64)).run()
|
|
self.assertAllEqual(2, table.size().eval())
|
|
self.assertAllEqual(64, len(table.export()[0].eval()))
|
|
|
|
save = saver.Saver()
|
|
|
|
# Restore the saved values in the parameter nodes.
|
|
save.restore(sess, save_path)
|
|
|
|
self.assertAllEqual(3, table.size().eval())
|
|
self.assertAllEqual(32, len(table.export()[0].eval()))
|
|
|
|
input_string = constant_op.constant(
|
|
[[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([0, 1, -1, 3, -1], output.eval())
|
|
|
|
def testReprobe(self):
|
|
with self.cached_session():
|
|
# Insert 6 keys into a table with 8 buckets.
|
|
# The values are chosen to make sure collisions occur when using GCC STL
|
|
keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64)
|
|
values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=0,
|
|
deleted_key=-1,
|
|
initial_num_buckets=8)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(6, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22],
|
|
dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([9], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result)
|
|
|
|
def testCustomEmptyKey(self):
|
|
with self.cached_session():
|
|
keys = constant_op.constant([11, 0, 13], dtypes.int64)
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=12,
|
|
deleted_key=-1)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([11, 0, 15], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([3], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
def testErrors(self):
|
|
with self.cached_session():
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=0,
|
|
deleted_key=-1)
|
|
|
|
# Inserting the empty key returns an error
|
|
keys1 = constant_op.constant([11, 0], dtypes.int64)
|
|
values1 = constant_op.constant([0, 1], dtypes.int64)
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"empty_key"):
|
|
self.evaluate(table.insert(keys1, values1))
|
|
|
|
# Looking up the empty key returns an error
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"empty_key"):
|
|
self.evaluate(table.lookup(keys1))
|
|
|
|
# Inserting the deleted key returns an error
|
|
keys2 = constant_op.constant([11, -1], dtypes.int64)
|
|
values2 = constant_op.constant([0, 1], dtypes.int64)
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"deleted_key"):
|
|
self.evaluate(table.insert(keys2, values2))
|
|
|
|
# Looking up the empty key returns an error
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"deleted_key"):
|
|
self.evaluate(table.lookup(keys2))
|
|
|
|
# Arbitrary tensors of keys are not supported
|
|
keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64)
|
|
values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64)
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"Expected key shape"):
|
|
self.evaluate(table.lookup(keys))
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"Expected key shape"):
|
|
self.evaluate(table.insert(keys, values))
|
|
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"Number of buckets must be"):
|
|
table2 = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=17,
|
|
deleted_key=-1,
|
|
initial_num_buckets=12)
|
|
self.assertAllEqual(0, self.evaluate(table2.size()))
|
|
|
|
with self.assertRaisesRegexp(
|
|
errors_impl.InvalidArgumentError,
|
|
"Empty and deleted keys must have same shape"):
|
|
table3 = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=42,
|
|
deleted_key=[1, 2])
|
|
self.assertAllEqual(0, self.evaluate(table3.size()))
|
|
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"Empty and deleted keys cannot be equal"):
|
|
table4 = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=42,
|
|
deleted_key=42)
|
|
self.assertAllEqual(0, self.evaluate(table4.size()))
|
|
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"Empty and deleted keys cannot be equal"):
|
|
table5 = lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.int64,
|
|
default_value=-1,
|
|
empty_key=[1, 2, 3],
|
|
deleted_key=[1, 2, 3])
|
|
self.assertAllEqual(0, self.evaluate(table5.size()))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testStringToResource(self):
|
|
v = variables.Variable(1.)
|
|
v1 = variables.Variable(1.)
|
|
table = lookup_ops.DenseHashTable(
|
|
dtypes.string,
|
|
dtypes.resource,
|
|
default_value=v.handle,
|
|
empty_key="<empty>",
|
|
deleted_key="<deleted>")
|
|
self.assertEqual([], table.lookup("not_found").shape)
|
|
table.insert("v1", v1.handle)
|
|
self.assertEqual([], table.lookup("v1").shape)
|
|
|
|
|
|
class IndexTableFromFile(test.TestCase):
|
|
|
|
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
|
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
|
with open(vocabulary_file, "w") as f:
|
|
f.write("\n".join(values) + "\n")
|
|
return vocabulary_file
|
|
|
|
def test_string_index_table_from_file(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file, num_oov_buckets=1)
|
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
|
|
|
|
def test_string_index_table_from_multicolumn_file(self):
|
|
vocabulary_file = self._createVocabFile(
|
|
"f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file,
|
|
num_oov_buckets=1,
|
|
key_column_index=0,
|
|
value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER)
|
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
|
|
|
|
def test_string_index_table_from_multicolumn_file_custom_delimiter(self):
|
|
vocabulary_file = self._createVocabFile(
|
|
"f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file,
|
|
num_oov_buckets=1,
|
|
key_column_index=0,
|
|
value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
|
|
delimiter=" ")
|
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
|
|
|
|
def test_string_index_table_from_file_tensor_filename(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
|
|
with self.cached_session():
|
|
vocabulary_file = constant_op.constant(vocabulary_file)
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file, num_oov_buckets=1)
|
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
|
|
if not context.executing_eagerly():
|
|
self.assertEqual(1,
|
|
len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
|
|
|
|
@test_util.run_v1_only("placeholder usage")
|
|
def test_string_index_table_from_file_placeholder_filename(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
|
|
with self.cached_session():
|
|
vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
|
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
|
|
feed_dict = {vocabulary_placeholder.name: vocabulary_file}
|
|
lookup_ops.tables_initializer().run(feed_dict=feed_dict)
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
|
|
self.assertEqual(0,
|
|
len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
|
|
|
|
def test_int32_index_table_from_file(self):
|
|
vocabulary_file = self._createVocabFile(
|
|
"f2i_vocab2.txt", values=("42", "1", "-1000"))
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file,
|
|
num_oov_buckets=1,
|
|
key_dtype=dtypes.int32)
|
|
ids = table.lookup(
|
|
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
|
|
|
|
def test_int64_index_table_from_file(self):
|
|
vocabulary_file = self._createVocabFile(
|
|
"f2i_vocab3.txt", values=("42", "1", "-1000"))
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file,
|
|
num_oov_buckets=1,
|
|
key_dtype=dtypes.int64)
|
|
ids = table.lookup(
|
|
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
|
|
|
|
def test_index_table_from_file_with_default_value(self):
|
|
default_value = -42
|
|
vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file, default_value=default_value)
|
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, default_value), self.evaluate(ids))
|
|
|
|
def test_index_table_from_file_with_oov_buckets(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file, num_oov_buckets=1000)
|
|
ids = table.lookup(
|
|
constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual(
|
|
(
|
|
1, # From vocabulary file.
|
|
2, # From vocabulary file.
|
|
867, # 3 + fingerprint("tarkus") mod 300.
|
|
860), # 3 + fingerprint("toccata") mod 300.
|
|
self.evaluate(ids))
|
|
|
|
def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
|
|
self.assertRaises(
|
|
ValueError, lookup_ops.index_table_from_file, vocabulary_file="")
|
|
|
|
def test_index_table_from_file_fails_with_empty_vocabulary(self):
|
|
self.assertRaises(
|
|
ValueError, lookup_ops.index_table_from_file, vocabulary_file=None)
|
|
|
|
def test_index_table_from_file_str_fails_with_zero_size_vocabulary(self):
|
|
vocabulary_file = self._createVocabFile("zero_vocab_str.txt")
|
|
self.assertRaisesRegexp(
|
|
ValueError,
|
|
"vocab_size must be greater than 0, got 0. "
|
|
"vocabulary_file: .*zero_vocab_str.txt",
|
|
lookup_ops.index_table_from_file,
|
|
vocabulary_file=vocabulary_file,
|
|
vocab_size=0)
|
|
|
|
def test_index_table_from_file_tensor_fails_with_zero_size_vocabulary(self):
|
|
vocabulary_file = constant_op.constant(
|
|
self._createVocabFile("zero_vocab_tensor.txt"))
|
|
self.assertRaisesRegexp(
|
|
ValueError,
|
|
"vocab_size must be greater than 0, got 0. "
|
|
"vocabulary_file: .*zero_vocab_tensor.txt",
|
|
lookup_ops.index_table_from_file,
|
|
vocabulary_file=vocabulary_file,
|
|
vocab_size=0)
|
|
|
|
def test_index_table_from_file_with_vocab_size_too_small(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file, vocab_size=2)
|
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, -1, -1), self.evaluate(ids))
|
|
self.assertEqual(2, self.evaluate(table.size()))
|
|
|
|
def test_index_table_from_file_with_vocab_size_too_large(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
|
|
with self.cached_session():
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"Invalid vocab_size"):
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file, vocab_size=4)
|
|
self.evaluate(table.initializer)
|
|
|
|
def test_index_table_from_file_with_vocab_size(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab8.txt")
|
|
|
|
self.assertRaises(
|
|
ValueError,
|
|
lookup_ops.index_table_from_file,
|
|
vocabulary_file=vocabulary_file,
|
|
vocab_size=0)
|
|
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file, vocab_size=3)
|
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, -1), self.evaluate(ids))
|
|
self.assertEqual(3, self.evaluate(table.size()))
|
|
|
|
def test_index_table_from_file_with_invalid_hashers(self):
|
|
vocabulary_file = self._createVocabFile("invalid_hasher.txt")
|
|
with self.cached_session():
|
|
with self.assertRaises(TypeError):
|
|
lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file,
|
|
vocab_size=3,
|
|
num_oov_buckets=1,
|
|
hasher_spec=1)
|
|
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file,
|
|
vocab_size=3,
|
|
num_oov_buckets=1,
|
|
hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None))
|
|
|
|
self.assertRaises(ValueError, table.lookup,
|
|
constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
def test_index_table_from_file_table_ref_with_oov_buckets(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab9.txt")
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file, num_oov_buckets=1)
|
|
self.assertIsNotNone(table.resource_handle)
|
|
|
|
def test_index_table_from_file_table_ref_without_oov_buckets(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab10.txt")
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_file(
|
|
vocabulary_file=vocabulary_file, num_oov_buckets=0)
|
|
self.assertIsNotNone(table.resource_handle)
|
|
|
|
|
|
class IndexTableFromTensor(test.TestCase):
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def test_index_table_from_tensor_with_tensor_init(self):
|
|
table = lookup_ops.index_table_from_tensor(
|
|
vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(
|
|
table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))))
|
|
else:
|
|
# Reinitializing a table in eager should work.
|
|
table = lookup_ops.index_table_from_tensor(
|
|
vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
|
|
|
|
def test_int32_index_table_from_tensor_with_tensor_init(self):
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_tensor(
|
|
vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
|
|
ids = table.lookup(
|
|
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.FailedPreconditionError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
|
|
|
|
def test_int64_index_table_from_tensor_with_tensor_init(self):
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_tensor(
|
|
vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
|
|
ids = table.lookup(
|
|
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.FailedPreconditionError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
|
|
|
|
def test_index_table_from_tensor_with_default_value(self):
|
|
default_value = -42
|
|
with self.cached_session():
|
|
table = lookup_ops.index_table_from_tensor(
|
|
vocabulary_list=["brain", "salad", "surgery"],
|
|
default_value=default_value)
|
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.FailedPreconditionError):
|
|
self.evaluate(ids)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((1, 2, default_value), self.evaluate(ids))
|
|
|
|
def test_index_table_from_tensor_missing_vocabulary_list(self):
|
|
with self.cached_session():
|
|
with self.assertRaisesRegexp(ValueError,
|
|
"vocabulary_list must be specified"):
|
|
lookup_ops.index_table_from_tensor(
|
|
vocabulary_list=None, num_oov_buckets=1)
|
|
|
|
def test_index_table_from_tensor_empty_vocabulary_list(self):
|
|
with self.cached_session():
|
|
with self.assertRaisesRegexp(
|
|
errors_impl.OpError, "keys and values cannot be empty"):
|
|
_ = lookup_ops.index_table_from_tensor(
|
|
vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
|
|
def test_index_table_from_tensor_with_invalid_hashers(self):
|
|
with self.cached_session():
|
|
with self.assertRaises(TypeError):
|
|
lookup_ops.index_table_from_tensor(
|
|
vocabulary_list=["brain", "salad", "surgery"],
|
|
num_oov_buckets=1,
|
|
hasher_spec=1)
|
|
|
|
table = lookup_ops.index_table_from_tensor(
|
|
vocabulary_list=["brain", "salad", "surgery"],
|
|
num_oov_buckets=1,
|
|
hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None))
|
|
|
|
self.assertRaises(ValueError, table.lookup,
|
|
constant_op.constant(["salad", "surgery", "tarkus"]))
|
|
|
|
|
|
class IndexToStringTableFromFileTest(test.TestCase):
|
|
|
|
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
|
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
|
with open(vocabulary_file, "w") as f:
|
|
f.write("\n".join(values) + "\n")
|
|
return vocabulary_file
|
|
|
|
def test_index_to_string_table(self):
|
|
vocabulary_path = self._createVocabFile("i2f_vocab1.txt")
|
|
# vocabulary_file supports string and tensor
|
|
type_funcs = [str, constant_op.constant]
|
|
for type_func in type_funcs:
|
|
vocabulary_file = type_func(vocabulary_path)
|
|
with self.cached_session():
|
|
table = lookup_ops.index_to_string_table_from_file(
|
|
vocabulary_file=vocabulary_file)
|
|
features = table.lookup(
|
|
constant_op.constant([0, 1, 2, 3], dtypes.int64))
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(features)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
|
self.evaluate(features))
|
|
|
|
def test_index_to_string_table_from_multicolumn_file(self):
|
|
vocabulary_file = self._createVocabFile(
|
|
"f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
|
|
with self.cached_session():
|
|
table = lookup_ops.index_to_string_table_from_file(
|
|
vocabulary_file=vocabulary_file,
|
|
key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
|
|
value_column_index=0)
|
|
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(features)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
|
self.evaluate(features))
|
|
|
|
def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self):
|
|
vocabulary_file = self._createVocabFile(
|
|
"f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
|
|
with self.cached_session():
|
|
table = lookup_ops.index_to_string_table_from_file(
|
|
vocabulary_file=vocabulary_file,
|
|
key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
|
|
value_column_index=0,
|
|
delimiter=" ")
|
|
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(features)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
|
self.evaluate(features))
|
|
|
|
def test_index_to_string_table_with_default_value(self):
|
|
default_value = b"NONE"
|
|
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
|
|
with self.cached_session():
|
|
table = lookup_ops.index_to_string_table_from_file(
|
|
vocabulary_file=vocabulary_file, default_value=default_value)
|
|
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(features)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((b"salad", b"surgery", default_value),
|
|
self.evaluate(features))
|
|
|
|
def test_index_to_string_table_with_vocab_size_too_small(self):
|
|
default_value = b"NONE"
|
|
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
|
|
with self.cached_session():
|
|
table = lookup_ops.index_to_string_table_from_file(
|
|
vocabulary_file=vocabulary_file,
|
|
vocab_size=2,
|
|
default_value=default_value)
|
|
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(features)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((b"salad", default_value, default_value),
|
|
self.evaluate(features))
|
|
|
|
def test_index_to_string_table_with_vocab_size_too_large(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
|
|
with self.cached_session():
|
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
"Invalid vocab_size"):
|
|
_ = lookup_ops.index_to_string_table_from_file(
|
|
vocabulary_file=vocabulary_file, vocab_size=4)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
|
|
def test_index_to_string_table_with_vocab_size(self):
|
|
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
|
|
with self.cached_session():
|
|
table = lookup_ops.index_to_string_table_from_file(
|
|
vocabulary_file=vocabulary_file, vocab_size=3)
|
|
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
|
|
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(features)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((b"salad", b"surgery", b"UNK"),
|
|
self.evaluate(features))
|
|
|
|
|
|
class IndexToStringTableFromTensorTest(test.TestCase):
|
|
|
|
def test_index_to_string_table_from_tensor(self):
|
|
with self.cached_session():
|
|
vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
|
|
table = lookup_ops.index_to_string_table_from_tensor(
|
|
vocabulary_list=vocabulary_list)
|
|
|
|
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
|
features = table.lookup(indices)
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(features)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
|
|
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
|
|
self.evaluate(features))
|
|
|
|
def test_duplicate_entries(self):
|
|
with self.cached_session():
|
|
vocabulary_list = constant_op.constant(["hello", "hello"])
|
|
table = lookup_ops.index_to_string_table_from_tensor(
|
|
vocabulary_list=vocabulary_list)
|
|
indices = constant_op.constant([0, 1, 4], dtypes.int64)
|
|
features = table.lookup(indices)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((b"hello", b"hello", b"UNK"), self.evaluate(features))
|
|
|
|
def test_index_to_string_with_default_value(self):
|
|
default_value = b"NONE"
|
|
with self.cached_session():
|
|
vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
|
|
table = lookup_ops.index_to_string_table_from_tensor(
|
|
vocabulary_list=vocabulary_list, default_value=default_value)
|
|
indices = constant_op.constant([1, 2, 4], dtypes.int64)
|
|
features = table.lookup(indices)
|
|
if not context.executing_eagerly():
|
|
with self.assertRaises(errors_impl.OpError):
|
|
self.evaluate(features)
|
|
self.evaluate(lookup_ops.tables_initializer())
|
|
self.assertAllEqual((b"salad", b"surgery", default_value),
|
|
self.evaluate(features))
|
|
|
|
|
|
class IdTableWithHashBucketsTest(test.TestCase):
|
|
|
|
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
|
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
|
with open(vocabulary_file, "w") as f:
|
|
f.write("\n".join(values) + "\n")
|
|
return vocabulary_file
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testStringIdTableWithHashBuckets(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_1.txt")
|
|
with self.cached_session():
|
|
default_value = -1
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_ops.StaticHashTable(
|
|
lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), default_value),
|
|
oov_buckets)
|
|
|
|
table.initializer.run()
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
|
|
|
|
out = table.lookup(input_string)
|
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
|
|
self.assertEqual(vocab_size + oov_buckets, table.size().eval())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInt32IdTableWithHashBuckets(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
|
|
with self.cached_session():
|
|
default_value = -1
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_ops.StaticHashTable(
|
|
lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
|
|
default_value),
|
|
oov_buckets,
|
|
key_dtype=dtypes.int32)
|
|
|
|
table.initializer.run()
|
|
|
|
values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32)
|
|
|
|
out = table.lookup(values)
|
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
|
|
self.assertEqual(vocab_size + oov_buckets, table.size().eval())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInt64IdTableWithHashBuckets(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
|
|
with self.cached_session():
|
|
default_value = -1
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_ops.StaticHashTable(
|
|
lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
|
|
default_value), oov_buckets)
|
|
|
|
table.initializer.run()
|
|
|
|
values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)
|
|
|
|
out = table.lookup(values)
|
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
|
|
self.assertEqual(vocab_size + oov_buckets, table.size().eval())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testStringIdTableWithOnlyHashBucket(self):
|
|
with self.cached_session():
|
|
oov_buckets = 5
|
|
|
|
# Set a table that only uses hash buckets, for each input value returns
|
|
# an id calculated by fingerprint("input") mod oov_buckets.
|
|
table = lookup_ops.IdTableWithHashBuckets(None, oov_buckets)
|
|
table.initializer.run()
|
|
|
|
values = constant_op.constant(("brain", "salad", "surgery"))
|
|
|
|
out = table.lookup(values)
|
|
self.assertAllEqual(
|
|
[
|
|
3, # fingerprint("brain") mod 5.
|
|
1, # fingerprint("salad") mod 5.
|
|
4 # fingerprint("surgery") mod 5
|
|
],
|
|
self.evaluate(out))
|
|
self.assertEqual(oov_buckets, table.size().eval())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInt32IdTableWithOnlyHashBucket(self):
|
|
with self.cached_session():
|
|
oov_buckets = 5
|
|
|
|
# Set a table that only uses hash buckets, for each input value returns
|
|
# an id calculated by fingerprint("input") mod oov_buckets.
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
None, oov_buckets, key_dtype=dtypes.int32)
|
|
table.initializer.run()
|
|
|
|
input_string = constant_op.constant([42, 1, -1000], dtype=dtypes.int32)
|
|
|
|
out = table.lookup(input_string)
|
|
self.assertAllEqual(
|
|
[
|
|
1, # fingerprint("42") mod 5.
|
|
4, # fingerprint("1") mod 5.
|
|
2 # fingerprint("-1000") mod 5
|
|
],
|
|
self.evaluate(out))
|
|
self.assertEqual(oov_buckets, table.size().eval())
|
|
|
|
def testFloat64IdTableWithOnlyHashBucket(self):
|
|
with self.cached_session():
|
|
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
|
|
lookup_ops.IdTableWithHashBuckets(
|
|
None, num_oov_buckets=5, key_dtype=dtypes.float64)
|
|
|
|
def testBoolIdTableWithOnlyHashBucket(self):
|
|
with self.cached_session():
|
|
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
|
|
lookup_ops.IdTableWithHashBuckets(
|
|
None, num_oov_buckets=5, key_dtype=dtypes.bool)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testIdTableWithHashBucketsWithMultipleInitializers(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_4.txt")
|
|
with self.cached_session() as sess:
|
|
default_value = -1
|
|
vocab_size = 3
|
|
oov_buckets = 3
|
|
|
|
vocab_table = lookup_ops.StaticHashTable(
|
|
lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), default_value)
|
|
table1 = lookup_ops.IdTableWithHashBuckets(
|
|
vocab_table,
|
|
oov_buckets,
|
|
hasher_spec=lookup_ops.FastHashSpec,
|
|
name="table1")
|
|
|
|
table2 = lookup_ops.IdTableWithHashBuckets(
|
|
vocab_table,
|
|
oov_buckets,
|
|
hasher_spec=lookup_ops.StrongHashSpec((1, 2)),
|
|
name="table2")
|
|
|
|
lookup_ops.tables_initializer().run()
|
|
|
|
input_string = constant_op.constant(
|
|
["fruit", "brain", "salad", "surgery", "UNK"])
|
|
|
|
out1 = table1.lookup(input_string)
|
|
out2 = table2.lookup(input_string)
|
|
|
|
out1, out2 = self.evaluate([out1, out2])
|
|
self.assertAllEqual([5, 0, 1, 2, 5], out1)
|
|
self.assertAllEqual([5, 0, 1, 2, 3], out2)
|
|
self.assertEqual(vocab_size + oov_buckets, table1.size().eval())
|
|
self.assertEqual(vocab_size + oov_buckets, table2.size().eval())
|
|
test_util.assert_ops_in_graph({
|
|
"table1_Lookup/hash_bucket": "StringToHashBucketFast",
|
|
"table2_Lookup/hash_bucket": "StringToHashBucketStrong",
|
|
}, sess.graph)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testIdTableWithHashBucketsInitializationAcrossSessions(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_5.txt")
|
|
with self.cached_session():
|
|
default_value = -1
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
table1 = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_ops.StaticHashTable(
|
|
lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), default_value),
|
|
oov_buckets)
|
|
|
|
table1.initializer.run()
|
|
|
|
input_string_1 = constant_op.constant(
|
|
["brain", "salad", "surgery", "UNK"])
|
|
|
|
out1 = table1.lookup(input_string_1)
|
|
|
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(out1))
|
|
self.assertEqual(vocab_size + oov_buckets, table1.size().eval())
|
|
|
|
with self.cached_session():
|
|
default_value = -1
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
|
|
# Underlying lookup table already initialized in previous session.
|
|
# No need to call table2.initializer.run()
|
|
table2 = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_ops.StaticHashTable(
|
|
lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), default_value),
|
|
oov_buckets)
|
|
|
|
input_string_2 = constant_op.constant(["fruit", "salad", "UNK"])
|
|
|
|
out2 = table2.lookup(input_string_2)
|
|
|
|
self.assertAllEqual([3, 1, 3], self.evaluate(out2))
|
|
self.assertEqual(vocab_size + oov_buckets, table2.size().eval())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_6.txt")
|
|
with self.cached_session() as sess:
|
|
default_value1 = -1
|
|
vocab_size = 3
|
|
oov_buckets = 0
|
|
table1 = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_ops.StaticHashTable(
|
|
lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), default_value1),
|
|
oov_buckets)
|
|
|
|
default_value2 = -2
|
|
table2 = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_ops.StaticHashTable(
|
|
lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), default_value2),
|
|
oov_buckets)
|
|
|
|
lookup_ops.tables_initializer().run()
|
|
|
|
input_string_1 = constant_op.constant(
|
|
["brain", "salad", "surgery", "UNK"])
|
|
input_string_2 = constant_op.constant(["fruit", "salad", "UNK"])
|
|
|
|
out1 = table1.lookup(input_string_1)
|
|
out2 = table2.lookup(input_string_2)
|
|
|
|
out1, out2 = self.evaluate([out1, out2])
|
|
self.assertAllEqual([0, 1, 2, -1], out1)
|
|
self.assertAllEqual([-2, 1, -2], out2)
|
|
self.assertEqual(vocab_size + oov_buckets, table1.size().eval())
|
|
self.assertEqual(vocab_size + oov_buckets, table2.size().eval())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSparseTensor(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_7.txt")
|
|
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
|
input_shape = [4, 4]
|
|
with self.cached_session() as sess:
|
|
sp_features = sparse_tensor.SparseTensor(
|
|
constant_op.constant(input_indices, dtypes.int64),
|
|
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
|
|
dtypes.string),
|
|
constant_op.constant(input_shape, dtypes.int64))
|
|
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_ops.StaticHashTable(
|
|
lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3),
|
|
-1), 1)
|
|
table.initializer.run()
|
|
|
|
sp_ids = table.lookup(sp_features)
|
|
|
|
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
|
|
|
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
|
|
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
|
|
|
self.assertAllEqual(input_indices, sp_ids_ind)
|
|
self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
|
|
self.assertAllEqual(input_shape, sp_ids_shape)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInt32SparseTensor(self):
|
|
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
|
input_shape = [4, 4]
|
|
with self.cached_session() as sess:
|
|
sp_features = sparse_tensor.SparseTensor(
|
|
constant_op.constant(input_indices, dtypes.int64),
|
|
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
|
|
constant_op.constant(input_shape, dtypes.int64))
|
|
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_ops.StaticHashTable(
|
|
lookup_ops.KeyValueTensorInitializer(
|
|
(42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1),
|
|
1,
|
|
key_dtype=dtypes.int32)
|
|
table.initializer.run()
|
|
|
|
sp_ids = table.lookup(sp_features)
|
|
|
|
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
|
|
|
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
|
|
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
|
|
|
self.assertAllEqual(input_indices, sp_ids_ind)
|
|
self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
|
|
self.assertAllEqual(input_shape, sp_ids_shape)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInt64SparseTensor(self):
|
|
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
|
|
input_shape = [4, 4]
|
|
with self.cached_session() as sess:
|
|
sp_features = sparse_tensor.SparseTensor(
|
|
constant_op.constant(input_indices, dtypes.int64),
|
|
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
|
|
constant_op.constant(input_shape, dtypes.int64))
|
|
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_ops.StaticHashTable(
|
|
lookup_ops.KeyValueTensorInitializer(
|
|
(42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1),
|
|
1,
|
|
key_dtype=dtypes.int64)
|
|
table.initializer.run()
|
|
|
|
sp_ids = table.lookup(sp_features)
|
|
|
|
self.assertAllEqual([5], sp_ids.values._shape_as_list())
|
|
|
|
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
|
|
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
|
|
|
|
self.assertAllEqual(input_indices, sp_ids_ind)
|
|
self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
|
|
self.assertAllEqual(input_shape, sp_ids_shape)
|
|
|
|
def testIdTableWithHashBucketsWithInvalidHashers(self):
|
|
vocab_file = self._createVocabFile("feat_to_id_4.txt")
|
|
with self.cached_session():
|
|
default_value = -1
|
|
vocab_size = 3
|
|
oov_buckets = 1
|
|
lookup_table = lookup_ops.StaticHashTable(
|
|
lookup_ops.TextFileIdTableInitializer(
|
|
vocab_file, vocab_size=vocab_size), default_value)
|
|
|
|
with self.assertRaises(TypeError):
|
|
lookup_ops.IdTableWithHashBuckets(
|
|
lookup_table, oov_buckets, hasher_spec=1)
|
|
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_table,
|
|
oov_buckets,
|
|
hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
|
|
|
|
with self.assertRaises(ValueError):
|
|
table.lookup(input_string)
|
|
|
|
with self.assertRaises(ValueError):
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_table,
|
|
oov_buckets,
|
|
hasher_spec=lookup_ops.StrongHashSpec([]))
|
|
|
|
with self.assertRaises(ValueError):
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_table,
|
|
oov_buckets,
|
|
hasher_spec=lookup_ops.StrongHashSpec([1, 2, 3]))
|
|
|
|
with self.assertRaises(TypeError):
|
|
table = lookup_ops.IdTableWithHashBuckets(
|
|
lookup_table,
|
|
oov_buckets,
|
|
hasher_spec=lookup_ops.StrongHashSpec([None, 2]))
|
|
|
|
def testIdTableWithHashBucketsNoInnerTable(self):
|
|
with self.cached_session():
|
|
table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1)
|
|
self.assertIsNone(table.resource_handle)
|
|
|
|
|
|
class MutableHashTableOpTest(test.TestCase):
|
|
|
|
def testMutableHashTable(self):
|
|
with self.cached_session():
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"])
|
|
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
|
|
remove_string = constant_op.constant(["tarkus", "tank"])
|
|
self.evaluate(table.remove(remove_string))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([3], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
exported_keys, exported_values = table.export()
|
|
|
|
# exported data is in the order of the internal map, i.e. undefined
|
|
sorted_keys = np.sort(self.evaluate(exported_keys))
|
|
sorted_values = np.sort(self.evaluate(exported_values))
|
|
self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
|
|
self.assertAllEqual([0, 1, 2], sorted_values)
|
|
|
|
@test_util.run_v1_only("SaverV1")
|
|
def testSaveRestore(self):
|
|
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
v0 = variables.Variable(10.0, name="v0")
|
|
v1 = variables.Variable(20.0, name="v1")
|
|
|
|
default_val = -1
|
|
keys = constant_op.constant(["b", "c", "d"], dtypes.string)
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(
|
|
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
|
|
|
|
save = saver.Saver()
|
|
self.evaluate(variables.global_variables_initializer())
|
|
|
|
# Check that the parameter nodes have been initialized.
|
|
self.assertEqual(10.0, self.evaluate(v0))
|
|
self.assertEqual(20.0, self.evaluate(v1))
|
|
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
val = save.save(sess, save_path)
|
|
self.assertIsInstance(val, six.string_types)
|
|
self.assertEqual(save_path, val)
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
v0 = variables.Variable(-1.0, name="v0")
|
|
v1 = variables.Variable(-1.0, name="v1")
|
|
default_val = -1
|
|
table = lookup_ops.MutableHashTable(
|
|
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
|
|
self.evaluate(
|
|
table.insert(
|
|
constant_op.constant(["a", "c"], dtypes.string),
|
|
constant_op.constant([12, 24], dtypes.int64)))
|
|
self.assertAllEqual(2, self.evaluate(table.size()))
|
|
|
|
save = saver.Saver()
|
|
|
|
# Restore the saved values in the parameter nodes.
|
|
save.restore(sess, save_path)
|
|
# Check that the parameter nodes have been restored.
|
|
self.assertEqual(10.0, self.evaluate(v0))
|
|
self.assertEqual(20.0, self.evaluate(v1))
|
|
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["a", "b", "c", "d", "e"],
|
|
dtypes.string)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
|
|
|
|
@test_util.run_v1_only("SaverV1")
|
|
def testSaveRestoreOnlyTable(self):
|
|
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
|
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
v0 = variables.Variable(10.0, name="v0")
|
|
v1 = variables.Variable(20.0, name="v1")
|
|
|
|
default_val = -1
|
|
keys = constant_op.constant(["b", "c", "d"], dtypes.string)
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(
|
|
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
|
|
|
|
save = saver.Saver([table])
|
|
self.evaluate(variables.global_variables_initializer())
|
|
|
|
# Check that the parameter nodes have been initialized.
|
|
self.assertEqual(10.0, self.evaluate(v0))
|
|
self.assertEqual(20.0, self.evaluate(v1))
|
|
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
val = save.save(sess, save_path)
|
|
self.assertIsInstance(val, six.string_types)
|
|
self.assertEqual(save_path, val)
|
|
|
|
with self.session(graph=ops.Graph()) as sess:
|
|
default_val = -1
|
|
table = lookup_ops.MutableHashTable(
|
|
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
|
|
self.evaluate(
|
|
table.insert(
|
|
constant_op.constant(["a", "c"], dtypes.string),
|
|
constant_op.constant([12, 24], dtypes.int64)))
|
|
self.assertAllEqual(2, self.evaluate(table.size()))
|
|
|
|
save = saver.Saver([table])
|
|
|
|
# Restore the saved values in the parameter nodes.
|
|
save.restore(sess, save_path)
|
|
# Check that the parameter nodes have been restored.
|
|
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["a", "b", "c", "d", "e"],
|
|
dtypes.string)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testObjectSaveRestore(self):
|
|
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
|
save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
|
|
|
v0 = variables.Variable(10.0, name="v0")
|
|
v1 = variables.Variable(20.0, name="v1")
|
|
|
|
default_val = -1
|
|
keys = constant_op.constant(["b", "c", "d"], dtypes.string)
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(
|
|
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
|
|
|
|
checkpoint = trackable.Checkpoint(table=table, v0=v0, v1=v1)
|
|
self.evaluate([v0.initializer, v1.initializer])
|
|
|
|
# Check that the parameter nodes have been initialized.
|
|
self.assertEqual(10.0, self.evaluate(v0))
|
|
self.assertEqual(20.0, self.evaluate(v1))
|
|
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
save_path = checkpoint.save(save_prefix)
|
|
del table, checkpoint, v0, v1
|
|
|
|
v0 = variables.Variable(-1.0, name="v0")
|
|
v1 = variables.Variable(-1.0, name="v1")
|
|
default_val = -1
|
|
table = lookup_ops.MutableHashTable(
|
|
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
|
|
self.evaluate(
|
|
table.insert(
|
|
constant_op.constant(["a", "c"], dtypes.string),
|
|
constant_op.constant([12, 24], dtypes.int64)))
|
|
self.assertAllEqual(2, self.evaluate(table.size()))
|
|
|
|
checkpoint = trackable.Checkpoint(table=table, v0=v0, v1=v1)
|
|
|
|
# Restore the saved values in the parameter nodes.
|
|
checkpoint.restore(save_path).run_restore_ops()
|
|
# Check that the parameter nodes have been restored.
|
|
self.assertEqual(10.0, self.evaluate(v0))
|
|
self.assertEqual(20.0, self.evaluate(v1))
|
|
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["a", "b", "c", "d", "e"],
|
|
dtypes.string)
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
|
|
|
|
@test_util.run_v1_only("Multiple sessions")
|
|
def testSharing(self):
|
|
# Start a server to store the table state
|
|
server = server_lib.Server({"local0": ["localhost:0"]},
|
|
protocol="grpc",
|
|
start=True)
|
|
# Create two sessions sharing the same state
|
|
session1 = session.Session(server.target)
|
|
session2 = session.Session(server.target)
|
|
|
|
table = lookup_ops.MutableHashTable(
|
|
dtypes.int64, dtypes.string, "-", name="t1")
|
|
|
|
# Populate the table in the first session
|
|
with session1:
|
|
self.assertAllEqual(0, table.size().eval())
|
|
|
|
keys = constant_op.constant([11, 12], dtypes.int64)
|
|
values = constant_op.constant(["a", "b"])
|
|
table.insert(keys, values).run()
|
|
self.assertAllEqual(2, table.size().eval())
|
|
|
|
output = table.lookup(constant_op.constant([11, 12, 13], dtypes.int64))
|
|
self.assertAllEqual([b"a", b"b", b"-"], output.eval())
|
|
|
|
# Verify that we can access the shared data from the second session
|
|
with session2:
|
|
self.assertAllEqual(2, table.size().eval())
|
|
|
|
output = table.lookup(constant_op.constant([10, 11, 12], dtypes.int64))
|
|
self.assertAllEqual([b"-", b"a", b"b"], output.eval())
|
|
|
|
def testMutableHashTableOfTensors(self):
|
|
with self.cached_session():
|
|
default_val = constant_op.constant([-1, -1], dtypes.int64)
|
|
keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"])
|
|
values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]],
|
|
dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
|
|
remove_string = constant_op.constant(["tarkus", "tank"])
|
|
self.evaluate(table.remove(remove_string))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([3, 2], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result)
|
|
|
|
exported_keys, exported_values = table.export()
|
|
# exported data is in the order of the internal map, i.e. undefined
|
|
sorted_keys = np.sort(self.evaluate(exported_keys))
|
|
sorted_values = np.sort(self.evaluate(exported_values), axis=0)
|
|
self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
|
|
sorted_expected_values = np.sort([[4, 5], [2, 3], [0, 1]], axis=0)
|
|
self.assertAllEqual(sorted_expected_values, sorted_values)
|
|
|
|
def testMutableHashTableExportInsert(self):
|
|
with self.cached_session():
|
|
default_val = constant_op.constant([-1, -1], dtypes.int64)
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
|
|
table1 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
self.assertAllEqual(0, self.evaluate(table1.size()))
|
|
self.evaluate(table1.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table1.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
expected_output = [[0, 1], [2, 3], [-1, -1]]
|
|
output1 = table1.lookup(input_string)
|
|
self.assertAllEqual(expected_output, self.evaluate(output1))
|
|
|
|
exported_keys, exported_values = table1.export()
|
|
self.assertAllEqual(3, self.evaluate(exported_keys).size)
|
|
self.assertAllEqual(6, self.evaluate(exported_values).size)
|
|
|
|
# Populate a second table from the exported data
|
|
table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
self.assertAllEqual(0, self.evaluate(table2.size()))
|
|
self.evaluate(table2.insert(exported_keys, exported_values))
|
|
self.assertAllEqual(3, self.evaluate(table2.size()))
|
|
|
|
# Verify lookup result is still the same
|
|
output2 = table2.lookup(input_string)
|
|
self.assertAllEqual(expected_output, self.evaluate(output2))
|
|
|
|
def testMutableHashTableOfTensorsInvalidShape(self):
|
|
with self.cached_session():
|
|
default_val = constant_op.constant([-1, -1], dtypes.int64)
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
|
|
# Shape [6] instead of [3, 2]
|
|
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64)
|
|
with self.assertRaisesOpError("Expected shape"):
|
|
self.evaluate(table.insert(keys, values))
|
|
|
|
# Shape [2,3] instead of [3, 2]
|
|
values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64)
|
|
with self.assertRaisesOpError("Expected shape"):
|
|
self.evaluate(table.insert(keys, values))
|
|
|
|
# Shape [2, 2] instead of [3, 2]
|
|
values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
|
|
with self.assertRaisesOpError("Expected shape"):
|
|
self.evaluate(table.insert(keys, values))
|
|
|
|
# Shape [3, 1] instead of [3, 2]
|
|
values = constant_op.constant([[0], [2], [4]], dtypes.int64)
|
|
with self.assertRaisesOpError("Expected shape"):
|
|
self.evaluate(table.insert(keys, values))
|
|
|
|
# Valid Insert
|
|
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
def testMutableHashTableInvalidDefaultValue(self):
|
|
with self.cached_session():
|
|
default_val = constant_op.constant([[-1, -1]], dtypes.int64)
|
|
with self.assertRaisesOpError("Default value must be a vector"):
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
def testMutableHashTableDuplicateInsert(self):
|
|
with self.cached_session():
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery", "brain"])
|
|
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([3, 1, -1], result)
|
|
|
|
def testMutableHashTableFindHighRank(self):
|
|
with self.cached_session():
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([["brain", "salad"],
|
|
["tank", "tarkus"]])
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([2, 2], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([[0, 1], [-1, -1]], result)
|
|
|
|
def testMutableHashTableInsertHighRank(self):
|
|
with self.cached_session():
|
|
default_val = -1
|
|
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
|
|
values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, 3, -1], result)
|
|
|
|
def testMutableHashTableRemoveHighRank(self):
|
|
with self.test_session():
|
|
default_val = -1
|
|
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
|
|
values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(4, self.evaluate(table.size()))
|
|
|
|
remove_string = constant_op.constant(["salad", "tarkus"])
|
|
self.evaluate(table.remove(remove_string))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, -1, 3, -1], result)
|
|
|
|
def testMutableHashTableOfTensorsFindHighRank(self):
|
|
with self.cached_session():
|
|
default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
|
|
dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([["brain", "salad"],
|
|
["tank", "tarkus"]])
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([2, 2, 3], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual(
|
|
[[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
|
|
|
|
def testMutableHashTableOfTensorsRemoveHighRank(self):
|
|
with self.test_session():
|
|
default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
|
|
dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
remove_string = constant_op.constant([["brain", "tank"]])
|
|
self.evaluate(table.remove(remove_string))
|
|
self.assertAllEqual(2, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([["brain", "salad"],
|
|
["surgery", "tank"]])
|
|
output = table.lookup(input_string)
|
|
self.assertAllEqual([2, 2, 3], output.get_shape())
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual(
|
|
[[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result)
|
|
|
|
def testMultipleMutableHashTables(self):
|
|
with self.cached_session():
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
|
|
table1 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
table3 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
self.evaluate(table1.insert(keys, values))
|
|
self.evaluate(table2.insert(keys, values))
|
|
self.evaluate(table3.insert(keys, values))
|
|
|
|
self.assertAllEqual(3, self.evaluate(table1.size()))
|
|
self.assertAllEqual(3, self.evaluate(table2.size()))
|
|
self.assertAllEqual(3, self.evaluate(table3.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output1 = table1.lookup(input_string)
|
|
output2 = table2.lookup(input_string)
|
|
output3 = table3.lookup(input_string)
|
|
|
|
out1, out2, out3 = self.evaluate([output1, output2, output3])
|
|
self.assertAllEqual([0, 1, -1], out1)
|
|
self.assertAllEqual([0, 1, -1], out2)
|
|
self.assertAllEqual([0, 1, -1], out3)
|
|
|
|
def testMutableHashTableWithTensorDefault(self):
|
|
with self.cached_session():
|
|
default_val = constant_op.constant(-1, dtypes.int64)
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual([0, 1, -1], result)
|
|
|
|
def testSignatureMismatch(self):
|
|
with self.cached_session():
|
|
default_val = -1
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
|
default_val)
|
|
|
|
# insert with keys of the wrong type
|
|
with self.assertRaises(ValueError):
|
|
self.evaluate(table.insert(constant_op.constant([4, 5, 6]), values))
|
|
|
|
# insert with values of the wrong type
|
|
with self.assertRaises(ValueError):
|
|
self.evaluate(table.insert(keys, constant_op.constant(["a", "b", "c"])))
|
|
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string_ref = variables.Variable("brain")
|
|
input_int64_ref = variables.Variable(-1, dtype=dtypes.int64)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
|
|
# Ref types do not produce an insert signature mismatch.
|
|
self.evaluate(table.insert(input_string_ref, input_int64_ref))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
# Ref types do not produce a lookup signature mismatch.
|
|
self.assertEqual(-1, self.evaluate(table.lookup(input_string_ref)))
|
|
|
|
# lookup with keys of the wrong type
|
|
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
|
|
with self.assertRaises(ValueError):
|
|
self.evaluate(table.lookup(input_string))
|
|
|
|
# default value of the wrong type
|
|
with self.assertRaises(TypeError):
|
|
lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, "UNK")
|
|
|
|
def testMutableHashTableStringFloat(self):
|
|
with self.cached_session():
|
|
default_val = -1.5
|
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
|
values = constant_op.constant([0, 1.1, 2.2], dtypes.float32)
|
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32,
|
|
default_val)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant(["brain", "salad", "tank"])
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllClose([0, 1.1, default_val], result)
|
|
|
|
def testMutableHashTableIntFloat(self):
|
|
with self.cached_session():
|
|
default_val = -1.0
|
|
keys = constant_op.constant([3, 7, 0], dtypes.int64)
|
|
values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32)
|
|
table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.float32,
|
|
default_val)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([7, 0, 11], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllClose([-1.2, 9.9, default_val], result)
|
|
|
|
def testMutableHashTableInt64String(self):
|
|
with self.cached_session():
|
|
default_val = "n/a"
|
|
keys = constant_op.constant([0, 1, 2], dtypes.int64)
|
|
values = constant_op.constant(["brain", "salad", "surgery"])
|
|
table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.string,
|
|
default_val)
|
|
self.assertAllEqual(0, self.evaluate(table.size()))
|
|
|
|
self.evaluate(table.insert(keys, values))
|
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
|
|
|
input_string = constant_op.constant([0, 1, 3], dtypes.int64)
|
|
output = table.lookup(input_string)
|
|
|
|
result = self.evaluate(output)
|
|
self.assertAllEqual((b"brain", b"salad", b"n/a"), result)
|
|
|
|
|
|
class MutableHashTableBenchmark(test.Benchmark):
|
|
|
|
def _create_table(self):
|
|
return lookup_ops.MutableHashTable(dtypes.int64, dtypes.float32, 0.0)
|
|
|
|
def benchmark_single_repeated_scalar_insert_scalar(self):
|
|
table = self._create_table()
|
|
value = variables.Variable(1.0)
|
|
insert = table.insert(0, value)
|
|
size = table.size()
|
|
with session.Session() as sess:
|
|
sess.run(value.initializer)
|
|
self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000)
|
|
assert sess.run(size) == 1
|
|
|
|
def benchmark_many_repeated_scalar_insert_scalar(self):
|
|
table = self._create_table()
|
|
c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next()
|
|
value = variables.Variable(1.0)
|
|
insert = table.insert(c, value)
|
|
size = table.size()
|
|
with session.Session() as sess:
|
|
sess.run(value.initializer)
|
|
self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000)
|
|
assert sess.run(size) >= 10000
|
|
|
|
def benchmark_single_repeated_batch_32_insert_scalar(self):
|
|
table = self._create_table()
|
|
value = variables.Variable([1.0] * 32)
|
|
insert = table.insert(list(range(32)), value)
|
|
size = table.size()
|
|
with session.Session() as sess:
|
|
sess.run(value.initializer)
|
|
self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000)
|
|
assert sess.run(size) == 32
|
|
|
|
def benchmark_many_repeated_batch_32_insert_scalar(self):
|
|
table = self._create_table()
|
|
c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next()
|
|
value = variables.Variable([1.0] * 32)
|
|
insert = table.insert(32 * c + list(range(32)), value)
|
|
size = table.size()
|
|
with session.Session() as sess:
|
|
sess.run(value.initializer)
|
|
self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000)
|
|
assert sess.run(size) >= 1000 * 32
|
|
|
|
|
|
class DenseHashTableBenchmark(MutableHashTableBenchmark):
|
|
|
|
def _create_table(self):
|
|
return lookup_ops.DenseHashTable(
|
|
dtypes.int64,
|
|
dtypes.float32,
|
|
default_value=0.0,
|
|
empty_key=-1,
|
|
deleted_key=-2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|