diff --git a/tensorflow/core/kernels/generate_vocab_remapping_op.cc b/tensorflow/core/kernels/generate_vocab_remapping_op.cc index d4cf83896f5..e60abc45acb 100644 --- a/tensorflow/core/kernels/generate_vocab_remapping_op.cc +++ b/tensorflow/core/kernels/generate_vocab_remapping_op.cc @@ -72,6 +72,7 @@ class GenerateVocabRemappingOp : public OpKernel { kUnusedLookupDelim, -1, // key_index, use the line number. -2, // value_index, use the whole line/token. + 0, // No offset. context->env(), new_vocab_table)); OP_REQUIRES(context, new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(), @@ -101,6 +102,7 @@ class GenerateVocabRemappingOp : public OpKernel { old_vocab_filename, old_vocab_size_, kUnusedLookupDelim, -2, // key_index, use the whole line/token. -1, // value_index, use the line number. + 0, // No offset. context->env(), old_vocab_table)); // Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ..., diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc index cb757ac930b..d21ac547db2 100644 --- a/tensorflow/core/kernels/lookup_table_init_op.cc +++ b/tensorflow/core/kernels/lookup_table_init_op.cc @@ -105,6 +105,7 @@ class InitializeTableFromTextFileOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_)); string delimiter; OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter)); OP_REQUIRES(ctx, delimiter.size() == 1, @@ -141,7 +142,7 @@ class InitializeTableFromTextFileOp : public OpKernel { } OP_REQUIRES_OK(ctx, lookup::InitializeTableFromTextFile( vocab_filename, vocab_size_, delimiter_, key_index_, - value_index_, ctx->env(), table)); + value_index_, offset_, ctx->env(), table)); if (ctx->track_allocations()) { ctx->record_persistent_memory_allocation(table->MemoryUsed() - memory_used_before); @@ -154,6 +155,7 @@ class InitializeTableFromTextFileOp : public OpKernel { char delimiter_; int64 key_index_; int64 value_index_; + int64 offset_; TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp); }; diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index d07b525a6bd..aa39063b2dc 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -77,7 +77,7 @@ class TextFileLineIterator // delimiter. Status Init(const string& filename, int64 vocab_size, char delimiter, DataType key_dtype, int64 key_index, DataType value_dtype, - int64 value_index, Env* env) { + int64 value_index, int64 offset, Env* env) { filename_ = filename; vocab_size_ = vocab_size; delimiter_ = delimiter; @@ -93,6 +93,7 @@ class TextFileLineIterator input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize)); valid_ = true; next_id_ = 0; + offset_ = offset; ignore_split_ = std::max(key_index_, value_index_) < 0; Next(); return status_; @@ -143,6 +144,7 @@ class TextFileLineIterator return; } } + status_ = SetValue(line, tokens, key_index_, &key_); if (!status_.ok()) { valid_ = false; @@ -186,6 +188,7 @@ class TextFileLineIterator int64 value_index_; Env* env_; int64 next_id_; + int64 offset_; int64 vocab_size_; string filename_; char delimiter_; @@ -199,7 +202,7 @@ class TextFileLineIterator Status SetValue(const string& line, const std::vector& tokens, int64 index, Tensor* tensor) { if (index == kLineNumber) { - tensor->flat()(0) = next_id_; + tensor->flat()(0) = next_id_ + offset_; return Status::OK(); } const string& token = (index == kWholeLine) ? line : tokens[index]; @@ -212,7 +215,7 @@ class TextFileLineIterator return errors::InvalidArgument("Field ", token, " in line ", next_id_, " is not a valid int32."); } - tensor->flat()(0) = value; + tensor->flat()(0) = value + offset_; } break; case DT_INT64: { int64 value; @@ -352,7 +355,7 @@ Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, // Helper function to initialize an InitializableLookupTable from a text file. Status InitializeTableFromTextFile(const string& filename, int64 vocab_size, char delimiter, int32 key_index, - int32 value_index, Env* env, + int32 value_index, int64 offset, Env* env, InitializableLookupTable* table) { if (key_index == kLineNumber && table->key_dtype() != DT_INT64) { return errors::InvalidArgument( @@ -380,7 +383,8 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size, TextFileLineIterator iter; TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype, - key_index, value_dtype, value_index, env)); + key_index, value_dtype, value_index, offset, + env)); // For initialization from files, ignore if the table is already // initialized. The table shared name should contain the filename to // avoid trying to initialize the same table from the same file at the same diff --git a/tensorflow/core/kernels/lookup_util.h b/tensorflow/core/kernels/lookup_util.h index 7e53ed5db51..26974ab2acb 100644 --- a/tensorflow/core/kernels/lookup_util.h +++ b/tensorflow/core/kernels/lookup_util.h @@ -53,7 +53,7 @@ Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, Status InitializeTableFromTextFile(const string& filename, int64 vocab_size, char delimiter, int32 key_index, - int32 value_index, Env* env, + int32 value_index, int64 offset, Env* env, InitializableLookupTable* table); // Initializes `table` from `dataset` by iterating over it. Caller retains diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc index 05aa229336d..f9aea527846 100644 --- a/tensorflow/core/ops/lookup_ops.cc +++ b/tensorflow/core/ops/lookup_ops.cc @@ -480,6 +480,7 @@ REGISTER_OP("InitializeTableFromTextFile") .Attr("value_index: int >= -2") .Attr("vocab_size: int >= -1 = -1") .Attr("delimiter: string = '\t'") + .Attr("offset: int = 0") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); @@ -497,6 +498,7 @@ REGISTER_OP("InitializeTableFromTextFileV2") .Attr("value_index: int >= -2") .Attr("vocab_size: int >= -1 = -1") .Attr("delimiter: string = '\t'") + .Attr("offset: int = 0") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index a098cdbd729..c9c04994940 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -802,35 +802,46 @@ class TrackableWeightHandler(object): # TODO(b/141682913): Figure out why this is private and fix it. saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access - if len(saveables) != 1: - raise ValueError('Only Trackables with one Saveable are supported.') - saveable = list(saveables)[0] + # 'Saveables' won't exist when we're passed a legacy TF1 table like + # a StaticHashTable. + if not saveables: + self._num_tensors = 0 + self._setter = lambda weights: None + self._getter = lambda: [] - if ops.executing_eagerly_outside_functions(): - # If we're in eager mode, we need to defer calling the Trackable's - # saveable() callable until data export time. - # However, it is safe to call the saveable as many times as we want, so - # we will call it now to figure out how many tensors this Trackable will - # produce. - self._saveable = saveable - self._num_tensors = len(self._saveable().specs) - self._setter = lambda weights: self._saveable().restore(weights, None) - self._getter = lambda: [spec.tensor for spec in self._saveable().specs] + elif len(saveables) == 1: + saveable = list(saveables)[0] + + if ops.executing_eagerly_outside_functions(): + # If we're in eager mode, we need to defer calling the Trackable's + # saveable() callable until data export time. + # However, it is safe to call the saveable as many times as we want, so + # we will call it now to figure out how many tensors this Trackable will + # produce. + self._saveable = saveable + self._num_tensors = len(self._saveable().specs) + self._setter = lambda weights: self._saveable().restore(weights, None) + self._getter = lambda: [spec.tensor for spec in self._saveable().specs] + else: + # If we're in Graph mode, we need to evaluate the Saveable only once and + # cache the resulting restore graph. Failing to do this will result in + # new assignment ops being added to the graph each time set_weights() is + # called. + self._placeholder_tensors = [] + self._saveable = saveable() + self._num_tensors = len(self._saveable.specs) + for spec in self._saveable.specs: + tensor = spec.tensor + self._placeholder_tensors.append( + array_ops.placeholder(tensor.dtype, tensor.shape)) + self._assign_op = self._saveable.restore(self._placeholder_tensors, + None) + self._setter = self._set_weights_v1 + self._getter = lambda: [spec.tensor for spec in self._saveable.specs] else: - # If we're in Graph mode, we need to evaluate the Saveable only once and - # cache the resulting restore graph. Failing to do this will result in - # new assignment ops being added to the graph each time set_weights() is - # called. - self._placeholder_tensors = [] - self._saveable = saveable() - self._num_tensors = len(self._saveable.specs) - for spec in self._saveable.specs: - tensor = spec.tensor - self._placeholder_tensors.append( - array_ops.placeholder(tensor.dtype, tensor.shape)) - self._assign_op = self._saveable.restore(self._placeholder_tensors, None) - self._setter = self._set_weights_v1 - self._getter = lambda: [spec.tensor for spec in self._saveable.specs] + raise ValueError('Only Trackables with one Saveable are supported. ' + 'The Trackable %s has %d Saveables.' % + (trackable, len(saveables))) @property def num_tensors(self): diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD index f46c06dd366..11776603ff4 100644 --- a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD @@ -80,6 +80,21 @@ tf_py_test( ], ) +tf_py_test( + name = "index_lookup_forward_benchmark", + srcs = ["index_lookup_forward_benchmark.py"], + python_version = "PY3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/keras/layers/preprocessing:index_lookup", + ], +) + tf_py_test( name = "normalization_adapt_benchmark", srcs = ["normalization_adapt_benchmark.py"], diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py new file mode 100644 index 00000000000..0e264fb49b8 --- /dev/null +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py @@ -0,0 +1,146 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark for Keras text vectorization preprocessing layer's adapt method.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import random +import string +import time + +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.compat import v2_compat +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.keras.layers.preprocessing import index_lookup +from tensorflow.python.ops import lookup_ops +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test + +v2_compat.enable_v2_behavior() + + +# word_gen creates random sequences of ASCII letters (both lowercase and upper). +# The number of unique strings is ~2,700. +def tensor_gen(batch, num_elements): + data = [] + for _ in range(batch): + batch_element = [] + for _ in range(num_elements - 1): + tok = "".join(random.choice(string.ascii_letters) for i in range(2)) + batch_element.append(tok) + batch_element.append("") # Explicitly test the empty string. + data.append(batch_element) + return constant_op.constant(data) + + +def get_vocab(): + vocab = list( + set([a + b for a in string.ascii_letters for b in string.ascii_letters])) # pylint:disable=g-complex-comprehension + vocab.sort() + return vocab + + +# This class uses TestCase for get_temp_dir(). +class BenchmarkLookup(benchmark.TensorFlowBenchmark): + """Benchmark the index lookup layer's forward pass.""" + + def _write_to_temp_file(self, file_name, vocab_list): + vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt") + with gfile.GFile(vocab_path, "w") as writer: + for vocab in vocab_list: + writer.write(vocab + "\n") + writer.flush() + writer.close() + return vocab_path + + def run_numpy_implementation(self, data, vocab): + """Test the python implementation.""" + input_t = keras.Input(shape=(), dtype=dtypes.string) + layer = index_lookup.IndexLookup( + vocabulary=vocab, + max_tokens=None, + num_oov_indices=1, + mask_token="", + oov_token="OOV", + dtype=dtypes.string) + out_t = layer(input_t) + model = keras.Model(input_t, out_t) + num_repeats = 5 + starts = [] + ends = [] + _ = model(data) + for _ in range(num_repeats): + starts.append(time.time()) + out = model(data) + ends.append(time.time()) + avg_time = np.mean(np.array(ends) - np.array(starts)) + return avg_time, out + + def bm_adapt_implementation(self, num_elements, batch_size): + """Test the KPL adapt implementation.""" + vocab = get_vocab() + vocab_file = self._write_to_temp_file("vocab", vocab) + vocabulary_initializer = lookup_ops.TextFileInitializer( + filename=vocab_file, + key_dtype=dtypes.string, + key_index=lookup_ops.TextFileIndex.WHOLE_LINE, + value_dtype=dtypes.int64, + value_index=lookup_ops.TextFileIndex.LINE_NUMBER, + value_index_offset=2) + input_t = keras.Input(shape=(), dtype=dtypes.string) + layer = index_lookup.IndexLookup( + vocabulary=vocabulary_initializer, + max_tokens=None, + num_oov_indices=1, + mask_token="", + oov_token="OOV", + dtype=dtypes.string) + out_t = layer(input_t) + model = keras.Model(input_t, out_t) + num_repeats = 5 + starts = [] + ends = [] + data = tensor_gen(batch_size, num_elements) + _ = model(data) + for _ in range(num_repeats): + starts.append(time.time()) + _ = model(data) + ends.append(time.time()) + avg_time = np.mean(np.array(ends) - np.array(starts)) + baseline, _ = self.run_numpy_implementation(data, vocab) + extras = { + "numpy implementation baseline": baseline, + "delta seconds": (baseline - avg_time), + "delta percent": ((baseline - avg_time) / baseline) * 100 + } + name = "index_lookup_forward|%s_elements|batch_%s" % (num_elements, + batch_size) + self.report_benchmark( + iters=num_repeats, wall_time=avg_time, extras=extras, name=name) + + def benchmark_vocab_size_by_batch(self): + for tensor_size in [100, 1000, 10000]: + for batch in [1, 16, 2048]: + self.bm_adapt_implementation(tensor_size, batch) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup.py b/tensorflow/python/keras/layers/preprocessing/index_lookup.py index f985769070b..46c37db190b 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup.py @@ -184,18 +184,6 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): self._value_dtype = dtypes.int64 oov_value = self._oov_value - self._table = lookup_ops.MutableHashTable( - key_dtype=self._key_dtype, - value_dtype=self._value_dtype, - default_value=oov_value, - name=(self._name + "_index_table")) - tracked_table = self._add_trackable(self._table, trainable=False) - # This is a workaround for summary() on this layer. Because the table is - # not mutable during training, the effective number of parameters (and so - # the weight shape) is 0; we add this as an attr so that the parameter - # counting code in the Model object doesn't throw an attribute error. - tracked_table.shape = tensor_shape.TensorShape((0,)) - if self.num_oov_indices <= 1: oov_indices = None else: @@ -203,13 +191,30 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): oov_end = oov_start + num_oov_indices oov_indices = list(range(oov_start, oov_end)) - self._table_handler = table_utils.TableHandler( - table=self._table, - oov_tokens=oov_indices, - use_v1_apis=self._use_v1_apis()) - - if vocabulary is not None: - self.set_vocabulary(vocabulary) + if vocabulary is not None and isinstance(vocabulary, + lookup_ops.TextFileInitializer): + self._table = self._static_table_class()( + vocabulary, default_value=oov_value) + self._table_handler = table_utils.TableHandler( + table=self._table, + mask_token=mask_token, + oov_tokens=oov_indices, + use_v1_apis=self._use_v1_apis()) + self.max_tokens = ( + self._table_handler.vocab_size() + self.num_oov_indices + + (0 if mask_token is None else 1)) + else: + self._table = lookup_ops.MutableHashTable( + key_dtype=self._key_dtype, + value_dtype=self._value_dtype, + default_value=oov_value, + name=(self._name + "_index_table")) + self._table_handler = table_utils.TableHandler( + table=self._table, + oov_tokens=oov_indices, + use_v1_apis=self._use_v1_apis()) + if vocabulary is not None: + self.set_vocabulary(vocabulary) if self.output_mode == TFIDF: # The TF-IDF weight may have a (None,) tensorshape. This creates @@ -232,6 +237,13 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): dtype=K.floatx(), initializer=initializer) + tracked_table = self._add_trackable(self._table, trainable=False) + # This is a workaround for summary() on this layer. Because the table is + # not mutable during training, the effective number of parameters (and so + # the weight shape) is 0; we add this as an attr so that the parameter + # counting code in the Model object doesn't throw an attribute error. + tracked_table.shape = tensor_shape.TensorShape((0,)) + def compute_output_shape(self, input_shape): if self.output_mode != INT: return tensor_shape.TensorShape([input_shape[0], self.max_tokens]) @@ -538,6 +550,9 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): def _use_v1_apis(self): return False + def _static_table_class(self): + return lookup_ops.StaticHashTable + class _IndexLookupAccumulator( collections.namedtuple("Accumulator", diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py index b845dd4cf28..eb609634e2a 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py @@ -41,7 +41,9 @@ from tensorflow.python.keras.layers.preprocessing import index_lookup_v1 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.saving import save from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -703,6 +705,15 @@ class CategoricalEncodingAdaptTest( class IndexLookupOutputTest(keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest): + def _write_to_temp_file(self, file_name, vocab_list): + vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt") + with gfile.GFile(vocab_path, "w") as writer: + for vocab in vocab_list: + writer.write(vocab + "\n") + writer.flush() + writer.close() + return vocab_path + def test_int_output(self): vocab_data = ["earth", "wind", "and", "fire"] input_array = np.array([["earth", "wind", "and", "fire"], @@ -958,7 +969,60 @@ class IndexLookupOutputTest(keras_parameterized.TestCase, layer_output = layer(input_data) self.assertAllEqual(layer_output.shape.as_list(), [16, 2]) + def test_int_output_file_vocab(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = np.array([["earth", "wind", "and", "fire"], + ["fire", "", "earth", "michigan"]]) + expected_output = [[2, 3, 4, 5], [5, 0, 2, 1]] + vocab_file = self._write_to_temp_file("temp", vocab_data) + vocabulary_initializer = lookup_ops.TextFileInitializer( + filename=vocab_file, + key_dtype=dtypes.string, + key_index=lookup_ops.TextFileIndex.WHOLE_LINE, + value_dtype=dtypes.int64, + value_index=lookup_ops.TextFileIndex.LINE_NUMBER, + value_index_offset=2) + + input_data = keras.Input(shape=(None,), dtype=dtypes.string) + layer = get_layer_class()( + vocabulary=vocabulary_initializer, + max_tokens=None, + num_oov_indices=1, + mask_token="", + oov_token="[OOV]", + dtype=dtypes.string) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + + def test_int_output_int_file_vocab(self): + vocab_data = ["10", "20", "30", "40"] + input_array = np.array([[10, 20, 30, 40], [40, 0, 10, 42]]) + expected_output = [[2, 3, 4, 5], [5, 0, 2, 1]] + + vocab_file = self._write_to_temp_file("temp", vocab_data) + vocabulary_initializer = lookup_ops.TextFileInitializer( + filename=vocab_file, + key_dtype=dtypes.int64, + key_index=lookup_ops.TextFileIndex.WHOLE_LINE, + value_dtype=dtypes.int64, + value_index=lookup_ops.TextFileIndex.LINE_NUMBER, + value_index_offset=2) + + input_data = keras.Input(shape=(None,), dtype=dtypes.int64) + layer = get_layer_class()( + vocabulary=vocabulary_initializer, + max_tokens=None, + num_oov_indices=1, + mask_token=0, + oov_token=-1, + dtype=dtypes.int64) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) @keras_parameterized.run_all_keras_modes class IndexLookupVocabularyTest(keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py index 47fea11dd57..c710108dd5b 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.keras.engine import base_preprocessing_layer_v1 from tensorflow.python.keras.layers.preprocessing import index_lookup +from tensorflow.python.ops import lookup_ops class IndexLookup(index_lookup.IndexLookup, @@ -58,3 +59,6 @@ class IndexLookup(index_lookup.IndexLookup, def _use_v1_apis(self): return True + + def _static_table_class(self): + return lookup_ops.StaticHashTableV1 diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils.py b/tensorflow/python/keras/layers/preprocessing/table_utils.py index 56b07fb4270..e7fe9177174 100644 --- a/tensorflow/python/keras/layers/preprocessing/table_utils.py +++ b/tensorflow/python/keras/layers/preprocessing/table_utils.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import backend as K from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops.ragged import ragged_functional_ops @@ -38,8 +39,24 @@ from tensorflow.python.platform import gfile class TableHandler(object): """Wrapper object that holds a lookup table and provides accessors.""" - def __init__(self, table, oov_tokens=None, use_v1_apis=False): + def __init__(self, + table, + oov_tokens=None, + mask_token=None, + use_v1_apis=False): self.table = table + + # If we are using V1 APIs, and the table has an initializer, we need to run + # it. However, not all tables have initializers, so we try-except here. + if use_v1_apis: + try: + K.get_session().run(self.table.initializer) + except AttributeError: + pass + + self.mutable = isinstance(table, lookup_ops.MutableHashTable) + self.mask_token = mask_token + self.use_v1_apis = use_v1_apis if oov_tokens is None: self.oov_tokens = oov_tokens @@ -56,10 +73,17 @@ class TableHandler(object): return self._eval(self.table.size()) def clear(self): + if not self.mutable: + return RuntimeError("Unable to clear a statically-backed table.") + keys, _ = self.table.export() self._run(self.table.remove(keys)) def insert(self, keys, values): + """Insert values into the backed table.""" + if not self.mutable: + raise RuntimeError("Unable to insert into a statically-backed table.") + if len(values) != len(keys): raise RuntimeError("Size mismatch between values and key arrays. " "Keys had size %s, values had size %s." % @@ -90,12 +114,35 @@ class TableHandler(object): return array_ops.where(oov_locations, oov_values, lookups) + def _lookup_and_mask(self, inputs): + """Return a lookup with any location with the mask_token masked to 0.""" + lookups = self.table.lookup(inputs) + # If we don't need to handle masking, return the lookup values directly. + if self.mask_token is None: + return lookups + + # If we do need to handle masking, increment all the lookup values by 1 + # to account for the mask value at location 0. This also increments the + # OOV value, so replace that. (This is inefficient, but we can't adjust + # the table safely, so we don't have a choice.) + oov_locations = math_ops.equal(lookups, self.table._default_value) # pylint: disable=protected-access + oov_values = array_ops.ones_like( + lookups, dtype=self.table._value_dtype) * self.table._default_value # pylint: disable=protected-access + adjusted_lookups = array_ops.where(oov_locations, oov_values, lookups) + + # Inject 0s wherever the mask token was in the inputs. + mask_locations = math_ops.equal(inputs, self.mask_token) + return array_ops.where( + mask_locations, + array_ops.zeros_like(lookups, dtype=self.table._value_dtype), # pylint: disable=protected-access + adjusted_lookups) # pylint: disable=protected-access + def _ragged_lookup(self, inputs): """Perform a table lookup on a ragged tensor.""" # The table lookup ops don't natively support ragged tensors, so if we have # a RT we need to use map_flat_values to look up every element. indexed_data = ragged_functional_ops.map_flat_values( - self.table.lookup, inputs) + self._lookup_and_mask, inputs) indexed_data = ragged_functional_ops.map_flat_values( self._replace_oov_buckets, inputs, indexed_data) # table.lookup is not shape-preserving, so we need to set the shape here. @@ -107,7 +154,7 @@ class TableHandler(object): def _sparse_lookup(self, inputs): """Perform a table lookup on a sparse tensor.""" - values = self.table.lookup(inputs.values) + values = self._lookup_and_mask(inputs.values) values = self._replace_oov_buckets(inputs.values, values) indexed_data = sparse_tensor.SparseTensor(inputs.indices, values, inputs.dense_shape) @@ -118,7 +165,7 @@ class TableHandler(object): def _tensor_lookup(self, inputs): """Perform a table lookup on a tf.tensor.""" - values = self.table.lookup(inputs) + values = self._lookup_and_mask(inputs) indexed_data = self._replace_oov_buckets(inputs, values) # (b/149446477): output does not preserve input shape. indexed_data.set_shape(inputs.shape) diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils_test.py b/tensorflow/python/keras/layers/preprocessing/table_utils_test.py index 05b18d1924a..d23eb975dc1 100644 --- a/tensorflow/python/keras/layers/preprocessing/table_utils_test.py +++ b/tensorflow/python/keras/layers/preprocessing/table_utils_test.py @@ -45,6 +45,41 @@ def get_table(dtype=dtypes.string, oov_tokens=None): table, oov_tokens, use_v1_apis=(not context.executing_eagerly())) +def get_static_table(tmpdir, + vocab_list, + mask_token=None, + dtype=dtypes.string, + oov_tokens=None): + vocabulary_file = os.path.join(tmpdir, "tmp_vocab.txt") + + if dtype == dtypes.string: + with open(vocabulary_file, "w") as f: + f.write("\n".join(vocab_list) + "\n") + else: + with open(vocabulary_file, "w") as f: + f.write("\n".join([str(v) for v in vocab_list]) + "\n") + + offset = ((0 if mask_token is None else 1) + + (len(oov_tokens) if oov_tokens is not None else 0)) + init = lookup_ops.TextFileInitializer( + vocabulary_file, + dtype, + lookup_ops.TextFileIndex.WHOLE_LINE, + dtypes.int64, + lookup_ops.TextFileIndex.LINE_NUMBER, + value_index_offset=offset) + if context.executing_eagerly(): + table = lookup_ops.StaticHashTable(init, default_value=-7) + else: + table = lookup_ops.StaticHashTableV1(init, default_value=-7) + + return table_utils.TableHandler( + table, + oov_tokens, + mask_token=mask_token, + use_v1_apis=(not context.executing_eagerly())) + + @keras_parameterized.run_all_keras_modes class CategoricalEncodingInputTest( keras_parameterized.TestCase, @@ -252,6 +287,132 @@ class IndexLookupOutputTest(keras_parameterized.TestCase, self.assertAllEqual(expected_output, output_data) +@keras_parameterized.run_all_keras_modes +class StaticIndexLookupOutputTest( + keras_parameterized.TestCase, + preprocessing_test_utils.PreprocessingLayerTest): + + def test_int_output_default_lookup_value(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "michigan"]]) + expected_output = [[1, 2, 3, 4], [4, 3, 1, -7]] + + table = get_static_table( + tmpdir=self.get_temp_dir(), + vocab_list=vocab_data, + mask_token="", + oov_tokens=None) + output_data = table.lookup(input_array) + + self.assertAllEqual(expected_output, output_data) + + def test_output_shape(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "michigan"]]) + + table = get_static_table( + tmpdir=self.get_temp_dir(), vocab_list=vocab_data, oov_tokens=None) + output_data = table.lookup(input_array) + + self.assertAllEqual(input_array.shape[1:], output_data.shape[1:]) + + def test_int_output_no_reserved_zero_default_lookup_value(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "michigan"]]) + expected_output = [[0, 1, 2, 3], [3, 2, 0, -7]] + + table = get_static_table( + tmpdir=self.get_temp_dir(), vocab_list=vocab_data, oov_tokens=None) + output_data = table.lookup(input_array) + + self.assertAllEqual(expected_output, output_data) + + +@keras_parameterized.run_all_keras_modes +class CategoricalEncodingStaticInputTest( + keras_parameterized.TestCase, + preprocessing_test_utils.PreprocessingLayerTest): + + def test_sparse_string_input(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], + values=["fire", "michigan"], + dense_shape=[3, 4]) + + expected_indices = [[0, 0], [1, 2]] + expected_values = [5, 1] + expected_dense_shape = [3, 4] + + table = get_static_table( + tmpdir=self.get_temp_dir(), + vocab_list=vocab_data, + mask_token="", + oov_tokens=[1]) + output_data = table.lookup(input_array) + + self.assertAllEqual(expected_indices, output_data.indices) + self.assertAllEqual(expected_values, output_data.values) + self.assertAllEqual(expected_dense_shape, output_data.dense_shape) + + def test_sparse_int_input(self): + vocab_data = np.array([10, 11, 12, 13], dtype=np.int64) + input_array = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], + values=np.array([13, 32], dtype=np.int64), + dense_shape=[3, 4]) + + expected_indices = [[0, 0], [1, 2]] + expected_values = [5, 1] + expected_dense_shape = [3, 4] + + table = get_static_table( + tmpdir=self.get_temp_dir(), + vocab_list=vocab_data, + dtype=dtypes.int64, + mask_token=0, + oov_tokens=[1]) + output_data = table.lookup(input_array) + + self.assertAllEqual(expected_indices, output_data.indices) + self.assertAllEqual(expected_values, output_data.values) + self.assertAllEqual(expected_dense_shape, output_data.dense_shape) + + def test_ragged_string_input(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = ragged_factory_ops.constant( + [["earth", "wind", "fire"], ["fire", "and", "earth", "michigan"]]) + expected_output = [[2, 3, 5], [5, 4, 2, 1]] + + table = get_static_table( + tmpdir=self.get_temp_dir(), + vocab_list=vocab_data, + mask_token="", + oov_tokens=[1]) + output_data = table.lookup(input_array) + + self.assertAllEqual(expected_output, output_data) + + def test_ragged_int_input(self): + vocab_data = np.array([10, 11, 12, 13], dtype=np.int64) + input_array = ragged_factory_ops.constant([[10, 11, 13], [13, 12, 10, 42]], + dtype=np.int64) + expected_output = [[2, 3, 5], [5, 4, 2, 1]] + + table = get_static_table( + tmpdir=self.get_temp_dir(), + vocab_list=vocab_data, + dtype=dtypes.int64, + mask_token=0, + oov_tokens=[1]) + output_data = table.lookup(input_array) + + self.assertAllEqual(expected_output, output_data) + + class GetVocabularyFromFileTest(test.TestCase): def setUp(self): diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 63773ee0f95..c5b9eea85fb 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -647,7 +647,8 @@ class TextFileInitializer(TableInitializerBase): value_index, vocab_size=None, delimiter="\t", - name=None): + name=None, + value_index_offset=0): """Constructs a table initializer object to populate from a text file. It generates one key-value pair per line. The type of table key and @@ -675,6 +676,13 @@ class TextFileInitializer(TableInitializerBase): vocab_size: The number of elements in the file, if known. delimiter: The delimiter to separate fields in a line. name: A name for the operation (optional). + value_index_offset: A number to add to all indices extracted from the file + This is useful for cases where a user would like to reserve one or more + low index values for control characters. For instance, if you would + like to ensure that no vocabulary item is mapped to index 0 (so you can + reserve 0 for a masking value), you can set value_index_offset to 1; + this will mean that the first vocabulary element is mapped to 1 + instead of 0. Raises: ValueError: when the filename is empty, or when the table key and value @@ -718,6 +726,7 @@ class TextFileInitializer(TableInitializerBase): self._name = name self._filename = self._track_trackable( trackable.Asset(filename), "_filename") + self._offset = value_index_offset super(TextFileInitializer, self).__init__(key_dtype, value_dtype) @@ -740,7 +749,8 @@ class TextFileInitializer(TableInitializerBase): self._filename, dtypes.string, name="asset_filepath") init_op = gen_lookup_ops.initialize_table_from_text_file_v2( table.resource_handle, filename, self._key_index, self._value_index, - -1 if self._vocab_size is None else self._vocab_size, self._delimiter) + -1 if self._vocab_size is None else self._vocab_size, self._delimiter, + self._offset) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) # If the filename tensor is anything other than a string constant (e.g., # if it is a placeholder) then it does not make sense to track it as an diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lookup.-text-file-initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lookup.-text-file-initializer.pbtxt index ff9a0ce6e7d..7c69b37421b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lookup.-text-file-initializer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lookup.-text-file-initializer.pbtxt @@ -14,7 +14,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], " + argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\', \'value_index_offset\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\', \'0\'], " } member_method { name: "initialize" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 1cdb121acc0..a9bdab234f3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1950,11 +1950,11 @@ tf_module { } member_method { name: "InitializeTableFromTextFile" - argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], " + argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], " } member_method { name: "InitializeTableFromTextFileV2" - argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], " + argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], " } member_method { name: "InitializeTableV2" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lookup.-text-file-initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lookup.-text-file-initializer.pbtxt index ff9a0ce6e7d..7c69b37421b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.lookup.-text-file-initializer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.lookup.-text-file-initializer.pbtxt @@ -14,7 +14,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], " + argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\', \'value_index_offset\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\', \'0\'], " } member_method { name: "initialize" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 1cdb121acc0..a9bdab234f3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1950,11 +1950,11 @@ tf_module { } member_method { name: "InitializeTableFromTextFile" - argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], " + argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], " } member_method { name: "InitializeTableFromTextFileV2" - argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], " + argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], " } member_method { name: "InitializeTableV2"