Open-source skip-gram ops

PiperOrigin-RevId: 157655970
This commit is contained in:
Wei Ho 2017-05-31 17:07:05 -07:00 committed by TensorFlower Gardener
parent faac0331c2
commit 458f94c128
12 changed files with 1377 additions and 0 deletions

View File

@ -281,6 +281,7 @@ filegroup(
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
"//tensorflow/contrib/tensorboard:all_files",
"//tensorflow/contrib/testing:all_files",
"//tensorflow/contrib/text:all_files",
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
"//tensorflow/contrib/training:all_files",
"//tensorflow/contrib/util:all_files",

View File

@ -66,6 +66,7 @@ py_library(
"//tensorflow/contrib/tensor_forest:init_py",
"//tensorflow/contrib/tensorboard",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/contrib/text:text_py",
"//tensorflow/contrib/tfprof",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
@ -84,6 +85,7 @@ cc_library(
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
"//tensorflow/contrib/nccl:nccl_kernels",
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
"//tensorflow/contrib/text:all_kernels",
],
)
@ -98,6 +100,7 @@ cc_library(
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
"//tensorflow/contrib/text:all_ops",
],
)

View File

@ -82,6 +82,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc"
"${tensorflow_source_dir}/tensorflow/contrib/text/kernels/skip_gram_kernels.cc"
"${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc"
)
list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs})
endif(tensorflow_BUILD_CONTRIB_KERNELS)

View File

@ -81,6 +81,7 @@ GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contri
GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}")
GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc")
########################################################

View File

@ -507,6 +507,11 @@ add_python_module("tensorflow/contrib/tensor_forest/python/ops")
add_python_module("tensorflow/contrib/testing")
add_python_module("tensorflow/contrib/testing/python")
add_python_module("tensorflow/contrib/testing/python/framework")
add_python_module("tensorflow/contrib/text")
add_python_module("tensorflow/contrib/text/kernels")
add_python_module("tensorflow/contrib/text/ops")
add_python_module("tensorflow/contrib/text/python")
add_python_module("tensorflow/contrib/text/python/ops")
add_python_module("tensorflow/contrib/tfprof" DONTCOPY) # SWIG wrapper not implemented.
#add_python_module("tensorflow/contrib/tfprof/python")
#add_python_module("tensorflow/contrib/tfprof/python/tools")
@ -644,6 +649,8 @@ GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_tensor_forest_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_hybrid_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/hybrid/ops/gen_training_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py)
GENERATE_PYTHON_OP_LIB("stateless_random_ops"

View File

@ -0,0 +1,119 @@
# Description:
# contains parts of TensorFlow that are experimental or unstable and which
# are not supported.
package(default_visibility = [
"//learning/brain:__subpackages__",
"//tensorflow:__subpackages__",
])
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_custom_op_py_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
"tf_kernel_library",
)
tf_custom_op_py_library(
name = "text_py",
srcs = [
"__init__.py",
"python/ops/__init__.py",
"python/ops/skip_gram_ops.py",
],
dso = [
":python/ops/_skip_gram_ops.so",
],
kernels = [
":all_kernels",
":all_ops",
],
srcs_version = "PY2AND3",
deps = [
":gen_skip_gram_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
],
)
tf_kernel_library(
name = "skip_gram_kernels",
srcs = ["kernels/skip_gram_kernels.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
alwayslink = 1,
)
cc_library(
name = "all_kernels",
deps = [":skip_gram_kernels"],
)
tf_custom_op_library(
name = "python/ops/_skip_gram_ops.so",
srcs = [
"kernels/skip_gram_kernels.cc",
"ops/skip_gram_ops.cc",
],
)
tf_gen_op_libs(
op_lib_names = ["skip_gram_ops"],
)
cc_library(
name = "all_ops",
deps = [":skip_gram_ops_op_lib"],
)
tf_gen_op_wrapper_py(
name = "gen_skip_gram_ops",
out = "python/ops/gen_skip_gram_ops.py",
deps = [":skip_gram_ops_op_lib"],
)
py_test(
name = "skip_gram_ops_test",
size = "medium",
srcs = ["python/ops/skip_gram_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":text_py",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_seed",
"//tensorflow/python:training",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
)

View File

@ -0,0 +1,30 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Text-processing ops.
@@skip_gram_sample
@@skip_gram_sample_with_text_vocab
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.text.python.ops import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -0,0 +1,139 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <string>
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/util/guarded_philox_random.h"
namespace tensorflow {
template <typename T>
class SkipGramGenerateCandidatesOp : public OpKernel {
public:
explicit SkipGramGenerateCandidatesOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, generator_.Init(context));
}
void Compute(OpKernelContext* context) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input_tensor", &input_tensor));
const auto input = input_tensor->flat<T>();
const Tensor* min_skips_tensor;
OP_REQUIRES_OK(context, context->input("min_skips", &min_skips_tensor));
const int min_skips = *(min_skips_tensor->scalar<int>().data());
const Tensor* max_skips_tensor;
OP_REQUIRES_OK(context, context->input("max_skips", &max_skips_tensor));
const int max_skips = *(max_skips_tensor->scalar<int>().data());
OP_REQUIRES(
context, min_skips >= 0 && max_skips >= 0,
errors::InvalidArgument("Both min_skips and max_skips must be >= 0."));
OP_REQUIRES(context, min_skips <= max_skips,
errors::InvalidArgument("min_skips must be <= max_skips."));
const Tensor* start_tensor;
OP_REQUIRES_OK(context, context->input("start", &start_tensor));
const int start = *(start_tensor->scalar<int>().data());
const Tensor* limit_tensor;
OP_REQUIRES_OK(context, context->input("limit", &limit_tensor));
const int limit = *(limit_tensor->scalar<int>().data());
const int end =
limit < 0 ? input.size()
: std::min(start + limit, static_cast<int>(input.size()));
const Tensor* emit_self_tensor;
OP_REQUIRES_OK(context,
context->input("emit_self_as_target", &emit_self_tensor));
const bool emit_self_as_target = *(emit_self_tensor->scalar<bool>().data());
std::vector<T> tokens;
std::vector<T> labels;
// Reserve the number of random numbers we will use - we use one for each
// token between start and end.
random::PhiloxRandom local_gen =
generator_.ReserveSamples32(end - start + 1);
random::SimplePhilox rng(&local_gen);
// For each token in the sentence, pick a random skip, then generates
// (token, label) pairs for all labels whose distances from the token are
// within the range [-skip, skip].
for (int i = start; i < end; ++i) {
const int skips = min_skips + rng.Uniform(max_skips - min_skips + 1);
for (int j = -skips; j <= skips; ++j) {
if ((i + j < start) || (i + j >= end) ||
(j == 0 && !emit_self_as_target)) {
continue;
}
tokens.push_back(input(i));
labels.push_back(input(i + j));
}
}
Tensor* tokens_output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
"tokens", TensorShape({static_cast<int>(tokens.size())}),
&tokens_output));
Tensor* labels_output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
"labels", TensorShape({static_cast<int>(labels.size())}),
&labels_output));
OP_REQUIRES(
context, tokens_output->IsSameSize(*labels_output),
errors::Internal(strings::StrCat(
"Mismatch between tokens_output shape of ",
tokens_output->shape().DebugString(),
" and labels_output shape of ",
labels_output->shape().DebugString(),
". This should never happen - contact ami-team@ if it does.")));
// Copies results to output tensors.
for (int i = 0; i < tokens.size(); ++i) {
tokens_output->vec<T>()(i) = tokens[i];
labels_output->vec<T>()(i) = labels[i];
}
}
private:
GuardedPhiloxRandom generator_;
};
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("SkipGramGenerateCandidates") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
SkipGramGenerateCandidatesOp<type>)
REGISTER_KERNEL(string);
REGISTER_KERNEL(int64);
REGISTER_KERNEL(int32);
REGISTER_KERNEL(int16);
// TODO(weiho): Add other types if the need arises.
#undef REGISTER_KERNEL
} // namespace tensorflow

View File

@ -0,0 +1,54 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("SkipGramGenerateCandidates")
.Input("input_tensor: T")
.Input("min_skips: int32")
.Input("max_skips: int32")
.Input("start: int32")
.Input("limit: int32")
.Input("emit_self_as_target: bool")
.Output("tokens: T")
.Output("labels: T")
.Attr("T: type")
// The seed attributes are needed by GuardedPhiloxRandom
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// input_tensor must be of rank-1.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
// All other args must be scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
// Due to possible randomness in selecting skips, we only know that the
// outputs will be of rank-1, but not their sizes.
c->set_output(0, c->Vector(c->UnknownDim()));
c->set_output(1, c->Vector(c->UnknownDim()));
return Status::OK();
})
.Doc(R"doc(
Generates skip-gram token and label paired Tensors from the input tensor.
See docs for the public-facing skip_gram_sample() Python op for more details.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,22 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various contrib ops related to text-processing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.text.python.ops.skip_gram_ops import skip_gram_sample
from tensorflow.contrib.text.python.ops.skip_gram_ops import skip_gram_sample_with_text_vocab

View File

@ -0,0 +1,428 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Skip-gram sampling ops from https://arxiv.org/abs/1301.3781."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import csv
from tensorflow.contrib import lookup
from tensorflow.contrib.text.python.ops import gen_skip_gram_ops
from tensorflow.contrib.util import loader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import resource_loader
from tensorflow.python.training import input as input_ops
_checkpoint_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_skip_gram_ops.so"))
ops.NotDifferentiable("SkipGramGenerateCandidates")
def skip_gram_sample(input_tensor,
min_skips=1,
max_skips=5,
start=0,
limit=-1,
emit_self_as_target=False,
vocab_freq_table=None,
vocab_min_count=None,
vocab_subsampling=None,
corpus_size=None,
batch_size=None,
batch_capacity=None,
seed=None,
name=None):
"""Generates skip-gram token and label paired Tensors from the input tensor.
Generates skip-gram `("token", "label")` pairs using each element in the
rank-1 `input_tensor` as a token. The window size used for each token will be
randomly selected from the range specified by `[min_skips, max_skips]`,
inclusive. See https://arxiv.org/abs/1301.3781 for more details about
skip-gram.
For example, given `input_tensor = ["the", "quick", "brown", "fox", "jumps"]`,
`min_skips = 1`, `max_skips = 2`, `emit_self_as_target = False`, the output
`(tokens, labels)` pairs for the token "quick" will be randomly selected from
either `(tokens=["quick", "quick"], labels=["the", "brown"])` for 1 skip, or
`(tokens=["quick", "quick", "quick"], labels=["the", "brown", "fox"])` for 2
skips.
If `emit_self_as_target = True`, each token will also be emitted as a label
for itself. From the previous example, the output will be either
`(tokens=["quick", "quick", "quick"], labels=["the", "quick", "brown"])` for 1
skip, or `(tokens=["quick", "quick", "quick", "quick"], labels=["the",
"quick", "brown", "fox"])` for 2 skips.
The same process is repeated for each element of `input_tensor` and
concatenated together into the two output rank-1 `Tensors` (one for all the
tokens, another for all the labels).
If `vocab_freq_table` is specified, tokens in `input_tensor` that are not
present in the vocabulary are discarded. Tokens whose frequency counts are
below `vocab_min_count` are also discarded. Tokens whose frequency proportions
in the corpus exceed `vocab_subsampling` may be randomly down-sampled. See
Eq. 5 in http://arxiv.org/abs/1310.4546 for more details about subsampling.
Due to the random window sizes used for each token, the lengths of the outputs
are non-deterministic, unless `batch_size` is specified to batch the outputs
to always return `Tensors` of length `batch_size`.
Args:
input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates.
min_skips: `int` or scalar `Tensor` specifying the minimum window size to
randomly use for each token. Must be >= 0 and <= `max_skips`. If
`min_skips` and `max_skips` are both 0, the only label outputted will be
the token itself when `emit_self_as_target = True` - or no output
otherwise.
max_skips: `int` or scalar `Tensor` specifying the maximum window size to
randomly use for each token. Must be >= 0.
start: `int` or scalar `Tensor` specifying the position in
`input_tensor` from which to start generating skip-gram candidates.
limit: `int` or scalar `Tensor` specifying the maximum number of
elements in `input_tensor` to use in generating skip-gram candidates. -1
means to use the rest of the `Tensor` after `start`.
emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
each token as a label for itself.
vocab_freq_table: (Optional) A lookup table (subclass of
`lookup.InitializableLookupTableBase`) that maps tokens to their raw
frequency counts. If specified, any token in `input_tensor` that is not
found in `vocab_freq_table` will be filtered out before generating
skip-gram candidates. While this will typically map to integer raw
frequency counts, it could also map to float frequency proportions.
`vocab_min_count` and `corpus_size` should be in the same units as this.
vocab_min_count: (Optional) `int`, `float`, or scalar `Tensor` specifying
minimum frequency threshold (from `vocab_freq_table`) for a token to be
kept in `input_tensor`. If this is specified, `vocab_freq_table` must also
be specified - and they should both be in the same units.
vocab_subsampling: (Optional) `float` specifying frequency proportion
threshold for tokens from `input_tensor`. Tokens that occur more
frequently (based on the ratio of the token's `vocab_freq_table` value to
the `corpus_size`) will be randomly down-sampled. Reasonable starting
values may be around 1e-3 or 1e-5. If this is specified, both
`vocab_freq_table` and `corpus_size` must also be specified. See Eq. 5
in http://arxiv.org/abs/1310.4546 for more details.
corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the
total number of tokens in the corpus (e.g., sum of all the frequency
counts of `vocab_freq_table`). Used with `vocab_subsampling` for
down-sampling frequently occurring tokens. If this is specified,
`vocab_freq_table` and `vocab_subsampling` must also be specified.
batch_size: (Optional) `int` specifying batch size of returned `Tensors`.
batch_capacity: (Optional) `int` specifying batch capacity for the queue
used for batching returned `Tensors`. Only has an effect if
`batch_size` > 0. Defaults to 100 * `batch_size` if not specified.
seed: (Optional) `int` used to create a random seed for window size and
subsampling. See `set_random_seed` docs for behavior.
name: (Optional) A `string` name or a name scope for the operations.
Returns:
A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
rank-1 and has the same type as `input_tensor`. The `Tensors` will be of
length `batch_size`; if `batch_size` is not specified, they will be of
random length, though they will be in sync with each other as long as they
are evaluated together.
Raises:
ValueError: If `vocab_freq_table` is not provided, but `vocab_min_count`,
`vocab_subsampling`, or `corpus_size` is specified. If `vocab_subsampling`
and `corpus_size` are not both present or both absent.
"""
if vocab_freq_table is None and (vocab_min_count is not None or
vocab_subsampling is not None or
corpus_size is not None):
raise ValueError(
"vocab_freq_table is not provided, but vocab_min_count={}, "
"vocab_subsampling={}, or corpus_size={} is not None. These settings "
"are useless without a vocab_freq_table.".format(
vocab_min_count, vocab_subsampling, corpus_size))
if (vocab_subsampling is None) != (corpus_size is None):
raise ValueError(
"vocab_subsampling is {} while corpus_size is {} - both must be "
"provided in order for subsampling to work.".format(
vocab_subsampling, corpus_size))
with ops.name_scope(
name,
"skip_gram_sample",
values=[input_tensor, min_skips, max_skips, start, limit]):
input_tensor = _filter_input(
input_tensor=input_tensor,
vocab_freq_table=vocab_freq_table,
vocab_min_count=vocab_min_count,
vocab_subsampling=vocab_subsampling,
corpus_size=corpus_size,
seed=seed)
seed1, seed2 = random_seed.get_seed(seed)
tokens, labels = gen_skip_gram_ops.skip_gram_generate_candidates(
input_tensor=input_tensor,
min_skips=min_skips,
max_skips=max_skips,
start=start,
limit=limit,
emit_self_as_target=emit_self_as_target,
# Note that seed here should be seed1! This is due to
# GuardedPhiloxRandom's hard-coded attributes of "seed" and "seed2".
seed=seed1,
seed2=seed2)
# TODO(weiho): If the need arises, add support for sparse input_tensor that
# figures out sentence boundaries, then calls
# skip_gram_generate_candidates() on each sentence.
# Batches the (tokens, labels) outputs so that they will be of deterministic
# batch_size, to facilitate feeding them into the rest of the network.
if batch_size is not None and batch_size > 0:
batch_capacity = (batch_capacity
if (batch_capacity is not None and batch_capacity > 0)
else 100 * batch_size)
return input_ops.batch(
[tokens, labels],
batch_size,
capacity=batch_capacity,
enqueue_many=True)
return tokens, labels
def skip_gram_sample_with_text_vocab(input_tensor,
vocab_freq_file,
vocab_token_index=0,
vocab_token_dtype=dtypes.string,
vocab_freq_index=1,
vocab_freq_dtype=dtypes.float64,
vocab_delimiter=",",
vocab_min_count=0,
vocab_subsampling=None,
min_skips=1,
max_skips=5,
start=0,
limit=-1,
emit_self_as_target=False,
batch_size=None,
batch_capacity=None,
seed=None,
name=None):
"""Skip-gram sampling with a text vocabulary file.
Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The
vocabulary file is expected to be a plain-text file, with lines of
`vocab_delimiter`-separated columns. The `vocab_token_index` column should
contain the vocabulary term, while the `vocab_freq_index` column should
contain the number of times that term occurs in the corpus. For example, with
a text vocabulary file of:
```
bonjour,fr,42
hello,en,777
hola,es,99
```
You should set `vocab_delimiter=","`, `vocab_token_index=0`, and
`vocab_freq_index=2`.
See `skip_gram_sample()` documentation for more details about the skip-gram
sampling process.
Args:
input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates.
vocab_freq_file: `string` specifying full file path to the text vocab file.
vocab_token_index: `int` specifying which column in the text vocab file
contains the tokens.
vocab_token_dtype: `DType` specifying the format of the tokens in the text
vocab file.
vocab_freq_index: `int` specifying which column in the text vocab file
contains the frequency counts of the tokens.
vocab_freq_dtype: `DType` specifying the format of the frequency counts in
the text vocab file.
vocab_delimiter: `string` specifying the delimiter used in the text vocab
file.
vocab_min_count: `int`, `float`, or scalar `Tensor` specifying
minimum frequency threshold (from `vocab_freq_file`) for a token to be
kept in `input_tensor`. This should correspond with `vocab_freq_dtype`.
vocab_subsampling: (Optional) `float` specifying frequency proportion
threshold for tokens from `input_tensor`. Tokens that occur more
frequently will be randomly down-sampled. Reasonable starting values may
be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for
more details.
min_skips: `int` or scalar `Tensor` specifying the minimum window size to
randomly use for each token. Must be >= 0 and <= `max_skips`. If
`min_skips` and `max_skips` are both 0, the only label outputted will be
the token itself.
max_skips: `int` or scalar `Tensor` specifying the maximum window size to
randomly use for each token. Must be >= 0.
start: `int` or scalar `Tensor` specifying the position in `input_tensor`
from which to start generating skip-gram candidates.
limit: `int` or scalar `Tensor` specifying the maximum number of elements in
`input_tensor` to use in generating skip-gram candidates. -1 means to use
the rest of the `Tensor` after `start`.
emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
each token as a label for itself.
batch_size: (Optional) `int` specifying batch size of returned `Tensors`.
batch_capacity: (Optional) `int` specifying batch capacity for the queue
used for batching returned `Tensors`. Only has an effect if
`batch_size` > 0. Defaults to 100 * `batch_size` if not specified.
seed: (Optional) `int` used to create a random seed for window size and
subsampling. See
[`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed)
for behavior.
name: (Optional) A `string` name or a name scope for the operations.
Returns:
A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
rank-1 and has the same type as `input_tensor`. The `Tensors` will be of
length `batch_size`; if `batch_size` is not specified, they will be of
random length, though they will be in sync with each other as long as they
are evaluated together.
Raises:
ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or
exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index`
and `vocab_freq_index` are both set to the same column. If any token in
`vocab_freq_file` has a negative frequency.
"""
if vocab_token_index < 0 or vocab_freq_index < 0:
raise ValueError(
"vocab_token_index={} and vocab_freq_index={} must both be >= 0.".
format(vocab_token_index, vocab_freq_index))
if vocab_token_index == vocab_freq_index:
raise ValueError(
"vocab_token_index and vocab_freq_index should be different, but are "
"both {}.".format(vocab_token_index))
# Iterates through the vocab file and calculates the number of vocab terms as
# well as the total corpus size (by summing the frequency counts of all the
# vocab terms).
corpus_size = 0.0
vocab_size = 0
with gfile.GFile(vocab_freq_file, mode="r") as f:
reader = csv.reader(f, delimiter=vocab_delimiter)
for row in reader:
if vocab_token_index >= len(row) or vocab_freq_index >= len(row):
raise ValueError(
"Row in vocab file only has {} columns, so vocab_token_index={} or "
"vocab_freq_index={} is out of bounds. Row content: {}".format(
len(row), vocab_token_index, vocab_freq_index, row))
vocab_size += 1
freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index])
if freq < 0:
raise ValueError(
"Row in vocab file has negative frequency of {}. Row content: {}".
format(freq, row))
# Note: tokens whose frequencies are below vocab_min_count will still
# contribute to the total corpus size used for vocab subsampling.
corpus_size += freq
vocab_freq_table = lookup.HashTable(
lookup.TextFileInitializer(
filename=vocab_freq_file,
key_dtype=vocab_token_dtype,
key_index=vocab_token_index,
value_dtype=vocab_freq_dtype,
value_index=vocab_freq_index,
vocab_size=vocab_size,
delimiter=vocab_delimiter),
# For vocab terms not in vocab file, use a default value of -1.
default_value=-1)
return skip_gram_sample(
input_tensor,
min_skips=min_skips,
max_skips=max_skips,
start=start,
limit=limit,
emit_self_as_target=emit_self_as_target,
vocab_freq_table=vocab_freq_table,
vocab_min_count=vocab_min_count,
vocab_subsampling=vocab_subsampling,
# corpus_size is not used unless vocab_subsampling is specified.
corpus_size=None if vocab_subsampling is None else corpus_size,
batch_size=batch_size,
batch_capacity=batch_capacity,
seed=seed,
name=name)
def _filter_input(input_tensor, vocab_freq_table, vocab_min_count,
vocab_subsampling, corpus_size, seed):
"""Filters input tensor based on vocab freq, threshold, and subsampling."""
if vocab_freq_table is None:
return input_tensor
if not isinstance(vocab_freq_table, lookup.InitializableLookupTableBase):
raise ValueError(
"vocab_freq_table must be a subclass of "
"InitializableLookupTableBase (such as HashTable) instead of type "
"{}.".format(type(vocab_freq_table)))
with ops.name_scope(
"filter_vocab", values=[vocab_freq_table, input_tensor, vocab_min_count]):
freq = vocab_freq_table.lookup(input_tensor)
# Filters out elements in input_tensor that are not found in
# vocab_freq_table (table returns a default value of -1 specified above when
# an element is not found).
mask = math_ops.not_equal(freq, vocab_freq_table.default_value)
# Filters out elements whose vocab frequencies are less than the threshold.
if vocab_min_count is not None:
cast_threshold = math_ops.cast(vocab_min_count, freq.dtype)
mask = math_ops.logical_and(mask,
math_ops.greater_equal(freq, cast_threshold))
input_tensor = array_ops.boolean_mask(input_tensor, mask)
freq = array_ops.boolean_mask(freq, mask)
if not vocab_subsampling:
return input_tensor
if vocab_subsampling < 0 or vocab_subsampling > 1:
raise ValueError(
"Invalid vocab_subsampling={} - it should be within range [0, 1].".
format(vocab_subsampling))
# Subsamples the input tokens based on vocabulary frequency and
# vocab_subsampling threshold (ie randomly discard commonly appearing
# tokens).
with ops.name_scope(
"subsample_vocab", values=[input_tensor, freq, vocab_subsampling]):
corpus_size = math_ops.cast(corpus_size, dtypes.float64)
freq = math_ops.cast(freq, dtypes.float64)
vocab_subsampling = math_ops.cast(vocab_subsampling, dtypes.float64)
# From tensorflow_models/tutorials/embedding/word2vec_kernels.cc, which is
# suppose to correlate with Eq. 5 in http://arxiv.org/abs/1310.4546.
keep_prob = ((math_ops.sqrt(freq /
(vocab_subsampling * corpus_size)) + 1.0) *
(vocab_subsampling * corpus_size / freq))
random_prob = random_ops.random_uniform(
array_ops.shape(freq),
minval=0,
maxval=1,
dtype=dtypes.float64,
seed=seed)
mask = math_ops.less_equal(random_prob, keep_prob)
return array_ops.boolean_mask(input_tensor, mask)

View File

@ -0,0 +1,571 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Skip-gram sampling ops tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import csv
import os
from tensorflow.contrib import lookup
from tensorflow.contrib import text
from tensorflow.contrib.text.python.ops import skip_gram_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
from tensorflow.python.training import queue_runner_impl
class SkipGramOpsTest(test.TestCase):
def _split_tokens_labels(self, output):
tokens = [x[0] for x in output]
labels = [x[1] for x in output]
return tokens, labels
def test_skip_gram_sample_skips_2(self):
"""Tests skip-gram with min_skips = max_skips = 2."""
input_tensor = constant_op.constant(
[b"the", b"quick", b"brown", b"fox", b"jumps"])
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=2, max_skips=2)
expected_tokens, expected_labels = self._split_tokens_labels([
(b"the", b"quick"),
(b"the", b"brown"),
(b"quick", b"the"),
(b"quick", b"brown"),
(b"quick", b"fox"),
(b"brown", b"the"),
(b"brown", b"quick"),
(b"brown", b"fox"),
(b"brown", b"jumps"),
(b"fox", b"quick"),
(b"fox", b"brown"),
(b"fox", b"jumps"),
(b"jumps", b"brown"),
(b"jumps", b"fox"),
])
with self.test_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_emit_self(self):
"""Tests skip-gram with emit_self_as_target = True."""
input_tensor = constant_op.constant(
[b"the", b"quick", b"brown", b"fox", b"jumps"])
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True)
expected_tokens, expected_labels = self._split_tokens_labels([
(b"the", b"the"),
(b"the", b"quick"),
(b"the", b"brown"),
(b"quick", b"the"),
(b"quick", b"quick"),
(b"quick", b"brown"),
(b"quick", b"fox"),
(b"brown", b"the"),
(b"brown", b"quick"),
(b"brown", b"brown"),
(b"brown", b"fox"),
(b"brown", b"jumps"),
(b"fox", b"quick"),
(b"fox", b"brown"),
(b"fox", b"fox"),
(b"fox", b"jumps"),
(b"jumps", b"brown"),
(b"jumps", b"fox"),
(b"jumps", b"jumps"),
])
with self.test_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_skips_0(self):
"""Tests skip-gram with min_skips = max_skips = 0."""
input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
# If emit_self_as_target is False (default), output will be empty.
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
with self.test_session():
self.assertEqual(0, tokens.eval().size)
self.assertEqual(0, labels.eval().size)
# If emit_self_as_target is True, each token will be its own label.
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True)
expected_tokens, expected_labels = self._split_tokens_labels([
(b"the", b"the"),
(b"quick", b"quick"),
(b"brown", b"brown"),
])
with self.test_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_skips_exceed_length(self):
"""Tests skip-gram when min/max_skips exceed length of input."""
input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=100, max_skips=100)
expected_tokens, expected_labels = self._split_tokens_labels([
(b"the", b"quick"),
(b"the", b"brown"),
(b"quick", b"the"),
(b"quick", b"brown"),
(b"brown", b"the"),
(b"brown", b"quick"),
])
with self.test_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_start_limit(self):
"""Tests skip-gram over a limited portion of the input."""
input_tensor = constant_op.constant(
[b"foo", b"the", b"quick", b"brown", b"bar"])
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=1, start=1, limit=3)
expected_tokens, expected_labels = self._split_tokens_labels([
(b"the", b"quick"),
(b"quick", b"the"),
(b"quick", b"brown"),
(b"brown", b"quick"),
])
with self.test_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_limit_exceeds(self):
"""Tests skip-gram when limit exceeds the length of the input."""
input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"])
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=1, start=1, limit=100)
expected_tokens, expected_labels = self._split_tokens_labels([
(b"the", b"quick"),
(b"quick", b"the"),
(b"quick", b"brown"),
(b"brown", b"quick"),
])
with self.test_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_random_skips(self):
"""Tests skip-gram with min_skips != max_skips, with random output."""
# The number of outputs is non-deterministic in this case, so set random
# seed to help ensure the outputs remain constant for this test case.
random_seed.set_random_seed(42)
input_tensor = constant_op.constant(
[b"the", b"quick", b"brown", b"fox", b"jumps", b"over"])
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=2, seed=9)
expected_tokens, expected_labels = self._split_tokens_labels([
(b"the", b"quick"),
(b"the", b"brown"),
(b"quick", b"the"),
(b"quick", b"brown"),
(b"quick", b"fox"),
(b"brown", b"the"),
(b"brown", b"quick"),
(b"brown", b"fox"),
(b"brown", b"jumps"),
(b"fox", b"brown"),
(b"fox", b"jumps"),
(b"jumps", b"fox"),
(b"jumps", b"over"),
(b"over", b"fox"),
(b"over", b"jumps"),
])
with self.test_session() as sess:
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens, tokens_eval)
self.assertAllEqual(expected_labels, labels_eval)
def test_skip_gram_sample_random_skips_default_seed(self):
"""Tests outputs are still random when no op-level seed is specified."""
# This is needed since tests set a graph-level seed by default. We want to
# explicitly avoid setting both graph-level seed and op-level seed, to
# simulate behavior under non-test settings when the user doesn't provide a
# seed to us. This results in random_seed.get_seed() returning None for both
# seeds, forcing the C++ kernel to execute its default seed logic.
random_seed.set_random_seed(None)
# Uses an input tensor with 10 words, with possible skip ranges in [1,
# 5]. Thus, the probability that two random samplings would result in the
# same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being
# flaky).
input_tensor = constant_op.constant([str(x) for x in range(10)])
# Do not provide an op-level seed here!
tokens_1, labels_1 = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=5)
tokens_2, labels_2 = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=5)
with self.test_session() as sess:
tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
[tokens_1, labels_1, tokens_2, labels_2])
if len(tokens_1_eval) == len(tokens_2_eval):
self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist())
if len(labels_1_eval) == len(labels_2_eval):
self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist())
def test_skip_gram_sample_batch(self):
"""Tests skip-gram with batching."""
input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"])
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=1, batch_size=3)
expected_tokens, expected_labels = self._split_tokens_labels([
(b"the", b"quick"),
(b"quick", b"the"),
(b"quick", b"brown"),
(b"brown", b"quick"),
(b"brown", b"fox"),
(b"fox", b"brown"),
])
with self.test_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens[:3], tokens_eval)
self.assertAllEqual(expected_labels[:3], labels_eval)
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens[3:6], tokens_eval)
self.assertAllEqual(expected_labels[3:6], labels_eval)
coord.request_stop()
coord.join(threads)
def test_skip_gram_sample_non_string_input(self):
"""Tests skip-gram with non-string input."""
input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16)
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=1)
expected_tokens, expected_labels = self._split_tokens_labels([
(1, 2),
(2, 1),
(2, 3),
(3, 2),
])
with self.test_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_errors(self):
"""Tests various errors raised by skip_gram_sample()."""
input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
invalid_skips = (
# min_skips and max_skips must be >= 0.
(-1, 2),
(1, -2),
# min_skips must be <= max_skips.
(2, 1))
for min_skips, max_skips in invalid_skips:
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=min_skips, max_skips=max_skips)
with self.test_session() as sess, self.assertRaises(
errors.InvalidArgumentError):
sess.run([tokens, labels])
# input_tensor must be of rank 1.
with self.assertRaises(ValueError):
invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]])
text.skip_gram_sample(invalid_tensor)
# vocab_freq_table must be provided if vocab_min_count, vocab_subsampling,
# or corpus_size is specified.
dummy_input = constant_op.constant([""])
with self.assertRaises(ValueError):
text.skip_gram_sample(
dummy_input, vocab_freq_table=None, vocab_min_count=1)
with self.assertRaises(ValueError):
text.skip_gram_sample(
dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5)
with self.assertRaises(ValueError):
text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100)
with self.assertRaises(ValueError):
text.skip_gram_sample(
dummy_input,
vocab_freq_table=None,
vocab_subsampling=1e-5,
corpus_size=100)
# vocab_subsampling and corpus_size must both be present or absent.
dummy_table = lookup.HashTable(
lookup.KeyValueTensorInitializer([b"foo"], [10]), -1)
with self.assertRaises(ValueError):
text.skip_gram_sample(
dummy_input,
vocab_freq_table=dummy_table,
vocab_subsampling=None,
corpus_size=100)
with self.assertRaises(ValueError):
text.skip_gram_sample(
dummy_input,
vocab_freq_table=dummy_table,
vocab_subsampling=1e-5,
corpus_size=None)
def test_filter_input_filter_vocab(self):
"""Tests input filtering based on vocab frequency table and thresholds."""
input_tensor = constant_op.constant(
[b"the", b"answer", b"to", b"life", b"and", b"universe"])
keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
vocab_freq_table = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), -1)
with self.test_session():
vocab_freq_table.init.run()
# No vocab_freq_table specified - output should be the same as input.
no_table_output = skip_gram_ops._filter_input(
input_tensor=input_tensor,
vocab_freq_table=None,
vocab_min_count=None,
vocab_subsampling=None,
corpus_size=None,
seed=None)
self.assertAllEqual(input_tensor.eval(), no_table_output.eval())
# vocab_freq_table specified, but no vocab_min_count - output should have
# filtered out tokens not in the table (b"answer").
table_output = skip_gram_ops._filter_input(
input_tensor=input_tensor,
vocab_freq_table=vocab_freq_table,
vocab_min_count=None,
vocab_subsampling=None,
corpus_size=None,
seed=None)
self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"],
table_output.eval())
# vocab_freq_table and vocab_min_count specified - output should have
# filtered out tokens whose frequencies are below the threshold
# (b"and": 0, b"life": 1).
threshold_output = skip_gram_ops._filter_input(
input_tensor=input_tensor,
vocab_freq_table=vocab_freq_table,
vocab_min_count=2,
vocab_subsampling=None,
corpus_size=None,
seed=None)
self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval())
def test_filter_input_subsample_vocab(self):
"""Tests input filtering based on vocab subsampling."""
# The outputs are non-deterministic, so set random seed to help ensure that
# the outputs remain constant for testing.
random_seed.set_random_seed(42)
input_tensor = constant_op.constant([
# keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
b"the",
b"answer", # Not in vocab. (Always discarded)
b"to", # keep_prob = 0.75.
b"life", # keep_prob > 1. (Always kept)
b"and", # keep_prob = 0.48.
b"universe" # Below vocab threshold of 3. (Always discarded)
])
keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64)
vocab_freq_table = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), -1)
with self.test_session():
vocab_freq_table.init.run()
output = skip_gram_ops._filter_input(
input_tensor=input_tensor,
vocab_freq_table=vocab_freq_table,
vocab_min_count=3,
vocab_subsampling=0.05,
corpus_size=math_ops.reduce_sum(values),
seed=9)
self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval())
def _make_text_vocab_freq_file(self):
filepath = os.path.join(test.get_temp_dir(), "vocab_freq.txt")
with open(filepath, "w") as f:
writer = csv.writer(f)
writer.writerows([
["and", 40],
["life", 8],
["the", 30],
["to", 20],
["universe", 2],
])
return filepath
def _make_text_vocab_float_file(self):
filepath = os.path.join(test.get_temp_dir(), "vocab_freq_float.txt")
with open(filepath, "w") as f:
writer = csv.writer(f)
writer.writerows([
["and", 0.4],
["life", 0.08],
["the", 0.3],
["to", 0.2],
["universe", 0.02],
])
return filepath
def test_skip_gram_sample_with_text_vocab_filter_vocab(self):
"""Tests skip-gram sampling with text vocab and freq threshold filtering."""
input_tensor = constant_op.constant([
b"the",
b"answer", # Will be filtered before candidate generation.
b"to",
b"life",
b"and",
b"universe" # Will be filtered before candidate generation.
])
# b"answer" is not in vocab file, and b"universe"'s frequency is below
# threshold of 3.
vocab_freq_file = self._make_text_vocab_freq_file()
tokens, labels = text.skip_gram_sample_with_text_vocab(
input_tensor=input_tensor,
vocab_freq_file=vocab_freq_file,
vocab_token_index=0,
vocab_freq_index=1,
vocab_min_count=3,
min_skips=1,
max_skips=1)
expected_tokens, expected_labels = self._split_tokens_labels([
(b"the", b"to"),
(b"to", b"the"),
(b"to", b"life"),
(b"life", b"to"),
(b"life", b"and"),
(b"and", b"life"),
])
with self.test_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count,
vocab_freq_dtype):
# The outputs are non-deterministic, so set random seed to help ensure that
# the outputs remain constant for testing.
random_seed.set_random_seed(42)
input_tensor = constant_op.constant([
# keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
b"the",
b"answer", # Not in vocab. (Always discarded)
b"to", # keep_prob = 0.75.
b"life", # keep_prob > 1. (Always kept)
b"and", # keep_prob = 0.48.
b"universe" # Below vocab threshold of 3. (Always discarded)
])
# keep_prob calculated from vocab file with relative frequencies of:
# and: 40
# life: 8
# the: 30
# to: 20
# universe: 2
tokens, labels = text.skip_gram_sample_with_text_vocab(
input_tensor=input_tensor,
vocab_freq_file=vocab_freq_file,
vocab_token_index=0,
vocab_freq_index=1,
vocab_freq_dtype=vocab_freq_dtype,
vocab_min_count=vocab_min_count,
vocab_subsampling=0.05,
min_skips=1,
max_skips=1,
seed=123)
expected_tokens, expected_labels = self._split_tokens_labels([
(b"the", b"to"),
(b"to", b"the"),
(b"to", b"life"),
(b"life", b"to"),
])
with self.test_session() as sess:
lookup_ops.tables_initializer().run()
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens, tokens_eval)
self.assertAllEqual(expected_labels, labels_eval)
def test_skip_gram_sample_with_text_vocab_subsample_vocab(self):
"""Tests skip-gram sampling with text vocab and vocab subsampling."""
# Vocab file frequencies
# and: 40
# life: 8
# the: 30
# to: 20
# universe: 2
self._text_vocab_subsample_vocab_helper(
vocab_freq_file=self._make_text_vocab_freq_file(),
vocab_min_count=3,
vocab_freq_dtype=dtypes.int64)
def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self):
"""Tests skip-gram sampling with text vocab and subsampling with floats."""
# Vocab file frequencies
# and: 0.4
# life: 0.08
# the: 0.3
# to: 0.2
# universe: 0.02
self._text_vocab_subsample_vocab_helper(
vocab_freq_file=self._make_text_vocab_float_file(),
vocab_min_count=0.03,
vocab_freq_dtype=dtypes.float32)
def test_skip_gram_sample_with_text_vocab_errors(self):
"""Tests various errors raised by skip_gram_sample_with_text_vocab()."""
dummy_input = constant_op.constant([""])
vocab_freq_file = self._make_text_vocab_freq_file()
invalid_indices = (
# vocab_token_index can't be negative.
(-1, 0),
# vocab_freq_index can't be negative.
(0, -1),
# vocab_token_index can't be equal to vocab_freq_index.
(0, 0),
(1, 1),
# vocab_freq_file only has two columns.
(0, 2),
(2, 0))
for vocab_token_index, vocab_freq_index in invalid_indices:
with self.assertRaises(ValueError):
text.skip_gram_sample_with_text_vocab(
input_tensor=dummy_input,
vocab_freq_file=vocab_freq_file,
vocab_token_index=vocab_token_index,
vocab_freq_index=vocab_freq_index)
if __name__ == "__main__":
test.main()