STT-tensorflow/tensorflow/contrib/lookup/lookup_ops_test.py
A. Unique TensorFlower 9654cf9e2d Add tf.tables_initializer as a replacement for tf.initialize_all_tables, and
update callers in tensorflow.

Deprecated tf.initialize_all_tables.
Change: 144575599
2017-01-15 09:07:24 -08:00

2027 lines
77 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.lookup.lookup_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import numpy as np
import six
from tensorflow.contrib.lookup import lookup_ops
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
class HashTableOpTest(test.TestCase):
def testHashTable(self):
with self.test_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
self.assertAllEqual([3], output.get_shape())
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testHashTableFindHighRank(self):
with self.test_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(
[["brain", "salad"], ["tank", "tarkus"]])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testHashTableInitWithPythonArrays(self):
with self.test_session():
default_val = -1
keys = ["brain", "salad", "surgery"]
values = [0, 1, 2]
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(
keys, values, value_dtype=dtypes.int64),
default_val)
table.init.run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testHashTableInitWithNumPyArrays(self):
with self.test_session():
default_val = -1
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
values = np.array([0, 1, 2], dtype=np.int64)
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testMultipleHashTables(self):
with self.test_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table1 = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table2 = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table3 = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
data_flow_ops.tables_initializer().run()
self.assertAllEqual(3, table1.size().eval())
self.assertAllEqual(3, table2.size().eval())
self.assertAllEqual(3, table3.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
output1 = table1.lookup(input_string)
output2 = table2.lookup(input_string)
output3 = table3.lookup(input_string)
out1, out2, out3 = sess.run([output1, output2, output3])
self.assertAllEqual([0, 1, -1], out1)
self.assertAllEqual([0, 1, -1], out2)
self.assertAllEqual([0, 1, -1], out3)
def testHashTableWithTensorDefault(self):
with self.test_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testHashTableWithSparseTensorInput(self):
with self.test_session() as sess:
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
sp_indices = [[0, 0], [0, 1], [1, 0]]
sp_shape = [2, 2]
input_tensor = sparse_tensor.SparseTensor(
constant_op.constant(sp_indices, dtypes.int64),
constant_op.constant(["brain", "salad", "tank"]),
constant_op.constant(sp_shape, dtypes.int64))
output = table.lookup(input_tensor)
out_indices, out_values, out_shape = sess.run(output)
self.assertAllEqual([0, 1, -1], out_values)
self.assertAllEqual(sp_indices, out_indices)
self.assertAllEqual(sp_shape, out_shape)
def testSignatureMismatch(self):
with self.test_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
with self.assertRaises(TypeError):
table.lookup(input_string)
with self.assertRaises(TypeError):
lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), "UNK")
def testDTypes(self):
with self.test_session():
default_val = -1
with self.assertRaises(TypeError):
lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(["a"], [1], [dtypes.string],
dtypes.int64), default_val)
def testNotInitialized(self):
with self.test_session():
default_val = -1
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(
["a"], [1], value_dtype=dtypes.int64),
default_val)
input_string = constant_op.constant(["brain", "salad", "surgery"])
output = table.lookup(input_string)
with self.assertRaisesOpError("Table not initialized"):
output.eval()
def testInitializeTwice(self):
with self.test_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
with self.assertRaisesOpError("Table already initialized"):
table.init.run()
def testInitializationWithInvalidDimensions(self):
with self.test_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
with self.assertRaises(ValueError):
lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
def testMultipleSessions(self):
# Start a server
server = server_lib.Server(
{
"local0": ["localhost:0"]
}, protocol="grpc", start=True)
# Create two sessions sharing the same state
session1 = session.Session(server.target)
session2 = session.Session(server.target)
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values),
default_val,
name="t1")
# Init the table in the first session.
with session1:
table.init.run()
self.assertAllEqual(3, table.size().eval())
# Init the table in the second session and verify that we do not get a
# "Table already initialized" error.
with session2:
table.init.run()
self.assertAllEqual(3, table.size().eval())
class MutableHashTableOpTest(test.TestCase):
def testMutableHashTable(self):
with self.test_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
self.assertAllEqual([3], output.get_shape())
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
exported_keys, exported_values = table.export()
self.assertAllEqual([None], exported_keys.get_shape().as_list())
self.assertAllEqual([None], exported_values.get_shape().as_list())
# exported data is in the order of the internal map, i.e. undefined
sorted_keys = np.sort(exported_keys.eval())
sorted_values = np.sort(exported_values.eval())
self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
self.assertAllEqual([0, 1, 2], sorted_values)
def testSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
with self.test_session(graph=ops.Graph()) as sess:
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(20.0, name="v1")
default_val = -1
keys = constant_op.constant(["b", "c", "d"], dtypes.string)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableHashTable(
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
save = saver.Saver()
variables.global_variables_initializer().run()
# Check that the parameter nodes have been initialized.
self.assertEqual(10.0, v0.eval())
self.assertEqual(20.0, v1.eval())
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
with self.test_session(graph=ops.Graph()) as sess:
v0 = variables.Variable(-1.0, name="v0")
v1 = variables.Variable(-1.0, name="v1")
default_val = -1
table = lookup_ops.MutableHashTable(
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
table.insert(
constant_op.constant(["a", "c"], dtypes.string),
constant_op.constant([12, 24], dtypes.int64)).run()
self.assertAllEqual(2, table.size().eval())
save = saver.Saver()
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
# Check that the parameter nodes have been restored.
self.assertEqual(10.0, v0.eval())
self.assertEqual(20.0, v1.eval())
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["a", "b", "c", "d", "e"],
dtypes.string)
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
def testSharing(self):
# Start a server to store the table state
server = server_lib.Server(
{
"local0": ["localhost:0"]
}, protocol="grpc", start=True)
# Create two sessions sharing the same state
session1 = session.Session(server.target)
session2 = session.Session(server.target)
table = lookup_ops.MutableHashTable(
dtypes.int64, dtypes.string, "-", name="t1")
# Populate the table in the first session
with session1:
self.assertAllEqual(0, table.size().eval())
keys = constant_op.constant([11, 12], dtypes.int64)
values = constant_op.constant(["a", "b"])
table.insert(keys, values).run()
self.assertAllEqual(2, table.size().eval())
output = table.lookup(constant_op.constant([11, 12, 13], dtypes.int64))
self.assertAllEqual([b"a", b"b", b"-"], output.eval())
# Verify that we can access the shared data from the second session
with session2:
self.assertAllEqual(2, table.size().eval())
output = table.lookup(constant_op.constant([10, 11, 12], dtypes.int64))
self.assertAllEqual([b"-", b"a", b"b"], output.eval())
def testMutableHashTableOfTensors(self):
with self.test_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
self.assertAllEqual([3, 2], output.get_shape())
result = output.eval()
self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result)
exported_keys, exported_values = table.export()
self.assertAllEqual([None], exported_keys.get_shape().as_list())
self.assertAllEqual([None, 2], exported_values.get_shape().as_list())
# exported data is in the order of the internal map, i.e. undefined
sorted_keys = np.sort(exported_keys.eval())
sorted_values = np.sort(exported_values.eval())
self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values)
def testMutableHashTableExportInsert(self):
with self.test_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
table1 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.assertAllEqual(0, table1.size().eval())
table1.insert(keys, values).run()
self.assertAllEqual(3, table1.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
expected_output = [[0, 1], [2, 3], [-1, -1]]
output1 = table1.lookup(input_string)
self.assertAllEqual(expected_output, output1.eval())
exported_keys, exported_values = table1.export()
self.assertAllEqual(3, exported_keys.eval().size)
self.assertAllEqual(6, exported_values.eval().size)
# Populate a second table from the exported data
table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.assertAllEqual(0, table2.size().eval())
table2.insert(exported_keys, exported_values).run()
self.assertAllEqual(3, table2.size().eval())
# Verify lookup result is still the same
output2 = table2.lookup(input_string)
self.assertAllEqual(expected_output, output2.eval())
def testMutableHashTableOfTensorsInvalidShape(self):
with self.test_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
# Shape [6] instead of [3, 2]
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64)
with self.assertRaisesOpError("Expected shape"):
table.insert(keys, values).run()
# Shape [2,3] instead of [3, 2]
values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64)
with self.assertRaisesOpError("Expected shape"):
table.insert(keys, values).run()
# Shape [2, 2] instead of [3, 2]
values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
with self.assertRaisesOpError("Expected shape"):
table.insert(keys, values).run()
# Shape [3, 1] instead of [3, 2]
values = constant_op.constant([[0], [2], [4]], dtypes.int64)
with self.assertRaisesOpError("Expected shape"):
table.insert(keys, values).run()
# Valid Insert
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
def testMutableHashTableInvalidDefaultValue(self):
with self.test_session():
default_val = constant_op.constant([[-1, -1]], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
with self.assertRaisesOpError("Default value must be a vector"):
self.assertAllEqual(0, table.size().eval())
def testMutableHashTableDuplicateInsert(self):
with self.test_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery", "brain"])
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([3, 1, -1], result)
def testMutableHashTableFindHighRank(self):
with self.test_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(
[["brain", "salad"], ["tank", "tarkus"]])
output = table.lookup(input_string)
self.assertAllEqual([2, 2], output.get_shape())
result = output.eval()
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testMutableHashTableInsertHighRank(self):
with self.test_session():
default_val = -1
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
table.insert(keys, values).run()
self.assertAllEqual(4, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, 3, -1], result)
def testMutableHashTableOfTensorsFindHighRank(self):
with self.test_session():
default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(
[["brain", "salad"], ["tank", "tarkus"]])
output = table.lookup(input_string)
self.assertAllEqual([2, 2, 3], output.get_shape())
result = output.eval()
self.assertAllEqual(
[[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
def testMultipleMutableHashTables(self):
with self.test_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table1 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
table3 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
table1.insert(keys, values).run()
table2.insert(keys, values).run()
table3.insert(keys, values).run()
self.assertAllEqual(3, table1.size().eval())
self.assertAllEqual(3, table2.size().eval())
self.assertAllEqual(3, table3.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
output1 = table1.lookup(input_string)
output2 = table2.lookup(input_string)
output3 = table3.lookup(input_string)
out1, out2, out3 = sess.run([output1, output2, output3])
self.assertAllEqual([0, 1, -1], out1)
self.assertAllEqual([0, 1, -1], out2)
self.assertAllEqual([0, 1, -1], out3)
def testMutableHashTableWithTensorDefault(self):
with self.test_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testSignatureMismatch(self):
with self.test_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
# insert with keys of the wrong type
with self.assertRaises(TypeError):
table.insert(constant_op.constant([4, 5, 6]), values).run()
# insert with values of the wrong type
with self.assertRaises(TypeError):
table.insert(keys, constant_op.constant(["a", "b", "c"])).run()
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
# lookup with keys of the wrong type
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
with self.assertRaises(TypeError):
table.lookup(input_string).eval()
# default value of the wrong type
with self.assertRaises(TypeError):
lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, "UNK")
def testMutableHashTableStringFloat(self):
with self.test_session():
default_val = -1.5
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1.1, 2.2], dtypes.float32)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32,
default_val)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllClose([0, 1.1, -1.5], result)
def testMutableHashTableInt64String(self):
with self.test_session():
default_val = "n/a"
keys = constant_op.constant([0, 1, 2], dtypes.int64)
values = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.string,
default_val)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant([0, 1, 3], dtypes.int64)
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual((b"brain", b"salad", b"n/a"), result)
class MutableDenseHashTableOpTest(test.TestCase):
def testBasic(self):
with self.test_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([3], output.get_shape())
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testLookupUnknownShape(self):
with self.test_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
placeholder_keys = array_ops.placeholder(dtypes.int64)
output = table.lookup(placeholder_keys)
self.assertAllEqual(None, output.get_shape())
result = output.eval({placeholder_keys: [11, 12, 15]})
self.assertAllEqual([0, 1, -1], result)
def testMapStringToFloat(self):
with self.test_session():
keys = constant_op.constant(["a", "b", "c"], dtypes.string)
values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32)
default_value = constant_op.constant(-1.5, dtypes.float32)
table = lookup_ops.MutableDenseHashTable(
dtypes.string,
dtypes.float32,
default_value=default_value,
empty_key="")
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant(["a", "b", "d"], dtypes.string)
output = table.lookup(input_string)
self.assertAllEqual([3], output.get_shape())
result = output.eval()
self.assertAllClose([0, 1.1, -1.5], result)
def testMapInt64ToFloat(self):
for float_dtype in [dtypes.float32, dtypes.float64]:
with self.test_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0.0, 1.1, 2.2], float_dtype)
default_value = constant_op.constant(-1.5, float_dtype)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64, float_dtype, default_value=default_value, empty_key=0)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([3], output.get_shape())
result = output.eval()
self.assertAllClose([0, 1.1, -1.5], result)
def testVectorValues(self):
with self.test_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]],
dtypes.int64)
default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=0,
initial_num_buckets=4)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(4, len(table.export()[0].eval()))
table.insert(
constant_op.constant([14], dtypes.int64),
constant_op.constant([[2, 3, 4, 5]], dtypes.int64)).run()
self.assertAllEqual(4, table.size().eval())
self.assertAllEqual(8, len(table.export()[0].eval()))
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([3, 4], output.get_shape())
result = output.eval()
self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]],
result)
def testVectorKeys(self):
with self.test_session():
keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
values = constant_op.constant([10, 11, 12], dtypes.int64)
empty_key = constant_op.constant([0, 3], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
initial_num_buckets=8)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
table.insert(
constant_op.constant([[0, 0]], dtypes.int64),
constant_op.constant([13], dtypes.int64)).run()
self.assertAllEqual(4, table.size().eval())
self.assertAllEqual(8, len(table.export()[0].eval()))
input_string = constant_op.constant([[0, 1], [1, 2], [0, 2]],
dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([3], output.get_shape())
result = output.eval()
self.assertAllEqual([10, 11, -1], result)
def testResize(self):
with self.test_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=-1,
empty_key=0,
initial_num_buckets=4)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(4, len(table.export()[0].eval()))
keys2 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64)
values2 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64)
table.insert(keys2, values2).run()
self.assertAllEqual(7, table.size().eval())
self.assertAllEqual(16, len(table.export()[0].eval()))
keys3 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18],
dtypes.int64)
output = table.lookup(keys3)
self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval())
def testExport(self):
with self.test_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([1, 2, 3], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=-1,
empty_key=100,
initial_num_buckets=8)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
exported_keys, exported_values = table.export()
self.assertAllEqual([None], exported_keys.get_shape().as_list())
self.assertAllEqual([None], exported_values.get_shape().as_list())
np_keys = exported_keys.eval()
np_values = exported_values.eval()
self.assertAllEqual(8, len(np_keys))
self.assertAllEqual(8, len(np_values))
# pair up keys and values, drop extra added dimension
pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0]
# sort by key
pairs = pairs[pairs[:, 0].argsort()]
self.assertAllEqual([[11, 1], [12, 2], [13, 3], [100, 0], [100, 0],
[100, 0], [100, 0], [100, 0]], pairs)
def testSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
with self.test_session(graph=ops.Graph()) as sess:
default_value = -1
empty_key = 0
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
name="t1",
checkpoint=True,
initial_num_buckets=32)
save = saver.Saver()
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
with self.test_session(graph=ops.Graph()) as sess:
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
name="t1",
checkpoint=True,
initial_num_buckets=64)
table.insert(
constant_op.constant([11, 14], dtypes.int64),
constant_op.constant([12, 24], dtypes.int64)).run()
self.assertAllEqual(2, table.size().eval())
self.assertAllEqual(64, len(table.export()[0].eval()))
save = saver.Saver()
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
def testVectorSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
with self.test_session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
default_value = constant_op.constant([-1, -2], dtypes.int64)
keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64)
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
name="t1",
checkpoint=True,
initial_num_buckets=32)
save = saver.Saver()
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
with self.test_session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
default_value = constant_op.constant([-1, -2], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
name="t1",
checkpoint=True,
initial_num_buckets=64)
table.insert(
constant_op.constant([[11, 12], [13, 15]], dtypes.int64),
constant_op.constant([[21, 22], [23, 24]], dtypes.int64)).run()
self.assertAllEqual(2, table.size().eval())
self.assertAllEqual(64, len(table.export()[0].eval()))
save = saver.Saver()
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
input_string = constant_op.constant(
[[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]],
output.eval())
def testVectorScalarSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
with self.test_session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
name="t2",
checkpoint=True,
initial_num_buckets=32)
save = saver.Saver()
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
with self.test_session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
name="t2",
checkpoint=True,
initial_num_buckets=64)
table.insert(
constant_op.constant([[11, 12], [13, 15]], dtypes.int64),
constant_op.constant([3, 4], dtypes.int64)).run()
self.assertAllEqual(2, table.size().eval())
self.assertAllEqual(64, len(table.export()[0].eval()))
save = saver.Saver()
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))
input_string = constant_op.constant(
[[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([0, 1, -1, 2, -1], output.eval())
def testReprobe(self):
with self.test_session():
# Insert 6 keys into a table with 8 buckets.
# The values are chosen to make sure collisions occur when using GCC STL
keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64)
values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=-1,
empty_key=0,
initial_num_buckets=8)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(6, table.size().eval())
input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22],
dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([9], output.get_shape())
result = output.eval()
self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result)
def testCustomEmptyKey(self):
with self.test_session():
keys = constant_op.constant([11, 0, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableDenseHashTable(
dtypes.int64, dtypes.int64, default_value=-1, empty_key=12)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
input_string = constant_op.constant([11, 0, 15], dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([3], output.get_shape())
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testErrors(self):
with self.test_session():
table = lookup_ops.MutableDenseHashTable(
dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
# Inserting the empty key returns an error
keys = constant_op.constant([11, 0], dtypes.int64)
values = constant_op.constant([0, 1], dtypes.int64)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"empty_key"):
table.insert(keys, values).run()
# Looking up the empty key returns an error
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"empty_key"):
table.lookup(keys).eval()
# Arbitrary tensors of keys are not supported
keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64)
values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Expected key shape"):
table.lookup(keys).eval()
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Expected key shape"):
table.insert(keys, values).run()
table2 = lookup_ops.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=-1,
empty_key=17,
initial_num_buckets=12)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Number of buckets must be"):
self.assertAllEqual(0, table2.size().eval())
class StringToIndexTableFromFile(test.TestCase):
def _createVocabFile(self, basename):
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
with open(vocabulary_file, "w") as f:
f.write("\n".join(["brain", "salad", "surgery"]) + "\n")
return vocabulary_file
def test_string_to_index_table_from_file(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
with self.test_session():
table = lookup_ops.string_to_index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_string_to_index_table_from_file_with_default_value(self):
default_value = -42
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
with self.test_session():
table = lookup_ops.string_to_index_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_string_to_index_table_from_file_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab3.txt")
with self.test_session():
table = lookup_ops.string_to_index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1000)
ids = table.lookup(
constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual(
(
1, # From vocabulary file.
2, # From vocabulary file.
867, # 3 + fingerprint("tarkus") mod 300.
860), # 3 + fingerprint("toccata") mod 300.
ids.eval())
def test_string_to_index_table_from_file_with_only_oov_buckets(self):
self.assertRaises(
ValueError,
lookup_ops.string_to_index_table_from_file,
vocabulary_file=None)
def test_string_to_index_table_from_file_with_vocab_size_too_small(self):
vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
with self.test_session():
table = lookup_ops.string_to_index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=2)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, -1, -1), ids.eval())
self.assertEqual(2, table.size().eval())
def test_string_to_index_table_from_file_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
with self.test_session():
table = lookup_ops.string_to_index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Invalid vocab_size", table.init.run)
def test_string_to_index_table_from_file_with_vocab_size(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
self.assertRaises(
ValueError,
lookup_ops.string_to_index_table_from_file,
vocabulary_file=vocabulary_file,
vocab_size=0)
with self.test_session():
table = lookup_ops.string_to_index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), ids.eval())
self.assertEqual(3, table.size().eval())
def test_string_to_index_table_from_file_with_invalid_hashers(self):
vocabulary_file = self._createVocabFile("invalid_hasher.txt")
with self.test_session():
with self.assertRaises(TypeError):
lookup_ops.string_to_index_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=3,
num_oov_buckets=1,
hasher_spec=1)
table = lookup_ops.string_to_index_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=3,
num_oov_buckets=1,
hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None))
self.assertRaises(ValueError, table.lookup,
constant_op.constant(["salad", "surgery", "tarkus"]))
class StringToIndexTableFromTensor(test.TestCase):
def test_string_to_index_table_from_tensor_with_tensor_init(self):
with self.test_session():
table = lookup_ops.string_to_index_table_from_tensor(
mapping=["brain", "salad", "surgery"], num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_string_to_index_table_from_tensor_with_default_value(self):
default_value = -42
with self.test_session():
table = lookup_ops.string_to_index_table_from_tensor(
mapping=["brain", "salad", "surgery"], default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_string_to_index_table_from_tensor_with_only_oov_buckets(self):
with self.test_session():
with self.assertRaises(ValueError):
lookup_ops.string_to_index_table_from_tensor(
mapping=None, num_oov_buckets=1)
def test_string_to_index_table_from_tensor_with_invalid_hashers(self):
with self.test_session():
with self.assertRaises(TypeError):
lookup_ops.string_to_index_table_from_tensor(
mapping=["brain", "salad", "surgery"],
num_oov_buckets=1,
hasher_spec=1)
table = lookup_ops.string_to_index_table_from_tensor(
mapping=["brain", "salad", "surgery"],
num_oov_buckets=1,
hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None))
self.assertRaises(ValueError, table.lookup,
constant_op.constant(["salad", "surgery", "tarkus"]))
class StringToIndexTest(test.TestCase):
def test_string_to_index(self):
with self.test_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
feats = constant_op.constant(["salad", "surgery", "tarkus"])
indices = lookup_ops.string_to_index(feats, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, indices.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), indices.eval())
def test_duplicate_entries(self):
with self.test_session():
mapping_strings = constant_op.constant(["hello", "hello"])
feats = constant_op.constant(["hello", "hola"])
indices = lookup_ops.string_to_index(feats, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError,
data_flow_ops.tables_initializer().run)
def test_string_to_index_with_default_value(self):
default_value = -42
with self.test_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
feats = constant_op.constant(["salad", "surgery", "tarkus"])
indices = lookup_ops.string_to_index(
feats, mapping=mapping_strings, default_value=default_value)
self.assertRaises(errors_impl.OpError, indices.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), indices.eval())
class IndexToStringTableFromFileTest(test.TestCase):
def _createVocabFile(self, basename):
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
with open(vocabulary_file, "w") as f:
f.write("\n".join(["brain", "salad", "surgery"]) + "\n")
return vocabulary_file
def test_index_to_string_table(self):
vocabulary_file = self._createVocabFile("i2f_vocab1.txt")
with self.test_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file)
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
features.eval())
def test_index_to_string_table_with_default_value(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
with self.test_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value),
features.eval())
def test_index_to_string_table_with_vocab_size_too_small(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
with self.test_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=2,
default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((b"salad", default_value, default_value),
features.eval())
def test_index_to_string_table_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
with self.test_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
init = data_flow_ops.tables_initializer()
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Invalid vocab_size", init.run)
def test_index_to_string_table_with_vocab_size(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
with self.test_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval())
class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_table_from_tensor(self):
with self.test_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.index_to_string_table_from_tensor(
mapping=mapping_strings)
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
features = table.lookup(indices)
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
features.eval())
def test_duplicate_entries(self):
with self.test_session():
mapping_strings = constant_op.constant(["hello", "hello"])
table = lookup_ops.index_to_string_table_from_tensor(
mapping=mapping_strings)
indices = constant_op.constant([0, 1, 4], dtypes.int64)
features = table.lookup(indices)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval())
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
with self.test_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.index_to_string_table_from_tensor(
mapping=mapping_strings, default_value=default_value)
indices = constant_op.constant([1, 2, 4], dtypes.int64)
features = table.lookup(indices)
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value),
features.eval())
class IndexToStringTest(test.TestCase):
def test_index_to_string(self):
with self.test_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
feats = lookup_ops.index_to_string(indices, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, feats.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
feats.eval())
def test_duplicate_entries(self):
with self.test_session():
mapping_strings = constant_op.constant(["hello", "hello"])
indices = constant_op.constant([0, 1, 4], dtypes.int64)
feats = lookup_ops.index_to_string(indices, mapping=mapping_strings)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval())
self.assertRaises(errors_impl.OpError,
data_flow_ops.tables_initializer().run)
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
with self.test_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
indices = constant_op.constant([1, 2, 4], dtypes.int64)
feats = lookup_ops.index_to_string(
indices, mapping=mapping_strings, default_value=default_value)
self.assertRaises(errors_impl.OpError, feats.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval())
class InitializeTableFromFileOpTest(test.TestCase):
def _createVocabFile(self, basename):
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
with open(vocabulary_file, "w") as f:
f.write("\n".join(["brain", "salad", "surgery"]) + "\n")
return vocabulary_file
def testInitializeTable(self):
vocabulary_file = self._createVocabFile("one_column_1.txt")
with self.test_session():
default_value = -1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string,
lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER),
default_value)
table.init.run()
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testInitializeIndexTable(self):
vocabulary_file = self._createVocabFile("one_column_2.txt")
with self.test_session():
default_value = "UNK"
key_index = lookup_ops.TextFileIndex.LINE_NUMBER
value_index = lookup_ops.TextFileIndex.WHOLE_LINE
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(vocabulary_file, dtypes.int64,
key_index, dtypes.string, value_index),
default_value)
table.init.run()
input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
output = table.lookup(input_values)
result = output.eval()
self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"], result)
def testMultiColumn(self):
vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt")
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
with self.test_session():
default_value = -1
key_index = 1
value_index = 2
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string,
key_index, dtypes.int64, value_index),
default_value)
table.init.run()
input_string = constant_op.constant(["brain", "salad", "surgery"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([1, 5, 6], result)
def testInvalidDataTypeInMultiColumn(self):
vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt")
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
with self.test_session():
default_value = -1
key_index = 2
value_index = 1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string,
key_index, dtypes.int64, value_index),
default_value)
with self.assertRaisesOpError("is not a valid"):
table.init.run()
def testInvalidDataType(self):
vocabulary_file = self._createVocabFile("one_column_3.txt")
with self.test_session():
default_value = "UNK"
key_index = lookup_ops.TextFileIndex.WHOLE_LINE
value_index = lookup_ops.TextFileIndex.LINE_NUMBER
with self.assertRaises(ValueError):
lookup_ops.HashTable(
lookup_ops.TextFileInitializer(vocabulary_file, dtypes.int64,
key_index, dtypes.string,
value_index), default_value)
def testInvalidIndex(self):
vocabulary_file = self._createVocabFile("one_column_4.txt")
with self.test_session():
default_value = -1
key_index = 1 # second column of the line
value_index = lookup_ops.TextFileIndex.LINE_NUMBER
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string,
key_index, dtypes.int64, value_index),
default_value)
with self.assertRaisesOpError("Invalid number of columns"):
table.init.run()
def testInitializeSameTableWithMultipleNodes(self):
vocabulary_file = self._createVocabFile("one_column_5.txt")
with self.test_session() as sess:
shared_name = "shared-one-columm"
default_value = -1
table1 = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string,
lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER),
default_value,
shared_name=shared_name)
table2 = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string,
lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER),
default_value,
shared_name=shared_name)
table3 = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string,
lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER),
default_value,
shared_name=shared_name)
data_flow_ops.tables_initializer().run()
input_string = constant_op.constant(["brain", "salad", "tank"])
output1 = table1.lookup(input_string)
output2 = table2.lookup(input_string)
output3 = table3.lookup(input_string)
out1, out2, out3 = sess.run([output1, output2, output3])
self.assertAllEqual([0, 1, -1], out1)
self.assertAllEqual([0, 1, -1], out2)
self.assertAllEqual([0, 1, -1], out3)
def testInitializeTableWithNoFilename(self):
with self.test_session():
default_value = -1
with self.assertRaises(ValueError):
lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
"", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
default_value)
def testInitializeWithVocabSize(self):
with self.test_session():
default_value = -1
vocab_size = 3
vocabulary_file1 = self._createVocabFile("one_column6.txt")
table1 = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
vocabulary_file1,
dtypes.string,
lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER,
vocab_size=vocab_size),
default_value)
# Initialize from file.
table1.init.run()
self.assertEquals(vocab_size, table1.size().eval())
vocabulary_file2 = self._createVocabFile("one_column7.txt")
vocab_size = 5
table2 = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
vocabulary_file2,
dtypes.string,
lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER,
vocab_size=vocab_size),
default_value)
with self.assertRaisesOpError("Invalid vocab_size"):
table2.init.run()
vocab_size = 1
vocabulary_file3 = self._createVocabFile("one_column3.txt")
table3 = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
vocabulary_file3,
dtypes.string,
lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER,
vocab_size=vocab_size),
default_value)
# Smaller vocab size reads only vocab_size records.
table3.init.run()
self.assertEquals(vocab_size, table3.size().eval())
def testFeedVocabularyName(self):
vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
with self.test_session():
default_value = -1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer("old_file.txt", dtypes.string,
lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER),
default_value)
# Initialize with non existing file (old_file.txt) should fail.
# TODO(yleon): Update message, which might change per FileSystem.
with self.assertRaisesOpError("old_file.txt"):
table.init.run()
# Initialize the model feeding the vocabulary file.
filenames = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
table.init.run(feed_dict={filenames[0]: vocabulary_file})
input_string = constant_op.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testInvalidFilenames(self):
vocabulary_file = self._createVocabFile("filename_shape.txt")
with self.test_session():
default_value = -1
# Invalid data type
other_type = constant_op.constant(1)
with self.assertRaises(ValueError):
lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
other_type, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
default_value)
# Non-scalar filename
filenames = constant_op.constant([vocabulary_file, vocabulary_file])
with self.assertRaises(ValueError):
lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
default_value)
def testIdToStringTable(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
with self.test_session():
default_value = "UNK"
vocab_size = 3
table = lookup_ops.HashTable(
lookup_ops.TextFileStringTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value)
table.init.run()
input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
out = table.lookup(input_values)
self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"], out.eval())
self.assertEquals(vocab_size, table.size().eval())
def testStringToIdTable(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt")
with self.test_session():
default_value = -1
vocab_size = 3
table = lookup_ops.HashTable(
lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value)
table.init.run()
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
out = table.lookup(input_string)
self.assertAllEqual([0, 1, 2, -1], out.eval())
self.assertEquals(vocab_size, table.size().eval())
class IdTableWithHashBucketsTest(test.TestCase):
def _createVocabFile(self, basename):
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
with open(vocabulary_file, "w") as f:
f.write("\n".join(["brain", "salad", "surgery"]) + "\n")
return vocabulary_file
def testIdTableWithHashBucketsInit(self):
vocab_file = self._createVocabFile("feat_to_id_3.txt")
with self.test_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
table = lookup_ops.IdTableWithHashBuckets(
lookup_ops.HashTable(
lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value),
oov_buckets)
table.init.run()
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
out = table.lookup(input_string)
self.assertAllEqual([0, 1, 2, 3], out.eval())
self.assertEquals(vocab_size + oov_buckets, table.size().eval())
def testIdTableWithOnlyHashBucket(self):
with self.test_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
# an id calculated by fingerprint("input") mod oov_buckets.
table = lookup_ops.IdTableWithHashBuckets(None, oov_buckets)
table.init.run()
input_string = constant_op.constant(["brain", "salad", "surgery"])
out = table.lookup(input_string)
self.assertAllEqual(
[
3, # fingerprint("brain") mod 5.
1, # fingerprint("salad") mod 5.
4 # fingerprint("surgery") mod 5
],
out.eval())
self.assertEquals(oov_buckets, table.size().eval())
def testIdTableWithHashBucketsWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
with self.test_session() as sess:
default_value = -1
vocab_size = 3
oov_buckets = 3
vocab_table = lookup_ops.HashTable(
lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value)
table1 = lookup_ops.IdTableWithHashBuckets(
vocab_table,
oov_buckets,
hasher_spec=lookup_ops.FastHashSpec,
name="table1")
table2 = lookup_ops.IdTableWithHashBuckets(
vocab_table,
oov_buckets,
hasher_spec=lookup_ops.StrongHashSpec((1, 2)),
name="table2")
data_flow_ops.tables_initializer().run()
input_string = constant_op.constant(
["fruit", "brain", "salad", "surgery", "UNK"])
out1 = table1.lookup(input_string)
out2 = table2.lookup(input_string)
out1, out2 = sess.run([out1, out2])
self.assertAllEqual([5, 0, 1, 2, 5], out1)
self.assertAllEqual([5, 0, 1, 2, 3], out2)
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
self.assertEquals(vocab_size + oov_buckets, table2.size().eval())
test_util.assert_ops_in_graph({
"table1_Lookup/hash_bucket": "StringToHashBucketFast",
"table2_Lookup/hash_bucket": "StringToHashBucketStrong",
}, sess.graph)
def testIdTableWithHashBucketsInitializationAcrossSessions(self):
vocab_file = self._createVocabFile("feat_to_id_5.txt")
shared_name = "across-sessions"
with self.test_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
table1 = lookup_ops.IdTableWithHashBuckets(
lookup_ops.HashTable(
lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value,
shared_name=shared_name),
oov_buckets)
table1.init.run()
input_string_1 = constant_op.constant(
["brain", "salad", "surgery", "UNK"])
out1 = table1.lookup(input_string_1)
self.assertAllEqual([0, 1, 2, 3], out1.eval())
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
with self.test_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
# Underlying lookup table already initialized in previous session.
# No need to call table2.init.run()
table2 = lookup_ops.IdTableWithHashBuckets(
lookup_ops.HashTable(
lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value,
shared_name=shared_name),
oov_buckets)
input_string_2 = constant_op.constant(["fruit", "salad", "UNK"])
out2 = table2.lookup(input_string_2)
self.assertAllEqual([3, 1, 3], out2.eval())
self.assertEquals(vocab_size + oov_buckets, table2.size().eval())
def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
vocab_file = self._createVocabFile("feat_to_id_6.txt")
with self.test_session() as sess:
default_value1 = -1
vocab_size = 3
oov_buckets = 0
table1 = lookup_ops.IdTableWithHashBuckets(
lookup_ops.HashTable(
lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value1),
oov_buckets)
default_value2 = -2
table2 = lookup_ops.IdTableWithHashBuckets(
lookup_ops.HashTable(
lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value2),
oov_buckets)
data_flow_ops.tables_initializer().run()
input_string_1 = constant_op.constant(
["brain", "salad", "surgery", "UNK"])
input_string_2 = constant_op.constant(["fruit", "salad", "UNK"])
out1 = table1.lookup(input_string_1)
out2 = table2.lookup(input_string_2)
out1, out2 = sess.run([out1, out2])
self.assertAllEqual([0, 1, 2, -1], out1)
self.assertAllEqual([-2, 1, -2], out2)
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
self.assertEquals(vocab_size + oov_buckets, table2.size().eval())
def testSparseTensor(self):
vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
with self.test_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
dtypes.string),
constant_op.constant(input_shape, dtypes.int64))
table = lookup_ops.IdTableWithHashBuckets(
lookup_ops.HashTable(
lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=3),
-1),
1)
table.init.run()
sp_ids = table.lookup(sp_features)
self.assertAllEqual([5], sp_ids.values._shape_as_list())
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
self.assertAllEqual(input_indices, sp_ids_ind)
self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
self.assertAllEqual(input_shape, sp_ids_shape)
def testIdTableWithHashBucketsWithInvalidHashers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
with self.test_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
lookup_table = lookup_ops.HashTable(
lookup_ops.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value)
with self.assertRaises(TypeError):
lookup_ops.IdTableWithHashBuckets(
lookup_table, oov_buckets, hasher_spec=1)
table = lookup_ops.IdTableWithHashBuckets(
lookup_table,
oov_buckets,
hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None))
input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
with self.assertRaises(ValueError):
table.lookup(input_string)
with self.assertRaises(ValueError):
table = lookup_ops.IdTableWithHashBuckets(
lookup_table,
oov_buckets,
hasher_spec=lookup_ops.StrongHashSpec([]))
with self.assertRaises(ValueError):
table = lookup_ops.IdTableWithHashBuckets(
lookup_table,
oov_buckets,
hasher_spec=lookup_ops.StrongHashSpec([1, 2, 3]))
with self.assertRaises(TypeError):
table = lookup_ops.IdTableWithHashBuckets(
lookup_table,
oov_buckets,
hasher_spec=lookup_ops.StrongHashSpec([None, 2]))
if __name__ == "__main__":
test.main()