Internal change
PiperOrigin-RevId: 357124297 Change-Id: I5aa7c2d89379643a767053802589320eb3476cd6
This commit is contained in:
parent
2ad0a33259
commit
0a501caa91
@ -72,6 +72,7 @@ class GenerateVocabRemappingOp : public OpKernel {
|
|||||||
kUnusedLookupDelim,
|
kUnusedLookupDelim,
|
||||||
-1, // key_index, use the line number.
|
-1, // key_index, use the line number.
|
||||||
-2, // value_index, use the whole line/token.
|
-2, // value_index, use the whole line/token.
|
||||||
|
0, // No offset.
|
||||||
context->env(), new_vocab_table));
|
context->env(), new_vocab_table));
|
||||||
OP_REQUIRES(context,
|
OP_REQUIRES(context,
|
||||||
new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(),
|
new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(),
|
||||||
@ -101,6 +102,7 @@ class GenerateVocabRemappingOp : public OpKernel {
|
|||||||
old_vocab_filename, old_vocab_size_, kUnusedLookupDelim,
|
old_vocab_filename, old_vocab_size_, kUnusedLookupDelim,
|
||||||
-2, // key_index, use the whole line/token.
|
-2, // key_index, use the whole line/token.
|
||||||
-1, // value_index, use the line number.
|
-1, // value_index, use the line number.
|
||||||
|
0, // No offset.
|
||||||
context->env(), old_vocab_table));
|
context->env(), old_vocab_table));
|
||||||
|
|
||||||
// Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ...,
|
// Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ...,
|
||||||
|
@ -105,6 +105,7 @@ class InitializeTableFromTextFileOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_));
|
||||||
string delimiter;
|
string delimiter;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter));
|
||||||
OP_REQUIRES(ctx, delimiter.size() == 1,
|
OP_REQUIRES(ctx, delimiter.size() == 1,
|
||||||
@ -141,7 +142,7 @@ class InitializeTableFromTextFileOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
OP_REQUIRES_OK(ctx, lookup::InitializeTableFromTextFile(
|
OP_REQUIRES_OK(ctx, lookup::InitializeTableFromTextFile(
|
||||||
vocab_filename, vocab_size_, delimiter_, key_index_,
|
vocab_filename, vocab_size_, delimiter_, key_index_,
|
||||||
value_index_, ctx->env(), table));
|
value_index_, offset_, ctx->env(), table));
|
||||||
if (ctx->track_allocations()) {
|
if (ctx->track_allocations()) {
|
||||||
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
|
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
|
||||||
memory_used_before);
|
memory_used_before);
|
||||||
@ -154,6 +155,7 @@ class InitializeTableFromTextFileOp : public OpKernel {
|
|||||||
char delimiter_;
|
char delimiter_;
|
||||||
int64 key_index_;
|
int64 key_index_;
|
||||||
int64 value_index_;
|
int64 value_index_;
|
||||||
|
int64 offset_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp);
|
||||||
};
|
};
|
||||||
|
@ -77,7 +77,7 @@ class TextFileLineIterator
|
|||||||
// delimiter.
|
// delimiter.
|
||||||
Status Init(const string& filename, int64 vocab_size, char delimiter,
|
Status Init(const string& filename, int64 vocab_size, char delimiter,
|
||||||
DataType key_dtype, int64 key_index, DataType value_dtype,
|
DataType key_dtype, int64 key_index, DataType value_dtype,
|
||||||
int64 value_index, Env* env) {
|
int64 value_index, int64 offset, Env* env) {
|
||||||
filename_ = filename;
|
filename_ = filename;
|
||||||
vocab_size_ = vocab_size;
|
vocab_size_ = vocab_size;
|
||||||
delimiter_ = delimiter;
|
delimiter_ = delimiter;
|
||||||
@ -93,6 +93,7 @@ class TextFileLineIterator
|
|||||||
input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize));
|
input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize));
|
||||||
valid_ = true;
|
valid_ = true;
|
||||||
next_id_ = 0;
|
next_id_ = 0;
|
||||||
|
offset_ = offset;
|
||||||
ignore_split_ = std::max(key_index_, value_index_) < 0;
|
ignore_split_ = std::max(key_index_, value_index_) < 0;
|
||||||
Next();
|
Next();
|
||||||
return status_;
|
return status_;
|
||||||
@ -143,6 +144,7 @@ class TextFileLineIterator
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
status_ = SetValue(line, tokens, key_index_, &key_);
|
status_ = SetValue(line, tokens, key_index_, &key_);
|
||||||
if (!status_.ok()) {
|
if (!status_.ok()) {
|
||||||
valid_ = false;
|
valid_ = false;
|
||||||
@ -186,6 +188,7 @@ class TextFileLineIterator
|
|||||||
int64 value_index_;
|
int64 value_index_;
|
||||||
Env* env_;
|
Env* env_;
|
||||||
int64 next_id_;
|
int64 next_id_;
|
||||||
|
int64 offset_;
|
||||||
int64 vocab_size_;
|
int64 vocab_size_;
|
||||||
string filename_;
|
string filename_;
|
||||||
char delimiter_;
|
char delimiter_;
|
||||||
@ -199,7 +202,7 @@ class TextFileLineIterator
|
|||||||
Status SetValue(const string& line, const std::vector<string>& tokens,
|
Status SetValue(const string& line, const std::vector<string>& tokens,
|
||||||
int64 index, Tensor* tensor) {
|
int64 index, Tensor* tensor) {
|
||||||
if (index == kLineNumber) {
|
if (index == kLineNumber) {
|
||||||
tensor->flat<int64>()(0) = next_id_;
|
tensor->flat<int64>()(0) = next_id_ + offset_;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
const string& token = (index == kWholeLine) ? line : tokens[index];
|
const string& token = (index == kWholeLine) ? line : tokens[index];
|
||||||
@ -212,7 +215,7 @@ class TextFileLineIterator
|
|||||||
return errors::InvalidArgument("Field ", token, " in line ", next_id_,
|
return errors::InvalidArgument("Field ", token, " in line ", next_id_,
|
||||||
" is not a valid int32.");
|
" is not a valid int32.");
|
||||||
}
|
}
|
||||||
tensor->flat<int32>()(0) = value;
|
tensor->flat<int32>()(0) = value + offset_;
|
||||||
} break;
|
} break;
|
||||||
case DT_INT64: {
|
case DT_INT64: {
|
||||||
int64 value;
|
int64 value;
|
||||||
@ -352,7 +355,7 @@ Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
|
|||||||
// Helper function to initialize an InitializableLookupTable from a text file.
|
// Helper function to initialize an InitializableLookupTable from a text file.
|
||||||
Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
||||||
char delimiter, int32 key_index,
|
char delimiter, int32 key_index,
|
||||||
int32 value_index, Env* env,
|
int32 value_index, int64 offset, Env* env,
|
||||||
InitializableLookupTable* table) {
|
InitializableLookupTable* table) {
|
||||||
if (key_index == kLineNumber && table->key_dtype() != DT_INT64) {
|
if (key_index == kLineNumber && table->key_dtype() != DT_INT64) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
@ -380,7 +383,8 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
|||||||
|
|
||||||
TextFileLineIterator iter;
|
TextFileLineIterator iter;
|
||||||
TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype,
|
TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype,
|
||||||
key_index, value_dtype, value_index, env));
|
key_index, value_dtype, value_index, offset,
|
||||||
|
env));
|
||||||
// For initialization from files, ignore if the table is already
|
// For initialization from files, ignore if the table is already
|
||||||
// initialized. The table shared name should contain the filename to
|
// initialized. The table shared name should contain the filename to
|
||||||
// avoid trying to initialize the same table from the same file at the same
|
// avoid trying to initialize the same table from the same file at the same
|
||||||
|
@ -53,7 +53,7 @@ Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
|
|||||||
|
|
||||||
Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
||||||
char delimiter, int32 key_index,
|
char delimiter, int32 key_index,
|
||||||
int32 value_index, Env* env,
|
int32 value_index, int64 offset, Env* env,
|
||||||
InitializableLookupTable* table);
|
InitializableLookupTable* table);
|
||||||
|
|
||||||
// Initializes `table` from `dataset` by iterating over it. Caller retains
|
// Initializes `table` from `dataset` by iterating over it. Caller retains
|
||||||
|
@ -480,6 +480,7 @@ REGISTER_OP("InitializeTableFromTextFile")
|
|||||||
.Attr("value_index: int >= -2")
|
.Attr("value_index: int >= -2")
|
||||||
.Attr("vocab_size: int >= -1 = -1")
|
.Attr("vocab_size: int >= -1 = -1")
|
||||||
.Attr("delimiter: string = '\t'")
|
.Attr("delimiter: string = '\t'")
|
||||||
|
.Attr("offset: int = 0")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
ShapeHandle handle;
|
ShapeHandle handle;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
|
||||||
@ -497,6 +498,7 @@ REGISTER_OP("InitializeTableFromTextFileV2")
|
|||||||
.Attr("value_index: int >= -2")
|
.Attr("value_index: int >= -2")
|
||||||
.Attr("vocab_size: int >= -1 = -1")
|
.Attr("vocab_size: int >= -1 = -1")
|
||||||
.Attr("delimiter: string = '\t'")
|
.Attr("delimiter: string = '\t'")
|
||||||
|
.Attr("offset: int = 0")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
ShapeHandle handle;
|
ShapeHandle handle;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
|
||||||
|
@ -802,8 +802,14 @@ class TrackableWeightHandler(object):
|
|||||||
|
|
||||||
# TODO(b/141682913): Figure out why this is private and fix it.
|
# TODO(b/141682913): Figure out why this is private and fix it.
|
||||||
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
|
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
|
||||||
if len(saveables) != 1:
|
# 'Saveables' won't exist when we're passed a legacy TF1 table like
|
||||||
raise ValueError('Only Trackables with one Saveable are supported.')
|
# a StaticHashTable.
|
||||||
|
if not saveables:
|
||||||
|
self._num_tensors = 0
|
||||||
|
self._setter = lambda weights: None
|
||||||
|
self._getter = lambda: []
|
||||||
|
|
||||||
|
elif len(saveables) == 1:
|
||||||
saveable = list(saveables)[0]
|
saveable = list(saveables)[0]
|
||||||
|
|
||||||
if ops.executing_eagerly_outside_functions():
|
if ops.executing_eagerly_outside_functions():
|
||||||
@ -828,9 +834,14 @@ class TrackableWeightHandler(object):
|
|||||||
tensor = spec.tensor
|
tensor = spec.tensor
|
||||||
self._placeholder_tensors.append(
|
self._placeholder_tensors.append(
|
||||||
array_ops.placeholder(tensor.dtype, tensor.shape))
|
array_ops.placeholder(tensor.dtype, tensor.shape))
|
||||||
self._assign_op = self._saveable.restore(self._placeholder_tensors, None)
|
self._assign_op = self._saveable.restore(self._placeholder_tensors,
|
||||||
|
None)
|
||||||
self._setter = self._set_weights_v1
|
self._setter = self._set_weights_v1
|
||||||
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
|
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
|
||||||
|
else:
|
||||||
|
raise ValueError('Only Trackables with one Saveable are supported. '
|
||||||
|
'The Trackable %s has %d Saveables.' %
|
||||||
|
(trackable, len(saveables)))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_tensors(self):
|
def num_tensors(self):
|
||||||
|
@ -80,6 +80,21 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "index_lookup_forward_benchmark",
|
||||||
|
srcs = ["index_lookup_forward_benchmark.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:extra_py_tests_deps",
|
||||||
|
"//tensorflow/python:platform_benchmark",
|
||||||
|
"//tensorflow/python:tensor_shape",
|
||||||
|
"//tensorflow/python/compat:v2_compat",
|
||||||
|
"//tensorflow/python/keras/layers/preprocessing:index_lookup",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "normalization_adapt_benchmark",
|
name = "normalization_adapt_benchmark",
|
||||||
srcs = ["normalization_adapt_benchmark.py"],
|
srcs = ["normalization_adapt_benchmark.py"],
|
||||||
|
@ -0,0 +1,146 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Benchmark for Keras text vectorization preprocessing layer's adapt method."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python import keras
|
||||||
|
from tensorflow.python.compat import v2_compat
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.keras.layers.preprocessing import index_lookup
|
||||||
|
from tensorflow.python.ops import lookup_ops
|
||||||
|
from tensorflow.python.platform import benchmark
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
v2_compat.enable_v2_behavior()
|
||||||
|
|
||||||
|
|
||||||
|
# word_gen creates random sequences of ASCII letters (both lowercase and upper).
|
||||||
|
# The number of unique strings is ~2,700.
|
||||||
|
def tensor_gen(batch, num_elements):
|
||||||
|
data = []
|
||||||
|
for _ in range(batch):
|
||||||
|
batch_element = []
|
||||||
|
for _ in range(num_elements - 1):
|
||||||
|
tok = "".join(random.choice(string.ascii_letters) for i in range(2))
|
||||||
|
batch_element.append(tok)
|
||||||
|
batch_element.append("") # Explicitly test the empty string.
|
||||||
|
data.append(batch_element)
|
||||||
|
return constant_op.constant(data)
|
||||||
|
|
||||||
|
|
||||||
|
def get_vocab():
|
||||||
|
vocab = list(
|
||||||
|
set([a + b for a in string.ascii_letters for b in string.ascii_letters])) # pylint:disable=g-complex-comprehension
|
||||||
|
vocab.sort()
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
# This class uses TestCase for get_temp_dir().
|
||||||
|
class BenchmarkLookup(benchmark.TensorFlowBenchmark):
|
||||||
|
"""Benchmark the index lookup layer's forward pass."""
|
||||||
|
|
||||||
|
def _write_to_temp_file(self, file_name, vocab_list):
|
||||||
|
vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt")
|
||||||
|
with gfile.GFile(vocab_path, "w") as writer:
|
||||||
|
for vocab in vocab_list:
|
||||||
|
writer.write(vocab + "\n")
|
||||||
|
writer.flush()
|
||||||
|
writer.close()
|
||||||
|
return vocab_path
|
||||||
|
|
||||||
|
def run_numpy_implementation(self, data, vocab):
|
||||||
|
"""Test the python implementation."""
|
||||||
|
input_t = keras.Input(shape=(), dtype=dtypes.string)
|
||||||
|
layer = index_lookup.IndexLookup(
|
||||||
|
vocabulary=vocab,
|
||||||
|
max_tokens=None,
|
||||||
|
num_oov_indices=1,
|
||||||
|
mask_token="",
|
||||||
|
oov_token="OOV",
|
||||||
|
dtype=dtypes.string)
|
||||||
|
out_t = layer(input_t)
|
||||||
|
model = keras.Model(input_t, out_t)
|
||||||
|
num_repeats = 5
|
||||||
|
starts = []
|
||||||
|
ends = []
|
||||||
|
_ = model(data)
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
starts.append(time.time())
|
||||||
|
out = model(data)
|
||||||
|
ends.append(time.time())
|
||||||
|
avg_time = np.mean(np.array(ends) - np.array(starts))
|
||||||
|
return avg_time, out
|
||||||
|
|
||||||
|
def bm_adapt_implementation(self, num_elements, batch_size):
|
||||||
|
"""Test the KPL adapt implementation."""
|
||||||
|
vocab = get_vocab()
|
||||||
|
vocab_file = self._write_to_temp_file("vocab", vocab)
|
||||||
|
vocabulary_initializer = lookup_ops.TextFileInitializer(
|
||||||
|
filename=vocab_file,
|
||||||
|
key_dtype=dtypes.string,
|
||||||
|
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
|
value_dtype=dtypes.int64,
|
||||||
|
value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
|
||||||
|
value_index_offset=2)
|
||||||
|
input_t = keras.Input(shape=(), dtype=dtypes.string)
|
||||||
|
layer = index_lookup.IndexLookup(
|
||||||
|
vocabulary=vocabulary_initializer,
|
||||||
|
max_tokens=None,
|
||||||
|
num_oov_indices=1,
|
||||||
|
mask_token="",
|
||||||
|
oov_token="OOV",
|
||||||
|
dtype=dtypes.string)
|
||||||
|
out_t = layer(input_t)
|
||||||
|
model = keras.Model(input_t, out_t)
|
||||||
|
num_repeats = 5
|
||||||
|
starts = []
|
||||||
|
ends = []
|
||||||
|
data = tensor_gen(batch_size, num_elements)
|
||||||
|
_ = model(data)
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
starts.append(time.time())
|
||||||
|
_ = model(data)
|
||||||
|
ends.append(time.time())
|
||||||
|
avg_time = np.mean(np.array(ends) - np.array(starts))
|
||||||
|
baseline, _ = self.run_numpy_implementation(data, vocab)
|
||||||
|
extras = {
|
||||||
|
"numpy implementation baseline": baseline,
|
||||||
|
"delta seconds": (baseline - avg_time),
|
||||||
|
"delta percent": ((baseline - avg_time) / baseline) * 100
|
||||||
|
}
|
||||||
|
name = "index_lookup_forward|%s_elements|batch_%s" % (num_elements,
|
||||||
|
batch_size)
|
||||||
|
self.report_benchmark(
|
||||||
|
iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
|
||||||
|
|
||||||
|
def benchmark_vocab_size_by_batch(self):
|
||||||
|
for tensor_size in [100, 1000, 10000]:
|
||||||
|
for batch in [1, 16, 2048]:
|
||||||
|
self.bm_adapt_implementation(tensor_size, batch)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -184,18 +184,6 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
self._value_dtype = dtypes.int64
|
self._value_dtype = dtypes.int64
|
||||||
oov_value = self._oov_value
|
oov_value = self._oov_value
|
||||||
|
|
||||||
self._table = lookup_ops.MutableHashTable(
|
|
||||||
key_dtype=self._key_dtype,
|
|
||||||
value_dtype=self._value_dtype,
|
|
||||||
default_value=oov_value,
|
|
||||||
name=(self._name + "_index_table"))
|
|
||||||
tracked_table = self._add_trackable(self._table, trainable=False)
|
|
||||||
# This is a workaround for summary() on this layer. Because the table is
|
|
||||||
# not mutable during training, the effective number of parameters (and so
|
|
||||||
# the weight shape) is 0; we add this as an attr so that the parameter
|
|
||||||
# counting code in the Model object doesn't throw an attribute error.
|
|
||||||
tracked_table.shape = tensor_shape.TensorShape((0,))
|
|
||||||
|
|
||||||
if self.num_oov_indices <= 1:
|
if self.num_oov_indices <= 1:
|
||||||
oov_indices = None
|
oov_indices = None
|
||||||
else:
|
else:
|
||||||
@ -203,11 +191,28 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
oov_end = oov_start + num_oov_indices
|
oov_end = oov_start + num_oov_indices
|
||||||
oov_indices = list(range(oov_start, oov_end))
|
oov_indices = list(range(oov_start, oov_end))
|
||||||
|
|
||||||
|
if vocabulary is not None and isinstance(vocabulary,
|
||||||
|
lookup_ops.TextFileInitializer):
|
||||||
|
self._table = self._static_table_class()(
|
||||||
|
vocabulary, default_value=oov_value)
|
||||||
|
self._table_handler = table_utils.TableHandler(
|
||||||
|
table=self._table,
|
||||||
|
mask_token=mask_token,
|
||||||
|
oov_tokens=oov_indices,
|
||||||
|
use_v1_apis=self._use_v1_apis())
|
||||||
|
self.max_tokens = (
|
||||||
|
self._table_handler.vocab_size() + self.num_oov_indices +
|
||||||
|
(0 if mask_token is None else 1))
|
||||||
|
else:
|
||||||
|
self._table = lookup_ops.MutableHashTable(
|
||||||
|
key_dtype=self._key_dtype,
|
||||||
|
value_dtype=self._value_dtype,
|
||||||
|
default_value=oov_value,
|
||||||
|
name=(self._name + "_index_table"))
|
||||||
self._table_handler = table_utils.TableHandler(
|
self._table_handler = table_utils.TableHandler(
|
||||||
table=self._table,
|
table=self._table,
|
||||||
oov_tokens=oov_indices,
|
oov_tokens=oov_indices,
|
||||||
use_v1_apis=self._use_v1_apis())
|
use_v1_apis=self._use_v1_apis())
|
||||||
|
|
||||||
if vocabulary is not None:
|
if vocabulary is not None:
|
||||||
self.set_vocabulary(vocabulary)
|
self.set_vocabulary(vocabulary)
|
||||||
|
|
||||||
@ -232,6 +237,13 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
dtype=K.floatx(),
|
dtype=K.floatx(),
|
||||||
initializer=initializer)
|
initializer=initializer)
|
||||||
|
|
||||||
|
tracked_table = self._add_trackable(self._table, trainable=False)
|
||||||
|
# This is a workaround for summary() on this layer. Because the table is
|
||||||
|
# not mutable during training, the effective number of parameters (and so
|
||||||
|
# the weight shape) is 0; we add this as an attr so that the parameter
|
||||||
|
# counting code in the Model object doesn't throw an attribute error.
|
||||||
|
tracked_table.shape = tensor_shape.TensorShape((0,))
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
if self.output_mode != INT:
|
if self.output_mode != INT:
|
||||||
return tensor_shape.TensorShape([input_shape[0], self.max_tokens])
|
return tensor_shape.TensorShape([input_shape[0], self.max_tokens])
|
||||||
@ -538,6 +550,9 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
def _use_v1_apis(self):
|
def _use_v1_apis(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _static_table_class(self):
|
||||||
|
return lookup_ops.StaticHashTable
|
||||||
|
|
||||||
|
|
||||||
class _IndexLookupAccumulator(
|
class _IndexLookupAccumulator(
|
||||||
collections.namedtuple("Accumulator",
|
collections.namedtuple("Accumulator",
|
||||||
|
@ -41,7 +41,9 @@ from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
|
|||||||
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
|
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
|
||||||
from tensorflow.python.keras.saving import save
|
from tensorflow.python.keras.saving import save
|
||||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||||
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -703,6 +705,15 @@ class CategoricalEncodingAdaptTest(
|
|||||||
class IndexLookupOutputTest(keras_parameterized.TestCase,
|
class IndexLookupOutputTest(keras_parameterized.TestCase,
|
||||||
preprocessing_test_utils.PreprocessingLayerTest):
|
preprocessing_test_utils.PreprocessingLayerTest):
|
||||||
|
|
||||||
|
def _write_to_temp_file(self, file_name, vocab_list):
|
||||||
|
vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt")
|
||||||
|
with gfile.GFile(vocab_path, "w") as writer:
|
||||||
|
for vocab in vocab_list:
|
||||||
|
writer.write(vocab + "\n")
|
||||||
|
writer.flush()
|
||||||
|
writer.close()
|
||||||
|
return vocab_path
|
||||||
|
|
||||||
def test_int_output(self):
|
def test_int_output(self):
|
||||||
vocab_data = ["earth", "wind", "and", "fire"]
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||||
@ -958,7 +969,60 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
|||||||
layer_output = layer(input_data)
|
layer_output = layer(input_data)
|
||||||
self.assertAllEqual(layer_output.shape.as_list(), [16, 2])
|
self.assertAllEqual(layer_output.shape.as_list(), [16, 2])
|
||||||
|
|
||||||
|
def test_int_output_file_vocab(self):
|
||||||
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
|
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||||
|
["fire", "", "earth", "michigan"]])
|
||||||
|
expected_output = [[2, 3, 4, 5], [5, 0, 2, 1]]
|
||||||
|
|
||||||
|
vocab_file = self._write_to_temp_file("temp", vocab_data)
|
||||||
|
vocabulary_initializer = lookup_ops.TextFileInitializer(
|
||||||
|
filename=vocab_file,
|
||||||
|
key_dtype=dtypes.string,
|
||||||
|
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
|
value_dtype=dtypes.int64,
|
||||||
|
value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
|
||||||
|
value_index_offset=2)
|
||||||
|
|
||||||
|
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||||
|
layer = get_layer_class()(
|
||||||
|
vocabulary=vocabulary_initializer,
|
||||||
|
max_tokens=None,
|
||||||
|
num_oov_indices=1,
|
||||||
|
mask_token="",
|
||||||
|
oov_token="[OOV]",
|
||||||
|
dtype=dtypes.string)
|
||||||
|
int_data = layer(input_data)
|
||||||
|
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||||
|
output_dataset = model.predict(input_array)
|
||||||
|
self.assertAllEqual(expected_output, output_dataset)
|
||||||
|
|
||||||
|
def test_int_output_int_file_vocab(self):
|
||||||
|
vocab_data = ["10", "20", "30", "40"]
|
||||||
|
input_array = np.array([[10, 20, 30, 40], [40, 0, 10, 42]])
|
||||||
|
expected_output = [[2, 3, 4, 5], [5, 0, 2, 1]]
|
||||||
|
|
||||||
|
vocab_file = self._write_to_temp_file("temp", vocab_data)
|
||||||
|
vocabulary_initializer = lookup_ops.TextFileInitializer(
|
||||||
|
filename=vocab_file,
|
||||||
|
key_dtype=dtypes.int64,
|
||||||
|
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
|
value_dtype=dtypes.int64,
|
||||||
|
value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
|
||||||
|
value_index_offset=2)
|
||||||
|
|
||||||
|
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
|
||||||
|
layer = get_layer_class()(
|
||||||
|
vocabulary=vocabulary_initializer,
|
||||||
|
max_tokens=None,
|
||||||
|
num_oov_indices=1,
|
||||||
|
mask_token=0,
|
||||||
|
oov_token=-1,
|
||||||
|
dtype=dtypes.int64)
|
||||||
|
int_data = layer(input_data)
|
||||||
|
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||||
|
output_dataset = model.predict(input_array)
|
||||||
|
self.assertAllEqual(expected_output, output_dataset)
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class IndexLookupVocabularyTest(keras_parameterized.TestCase,
|
class IndexLookupVocabularyTest(keras_parameterized.TestCase,
|
||||||
preprocessing_test_utils.PreprocessingLayerTest
|
preprocessing_test_utils.PreprocessingLayerTest
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.keras.engine import base_preprocessing_layer_v1
|
from tensorflow.python.keras.engine import base_preprocessing_layer_v1
|
||||||
from tensorflow.python.keras.layers.preprocessing import index_lookup
|
from tensorflow.python.keras.layers.preprocessing import index_lookup
|
||||||
|
from tensorflow.python.ops import lookup_ops
|
||||||
|
|
||||||
|
|
||||||
class IndexLookup(index_lookup.IndexLookup,
|
class IndexLookup(index_lookup.IndexLookup,
|
||||||
@ -58,3 +59,6 @@ class IndexLookup(index_lookup.IndexLookup,
|
|||||||
|
|
||||||
def _use_v1_apis(self):
|
def _use_v1_apis(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _static_table_class(self):
|
||||||
|
return lookup_ops.StaticHashTableV1
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import sparse_tensor
|
|||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||||
@ -38,8 +39,24 @@ from tensorflow.python.platform import gfile
|
|||||||
class TableHandler(object):
|
class TableHandler(object):
|
||||||
"""Wrapper object that holds a lookup table and provides accessors."""
|
"""Wrapper object that holds a lookup table and provides accessors."""
|
||||||
|
|
||||||
def __init__(self, table, oov_tokens=None, use_v1_apis=False):
|
def __init__(self,
|
||||||
|
table,
|
||||||
|
oov_tokens=None,
|
||||||
|
mask_token=None,
|
||||||
|
use_v1_apis=False):
|
||||||
self.table = table
|
self.table = table
|
||||||
|
|
||||||
|
# If we are using V1 APIs, and the table has an initializer, we need to run
|
||||||
|
# it. However, not all tables have initializers, so we try-except here.
|
||||||
|
if use_v1_apis:
|
||||||
|
try:
|
||||||
|
K.get_session().run(self.table.initializer)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.mutable = isinstance(table, lookup_ops.MutableHashTable)
|
||||||
|
self.mask_token = mask_token
|
||||||
|
|
||||||
self.use_v1_apis = use_v1_apis
|
self.use_v1_apis = use_v1_apis
|
||||||
if oov_tokens is None:
|
if oov_tokens is None:
|
||||||
self.oov_tokens = oov_tokens
|
self.oov_tokens = oov_tokens
|
||||||
@ -56,10 +73,17 @@ class TableHandler(object):
|
|||||||
return self._eval(self.table.size())
|
return self._eval(self.table.size())
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
|
if not self.mutable:
|
||||||
|
return RuntimeError("Unable to clear a statically-backed table.")
|
||||||
|
|
||||||
keys, _ = self.table.export()
|
keys, _ = self.table.export()
|
||||||
self._run(self.table.remove(keys))
|
self._run(self.table.remove(keys))
|
||||||
|
|
||||||
def insert(self, keys, values):
|
def insert(self, keys, values):
|
||||||
|
"""Insert values into the backed table."""
|
||||||
|
if not self.mutable:
|
||||||
|
raise RuntimeError("Unable to insert into a statically-backed table.")
|
||||||
|
|
||||||
if len(values) != len(keys):
|
if len(values) != len(keys):
|
||||||
raise RuntimeError("Size mismatch between values and key arrays. "
|
raise RuntimeError("Size mismatch between values and key arrays. "
|
||||||
"Keys had size %s, values had size %s." %
|
"Keys had size %s, values had size %s." %
|
||||||
@ -90,12 +114,35 @@ class TableHandler(object):
|
|||||||
|
|
||||||
return array_ops.where(oov_locations, oov_values, lookups)
|
return array_ops.where(oov_locations, oov_values, lookups)
|
||||||
|
|
||||||
|
def _lookup_and_mask(self, inputs):
|
||||||
|
"""Return a lookup with any location with the mask_token masked to 0."""
|
||||||
|
lookups = self.table.lookup(inputs)
|
||||||
|
# If we don't need to handle masking, return the lookup values directly.
|
||||||
|
if self.mask_token is None:
|
||||||
|
return lookups
|
||||||
|
|
||||||
|
# If we do need to handle masking, increment all the lookup values by 1
|
||||||
|
# to account for the mask value at location 0. This also increments the
|
||||||
|
# OOV value, so replace that. (This is inefficient, but we can't adjust
|
||||||
|
# the table safely, so we don't have a choice.)
|
||||||
|
oov_locations = math_ops.equal(lookups, self.table._default_value) # pylint: disable=protected-access
|
||||||
|
oov_values = array_ops.ones_like(
|
||||||
|
lookups, dtype=self.table._value_dtype) * self.table._default_value # pylint: disable=protected-access
|
||||||
|
adjusted_lookups = array_ops.where(oov_locations, oov_values, lookups)
|
||||||
|
|
||||||
|
# Inject 0s wherever the mask token was in the inputs.
|
||||||
|
mask_locations = math_ops.equal(inputs, self.mask_token)
|
||||||
|
return array_ops.where(
|
||||||
|
mask_locations,
|
||||||
|
array_ops.zeros_like(lookups, dtype=self.table._value_dtype), # pylint: disable=protected-access
|
||||||
|
adjusted_lookups) # pylint: disable=protected-access
|
||||||
|
|
||||||
def _ragged_lookup(self, inputs):
|
def _ragged_lookup(self, inputs):
|
||||||
"""Perform a table lookup on a ragged tensor."""
|
"""Perform a table lookup on a ragged tensor."""
|
||||||
# The table lookup ops don't natively support ragged tensors, so if we have
|
# The table lookup ops don't natively support ragged tensors, so if we have
|
||||||
# a RT we need to use map_flat_values to look up every element.
|
# a RT we need to use map_flat_values to look up every element.
|
||||||
indexed_data = ragged_functional_ops.map_flat_values(
|
indexed_data = ragged_functional_ops.map_flat_values(
|
||||||
self.table.lookup, inputs)
|
self._lookup_and_mask, inputs)
|
||||||
indexed_data = ragged_functional_ops.map_flat_values(
|
indexed_data = ragged_functional_ops.map_flat_values(
|
||||||
self._replace_oov_buckets, inputs, indexed_data)
|
self._replace_oov_buckets, inputs, indexed_data)
|
||||||
# table.lookup is not shape-preserving, so we need to set the shape here.
|
# table.lookup is not shape-preserving, so we need to set the shape here.
|
||||||
@ -107,7 +154,7 @@ class TableHandler(object):
|
|||||||
|
|
||||||
def _sparse_lookup(self, inputs):
|
def _sparse_lookup(self, inputs):
|
||||||
"""Perform a table lookup on a sparse tensor."""
|
"""Perform a table lookup on a sparse tensor."""
|
||||||
values = self.table.lookup(inputs.values)
|
values = self._lookup_and_mask(inputs.values)
|
||||||
values = self._replace_oov_buckets(inputs.values, values)
|
values = self._replace_oov_buckets(inputs.values, values)
|
||||||
indexed_data = sparse_tensor.SparseTensor(inputs.indices, values,
|
indexed_data = sparse_tensor.SparseTensor(inputs.indices, values,
|
||||||
inputs.dense_shape)
|
inputs.dense_shape)
|
||||||
@ -118,7 +165,7 @@ class TableHandler(object):
|
|||||||
|
|
||||||
def _tensor_lookup(self, inputs):
|
def _tensor_lookup(self, inputs):
|
||||||
"""Perform a table lookup on a tf.tensor."""
|
"""Perform a table lookup on a tf.tensor."""
|
||||||
values = self.table.lookup(inputs)
|
values = self._lookup_and_mask(inputs)
|
||||||
indexed_data = self._replace_oov_buckets(inputs, values)
|
indexed_data = self._replace_oov_buckets(inputs, values)
|
||||||
# (b/149446477): output does not preserve input shape.
|
# (b/149446477): output does not preserve input shape.
|
||||||
indexed_data.set_shape(inputs.shape)
|
indexed_data.set_shape(inputs.shape)
|
||||||
|
@ -45,6 +45,41 @@ def get_table(dtype=dtypes.string, oov_tokens=None):
|
|||||||
table, oov_tokens, use_v1_apis=(not context.executing_eagerly()))
|
table, oov_tokens, use_v1_apis=(not context.executing_eagerly()))
|
||||||
|
|
||||||
|
|
||||||
|
def get_static_table(tmpdir,
|
||||||
|
vocab_list,
|
||||||
|
mask_token=None,
|
||||||
|
dtype=dtypes.string,
|
||||||
|
oov_tokens=None):
|
||||||
|
vocabulary_file = os.path.join(tmpdir, "tmp_vocab.txt")
|
||||||
|
|
||||||
|
if dtype == dtypes.string:
|
||||||
|
with open(vocabulary_file, "w") as f:
|
||||||
|
f.write("\n".join(vocab_list) + "\n")
|
||||||
|
else:
|
||||||
|
with open(vocabulary_file, "w") as f:
|
||||||
|
f.write("\n".join([str(v) for v in vocab_list]) + "\n")
|
||||||
|
|
||||||
|
offset = ((0 if mask_token is None else 1) +
|
||||||
|
(len(oov_tokens) if oov_tokens is not None else 0))
|
||||||
|
init = lookup_ops.TextFileInitializer(
|
||||||
|
vocabulary_file,
|
||||||
|
dtype,
|
||||||
|
lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||||
|
dtypes.int64,
|
||||||
|
lookup_ops.TextFileIndex.LINE_NUMBER,
|
||||||
|
value_index_offset=offset)
|
||||||
|
if context.executing_eagerly():
|
||||||
|
table = lookup_ops.StaticHashTable(init, default_value=-7)
|
||||||
|
else:
|
||||||
|
table = lookup_ops.StaticHashTableV1(init, default_value=-7)
|
||||||
|
|
||||||
|
return table_utils.TableHandler(
|
||||||
|
table,
|
||||||
|
oov_tokens,
|
||||||
|
mask_token=mask_token,
|
||||||
|
use_v1_apis=(not context.executing_eagerly()))
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class CategoricalEncodingInputTest(
|
class CategoricalEncodingInputTest(
|
||||||
keras_parameterized.TestCase,
|
keras_parameterized.TestCase,
|
||||||
@ -252,6 +287,132 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
|||||||
self.assertAllEqual(expected_output, output_data)
|
self.assertAllEqual(expected_output, output_data)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
class StaticIndexLookupOutputTest(
|
||||||
|
keras_parameterized.TestCase,
|
||||||
|
preprocessing_test_utils.PreprocessingLayerTest):
|
||||||
|
|
||||||
|
def test_int_output_default_lookup_value(self):
|
||||||
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
|
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||||
|
["fire", "and", "earth", "michigan"]])
|
||||||
|
expected_output = [[1, 2, 3, 4], [4, 3, 1, -7]]
|
||||||
|
|
||||||
|
table = get_static_table(
|
||||||
|
tmpdir=self.get_temp_dir(),
|
||||||
|
vocab_list=vocab_data,
|
||||||
|
mask_token="",
|
||||||
|
oov_tokens=None)
|
||||||
|
output_data = table.lookup(input_array)
|
||||||
|
|
||||||
|
self.assertAllEqual(expected_output, output_data)
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
|
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||||
|
["fire", "and", "earth", "michigan"]])
|
||||||
|
|
||||||
|
table = get_static_table(
|
||||||
|
tmpdir=self.get_temp_dir(), vocab_list=vocab_data, oov_tokens=None)
|
||||||
|
output_data = table.lookup(input_array)
|
||||||
|
|
||||||
|
self.assertAllEqual(input_array.shape[1:], output_data.shape[1:])
|
||||||
|
|
||||||
|
def test_int_output_no_reserved_zero_default_lookup_value(self):
|
||||||
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
|
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||||
|
["fire", "and", "earth", "michigan"]])
|
||||||
|
expected_output = [[0, 1, 2, 3], [3, 2, 0, -7]]
|
||||||
|
|
||||||
|
table = get_static_table(
|
||||||
|
tmpdir=self.get_temp_dir(), vocab_list=vocab_data, oov_tokens=None)
|
||||||
|
output_data = table.lookup(input_array)
|
||||||
|
|
||||||
|
self.assertAllEqual(expected_output, output_data)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
class CategoricalEncodingStaticInputTest(
|
||||||
|
keras_parameterized.TestCase,
|
||||||
|
preprocessing_test_utils.PreprocessingLayerTest):
|
||||||
|
|
||||||
|
def test_sparse_string_input(self):
|
||||||
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
|
input_array = sparse_tensor.SparseTensor(
|
||||||
|
indices=[[0, 0], [1, 2]],
|
||||||
|
values=["fire", "michigan"],
|
||||||
|
dense_shape=[3, 4])
|
||||||
|
|
||||||
|
expected_indices = [[0, 0], [1, 2]]
|
||||||
|
expected_values = [5, 1]
|
||||||
|
expected_dense_shape = [3, 4]
|
||||||
|
|
||||||
|
table = get_static_table(
|
||||||
|
tmpdir=self.get_temp_dir(),
|
||||||
|
vocab_list=vocab_data,
|
||||||
|
mask_token="",
|
||||||
|
oov_tokens=[1])
|
||||||
|
output_data = table.lookup(input_array)
|
||||||
|
|
||||||
|
self.assertAllEqual(expected_indices, output_data.indices)
|
||||||
|
self.assertAllEqual(expected_values, output_data.values)
|
||||||
|
self.assertAllEqual(expected_dense_shape, output_data.dense_shape)
|
||||||
|
|
||||||
|
def test_sparse_int_input(self):
|
||||||
|
vocab_data = np.array([10, 11, 12, 13], dtype=np.int64)
|
||||||
|
input_array = sparse_tensor.SparseTensor(
|
||||||
|
indices=[[0, 0], [1, 2]],
|
||||||
|
values=np.array([13, 32], dtype=np.int64),
|
||||||
|
dense_shape=[3, 4])
|
||||||
|
|
||||||
|
expected_indices = [[0, 0], [1, 2]]
|
||||||
|
expected_values = [5, 1]
|
||||||
|
expected_dense_shape = [3, 4]
|
||||||
|
|
||||||
|
table = get_static_table(
|
||||||
|
tmpdir=self.get_temp_dir(),
|
||||||
|
vocab_list=vocab_data,
|
||||||
|
dtype=dtypes.int64,
|
||||||
|
mask_token=0,
|
||||||
|
oov_tokens=[1])
|
||||||
|
output_data = table.lookup(input_array)
|
||||||
|
|
||||||
|
self.assertAllEqual(expected_indices, output_data.indices)
|
||||||
|
self.assertAllEqual(expected_values, output_data.values)
|
||||||
|
self.assertAllEqual(expected_dense_shape, output_data.dense_shape)
|
||||||
|
|
||||||
|
def test_ragged_string_input(self):
|
||||||
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
|
input_array = ragged_factory_ops.constant(
|
||||||
|
[["earth", "wind", "fire"], ["fire", "and", "earth", "michigan"]])
|
||||||
|
expected_output = [[2, 3, 5], [5, 4, 2, 1]]
|
||||||
|
|
||||||
|
table = get_static_table(
|
||||||
|
tmpdir=self.get_temp_dir(),
|
||||||
|
vocab_list=vocab_data,
|
||||||
|
mask_token="",
|
||||||
|
oov_tokens=[1])
|
||||||
|
output_data = table.lookup(input_array)
|
||||||
|
|
||||||
|
self.assertAllEqual(expected_output, output_data)
|
||||||
|
|
||||||
|
def test_ragged_int_input(self):
|
||||||
|
vocab_data = np.array([10, 11, 12, 13], dtype=np.int64)
|
||||||
|
input_array = ragged_factory_ops.constant([[10, 11, 13], [13, 12, 10, 42]],
|
||||||
|
dtype=np.int64)
|
||||||
|
expected_output = [[2, 3, 5], [5, 4, 2, 1]]
|
||||||
|
|
||||||
|
table = get_static_table(
|
||||||
|
tmpdir=self.get_temp_dir(),
|
||||||
|
vocab_list=vocab_data,
|
||||||
|
dtype=dtypes.int64,
|
||||||
|
mask_token=0,
|
||||||
|
oov_tokens=[1])
|
||||||
|
output_data = table.lookup(input_array)
|
||||||
|
|
||||||
|
self.assertAllEqual(expected_output, output_data)
|
||||||
|
|
||||||
|
|
||||||
class GetVocabularyFromFileTest(test.TestCase):
|
class GetVocabularyFromFileTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -647,7 +647,8 @@ class TextFileInitializer(TableInitializerBase):
|
|||||||
value_index,
|
value_index,
|
||||||
vocab_size=None,
|
vocab_size=None,
|
||||||
delimiter="\t",
|
delimiter="\t",
|
||||||
name=None):
|
name=None,
|
||||||
|
value_index_offset=0):
|
||||||
"""Constructs a table initializer object to populate from a text file.
|
"""Constructs a table initializer object to populate from a text file.
|
||||||
|
|
||||||
It generates one key-value pair per line. The type of table key and
|
It generates one key-value pair per line. The type of table key and
|
||||||
@ -675,6 +676,13 @@ class TextFileInitializer(TableInitializerBase):
|
|||||||
vocab_size: The number of elements in the file, if known.
|
vocab_size: The number of elements in the file, if known.
|
||||||
delimiter: The delimiter to separate fields in a line.
|
delimiter: The delimiter to separate fields in a line.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
value_index_offset: A number to add to all indices extracted from the file
|
||||||
|
This is useful for cases where a user would like to reserve one or more
|
||||||
|
low index values for control characters. For instance, if you would
|
||||||
|
like to ensure that no vocabulary item is mapped to index 0 (so you can
|
||||||
|
reserve 0 for a masking value), you can set value_index_offset to 1;
|
||||||
|
this will mean that the first vocabulary element is mapped to 1
|
||||||
|
instead of 0.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: when the filename is empty, or when the table key and value
|
ValueError: when the filename is empty, or when the table key and value
|
||||||
@ -718,6 +726,7 @@ class TextFileInitializer(TableInitializerBase):
|
|||||||
self._name = name
|
self._name = name
|
||||||
self._filename = self._track_trackable(
|
self._filename = self._track_trackable(
|
||||||
trackable.Asset(filename), "_filename")
|
trackable.Asset(filename), "_filename")
|
||||||
|
self._offset = value_index_offset
|
||||||
|
|
||||||
super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
|
super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
|
||||||
|
|
||||||
@ -740,7 +749,8 @@ class TextFileInitializer(TableInitializerBase):
|
|||||||
self._filename, dtypes.string, name="asset_filepath")
|
self._filename, dtypes.string, name="asset_filepath")
|
||||||
init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
|
init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
|
||||||
table.resource_handle, filename, self._key_index, self._value_index,
|
table.resource_handle, filename, self._key_index, self._value_index,
|
||||||
-1 if self._vocab_size is None else self._vocab_size, self._delimiter)
|
-1 if self._vocab_size is None else self._vocab_size, self._delimiter,
|
||||||
|
self._offset)
|
||||||
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
|
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
|
||||||
# If the filename tensor is anything other than a string constant (e.g.,
|
# If the filename tensor is anything other than a string constant (e.g.,
|
||||||
# if it is a placeholder) then it does not make sense to track it as an
|
# if it is a placeholder) then it does not make sense to track it as an
|
||||||
|
@ -14,7 +14,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
|
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\', \'value_index_offset\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\', \'0\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "initialize"
|
name: "initialize"
|
||||||
|
@ -1950,11 +1950,11 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "InitializeTableFromTextFile"
|
name: "InitializeTableFromTextFile"
|
||||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
|
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "InitializeTableFromTextFileV2"
|
name: "InitializeTableFromTextFileV2"
|
||||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
|
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "InitializeTableV2"
|
name: "InitializeTableV2"
|
||||||
|
@ -14,7 +14,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\'], "
|
argspec: "args=[\'self\', \'filename\', \'key_dtype\', \'key_index\', \'value_dtype\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\', \'value_index_offset\'], varargs=None, keywords=None, defaults=[\'None\', \'\\t\', \'None\', \'0\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "initialize"
|
name: "initialize"
|
||||||
|
@ -1950,11 +1950,11 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "InitializeTableFromTextFile"
|
name: "InitializeTableFromTextFile"
|
||||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
|
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "InitializeTableFromTextFileV2"
|
name: "InitializeTableFromTextFileV2"
|
||||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
|
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'offset\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "InitializeTableV2"
|
name: "InitializeTableV2"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user