Open-source skip-gram ops
PiperOrigin-RevId: 157655970
This commit is contained in:
parent
faac0331c2
commit
458f94c128
@ -281,6 +281,7 @@ filegroup(
|
||||
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
|
||||
"//tensorflow/contrib/tensorboard:all_files",
|
||||
"//tensorflow/contrib/testing:all_files",
|
||||
"//tensorflow/contrib/text:all_files",
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
|
||||
"//tensorflow/contrib/training:all_files",
|
||||
"//tensorflow/contrib/util:all_files",
|
||||
|
@ -66,6 +66,7 @@ py_library(
|
||||
"//tensorflow/contrib/tensor_forest:init_py",
|
||||
"//tensorflow/contrib/tensorboard",
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
"//tensorflow/contrib/text:text_py",
|
||||
"//tensorflow/contrib/tfprof",
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
@ -84,6 +85,7 @@ cc_library(
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
|
||||
"//tensorflow/contrib/nccl:nccl_kernels",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
|
||||
"//tensorflow/contrib/text:all_kernels",
|
||||
],
|
||||
)
|
||||
|
||||
@ -98,6 +100,7 @@ cc_library(
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
|
||||
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
||||
"//tensorflow/contrib/text:all_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -82,6 +82,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/text/kernels/skip_gram_kernels.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc"
|
||||
)
|
||||
list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs})
|
||||
endif(tensorflow_BUILD_CONTRIB_KERNELS)
|
||||
|
@ -81,6 +81,7 @@ GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contri
|
||||
GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc")
|
||||
|
||||
########################################################
|
||||
|
@ -507,6 +507,11 @@ add_python_module("tensorflow/contrib/tensor_forest/python/ops")
|
||||
add_python_module("tensorflow/contrib/testing")
|
||||
add_python_module("tensorflow/contrib/testing/python")
|
||||
add_python_module("tensorflow/contrib/testing/python/framework")
|
||||
add_python_module("tensorflow/contrib/text")
|
||||
add_python_module("tensorflow/contrib/text/kernels")
|
||||
add_python_module("tensorflow/contrib/text/ops")
|
||||
add_python_module("tensorflow/contrib/text/python")
|
||||
add_python_module("tensorflow/contrib/text/python/ops")
|
||||
add_python_module("tensorflow/contrib/tfprof" DONTCOPY) # SWIG wrapper not implemented.
|
||||
#add_python_module("tensorflow/contrib/tfprof/python")
|
||||
#add_python_module("tensorflow/contrib/tfprof/python/tools")
|
||||
@ -644,6 +649,8 @@ GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_tensor_forest_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_hybrid_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/hybrid/ops/gen_training_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("stateless_random_ops"
|
||||
|
119
tensorflow/contrib/text/BUILD
Normal file
119
tensorflow/contrib/text/BUILD
Normal file
@ -0,0 +1,119 @@
|
||||
# Description:
|
||||
# contains parts of TensorFlow that are experimental or unstable and which
|
||||
# are not supported.
|
||||
|
||||
package(default_visibility = [
|
||||
"//learning/brain:__subpackages__",
|
||||
"//tensorflow:__subpackages__",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
"tf_custom_op_py_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "text_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/ops/__init__.py",
|
||||
"python/ops/skip_gram_ops.py",
|
||||
],
|
||||
dso = [
|
||||
":python/ops/_skip_gram_ops.so",
|
||||
],
|
||||
kernels = [
|
||||
":all_kernels",
|
||||
":all_ops",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_skip_gram_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "skip_gram_kernels",
|
||||
srcs = ["kernels/skip_gram_kernels.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_kernels",
|
||||
deps = [":skip_gram_kernels"],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_skip_gram_ops.so",
|
||||
srcs = [
|
||||
"kernels/skip_gram_kernels.cc",
|
||||
"ops/skip_gram_ops.cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["skip_gram_ops"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_ops",
|
||||
deps = [":skip_gram_ops_op_lib"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_skip_gram_ops",
|
||||
out = "python/ops/gen_skip_gram_ops.py",
|
||||
deps = [":skip_gram_ops_op_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "skip_gram_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["python/ops/skip_gram_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":text_py",
|
||||
"//tensorflow/contrib/lookup:lookup_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
)
|
30
tensorflow/contrib/text/__init__.py
Normal file
30
tensorflow/contrib/text/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Text-processing ops.
|
||||
|
||||
@@skip_gram_sample
|
||||
@@skip_gram_sample_with_text_vocab
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.text.python.ops import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
remove_undocumented(__name__)
|
139
tensorflow/contrib/text/kernels/skip_gram_kernels.cc
Normal file
139
tensorflow/contrib/text/kernels/skip_gram_kernels.cc
Normal file
@ -0,0 +1,139 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T>
|
||||
class SkipGramGenerateCandidatesOp : public OpKernel {
|
||||
public:
|
||||
explicit SkipGramGenerateCandidatesOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, generator_.Init(context));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* input_tensor;
|
||||
OP_REQUIRES_OK(context, context->input("input_tensor", &input_tensor));
|
||||
const auto input = input_tensor->flat<T>();
|
||||
|
||||
const Tensor* min_skips_tensor;
|
||||
OP_REQUIRES_OK(context, context->input("min_skips", &min_skips_tensor));
|
||||
const int min_skips = *(min_skips_tensor->scalar<int>().data());
|
||||
const Tensor* max_skips_tensor;
|
||||
OP_REQUIRES_OK(context, context->input("max_skips", &max_skips_tensor));
|
||||
const int max_skips = *(max_skips_tensor->scalar<int>().data());
|
||||
|
||||
OP_REQUIRES(
|
||||
context, min_skips >= 0 && max_skips >= 0,
|
||||
errors::InvalidArgument("Both min_skips and max_skips must be >= 0."));
|
||||
OP_REQUIRES(context, min_skips <= max_skips,
|
||||
errors::InvalidArgument("min_skips must be <= max_skips."));
|
||||
|
||||
const Tensor* start_tensor;
|
||||
OP_REQUIRES_OK(context, context->input("start", &start_tensor));
|
||||
const int start = *(start_tensor->scalar<int>().data());
|
||||
const Tensor* limit_tensor;
|
||||
OP_REQUIRES_OK(context, context->input("limit", &limit_tensor));
|
||||
const int limit = *(limit_tensor->scalar<int>().data());
|
||||
const int end =
|
||||
limit < 0 ? input.size()
|
||||
: std::min(start + limit, static_cast<int>(input.size()));
|
||||
|
||||
const Tensor* emit_self_tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("emit_self_as_target", &emit_self_tensor));
|
||||
const bool emit_self_as_target = *(emit_self_tensor->scalar<bool>().data());
|
||||
|
||||
std::vector<T> tokens;
|
||||
std::vector<T> labels;
|
||||
|
||||
// Reserve the number of random numbers we will use - we use one for each
|
||||
// token between start and end.
|
||||
random::PhiloxRandom local_gen =
|
||||
generator_.ReserveSamples32(end - start + 1);
|
||||
random::SimplePhilox rng(&local_gen);
|
||||
|
||||
// For each token in the sentence, pick a random skip, then generates
|
||||
// (token, label) pairs for all labels whose distances from the token are
|
||||
// within the range [-skip, skip].
|
||||
for (int i = start; i < end; ++i) {
|
||||
const int skips = min_skips + rng.Uniform(max_skips - min_skips + 1);
|
||||
for (int j = -skips; j <= skips; ++j) {
|
||||
if ((i + j < start) || (i + j >= end) ||
|
||||
(j == 0 && !emit_self_as_target)) {
|
||||
continue;
|
||||
}
|
||||
tokens.push_back(input(i));
|
||||
labels.push_back(input(i + j));
|
||||
}
|
||||
}
|
||||
|
||||
Tensor* tokens_output = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(
|
||||
"tokens", TensorShape({static_cast<int>(tokens.size())}),
|
||||
&tokens_output));
|
||||
Tensor* labels_output = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(
|
||||
"labels", TensorShape({static_cast<int>(labels.size())}),
|
||||
&labels_output));
|
||||
OP_REQUIRES(
|
||||
context, tokens_output->IsSameSize(*labels_output),
|
||||
errors::Internal(strings::StrCat(
|
||||
"Mismatch between tokens_output shape of ",
|
||||
tokens_output->shape().DebugString(),
|
||||
" and labels_output shape of ",
|
||||
labels_output->shape().DebugString(),
|
||||
". This should never happen - contact ami-team@ if it does.")));
|
||||
|
||||
// Copies results to output tensors.
|
||||
for (int i = 0; i < tokens.size(); ++i) {
|
||||
tokens_output->vec<T>()(i) = tokens[i];
|
||||
labels_output->vec<T>()(i) = labels[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
GuardedPhiloxRandom generator_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SkipGramGenerateCandidates") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T"), \
|
||||
SkipGramGenerateCandidatesOp<type>)
|
||||
|
||||
REGISTER_KERNEL(string);
|
||||
REGISTER_KERNEL(int64);
|
||||
REGISTER_KERNEL(int32);
|
||||
REGISTER_KERNEL(int16);
|
||||
// TODO(weiho): Add other types if the need arises.
|
||||
|
||||
#undef REGISTER_KERNEL
|
||||
|
||||
} // namespace tensorflow
|
54
tensorflow/contrib/text/ops/skip_gram_ops.cc
Normal file
54
tensorflow/contrib/text/ops/skip_gram_ops.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER_OP("SkipGramGenerateCandidates")
|
||||
.Input("input_tensor: T")
|
||||
.Input("min_skips: int32")
|
||||
.Input("max_skips: int32")
|
||||
.Input("start: int32")
|
||||
.Input("limit: int32")
|
||||
.Input("emit_self_as_target: bool")
|
||||
.Output("tokens: T")
|
||||
.Output("labels: T")
|
||||
.Attr("T: type")
|
||||
// The seed attributes are needed by GuardedPhiloxRandom
|
||||
.Attr("seed: int = 0")
|
||||
.Attr("seed2: int = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// input_tensor must be of rank-1.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
|
||||
// All other args must be scalar.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||
|
||||
// Due to possible randomness in selecting skips, we only know that the
|
||||
// outputs will be of rank-1, but not their sizes.
|
||||
c->set_output(0, c->Vector(c->UnknownDim()));
|
||||
c->set_output(1, c->Vector(c->UnknownDim()));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Generates skip-gram token and label paired Tensors from the input tensor.
|
||||
See docs for the public-facing skip_gram_sample() Python op for more details.
|
||||
)doc");
|
||||
} // namespace tensorflow
|
22
tensorflow/contrib/text/python/ops/__init__.py
Normal file
22
tensorflow/contrib/text/python/ops/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Various contrib ops related to text-processing."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.text.python.ops.skip_gram_ops import skip_gram_sample
|
||||
from tensorflow.contrib.text.python.ops.skip_gram_ops import skip_gram_sample_with_text_vocab
|
428
tensorflow/contrib/text/python/ops/skip_gram_ops.py
Normal file
428
tensorflow/contrib/text/python/ops/skip_gram_ops.py
Normal file
@ -0,0 +1,428 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Skip-gram sampling ops from https://arxiv.org/abs/1301.3781."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import csv
|
||||
|
||||
from tensorflow.contrib import lookup
|
||||
from tensorflow.contrib.text.python.ops import gen_skip_gram_ops
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.training import input as input_ops
|
||||
|
||||
_checkpoint_ops_so = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_skip_gram_ops.so"))
|
||||
|
||||
ops.NotDifferentiable("SkipGramGenerateCandidates")
|
||||
|
||||
|
||||
def skip_gram_sample(input_tensor,
|
||||
min_skips=1,
|
||||
max_skips=5,
|
||||
start=0,
|
||||
limit=-1,
|
||||
emit_self_as_target=False,
|
||||
vocab_freq_table=None,
|
||||
vocab_min_count=None,
|
||||
vocab_subsampling=None,
|
||||
corpus_size=None,
|
||||
batch_size=None,
|
||||
batch_capacity=None,
|
||||
seed=None,
|
||||
name=None):
|
||||
"""Generates skip-gram token and label paired Tensors from the input tensor.
|
||||
|
||||
Generates skip-gram `("token", "label")` pairs using each element in the
|
||||
rank-1 `input_tensor` as a token. The window size used for each token will be
|
||||
randomly selected from the range specified by `[min_skips, max_skips]`,
|
||||
inclusive. See https://arxiv.org/abs/1301.3781 for more details about
|
||||
skip-gram.
|
||||
|
||||
For example, given `input_tensor = ["the", "quick", "brown", "fox", "jumps"]`,
|
||||
`min_skips = 1`, `max_skips = 2`, `emit_self_as_target = False`, the output
|
||||
`(tokens, labels)` pairs for the token "quick" will be randomly selected from
|
||||
either `(tokens=["quick", "quick"], labels=["the", "brown"])` for 1 skip, or
|
||||
`(tokens=["quick", "quick", "quick"], labels=["the", "brown", "fox"])` for 2
|
||||
skips.
|
||||
|
||||
If `emit_self_as_target = True`, each token will also be emitted as a label
|
||||
for itself. From the previous example, the output will be either
|
||||
`(tokens=["quick", "quick", "quick"], labels=["the", "quick", "brown"])` for 1
|
||||
skip, or `(tokens=["quick", "quick", "quick", "quick"], labels=["the",
|
||||
"quick", "brown", "fox"])` for 2 skips.
|
||||
|
||||
The same process is repeated for each element of `input_tensor` and
|
||||
concatenated together into the two output rank-1 `Tensors` (one for all the
|
||||
tokens, another for all the labels).
|
||||
|
||||
If `vocab_freq_table` is specified, tokens in `input_tensor` that are not
|
||||
present in the vocabulary are discarded. Tokens whose frequency counts are
|
||||
below `vocab_min_count` are also discarded. Tokens whose frequency proportions
|
||||
in the corpus exceed `vocab_subsampling` may be randomly down-sampled. See
|
||||
Eq. 5 in http://arxiv.org/abs/1310.4546 for more details about subsampling.
|
||||
|
||||
Due to the random window sizes used for each token, the lengths of the outputs
|
||||
are non-deterministic, unless `batch_size` is specified to batch the outputs
|
||||
to always return `Tensors` of length `batch_size`.
|
||||
|
||||
Args:
|
||||
input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates.
|
||||
min_skips: `int` or scalar `Tensor` specifying the minimum window size to
|
||||
randomly use for each token. Must be >= 0 and <= `max_skips`. If
|
||||
`min_skips` and `max_skips` are both 0, the only label outputted will be
|
||||
the token itself when `emit_self_as_target = True` - or no output
|
||||
otherwise.
|
||||
max_skips: `int` or scalar `Tensor` specifying the maximum window size to
|
||||
randomly use for each token. Must be >= 0.
|
||||
start: `int` or scalar `Tensor` specifying the position in
|
||||
`input_tensor` from which to start generating skip-gram candidates.
|
||||
limit: `int` or scalar `Tensor` specifying the maximum number of
|
||||
elements in `input_tensor` to use in generating skip-gram candidates. -1
|
||||
means to use the rest of the `Tensor` after `start`.
|
||||
emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
|
||||
each token as a label for itself.
|
||||
vocab_freq_table: (Optional) A lookup table (subclass of
|
||||
`lookup.InitializableLookupTableBase`) that maps tokens to their raw
|
||||
frequency counts. If specified, any token in `input_tensor` that is not
|
||||
found in `vocab_freq_table` will be filtered out before generating
|
||||
skip-gram candidates. While this will typically map to integer raw
|
||||
frequency counts, it could also map to float frequency proportions.
|
||||
`vocab_min_count` and `corpus_size` should be in the same units as this.
|
||||
vocab_min_count: (Optional) `int`, `float`, or scalar `Tensor` specifying
|
||||
minimum frequency threshold (from `vocab_freq_table`) for a token to be
|
||||
kept in `input_tensor`. If this is specified, `vocab_freq_table` must also
|
||||
be specified - and they should both be in the same units.
|
||||
vocab_subsampling: (Optional) `float` specifying frequency proportion
|
||||
threshold for tokens from `input_tensor`. Tokens that occur more
|
||||
frequently (based on the ratio of the token's `vocab_freq_table` value to
|
||||
the `corpus_size`) will be randomly down-sampled. Reasonable starting
|
||||
values may be around 1e-3 or 1e-5. If this is specified, both
|
||||
`vocab_freq_table` and `corpus_size` must also be specified. See Eq. 5
|
||||
in http://arxiv.org/abs/1310.4546 for more details.
|
||||
corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the
|
||||
total number of tokens in the corpus (e.g., sum of all the frequency
|
||||
counts of `vocab_freq_table`). Used with `vocab_subsampling` for
|
||||
down-sampling frequently occurring tokens. If this is specified,
|
||||
`vocab_freq_table` and `vocab_subsampling` must also be specified.
|
||||
batch_size: (Optional) `int` specifying batch size of returned `Tensors`.
|
||||
batch_capacity: (Optional) `int` specifying batch capacity for the queue
|
||||
used for batching returned `Tensors`. Only has an effect if
|
||||
`batch_size` > 0. Defaults to 100 * `batch_size` if not specified.
|
||||
seed: (Optional) `int` used to create a random seed for window size and
|
||||
subsampling. See `set_random_seed` docs for behavior.
|
||||
name: (Optional) A `string` name or a name scope for the operations.
|
||||
|
||||
Returns:
|
||||
A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
|
||||
rank-1 and has the same type as `input_tensor`. The `Tensors` will be of
|
||||
length `batch_size`; if `batch_size` is not specified, they will be of
|
||||
random length, though they will be in sync with each other as long as they
|
||||
are evaluated together.
|
||||
|
||||
Raises:
|
||||
ValueError: If `vocab_freq_table` is not provided, but `vocab_min_count`,
|
||||
`vocab_subsampling`, or `corpus_size` is specified. If `vocab_subsampling`
|
||||
and `corpus_size` are not both present or both absent.
|
||||
"""
|
||||
|
||||
if vocab_freq_table is None and (vocab_min_count is not None or
|
||||
vocab_subsampling is not None or
|
||||
corpus_size is not None):
|
||||
raise ValueError(
|
||||
"vocab_freq_table is not provided, but vocab_min_count={}, "
|
||||
"vocab_subsampling={}, or corpus_size={} is not None. These settings "
|
||||
"are useless without a vocab_freq_table.".format(
|
||||
vocab_min_count, vocab_subsampling, corpus_size))
|
||||
|
||||
if (vocab_subsampling is None) != (corpus_size is None):
|
||||
raise ValueError(
|
||||
"vocab_subsampling is {} while corpus_size is {} - both must be "
|
||||
"provided in order for subsampling to work.".format(
|
||||
vocab_subsampling, corpus_size))
|
||||
|
||||
with ops.name_scope(
|
||||
name,
|
||||
"skip_gram_sample",
|
||||
values=[input_tensor, min_skips, max_skips, start, limit]):
|
||||
|
||||
input_tensor = _filter_input(
|
||||
input_tensor=input_tensor,
|
||||
vocab_freq_table=vocab_freq_table,
|
||||
vocab_min_count=vocab_min_count,
|
||||
vocab_subsampling=vocab_subsampling,
|
||||
corpus_size=corpus_size,
|
||||
seed=seed)
|
||||
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
tokens, labels = gen_skip_gram_ops.skip_gram_generate_candidates(
|
||||
input_tensor=input_tensor,
|
||||
min_skips=min_skips,
|
||||
max_skips=max_skips,
|
||||
start=start,
|
||||
limit=limit,
|
||||
emit_self_as_target=emit_self_as_target,
|
||||
# Note that seed here should be seed1! This is due to
|
||||
# GuardedPhiloxRandom's hard-coded attributes of "seed" and "seed2".
|
||||
seed=seed1,
|
||||
seed2=seed2)
|
||||
|
||||
# TODO(weiho): If the need arises, add support for sparse input_tensor that
|
||||
# figures out sentence boundaries, then calls
|
||||
# skip_gram_generate_candidates() on each sentence.
|
||||
|
||||
# Batches the (tokens, labels) outputs so that they will be of deterministic
|
||||
# batch_size, to facilitate feeding them into the rest of the network.
|
||||
if batch_size is not None and batch_size > 0:
|
||||
batch_capacity = (batch_capacity
|
||||
if (batch_capacity is not None and batch_capacity > 0)
|
||||
else 100 * batch_size)
|
||||
return input_ops.batch(
|
||||
[tokens, labels],
|
||||
batch_size,
|
||||
capacity=batch_capacity,
|
||||
enqueue_many=True)
|
||||
|
||||
return tokens, labels
|
||||
|
||||
|
||||
def skip_gram_sample_with_text_vocab(input_tensor,
|
||||
vocab_freq_file,
|
||||
vocab_token_index=0,
|
||||
vocab_token_dtype=dtypes.string,
|
||||
vocab_freq_index=1,
|
||||
vocab_freq_dtype=dtypes.float64,
|
||||
vocab_delimiter=",",
|
||||
vocab_min_count=0,
|
||||
vocab_subsampling=None,
|
||||
min_skips=1,
|
||||
max_skips=5,
|
||||
start=0,
|
||||
limit=-1,
|
||||
emit_self_as_target=False,
|
||||
batch_size=None,
|
||||
batch_capacity=None,
|
||||
seed=None,
|
||||
name=None):
|
||||
"""Skip-gram sampling with a text vocabulary file.
|
||||
|
||||
Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The
|
||||
vocabulary file is expected to be a plain-text file, with lines of
|
||||
`vocab_delimiter`-separated columns. The `vocab_token_index` column should
|
||||
contain the vocabulary term, while the `vocab_freq_index` column should
|
||||
contain the number of times that term occurs in the corpus. For example, with
|
||||
a text vocabulary file of:
|
||||
|
||||
```
|
||||
bonjour,fr,42
|
||||
hello,en,777
|
||||
hola,es,99
|
||||
```
|
||||
|
||||
You should set `vocab_delimiter=","`, `vocab_token_index=0`, and
|
||||
`vocab_freq_index=2`.
|
||||
|
||||
See `skip_gram_sample()` documentation for more details about the skip-gram
|
||||
sampling process.
|
||||
|
||||
Args:
|
||||
input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates.
|
||||
vocab_freq_file: `string` specifying full file path to the text vocab file.
|
||||
vocab_token_index: `int` specifying which column in the text vocab file
|
||||
contains the tokens.
|
||||
vocab_token_dtype: `DType` specifying the format of the tokens in the text
|
||||
vocab file.
|
||||
vocab_freq_index: `int` specifying which column in the text vocab file
|
||||
contains the frequency counts of the tokens.
|
||||
vocab_freq_dtype: `DType` specifying the format of the frequency counts in
|
||||
the text vocab file.
|
||||
vocab_delimiter: `string` specifying the delimiter used in the text vocab
|
||||
file.
|
||||
vocab_min_count: `int`, `float`, or scalar `Tensor` specifying
|
||||
minimum frequency threshold (from `vocab_freq_file`) for a token to be
|
||||
kept in `input_tensor`. This should correspond with `vocab_freq_dtype`.
|
||||
vocab_subsampling: (Optional) `float` specifying frequency proportion
|
||||
threshold for tokens from `input_tensor`. Tokens that occur more
|
||||
frequently will be randomly down-sampled. Reasonable starting values may
|
||||
be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for
|
||||
more details.
|
||||
min_skips: `int` or scalar `Tensor` specifying the minimum window size to
|
||||
randomly use for each token. Must be >= 0 and <= `max_skips`. If
|
||||
`min_skips` and `max_skips` are both 0, the only label outputted will be
|
||||
the token itself.
|
||||
max_skips: `int` or scalar `Tensor` specifying the maximum window size to
|
||||
randomly use for each token. Must be >= 0.
|
||||
start: `int` or scalar `Tensor` specifying the position in `input_tensor`
|
||||
from which to start generating skip-gram candidates.
|
||||
limit: `int` or scalar `Tensor` specifying the maximum number of elements in
|
||||
`input_tensor` to use in generating skip-gram candidates. -1 means to use
|
||||
the rest of the `Tensor` after `start`.
|
||||
emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
|
||||
each token as a label for itself.
|
||||
batch_size: (Optional) `int` specifying batch size of returned `Tensors`.
|
||||
batch_capacity: (Optional) `int` specifying batch capacity for the queue
|
||||
used for batching returned `Tensors`. Only has an effect if
|
||||
`batch_size` > 0. Defaults to 100 * `batch_size` if not specified.
|
||||
seed: (Optional) `int` used to create a random seed for window size and
|
||||
subsampling. See
|
||||
[`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed)
|
||||
for behavior.
|
||||
name: (Optional) A `string` name or a name scope for the operations.
|
||||
|
||||
Returns:
|
||||
A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
|
||||
rank-1 and has the same type as `input_tensor`. The `Tensors` will be of
|
||||
length `batch_size`; if `batch_size` is not specified, they will be of
|
||||
random length, though they will be in sync with each other as long as they
|
||||
are evaluated together.
|
||||
|
||||
Raises:
|
||||
ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or
|
||||
exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index`
|
||||
and `vocab_freq_index` are both set to the same column. If any token in
|
||||
`vocab_freq_file` has a negative frequency.
|
||||
"""
|
||||
|
||||
if vocab_token_index < 0 or vocab_freq_index < 0:
|
||||
raise ValueError(
|
||||
"vocab_token_index={} and vocab_freq_index={} must both be >= 0.".
|
||||
format(vocab_token_index, vocab_freq_index))
|
||||
if vocab_token_index == vocab_freq_index:
|
||||
raise ValueError(
|
||||
"vocab_token_index and vocab_freq_index should be different, but are "
|
||||
"both {}.".format(vocab_token_index))
|
||||
|
||||
# Iterates through the vocab file and calculates the number of vocab terms as
|
||||
# well as the total corpus size (by summing the frequency counts of all the
|
||||
# vocab terms).
|
||||
corpus_size = 0.0
|
||||
vocab_size = 0
|
||||
with gfile.GFile(vocab_freq_file, mode="r") as f:
|
||||
reader = csv.reader(f, delimiter=vocab_delimiter)
|
||||
for row in reader:
|
||||
if vocab_token_index >= len(row) or vocab_freq_index >= len(row):
|
||||
raise ValueError(
|
||||
"Row in vocab file only has {} columns, so vocab_token_index={} or "
|
||||
"vocab_freq_index={} is out of bounds. Row content: {}".format(
|
||||
len(row), vocab_token_index, vocab_freq_index, row))
|
||||
vocab_size += 1
|
||||
freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index])
|
||||
if freq < 0:
|
||||
raise ValueError(
|
||||
"Row in vocab file has negative frequency of {}. Row content: {}".
|
||||
format(freq, row))
|
||||
# Note: tokens whose frequencies are below vocab_min_count will still
|
||||
# contribute to the total corpus size used for vocab subsampling.
|
||||
corpus_size += freq
|
||||
|
||||
vocab_freq_table = lookup.HashTable(
|
||||
lookup.TextFileInitializer(
|
||||
filename=vocab_freq_file,
|
||||
key_dtype=vocab_token_dtype,
|
||||
key_index=vocab_token_index,
|
||||
value_dtype=vocab_freq_dtype,
|
||||
value_index=vocab_freq_index,
|
||||
vocab_size=vocab_size,
|
||||
delimiter=vocab_delimiter),
|
||||
# For vocab terms not in vocab file, use a default value of -1.
|
||||
default_value=-1)
|
||||
|
||||
return skip_gram_sample(
|
||||
input_tensor,
|
||||
min_skips=min_skips,
|
||||
max_skips=max_skips,
|
||||
start=start,
|
||||
limit=limit,
|
||||
emit_self_as_target=emit_self_as_target,
|
||||
vocab_freq_table=vocab_freq_table,
|
||||
vocab_min_count=vocab_min_count,
|
||||
vocab_subsampling=vocab_subsampling,
|
||||
# corpus_size is not used unless vocab_subsampling is specified.
|
||||
corpus_size=None if vocab_subsampling is None else corpus_size,
|
||||
batch_size=batch_size,
|
||||
batch_capacity=batch_capacity,
|
||||
seed=seed,
|
||||
name=name)
|
||||
|
||||
|
||||
def _filter_input(input_tensor, vocab_freq_table, vocab_min_count,
|
||||
vocab_subsampling, corpus_size, seed):
|
||||
"""Filters input tensor based on vocab freq, threshold, and subsampling."""
|
||||
if vocab_freq_table is None:
|
||||
return input_tensor
|
||||
|
||||
if not isinstance(vocab_freq_table, lookup.InitializableLookupTableBase):
|
||||
raise ValueError(
|
||||
"vocab_freq_table must be a subclass of "
|
||||
"InitializableLookupTableBase (such as HashTable) instead of type "
|
||||
"{}.".format(type(vocab_freq_table)))
|
||||
|
||||
with ops.name_scope(
|
||||
"filter_vocab", values=[vocab_freq_table, input_tensor, vocab_min_count]):
|
||||
freq = vocab_freq_table.lookup(input_tensor)
|
||||
# Filters out elements in input_tensor that are not found in
|
||||
# vocab_freq_table (table returns a default value of -1 specified above when
|
||||
# an element is not found).
|
||||
mask = math_ops.not_equal(freq, vocab_freq_table.default_value)
|
||||
|
||||
# Filters out elements whose vocab frequencies are less than the threshold.
|
||||
if vocab_min_count is not None:
|
||||
cast_threshold = math_ops.cast(vocab_min_count, freq.dtype)
|
||||
mask = math_ops.logical_and(mask,
|
||||
math_ops.greater_equal(freq, cast_threshold))
|
||||
|
||||
input_tensor = array_ops.boolean_mask(input_tensor, mask)
|
||||
freq = array_ops.boolean_mask(freq, mask)
|
||||
|
||||
if not vocab_subsampling:
|
||||
return input_tensor
|
||||
|
||||
if vocab_subsampling < 0 or vocab_subsampling > 1:
|
||||
raise ValueError(
|
||||
"Invalid vocab_subsampling={} - it should be within range [0, 1].".
|
||||
format(vocab_subsampling))
|
||||
|
||||
# Subsamples the input tokens based on vocabulary frequency and
|
||||
# vocab_subsampling threshold (ie randomly discard commonly appearing
|
||||
# tokens).
|
||||
with ops.name_scope(
|
||||
"subsample_vocab", values=[input_tensor, freq, vocab_subsampling]):
|
||||
corpus_size = math_ops.cast(corpus_size, dtypes.float64)
|
||||
freq = math_ops.cast(freq, dtypes.float64)
|
||||
vocab_subsampling = math_ops.cast(vocab_subsampling, dtypes.float64)
|
||||
|
||||
# From tensorflow_models/tutorials/embedding/word2vec_kernels.cc, which is
|
||||
# suppose to correlate with Eq. 5 in http://arxiv.org/abs/1310.4546.
|
||||
keep_prob = ((math_ops.sqrt(freq /
|
||||
(vocab_subsampling * corpus_size)) + 1.0) *
|
||||
(vocab_subsampling * corpus_size / freq))
|
||||
random_prob = random_ops.random_uniform(
|
||||
array_ops.shape(freq),
|
||||
minval=0,
|
||||
maxval=1,
|
||||
dtype=dtypes.float64,
|
||||
seed=seed)
|
||||
|
||||
mask = math_ops.less_equal(random_prob, keep_prob)
|
||||
return array_ops.boolean_mask(input_tensor, mask)
|
571
tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
Normal file
571
tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
Normal file
@ -0,0 +1,571 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Skip-gram sampling ops tests."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import csv
|
||||
import os
|
||||
|
||||
from tensorflow.contrib import lookup
|
||||
from tensorflow.contrib import text
|
||||
from tensorflow.contrib.text.python.ops import skip_gram_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import queue_runner_impl
|
||||
|
||||
|
||||
class SkipGramOpsTest(test.TestCase):
|
||||
|
||||
def _split_tokens_labels(self, output):
|
||||
tokens = [x[0] for x in output]
|
||||
labels = [x[1] for x in output]
|
||||
return tokens, labels
|
||||
|
||||
def test_skip_gram_sample_skips_2(self):
|
||||
"""Tests skip-gram with min_skips = max_skips = 2."""
|
||||
input_tensor = constant_op.constant(
|
||||
[b"the", b"quick", b"brown", b"fox", b"jumps"])
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=2, max_skips=2)
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(b"the", b"quick"),
|
||||
(b"the", b"brown"),
|
||||
(b"quick", b"the"),
|
||||
(b"quick", b"brown"),
|
||||
(b"quick", b"fox"),
|
||||
(b"brown", b"the"),
|
||||
(b"brown", b"quick"),
|
||||
(b"brown", b"fox"),
|
||||
(b"brown", b"jumps"),
|
||||
(b"fox", b"quick"),
|
||||
(b"fox", b"brown"),
|
||||
(b"fox", b"jumps"),
|
||||
(b"jumps", b"brown"),
|
||||
(b"jumps", b"fox"),
|
||||
])
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_tokens, tokens.eval())
|
||||
self.assertAllEqual(expected_labels, labels.eval())
|
||||
|
||||
def test_skip_gram_sample_emit_self(self):
|
||||
"""Tests skip-gram with emit_self_as_target = True."""
|
||||
input_tensor = constant_op.constant(
|
||||
[b"the", b"quick", b"brown", b"fox", b"jumps"])
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True)
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(b"the", b"the"),
|
||||
(b"the", b"quick"),
|
||||
(b"the", b"brown"),
|
||||
(b"quick", b"the"),
|
||||
(b"quick", b"quick"),
|
||||
(b"quick", b"brown"),
|
||||
(b"quick", b"fox"),
|
||||
(b"brown", b"the"),
|
||||
(b"brown", b"quick"),
|
||||
(b"brown", b"brown"),
|
||||
(b"brown", b"fox"),
|
||||
(b"brown", b"jumps"),
|
||||
(b"fox", b"quick"),
|
||||
(b"fox", b"brown"),
|
||||
(b"fox", b"fox"),
|
||||
(b"fox", b"jumps"),
|
||||
(b"jumps", b"brown"),
|
||||
(b"jumps", b"fox"),
|
||||
(b"jumps", b"jumps"),
|
||||
])
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_tokens, tokens.eval())
|
||||
self.assertAllEqual(expected_labels, labels.eval())
|
||||
|
||||
def test_skip_gram_sample_skips_0(self):
|
||||
"""Tests skip-gram with min_skips = max_skips = 0."""
|
||||
input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
|
||||
|
||||
# If emit_self_as_target is False (default), output will be empty.
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
|
||||
with self.test_session():
|
||||
self.assertEqual(0, tokens.eval().size)
|
||||
self.assertEqual(0, labels.eval().size)
|
||||
|
||||
# If emit_self_as_target is True, each token will be its own label.
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True)
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(b"the", b"the"),
|
||||
(b"quick", b"quick"),
|
||||
(b"brown", b"brown"),
|
||||
])
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_tokens, tokens.eval())
|
||||
self.assertAllEqual(expected_labels, labels.eval())
|
||||
|
||||
def test_skip_gram_sample_skips_exceed_length(self):
|
||||
"""Tests skip-gram when min/max_skips exceed length of input."""
|
||||
input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=100, max_skips=100)
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(b"the", b"quick"),
|
||||
(b"the", b"brown"),
|
||||
(b"quick", b"the"),
|
||||
(b"quick", b"brown"),
|
||||
(b"brown", b"the"),
|
||||
(b"brown", b"quick"),
|
||||
])
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_tokens, tokens.eval())
|
||||
self.assertAllEqual(expected_labels, labels.eval())
|
||||
|
||||
def test_skip_gram_sample_start_limit(self):
|
||||
"""Tests skip-gram over a limited portion of the input."""
|
||||
input_tensor = constant_op.constant(
|
||||
[b"foo", b"the", b"quick", b"brown", b"bar"])
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=1, max_skips=1, start=1, limit=3)
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(b"the", b"quick"),
|
||||
(b"quick", b"the"),
|
||||
(b"quick", b"brown"),
|
||||
(b"brown", b"quick"),
|
||||
])
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_tokens, tokens.eval())
|
||||
self.assertAllEqual(expected_labels, labels.eval())
|
||||
|
||||
def test_skip_gram_sample_limit_exceeds(self):
|
||||
"""Tests skip-gram when limit exceeds the length of the input."""
|
||||
input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"])
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=1, max_skips=1, start=1, limit=100)
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(b"the", b"quick"),
|
||||
(b"quick", b"the"),
|
||||
(b"quick", b"brown"),
|
||||
(b"brown", b"quick"),
|
||||
])
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_tokens, tokens.eval())
|
||||
self.assertAllEqual(expected_labels, labels.eval())
|
||||
|
||||
def test_skip_gram_sample_random_skips(self):
|
||||
"""Tests skip-gram with min_skips != max_skips, with random output."""
|
||||
# The number of outputs is non-deterministic in this case, so set random
|
||||
# seed to help ensure the outputs remain constant for this test case.
|
||||
random_seed.set_random_seed(42)
|
||||
|
||||
input_tensor = constant_op.constant(
|
||||
[b"the", b"quick", b"brown", b"fox", b"jumps", b"over"])
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=1, max_skips=2, seed=9)
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(b"the", b"quick"),
|
||||
(b"the", b"brown"),
|
||||
(b"quick", b"the"),
|
||||
(b"quick", b"brown"),
|
||||
(b"quick", b"fox"),
|
||||
(b"brown", b"the"),
|
||||
(b"brown", b"quick"),
|
||||
(b"brown", b"fox"),
|
||||
(b"brown", b"jumps"),
|
||||
(b"fox", b"brown"),
|
||||
(b"fox", b"jumps"),
|
||||
(b"jumps", b"fox"),
|
||||
(b"jumps", b"over"),
|
||||
(b"over", b"fox"),
|
||||
(b"over", b"jumps"),
|
||||
])
|
||||
with self.test_session() as sess:
|
||||
tokens_eval, labels_eval = sess.run([tokens, labels])
|
||||
self.assertAllEqual(expected_tokens, tokens_eval)
|
||||
self.assertAllEqual(expected_labels, labels_eval)
|
||||
|
||||
def test_skip_gram_sample_random_skips_default_seed(self):
|
||||
"""Tests outputs are still random when no op-level seed is specified."""
|
||||
# This is needed since tests set a graph-level seed by default. We want to
|
||||
# explicitly avoid setting both graph-level seed and op-level seed, to
|
||||
# simulate behavior under non-test settings when the user doesn't provide a
|
||||
# seed to us. This results in random_seed.get_seed() returning None for both
|
||||
# seeds, forcing the C++ kernel to execute its default seed logic.
|
||||
random_seed.set_random_seed(None)
|
||||
|
||||
# Uses an input tensor with 10 words, with possible skip ranges in [1,
|
||||
# 5]. Thus, the probability that two random samplings would result in the
|
||||
# same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being
|
||||
# flaky).
|
||||
input_tensor = constant_op.constant([str(x) for x in range(10)])
|
||||
|
||||
# Do not provide an op-level seed here!
|
||||
tokens_1, labels_1 = text.skip_gram_sample(
|
||||
input_tensor, min_skips=1, max_skips=5)
|
||||
tokens_2, labels_2 = text.skip_gram_sample(
|
||||
input_tensor, min_skips=1, max_skips=5)
|
||||
|
||||
with self.test_session() as sess:
|
||||
tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
|
||||
[tokens_1, labels_1, tokens_2, labels_2])
|
||||
|
||||
if len(tokens_1_eval) == len(tokens_2_eval):
|
||||
self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist())
|
||||
if len(labels_1_eval) == len(labels_2_eval):
|
||||
self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist())
|
||||
|
||||
def test_skip_gram_sample_batch(self):
|
||||
"""Tests skip-gram with batching."""
|
||||
input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"])
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=1, max_skips=1, batch_size=3)
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(b"the", b"quick"),
|
||||
(b"quick", b"the"),
|
||||
(b"quick", b"brown"),
|
||||
(b"brown", b"quick"),
|
||||
(b"brown", b"fox"),
|
||||
(b"fox", b"brown"),
|
||||
])
|
||||
with self.test_session() as sess:
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
|
||||
|
||||
tokens_eval, labels_eval = sess.run([tokens, labels])
|
||||
self.assertAllEqual(expected_tokens[:3], tokens_eval)
|
||||
self.assertAllEqual(expected_labels[:3], labels_eval)
|
||||
tokens_eval, labels_eval = sess.run([tokens, labels])
|
||||
self.assertAllEqual(expected_tokens[3:6], tokens_eval)
|
||||
self.assertAllEqual(expected_labels[3:6], labels_eval)
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def test_skip_gram_sample_non_string_input(self):
|
||||
"""Tests skip-gram with non-string input."""
|
||||
input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16)
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=1, max_skips=1)
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(1, 2),
|
||||
(2, 1),
|
||||
(2, 3),
|
||||
(3, 2),
|
||||
])
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_tokens, tokens.eval())
|
||||
self.assertAllEqual(expected_labels, labels.eval())
|
||||
|
||||
def test_skip_gram_sample_errors(self):
|
||||
"""Tests various errors raised by skip_gram_sample()."""
|
||||
input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
|
||||
|
||||
invalid_skips = (
|
||||
# min_skips and max_skips must be >= 0.
|
||||
(-1, 2),
|
||||
(1, -2),
|
||||
# min_skips must be <= max_skips.
|
||||
(2, 1))
|
||||
for min_skips, max_skips in invalid_skips:
|
||||
tokens, labels = text.skip_gram_sample(
|
||||
input_tensor, min_skips=min_skips, max_skips=max_skips)
|
||||
with self.test_session() as sess, self.assertRaises(
|
||||
errors.InvalidArgumentError):
|
||||
sess.run([tokens, labels])
|
||||
|
||||
# input_tensor must be of rank 1.
|
||||
with self.assertRaises(ValueError):
|
||||
invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]])
|
||||
text.skip_gram_sample(invalid_tensor)
|
||||
|
||||
# vocab_freq_table must be provided if vocab_min_count, vocab_subsampling,
|
||||
# or corpus_size is specified.
|
||||
dummy_input = constant_op.constant([""])
|
||||
with self.assertRaises(ValueError):
|
||||
text.skip_gram_sample(
|
||||
dummy_input, vocab_freq_table=None, vocab_min_count=1)
|
||||
with self.assertRaises(ValueError):
|
||||
text.skip_gram_sample(
|
||||
dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5)
|
||||
with self.assertRaises(ValueError):
|
||||
text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100)
|
||||
with self.assertRaises(ValueError):
|
||||
text.skip_gram_sample(
|
||||
dummy_input,
|
||||
vocab_freq_table=None,
|
||||
vocab_subsampling=1e-5,
|
||||
corpus_size=100)
|
||||
|
||||
# vocab_subsampling and corpus_size must both be present or absent.
|
||||
dummy_table = lookup.HashTable(
|
||||
lookup.KeyValueTensorInitializer([b"foo"], [10]), -1)
|
||||
with self.assertRaises(ValueError):
|
||||
text.skip_gram_sample(
|
||||
dummy_input,
|
||||
vocab_freq_table=dummy_table,
|
||||
vocab_subsampling=None,
|
||||
corpus_size=100)
|
||||
with self.assertRaises(ValueError):
|
||||
text.skip_gram_sample(
|
||||
dummy_input,
|
||||
vocab_freq_table=dummy_table,
|
||||
vocab_subsampling=1e-5,
|
||||
corpus_size=None)
|
||||
|
||||
def test_filter_input_filter_vocab(self):
|
||||
"""Tests input filtering based on vocab frequency table and thresholds."""
|
||||
input_tensor = constant_op.constant(
|
||||
[b"the", b"answer", b"to", b"life", b"and", b"universe"])
|
||||
keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
|
||||
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
|
||||
vocab_freq_table = lookup.HashTable(
|
||||
lookup.KeyValueTensorInitializer(keys, values), -1)
|
||||
|
||||
with self.test_session():
|
||||
vocab_freq_table.init.run()
|
||||
|
||||
# No vocab_freq_table specified - output should be the same as input.
|
||||
no_table_output = skip_gram_ops._filter_input(
|
||||
input_tensor=input_tensor,
|
||||
vocab_freq_table=None,
|
||||
vocab_min_count=None,
|
||||
vocab_subsampling=None,
|
||||
corpus_size=None,
|
||||
seed=None)
|
||||
self.assertAllEqual(input_tensor.eval(), no_table_output.eval())
|
||||
|
||||
# vocab_freq_table specified, but no vocab_min_count - output should have
|
||||
# filtered out tokens not in the table (b"answer").
|
||||
table_output = skip_gram_ops._filter_input(
|
||||
input_tensor=input_tensor,
|
||||
vocab_freq_table=vocab_freq_table,
|
||||
vocab_min_count=None,
|
||||
vocab_subsampling=None,
|
||||
corpus_size=None,
|
||||
seed=None)
|
||||
self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"],
|
||||
table_output.eval())
|
||||
|
||||
# vocab_freq_table and vocab_min_count specified - output should have
|
||||
# filtered out tokens whose frequencies are below the threshold
|
||||
# (b"and": 0, b"life": 1).
|
||||
threshold_output = skip_gram_ops._filter_input(
|
||||
input_tensor=input_tensor,
|
||||
vocab_freq_table=vocab_freq_table,
|
||||
vocab_min_count=2,
|
||||
vocab_subsampling=None,
|
||||
corpus_size=None,
|
||||
seed=None)
|
||||
self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval())
|
||||
|
||||
def test_filter_input_subsample_vocab(self):
|
||||
"""Tests input filtering based on vocab subsampling."""
|
||||
# The outputs are non-deterministic, so set random seed to help ensure that
|
||||
# the outputs remain constant for testing.
|
||||
random_seed.set_random_seed(42)
|
||||
|
||||
input_tensor = constant_op.constant([
|
||||
# keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
|
||||
b"the",
|
||||
b"answer", # Not in vocab. (Always discarded)
|
||||
b"to", # keep_prob = 0.75.
|
||||
b"life", # keep_prob > 1. (Always kept)
|
||||
b"and", # keep_prob = 0.48.
|
||||
b"universe" # Below vocab threshold of 3. (Always discarded)
|
||||
])
|
||||
keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
|
||||
values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64)
|
||||
vocab_freq_table = lookup.HashTable(
|
||||
lookup.KeyValueTensorInitializer(keys, values), -1)
|
||||
|
||||
with self.test_session():
|
||||
vocab_freq_table.init.run()
|
||||
output = skip_gram_ops._filter_input(
|
||||
input_tensor=input_tensor,
|
||||
vocab_freq_table=vocab_freq_table,
|
||||
vocab_min_count=3,
|
||||
vocab_subsampling=0.05,
|
||||
corpus_size=math_ops.reduce_sum(values),
|
||||
seed=9)
|
||||
self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval())
|
||||
|
||||
def _make_text_vocab_freq_file(self):
|
||||
filepath = os.path.join(test.get_temp_dir(), "vocab_freq.txt")
|
||||
with open(filepath, "w") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerows([
|
||||
["and", 40],
|
||||
["life", 8],
|
||||
["the", 30],
|
||||
["to", 20],
|
||||
["universe", 2],
|
||||
])
|
||||
return filepath
|
||||
|
||||
def _make_text_vocab_float_file(self):
|
||||
filepath = os.path.join(test.get_temp_dir(), "vocab_freq_float.txt")
|
||||
with open(filepath, "w") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerows([
|
||||
["and", 0.4],
|
||||
["life", 0.08],
|
||||
["the", 0.3],
|
||||
["to", 0.2],
|
||||
["universe", 0.02],
|
||||
])
|
||||
return filepath
|
||||
|
||||
def test_skip_gram_sample_with_text_vocab_filter_vocab(self):
|
||||
"""Tests skip-gram sampling with text vocab and freq threshold filtering."""
|
||||
input_tensor = constant_op.constant([
|
||||
b"the",
|
||||
b"answer", # Will be filtered before candidate generation.
|
||||
b"to",
|
||||
b"life",
|
||||
b"and",
|
||||
b"universe" # Will be filtered before candidate generation.
|
||||
])
|
||||
|
||||
# b"answer" is not in vocab file, and b"universe"'s frequency is below
|
||||
# threshold of 3.
|
||||
vocab_freq_file = self._make_text_vocab_freq_file()
|
||||
|
||||
tokens, labels = text.skip_gram_sample_with_text_vocab(
|
||||
input_tensor=input_tensor,
|
||||
vocab_freq_file=vocab_freq_file,
|
||||
vocab_token_index=0,
|
||||
vocab_freq_index=1,
|
||||
vocab_min_count=3,
|
||||
min_skips=1,
|
||||
max_skips=1)
|
||||
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(b"the", b"to"),
|
||||
(b"to", b"the"),
|
||||
(b"to", b"life"),
|
||||
(b"life", b"to"),
|
||||
(b"life", b"and"),
|
||||
(b"and", b"life"),
|
||||
])
|
||||
with self.test_session():
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual(expected_tokens, tokens.eval())
|
||||
self.assertAllEqual(expected_labels, labels.eval())
|
||||
|
||||
def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count,
|
||||
vocab_freq_dtype):
|
||||
# The outputs are non-deterministic, so set random seed to help ensure that
|
||||
# the outputs remain constant for testing.
|
||||
random_seed.set_random_seed(42)
|
||||
|
||||
input_tensor = constant_op.constant([
|
||||
# keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
|
||||
b"the",
|
||||
b"answer", # Not in vocab. (Always discarded)
|
||||
b"to", # keep_prob = 0.75.
|
||||
b"life", # keep_prob > 1. (Always kept)
|
||||
b"and", # keep_prob = 0.48.
|
||||
b"universe" # Below vocab threshold of 3. (Always discarded)
|
||||
])
|
||||
# keep_prob calculated from vocab file with relative frequencies of:
|
||||
# and: 40
|
||||
# life: 8
|
||||
# the: 30
|
||||
# to: 20
|
||||
# universe: 2
|
||||
|
||||
tokens, labels = text.skip_gram_sample_with_text_vocab(
|
||||
input_tensor=input_tensor,
|
||||
vocab_freq_file=vocab_freq_file,
|
||||
vocab_token_index=0,
|
||||
vocab_freq_index=1,
|
||||
vocab_freq_dtype=vocab_freq_dtype,
|
||||
vocab_min_count=vocab_min_count,
|
||||
vocab_subsampling=0.05,
|
||||
min_skips=1,
|
||||
max_skips=1,
|
||||
seed=123)
|
||||
|
||||
expected_tokens, expected_labels = self._split_tokens_labels([
|
||||
(b"the", b"to"),
|
||||
(b"to", b"the"),
|
||||
(b"to", b"life"),
|
||||
(b"life", b"to"),
|
||||
])
|
||||
with self.test_session() as sess:
|
||||
lookup_ops.tables_initializer().run()
|
||||
tokens_eval, labels_eval = sess.run([tokens, labels])
|
||||
self.assertAllEqual(expected_tokens, tokens_eval)
|
||||
self.assertAllEqual(expected_labels, labels_eval)
|
||||
|
||||
def test_skip_gram_sample_with_text_vocab_subsample_vocab(self):
|
||||
"""Tests skip-gram sampling with text vocab and vocab subsampling."""
|
||||
# Vocab file frequencies
|
||||
# and: 40
|
||||
# life: 8
|
||||
# the: 30
|
||||
# to: 20
|
||||
# universe: 2
|
||||
self._text_vocab_subsample_vocab_helper(
|
||||
vocab_freq_file=self._make_text_vocab_freq_file(),
|
||||
vocab_min_count=3,
|
||||
vocab_freq_dtype=dtypes.int64)
|
||||
|
||||
def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self):
|
||||
"""Tests skip-gram sampling with text vocab and subsampling with floats."""
|
||||
# Vocab file frequencies
|
||||
# and: 0.4
|
||||
# life: 0.08
|
||||
# the: 0.3
|
||||
# to: 0.2
|
||||
# universe: 0.02
|
||||
self._text_vocab_subsample_vocab_helper(
|
||||
vocab_freq_file=self._make_text_vocab_float_file(),
|
||||
vocab_min_count=0.03,
|
||||
vocab_freq_dtype=dtypes.float32)
|
||||
|
||||
def test_skip_gram_sample_with_text_vocab_errors(self):
|
||||
"""Tests various errors raised by skip_gram_sample_with_text_vocab()."""
|
||||
dummy_input = constant_op.constant([""])
|
||||
vocab_freq_file = self._make_text_vocab_freq_file()
|
||||
|
||||
invalid_indices = (
|
||||
# vocab_token_index can't be negative.
|
||||
(-1, 0),
|
||||
# vocab_freq_index can't be negative.
|
||||
(0, -1),
|
||||
# vocab_token_index can't be equal to vocab_freq_index.
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
# vocab_freq_file only has two columns.
|
||||
(0, 2),
|
||||
(2, 0))
|
||||
|
||||
for vocab_token_index, vocab_freq_index in invalid_indices:
|
||||
with self.assertRaises(ValueError):
|
||||
text.skip_gram_sample_with_text_vocab(
|
||||
input_tensor=dummy_input,
|
||||
vocab_freq_file=vocab_freq_file,
|
||||
vocab_token_index=vocab_token_index,
|
||||
vocab_freq_index=vocab_freq_index)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user