Internal change

PiperOrigin-RevId: 357124297
Change-Id: I5aa7c2d89379643a767053802589320eb3476cd6
This commit is contained in:
A. Unique TensorFlower 2021-02-11 21:58:57 -08:00 committed by TensorFlower Gardener
parent 2ad0a33259
commit 0a501caa91
18 changed files with 548 additions and 65 deletions

View File

@ -72,6 +72,7 @@ class GenerateVocabRemappingOp : public OpKernel {
kUnusedLookupDelim,
-1, // key_index, use the line number.
-2, // value_index, use the whole line/token.
0, // No offset.
context->env(), new_vocab_table));
OP_REQUIRES(context,
new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(),
@ -101,6 +102,7 @@ class GenerateVocabRemappingOp : public OpKernel {
old_vocab_filename, old_vocab_size_, kUnusedLookupDelim,
-2, // key_index, use the whole line/token.
-1, // value_index, use the line number.
0, // No offset.
context->env(), old_vocab_table));
// Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ...,

View File

@ -105,6 +105,7 @@ class InitializeTableFromTextFileOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_));
string delimiter;
OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter));
OP_REQUIRES(ctx, delimiter.size() == 1,
@ -141,7 +142,7 @@ class InitializeTableFromTextFileOp : public OpKernel {
}
OP_REQUIRES_OK(ctx, lookup::InitializeTableFromTextFile(
vocab_filename, vocab_size_, delimiter_, key_index_,
value_index_, ctx->env(), table));
value_index_, offset_, ctx->env(), table));
if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
memory_used_before);
@ -154,6 +155,7 @@ class InitializeTableFromTextFileOp : public OpKernel {
char delimiter_;
int64 key_index_;
int64 value_index_;
int64 offset_;
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp);
};

View File

@ -77,7 +77,7 @@ class TextFileLineIterator
// delimiter.
Status Init(const string& filename, int64 vocab_size, char delimiter,
DataType key_dtype, int64 key_index, DataType value_dtype,
int64 value_index, Env* env) {
int64 value_index, int64 offset, Env* env) {
filename_ = filename;
vocab_size_ = vocab_size;
delimiter_ = delimiter;
@ -93,6 +93,7 @@ class TextFileLineIterator
input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize));
valid_ = true;
next_id_ = 0;
offset_ = offset;
ignore_split_ = std::max(key_index_, value_index_) < 0;
Next();
return status_;
@ -143,6 +144,7 @@ class TextFileLineIterator
return;
}
}
status_ = SetValue(line, tokens, key_index_, &key_);
if (!status_.ok()) {
valid_ = false;
@ -186,6 +188,7 @@ class TextFileLineIterator
int64 value_index_;
Env* env_;
int64 next_id_;
int64 offset_;
int64 vocab_size_;
string filename_;
char delimiter_;
@ -199,7 +202,7 @@ class TextFileLineIterator
Status SetValue(const string& line, const std::vector<string>& tokens,
int64 index, Tensor* tensor) {
if (index == kLineNumber) {
tensor->flat<int64>()(0) = next_id_;
tensor->flat<int64>()(0) = next_id_ + offset_;
return Status::OK();
}
const string& token = (index == kWholeLine) ? line : tokens[index];
@ -212,7 +215,7 @@ class TextFileLineIterator
return errors::InvalidArgument("Field ", token, " in line ", next_id_,
" is not a valid int32.");
}
tensor->flat<int32>()(0) = value;
tensor->flat<int32>()(0) = value + offset_;
} break;
case DT_INT64: {
int64 value;
@ -352,7 +355,7 @@ Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
// Helper function to initialize an InitializableLookupTable from a text file.
Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
char delimiter, int32 key_index,
int32 value_index, Env* env,
int32 value_index, int64 offset, Env* env,
InitializableLookupTable* table) {
if (key_index == kLineNumber && table->key_dtype() != DT_INT64) {
return errors::InvalidArgument(
@ -380,7 +383,8 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
TextFileLineIterator iter;
TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype,
key_index, value_dtype, value_index, env));
key_index, value_dtype, value_index, offset,
env));
// For initialization from files, ignore if the table is already
// initialized. The table shared name should contain the filename to
// avoid trying to initialize the same table from the same file at the same

View File

@ -53,7 +53,7 @@ Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
char delimiter, int32 key_index,
int32 value_index, Env* env,
int32 value_index, int64 offset, Env* env,
InitializableLookupTable* table);
// Initializes `table` from `dataset` by iterating over it. Caller retains

View File

@ -480,6 +480,7 @@ REGISTER_OP("InitializeTableFromTextFile")
.Attr("value_index: int >= -2")
.Attr("vocab_size: int >= -1 = -1")
.Attr("delimiter: string = '\t'")
.Attr("offset: int = 0")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
@ -497,6 +498,7 @@ REGISTER_OP("InitializeTableFromTextFileV2")
.Attr("value_index: int >= -2")
.Attr("vocab_size: int >= -1 = -1")
.Attr("delimiter: string = '\t'")
.Attr("offset: int = 0")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));

View File

@ -802,8 +802,14 @@ class TrackableWeightHandler(object):
# TODO(b/141682913): Figure out why this is private and fix it.
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
if len(saveables) != 1:
raise ValueError('Only Trackables with one Saveable are supported.')
# 'Saveables' won't exist when we're passed a legacy TF1 table like
# a StaticHashTable.
if not saveables:
self._num_tensors = 0
self._setter = lambda weights: None
self._getter = lambda: []
elif len(saveables) == 1:
saveable = list(saveables)[0]
if ops.executing_eagerly_outside_functions():
@ -828,9 +834,14 @@ class TrackableWeightHandler(object):
tensor = spec.tensor
self._placeholder_tensors.append(
array_ops.placeholder(tensor.dtype, tensor.shape))
self._assign_op = self._saveable.restore(self._placeholder_tensors, None)
self._assign_op = self._saveable.restore(self._placeholder_tensors,
None)
self._setter = self._set_weights_v1
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
else:
raise ValueError('Only Trackables with one Saveable are supported. '
'The Trackable %s has %d Saveables.' %
(trackable, len(saveables)))
@property
def num_tensors(self):

View File

@ -80,6 +80,21 @@ tf_py_test(
],
)
tf_py_test(
name = "index_lookup_forward_benchmark",
srcs = ["index_lookup_forward_benchmark.py"],
python_version = "PY3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:platform_benchmark",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/keras/layers/preprocessing:index_lookup",
],
)
tf_py_test(
name = "normalization_adapt_benchmark",
srcs = ["normalization_adapt_benchmark.py"],

View File

@ -0,0 +1,146 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark for Keras text vectorization preprocessing layer's adapt method."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import string
import time
import numpy as np
from tensorflow.python import keras
from tensorflow.python.compat import v2_compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.keras.layers.preprocessing import index_lookup
from tensorflow.python.ops import lookup_ops
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
v2_compat.enable_v2_behavior()
# word_gen creates random sequences of ASCII letters (both lowercase and upper).
# The number of unique strings is ~2,700.
def tensor_gen(batch, num_elements):
data = []
for _ in range(batch):
batch_element = []
for _ in range(num_elements - 1):
tok = "".join(random.choice(string.ascii_letters) for i in range(2))
batch_element.append(tok)
batch_element.append("") # Explicitly test the empty string.
data.append(batch_element)
return constant_op.constant(data)
def get_vocab():
vocab = list(
set([a + b for a in string.ascii_letters for b in string.ascii_letters])) # pylint:disable=g-complex-comprehension
vocab.sort()
return vocab
# This class uses TestCase for get_temp_dir().
class BenchmarkLookup(benchmark.TensorFlowBenchmark):
"""Benchmark the index lookup layer's forward pass."""
def _write_to_temp_file(self, file_name, vocab_list):
vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt")
with gfile.GFile(vocab_path, "w") as writer:
for vocab in vocab_list:
writer.write(vocab + "\n")
writer.flush()
writer.close()
return vocab_path
def run_numpy_implementation(self, data, vocab):
"""Test the python implementation."""
input_t = keras.Input(shape=(), dtype=dtypes.string)
layer = index_lookup.IndexLookup(
vocabulary=vocab,
max_tokens=None,
num_oov_indices=1,
mask_token="",
oov_token="OOV",
dtype=dtypes.string)
out_t = layer(input_t)
model = keras.Model(input_t, out_t)
num_repeats = 5
starts = []
ends = []
_ = model(data)
for _ in range(num_repeats):
starts.append(time.time())
out = model(data)
ends.append(time.time())
avg_time = np.mean(np.array(ends) - np.array(starts))
return avg_time, out
def bm_adapt_implementation(self, num_elements, batch_size):
"""Test the KPL adapt implementation."""
vocab = get_vocab()
vocab_file = self._write_to_temp_file("vocab", vocab)
vocabulary_initializer = lookup_ops.TextFileInitializer(
filename=vocab_file,
key_dtype=dtypes.string,
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
value_dtype=dtypes.int64,
value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
value_index_offset=2)
input_t = keras.Input(shape=(), dtype=dtypes.string)
layer = index_lookup.IndexLookup(
vocabulary=vocabulary_initializer,
max_tokens=None,
num_oov_indices=1,
mask_token="",
oov_token="OOV",
dtype=dtypes.string)
out_t = layer(input_t)
model = keras.Model(input_t, out_t)
num_repeats = 5
starts = []
ends = []
data = tensor_gen(batch_size, num_elements)
_ = model(data)
for _ in range(num_repeats):
starts.append(time.time())
_ = model(data)
ends.append(time.time())
avg_time = np.mean(np.array(ends) - np.array(starts))
baseline, _ = self.run_numpy_implementation(data, vocab)
extras = {
"numpy implementation baseline": baseline,
"delta seconds": (baseline - avg_time),
"delta percent": ((baseline - avg_time) / baseline) * 100
}
name = "index_lookup_forward|%s_elements|batch_%s" % (num_elements,
batch_size)
self.report_benchmark(
iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
def benchmark_vocab_size_by_batch(self):
for tensor_size in [100, 1000, 10000]:
for batch in [1, 16, 2048]:
self.bm_adapt_implementation(tensor_size, batch)
if __name__ == "__main__":
test.main()

View File

@ -184,18 +184,6 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
self._value_dtype = dtypes.int64
oov_value = self._oov_value
self._table = lookup_ops.MutableHashTable(
key_dtype=self._key_dtype,
value_dtype=self._value_dtype,
default_value=oov_value,
name=(self._name + "_index_table"))
tracked_table = self._add_trackable(self._table, trainable=False)
# This is a workaround for summary() on this layer. Because the table is
# not mutable during training, the effective number of parameters (and so
# the weight shape) is 0; we add this as an attr so that the parameter
# counting code in the Model object doesn't throw an attribute error.
tracked_table.shape = tensor_shape.TensorShape((0,))
if self.num_oov_indices <= 1:
oov_indices = None
else:
@ -203,11 +191,28 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
oov_end = oov_start + num_oov_indices
oov_indices = list(range(oov_start, oov_end))
if vocabulary is not None and isinstance(vocabulary,
lookup_ops.TextFileInitializer):
self._table = self._static_table_class()(
vocabulary, default_value=oov_value)
self._table_handler = table_utils.TableHandler(
table=self._table,
mask_token=mask_token,
oov_tokens=oov_indices,
use_v1_apis=self._use_v1_apis())
self.max_tokens = (
self._table_handler.vocab_size() + self.num_oov_indices +
(0 if mask_token is None else 1))
else:
self._table = lookup_ops.MutableHashTable(
key_dtype=self._key_dtype,
value_dtype=self._value_dtype,
default_value=oov_value,
name=(self._name + "_index_table"))
self._table_handler = table_utils.TableHandler(
table=self._table,
oov_tokens=oov_indices,
use_v1_apis=self._use_v1_apis())
if vocabulary is not None:
self.set_vocabulary(vocabulary)
@ -232,6 +237,13 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
dtype=K.floatx(),
initializer=initializer)
tracked_table = self._add_trackable(self._table, trainable=False)
# This is a workaround for summary() on this layer. Because the table is
# not mutable during training, the effective number of parameters (and so
# the weight shape) is 0; we add this as an attr so that the parameter
# counting code in the Model object doesn't throw an attribute error.
tracked_table.shape = tensor_shape.TensorShape((0,))
def compute_output_shape(self, input_shape):
if self.output_mode != INT:
return tensor_shape.TensorShape([input_shape[0], self.max_tokens])
@ -538,6 +550,9 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
def _use_v1_apis(self):
return False
def _static_table_class(self):
return lookup_ops.StaticHashTable
class _IndexLookupAccumulator(
collections.namedtuple("Accumulator",

View File

@ -41,7 +41,9 @@ from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.keras.saving import save
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@ -703,6 +705,15 @@ class CategoricalEncodingAdaptTest(
class IndexLookupOutputTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def _write_to_temp_file(self, file_name, vocab_list):
vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt")
with gfile.GFile(vocab_path, "w") as writer:
for vocab in vocab_list:
writer.write(vocab + "\n")
writer.flush()
writer.close()
return vocab_path
def test_int_output(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
@ -958,7 +969,60 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
layer_output = layer(input_data)
self.assertAllEqual(layer_output.shape.as_list(), [16, 2])
def test_int_output_file_vocab(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
["fire", "", "earth", "michigan"]])
expected_output = [[2, 3, 4, 5], [5, 0, 2, 1]]
vocab_file = self._write_to_temp_file("temp", vocab_data)
vocabulary_initializer = lookup_ops.TextFileInitializer(
filename=vocab_file,
key_dtype=dtypes.string,
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
value_dtype=dtypes.int64,
value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
value_index_offset=2)
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
layer = get_layer_class()(
vocabulary=vocabulary_initializer,
max_tokens=None,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
dtype=dtypes.string)
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
def test_int_output_int_file_vocab(self):
vocab_data = ["10", "20", "30", "40"]
input_array = np.array([[10, 20, 30, 40], [40, 0, 10, 42]])
expected_output = [[2, 3, 4, 5], [5, 0, 2, 1]]
vocab_file = self._write_to_temp_file("temp", vocab_data)
vocabulary_initializer = lookup_ops.TextFileInitializer(
filename=vocab_file,
key_dtype=dtypes.int64,
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
value_dtype=dtypes.int64,
value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
value_index_offset=2)
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
layer = get_layer_class()(
vocabulary=vocabulary_initializer,
max_tokens=None,
num_oov_indices=1,
mask_token=0,
oov_token=-1,
dtype=dtypes.int64)
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
@keras_parameterized.run_all_keras_modes
class IndexLookupVocabularyTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.keras.engine import base_preprocessing_layer_v1
from tensorflow.python.keras.layers.preprocessing import index_lookup
from tensorflow.python.ops import lookup_ops
class IndexLookup(index_lookup.IndexLookup,
@ -58,3 +59,6 @@ class IndexLookup(index_lookup.IndexLookup,
def _use_v1_apis(self):
return True
def _static_table_class(self):
return lookup_ops.StaticHashTableV1

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
@ -38,8 +39,24 @@ from tensorflow.python.platform import gfile
class TableHandler(object):
"""Wrapper object that holds a lookup table and provides accessors."""
def __init__(self, table, oov_tokens=None, use_v1_apis=False):
def __init__(self,
table,
oov_tokens=None,
mask_token=None,
use_v1_apis=False):
self.table = table
# If we are using V1 APIs, and the table has an initializer, we need to run
# it. However, not all tables have initializers, so we try-except here.
if use_v1_apis:
try:
K.get_session().run(self.table.initializer)
except AttributeError:
pass
self.mutable = isinstance(table, lookup_ops.MutableHashTable)
self.mask_token = mask_token
self.use_v1_apis = use_v1_apis
if oov_tokens is None:
self.oov_tokens = oov_tokens
@ -56,10 +73,17 @@ class TableHandler(object):
return self._eval(self.table.size())
def clear(self):
if not self.mutable:
return RuntimeError("Unable to clear a statically-backed table.")
keys, _ = self.table.export()
self._run(self.table.remove(keys))
def insert(self, keys, values):
"""Insert values into the backed table."""
if not self.mutable:
raise RuntimeError("Unable to insert into a statically-backed table.")
if len(values) != len(keys):
raise RuntimeError("Size mismatch between values and key arrays. "
"Keys had size %s, values had size %s." %
@ -90,12 +114,35 @@ class TableHandler(object):
return array_ops.where(oov_locations, oov_values, lookups)
def _lookup_and_mask(self, inputs):
"""Return a lookup with any location with the mask_token masked to 0."""
lookups = self.table.lookup(inputs)
# If we don't need to handle masking, return the lookup values directly.
if self.mask_token is None:
return lookups
# If we do need to handle masking, increment all the lookup values by 1
# to account for the mask value at location 0. This also increments the
# OOV value, so replace that. (This is inefficient, but we can't adjust
# the table safely, so we don't have a choice.)
oov_locations = math_ops.equal(lookups, self.table._default_value) # pylint: disable=protected-access
oov_values = array_ops.ones_like(
lookups, dtype=self.table._value_dtype) * self.table._default_value # pylint: disable=protected-access
adjusted_lookups = array_ops.where(oov_locations, oov_values, lookups)
# Inject 0s wherever the mask token was in the inputs.
mask_locations = math_ops.equal(inputs, self.mask_token)
return array_ops.where(
mask_locations,
array_ops.zeros_like(lookups, dtype=self.table._value_dtype), # pylint: disable=protected-access
adjusted_lookups) # pylint: disable=protected-access
def _ragged_lookup(self, inputs):
"""Perform a table lookup on a ragged tensor."""
# The table lookup ops don't natively support ragged tensors, so if we have
# a RT we need to use map_flat_values to look up every element.
indexed_data = ragged_functional_ops.map_flat_values(
self.table.lookup, inputs)
self._lookup_and_mask, inputs)
indexed_data = ragged_functional_ops.map_flat_values(
self._replace_oov_buckets, inputs, indexed_data)
# table.lookup is not shape-preserving, so we need to set the shape here.
@ -107,7 +154,7 @@ class TableHandler(object):
def _sparse_lookup(self, inputs):
"""Perform a table lookup on a sparse tensor."""
values = self.table.lookup(inputs.values)
values = self._lookup_and_mask(inputs.values)
values = self._replace_oov_buckets(inputs.values, values)
indexed_data = sparse_tensor.SparseTensor(inputs.indices, values,
inputs.dense_shape)
@ -118,7 +165,7 @@ class TableHandler(object):
def _tensor_lookup(self, inputs):
"""Perform a table lookup on a tf.tensor."""
values = self.table.lookup(inputs)
values = self._lookup_and_mask(inputs)
indexed_data = self._replace_oov_buckets(inputs, values)
# (b/149446477): output does not preserve input shape.
indexed_data.set_shape(inputs.shape)

View File

@ -45,6 +45,41 @@ def get_table(dtype=dtypes.string, oov_tokens=None):
table, oov_tokens, use_v1_apis=(not context.executing_eagerly()))
def get_static_table(tmpdir,
vocab_list,
mask_token=None,
dtype=dtypes.string,
oov_tokens=None):
vocabulary_file = os.path.join(tmpdir, "tmp_vocab.txt")
if dtype == dtypes.string:
with open(vocabulary_file, "w") as f:
f.write("\n".join(vocab_list) + "\n")
else:
with open(vocabulary_file, "w") as f:
f.write("\n".join([str(v) for v in vocab_list]) + "\n")
offset = ((0 if mask_token is None else 1) +
(len(oov_tokens) if oov_tokens is not None else 0))
init = lookup_ops.TextFileInitializer(
vocabulary_file,
dtype,
lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER,
value_index_offset=offset)
if context.executing_eagerly():
table = lookup_ops.StaticHashTable(init, default_value=-7)
else:
table = lookup_ops.StaticHashTableV1(init, default_value=-7)
return table_utils.TableHandler(
table,
oov_tokens,
mask_token=mask_token,
use_v1_apis=(not context.executing_eagerly()))
@keras_parameterized.run_all_keras_modes
class CategoricalEncodingInputTest(
keras_parameterized.TestCase,
@ -252,6 +287,132 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
self.assertAllEqual(expected_output, output_data)
@keras_parameterized.run_all_keras_modes
class StaticIndexLookupOutputTest(
keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def test_int_output_default_lookup_value(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "michigan"]])
expected_output = [[1, 2, 3, 4], [4, 3, 1, -7]]
table = get_static_table(
tmpdir=self.get_temp_dir(),
vocab_list=vocab_data,
mask_token="",
oov_tokens=None)
output_data = table.lookup(input_array)
self.assertAllEqual(expected_output, output_data)
def test_output_shape(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "michigan"]])
table = get_static_table(
tmpdir=self.get_temp_dir(), vocab_list=vocab_data, oov_tokens=None)
output_data = table.lookup(input_array)
self.assertAllEqual(input_array.shape[1:], output_data.shape[1:])
def test_int_output_no_reserved_zero_default_lookup_value(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "michigan"]])
expected_output = [[0, 1, 2, 3], [3, 2, 0, -7]]
table = get_static_table(
tmpdir=self.get_temp_dir(), vocab_list=vocab_data, oov_tokens=None)
output_data = table.lookup(input_array)
self.assertAllEqual(expected_output, output_data)
@keras_parameterized.run_all_keras_modes
class CategoricalEncodingStaticInputTest(
keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def test_sparse_string_input(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]],
values=["fire", "michigan"],
dense_shape=[3, 4])
expected_indices = [[0, 0], [1, 2]]
expected_values = [5, 1]
expected_dense_shape = [3, 4]
table = get_static_table(
tmpdir=self.get_temp_dir(),
vocab_list=vocab_data,
mask_token="",
oov_tokens=[1])
output_data = table.lookup(input_array)
self.assertAllEqual(expected_indices, output_data.indices)
self.assertAllEqual(expected_values, output_data.values)
self.assertAllEqual(expected_dense_shape, output_data.dense_shape)
def test_sparse_int_input(self):
vocab_data = np.array([10, 11, 12, 13], dtype=np.int64)
input_array = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]],
values=np.array([13, 32], dtype=np.int64),
dense_shape=[3, 4])
expected_indices = [[0, 0], [1, 2]]
expected_values = [5, 1]
expected_dense_shape = [3, 4]
table = get_static_table(
tmpdir=self.get_temp_dir(),
vocab_list=vocab_data,
dtype=dtypes.int64,
mask_token=0,
oov_tokens=[1])
output_data = table.lookup(input_array)
self.assertAllEqual(expected_indices, output_data.indices)
self.assertAllEqual(expected_values, output_data.values)
self.assertAllEqual(expected_dense_shape, output_data.dense_shape)
def test_ragged_string_input(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = ragged_factory_ops.constant(
[["earth", "wind", "fire"], ["fire", "and", "earth", "michigan"]])
expected_output = [[2, 3, 5], [5, 4, 2, 1]]
table = get_static_table(
tmpdir=self.get_temp_dir(),
vocab_list=vocab_data,
mask_token="",
oov_tokens=[1])
output_data = table.lookup(input_array)
self.assertAllEqual(expected_output, output_data)
def test_ragged_int_input(self):
vocab_data = np.array([10, 11, 12, 13], dtype=np.int64)
input_array = ragged_factory_ops.constant([[10, 11, 13], [13, 12, 10, 42]],
dtype=np.int64)
expected_output = [[2, 3, 5], [5, 4, 2, 1]]
table = get_static_table(
tmpdir=self.get_temp_dir(),
vocab_list=vocab_data,
dtype=dtypes.int64,
mask_token=0,
oov_tokens=[1])
output_data = table.lookup(input_array)
self.assertAllEqual(expected_output, output_data)
class GetVocabularyFromFileTest(test.TestCase):
def setUp(self):

View File

@ -647,7 +647,8 @@ class TextFileInitializer(TableInitializerBase):
value_index,
vocab_size=None,
delimiter="\t",
name=None):
name=None,
value_index_offset=0):
"""Constructs a table initializer object to populate from a text file.
It generates one key-value pair per line. The type of table key and
@ -675,6 +676,13 @@ class TextFileInitializer(TableInitializerBase):
vocab_size: The number of elements in the file, if known.
delimiter: The delimiter to separate fields in a line.
name: A name for the operation (optional).
value_index_offset: A number to add to all indices extracted from the file
This is useful for cases where a user would like to reserve one or more
low index values for control characters. For instance, if you would
like to ensure that no vocabulary item is mapped to index 0 (so you can
reserve 0 for a masking value), you can set value_index_offset to 1;
this will mean that the first vocabulary element is mapped to 1
instead of 0.
Raises:
ValueError: when the filename is empty, or when the table key and value
@ -718,6 +726,7 @@ class TextFileInitializer(TableInitializerBase):
self._name = name
self._filename = self._track_trackable(
trackable.Asset(filename), "_filename")
self._offset = value_index_offset
super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
@ -740,7 +749,8 @@ class TextFileInitializer(TableInitializerBase):
self._filename, dtypes.string, name="asset_filepath")
init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
table.resource_handle, filename, self._key_index, self._value_index,
-1 if self._vocab_size is None else self._vocab_size, self._delimiter)
-1 if self._vocab_size is None else self._vocab_size, self._delimiter,
self._offset)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
# If the filename tensor is anything other than a string constant (e.g.,
# if it is a placeholder) then it does not make sense to track it as an

View File

@ -14,7 +14,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\', \'value_index_offset\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\', \'0\'], "
}
member_method {
name: "initialize"

View File

@ -1950,11 +1950,11 @@ tf_module {
}
member_method {
name: "InitializeTableFromTextFile"
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
}
member_method {
name: "InitializeTableFromTextFileV2"
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
}
member_method {
name: "InitializeTableV2"

View File

@ -14,7 +14,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\', \'value_index_offset\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\', \'0\'], "
}
member_method {
name: "initialize"

View File

@ -1950,11 +1950,11 @@ tf_module {
}
member_method {
name: "InitializeTableFromTextFile"
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
}
member_method {
name: "InitializeTableFromTextFileV2"
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
}
member_method {
name: "InitializeTableV2"