From 458f94c128fa5f72085be9a2489765615e1951a7 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Wed, 31 May 2017 17:07:05 -0700 Subject: [PATCH] Open-source skip-gram ops PiperOrigin-RevId: 157655970 --- tensorflow/BUILD | 1 + tensorflow/contrib/BUILD | 3 + .../contrib/cmake/tf_core_kernels.cmake | 2 + tensorflow/contrib/cmake/tf_core_ops.cmake | 1 + tensorflow/contrib/cmake/tf_python.cmake | 7 + tensorflow/contrib/text/BUILD | 119 ++++ tensorflow/contrib/text/__init__.py | 30 + .../contrib/text/kernels/skip_gram_kernels.cc | 139 +++++ tensorflow/contrib/text/ops/skip_gram_ops.cc | 54 ++ .../contrib/text/python/ops/__init__.py | 22 + .../contrib/text/python/ops/skip_gram_ops.py | 428 +++++++++++++ .../text/python/ops/skip_gram_ops_test.py | 571 ++++++++++++++++++ 12 files changed, 1377 insertions(+) create mode 100644 tensorflow/contrib/text/BUILD create mode 100644 tensorflow/contrib/text/__init__.py create mode 100644 tensorflow/contrib/text/kernels/skip_gram_kernels.cc create mode 100644 tensorflow/contrib/text/ops/skip_gram_ops.cc create mode 100644 tensorflow/contrib/text/python/ops/__init__.py create mode 100644 tensorflow/contrib/text/python/ops/skip_gram_ops.py create mode 100644 tensorflow/contrib/text/python/ops/skip_gram_ops_test.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 1ea62694d43..b90dc1b2050 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -281,6 +281,7 @@ filegroup( "//tensorflow/contrib/tensor_forest/hybrid:all_files", "//tensorflow/contrib/tensorboard:all_files", "//tensorflow/contrib/testing:all_files", + "//tensorflow/contrib/text:all_files", "//tensorflow/contrib/tfprof/python/tools/tfprof:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index e6f927b8b8a..d3fb30ca50f 100755 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -66,6 +66,7 @@ py_library( "//tensorflow/contrib/tensor_forest:init_py", "//tensorflow/contrib/tensorboard", "//tensorflow/contrib/testing:testing_py", + "//tensorflow/contrib/text:text_py", "//tensorflow/contrib/tfprof", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", @@ -84,6 +85,7 @@ cc_library( "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nccl:nccl_kernels", "//tensorflow/contrib/tensor_forest:tensor_forest_kernels", + "//tensorflow/contrib/text:all_kernels", ], ) @@ -98,6 +100,7 @@ cc_library( "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", + "//tensorflow/contrib/text:all_ops", ], ) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index ac3b9004969..500b917ac99 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -82,6 +82,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc" + "${tensorflow_source_dir}/tensorflow/contrib/text/kernels/skip_gram_kernels.cc" + "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc" ) list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs}) endif(tensorflow_BUILD_CONTRIB_KERNELS) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index c2eedb00b63..3c2f89c6c82 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -81,6 +81,7 @@ GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contri GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}") +GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 480d93723d8..c9b5e20cf8e 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -507,6 +507,11 @@ add_python_module("tensorflow/contrib/tensor_forest/python/ops") add_python_module("tensorflow/contrib/testing") add_python_module("tensorflow/contrib/testing/python") add_python_module("tensorflow/contrib/testing/python/framework") +add_python_module("tensorflow/contrib/text") +add_python_module("tensorflow/contrib/text/kernels") +add_python_module("tensorflow/contrib/text/ops") +add_python_module("tensorflow/contrib/text/python") +add_python_module("tensorflow/contrib/text/python/ops") add_python_module("tensorflow/contrib/tfprof" DONTCOPY) # SWIG wrapper not implemented. #add_python_module("tensorflow/contrib/tfprof/python") #add_python_module("tensorflow/contrib/tfprof/python/tools") @@ -644,6 +649,8 @@ GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_tensor_forest_ops.py) GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_hybrid_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/hybrid/ops/gen_training_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py) GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) GENERATE_PYTHON_OP_LIB("stateless_random_ops" diff --git a/tensorflow/contrib/text/BUILD b/tensorflow/contrib/text/BUILD new file mode 100644 index 00000000000..ff69c4e2cbe --- /dev/null +++ b/tensorflow/contrib/text/BUILD @@ -0,0 +1,119 @@ +# Description: +# contains parts of TensorFlow that are experimental or unstable and which +# are not supported. + +package(default_visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", +]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", +) + +tf_custom_op_py_library( + name = "text_py", + srcs = [ + "__init__.py", + "python/ops/__init__.py", + "python/ops/skip_gram_ops.py", + ], + dso = [ + ":python/ops/_skip_gram_ops.so", + ], + kernels = [ + ":all_kernels", + ":all_ops", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_skip_gram_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + ], +) + +tf_kernel_library( + name = "skip_gram_kernels", + srcs = ["kernels/skip_gram_kernels.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], + alwayslink = 1, +) + +cc_library( + name = "all_kernels", + deps = [":skip_gram_kernels"], +) + +tf_custom_op_library( + name = "python/ops/_skip_gram_ops.so", + srcs = [ + "kernels/skip_gram_kernels.cc", + "ops/skip_gram_ops.cc", + ], +) + +tf_gen_op_libs( + op_lib_names = ["skip_gram_ops"], +) + +cc_library( + name = "all_ops", + deps = [":skip_gram_ops_op_lib"], +) + +tf_gen_op_wrapper_py( + name = "gen_skip_gram_ops", + out = "python/ops/gen_skip_gram_ops.py", + deps = [":skip_gram_ops_op_lib"], +) + +py_test( + name = "skip_gram_ops_test", + size = "medium", + srcs = ["python/ops/skip_gram_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":text_py", + "//tensorflow/contrib/lookup:lookup_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", + "//tensorflow/python:training", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/contrib/text/__init__.py b/tensorflow/contrib/text/__init__.py new file mode 100644 index 00000000000..35e66231890 --- /dev/null +++ b/tensorflow/contrib/text/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Text-processing ops. + +@@skip_gram_sample +@@skip_gram_sample_with_text_vocab +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.text.python.ops import * +# pylint: enable=unused-import,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/text/kernels/skip_gram_kernels.cc b/tensorflow/contrib/text/kernels/skip_gram_kernels.cc new file mode 100644 index 00000000000..3cd0b5f72b5 --- /dev/null +++ b/tensorflow/contrib/text/kernels/skip_gram_kernels.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { + +template +class SkipGramGenerateCandidatesOp : public OpKernel { + public: + explicit SkipGramGenerateCandidatesOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, generator_.Init(context)); + } + + void Compute(OpKernelContext* context) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("input_tensor", &input_tensor)); + const auto input = input_tensor->flat(); + + const Tensor* min_skips_tensor; + OP_REQUIRES_OK(context, context->input("min_skips", &min_skips_tensor)); + const int min_skips = *(min_skips_tensor->scalar().data()); + const Tensor* max_skips_tensor; + OP_REQUIRES_OK(context, context->input("max_skips", &max_skips_tensor)); + const int max_skips = *(max_skips_tensor->scalar().data()); + + OP_REQUIRES( + context, min_skips >= 0 && max_skips >= 0, + errors::InvalidArgument("Both min_skips and max_skips must be >= 0.")); + OP_REQUIRES(context, min_skips <= max_skips, + errors::InvalidArgument("min_skips must be <= max_skips.")); + + const Tensor* start_tensor; + OP_REQUIRES_OK(context, context->input("start", &start_tensor)); + const int start = *(start_tensor->scalar().data()); + const Tensor* limit_tensor; + OP_REQUIRES_OK(context, context->input("limit", &limit_tensor)); + const int limit = *(limit_tensor->scalar().data()); + const int end = + limit < 0 ? input.size() + : std::min(start + limit, static_cast(input.size())); + + const Tensor* emit_self_tensor; + OP_REQUIRES_OK(context, + context->input("emit_self_as_target", &emit_self_tensor)); + const bool emit_self_as_target = *(emit_self_tensor->scalar().data()); + + std::vector tokens; + std::vector labels; + + // Reserve the number of random numbers we will use - we use one for each + // token between start and end. + random::PhiloxRandom local_gen = + generator_.ReserveSamples32(end - start + 1); + random::SimplePhilox rng(&local_gen); + + // For each token in the sentence, pick a random skip, then generates + // (token, label) pairs for all labels whose distances from the token are + // within the range [-skip, skip]. + for (int i = start; i < end; ++i) { + const int skips = min_skips + rng.Uniform(max_skips - min_skips + 1); + for (int j = -skips; j <= skips; ++j) { + if ((i + j < start) || (i + j >= end) || + (j == 0 && !emit_self_as_target)) { + continue; + } + tokens.push_back(input(i)); + labels.push_back(input(i + j)); + } + } + + Tensor* tokens_output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + "tokens", TensorShape({static_cast(tokens.size())}), + &tokens_output)); + Tensor* labels_output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + "labels", TensorShape({static_cast(labels.size())}), + &labels_output)); + OP_REQUIRES( + context, tokens_output->IsSameSize(*labels_output), + errors::Internal(strings::StrCat( + "Mismatch between tokens_output shape of ", + tokens_output->shape().DebugString(), + " and labels_output shape of ", + labels_output->shape().DebugString(), + ". This should never happen - contact ami-team@ if it does."))); + + // Copies results to output tensors. + for (int i = 0; i < tokens.size(); ++i) { + tokens_output->vec()(i) = tokens[i]; + labels_output->vec()(i) = labels[i]; + } + } + + private: + GuardedPhiloxRandom generator_; +}; + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("SkipGramGenerateCandidates") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + SkipGramGenerateCandidatesOp) + +REGISTER_KERNEL(string); +REGISTER_KERNEL(int64); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(int16); +// TODO(weiho): Add other types if the need arises. + +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/contrib/text/ops/skip_gram_ops.cc b/tensorflow/contrib/text/ops/skip_gram_ops.cc new file mode 100644 index 00000000000..9a7a20d81a9 --- /dev/null +++ b/tensorflow/contrib/text/ops/skip_gram_ops.cc @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +REGISTER_OP("SkipGramGenerateCandidates") + .Input("input_tensor: T") + .Input("min_skips: int32") + .Input("max_skips: int32") + .Input("start: int32") + .Input("limit: int32") + .Input("emit_self_as_target: bool") + .Output("tokens: T") + .Output("labels: T") + .Attr("T: type") + // The seed attributes are needed by GuardedPhiloxRandom + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // input_tensor must be of rank-1. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); + // All other args must be scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + + // Due to possible randomness in selecting skips, we only know that the + // outputs will be of rank-1, but not their sizes. + c->set_output(0, c->Vector(c->UnknownDim())); + c->set_output(1, c->Vector(c->UnknownDim())); + return Status::OK(); + }) + .Doc(R"doc( +Generates skip-gram token and label paired Tensors from the input tensor. +See docs for the public-facing skip_gram_sample() Python op for more details. +)doc"); +} // namespace tensorflow diff --git a/tensorflow/contrib/text/python/ops/__init__.py b/tensorflow/contrib/text/python/ops/__init__.py new file mode 100644 index 00000000000..bb47266dd2b --- /dev/null +++ b/tensorflow/contrib/text/python/ops/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various contrib ops related to text-processing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.text.python.ops.skip_gram_ops import skip_gram_sample +from tensorflow.contrib.text.python.ops.skip_gram_ops import skip_gram_sample_with_text_vocab diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops.py b/tensorflow/contrib/text/python/ops/skip_gram_ops.py new file mode 100644 index 00000000000..410ee517e03 --- /dev/null +++ b/tensorflow/contrib/text/python/ops/skip_gram_ops.py @@ -0,0 +1,428 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Skip-gram sampling ops from https://arxiv.org/abs/1301.3781.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import csv + +from tensorflow.contrib import lookup +from tensorflow.contrib.text.python.ops import gen_skip_gram_ops +from tensorflow.contrib.util import loader +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import resource_loader +from tensorflow.python.training import input as input_ops + +_checkpoint_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_skip_gram_ops.so")) + +ops.NotDifferentiable("SkipGramGenerateCandidates") + + +def skip_gram_sample(input_tensor, + min_skips=1, + max_skips=5, + start=0, + limit=-1, + emit_self_as_target=False, + vocab_freq_table=None, + vocab_min_count=None, + vocab_subsampling=None, + corpus_size=None, + batch_size=None, + batch_capacity=None, + seed=None, + name=None): + """Generates skip-gram token and label paired Tensors from the input tensor. + + Generates skip-gram `("token", "label")` pairs using each element in the + rank-1 `input_tensor` as a token. The window size used for each token will be + randomly selected from the range specified by `[min_skips, max_skips]`, + inclusive. See https://arxiv.org/abs/1301.3781 for more details about + skip-gram. + + For example, given `input_tensor = ["the", "quick", "brown", "fox", "jumps"]`, + `min_skips = 1`, `max_skips = 2`, `emit_self_as_target = False`, the output + `(tokens, labels)` pairs for the token "quick" will be randomly selected from + either `(tokens=["quick", "quick"], labels=["the", "brown"])` for 1 skip, or + `(tokens=["quick", "quick", "quick"], labels=["the", "brown", "fox"])` for 2 + skips. + + If `emit_self_as_target = True`, each token will also be emitted as a label + for itself. From the previous example, the output will be either + `(tokens=["quick", "quick", "quick"], labels=["the", "quick", "brown"])` for 1 + skip, or `(tokens=["quick", "quick", "quick", "quick"], labels=["the", + "quick", "brown", "fox"])` for 2 skips. + + The same process is repeated for each element of `input_tensor` and + concatenated together into the two output rank-1 `Tensors` (one for all the + tokens, another for all the labels). + + If `vocab_freq_table` is specified, tokens in `input_tensor` that are not + present in the vocabulary are discarded. Tokens whose frequency counts are + below `vocab_min_count` are also discarded. Tokens whose frequency proportions + in the corpus exceed `vocab_subsampling` may be randomly down-sampled. See + Eq. 5 in http://arxiv.org/abs/1310.4546 for more details about subsampling. + + Due to the random window sizes used for each token, the lengths of the outputs + are non-deterministic, unless `batch_size` is specified to batch the outputs + to always return `Tensors` of length `batch_size`. + + Args: + input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates. + min_skips: `int` or scalar `Tensor` specifying the minimum window size to + randomly use for each token. Must be >= 0 and <= `max_skips`. If + `min_skips` and `max_skips` are both 0, the only label outputted will be + the token itself when `emit_self_as_target = True` - or no output + otherwise. + max_skips: `int` or scalar `Tensor` specifying the maximum window size to + randomly use for each token. Must be >= 0. + start: `int` or scalar `Tensor` specifying the position in + `input_tensor` from which to start generating skip-gram candidates. + limit: `int` or scalar `Tensor` specifying the maximum number of + elements in `input_tensor` to use in generating skip-gram candidates. -1 + means to use the rest of the `Tensor` after `start`. + emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit + each token as a label for itself. + vocab_freq_table: (Optional) A lookup table (subclass of + `lookup.InitializableLookupTableBase`) that maps tokens to their raw + frequency counts. If specified, any token in `input_tensor` that is not + found in `vocab_freq_table` will be filtered out before generating + skip-gram candidates. While this will typically map to integer raw + frequency counts, it could also map to float frequency proportions. + `vocab_min_count` and `corpus_size` should be in the same units as this. + vocab_min_count: (Optional) `int`, `float`, or scalar `Tensor` specifying + minimum frequency threshold (from `vocab_freq_table`) for a token to be + kept in `input_tensor`. If this is specified, `vocab_freq_table` must also + be specified - and they should both be in the same units. + vocab_subsampling: (Optional) `float` specifying frequency proportion + threshold for tokens from `input_tensor`. Tokens that occur more + frequently (based on the ratio of the token's `vocab_freq_table` value to + the `corpus_size`) will be randomly down-sampled. Reasonable starting + values may be around 1e-3 or 1e-5. If this is specified, both + `vocab_freq_table` and `corpus_size` must also be specified. See Eq. 5 + in http://arxiv.org/abs/1310.4546 for more details. + corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the + total number of tokens in the corpus (e.g., sum of all the frequency + counts of `vocab_freq_table`). Used with `vocab_subsampling` for + down-sampling frequently occurring tokens. If this is specified, + `vocab_freq_table` and `vocab_subsampling` must also be specified. + batch_size: (Optional) `int` specifying batch size of returned `Tensors`. + batch_capacity: (Optional) `int` specifying batch capacity for the queue + used for batching returned `Tensors`. Only has an effect if + `batch_size` > 0. Defaults to 100 * `batch_size` if not specified. + seed: (Optional) `int` used to create a random seed for window size and + subsampling. See `set_random_seed` docs for behavior. + name: (Optional) A `string` name or a name scope for the operations. + + Returns: + A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of + rank-1 and has the same type as `input_tensor`. The `Tensors` will be of + length `batch_size`; if `batch_size` is not specified, they will be of + random length, though they will be in sync with each other as long as they + are evaluated together. + + Raises: + ValueError: If `vocab_freq_table` is not provided, but `vocab_min_count`, + `vocab_subsampling`, or `corpus_size` is specified. If `vocab_subsampling` + and `corpus_size` are not both present or both absent. + """ + + if vocab_freq_table is None and (vocab_min_count is not None or + vocab_subsampling is not None or + corpus_size is not None): + raise ValueError( + "vocab_freq_table is not provided, but vocab_min_count={}, " + "vocab_subsampling={}, or corpus_size={} is not None. These settings " + "are useless without a vocab_freq_table.".format( + vocab_min_count, vocab_subsampling, corpus_size)) + + if (vocab_subsampling is None) != (corpus_size is None): + raise ValueError( + "vocab_subsampling is {} while corpus_size is {} - both must be " + "provided in order for subsampling to work.".format( + vocab_subsampling, corpus_size)) + + with ops.name_scope( + name, + "skip_gram_sample", + values=[input_tensor, min_skips, max_skips, start, limit]): + + input_tensor = _filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=vocab_min_count, + vocab_subsampling=vocab_subsampling, + corpus_size=corpus_size, + seed=seed) + + seed1, seed2 = random_seed.get_seed(seed) + tokens, labels = gen_skip_gram_ops.skip_gram_generate_candidates( + input_tensor=input_tensor, + min_skips=min_skips, + max_skips=max_skips, + start=start, + limit=limit, + emit_self_as_target=emit_self_as_target, + # Note that seed here should be seed1! This is due to + # GuardedPhiloxRandom's hard-coded attributes of "seed" and "seed2". + seed=seed1, + seed2=seed2) + + # TODO(weiho): If the need arises, add support for sparse input_tensor that + # figures out sentence boundaries, then calls + # skip_gram_generate_candidates() on each sentence. + + # Batches the (tokens, labels) outputs so that they will be of deterministic + # batch_size, to facilitate feeding them into the rest of the network. + if batch_size is not None and batch_size > 0: + batch_capacity = (batch_capacity + if (batch_capacity is not None and batch_capacity > 0) + else 100 * batch_size) + return input_ops.batch( + [tokens, labels], + batch_size, + capacity=batch_capacity, + enqueue_many=True) + + return tokens, labels + + +def skip_gram_sample_with_text_vocab(input_tensor, + vocab_freq_file, + vocab_token_index=0, + vocab_token_dtype=dtypes.string, + vocab_freq_index=1, + vocab_freq_dtype=dtypes.float64, + vocab_delimiter=",", + vocab_min_count=0, + vocab_subsampling=None, + min_skips=1, + max_skips=5, + start=0, + limit=-1, + emit_self_as_target=False, + batch_size=None, + batch_capacity=None, + seed=None, + name=None): + """Skip-gram sampling with a text vocabulary file. + + Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The + vocabulary file is expected to be a plain-text file, with lines of + `vocab_delimiter`-separated columns. The `vocab_token_index` column should + contain the vocabulary term, while the `vocab_freq_index` column should + contain the number of times that term occurs in the corpus. For example, with + a text vocabulary file of: + + ``` + bonjour,fr,42 + hello,en,777 + hola,es,99 + ``` + + You should set `vocab_delimiter=","`, `vocab_token_index=0`, and + `vocab_freq_index=2`. + + See `skip_gram_sample()` documentation for more details about the skip-gram + sampling process. + + Args: + input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates. + vocab_freq_file: `string` specifying full file path to the text vocab file. + vocab_token_index: `int` specifying which column in the text vocab file + contains the tokens. + vocab_token_dtype: `DType` specifying the format of the tokens in the text + vocab file. + vocab_freq_index: `int` specifying which column in the text vocab file + contains the frequency counts of the tokens. + vocab_freq_dtype: `DType` specifying the format of the frequency counts in + the text vocab file. + vocab_delimiter: `string` specifying the delimiter used in the text vocab + file. + vocab_min_count: `int`, `float`, or scalar `Tensor` specifying + minimum frequency threshold (from `vocab_freq_file`) for a token to be + kept in `input_tensor`. This should correspond with `vocab_freq_dtype`. + vocab_subsampling: (Optional) `float` specifying frequency proportion + threshold for tokens from `input_tensor`. Tokens that occur more + frequently will be randomly down-sampled. Reasonable starting values may + be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for + more details. + min_skips: `int` or scalar `Tensor` specifying the minimum window size to + randomly use for each token. Must be >= 0 and <= `max_skips`. If + `min_skips` and `max_skips` are both 0, the only label outputted will be + the token itself. + max_skips: `int` or scalar `Tensor` specifying the maximum window size to + randomly use for each token. Must be >= 0. + start: `int` or scalar `Tensor` specifying the position in `input_tensor` + from which to start generating skip-gram candidates. + limit: `int` or scalar `Tensor` specifying the maximum number of elements in + `input_tensor` to use in generating skip-gram candidates. -1 means to use + the rest of the `Tensor` after `start`. + emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit + each token as a label for itself. + batch_size: (Optional) `int` specifying batch size of returned `Tensors`. + batch_capacity: (Optional) `int` specifying batch capacity for the queue + used for batching returned `Tensors`. Only has an effect if + `batch_size` > 0. Defaults to 100 * `batch_size` if not specified. + seed: (Optional) `int` used to create a random seed for window size and + subsampling. See + [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed) + for behavior. + name: (Optional) A `string` name or a name scope for the operations. + + Returns: + A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of + rank-1 and has the same type as `input_tensor`. The `Tensors` will be of + length `batch_size`; if `batch_size` is not specified, they will be of + random length, though they will be in sync with each other as long as they + are evaluated together. + + Raises: + ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or + exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index` + and `vocab_freq_index` are both set to the same column. If any token in + `vocab_freq_file` has a negative frequency. + """ + + if vocab_token_index < 0 or vocab_freq_index < 0: + raise ValueError( + "vocab_token_index={} and vocab_freq_index={} must both be >= 0.". + format(vocab_token_index, vocab_freq_index)) + if vocab_token_index == vocab_freq_index: + raise ValueError( + "vocab_token_index and vocab_freq_index should be different, but are " + "both {}.".format(vocab_token_index)) + + # Iterates through the vocab file and calculates the number of vocab terms as + # well as the total corpus size (by summing the frequency counts of all the + # vocab terms). + corpus_size = 0.0 + vocab_size = 0 + with gfile.GFile(vocab_freq_file, mode="r") as f: + reader = csv.reader(f, delimiter=vocab_delimiter) + for row in reader: + if vocab_token_index >= len(row) or vocab_freq_index >= len(row): + raise ValueError( + "Row in vocab file only has {} columns, so vocab_token_index={} or " + "vocab_freq_index={} is out of bounds. Row content: {}".format( + len(row), vocab_token_index, vocab_freq_index, row)) + vocab_size += 1 + freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index]) + if freq < 0: + raise ValueError( + "Row in vocab file has negative frequency of {}. Row content: {}". + format(freq, row)) + # Note: tokens whose frequencies are below vocab_min_count will still + # contribute to the total corpus size used for vocab subsampling. + corpus_size += freq + + vocab_freq_table = lookup.HashTable( + lookup.TextFileInitializer( + filename=vocab_freq_file, + key_dtype=vocab_token_dtype, + key_index=vocab_token_index, + value_dtype=vocab_freq_dtype, + value_index=vocab_freq_index, + vocab_size=vocab_size, + delimiter=vocab_delimiter), + # For vocab terms not in vocab file, use a default value of -1. + default_value=-1) + + return skip_gram_sample( + input_tensor, + min_skips=min_skips, + max_skips=max_skips, + start=start, + limit=limit, + emit_self_as_target=emit_self_as_target, + vocab_freq_table=vocab_freq_table, + vocab_min_count=vocab_min_count, + vocab_subsampling=vocab_subsampling, + # corpus_size is not used unless vocab_subsampling is specified. + corpus_size=None if vocab_subsampling is None else corpus_size, + batch_size=batch_size, + batch_capacity=batch_capacity, + seed=seed, + name=name) + + +def _filter_input(input_tensor, vocab_freq_table, vocab_min_count, + vocab_subsampling, corpus_size, seed): + """Filters input tensor based on vocab freq, threshold, and subsampling.""" + if vocab_freq_table is None: + return input_tensor + + if not isinstance(vocab_freq_table, lookup.InitializableLookupTableBase): + raise ValueError( + "vocab_freq_table must be a subclass of " + "InitializableLookupTableBase (such as HashTable) instead of type " + "{}.".format(type(vocab_freq_table))) + + with ops.name_scope( + "filter_vocab", values=[vocab_freq_table, input_tensor, vocab_min_count]): + freq = vocab_freq_table.lookup(input_tensor) + # Filters out elements in input_tensor that are not found in + # vocab_freq_table (table returns a default value of -1 specified above when + # an element is not found). + mask = math_ops.not_equal(freq, vocab_freq_table.default_value) + + # Filters out elements whose vocab frequencies are less than the threshold. + if vocab_min_count is not None: + cast_threshold = math_ops.cast(vocab_min_count, freq.dtype) + mask = math_ops.logical_and(mask, + math_ops.greater_equal(freq, cast_threshold)) + + input_tensor = array_ops.boolean_mask(input_tensor, mask) + freq = array_ops.boolean_mask(freq, mask) + + if not vocab_subsampling: + return input_tensor + + if vocab_subsampling < 0 or vocab_subsampling > 1: + raise ValueError( + "Invalid vocab_subsampling={} - it should be within range [0, 1].". + format(vocab_subsampling)) + + # Subsamples the input tokens based on vocabulary frequency and + # vocab_subsampling threshold (ie randomly discard commonly appearing + # tokens). + with ops.name_scope( + "subsample_vocab", values=[input_tensor, freq, vocab_subsampling]): + corpus_size = math_ops.cast(corpus_size, dtypes.float64) + freq = math_ops.cast(freq, dtypes.float64) + vocab_subsampling = math_ops.cast(vocab_subsampling, dtypes.float64) + + # From tensorflow_models/tutorials/embedding/word2vec_kernels.cc, which is + # suppose to correlate with Eq. 5 in http://arxiv.org/abs/1310.4546. + keep_prob = ((math_ops.sqrt(freq / + (vocab_subsampling * corpus_size)) + 1.0) * + (vocab_subsampling * corpus_size / freq)) + random_prob = random_ops.random_uniform( + array_ops.shape(freq), + minval=0, + maxval=1, + dtype=dtypes.float64, + seed=seed) + + mask = math_ops.less_equal(random_prob, keep_prob) + return array_ops.boolean_mask(input_tensor, mask) diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py new file mode 100644 index 00000000000..d989942f732 --- /dev/null +++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py @@ -0,0 +1,571 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Skip-gram sampling ops tests.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import csv +import os + +from tensorflow.contrib import lookup +from tensorflow.contrib import text +from tensorflow.contrib.text.python.ops import skip_gram_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.training import coordinator +from tensorflow.python.training import queue_runner_impl + + +class SkipGramOpsTest(test.TestCase): + + def _split_tokens_labels(self, output): + tokens = [x[0] for x in output] + labels = [x[1] for x in output] + return tokens, labels + + def test_skip_gram_sample_skips_2(self): + """Tests skip-gram with min_skips = max_skips = 2.""" + input_tensor = constant_op.constant( + [b"the", b"quick", b"brown", b"fox", b"jumps"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=2, max_skips=2) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"the", b"brown"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"quick", b"fox"), + (b"brown", b"the"), + (b"brown", b"quick"), + (b"brown", b"fox"), + (b"brown", b"jumps"), + (b"fox", b"quick"), + (b"fox", b"brown"), + (b"fox", b"jumps"), + (b"jumps", b"brown"), + (b"jumps", b"fox"), + ]) + with self.test_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_emit_self(self): + """Tests skip-gram with emit_self_as_target = True.""" + input_tensor = constant_op.constant( + [b"the", b"quick", b"brown", b"fox", b"jumps"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"the"), + (b"the", b"quick"), + (b"the", b"brown"), + (b"quick", b"the"), + (b"quick", b"quick"), + (b"quick", b"brown"), + (b"quick", b"fox"), + (b"brown", b"the"), + (b"brown", b"quick"), + (b"brown", b"brown"), + (b"brown", b"fox"), + (b"brown", b"jumps"), + (b"fox", b"quick"), + (b"fox", b"brown"), + (b"fox", b"fox"), + (b"fox", b"jumps"), + (b"jumps", b"brown"), + (b"jumps", b"fox"), + (b"jumps", b"jumps"), + ]) + with self.test_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_skips_0(self): + """Tests skip-gram with min_skips = max_skips = 0.""" + input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) + + # If emit_self_as_target is False (default), output will be empty. + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False) + with self.test_session(): + self.assertEqual(0, tokens.eval().size) + self.assertEqual(0, labels.eval().size) + + # If emit_self_as_target is True, each token will be its own label. + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"the"), + (b"quick", b"quick"), + (b"brown", b"brown"), + ]) + with self.test_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_skips_exceed_length(self): + """Tests skip-gram when min/max_skips exceed length of input.""" + input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=100, max_skips=100) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"the", b"brown"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"brown", b"the"), + (b"brown", b"quick"), + ]) + with self.test_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_start_limit(self): + """Tests skip-gram over a limited portion of the input.""" + input_tensor = constant_op.constant( + [b"foo", b"the", b"quick", b"brown", b"bar"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=1, start=1, limit=3) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"brown", b"quick"), + ]) + with self.test_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_limit_exceeds(self): + """Tests skip-gram when limit exceeds the length of the input.""" + input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=1, start=1, limit=100) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"brown", b"quick"), + ]) + with self.test_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_random_skips(self): + """Tests skip-gram with min_skips != max_skips, with random output.""" + # The number of outputs is non-deterministic in this case, so set random + # seed to help ensure the outputs remain constant for this test case. + random_seed.set_random_seed(42) + + input_tensor = constant_op.constant( + [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=2, seed=9) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"the", b"brown"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"quick", b"fox"), + (b"brown", b"the"), + (b"brown", b"quick"), + (b"brown", b"fox"), + (b"brown", b"jumps"), + (b"fox", b"brown"), + (b"fox", b"jumps"), + (b"jumps", b"fox"), + (b"jumps", b"over"), + (b"over", b"fox"), + (b"over", b"jumps"), + ]) + with self.test_session() as sess: + tokens_eval, labels_eval = sess.run([tokens, labels]) + self.assertAllEqual(expected_tokens, tokens_eval) + self.assertAllEqual(expected_labels, labels_eval) + + def test_skip_gram_sample_random_skips_default_seed(self): + """Tests outputs are still random when no op-level seed is specified.""" + # This is needed since tests set a graph-level seed by default. We want to + # explicitly avoid setting both graph-level seed and op-level seed, to + # simulate behavior under non-test settings when the user doesn't provide a + # seed to us. This results in random_seed.get_seed() returning None for both + # seeds, forcing the C++ kernel to execute its default seed logic. + random_seed.set_random_seed(None) + + # Uses an input tensor with 10 words, with possible skip ranges in [1, + # 5]. Thus, the probability that two random samplings would result in the + # same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being + # flaky). + input_tensor = constant_op.constant([str(x) for x in range(10)]) + + # Do not provide an op-level seed here! + tokens_1, labels_1 = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=5) + tokens_2, labels_2 = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=5) + + with self.test_session() as sess: + tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run( + [tokens_1, labels_1, tokens_2, labels_2]) + + if len(tokens_1_eval) == len(tokens_2_eval): + self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist()) + if len(labels_1_eval) == len(labels_2_eval): + self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist()) + + def test_skip_gram_sample_batch(self): + """Tests skip-gram with batching.""" + input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=1, batch_size=3) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"brown", b"quick"), + (b"brown", b"fox"), + (b"fox", b"brown"), + ]) + with self.test_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) + + tokens_eval, labels_eval = sess.run([tokens, labels]) + self.assertAllEqual(expected_tokens[:3], tokens_eval) + self.assertAllEqual(expected_labels[:3], labels_eval) + tokens_eval, labels_eval = sess.run([tokens, labels]) + self.assertAllEqual(expected_tokens[3:6], tokens_eval) + self.assertAllEqual(expected_labels[3:6], labels_eval) + + coord.request_stop() + coord.join(threads) + + def test_skip_gram_sample_non_string_input(self): + """Tests skip-gram with non-string input.""" + input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=1) + expected_tokens, expected_labels = self._split_tokens_labels([ + (1, 2), + (2, 1), + (2, 3), + (3, 2), + ]) + with self.test_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_errors(self): + """Tests various errors raised by skip_gram_sample().""" + input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) + + invalid_skips = ( + # min_skips and max_skips must be >= 0. + (-1, 2), + (1, -2), + # min_skips must be <= max_skips. + (2, 1)) + for min_skips, max_skips in invalid_skips: + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=min_skips, max_skips=max_skips) + with self.test_session() as sess, self.assertRaises( + errors.InvalidArgumentError): + sess.run([tokens, labels]) + + # input_tensor must be of rank 1. + with self.assertRaises(ValueError): + invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]]) + text.skip_gram_sample(invalid_tensor) + + # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling, + # or corpus_size is specified. + dummy_input = constant_op.constant([""]) + with self.assertRaises(ValueError): + text.skip_gram_sample( + dummy_input, vocab_freq_table=None, vocab_min_count=1) + with self.assertRaises(ValueError): + text.skip_gram_sample( + dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5) + with self.assertRaises(ValueError): + text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100) + with self.assertRaises(ValueError): + text.skip_gram_sample( + dummy_input, + vocab_freq_table=None, + vocab_subsampling=1e-5, + corpus_size=100) + + # vocab_subsampling and corpus_size must both be present or absent. + dummy_table = lookup.HashTable( + lookup.KeyValueTensorInitializer([b"foo"], [10]), -1) + with self.assertRaises(ValueError): + text.skip_gram_sample( + dummy_input, + vocab_freq_table=dummy_table, + vocab_subsampling=None, + corpus_size=100) + with self.assertRaises(ValueError): + text.skip_gram_sample( + dummy_input, + vocab_freq_table=dummy_table, + vocab_subsampling=1e-5, + corpus_size=None) + + def test_filter_input_filter_vocab(self): + """Tests input filtering based on vocab frequency table and thresholds.""" + input_tensor = constant_op.constant( + [b"the", b"answer", b"to", b"life", b"and", b"universe"]) + keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"]) + values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) + vocab_freq_table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), -1) + + with self.test_session(): + vocab_freq_table.init.run() + + # No vocab_freq_table specified - output should be the same as input. + no_table_output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=None, + vocab_min_count=None, + vocab_subsampling=None, + corpus_size=None, + seed=None) + self.assertAllEqual(input_tensor.eval(), no_table_output.eval()) + + # vocab_freq_table specified, but no vocab_min_count - output should have + # filtered out tokens not in the table (b"answer"). + table_output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=None, + vocab_subsampling=None, + corpus_size=None, + seed=None) + self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"], + table_output.eval()) + + # vocab_freq_table and vocab_min_count specified - output should have + # filtered out tokens whose frequencies are below the threshold + # (b"and": 0, b"life": 1). + threshold_output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=2, + vocab_subsampling=None, + corpus_size=None, + seed=None) + self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval()) + + def test_filter_input_subsample_vocab(self): + """Tests input filtering based on vocab subsampling.""" + # The outputs are non-deterministic, so set random seed to help ensure that + # the outputs remain constant for testing. + random_seed.set_random_seed(42) + + input_tensor = constant_op.constant([ + # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57. + b"the", + b"answer", # Not in vocab. (Always discarded) + b"to", # keep_prob = 0.75. + b"life", # keep_prob > 1. (Always kept) + b"and", # keep_prob = 0.48. + b"universe" # Below vocab threshold of 3. (Always discarded) + ]) + keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"]) + values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64) + vocab_freq_table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), -1) + + with self.test_session(): + vocab_freq_table.init.run() + output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=3, + vocab_subsampling=0.05, + corpus_size=math_ops.reduce_sum(values), + seed=9) + self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval()) + + def _make_text_vocab_freq_file(self): + filepath = os.path.join(test.get_temp_dir(), "vocab_freq.txt") + with open(filepath, "w") as f: + writer = csv.writer(f) + writer.writerows([ + ["and", 40], + ["life", 8], + ["the", 30], + ["to", 20], + ["universe", 2], + ]) + return filepath + + def _make_text_vocab_float_file(self): + filepath = os.path.join(test.get_temp_dir(), "vocab_freq_float.txt") + with open(filepath, "w") as f: + writer = csv.writer(f) + writer.writerows([ + ["and", 0.4], + ["life", 0.08], + ["the", 0.3], + ["to", 0.2], + ["universe", 0.02], + ]) + return filepath + + def test_skip_gram_sample_with_text_vocab_filter_vocab(self): + """Tests skip-gram sampling with text vocab and freq threshold filtering.""" + input_tensor = constant_op.constant([ + b"the", + b"answer", # Will be filtered before candidate generation. + b"to", + b"life", + b"and", + b"universe" # Will be filtered before candidate generation. + ]) + + # b"answer" is not in vocab file, and b"universe"'s frequency is below + # threshold of 3. + vocab_freq_file = self._make_text_vocab_freq_file() + + tokens, labels = text.skip_gram_sample_with_text_vocab( + input_tensor=input_tensor, + vocab_freq_file=vocab_freq_file, + vocab_token_index=0, + vocab_freq_index=1, + vocab_min_count=3, + min_skips=1, + max_skips=1) + + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"to"), + (b"to", b"the"), + (b"to", b"life"), + (b"life", b"to"), + (b"life", b"and"), + (b"and", b"life"), + ]) + with self.test_session(): + lookup_ops.tables_initializer().run() + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count, + vocab_freq_dtype): + # The outputs are non-deterministic, so set random seed to help ensure that + # the outputs remain constant for testing. + random_seed.set_random_seed(42) + + input_tensor = constant_op.constant([ + # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57. + b"the", + b"answer", # Not in vocab. (Always discarded) + b"to", # keep_prob = 0.75. + b"life", # keep_prob > 1. (Always kept) + b"and", # keep_prob = 0.48. + b"universe" # Below vocab threshold of 3. (Always discarded) + ]) + # keep_prob calculated from vocab file with relative frequencies of: + # and: 40 + # life: 8 + # the: 30 + # to: 20 + # universe: 2 + + tokens, labels = text.skip_gram_sample_with_text_vocab( + input_tensor=input_tensor, + vocab_freq_file=vocab_freq_file, + vocab_token_index=0, + vocab_freq_index=1, + vocab_freq_dtype=vocab_freq_dtype, + vocab_min_count=vocab_min_count, + vocab_subsampling=0.05, + min_skips=1, + max_skips=1, + seed=123) + + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"to"), + (b"to", b"the"), + (b"to", b"life"), + (b"life", b"to"), + ]) + with self.test_session() as sess: + lookup_ops.tables_initializer().run() + tokens_eval, labels_eval = sess.run([tokens, labels]) + self.assertAllEqual(expected_tokens, tokens_eval) + self.assertAllEqual(expected_labels, labels_eval) + + def test_skip_gram_sample_with_text_vocab_subsample_vocab(self): + """Tests skip-gram sampling with text vocab and vocab subsampling.""" + # Vocab file frequencies + # and: 40 + # life: 8 + # the: 30 + # to: 20 + # universe: 2 + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=self._make_text_vocab_freq_file(), + vocab_min_count=3, + vocab_freq_dtype=dtypes.int64) + + def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self): + """Tests skip-gram sampling with text vocab and subsampling with floats.""" + # Vocab file frequencies + # and: 0.4 + # life: 0.08 + # the: 0.3 + # to: 0.2 + # universe: 0.02 + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=self._make_text_vocab_float_file(), + vocab_min_count=0.03, + vocab_freq_dtype=dtypes.float32) + + def test_skip_gram_sample_with_text_vocab_errors(self): + """Tests various errors raised by skip_gram_sample_with_text_vocab().""" + dummy_input = constant_op.constant([""]) + vocab_freq_file = self._make_text_vocab_freq_file() + + invalid_indices = ( + # vocab_token_index can't be negative. + (-1, 0), + # vocab_freq_index can't be negative. + (0, -1), + # vocab_token_index can't be equal to vocab_freq_index. + (0, 0), + (1, 1), + # vocab_freq_file only has two columns. + (0, 2), + (2, 0)) + + for vocab_token_index, vocab_freq_index in invalid_indices: + with self.assertRaises(ValueError): + text.skip_gram_sample_with_text_vocab( + input_tensor=dummy_input, + vocab_freq_file=vocab_freq_file, + vocab_token_index=vocab_token_index, + vocab_freq_index=vocab_freq_index) + + +if __name__ == "__main__": + test.main()