Open-source ops to load and remap matrix (2-D) tensors. (Used for loading/remapping embeddings, warm-starting weights/biases, etc.)

PiperOrigin-RevId: 157148893
This commit is contained in:
Wei Ho 2017-05-25 13:54:07 -07:00 committed by TensorFlower Gardener
parent 6993d7f7c5
commit 4659e562c9
23 changed files with 1822 additions and 2 deletions

View File

@ -78,6 +78,8 @@ cc_library(
deps = [ deps = [
"//tensorflow/contrib/batching:batch_ops_kernels", "//tensorflow/contrib/batching:batch_ops_kernels",
"//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/framework:generate_vocab_remapping_kernel",
"//tensorflow/contrib/framework:load_and_remap_matrix_kernel",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
"//tensorflow/contrib/nccl:nccl_kernels", "//tensorflow/contrib/nccl:nccl_kernels",

View File

@ -47,6 +47,10 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc"
#"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/decode_audio_op.cc" #"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/decode_audio_op.cc"
#"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc" #"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/generate_vocab_remapping_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/checkpoint_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/generate_vocab_remapping_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"

View File

@ -69,6 +69,8 @@ file(GLOB_RECURSE tensor_forest_hybrid_srcs
GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(framework_checkpoint "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/checkpoint_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(framework_generate_vocab_remapping "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/generate_vocab_remapping_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(input_pipeline "${tensorflow_source_dir}/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(input_pipeline "${tensorflow_source_dir}/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc")

View File

@ -619,6 +619,10 @@ GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py) DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops" GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_factorization_ops.py) DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_factorization_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_framework_checkpoint_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_checkpoint_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_framework_generate_vocab_remapping_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_generate_vocab_remapping_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_framework_variable_ops" GENERATE_PYTHON_OP_LIB("contrib_framework_variable_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_variable_ops.py) DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_variable_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_input_pipeline_ops" GENERATE_PYTHON_OP_LIB("contrib_input_pipeline_ops"

View File

@ -25,6 +25,7 @@ tf_custom_op_py_library(
"python/framework/__init__.py", "python/framework/__init__.py",
"python/framework/checkpoint_utils.py", "python/framework/checkpoint_utils.py",
"python/framework/experimental.py", "python/framework/experimental.py",
"python/framework/load_and_remap_matrix_ops.py",
"python/framework/tensor_util.py", "python/framework/tensor_util.py",
"python/ops/__init__.py", "python/ops/__init__.py",
"python/ops/arg_scope.py", "python/ops/arg_scope.py",
@ -34,13 +35,20 @@ tf_custom_op_py_library(
], ],
dso = [ dso = [
":python/ops/_variable_ops.so", ":python/ops/_variable_ops.so",
":python/framework/_checkpoint_ops.so",
], ],
kernels = [ kernels = [
":checkpoint_ops_op_lib",
":generate_vocab_remapping_kernel",
":generate_vocab_remapping_ops_op_lib",
":load_and_remap_matrix_kernel",
":variable_kernels", ":variable_kernels",
":variable_ops_op_lib", ":variable_ops_op_lib",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":gen_checkpoint_ops",
":gen_generate_vocab_remapping_ops",
":gen_variable_ops", ":gen_variable_ops",
"//tensorflow/contrib/util:util_py", "//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
@ -81,6 +89,32 @@ tf_kernel_library(
alwayslink = 1, alwayslink = 1,
) )
tf_kernel_library(
name = "generate_vocab_remapping_kernel",
srcs = ["kernels/generate_vocab_remapping_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:lookup_table_init_op",
"//tensorflow/core/kernels:lookup_table_op",
"//third_party/eigen3",
],
alwayslink = 1,
)
tf_kernel_library(
name = "load_and_remap_matrix_kernel",
srcs = ["kernels/load_and_remap_matrix_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/util/tensor_bundle",
"//third_party/eigen3",
],
alwayslink = 1,
)
tf_custom_op_library( tf_custom_op_library(
name = "python/ops/_variable_ops.so", name = "python/ops/_variable_ops.so",
srcs = [ srcs = [
@ -94,13 +128,33 @@ tf_custom_op_library(
], ],
) )
tf_custom_op_library(
name = "python/framework/_checkpoint_ops.so",
srcs = [
"kernels/generate_vocab_remapping_op.cc",
"kernels/load_and_remap_matrix_op.cc",
"ops/checkpoint_ops.cc",
"ops/generate_vocab_remapping_ops.cc",
],
deps = [
"//tensorflow/core/kernels:lookup_headers_lib",
"//tensorflow/core/util/tensor_bundle:tensor_bundle_headers_lib",
],
)
tf_gen_op_libs( tf_gen_op_libs(
op_lib_names = ["variable_ops"], op_lib_names = [
"checkpoint_ops",
"generate_vocab_remapping_ops",
"variable_ops",
],
) )
cc_library( cc_library(
name = "all_ops", name = "all_ops",
deps = [ deps = [
":checkpoint_ops_op_lib",
":generate_vocab_remapping_ops_op_lib",
":variable_ops_op_lib", ":variable_ops_op_lib",
], ],
) )
@ -113,6 +167,22 @@ tf_gen_op_wrapper_py(
], ],
) )
tf_gen_op_wrapper_py(
name = "gen_generate_vocab_remapping_ops",
out = "python/ops/gen_generate_vocab_remapping_ops.py",
deps = [
":generate_vocab_remapping_ops_op_lib",
],
)
tf_gen_op_wrapper_py(
name = "gen_checkpoint_ops",
out = "python/ops/gen_checkpoint_ops.py",
deps = [
":checkpoint_ops_op_lib",
],
)
py_test( py_test(
name = "arg_scope_test", name = "arg_scope_test",
size = "small", size = "small",
@ -227,6 +297,43 @@ py_test(
], ],
) )
filegroup(
name = "load_and_remap_matrix_testdata",
srcs = [
"testdata/bundle_checkpoint.data-00000-of-00001",
"testdata/bundle_checkpoint.index",
"testdata/bundle_checkpoint_vocab.txt",
"testdata/bundle_checkpoint_vocab_with_oov.txt",
"testdata/keyword.txt",
"testdata/keyword_new.txt",
"testdata/keyword_shifted.txt",
],
)
py_test(
name = "load_and_remap_matrix_ops_test",
size = "medium",
srcs = ["python/framework/load_and_remap_matrix_ops_test.py"],
data = [":load_and_remap_matrix_testdata"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:partitioned_variables",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
filegroup( filegroup(
name = "all_files", name = "all_files",
srcs = glob( srcs = glob(

View File

@ -74,6 +74,9 @@ See the @{$python/contrib.framework} guide.
@@list_variables @@list_variables
@@load_variable @@load_variable
@@init_from_checkpoint @@init_from_checkpoint
@@load_and_remap_matrix_initializer
@@load_embedding_initializer
@@load_linear_multiclass_bias_initializer
""" """
from __future__ import absolute_import from __future__ import absolute_import

View File

@ -0,0 +1,173 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/lookup_table_init_op.h"
#include "tensorflow/core/kernels/lookup_table_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace {
// lookup::InitializeTableFromTextFile requires a delimiter even though we use
// the entire line for vocabularies.
constexpr char kUnusedLookupDelim = '\t';
} // namespace
// This Op generates a vocab remapping Tensor from an old and new vocabulary
// file that maps new ID's to old ID's.
class GenerateVocabRemappingOp : public OpKernel {
public:
explicit GenerateVocabRemappingOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context,
context->GetAttr("new_vocab_offset", &new_vocab_offset_));
OP_REQUIRES_OK(context, context->GetAttr("num_new_vocab", &num_new_vocab_));
}
void Compute(OpKernelContext* context) override {
const Tensor* new_vocab_file_tensor;
OP_REQUIRES_OK(context,
context->input("new_vocab_file", &new_vocab_file_tensor));
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(new_vocab_file_tensor->shape()),
errors::InvalidArgument(
"new_vocab_file should be a single string, but got ",
new_vocab_file_tensor->shape().DebugString()));
// Build a new ID->token lookup table.
const string& new_vocab_filename =
new_vocab_file_tensor->scalar<string>()();
OP_REQUIRES(context, !new_vocab_filename.empty(),
errors::InvalidArgument("new vocab filename cannot be empty."));
lookup::HashTable<int64, string>* new_vocab_table =
new lookup::HashTable<int64, string>(context, this);
core::ScopedUnref unref_new(new_vocab_table);
// Note: we pass -1 (unknown) for vocab_size, which is supposed to be the
// total elements in file. This is different from num_new_vocab_, which
// accounts for partitioning.
OP_REQUIRES_OK(context, lookup::InitializeTableFromTextFile(
new_vocab_filename,
-1, // vocab_size
kUnusedLookupDelim,
-1, // key_index, use the line number.
-2, // value_index, use the whole line/token.
context->env(), new_vocab_table));
OP_REQUIRES(context,
new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(),
errors::InvalidArgument("lookup table size must be larger than "
"last new vocab entry's line"));
const Tensor* old_vocab_file_tensor;
OP_REQUIRES_OK(context,
context->input("old_vocab_file", &old_vocab_file_tensor));
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(old_vocab_file_tensor->shape()),
errors::InvalidArgument(
"old_vocab_file should be a single string, but got ",
old_vocab_file_tensor->shape().DebugString()));
// Build a token->old ID lookup table.
const string& old_vocab_filename =
old_vocab_file_tensor->scalar<string>()();
OP_REQUIRES(context, !old_vocab_filename.empty(),
errors::InvalidArgument("new vocab filename cannot be empty."));
lookup::HashTable<string, int64>* old_vocab_table =
new lookup::HashTable<string, int64>(context, this);
core::ScopedUnref unref_old(old_vocab_table);
// Note: we pass -1 (unknown) for vocab_size, which is supposed to be the
// total elements in file. This is different from num_new_vocab_, which
// accounts for partitioning.
OP_REQUIRES_OK(context, lookup::InitializeTableFromTextFile(
old_vocab_filename,
-1, // vocab_size
kUnusedLookupDelim,
-2, // key_index, use the whole line/token.
-1, // value_index, use the line number.
context->env(), old_vocab_table));
// Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ...,
// new_vocab_offset + num_new_vocab_]
// The double look-up requires a few temporary Tensors.
Tensor new_ids;
OP_REQUIRES_OK(
context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}),
&new_ids));
auto new_ids_vec = new_ids.vec<int64>();
// Note that we should always be able to find tokens for all new ID's, given
// that the lookup table is constructed with the vocabulary file itself
// (see the check on offset and table size post-initialization).
Tensor default_token;
OP_REQUIRES_OK(
context, context->allocate_temp(
DT_STRING, TensorShape({num_new_vocab_}), &default_token));
auto default_token_vec = default_token.vec<string>();
default_token_vec.setConstant("" /* NOT_FOUND_TOKEN */);
Tensor default_id;
OP_REQUIRES_OK(
context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}),
&default_id));
auto default_id_vec = default_id.vec<int64>();
default_id_vec.setConstant(-1 /* NOT_FOUND_ID */);
for (int i = 0; i < num_new_vocab_; ++i) {
new_ids_vec(i) = static_cast<int64>(i + new_vocab_offset_);
}
Tensor tokens;
OP_REQUIRES_OK(context,
context->allocate_temp(
DT_STRING, TensorShape({num_new_vocab_}), &tokens));
Tensor* remapping;
OP_REQUIRES_OK(context,
context->allocate_output(
"remapping", TensorShape({num_new_vocab_}), &remapping));
// In the corner case where num_new_vocab_ is 0 (we are dealing with an
// OOV-only partition), we should not do this lookup.
if (num_new_vocab_ != 0) {
OP_REQUIRES_OK(context, new_vocab_table->Find(context, new_ids, &tokens,
default_token));
OP_REQUIRES_OK(context, old_vocab_table->Find(context, tokens, remapping,
default_id));
}
// Iterate through remapping to calculate num_present.
const auto remapping_vec = remapping->vec<int64>();
int num_present = 0;
for (int i = 0; i < num_new_vocab_; ++i) {
if (remapping_vec(i) != -1 /* NOT_FOUND_ID */) {
++num_present;
}
}
Tensor* num_present_t;
OP_REQUIRES_OK(context,
context->allocate_output("num_present", TensorShape({}),
&num_present_t));
num_present_t->scalar<int>()() = num_present;
}
private:
int new_vocab_offset_;
int num_new_vocab_;
};
REGISTER_KERNEL_BUILDER(Name("GenerateVocabRemapping").Device(DEVICE_CPU),
GenerateVocabRemappingOp);
} // namespace tensorflow

View File

@ -0,0 +1,259 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
namespace tensorflow {
// This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
// swaps around the rows/columns according to row_remapping/col_remapping.
// "Missing" cells are initialized with values from initializing_values.
class LoadAndRemapMatrixOp : public OpKernel {
public:
explicit LoadAndRemapMatrixOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
}
void Compute(OpKernelContext* context) override {
// Checks what we're remapping and inverts the relevant remapping Tensors to
// be maps with key = old ID, value = new ID.
std::vector<std::pair<int64, int64>> old_row_to_new_row_pairs;
std::vector<bool> row_id_present(num_rows_);
const Tensor* row_remapping_t;
OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
const auto row_remapping = row_remapping_t->vec<int64>();
OP_REQUIRES(context, row_remapping.size() == num_rows_,
errors::InvalidArgument(strings::StrCat(
"Size of row_remapping is ", row_remapping.size(),
" intead of being equal to num_rows=", num_rows_)));
old_row_to_new_row_pairs.reserve(num_rows_);
for (int i = 0; i < row_remapping.size(); ++i) {
if (row_remapping(i) < 0) continue;
row_id_present[i] = true;
old_row_to_new_row_pairs.push_back(std::make_pair(row_remapping(i), i));
}
// Processes the remapping for columns.
std::unordered_map<int64, int64> old_col_to_new_col_map;
std::vector<bool> col_id_present(num_cols_);
const Tensor* col_remapping_t;
OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
const auto col_remapping = col_remapping_t->vec<int64>();
// Note that we always "remap rows", even when the row vocabulary does
// not change, because partitioning requires a mapping from partitioned
// Variables to the full checkpoints we load.
const bool remap_cols = col_remapping.size() > 0;
if (remap_cols) {
OP_REQUIRES(
context, col_remapping.size() == num_cols_,
errors::InvalidArgument(strings::StrCat(
"Provided col_remapping, but its size is ", col_remapping.size(),
" instead of being equal to num_cols=", num_cols_)));
for (int i = 0; i < col_remapping.size(); ++i) {
const int64 old_col = col_remapping(i);
if (old_col < 0) continue;
col_id_present[i] = true;
OP_REQUIRES(
context,
gtl::InsertIfNotPresent(&old_col_to_new_col_map, old_col, i),
errors::Unimplemented(strings::StrCat(
"Old column ID ", old_col, " is mapped to both new column ID ",
old_col_to_new_col_map[old_col], " and ", i,
", which is not currently supported - but could be "
"implemented.")));
}
} else {
col_id_present.clear();
col_id_present.resize(num_cols_, true);
}
// Processes the checkpoint source and the provided Tensor name.
const Tensor* ckpt_path_t;
OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
const string ckpt_path = *(ckpt_path_t->scalar<string>().data());
const Tensor* old_tensor_name_t;
OP_REQUIRES_OK(context,
context->input("old_tensor_name", &old_tensor_name_t));
const string old_tensor_name =
*(old_tensor_name_t->scalar<string>().data());
LOG(INFO) << "Processing checkpoint : " << ckpt_path;
BundleReader reader(context->env(), ckpt_path);
OP_REQUIRES_OK(context, reader.status());
DataType tensor_type;
TensorShape tensor_shape;
OP_REQUIRES_OK(context, reader.LookupDtypeAndShape(
old_tensor_name, &tensor_type, &tensor_shape));
OP_REQUIRES(context, tensor_type == DT_FLOAT,
errors::InvalidArgument(strings::StrCat(
"Tensor ", old_tensor_name, " has invalid type ",
DataTypeString(tensor_type), " instead of expected type ",
DataTypeString(DT_FLOAT))));
// This op is limited to loading Tensors of rank 2 (matrices).
OP_REQUIRES(
context, tensor_shape.dims() == 2,
errors::InvalidArgument(strings::StrCat(
"Tensor ", old_tensor_name, " has shape ",
tensor_shape.DebugString(), " of invalid rank ",
tensor_shape.dims(), " instead of expected shape of rank 2.")));
if (!remap_cols) {
// TODO(weiho): Consider relaxing this restriction to allow partial column
// loading (even when no column remapping is specified) if there turns out
// to be a use case for it.
OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1),
errors::InvalidArgument(strings::StrCat(
"Tensor ", old_tensor_name, " has shape ",
tensor_shape.DebugString(),
", where the size of its 2nd dimension is ",
tensor_shape.dim_size(1),
" instead of being equal to num_cols=", num_cols_)));
}
// Uses TensorSlice to selectively read rows of interest from the old
// tensor. Given BundleReader's use of RandomAccessFile and InputBuffer,
// there shouldn't too many more additional disk seeks when compared to
// loading the old tensor in chunks, once we sort the row IDs. Even if there
// are locality concerns with some reading patterns, that just means if we
// had read it in chunks, then we would have had to read, copy, and process
// then discard many redundant rows - so we should come out ahead this way.
// In addition, this frees us from having to hold the entire old tensor in
// memory.
std::sort(old_row_to_new_row_pairs.begin(), old_row_to_new_row_pairs.end());
std::vector<TensorSlice> tensor_slices;
tensor_slices.reserve(old_row_to_new_row_pairs.size());
TensorSlice slice(tensor_shape.dims());
for (const auto& pair : old_row_to_new_row_pairs) {
OP_REQUIRES(
context, pair.first < tensor_shape.dim_size(0),
errors::InvalidArgument(strings::StrCat(
"Trying to read row ", pair.first, " from tensor ",
old_tensor_name, ", which only has ", tensor_shape.dim_size(0),
" rows (with shape ", tensor_shape.DebugString(), ").")));
slice.set_start(0, pair.first);
slice.set_length(0, 1);
tensor_slices.push_back(slice);
}
// Allocates the output matrix.
Tensor* output_matrix_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("output_matrix",
TensorShape({num_rows_, num_cols_}),
&output_matrix_t));
auto output_matrix = output_matrix_t->matrix<float>();
// Iterates through tensor slices and copies over values from the old tensor
// to the output matrix.
Tensor loaded_tensor_t(DT_FLOAT,
TensorShape({1, tensor_shape.dim_size(1)}));
for (int i = 0; i < tensor_slices.size(); ++i) {
const int64 new_row = old_row_to_new_row_pairs[i].second;
if (i % 500000 == 0) {
LOG(INFO) << "Processing slice " << i << " of " << tensor_slices.size()
<< " - corresponding to old row "
<< old_row_to_new_row_pairs[i].first << " of "
<< tensor_shape.dim_size(0);
}
OP_REQUIRES_OK(context,
reader.LookupSlice(old_tensor_name, tensor_slices[i],
&loaded_tensor_t));
// Copies over the row element-by-element, in case remapping is needed
// along the column axis.
const auto& loaded_tensor = loaded_tensor_t.flat<float>();
for (int old_col = 0; old_col < loaded_tensor.size(); ++old_col) {
int64 new_col = old_col;
if (remap_cols) {
const int64* new_col_ptr =
gtl::FindOrNull(old_col_to_new_col_map, old_col);
if (new_col_ptr == nullptr) {
// Column remapping is specified, but this column is not found in
// old_col_to_new_col_map, so we leave it uninitialized, to be
// filled in with initializing_values later.
continue;
}
new_col = *new_col_ptr;
}
OP_REQUIRES(context,
new_row < num_rows_ && new_col < num_cols_ &&
new_row >= 0 && new_col >= 0,
errors::Internal(strings::StrCat(
"new_row=", new_row, " and new_col=", new_col,
" should have been less than num_rows_=", num_rows_,
" and num_cols_=", num_cols_,
" and non-negative. This should never have happened "
"if the code were correct. Please file a bug.")));
output_matrix(new_row, new_col) = loaded_tensor(old_col);
}
}
LOG(INFO) << "Copied " << tensor_slices.size()
<< " rows from old matrix (with " << tensor_shape.dim_size(0)
<< " rows) to new matrix (with " << num_rows_ << " rows).";
// At this point, there are potentially whole rows/columns uninitialized
// (corresponding to the indices where row_id_present/col_id_present are
// false). We fill this in cell-by-cell using row_id_present and
// col_id_present while dequeuing from the initializing_values vector.
const Tensor* initializing_values_t;
OP_REQUIRES_OK(
context, context->input("initializing_values", &initializing_values_t));
const auto initializing_values = initializing_values_t->flat<float>();
int64 initializing_values_index = 0;
for (int i = 0; i < num_rows_; ++i) {
for (int j = 0; j < num_cols_; ++j) {
if (!row_id_present[i] || !col_id_present[j]) {
output_matrix(i, j) = initializing_values(initializing_values_index);
++initializing_values_index;
}
}
}
// Checks that we used all the given initializing values.
OP_REQUIRES(
context, initializing_values_index == initializing_values.size(),
errors::InvalidArgument(
"initializing_values contained ", initializing_values.size(),
" elements, but only ", initializing_values_index,
" elements were used to fill in missing values."));
}
private:
int64 num_rows_;
int64 num_cols_;
};
REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
LoadAndRemapMatrixOp);
} // namespace tensorflow

View File

@ -0,0 +1,102 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("LoadAndRemapMatrix")
.Input("ckpt_path: string")
.Input("old_tensor_name: string")
.Input("row_remapping: int64")
.Input("col_remapping: int64")
.Input("initializing_values: float")
.Attr("num_rows: int >= 0")
.Attr("num_cols: int >= 1")
.Output("output_matrix: float")
// TODO(b/30502450): Setting the op as being stateful prevents it from being
// executed more often than expected (possibly due to stateful ops not being
// subject to constant folding?). This op is usually slow and may require
// multiple disk reads, so we want to minimize the number of times it's
// executed redundantly.
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
int64 num_rows;
TF_RETURN_IF_ERROR(c->GetAttr("num_rows", &num_rows));
int64 num_cols;
TF_RETURN_IF_ERROR(c->GetAttr("num_cols", &num_cols));
c->set_output(0, c->Matrix(num_rows, num_cols));
return Status::OK();
})
.Doc(R"doc(
Loads a 2-D (matrix) `Tensor` with name `old_tensor_name` from the checkpoint
at `ckpt_path` and potentially reorders its rows and columns using the
specified remappings.
Most users should use one of the wrapper initializers (such as
`tf.contrib.framework.load_and_remap_matrix_initializer`) instead of this
function directly.
The remappings are 1-D tensors with the following properties:
* `row_remapping` must have exactly `num_rows` entries. Row `i` of the output
matrix will be initialized from the row corresponding to index
`row_remapping[i]` in the old `Tensor` from the checkpoint.
* `col_remapping` must have either 0 entries (indicating that no column
reordering is needed) or `num_cols` entries. If specified, column `j` of the
output matrix will be initialized from the column corresponding to index
`col_remapping[j]` in the old `Tensor` from the checkpoint.
* A value of -1 in either of the remappings signifies a "missing" entry. In that
case, values from the `initializing_values` tensor will be used to fill that
missing row or column. If `row_remapping` has `r` missing entries and
`col_remapping` has `c` missing entries, then the following condition must be
true:
`(r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values)`
The remapping tensors can be generated using the GenerateVocabRemapping op.
As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1],
initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing
the value from row i, column j of the old tensor in the checkpoint, the output
matrix will look like the following:
[[w(1, 0), w(1, 2), 0.5],
[w(0, 0), w(0, 2), -0.5],
[0.25, -0.25, 42]]
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from
which the old matrix `Tensor` will be loaded.
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
row_remapping: An int `Tensor` of row remappings (generally created by
`generate_vocab_remapping`). Even if no row remapping is needed, this must
still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted
index-valued `Tensor` (e.g. [8, 9, 10, ...], for partitioned `Variables`).
col_remapping: An int `Tensor` of column remappings (generally created by
`generate_vocab_remapping`). May be a size-0 `Tensor` if only row remapping
is to be done (e.g. column ordering is the same).
initializing_values: A float `Tensor` containing values to fill in for cells
in the output matrix that are not loaded from the checkpoint. Length must be
exactly the same as the number of missing / new cells.
num_rows: Number of rows (length of the 1st dimension) in the output matrix.
num_cols: Number of columns (length of the 2nd dimension) in the output matrix.
output_matrix: Output matrix containing existing values loaded from the
checkpoint, and with any missing values filled in from initializing_values.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,77 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("GenerateVocabRemapping")
.Input("new_vocab_file: string")
.Input("old_vocab_file: string")
.Attr("new_vocab_offset: int >= 0")
.Attr("num_new_vocab: int >= 0")
.Output("remapping: int64")
.Output("num_present: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
int64 new_vocab_offset;
TF_RETURN_IF_ERROR(c->GetAttr("new_vocab_offset", &new_vocab_offset));
int64 num_new_vocab;
TF_RETURN_IF_ERROR(c->GetAttr("num_new_vocab", &num_new_vocab));
c->set_output(0, c->Vector(num_new_vocab));
c->set_output(1, c->Scalar());
return Status::OK();
})
.Doc(R"doc(
Given a path to new and old vocabulary files, returns a remapping Tensor of
length `num_new_vocab`, where `remapping[i]` contains the row number in the old
vocabulary that corresponds to row `i` in the new vocabulary (starting at line
`new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
in the new vocabulary is not in the old vocabulary. `num_vocab_offset` enables
use in the partitioned variable case, and should generally be set through
examining partitioning info. The format of the files should be a text file,
with each line containing a single entity within the vocabulary.
For example, with `new_vocab_file` a text file containing each of the following
elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3],
`num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be
`[0, -1, 2]`.
The op also returns a count of how many entries in the new vocabulary
were present in the old vocabulary, which is used to calculate the number of
values to initialize in a weight matrix remapping
This functionality can be used to remap both row vocabularies (typically,
features) and column vocabularies (typically, classes) from TensorFlow
checkpoints. Note that the partitioning logic relies on contiguous vocabularies
corresponding to div-partitioned variables. Moreover, the underlying remapping
uses an IndexTable (as opposed to an inexact CuckooTable), so client code should
use the corresponding index_table_from_file() as the FeatureColumn framework
does (as opposed to tf.feature_to_id(), which uses a CuckooTable).
new_vocab_file: Path to the new vocab file.
old_vocab_file: Path to the old vocab file.
new_vocab_offset: How many entries into the new vocab file to start reading.
num_new_vocab: Number of entries in the new vocab file to remap.
remapping: A Tensor of length num_new_vocab where the element at index i
is equal to the old ID that maps to the new ID i. This element is -1 for any
new ID that is not found in the old vocabulary.
num_present: Number of new vocab entries found in old vocab.
)doc");
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from tensorflow.contrib.framework.python.framework.checkpoint_utils import * from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
from tensorflow.contrib.framework.python.framework.experimental import experimental from tensorflow.contrib.framework.python.framework.experimental import experimental
from tensorflow.contrib.framework.python.framework.load_and_remap_matrix_ops import *
from tensorflow.contrib.framework.python.framework.tensor_util import * from tensorflow.contrib.framework.python.framework.tensor_util import *
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
from tensorflow.python.util import decorator_utils from tensorflow.python.util import decorator_utils

View File

@ -0,0 +1,491 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations for generating and loading vocab remappings."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import gen_checkpoint_ops
from tensorflow.contrib.framework.python.ops import gen_generate_vocab_remapping_ops
from tensorflow.contrib.util import loader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import resource_loader
_checkpoint_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_checkpoint_ops.so"))
ops.NotDifferentiable("GenerateVocabRemapping")
ops.NotDifferentiable("LoadAndRemapMatrix")
def _load_and_remap_matrix(ckpt_path,
old_tensor_name,
new_row_vocab_offset,
num_rows_to_load,
new_col_vocab_size,
initializer,
old_row_vocab_file=None,
new_row_vocab_file=None,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=0,
num_col_oov_buckets=0):
"""Loads a 2-D (matrix) `Tensor` from checkpoint.
Generates 1D-remappings for rows and columns using the
`GenerateVocabRemapping` op, and initializes any anticipated values with the
provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a
matrix that loads existing values from the checkpoint, while filling out
"missing" values with the newly initialized values. See
contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped
functionality (LoadAndRemapMatrix). This wrapper can be used to perform only
row remapping or only col remapping. If only row remapping is desired,
{new,old}_col_vocab_file should be `None`, and vice versa for column
remapping.
NOTE: This only supports div-partitioning the vocabulary on the 1st dimension
(row axis) via `new_row_vocab_offset`.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
new_row_vocab_offset: A 0-indexed integer representing what line to
start reading at in the new row vocabulary. Used for partitioned
variables.
num_rows_to_load: Number of rows to load for the new vocabulary (note: to
support variable partitioning and partial loading, this does not need to
be the same as the number of entries in `new_row_vocab_file`).
new_col_vocab_size: Number of columns to load - should be the same as the
number of entries in `new_col_vocab_file`, since we don't support
partitioning along the column axis.
initializer: Callable initializer function that accepts a 1-D tensor as the
arg to specify the shape of the returned tensor. Used to initialize
missing values.
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old row vocabulary file. Can be None, which represents no
remapping on the row axis.
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new row vocabulary file. Can be None, which represents no remapping
on the row axis - in which case, `new_row_vocab_offset` and
`num_rows_to_load` work under the assumption that the new row vocab is the
same as the old row vocab.
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old column vocabulary file. Can be None, which represents no
remapping on the column axis.
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new column vocabulary file. Can be None, which represents no
remapping on the column axis - in which case, `new_col_vocab_size` works
under the assumption that the new col vocab is the same as the old col
vocab.
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
to append. Must be >= 0.
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
columns to append. Must be >= 0.
Returns:
A Tensor of shape `[num_rows_to_load + num_row_oov_buckets,
new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the
specified tensor in the checkpoint, and any missing or OOV values
initialized with the given `initializer`.
Raises:
ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0.
ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is
provided, while the other is not. Same for `old_col_vocab_file` and
`new_col_vocab_file`.
ValueError: If neither row vocabs or col vocabs are provided.
"""
if num_row_oov_buckets < 0:
raise ValueError("num_row_oov_buckets must be >= 0, but received %d" %
num_row_oov_buckets)
if num_col_oov_buckets < 0:
raise ValueError("num_col_oov_buckets must be >= 0, but received %d" %
num_col_oov_buckets)
if bool(old_row_vocab_file) != bool(new_row_vocab_file):
raise ValueError(
"old_row_vocab_file and new_row_vocab_file must both be specified or "
"left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'".
format(old_row_vocab_file, new_row_vocab_file))
if bool(old_col_vocab_file) != bool(new_col_vocab_file):
raise ValueError(
"old_col_vocab_file and new_col_vocab_file must both be specified or "
"left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'".
format(old_col_vocab_file, new_col_vocab_file))
remap_rows = new_row_vocab_file and old_row_vocab_file
remap_cols = new_col_vocab_file and old_col_vocab_file
if not (remap_rows or remap_cols):
raise ValueError(
"Must provide either row or column vocab files. If no remapping is "
"necessary, consider using `tf.contrib.framework.init_from_checkpoint` "
"instead.")
num_rows_present = num_rows_to_load
if remap_rows:
row_remapping, num_rows_present = (
gen_generate_vocab_remapping_ops.generate_vocab_remapping(
new_vocab_file=new_row_vocab_file,
old_vocab_file=old_row_vocab_file,
new_vocab_offset=new_row_vocab_offset,
num_new_vocab=num_rows_to_load))
else:
# Even when the rows are not being reordered, we still need to generate a
# remapping to account for initializing partitioned Variables (when
# new_row_vocab_offset is non-zero).
row_remapping = math_ops.range(
new_row_vocab_offset,
new_row_vocab_offset + num_rows_to_load,
dtype=dtypes.int64)
col_remapping = []
num_cols_present = new_col_vocab_size
if remap_cols:
col_remapping, num_cols_present = (
gen_generate_vocab_remapping_ops.generate_vocab_remapping(
new_vocab_file=new_col_vocab_file,
old_vocab_file=old_col_vocab_file,
new_vocab_offset=0, # Offset is unused for cols (no partitioning).
num_new_vocab=new_col_vocab_size))
init_vals = initializer([
num_rows_to_load * new_col_vocab_size -
num_rows_present * num_cols_present, 1
])
return_tensor = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=ckpt_path,
old_tensor_name=old_tensor_name,
row_remapping=row_remapping,
col_remapping=col_remapping,
initializing_values=init_vals,
num_rows=num_rows_to_load,
num_cols=new_col_vocab_size)
# Add OOV row(s) and column(s).
if num_row_oov_buckets > 0:
init_row_oov_val = initializer([num_row_oov_buckets, new_col_vocab_size])
init_row_oov_val = ops.convert_to_tensor(init_row_oov_val)
return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0)
if num_col_oov_buckets > 0:
# We need to add any row OOV to the new column shape.
init_col_oov_val = initializer(
[num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets])
init_col_oov_val = ops.convert_to_tensor(init_col_oov_val)
return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1)
return return_tensor
def load_and_remap_matrix_initializer(ckpt_path,
old_tensor_name,
new_row_vocab_size,
new_col_vocab_size,
old_row_vocab_file=None,
new_row_vocab_file=None,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=0,
num_col_oov_buckets=0,
initializer=None):
r"""Returns a var initializer for loading and remapping a 2-D (matrix) tensor.
The returned initializer loads a 2-D (matrix) `Tensor` with name
`old_tensor_name` from the checkpoint at `ckpt_path`. It will reorder the
rows/columns according to the specified vocab files and append additional
out-of-vocabulary rows/columns according to the number of OOV buckets.
The format of the file at the `{old,new}_{row,col}_vocab_file` path should be
a text file, with each line containing a single entity within the vocabulary.
Let the function `line_of(f, "x")` return the 0-indexed line number of the
entity "x" in file f, and the function `entity_at(f, i)` return the entity at
line i of file f. Then, row i of the new output matrix will be taken from row
`line_of(old_row_vocab_file, entity_at(new_row_vocab_file, i))` of the old
matrix. If any entity in `new_row_vocab_file` is not found in
`old_row_vocab_file`, that row is considered a "missing" row, and its values
will be initialized using the `initializer` arg. The same logic also applies
for the columns.
For example, assuming that:
* `old_row_vocab_file` contains "mercury\nvenus\nmars"
* `new_row_vocab_file` contains "venus\njupiter\nmercury"
* `old_col_vocab_file` contains "good\nbetter\nbest"
* `new_col_vocab_file` contains "good\nbest\nfantastic"
* `initializer` returns the natural numbers `[1, 2, 3, 4, ...]`
* `w(i, j)` represents the value from row i, column j of the old matrix
Then the new output matrix will look like:
`[[w(1, 0), w(1, 2), 1],
[2, 3, 4],
[w(0, 0), w(0, 2), 5]]`
If we further specify that:
* `num_row_oov_buckets` == 2
* `num_col_oov_buckets` == 1
Then the new output matrix will look like:
`[[w(1, 0), w(1, 2), 1, 12],
[2, 3, 4, 13],
[w(0, 0), w(0, 2), 5, 14],
[6, 7, 8, 15],
[9, 10, 11, 16]]`
If `{old,new}_row_vocab_file` are None, we assume that the old and new row
vocab files are the same, and no row remapping is done. If
`{old,new}_col_vocab_file` are None, we assume that the old and new column
vocab files are the same, and no column remapping is done.
The returned initializer only supports div-partitioning along the row axis. It
does not support partitioning along the column axis or mod-partitioning.
NOTE: When this is used to warm-start variables, client code should use
`tf.lookup.index_table_from_tensor()` like
contrib/layers/python/layers/feature_column.py does, as opposed to
`tf.feature_to_id()` - in order to ensure the underlying lookup tables are the
same.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
new_row_vocab_size: `int` specifying the number of entries in
`new_row_vocab_file`. If no row remapping is needed (no row vocab
provided), this should be equal to the number of rows to load from the old
matrix (which can theoretically be smaller than the number of rows in the
old matrix).
new_col_vocab_size: `int` specifying the number of entries in
`new_col_vocab_file`. If no column remapping is needed (no column vocab
provided), this should be equal to the number of columns in the old
matrix.
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old row vocabulary file. Can be None, which represents no
remapping on the row axis.
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new row vocabulary file. Can be None, which represents no remapping
on the row axis.
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old column vocabulary file. Can be None, which represents no
remapping on the column axis.
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new column vocabulary file. Can be None, which represents no
remapping on the column axis.
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
to append. Must be >= 0.
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
columns to append. Must be >= 0.
initializer: Initializer function to initialize missing values. Accepts a
1-D tensor as the arg to specify the shape of the returned tensor. If
`None`, defaults to using `zeros_initializer()`.
Returns:
A variable initializer function that should be used to initialize a
(potentially partitioned) `Variable` whose complete shape is
`[new_row_vocab_size + num_row_oov_buckets, new_col_vocab_size +
num_col_oov_buckets]`.
Raises:
TypeError: If `initializer` is specified but not callable.
"""
if initializer is None:
# TODO(b/25671353): Consider using sqrt(6/(fan_in + fan_out)) instead, from
# Glorot and Bengio, 2010.
initializer = init_ops.zeros_initializer()
if not callable(initializer):
raise TypeError(
"initializer must be callable, instead of being {} of type {}.".format(
initializer, type(initializer)))
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
"""Variable initializer.
Args:
shape: Shape of `Tensor` to return. Should include OOV on both axes.
dtype: Must be float32.
partition_info: variable_scope._PartitionInfo.
Returns:
`Tensor` of shape `shape`.
Raises:
TypeError: If `dtype` is anything other than float32.
ValueError: For shape mismatch upon invocation.
"""
# Sanity checks.
if dtype != dtypes.float32:
raise TypeError(
"Currently, only float32 is supported. Received dtype: {}".format(
dtype))
if len(shape) != 2:
raise ValueError("Expected 2-dim shape, but received: {}".format(shape))
if shape[0] <= 0:
raise ValueError(
"Expected 1st dim of shape to be > 0, but received shape: {}".format(
shape))
if shape[1] != (new_col_vocab_size + num_col_oov_buckets):
raise ValueError(
"Expected 2nd dim of shape to be new_col_vocab_size ({}) + "
"num_col_oov_buckets ({}) = {}, but received shape: {}".format(
new_col_vocab_size, num_col_oov_buckets,
new_col_vocab_size + num_col_oov_buckets, shape))
offset = 0
if partition_info is not None:
offset = partition_info.single_offset(shape)
if offset + shape[0] > new_row_vocab_size + num_row_oov_buckets:
raise ValueError(
"Trying to initialize {} additional rows after {} rows have already "
"been initialized, which would exceed expected total row count of "
"new_row_vocab_size ({}) + num_row_oov_buckets ({}) = {}.".format(
shape[0], offset, new_row_vocab_size, num_row_oov_buckets,
new_row_vocab_size + num_row_oov_buckets))
row_oov_buckets_to_use = min(shape[0],
max(0, offset + shape[0] - new_row_vocab_size))
num_rows_to_load = shape[0] - row_oov_buckets_to_use
return _load_and_remap_matrix(
ckpt_path=ckpt_path,
old_tensor_name=old_tensor_name,
new_row_vocab_offset=offset,
num_rows_to_load=num_rows_to_load,
new_col_vocab_size=new_col_vocab_size,
initializer=initializer,
old_row_vocab_file=old_row_vocab_file,
new_row_vocab_file=new_row_vocab_file,
old_col_vocab_file=old_col_vocab_file,
new_col_vocab_file=new_col_vocab_file,
num_row_oov_buckets=row_oov_buckets_to_use,
num_col_oov_buckets=num_col_oov_buckets)
return _initializer
def load_embedding_initializer(ckpt_path,
embedding_tensor_name,
new_vocab_size,
embedding_dim,
old_vocab_file,
new_vocab_file,
num_oov_buckets=0,
initializer=None):
"""Returns a variable initializer for loading pre-trained embeddings.
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
embedding weights and remapping according to the provided vocab files. See
docs for `load_and_remap_matrix_initializer()` for more details.
NOTE: Only for use with div-partitioned variables / vocabularies.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
embedding_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
new_vocab_size: Number of entries in the new vocab.
embedding_dim: `int` specifying the dimension of the embedding vectors from
the checkpoint. Must match the number of columns in the old embedding
matrix.
old_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old vocabulary file.
new_vocab_file: A scalar `Tensor` of type `string` containing the
path to the new vocabulary file.
num_oov_buckets: `int` specifying the number of out-of-vocabulary
buckets to use. Must be >= 0.
initializer: Initializer function that accepts a 1-D tensor as the arg to
specify the shape of the returned tensor. If `None`, defaults to using
`truncated_normal_initializer()`.
Returns:
A variable initializer function.
"""
if initializer is None:
# TODO(b/25671353): This should be kept in sync with the stddev used by
# feature_column.py's _EmbeddingColumn.
initializer = init_ops.truncated_normal_initializer(
stddev=1.0 /
math_ops.sqrt(math_ops.cast(embedding_dim, dtypes.float32)))
return load_and_remap_matrix_initializer(
ckpt_path=ckpt_path,
old_tensor_name=embedding_tensor_name,
new_row_vocab_size=new_vocab_size,
new_col_vocab_size=embedding_dim,
old_row_vocab_file=old_vocab_file,
new_row_vocab_file=new_vocab_file,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=num_oov_buckets,
num_col_oov_buckets=0,
initializer=initializer)
def load_linear_multiclass_bias_initializer(ckpt_path,
bias_tensor_name,
new_class_vocab_size,
old_class_vocab_file,
new_class_vocab_file,
num_class_oov_buckets=0,
initializer=None):
"""Loads pre-trained multi-class biases for linear models from checkpoint.
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
multi-class bias and remapping according to the provided vocab files. See docs
for `load_and_remap_matrix_initializer()` for more details. In this case, the
provided row_vocab is the class vocabulary, and the expected shape is
`[new_class_vocab_size, 1]`.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
bias_tensor_name: Tensor name to load from in the checkpoints.
new_class_vocab_size: Number of entries in the new class vocab.
old_class_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old class vocabulary file.
new_class_vocab_file: A scalar `Tensor` of type `string` containing the
path to the new class vocabulary file.
num_class_oov_buckets: `int` specifying the number of out-of-vocabulary
buckets to use for the classes. Must be >= 0.
initializer: Initializer function that accepts a 1-D tensor as the arg to
specify the shape of the returned tensor. If `None`, defaults to using
`zeros_initializer()`.
Returns:
A variable initializer function.
"""
# Linear multi-class biases should be zero-initialized.
if initializer is None:
initializer = init_ops.zeros_initializer()
return load_and_remap_matrix_initializer(
ckpt_path=ckpt_path,
old_tensor_name=bias_tensor_name,
new_row_vocab_size=new_class_vocab_size,
new_col_vocab_size=1,
old_row_vocab_file=old_class_vocab_file,
new_row_vocab_file=new_class_vocab_file,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=num_class_oov_buckets,
num_col_oov_buckets=0,
initializer=initializer)

View File

@ -0,0 +1,558 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for the op to generate vocab remapping."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework.python.framework import load_and_remap_matrix_ops
from tensorflow.contrib.framework.python.ops import gen_checkpoint_ops
from tensorflow.contrib.framework.python.ops.gen_generate_vocab_remapping_ops import generate_vocab_remapping
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.platform import test
from tensorflow.python.training import saver
FLAGS = flags.FLAGS
_TESTDATA_PATH = 'contrib/framework/testdata'
class GenerateVocabRemappingTest(test.TestCase):
"""Tests for the generate_vocab_remapping() method."""
def setUp(self):
self.new_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_shifted.txt')
self.old_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
def test_generate_remapping_with_no_vocab_changes(self):
"""Tests where vocab does not change at all."""
remapping, num_present = generate_vocab_remapping(
new_vocab_file=self.old_vocab_file,
old_vocab_file=self.old_vocab_file,
num_new_vocab=3,
new_vocab_offset=0)
expected_remapping = range(0, 3)
expected_num_present = 3
with self.test_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
def test_generate_remapping_with_shifted_vocab(self):
"""Tests where vocab is the same, but shifted / ordered differently."""
remapping, num_present = generate_vocab_remapping(
new_vocab_file=self.new_vocab_file,
old_vocab_file=self.old_vocab_file,
num_new_vocab=3,
new_vocab_offset=0)
expected_remapping = [2, 0, 1]
expected_num_present = 3
with self.test_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
def test_generate_remapping_with_offset(self):
"""Tests offset and num_new_vocab logic."""
remapping, num_present = generate_vocab_remapping(
new_vocab_file=self.new_vocab_file,
old_vocab_file=self.old_vocab_file,
num_new_vocab=1,
new_vocab_offset=1)
expected_remapping = [0]
expected_num_present = 1
with self.test_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
class LoadAndRemapMatrixTest(test.TestCase):
"""Tests for the load_and_remap_weight_matrix() op."""
def setUp(self):
ops.reset_default_graph()
self.old_num_rows = 5
self.old_num_cols = 16
self.matrix_value = np.reshape(
range(0, self.old_num_rows * self.old_num_cols), (self.old_num_rows,
self.old_num_cols))
with variable_scope.variable_scope('some_scope'):
matrix = variable_scope.get_variable(
'matrix',
dtype=dtypes.float32,
initializer=constant_op.constant(
self.matrix_value, dtype=dtypes.float32))
self.old_tensor_name = 'some_scope/matrix'
save = saver.Saver([matrix])
with self.test_session() as sess:
variables.global_variables_initializer().run()
self.bundle_file = os.path.join(test.get_temp_dir(), 'bundle_checkpoint')
save.save(sess, self.bundle_file)
def test_load_and_remap_no_missing(self):
"""Tests the op's load and remap where there are no missing entries."""
# No column remapping, new weight matrix has second row, then first row.
row_remapping = [1, 0]
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=row_remapping,
col_remapping=[],
initializing_values=[],
num_rows=2,
num_cols=self.old_num_cols)
with self.test_session():
self.assertAllClose(self.matrix_value[row_remapping],
remapped_weight_matrix.eval())
# No row remapping, new weight matrix has third col, then first col.
row_remapping = list(range(self.old_num_rows))
col_remapping = [2, 0]
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=row_remapping,
col_remapping=col_remapping,
initializing_values=[],
num_rows=len(row_remapping),
num_cols=len(col_remapping))
with self.test_session():
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
remapped_weight_matrix.eval())
# Both row and column remappings.
row_remapping = [1, 0, 4]
col_remapping = [1, 15]
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=row_remapping,
col_remapping=col_remapping,
initializing_values=[],
num_rows=len(row_remapping),
num_cols=len(col_remapping))
with self.test_session():
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
remapped_weight_matrix.eval())
def test_load_and_remap_with_init(self):
"""Tests the op's load and remap where there are missing entries."""
init_val = 42
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=[2, -1, 0],
col_remapping=[1, -1],
initializing_values=[init_val] * 4,
num_rows=3,
num_cols=2)
expected_remapped_weight_matrix = np.reshape(
[33, init_val, init_val, init_val, 1, init_val], [3, 2])
with self.test_session():
self.assertAllClose(expected_remapped_weight_matrix,
remapped_weight_matrix.eval())
def test_load_and_remap_all_missing_rows(self):
"""Tests when all the rows are missing and need to be initialized."""
num_rows = 7
initializing_values = [42] * num_rows * self.old_num_cols
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=[-1] * num_rows,
col_remapping=[],
initializing_values=initializing_values,
num_rows=num_rows,
num_cols=self.old_num_cols)
with self.test_session():
self.assertAllClose(
np.reshape(initializing_values, (num_rows, self.old_num_cols)),
remapped_weight_matrix.eval())
def test_load_and_remap_all_missing_rows_and_cols(self):
"""Tests when all the rows & cols are missing and need to be initialized."""
num_rows = 7
num_cols = 4
initializing_values = [42] * num_rows * num_cols
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=[-1] * num_rows,
col_remapping=[-1] * num_cols,
initializing_values=initializing_values,
num_rows=num_rows,
num_cols=num_cols)
with self.test_session():
self.assertAllClose(
np.reshape(initializing_values, (num_rows, num_cols)),
remapped_weight_matrix.eval())
def test_load_and_remap_duplicate_row_remapping(self):
"""Tests when an old row maps to multiple new rows.
(This should usually not happen when using public APIs).
"""
row_remapping = [1, 0, 0, 0, 1, 2]
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=row_remapping,
col_remapping=[],
initializing_values=[],
num_rows=len(row_remapping),
num_cols=self.old_num_cols)
with self.test_session():
self.assertAllClose(self.matrix_value[row_remapping],
remapped_weight_matrix.eval())
def test_load_and_remap_invalid_col_remapping(self):
"""Tests that an error is raised when an old col maps to multiple new cols.
(This should usually not happen when using public APIs).
"""
col_remapping = [1, 0, 0, 0, 1, 2]
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=list(range(self.old_num_rows)),
col_remapping=col_remapping,
initializing_values=[],
num_rows=self.old_num_rows,
num_cols=len(col_remapping))
with self.test_session(), self.assertRaises(errors.UnimplementedError):
remapped_weight_matrix.eval()
class LoadAndRemapWrappersTest(test.TestCase):
"""Tests for the functionality of the Python wrappers."""
def setUp(self):
self.bundle_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint')
self.new_feature_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint_vocab.txt')
self.old_feature_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH),
'bundle_checkpoint_vocab_with_oov.txt')
self.new_class_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt')
self.old_class_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
self.init_val = 42
def _init_val_initializer(shape, dtype=None, partition_info=None):
del dtype, partition_info # Unused by this unit-testing initializer.
return array_ops.tile(
constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
self.initializer = _init_val_initializer
def test_load_and_remap_matrix(self):
"""Tests the end-to-end loading / remapping of weights."""
# load_and_remap_matrix() is the generalized wrapper that takes in row and
# column vocabulary files, calls the relevant remappings, and returns the
# weight matrix. Take this example to be linear multi-class by providing
# both row and column vocabularies.
remapped_matrix = load_and_remap_matrix_ops._load_and_remap_matrix(
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_rows_to_load=4,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_offset=1,
initializer=self.initializer,
num_row_oov_buckets=1,
num_col_oov_buckets=1)
# [4 in vocab + 1 oov features, 4 in vocab + 1 oov classes]. The offset
# means we read
expected_remapped_matrix = np.concatenate(
[
np.reshape([18, 34, 50, self.init_val, self.init_val], [5, 1]),
np.reshape([16, 32, 48, self.init_val, self.init_val], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1]),
np.reshape([17, 33, 49, self.init_val, self.init_val], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1])
],
axis=1)
with self.test_session():
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
def test_load_and_remap_output_layer_weight_initializer_linear(self):
"""Tests for the output layer initializer in the linear multi-class case."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_row_oov_buckets=1,
num_col_oov_buckets=1,
initializer=self.initializer))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, self.init_val, self.init_val], [6, 1]),
np.reshape([0, 16, 32, 48, self.init_val, self.init_val], [6, 1]),
np.reshape([self.init_val] * 6, [6, 1]),
np.reshape([1, 17, 33, 49, self.init_val, self.init_val], [6, 1]),
np.reshape([self.init_val] * 6, [6, 1])
],
axis=1)
# The new weight matrix is of size
# [5 feature vocab + 1 feature OOV, 4 class vocab + 1 class OOV]. Use a
# partitioned variable to confirm that the offset logic works.
remapped_matrix = variable_scope.get_variable(
name='linear/obtained_weight_matrix',
shape=[6, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_load_and_remap_output_layer_weight_initializer_dnn_output(self):
"""Tests for the output layer initializer in the DNN output case."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
num_col_oov_buckets=1,
initializer=self.initializer))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, 66], [5, 1]),
np.reshape([0, 16, 32, 48, 64], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1]),
np.reshape([1, 17, 33, 49, 65], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1])
],
axis=1)
# The new weight matrix is of size
# [5-sized input layer, 4 class vocab + 1 class OOV].
remapped_matrix = variable_scope.get_variable(
name='dnn_output/obtained_weight_matrix',
shape=[5, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_initializer_with_oov_only_partition(self):
"""Tests for the output layer initializer where one partition is all OOV."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_row_oov_buckets=5,
num_col_oov_buckets=1,
initializer=self.initializer))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50] + [self.init_val] * 6, [10, 1]),
np.reshape([0, 16, 32, 48] + [self.init_val] * 6, [10, 1]),
np.reshape([self.init_val] * 10, [10, 1]),
np.reshape([1, 17, 33, 49] + [self.init_val] * 6, [10, 1]),
np.reshape([self.init_val] * 10, [10, 1]),
],
axis=1)
# The new weight matrix is of size
# [5 feature vocab + 5 feature OOV, 4 class vocab + 1 class OOV]. The
# second partition has only OOV.
remapped_matrix = variable_scope.get_variable(
name='linear_all_oov/obtained_weight_matrix',
shape=[10, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_load_and_remap_linear_multiclass_initializer_default_init(self):
"""Tests where the zeros_initializer default is used for linear."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_row_oov_buckets=1,
num_col_oov_buckets=1))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, 0, 0], [6, 1]),
np.reshape([0, 16, 32, 48, 0, 0], [6, 1]),
np.reshape([0] * 6, [6, 1]),
np.reshape([1, 17, 33, 49, 0, 0], [6, 1]),
np.reshape([0] * 6, [6, 1])
],
axis=1)
remapped_matrix = variable_scope.get_variable(
name='linear_init_fallback/obtained_weight_matrix',
shape=[6, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_load_embedding_initializer(self):
"""Tests for the load_embedding initializer wrapper."""
embedding_loading_initializer = (
contrib_framework.load_embedding_initializer(
new_vocab_file=self.new_feature_vocab_file,
old_vocab_file=self.old_feature_vocab_file,
new_vocab_size=5,
embedding_dim=16,
embedding_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
num_oov_buckets=1,
initializer=self.initializer))
expected_remapped_embeddings = np.concatenate(
[
np.reshape(range(64), [4, 16]),
np.reshape([self.init_val] * 32, [2, 16]),
],
axis=0)
# The new weight matrix is of size
# [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the
# last vocab row (2nd last row) is newly initialized (wasn't found in
# previous vocab) and the actual last row is OOV and also newly initialized.
# Use a partitioned variable to confirm that the offset logic works.
remapped_embeddings = variable_scope.get_variable(
name='embedding/obtained_embedding_matrix',
shape=[6, 16],
initializer=embedding_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_embeddings,
remapped_embeddings.as_tensor().eval())
class LoadMulticlassBiasTest(test.TestCase):
"""Tests for the load_linear_multiclass_bias_initializer functionality."""
def setUp(self):
ops.reset_default_graph()
dim = 1
num = 3
with ops.name_scope('some_scope'):
# Basically from 0 to dim*num-1.
flat_data = math_ops.linspace(0.0, dim * num - 1, dim * num)
bias = variables.Variable(
array_ops.reshape(flat_data, (num, dim)), name='bias')
save = saver.Saver([bias])
with self.test_session() as sess:
variables.global_variables_initializer().run()
self.bundle_file = os.path.join(test.get_temp_dir(), 'bias_checkpoint')
save.save(sess, self.bundle_file)
self.new_class_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt')
self.old_class_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
self.init_val = 42
def _init_val_initializer(shape, dtype=None, partition_info=None):
del dtype, partition_info # Unused by this unit-testing initializer.
return array_ops.tile(
constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
self.initializer = _init_val_initializer
def test_load_linear_multiclass_bias_initializer(self):
"""Tests for the bias initializer wrapper."""
bias_loading_initializer = (
contrib_framework.load_linear_multiclass_bias_initializer(
new_class_vocab_file=self.new_class_vocab_file,
old_class_vocab_file=self.old_class_vocab_file,
new_class_vocab_size=4,
bias_tensor_name='some_scope/bias',
ckpt_path=[self.bundle_file],
num_class_oov_buckets=1,
initializer=self.initializer))
expected_remapped_bias_vector = np.reshape(
[2, 0, self.init_val, 1, self.init_val], [5, 1])
# The new bias vector is of size [4 class vocab + 1 class OOV, 1].
remapped_bias_vector = variable_scope.get_variable(
name='bias/obtained_bias_vector',
shape=[5, 1],
initializer=bias_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(3))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_bias_vector,
remapped_bias_vector.as_tensor().eval())
if __name__ == '__main__':
test.main()

Binary file not shown.

View File

@ -0,0 +1,5 @@
zero
one
two
three
four

View File

@ -0,0 +1,4 @@
zero
one
two
three

View File

@ -0,0 +1,3 @@
knitting
eminem
MISSING

View File

@ -0,0 +1,4 @@
MISSING
knitting
flask
eminem

View File

@ -0,0 +1,3 @@
MISSING
knitting
eminem

View File

@ -1336,6 +1336,11 @@ cc_library(
], ],
) )
cc_header_only_library(
name = "lookup_headers_lib",
deps = [":lookup"],
)
DATA_FLOW_DEPS = [ DATA_FLOW_DEPS = [
":bounds_check", ":bounds_check",
":concat_lib", ":concat_lib",

View File

@ -3,11 +3,16 @@
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
features = ["-parse_headers"],
) )
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_copts") load(
"//tensorflow:tensorflow.bzl",
"cc_header_only_library",
"tf_copts",
)
# To be exported to tensorflow/core:mobile_srcs. # To be exported to tensorflow/core:mobile_srcs.
filegroup( filegroup(
@ -43,6 +48,13 @@ cc_library(
], ],
) )
cc_header_only_library(
name = "tensor_bundle_headers_lib",
deps = [
":tensor_bundle",
],
)
cc_library( cc_library(
name = "naming", name = "naming",
srcs = ["naming.cc"], srcs = ["naming.cc"],

View File

@ -56,6 +56,7 @@ BLACKLIST = [
"//tensorflow/contrib/factorization/examples:mnist", "//tensorflow/contrib/factorization/examples:mnist",
"//tensorflow/contrib/factorization/examples:mnist.py", "//tensorflow/contrib/factorization/examples:mnist.py",
"//tensorflow/contrib/factorization:factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", # pylint:disable=line-too-long "//tensorflow/contrib/factorization:factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", # pylint:disable=line-too-long
"//tensorflow/contrib/framework:load_and_remap_matrix_testdata",
"//tensorflow/contrib/bayesflow:reinforce_simple_example", "//tensorflow/contrib/bayesflow:reinforce_simple_example",
"//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long
] ]