Internal change
PiperOrigin-RevId: 357124297 Change-Id: I5aa7c2d89379643a767053802589320eb3476cd6
This commit is contained in:
parent
2ad0a33259
commit
0a501caa91
@ -72,6 +72,7 @@ class GenerateVocabRemappingOp : public OpKernel {
|
||||
kUnusedLookupDelim,
|
||||
-1, // key_index, use the line number.
|
||||
-2, // value_index, use the whole line/token.
|
||||
0, // No offset.
|
||||
context->env(), new_vocab_table));
|
||||
OP_REQUIRES(context,
|
||||
new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(),
|
||||
@ -101,6 +102,7 @@ class GenerateVocabRemappingOp : public OpKernel {
|
||||
old_vocab_filename, old_vocab_size_, kUnusedLookupDelim,
|
||||
-2, // key_index, use the whole line/token.
|
||||
-1, // value_index, use the line number.
|
||||
0, // No offset.
|
||||
context->env(), old_vocab_table));
|
||||
|
||||
// Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ...,
|
||||
|
@ -105,6 +105,7 @@ class InitializeTableFromTextFileOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_));
|
||||
string delimiter;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter));
|
||||
OP_REQUIRES(ctx, delimiter.size() == 1,
|
||||
@ -141,7 +142,7 @@ class InitializeTableFromTextFileOp : public OpKernel {
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, lookup::InitializeTableFromTextFile(
|
||||
vocab_filename, vocab_size_, delimiter_, key_index_,
|
||||
value_index_, ctx->env(), table));
|
||||
value_index_, offset_, ctx->env(), table));
|
||||
if (ctx->track_allocations()) {
|
||||
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
|
||||
memory_used_before);
|
||||
@ -154,6 +155,7 @@ class InitializeTableFromTextFileOp : public OpKernel {
|
||||
char delimiter_;
|
||||
int64 key_index_;
|
||||
int64 value_index_;
|
||||
int64 offset_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp);
|
||||
};
|
||||
|
@ -77,7 +77,7 @@ class TextFileLineIterator
|
||||
// delimiter.
|
||||
Status Init(const string& filename, int64 vocab_size, char delimiter,
|
||||
DataType key_dtype, int64 key_index, DataType value_dtype,
|
||||
int64 value_index, Env* env) {
|
||||
int64 value_index, int64 offset, Env* env) {
|
||||
filename_ = filename;
|
||||
vocab_size_ = vocab_size;
|
||||
delimiter_ = delimiter;
|
||||
@ -93,6 +93,7 @@ class TextFileLineIterator
|
||||
input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize));
|
||||
valid_ = true;
|
||||
next_id_ = 0;
|
||||
offset_ = offset;
|
||||
ignore_split_ = std::max(key_index_, value_index_) < 0;
|
||||
Next();
|
||||
return status_;
|
||||
@ -143,6 +144,7 @@ class TextFileLineIterator
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
status_ = SetValue(line, tokens, key_index_, &key_);
|
||||
if (!status_.ok()) {
|
||||
valid_ = false;
|
||||
@ -186,6 +188,7 @@ class TextFileLineIterator
|
||||
int64 value_index_;
|
||||
Env* env_;
|
||||
int64 next_id_;
|
||||
int64 offset_;
|
||||
int64 vocab_size_;
|
||||
string filename_;
|
||||
char delimiter_;
|
||||
@ -199,7 +202,7 @@ class TextFileLineIterator
|
||||
Status SetValue(const string& line, const std::vector<string>& tokens,
|
||||
int64 index, Tensor* tensor) {
|
||||
if (index == kLineNumber) {
|
||||
tensor->flat<int64>()(0) = next_id_;
|
||||
tensor->flat<int64>()(0) = next_id_ + offset_;
|
||||
return Status::OK();
|
||||
}
|
||||
const string& token = (index == kWholeLine) ? line : tokens[index];
|
||||
@ -212,7 +215,7 @@ class TextFileLineIterator
|
||||
return errors::InvalidArgument("Field ", token, " in line ", next_id_,
|
||||
" is not a valid int32.");
|
||||
}
|
||||
tensor->flat<int32>()(0) = value;
|
||||
tensor->flat<int32>()(0) = value + offset_;
|
||||
} break;
|
||||
case DT_INT64: {
|
||||
int64 value;
|
||||
@ -352,7 +355,7 @@ Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
|
||||
// Helper function to initialize an InitializableLookupTable from a text file.
|
||||
Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
||||
char delimiter, int32 key_index,
|
||||
int32 value_index, Env* env,
|
||||
int32 value_index, int64 offset, Env* env,
|
||||
InitializableLookupTable* table) {
|
||||
if (key_index == kLineNumber && table->key_dtype() != DT_INT64) {
|
||||
return errors::InvalidArgument(
|
||||
@ -380,7 +383,8 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
||||
|
||||
TextFileLineIterator iter;
|
||||
TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype,
|
||||
key_index, value_dtype, value_index, env));
|
||||
key_index, value_dtype, value_index, offset,
|
||||
env));
|
||||
// For initialization from files, ignore if the table is already
|
||||
// initialized. The table shared name should contain the filename to
|
||||
// avoid trying to initialize the same table from the same file at the same
|
||||
|
@ -53,7 +53,7 @@ Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
|
||||
|
||||
Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
||||
char delimiter, int32 key_index,
|
||||
int32 value_index, Env* env,
|
||||
int32 value_index, int64 offset, Env* env,
|
||||
InitializableLookupTable* table);
|
||||
|
||||
// Initializes `table` from `dataset` by iterating over it. Caller retains
|
||||
|
@ -480,6 +480,7 @@ REGISTER_OP("InitializeTableFromTextFile")
|
||||
.Attr("value_index: int >= -2")
|
||||
.Attr("vocab_size: int >= -1 = -1")
|
||||
.Attr("delimiter: string = '\t'")
|
||||
.Attr("offset: int = 0")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle handle;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
|
||||
@ -497,6 +498,7 @@ REGISTER_OP("InitializeTableFromTextFileV2")
|
||||
.Attr("value_index: int >= -2")
|
||||
.Attr("vocab_size: int >= -1 = -1")
|
||||
.Attr("delimiter: string = '\t'")
|
||||
.Attr("offset: int = 0")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle handle;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
|
||||
|
@ -802,35 +802,46 @@ class TrackableWeightHandler(object):
|
||||
|
||||
# TODO(b/141682913): Figure out why this is private and fix it.
|
||||
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
|
||||
if len(saveables) != 1:
|
||||
raise ValueError('Only Trackables with one Saveable are supported.')
|
||||
saveable = list(saveables)[0]
|
||||
# 'Saveables' won't exist when we're passed a legacy TF1 table like
|
||||
# a StaticHashTable.
|
||||
if not saveables:
|
||||
self._num_tensors = 0
|
||||
self._setter = lambda weights: None
|
||||
self._getter = lambda: []
|
||||
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# If we're in eager mode, we need to defer calling the Trackable's
|
||||
# saveable() callable until data export time.
|
||||
# However, it is safe to call the saveable as many times as we want, so
|
||||
# we will call it now to figure out how many tensors this Trackable will
|
||||
# produce.
|
||||
self._saveable = saveable
|
||||
self._num_tensors = len(self._saveable().specs)
|
||||
self._setter = lambda weights: self._saveable().restore(weights, None)
|
||||
self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
|
||||
elif len(saveables) == 1:
|
||||
saveable = list(saveables)[0]
|
||||
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# If we're in eager mode, we need to defer calling the Trackable's
|
||||
# saveable() callable until data export time.
|
||||
# However, it is safe to call the saveable as many times as we want, so
|
||||
# we will call it now to figure out how many tensors this Trackable will
|
||||
# produce.
|
||||
self._saveable = saveable
|
||||
self._num_tensors = len(self._saveable().specs)
|
||||
self._setter = lambda weights: self._saveable().restore(weights, None)
|
||||
self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
|
||||
else:
|
||||
# If we're in Graph mode, we need to evaluate the Saveable only once and
|
||||
# cache the resulting restore graph. Failing to do this will result in
|
||||
# new assignment ops being added to the graph each time set_weights() is
|
||||
# called.
|
||||
self._placeholder_tensors = []
|
||||
self._saveable = saveable()
|
||||
self._num_tensors = len(self._saveable.specs)
|
||||
for spec in self._saveable.specs:
|
||||
tensor = spec.tensor
|
||||
self._placeholder_tensors.append(
|
||||
array_ops.placeholder(tensor.dtype, tensor.shape))
|
||||
self._assign_op = self._saveable.restore(self._placeholder_tensors,
|
||||
None)
|
||||
self._setter = self._set_weights_v1
|
||||
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
|
||||
else:
|
||||
# If we're in Graph mode, we need to evaluate the Saveable only once and
|
||||
# cache the resulting restore graph. Failing to do this will result in
|
||||
# new assignment ops being added to the graph each time set_weights() is
|
||||
# called.
|
||||
self._placeholder_tensors = []
|
||||
self._saveable = saveable()
|
||||
self._num_tensors = len(self._saveable.specs)
|
||||
for spec in self._saveable.specs:
|
||||
tensor = spec.tensor
|
||||
self._placeholder_tensors.append(
|
||||
array_ops.placeholder(tensor.dtype, tensor.shape))
|
||||
self._assign_op = self._saveable.restore(self._placeholder_tensors, None)
|
||||
self._setter = self._set_weights_v1
|
||||
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
|
||||
raise ValueError('Only Trackables with one Saveable are supported. '
|
||||
'The Trackable %s has %d Saveables.' %
|
||||
(trackable, len(saveables)))
|
||||
|
||||
@property
|
||||
def num_tensors(self):
|
||||
|
@ -80,6 +80,21 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "index_lookup_forward_benchmark",
|
||||
srcs = ["index_lookup_forward_benchmark.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:platform_benchmark",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/keras/layers/preprocessing:index_lookup",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "normalization_adapt_benchmark",
|
||||
srcs = ["normalization_adapt_benchmark.py"],
|
||||
|
@ -0,0 +1,146 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Benchmark for Keras text vectorization preprocessing layer's adapt method."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras.layers.preprocessing import index_lookup
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
||||
|
||||
# word_gen creates random sequences of ASCII letters (both lowercase and upper).
|
||||
# The number of unique strings is ~2,700.
|
||||
def tensor_gen(batch, num_elements):
|
||||
data = []
|
||||
for _ in range(batch):
|
||||
batch_element = []
|
||||
for _ in range(num_elements - 1):
|
||||
tok = "".join(random.choice(string.ascii_letters) for i in range(2))
|
||||
batch_element.append(tok)
|
||||
batch_element.append("") # Explicitly test the empty string.
|
||||
data.append(batch_element)
|
||||
return constant_op.constant(data)
|
||||
|
||||
|
||||
def get_vocab():
|
||||
vocab = list(
|
||||
set([a + b for a in string.ascii_letters for b in string.ascii_letters])) # pylint:disable=g-complex-comprehension
|
||||
vocab.sort()
|
||||
return vocab
|
||||
|
||||
|
||||
# This class uses TestCase for get_temp_dir().
|
||||
class BenchmarkLookup(benchmark.TensorFlowBenchmark):
|
||||
"""Benchmark the index lookup layer's forward pass."""
|
||||
|
||||
def _write_to_temp_file(self, file_name, vocab_list):
|
||||
vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt")
|
||||
with gfile.GFile(vocab_path, "w") as writer:
|
||||
for vocab in vocab_list:
|
||||
writer.write(vocab + "\n")
|
||||
writer.flush()
|
||||
writer.close()
|
||||
return vocab_path
|
||||
|
||||
def run_numpy_implementation(self, data, vocab):
|
||||
"""Test the python implementation."""
|
||||
input_t = keras.Input(shape=(), dtype=dtypes.string)
|
||||
layer = index_lookup.IndexLookup(
|
||||
vocabulary=vocab,
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="OOV",
|
||||
dtype=dtypes.string)
|
||||
out_t = layer(input_t)
|
||||
model = keras.Model(input_t, out_t)
|
||||
num_repeats = 5
|
||||
starts = []
|
||||
ends = []
|
||||
_ = model(data)
|
||||
for _ in range(num_repeats):
|
||||
starts.append(time.time())
|
||||
out = model(data)
|
||||
ends.append(time.time())
|
||||
avg_time = np.mean(np.array(ends) - np.array(starts))
|
||||
return avg_time, out
|
||||
|
||||
def bm_adapt_implementation(self, num_elements, batch_size):
|
||||
"""Test the KPL adapt implementation."""
|
||||
vocab = get_vocab()
|
||||
vocab_file = self._write_to_temp_file("vocab", vocab)
|
||||
vocabulary_initializer = lookup_ops.TextFileInitializer(
|
||||
filename=vocab_file,
|
||||
key_dtype=dtypes.string,
|
||||
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||
value_dtype=dtypes.int64,
|
||||
value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
|
||||
value_index_offset=2)
|
||||
input_t = keras.Input(shape=(), dtype=dtypes.string)
|
||||
layer = index_lookup.IndexLookup(
|
||||
vocabulary=vocabulary_initializer,
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="OOV",
|
||||
dtype=dtypes.string)
|
||||
out_t = layer(input_t)
|
||||
model = keras.Model(input_t, out_t)
|
||||
num_repeats = 5
|
||||
starts = []
|
||||
ends = []
|
||||
data = tensor_gen(batch_size, num_elements)
|
||||
_ = model(data)
|
||||
for _ in range(num_repeats):
|
||||
starts.append(time.time())
|
||||
_ = model(data)
|
||||
ends.append(time.time())
|
||||
avg_time = np.mean(np.array(ends) - np.array(starts))
|
||||
baseline, _ = self.run_numpy_implementation(data, vocab)
|
||||
extras = {
|
||||
"numpy implementation baseline": baseline,
|
||||
"delta seconds": (baseline - avg_time),
|
||||
"delta percent": ((baseline - avg_time) / baseline) * 100
|
||||
}
|
||||
name = "index_lookup_forward|%s_elements|batch_%s" % (num_elements,
|
||||
batch_size)
|
||||
self.report_benchmark(
|
||||
iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
|
||||
|
||||
def benchmark_vocab_size_by_batch(self):
|
||||
for tensor_size in [100, 1000, 10000]:
|
||||
for batch in [1, 16, 2048]:
|
||||
self.bm_adapt_implementation(tensor_size, batch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -184,18 +184,6 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
self._value_dtype = dtypes.int64
|
||||
oov_value = self._oov_value
|
||||
|
||||
self._table = lookup_ops.MutableHashTable(
|
||||
key_dtype=self._key_dtype,
|
||||
value_dtype=self._value_dtype,
|
||||
default_value=oov_value,
|
||||
name=(self._name + "_index_table"))
|
||||
tracked_table = self._add_trackable(self._table, trainable=False)
|
||||
# This is a workaround for summary() on this layer. Because the table is
|
||||
# not mutable during training, the effective number of parameters (and so
|
||||
# the weight shape) is 0; we add this as an attr so that the parameter
|
||||
# counting code in the Model object doesn't throw an attribute error.
|
||||
tracked_table.shape = tensor_shape.TensorShape((0,))
|
||||
|
||||
if self.num_oov_indices <= 1:
|
||||
oov_indices = None
|
||||
else:
|
||||
@ -203,13 +191,30 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
oov_end = oov_start + num_oov_indices
|
||||
oov_indices = list(range(oov_start, oov_end))
|
||||
|
||||
self._table_handler = table_utils.TableHandler(
|
||||
table=self._table,
|
||||
oov_tokens=oov_indices,
|
||||
use_v1_apis=self._use_v1_apis())
|
||||
|
||||
if vocabulary is not None:
|
||||
self.set_vocabulary(vocabulary)
|
||||
if vocabulary is not None and isinstance(vocabulary,
|
||||
lookup_ops.TextFileInitializer):
|
||||
self._table = self._static_table_class()(
|
||||
vocabulary, default_value=oov_value)
|
||||
self._table_handler = table_utils.TableHandler(
|
||||
table=self._table,
|
||||
mask_token=mask_token,
|
||||
oov_tokens=oov_indices,
|
||||
use_v1_apis=self._use_v1_apis())
|
||||
self.max_tokens = (
|
||||
self._table_handler.vocab_size() + self.num_oov_indices +
|
||||
(0 if mask_token is None else 1))
|
||||
else:
|
||||
self._table = lookup_ops.MutableHashTable(
|
||||
key_dtype=self._key_dtype,
|
||||
value_dtype=self._value_dtype,
|
||||
default_value=oov_value,
|
||||
name=(self._name + "_index_table"))
|
||||
self._table_handler = table_utils.TableHandler(
|
||||
table=self._table,
|
||||
oov_tokens=oov_indices,
|
||||
use_v1_apis=self._use_v1_apis())
|
||||
if vocabulary is not None:
|
||||
self.set_vocabulary(vocabulary)
|
||||
|
||||
if self.output_mode == TFIDF:
|
||||
# The TF-IDF weight may have a (None,) tensorshape. This creates
|
||||
@ -232,6 +237,13 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
dtype=K.floatx(),
|
||||
initializer=initializer)
|
||||
|
||||
tracked_table = self._add_trackable(self._table, trainable=False)
|
||||
# This is a workaround for summary() on this layer. Because the table is
|
||||
# not mutable during training, the effective number of parameters (and so
|
||||
# the weight shape) is 0; we add this as an attr so that the parameter
|
||||
# counting code in the Model object doesn't throw an attribute error.
|
||||
tracked_table.shape = tensor_shape.TensorShape((0,))
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
if self.output_mode != INT:
|
||||
return tensor_shape.TensorShape([input_shape[0], self.max_tokens])
|
||||
@ -538,6 +550,9 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
def _use_v1_apis(self):
|
||||
return False
|
||||
|
||||
def _static_table_class(self):
|
||||
return lookup_ops.StaticHashTable
|
||||
|
||||
|
||||
class _IndexLookupAccumulator(
|
||||
collections.namedtuple("Accumulator",
|
||||
|
@ -41,7 +41,9 @@ from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
|
||||
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
|
||||
from tensorflow.python.keras.saving import save
|
||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -703,6 +705,15 @@ class CategoricalEncodingAdaptTest(
|
||||
class IndexLookupOutputTest(keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
||||
def _write_to_temp_file(self, file_name, vocab_list):
|
||||
vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt")
|
||||
with gfile.GFile(vocab_path, "w") as writer:
|
||||
for vocab in vocab_list:
|
||||
writer.write(vocab + "\n")
|
||||
writer.flush()
|
||||
writer.close()
|
||||
return vocab_path
|
||||
|
||||
def test_int_output(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||
@ -958,7 +969,60 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
||||
layer_output = layer(input_data)
|
||||
self.assertAllEqual(layer_output.shape.as_list(), [16, 2])
|
||||
|
||||
def test_int_output_file_vocab(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "", "earth", "michigan"]])
|
||||
expected_output = [[2, 3, 4, 5], [5, 0, 2, 1]]
|
||||
|
||||
vocab_file = self._write_to_temp_file("temp", vocab_data)
|
||||
vocabulary_initializer = lookup_ops.TextFileInitializer(
|
||||
filename=vocab_file,
|
||||
key_dtype=dtypes.string,
|
||||
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||
value_dtype=dtypes.int64,
|
||||
value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
|
||||
value_index_offset=2)
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(
|
||||
vocabulary=vocabulary_initializer,
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="[OOV]",
|
||||
dtype=dtypes.string)
|
||||
int_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_int_output_int_file_vocab(self):
|
||||
vocab_data = ["10", "20", "30", "40"]
|
||||
input_array = np.array([[10, 20, 30, 40], [40, 0, 10, 42]])
|
||||
expected_output = [[2, 3, 4, 5], [5, 0, 2, 1]]
|
||||
|
||||
vocab_file = self._write_to_temp_file("temp", vocab_data)
|
||||
vocabulary_initializer = lookup_ops.TextFileInitializer(
|
||||
filename=vocab_file,
|
||||
key_dtype=dtypes.int64,
|
||||
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||
value_dtype=dtypes.int64,
|
||||
value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
|
||||
value_index_offset=2)
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
|
||||
layer = get_layer_class()(
|
||||
vocabulary=vocabulary_initializer,
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token=0,
|
||||
oov_token=-1,
|
||||
dtype=dtypes.int64)
|
||||
int_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class IndexLookupVocabularyTest(keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.engine import base_preprocessing_layer_v1
|
||||
from tensorflow.python.keras.layers.preprocessing import index_lookup
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
|
||||
|
||||
class IndexLookup(index_lookup.IndexLookup,
|
||||
@ -58,3 +59,6 @@ class IndexLookup(index_lookup.IndexLookup,
|
||||
|
||||
def _use_v1_apis(self):
|
||||
return True
|
||||
|
||||
def _static_table_class(self):
|
||||
return lookup_ops.StaticHashTableV1
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
@ -38,8 +39,24 @@ from tensorflow.python.platform import gfile
|
||||
class TableHandler(object):
|
||||
"""Wrapper object that holds a lookup table and provides accessors."""
|
||||
|
||||
def __init__(self, table, oov_tokens=None, use_v1_apis=False):
|
||||
def __init__(self,
|
||||
table,
|
||||
oov_tokens=None,
|
||||
mask_token=None,
|
||||
use_v1_apis=False):
|
||||
self.table = table
|
||||
|
||||
# If we are using V1 APIs, and the table has an initializer, we need to run
|
||||
# it. However, not all tables have initializers, so we try-except here.
|
||||
if use_v1_apis:
|
||||
try:
|
||||
K.get_session().run(self.table.initializer)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self.mutable = isinstance(table, lookup_ops.MutableHashTable)
|
||||
self.mask_token = mask_token
|
||||
|
||||
self.use_v1_apis = use_v1_apis
|
||||
if oov_tokens is None:
|
||||
self.oov_tokens = oov_tokens
|
||||
@ -56,10 +73,17 @@ class TableHandler(object):
|
||||
return self._eval(self.table.size())
|
||||
|
||||
def clear(self):
|
||||
if not self.mutable:
|
||||
return RuntimeError("Unable to clear a statically-backed table.")
|
||||
|
||||
keys, _ = self.table.export()
|
||||
self._run(self.table.remove(keys))
|
||||
|
||||
def insert(self, keys, values):
|
||||
"""Insert values into the backed table."""
|
||||
if not self.mutable:
|
||||
raise RuntimeError("Unable to insert into a statically-backed table.")
|
||||
|
||||
if len(values) != len(keys):
|
||||
raise RuntimeError("Size mismatch between values and key arrays. "
|
||||
"Keys had size %s, values had size %s." %
|
||||
@ -90,12 +114,35 @@ class TableHandler(object):
|
||||
|
||||
return array_ops.where(oov_locations, oov_values, lookups)
|
||||
|
||||
def _lookup_and_mask(self, inputs):
|
||||
"""Return a lookup with any location with the mask_token masked to 0."""
|
||||
lookups = self.table.lookup(inputs)
|
||||
# If we don't need to handle masking, return the lookup values directly.
|
||||
if self.mask_token is None:
|
||||
return lookups
|
||||
|
||||
# If we do need to handle masking, increment all the lookup values by 1
|
||||
# to account for the mask value at location 0. This also increments the
|
||||
# OOV value, so replace that. (This is inefficient, but we can't adjust
|
||||
# the table safely, so we don't have a choice.)
|
||||
oov_locations = math_ops.equal(lookups, self.table._default_value) # pylint: disable=protected-access
|
||||
oov_values = array_ops.ones_like(
|
||||
lookups, dtype=self.table._value_dtype) * self.table._default_value # pylint: disable=protected-access
|
||||
adjusted_lookups = array_ops.where(oov_locations, oov_values, lookups)
|
||||
|
||||
# Inject 0s wherever the mask token was in the inputs.
|
||||
mask_locations = math_ops.equal(inputs, self.mask_token)
|
||||
return array_ops.where(
|
||||
mask_locations,
|
||||
array_ops.zeros_like(lookups, dtype=self.table._value_dtype), # pylint: disable=protected-access
|
||||
adjusted_lookups) # pylint: disable=protected-access
|
||||
|
||||
def _ragged_lookup(self, inputs):
|
||||
"""Perform a table lookup on a ragged tensor."""
|
||||
# The table lookup ops don't natively support ragged tensors, so if we have
|
||||
# a RT we need to use map_flat_values to look up every element.
|
||||
indexed_data = ragged_functional_ops.map_flat_values(
|
||||
self.table.lookup, inputs)
|
||||
self._lookup_and_mask, inputs)
|
||||
indexed_data = ragged_functional_ops.map_flat_values(
|
||||
self._replace_oov_buckets, inputs, indexed_data)
|
||||
# table.lookup is not shape-preserving, so we need to set the shape here.
|
||||
@ -107,7 +154,7 @@ class TableHandler(object):
|
||||
|
||||
def _sparse_lookup(self, inputs):
|
||||
"""Perform a table lookup on a sparse tensor."""
|
||||
values = self.table.lookup(inputs.values)
|
||||
values = self._lookup_and_mask(inputs.values)
|
||||
values = self._replace_oov_buckets(inputs.values, values)
|
||||
indexed_data = sparse_tensor.SparseTensor(inputs.indices, values,
|
||||
inputs.dense_shape)
|
||||
@ -118,7 +165,7 @@ class TableHandler(object):
|
||||
|
||||
def _tensor_lookup(self, inputs):
|
||||
"""Perform a table lookup on a tf.tensor."""
|
||||
values = self.table.lookup(inputs)
|
||||
values = self._lookup_and_mask(inputs)
|
||||
indexed_data = self._replace_oov_buckets(inputs, values)
|
||||
# (b/149446477): output does not preserve input shape.
|
||||
indexed_data.set_shape(inputs.shape)
|
||||
|
@ -45,6 +45,41 @@ def get_table(dtype=dtypes.string, oov_tokens=None):
|
||||
table, oov_tokens, use_v1_apis=(not context.executing_eagerly()))
|
||||
|
||||
|
||||
def get_static_table(tmpdir,
|
||||
vocab_list,
|
||||
mask_token=None,
|
||||
dtype=dtypes.string,
|
||||
oov_tokens=None):
|
||||
vocabulary_file = os.path.join(tmpdir, "tmp_vocab.txt")
|
||||
|
||||
if dtype == dtypes.string:
|
||||
with open(vocabulary_file, "w") as f:
|
||||
f.write("\n".join(vocab_list) + "\n")
|
||||
else:
|
||||
with open(vocabulary_file, "w") as f:
|
||||
f.write("\n".join([str(v) for v in vocab_list]) + "\n")
|
||||
|
||||
offset = ((0 if mask_token is None else 1) +
|
||||
(len(oov_tokens) if oov_tokens is not None else 0))
|
||||
init = lookup_ops.TextFileInitializer(
|
||||
vocabulary_file,
|
||||
dtype,
|
||||
lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||
dtypes.int64,
|
||||
lookup_ops.TextFileIndex.LINE_NUMBER,
|
||||
value_index_offset=offset)
|
||||
if context.executing_eagerly():
|
||||
table = lookup_ops.StaticHashTable(init, default_value=-7)
|
||||
else:
|
||||
table = lookup_ops.StaticHashTableV1(init, default_value=-7)
|
||||
|
||||
return table_utils.TableHandler(
|
||||
table,
|
||||
oov_tokens,
|
||||
mask_token=mask_token,
|
||||
use_v1_apis=(not context.executing_eagerly()))
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class CategoricalEncodingInputTest(
|
||||
keras_parameterized.TestCase,
|
||||
@ -252,6 +287,132 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
||||
self.assertAllEqual(expected_output, output_data)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class StaticIndexLookupOutputTest(
|
||||
keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
||||
def test_int_output_default_lookup_value(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", "michigan"]])
|
||||
expected_output = [[1, 2, 3, 4], [4, 3, 1, -7]]
|
||||
|
||||
table = get_static_table(
|
||||
tmpdir=self.get_temp_dir(),
|
||||
vocab_list=vocab_data,
|
||||
mask_token="",
|
||||
oov_tokens=None)
|
||||
output_data = table.lookup(input_array)
|
||||
|
||||
self.assertAllEqual(expected_output, output_data)
|
||||
|
||||
def test_output_shape(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", "michigan"]])
|
||||
|
||||
table = get_static_table(
|
||||
tmpdir=self.get_temp_dir(), vocab_list=vocab_data, oov_tokens=None)
|
||||
output_data = table.lookup(input_array)
|
||||
|
||||
self.assertAllEqual(input_array.shape[1:], output_data.shape[1:])
|
||||
|
||||
def test_int_output_no_reserved_zero_default_lookup_value(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", "michigan"]])
|
||||
expected_output = [[0, 1, 2, 3], [3, 2, 0, -7]]
|
||||
|
||||
table = get_static_table(
|
||||
tmpdir=self.get_temp_dir(), vocab_list=vocab_data, oov_tokens=None)
|
||||
output_data = table.lookup(input_array)
|
||||
|
||||
self.assertAllEqual(expected_output, output_data)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class CategoricalEncodingStaticInputTest(
|
||||
keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
||||
def test_sparse_string_input(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]],
|
||||
values=["fire", "michigan"],
|
||||
dense_shape=[3, 4])
|
||||
|
||||
expected_indices = [[0, 0], [1, 2]]
|
||||
expected_values = [5, 1]
|
||||
expected_dense_shape = [3, 4]
|
||||
|
||||
table = get_static_table(
|
||||
tmpdir=self.get_temp_dir(),
|
||||
vocab_list=vocab_data,
|
||||
mask_token="",
|
||||
oov_tokens=[1])
|
||||
output_data = table.lookup(input_array)
|
||||
|
||||
self.assertAllEqual(expected_indices, output_data.indices)
|
||||
self.assertAllEqual(expected_values, output_data.values)
|
||||
self.assertAllEqual(expected_dense_shape, output_data.dense_shape)
|
||||
|
||||
def test_sparse_int_input(self):
|
||||
vocab_data = np.array([10, 11, 12, 13], dtype=np.int64)
|
||||
input_array = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]],
|
||||
values=np.array([13, 32], dtype=np.int64),
|
||||
dense_shape=[3, 4])
|
||||
|
||||
expected_indices = [[0, 0], [1, 2]]
|
||||
expected_values = [5, 1]
|
||||
expected_dense_shape = [3, 4]
|
||||
|
||||
table = get_static_table(
|
||||
tmpdir=self.get_temp_dir(),
|
||||
vocab_list=vocab_data,
|
||||
dtype=dtypes.int64,
|
||||
mask_token=0,
|
||||
oov_tokens=[1])
|
||||
output_data = table.lookup(input_array)
|
||||
|
||||
self.assertAllEqual(expected_indices, output_data.indices)
|
||||
self.assertAllEqual(expected_values, output_data.values)
|
||||
self.assertAllEqual(expected_dense_shape, output_data.dense_shape)
|
||||
|
||||
def test_ragged_string_input(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = ragged_factory_ops.constant(
|
||||
[["earth", "wind", "fire"], ["fire", "and", "earth", "michigan"]])
|
||||
expected_output = [[2, 3, 5], [5, 4, 2, 1]]
|
||||
|
||||
table = get_static_table(
|
||||
tmpdir=self.get_temp_dir(),
|
||||
vocab_list=vocab_data,
|
||||
mask_token="",
|
||||
oov_tokens=[1])
|
||||
output_data = table.lookup(input_array)
|
||||
|
||||
self.assertAllEqual(expected_output, output_data)
|
||||
|
||||
def test_ragged_int_input(self):
|
||||
vocab_data = np.array([10, 11, 12, 13], dtype=np.int64)
|
||||
input_array = ragged_factory_ops.constant([[10, 11, 13], [13, 12, 10, 42]],
|
||||
dtype=np.int64)
|
||||
expected_output = [[2, 3, 5], [5, 4, 2, 1]]
|
||||
|
||||
table = get_static_table(
|
||||
tmpdir=self.get_temp_dir(),
|
||||
vocab_list=vocab_data,
|
||||
dtype=dtypes.int64,
|
||||
mask_token=0,
|
||||
oov_tokens=[1])
|
||||
output_data = table.lookup(input_array)
|
||||
|
||||
self.assertAllEqual(expected_output, output_data)
|
||||
|
||||
|
||||
class GetVocabularyFromFileTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -647,7 +647,8 @@ class TextFileInitializer(TableInitializerBase):
|
||||
value_index,
|
||||
vocab_size=None,
|
||||
delimiter="\t",
|
||||
name=None):
|
||||
name=None,
|
||||
value_index_offset=0):
|
||||
"""Constructs a table initializer object to populate from a text file.
|
||||
|
||||
It generates one key-value pair per line. The type of table key and
|
||||
@ -675,6 +676,13 @@ class TextFileInitializer(TableInitializerBase):
|
||||
vocab_size: The number of elements in the file, if known.
|
||||
delimiter: The delimiter to separate fields in a line.
|
||||
name: A name for the operation (optional).
|
||||
value_index_offset: A number to add to all indices extracted from the file
|
||||
This is useful for cases where a user would like to reserve one or more
|
||||
low index values for control characters. For instance, if you would
|
||||
like to ensure that no vocabulary item is mapped to index 0 (so you can
|
||||
reserve 0 for a masking value), you can set value_index_offset to 1;
|
||||
this will mean that the first vocabulary element is mapped to 1
|
||||
instead of 0.
|
||||
|
||||
Raises:
|
||||
ValueError: when the filename is empty, or when the table key and value
|
||||
@ -718,6 +726,7 @@ class TextFileInitializer(TableInitializerBase):
|
||||
self._name = name
|
||||
self._filename = self._track_trackable(
|
||||
trackable.Asset(filename), "_filename")
|
||||
self._offset = value_index_offset
|
||||
|
||||
super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
|
||||
|
||||
@ -740,7 +749,8 @@ class TextFileInitializer(TableInitializerBase):
|
||||
self._filename, dtypes.string, name="asset_filepath")
|
||||
init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
|
||||
table.resource_handle, filename, self._key_index, self._value_index,
|
||||
-1 if self._vocab_size is None else self._vocab_size, self._delimiter)
|
||||
-1 if self._vocab_size is None else self._vocab_size, self._delimiter,
|
||||
self._offset)
|
||||
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
|
||||
# If the filename tensor is anything other than a string constant (e.g.,
|
||||
# if it is a placeholder) then it does not make sense to track it as an
|
||||
|
@ -14,7 +14,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\', \'value_index_offset\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\', \'0\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "initialize"
|
||||
|
@ -1950,11 +1950,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "InitializeTableFromTextFile"
|
||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
|
||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "InitializeTableFromTextFileV2"
|
||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
|
||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "InitializeTableV2"
|
||||
|
@ -14,7 +14,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\', \'value_index_offset\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\', \'0\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "initialize"
|
||||
|
@ -1950,11 +1950,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "InitializeTableFromTextFile"
|
||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
|
||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "InitializeTableFromTextFileV2"
|
||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
|
||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "InitializeTableV2"
|
||||
|
Loading…
Reference in New Issue
Block a user