Open-source ops to load and remap matrix (2-D) tensors. (Used for loading/remapping embeddings, warm-starting weights/biases, etc.)
PiperOrigin-RevId: 157148893
This commit is contained in:
parent
6993d7f7c5
commit
4659e562c9
@ -78,6 +78,8 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/contrib/batching:batch_ops_kernels",
|
||||
"//tensorflow/contrib/factorization/kernels:all_kernels",
|
||||
"//tensorflow/contrib/framework:generate_vocab_remapping_kernel",
|
||||
"//tensorflow/contrib/framework:load_and_remap_matrix_kernel",
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
|
||||
"//tensorflow/contrib/nccl:nccl_kernels",
|
||||
|
@ -47,6 +47,10 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc"
|
||||
#"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/decode_audio_op.cc"
|
||||
#"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/generate_vocab_remapping_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/checkpoint_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/generate_vocab_remapping_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
|
||||
|
@ -69,6 +69,8 @@ file(GLOB_RECURSE tensor_forest_hybrid_srcs
|
||||
GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(framework_checkpoint "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/checkpoint_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(framework_generate_vocab_remapping "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/generate_vocab_remapping_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(input_pipeline "${tensorflow_source_dir}/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc")
|
||||
|
@ -619,6 +619,10 @@ GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_factorization_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_framework_checkpoint_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_checkpoint_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_framework_generate_vocab_remapping_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_generate_vocab_remapping_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_framework_variable_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_variable_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_input_pipeline_ops"
|
||||
|
@ -25,6 +25,7 @@ tf_custom_op_py_library(
|
||||
"python/framework/__init__.py",
|
||||
"python/framework/checkpoint_utils.py",
|
||||
"python/framework/experimental.py",
|
||||
"python/framework/load_and_remap_matrix_ops.py",
|
||||
"python/framework/tensor_util.py",
|
||||
"python/ops/__init__.py",
|
||||
"python/ops/arg_scope.py",
|
||||
@ -34,13 +35,20 @@ tf_custom_op_py_library(
|
||||
],
|
||||
dso = [
|
||||
":python/ops/_variable_ops.so",
|
||||
":python/framework/_checkpoint_ops.so",
|
||||
],
|
||||
kernels = [
|
||||
":checkpoint_ops_op_lib",
|
||||
":generate_vocab_remapping_kernel",
|
||||
":generate_vocab_remapping_ops_op_lib",
|
||||
":load_and_remap_matrix_kernel",
|
||||
":variable_kernels",
|
||||
":variable_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_checkpoint_ops",
|
||||
":gen_generate_vocab_remapping_ops",
|
||||
":gen_variable_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -81,6 +89,32 @@ tf_kernel_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "generate_vocab_remapping_kernel",
|
||||
srcs = ["kernels/generate_vocab_remapping_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/kernels:lookup_table_init_op",
|
||||
"//tensorflow/core/kernels:lookup_table_op",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "load_and_remap_matrix_kernel",
|
||||
srcs = ["kernels/load_and_remap_matrix_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/util/tensor_bundle",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_variable_ops.so",
|
||||
srcs = [
|
||||
@ -94,13 +128,33 @@ tf_custom_op_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/framework/_checkpoint_ops.so",
|
||||
srcs = [
|
||||
"kernels/generate_vocab_remapping_op.cc",
|
||||
"kernels/load_and_remap_matrix_op.cc",
|
||||
"ops/checkpoint_ops.cc",
|
||||
"ops/generate_vocab_remapping_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core/kernels:lookup_headers_lib",
|
||||
"//tensorflow/core/util/tensor_bundle:tensor_bundle_headers_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["variable_ops"],
|
||||
op_lib_names = [
|
||||
"checkpoint_ops",
|
||||
"generate_vocab_remapping_ops",
|
||||
"variable_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_ops",
|
||||
deps = [
|
||||
":checkpoint_ops_op_lib",
|
||||
":generate_vocab_remapping_ops_op_lib",
|
||||
":variable_ops_op_lib",
|
||||
],
|
||||
)
|
||||
@ -113,6 +167,22 @@ tf_gen_op_wrapper_py(
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_generate_vocab_remapping_ops",
|
||||
out = "python/ops/gen_generate_vocab_remapping_ops.py",
|
||||
deps = [
|
||||
":generate_vocab_remapping_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_checkpoint_ops",
|
||||
out = "python/ops/gen_checkpoint_ops.py",
|
||||
deps = [
|
||||
":checkpoint_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "arg_scope_test",
|
||||
size = "small",
|
||||
@ -227,6 +297,43 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "load_and_remap_matrix_testdata",
|
||||
srcs = [
|
||||
"testdata/bundle_checkpoint.data-00000-of-00001",
|
||||
"testdata/bundle_checkpoint.index",
|
||||
"testdata/bundle_checkpoint_vocab.txt",
|
||||
"testdata/bundle_checkpoint_vocab_with_oov.txt",
|
||||
"testdata/keyword.txt",
|
||||
"testdata/keyword_new.txt",
|
||||
"testdata/keyword_shifted.txt",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "load_and_remap_matrix_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["python/framework/load_and_remap_matrix_ops_test.py"],
|
||||
data = [":load_and_remap_matrix_testdata"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":framework_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:partitioned_variables",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -74,6 +74,9 @@ See the @{$python/contrib.framework} guide.
|
||||
@@list_variables
|
||||
@@load_variable
|
||||
@@init_from_checkpoint
|
||||
@@load_and_remap_matrix_initializer
|
||||
@@load_embedding_initializer
|
||||
@@load_linear_multiclass_bias_initializer
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
@ -0,0 +1,173 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/lookup_table_init_op.h"
|
||||
#include "tensorflow/core/kernels/lookup_table_op.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// lookup::InitializeTableFromTextFile requires a delimiter even though we use
|
||||
// the entire line for vocabularies.
|
||||
constexpr char kUnusedLookupDelim = '\t';
|
||||
} // namespace
|
||||
|
||||
// This Op generates a vocab remapping Tensor from an old and new vocabulary
|
||||
// file that maps new ID's to old ID's.
|
||||
class GenerateVocabRemappingOp : public OpKernel {
|
||||
public:
|
||||
explicit GenerateVocabRemappingOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("new_vocab_offset", &new_vocab_offset_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_new_vocab", &num_new_vocab_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* new_vocab_file_tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("new_vocab_file", &new_vocab_file_tensor));
|
||||
OP_REQUIRES(context,
|
||||
TensorShapeUtils::IsScalar(new_vocab_file_tensor->shape()),
|
||||
errors::InvalidArgument(
|
||||
"new_vocab_file should be a single string, but got ",
|
||||
new_vocab_file_tensor->shape().DebugString()));
|
||||
|
||||
// Build a new ID->token lookup table.
|
||||
const string& new_vocab_filename =
|
||||
new_vocab_file_tensor->scalar<string>()();
|
||||
OP_REQUIRES(context, !new_vocab_filename.empty(),
|
||||
errors::InvalidArgument("new vocab filename cannot be empty."));
|
||||
lookup::HashTable<int64, string>* new_vocab_table =
|
||||
new lookup::HashTable<int64, string>(context, this);
|
||||
core::ScopedUnref unref_new(new_vocab_table);
|
||||
// Note: we pass -1 (unknown) for vocab_size, which is supposed to be the
|
||||
// total elements in file. This is different from num_new_vocab_, which
|
||||
// accounts for partitioning.
|
||||
OP_REQUIRES_OK(context, lookup::InitializeTableFromTextFile(
|
||||
new_vocab_filename,
|
||||
-1, // vocab_size
|
||||
kUnusedLookupDelim,
|
||||
-1, // key_index, use the line number.
|
||||
-2, // value_index, use the whole line/token.
|
||||
context->env(), new_vocab_table));
|
||||
OP_REQUIRES(context,
|
||||
new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(),
|
||||
errors::InvalidArgument("lookup table size must be larger than "
|
||||
"last new vocab entry's line"));
|
||||
|
||||
const Tensor* old_vocab_file_tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("old_vocab_file", &old_vocab_file_tensor));
|
||||
OP_REQUIRES(context,
|
||||
TensorShapeUtils::IsScalar(old_vocab_file_tensor->shape()),
|
||||
errors::InvalidArgument(
|
||||
"old_vocab_file should be a single string, but got ",
|
||||
old_vocab_file_tensor->shape().DebugString()));
|
||||
// Build a token->old ID lookup table.
|
||||
const string& old_vocab_filename =
|
||||
old_vocab_file_tensor->scalar<string>()();
|
||||
OP_REQUIRES(context, !old_vocab_filename.empty(),
|
||||
errors::InvalidArgument("new vocab filename cannot be empty."));
|
||||
lookup::HashTable<string, int64>* old_vocab_table =
|
||||
new lookup::HashTable<string, int64>(context, this);
|
||||
core::ScopedUnref unref_old(old_vocab_table);
|
||||
// Note: we pass -1 (unknown) for vocab_size, which is supposed to be the
|
||||
// total elements in file. This is different from num_new_vocab_, which
|
||||
// accounts for partitioning.
|
||||
OP_REQUIRES_OK(context, lookup::InitializeTableFromTextFile(
|
||||
old_vocab_filename,
|
||||
-1, // vocab_size
|
||||
kUnusedLookupDelim,
|
||||
-2, // key_index, use the whole line/token.
|
||||
-1, // value_index, use the line number.
|
||||
context->env(), old_vocab_table));
|
||||
|
||||
// Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ...,
|
||||
// new_vocab_offset + num_new_vocab_]
|
||||
// The double look-up requires a few temporary Tensors.
|
||||
Tensor new_ids;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}),
|
||||
&new_ids));
|
||||
auto new_ids_vec = new_ids.vec<int64>();
|
||||
// Note that we should always be able to find tokens for all new ID's, given
|
||||
// that the lookup table is constructed with the vocabulary file itself
|
||||
// (see the check on offset and table size post-initialization).
|
||||
Tensor default_token;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_temp(
|
||||
DT_STRING, TensorShape({num_new_vocab_}), &default_token));
|
||||
auto default_token_vec = default_token.vec<string>();
|
||||
default_token_vec.setConstant("" /* NOT_FOUND_TOKEN */);
|
||||
|
||||
Tensor default_id;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}),
|
||||
&default_id));
|
||||
auto default_id_vec = default_id.vec<int64>();
|
||||
default_id_vec.setConstant(-1 /* NOT_FOUND_ID */);
|
||||
|
||||
for (int i = 0; i < num_new_vocab_; ++i) {
|
||||
new_ids_vec(i) = static_cast<int64>(i + new_vocab_offset_);
|
||||
}
|
||||
Tensor tokens;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(
|
||||
DT_STRING, TensorShape({num_new_vocab_}), &tokens));
|
||||
Tensor* remapping;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(
|
||||
"remapping", TensorShape({num_new_vocab_}), &remapping));
|
||||
// In the corner case where num_new_vocab_ is 0 (we are dealing with an
|
||||
// OOV-only partition), we should not do this lookup.
|
||||
if (num_new_vocab_ != 0) {
|
||||
OP_REQUIRES_OK(context, new_vocab_table->Find(context, new_ids, &tokens,
|
||||
default_token));
|
||||
OP_REQUIRES_OK(context, old_vocab_table->Find(context, tokens, remapping,
|
||||
default_id));
|
||||
}
|
||||
// Iterate through remapping to calculate num_present.
|
||||
const auto remapping_vec = remapping->vec<int64>();
|
||||
int num_present = 0;
|
||||
for (int i = 0; i < num_new_vocab_; ++i) {
|
||||
if (remapping_vec(i) != -1 /* NOT_FOUND_ID */) {
|
||||
++num_present;
|
||||
}
|
||||
}
|
||||
Tensor* num_present_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("num_present", TensorShape({}),
|
||||
&num_present_t));
|
||||
num_present_t->scalar<int>()() = num_present;
|
||||
}
|
||||
|
||||
private:
|
||||
int new_vocab_offset_;
|
||||
int num_new_vocab_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("GenerateVocabRemapping").Device(DEVICE_CPU),
|
||||
GenerateVocabRemappingOp);
|
||||
|
||||
} // namespace tensorflow
|
259
tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc
Normal file
259
tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc
Normal file
@ -0,0 +1,259 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
|
||||
// swaps around the rows/columns according to row_remapping/col_remapping.
|
||||
// "Missing" cells are initialized with values from initializing_values.
|
||||
class LoadAndRemapMatrixOp : public OpKernel {
|
||||
public:
|
||||
explicit LoadAndRemapMatrixOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Checks what we're remapping and inverts the relevant remapping Tensors to
|
||||
// be maps with key = old ID, value = new ID.
|
||||
std::vector<std::pair<int64, int64>> old_row_to_new_row_pairs;
|
||||
std::vector<bool> row_id_present(num_rows_);
|
||||
const Tensor* row_remapping_t;
|
||||
OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
|
||||
const auto row_remapping = row_remapping_t->vec<int64>();
|
||||
OP_REQUIRES(context, row_remapping.size() == num_rows_,
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Size of row_remapping is ", row_remapping.size(),
|
||||
" intead of being equal to num_rows=", num_rows_)));
|
||||
old_row_to_new_row_pairs.reserve(num_rows_);
|
||||
for (int i = 0; i < row_remapping.size(); ++i) {
|
||||
if (row_remapping(i) < 0) continue;
|
||||
row_id_present[i] = true;
|
||||
old_row_to_new_row_pairs.push_back(std::make_pair(row_remapping(i), i));
|
||||
}
|
||||
|
||||
// Processes the remapping for columns.
|
||||
std::unordered_map<int64, int64> old_col_to_new_col_map;
|
||||
std::vector<bool> col_id_present(num_cols_);
|
||||
const Tensor* col_remapping_t;
|
||||
OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
|
||||
const auto col_remapping = col_remapping_t->vec<int64>();
|
||||
// Note that we always "remap rows", even when the row vocabulary does
|
||||
// not change, because partitioning requires a mapping from partitioned
|
||||
// Variables to the full checkpoints we load.
|
||||
const bool remap_cols = col_remapping.size() > 0;
|
||||
if (remap_cols) {
|
||||
OP_REQUIRES(
|
||||
context, col_remapping.size() == num_cols_,
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Provided col_remapping, but its size is ", col_remapping.size(),
|
||||
" instead of being equal to num_cols=", num_cols_)));
|
||||
for (int i = 0; i < col_remapping.size(); ++i) {
|
||||
const int64 old_col = col_remapping(i);
|
||||
if (old_col < 0) continue;
|
||||
col_id_present[i] = true;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
gtl::InsertIfNotPresent(&old_col_to_new_col_map, old_col, i),
|
||||
errors::Unimplemented(strings::StrCat(
|
||||
"Old column ID ", old_col, " is mapped to both new column ID ",
|
||||
old_col_to_new_col_map[old_col], " and ", i,
|
||||
", which is not currently supported - but could be "
|
||||
"implemented.")));
|
||||
}
|
||||
} else {
|
||||
col_id_present.clear();
|
||||
col_id_present.resize(num_cols_, true);
|
||||
}
|
||||
|
||||
// Processes the checkpoint source and the provided Tensor name.
|
||||
const Tensor* ckpt_path_t;
|
||||
OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
|
||||
const string ckpt_path = *(ckpt_path_t->scalar<string>().data());
|
||||
const Tensor* old_tensor_name_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("old_tensor_name", &old_tensor_name_t));
|
||||
const string old_tensor_name =
|
||||
*(old_tensor_name_t->scalar<string>().data());
|
||||
|
||||
LOG(INFO) << "Processing checkpoint : " << ckpt_path;
|
||||
BundleReader reader(context->env(), ckpt_path);
|
||||
OP_REQUIRES_OK(context, reader.status());
|
||||
|
||||
DataType tensor_type;
|
||||
TensorShape tensor_shape;
|
||||
OP_REQUIRES_OK(context, reader.LookupDtypeAndShape(
|
||||
old_tensor_name, &tensor_type, &tensor_shape));
|
||||
OP_REQUIRES(context, tensor_type == DT_FLOAT,
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Tensor ", old_tensor_name, " has invalid type ",
|
||||
DataTypeString(tensor_type), " instead of expected type ",
|
||||
DataTypeString(DT_FLOAT))));
|
||||
// This op is limited to loading Tensors of rank 2 (matrices).
|
||||
OP_REQUIRES(
|
||||
context, tensor_shape.dims() == 2,
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Tensor ", old_tensor_name, " has shape ",
|
||||
tensor_shape.DebugString(), " of invalid rank ",
|
||||
tensor_shape.dims(), " instead of expected shape of rank 2.")));
|
||||
|
||||
if (!remap_cols) {
|
||||
// TODO(weiho): Consider relaxing this restriction to allow partial column
|
||||
// loading (even when no column remapping is specified) if there turns out
|
||||
// to be a use case for it.
|
||||
OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1),
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Tensor ", old_tensor_name, " has shape ",
|
||||
tensor_shape.DebugString(),
|
||||
", where the size of its 2nd dimension is ",
|
||||
tensor_shape.dim_size(1),
|
||||
" instead of being equal to num_cols=", num_cols_)));
|
||||
}
|
||||
|
||||
// Uses TensorSlice to selectively read rows of interest from the old
|
||||
// tensor. Given BundleReader's use of RandomAccessFile and InputBuffer,
|
||||
// there shouldn't too many more additional disk seeks when compared to
|
||||
// loading the old tensor in chunks, once we sort the row IDs. Even if there
|
||||
// are locality concerns with some reading patterns, that just means if we
|
||||
// had read it in chunks, then we would have had to read, copy, and process
|
||||
// then discard many redundant rows - so we should come out ahead this way.
|
||||
// In addition, this frees us from having to hold the entire old tensor in
|
||||
// memory.
|
||||
std::sort(old_row_to_new_row_pairs.begin(), old_row_to_new_row_pairs.end());
|
||||
std::vector<TensorSlice> tensor_slices;
|
||||
tensor_slices.reserve(old_row_to_new_row_pairs.size());
|
||||
TensorSlice slice(tensor_shape.dims());
|
||||
for (const auto& pair : old_row_to_new_row_pairs) {
|
||||
OP_REQUIRES(
|
||||
context, pair.first < tensor_shape.dim_size(0),
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Trying to read row ", pair.first, " from tensor ",
|
||||
old_tensor_name, ", which only has ", tensor_shape.dim_size(0),
|
||||
" rows (with shape ", tensor_shape.DebugString(), ").")));
|
||||
slice.set_start(0, pair.first);
|
||||
slice.set_length(0, 1);
|
||||
tensor_slices.push_back(slice);
|
||||
}
|
||||
|
||||
// Allocates the output matrix.
|
||||
Tensor* output_matrix_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("output_matrix",
|
||||
TensorShape({num_rows_, num_cols_}),
|
||||
&output_matrix_t));
|
||||
auto output_matrix = output_matrix_t->matrix<float>();
|
||||
|
||||
// Iterates through tensor slices and copies over values from the old tensor
|
||||
// to the output matrix.
|
||||
Tensor loaded_tensor_t(DT_FLOAT,
|
||||
TensorShape({1, tensor_shape.dim_size(1)}));
|
||||
for (int i = 0; i < tensor_slices.size(); ++i) {
|
||||
const int64 new_row = old_row_to_new_row_pairs[i].second;
|
||||
if (i % 500000 == 0) {
|
||||
LOG(INFO) << "Processing slice " << i << " of " << tensor_slices.size()
|
||||
<< " - corresponding to old row "
|
||||
<< old_row_to_new_row_pairs[i].first << " of "
|
||||
<< tensor_shape.dim_size(0);
|
||||
}
|
||||
OP_REQUIRES_OK(context,
|
||||
reader.LookupSlice(old_tensor_name, tensor_slices[i],
|
||||
&loaded_tensor_t));
|
||||
|
||||
// Copies over the row element-by-element, in case remapping is needed
|
||||
// along the column axis.
|
||||
const auto& loaded_tensor = loaded_tensor_t.flat<float>();
|
||||
for (int old_col = 0; old_col < loaded_tensor.size(); ++old_col) {
|
||||
int64 new_col = old_col;
|
||||
if (remap_cols) {
|
||||
const int64* new_col_ptr =
|
||||
gtl::FindOrNull(old_col_to_new_col_map, old_col);
|
||||
if (new_col_ptr == nullptr) {
|
||||
// Column remapping is specified, but this column is not found in
|
||||
// old_col_to_new_col_map, so we leave it uninitialized, to be
|
||||
// filled in with initializing_values later.
|
||||
continue;
|
||||
}
|
||||
new_col = *new_col_ptr;
|
||||
}
|
||||
|
||||
OP_REQUIRES(context,
|
||||
new_row < num_rows_ && new_col < num_cols_ &&
|
||||
new_row >= 0 && new_col >= 0,
|
||||
errors::Internal(strings::StrCat(
|
||||
"new_row=", new_row, " and new_col=", new_col,
|
||||
" should have been less than num_rows_=", num_rows_,
|
||||
" and num_cols_=", num_cols_,
|
||||
" and non-negative. This should never have happened "
|
||||
"if the code were correct. Please file a bug.")));
|
||||
output_matrix(new_row, new_col) = loaded_tensor(old_col);
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Copied " << tensor_slices.size()
|
||||
<< " rows from old matrix (with " << tensor_shape.dim_size(0)
|
||||
<< " rows) to new matrix (with " << num_rows_ << " rows).";
|
||||
|
||||
// At this point, there are potentially whole rows/columns uninitialized
|
||||
// (corresponding to the indices where row_id_present/col_id_present are
|
||||
// false). We fill this in cell-by-cell using row_id_present and
|
||||
// col_id_present while dequeuing from the initializing_values vector.
|
||||
const Tensor* initializing_values_t;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->input("initializing_values", &initializing_values_t));
|
||||
const auto initializing_values = initializing_values_t->flat<float>();
|
||||
int64 initializing_values_index = 0;
|
||||
for (int i = 0; i < num_rows_; ++i) {
|
||||
for (int j = 0; j < num_cols_; ++j) {
|
||||
if (!row_id_present[i] || !col_id_present[j]) {
|
||||
output_matrix(i, j) = initializing_values(initializing_values_index);
|
||||
++initializing_values_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that we used all the given initializing values.
|
||||
OP_REQUIRES(
|
||||
context, initializing_values_index == initializing_values.size(),
|
||||
errors::InvalidArgument(
|
||||
"initializing_values contained ", initializing_values.size(),
|
||||
" elements, but only ", initializing_values_index,
|
||||
" elements were used to fill in missing values."));
|
||||
}
|
||||
|
||||
private:
|
||||
int64 num_rows_;
|
||||
int64 num_cols_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
|
||||
LoadAndRemapMatrixOp);
|
||||
|
||||
} // namespace tensorflow
|
102
tensorflow/contrib/framework/ops/checkpoint_ops.cc
Normal file
102
tensorflow/contrib/framework/ops/checkpoint_ops.cc
Normal file
@ -0,0 +1,102 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER_OP("LoadAndRemapMatrix")
|
||||
.Input("ckpt_path: string")
|
||||
.Input("old_tensor_name: string")
|
||||
.Input("row_remapping: int64")
|
||||
.Input("col_remapping: int64")
|
||||
.Input("initializing_values: float")
|
||||
.Attr("num_rows: int >= 0")
|
||||
.Attr("num_cols: int >= 1")
|
||||
.Output("output_matrix: float")
|
||||
// TODO(b/30502450): Setting the op as being stateful prevents it from being
|
||||
// executed more often than expected (possibly due to stateful ops not being
|
||||
// subject to constant folding?). This op is usually slow and may require
|
||||
// multiple disk reads, so we want to minimize the number of times it's
|
||||
// executed redundantly.
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
|
||||
int64 num_rows;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_rows", &num_rows));
|
||||
int64 num_cols;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_cols", &num_cols));
|
||||
|
||||
c->set_output(0, c->Matrix(num_rows, num_cols));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Loads a 2-D (matrix) `Tensor` with name `old_tensor_name` from the checkpoint
|
||||
at `ckpt_path` and potentially reorders its rows and columns using the
|
||||
specified remappings.
|
||||
|
||||
Most users should use one of the wrapper initializers (such as
|
||||
`tf.contrib.framework.load_and_remap_matrix_initializer`) instead of this
|
||||
function directly.
|
||||
|
||||
The remappings are 1-D tensors with the following properties:
|
||||
|
||||
* `row_remapping` must have exactly `num_rows` entries. Row `i` of the output
|
||||
matrix will be initialized from the row corresponding to index
|
||||
`row_remapping[i]` in the old `Tensor` from the checkpoint.
|
||||
* `col_remapping` must have either 0 entries (indicating that no column
|
||||
reordering is needed) or `num_cols` entries. If specified, column `j` of the
|
||||
output matrix will be initialized from the column corresponding to index
|
||||
`col_remapping[j]` in the old `Tensor` from the checkpoint.
|
||||
* A value of -1 in either of the remappings signifies a "missing" entry. In that
|
||||
case, values from the `initializing_values` tensor will be used to fill that
|
||||
missing row or column. If `row_remapping` has `r` missing entries and
|
||||
`col_remapping` has `c` missing entries, then the following condition must be
|
||||
true:
|
||||
|
||||
`(r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values)`
|
||||
|
||||
The remapping tensors can be generated using the GenerateVocabRemapping op.
|
||||
|
||||
As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1],
|
||||
initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing
|
||||
the value from row i, column j of the old tensor in the checkpoint, the output
|
||||
matrix will look like the following:
|
||||
|
||||
[[w(1, 0), w(1, 2), 0.5],
|
||||
[w(0, 0), w(0, 2), -0.5],
|
||||
[0.25, -0.25, 42]]
|
||||
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from
|
||||
which the old matrix `Tensor` will be loaded.
|
||||
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
row_remapping: An int `Tensor` of row remappings (generally created by
|
||||
`generate_vocab_remapping`). Even if no row remapping is needed, this must
|
||||
still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted
|
||||
index-valued `Tensor` (e.g. [8, 9, 10, ...], for partitioned `Variables`).
|
||||
col_remapping: An int `Tensor` of column remappings (generally created by
|
||||
`generate_vocab_remapping`). May be a size-0 `Tensor` if only row remapping
|
||||
is to be done (e.g. column ordering is the same).
|
||||
initializing_values: A float `Tensor` containing values to fill in for cells
|
||||
in the output matrix that are not loaded from the checkpoint. Length must be
|
||||
exactly the same as the number of missing / new cells.
|
||||
num_rows: Number of rows (length of the 1st dimension) in the output matrix.
|
||||
num_cols: Number of columns (length of the 2nd dimension) in the output matrix.
|
||||
output_matrix: Output matrix containing existing values loaded from the
|
||||
checkpoint, and with any missing values filled in from initializing_values.
|
||||
)doc");
|
||||
} // namespace tensorflow
|
@ -0,0 +1,77 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER_OP("GenerateVocabRemapping")
|
||||
.Input("new_vocab_file: string")
|
||||
.Input("old_vocab_file: string")
|
||||
.Attr("new_vocab_offset: int >= 0")
|
||||
.Attr("num_new_vocab: int >= 0")
|
||||
.Output("remapping: int64")
|
||||
.Output("num_present: int32")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
|
||||
int64 new_vocab_offset;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("new_vocab_offset", &new_vocab_offset));
|
||||
int64 num_new_vocab;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_new_vocab", &num_new_vocab));
|
||||
|
||||
c->set_output(0, c->Vector(num_new_vocab));
|
||||
c->set_output(1, c->Scalar());
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Given a path to new and old vocabulary files, returns a remapping Tensor of
|
||||
length `num_new_vocab`, where `remapping[i]` contains the row number in the old
|
||||
vocabulary that corresponds to row `i` in the new vocabulary (starting at line
|
||||
`new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
|
||||
in the new vocabulary is not in the old vocabulary. `num_vocab_offset` enables
|
||||
use in the partitioned variable case, and should generally be set through
|
||||
examining partitioning info. The format of the files should be a text file,
|
||||
with each line containing a single entity within the vocabulary.
|
||||
|
||||
For example, with `new_vocab_file` a text file containing each of the following
|
||||
elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3],
|
||||
`num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be
|
||||
`[0, -1, 2]`.
|
||||
|
||||
The op also returns a count of how many entries in the new vocabulary
|
||||
were present in the old vocabulary, which is used to calculate the number of
|
||||
values to initialize in a weight matrix remapping
|
||||
|
||||
This functionality can be used to remap both row vocabularies (typically,
|
||||
features) and column vocabularies (typically, classes) from TensorFlow
|
||||
checkpoints. Note that the partitioning logic relies on contiguous vocabularies
|
||||
corresponding to div-partitioned variables. Moreover, the underlying remapping
|
||||
uses an IndexTable (as opposed to an inexact CuckooTable), so client code should
|
||||
use the corresponding index_table_from_file() as the FeatureColumn framework
|
||||
does (as opposed to tf.feature_to_id(), which uses a CuckooTable).
|
||||
|
||||
new_vocab_file: Path to the new vocab file.
|
||||
old_vocab_file: Path to the old vocab file.
|
||||
new_vocab_offset: How many entries into the new vocab file to start reading.
|
||||
num_new_vocab: Number of entries in the new vocab file to remap.
|
||||
remapping: A Tensor of length num_new_vocab where the element at index i
|
||||
is equal to the old ID that maps to the new ID i. This element is -1 for any
|
||||
new ID that is not found in the old vocabulary.
|
||||
num_present: Number of new vocab entries found in old vocab.
|
||||
)doc");
|
||||
} // namespace tensorflow
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
|
||||
from tensorflow.contrib.framework.python.framework.experimental import experimental
|
||||
from tensorflow.contrib.framework.python.framework.load_and_remap_matrix_ops import *
|
||||
from tensorflow.contrib.framework.python.framework.tensor_util import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util import decorator_utils
|
||||
|
@ -0,0 +1,491 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Operations for generating and loading vocab remappings."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import gen_checkpoint_ops
|
||||
from tensorflow.contrib.framework.python.ops import gen_generate_vocab_remapping_ops
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_checkpoint_ops_so = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_checkpoint_ops.so"))
|
||||
|
||||
ops.NotDifferentiable("GenerateVocabRemapping")
|
||||
ops.NotDifferentiable("LoadAndRemapMatrix")
|
||||
|
||||
|
||||
def _load_and_remap_matrix(ckpt_path,
|
||||
old_tensor_name,
|
||||
new_row_vocab_offset,
|
||||
num_rows_to_load,
|
||||
new_col_vocab_size,
|
||||
initializer,
|
||||
old_row_vocab_file=None,
|
||||
new_row_vocab_file=None,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=0,
|
||||
num_col_oov_buckets=0):
|
||||
"""Loads a 2-D (matrix) `Tensor` from checkpoint.
|
||||
|
||||
Generates 1D-remappings for rows and columns using the
|
||||
`GenerateVocabRemapping` op, and initializes any anticipated values with the
|
||||
provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a
|
||||
matrix that loads existing values from the checkpoint, while filling out
|
||||
"missing" values with the newly initialized values. See
|
||||
contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped
|
||||
functionality (LoadAndRemapMatrix). This wrapper can be used to perform only
|
||||
row remapping or only col remapping. If only row remapping is desired,
|
||||
{new,old}_col_vocab_file should be `None`, and vice versa for column
|
||||
remapping.
|
||||
|
||||
NOTE: This only supports div-partitioning the vocabulary on the 1st dimension
|
||||
(row axis) via `new_row_vocab_offset`.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
new_row_vocab_offset: A 0-indexed integer representing what line to
|
||||
start reading at in the new row vocabulary. Used for partitioned
|
||||
variables.
|
||||
num_rows_to_load: Number of rows to load for the new vocabulary (note: to
|
||||
support variable partitioning and partial loading, this does not need to
|
||||
be the same as the number of entries in `new_row_vocab_file`).
|
||||
new_col_vocab_size: Number of columns to load - should be the same as the
|
||||
number of entries in `new_col_vocab_file`, since we don't support
|
||||
partitioning along the column axis.
|
||||
initializer: Callable initializer function that accepts a 1-D tensor as the
|
||||
arg to specify the shape of the returned tensor. Used to initialize
|
||||
missing values.
|
||||
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old row vocabulary file. Can be None, which represents no
|
||||
remapping on the row axis.
|
||||
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new row vocabulary file. Can be None, which represents no remapping
|
||||
on the row axis - in which case, `new_row_vocab_offset` and
|
||||
`num_rows_to_load` work under the assumption that the new row vocab is the
|
||||
same as the old row vocab.
|
||||
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis - in which case, `new_col_vocab_size` works
|
||||
under the assumption that the new col vocab is the same as the old col
|
||||
vocab.
|
||||
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
|
||||
to append. Must be >= 0.
|
||||
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
columns to append. Must be >= 0.
|
||||
|
||||
Returns:
|
||||
A Tensor of shape `[num_rows_to_load + num_row_oov_buckets,
|
||||
new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the
|
||||
specified tensor in the checkpoint, and any missing or OOV values
|
||||
initialized with the given `initializer`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0.
|
||||
ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is
|
||||
provided, while the other is not. Same for `old_col_vocab_file` and
|
||||
`new_col_vocab_file`.
|
||||
ValueError: If neither row vocabs or col vocabs are provided.
|
||||
"""
|
||||
if num_row_oov_buckets < 0:
|
||||
raise ValueError("num_row_oov_buckets must be >= 0, but received %d" %
|
||||
num_row_oov_buckets)
|
||||
if num_col_oov_buckets < 0:
|
||||
raise ValueError("num_col_oov_buckets must be >= 0, but received %d" %
|
||||
num_col_oov_buckets)
|
||||
|
||||
if bool(old_row_vocab_file) != bool(new_row_vocab_file):
|
||||
raise ValueError(
|
||||
"old_row_vocab_file and new_row_vocab_file must both be specified or "
|
||||
"left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'".
|
||||
format(old_row_vocab_file, new_row_vocab_file))
|
||||
if bool(old_col_vocab_file) != bool(new_col_vocab_file):
|
||||
raise ValueError(
|
||||
"old_col_vocab_file and new_col_vocab_file must both be specified or "
|
||||
"left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'".
|
||||
format(old_col_vocab_file, new_col_vocab_file))
|
||||
|
||||
remap_rows = new_row_vocab_file and old_row_vocab_file
|
||||
remap_cols = new_col_vocab_file and old_col_vocab_file
|
||||
if not (remap_rows or remap_cols):
|
||||
raise ValueError(
|
||||
"Must provide either row or column vocab files. If no remapping is "
|
||||
"necessary, consider using `tf.contrib.framework.init_from_checkpoint` "
|
||||
"instead.")
|
||||
|
||||
num_rows_present = num_rows_to_load
|
||||
if remap_rows:
|
||||
row_remapping, num_rows_present = (
|
||||
gen_generate_vocab_remapping_ops.generate_vocab_remapping(
|
||||
new_vocab_file=new_row_vocab_file,
|
||||
old_vocab_file=old_row_vocab_file,
|
||||
new_vocab_offset=new_row_vocab_offset,
|
||||
num_new_vocab=num_rows_to_load))
|
||||
else:
|
||||
# Even when the rows are not being reordered, we still need to generate a
|
||||
# remapping to account for initializing partitioned Variables (when
|
||||
# new_row_vocab_offset is non-zero).
|
||||
row_remapping = math_ops.range(
|
||||
new_row_vocab_offset,
|
||||
new_row_vocab_offset + num_rows_to_load,
|
||||
dtype=dtypes.int64)
|
||||
|
||||
col_remapping = []
|
||||
num_cols_present = new_col_vocab_size
|
||||
if remap_cols:
|
||||
col_remapping, num_cols_present = (
|
||||
gen_generate_vocab_remapping_ops.generate_vocab_remapping(
|
||||
new_vocab_file=new_col_vocab_file,
|
||||
old_vocab_file=old_col_vocab_file,
|
||||
new_vocab_offset=0, # Offset is unused for cols (no partitioning).
|
||||
num_new_vocab=new_col_vocab_size))
|
||||
|
||||
init_vals = initializer([
|
||||
num_rows_to_load * new_col_vocab_size -
|
||||
num_rows_present * num_cols_present, 1
|
||||
])
|
||||
return_tensor = gen_checkpoint_ops.load_and_remap_matrix(
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=old_tensor_name,
|
||||
row_remapping=row_remapping,
|
||||
col_remapping=col_remapping,
|
||||
initializing_values=init_vals,
|
||||
num_rows=num_rows_to_load,
|
||||
num_cols=new_col_vocab_size)
|
||||
|
||||
# Add OOV row(s) and column(s).
|
||||
if num_row_oov_buckets > 0:
|
||||
init_row_oov_val = initializer([num_row_oov_buckets, new_col_vocab_size])
|
||||
init_row_oov_val = ops.convert_to_tensor(init_row_oov_val)
|
||||
return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0)
|
||||
if num_col_oov_buckets > 0:
|
||||
# We need to add any row OOV to the new column shape.
|
||||
init_col_oov_val = initializer(
|
||||
[num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets])
|
||||
init_col_oov_val = ops.convert_to_tensor(init_col_oov_val)
|
||||
return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1)
|
||||
|
||||
return return_tensor
|
||||
|
||||
|
||||
def load_and_remap_matrix_initializer(ckpt_path,
|
||||
old_tensor_name,
|
||||
new_row_vocab_size,
|
||||
new_col_vocab_size,
|
||||
old_row_vocab_file=None,
|
||||
new_row_vocab_file=None,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=0,
|
||||
num_col_oov_buckets=0,
|
||||
initializer=None):
|
||||
r"""Returns a var initializer for loading and remapping a 2-D (matrix) tensor.
|
||||
|
||||
The returned initializer loads a 2-D (matrix) `Tensor` with name
|
||||
`old_tensor_name` from the checkpoint at `ckpt_path`. It will reorder the
|
||||
rows/columns according to the specified vocab files and append additional
|
||||
out-of-vocabulary rows/columns according to the number of OOV buckets.
|
||||
|
||||
The format of the file at the `{old,new}_{row,col}_vocab_file` path should be
|
||||
a text file, with each line containing a single entity within the vocabulary.
|
||||
Let the function `line_of(f, "x")` return the 0-indexed line number of the
|
||||
entity "x" in file f, and the function `entity_at(f, i)` return the entity at
|
||||
line i of file f. Then, row i of the new output matrix will be taken from row
|
||||
`line_of(old_row_vocab_file, entity_at(new_row_vocab_file, i))` of the old
|
||||
matrix. If any entity in `new_row_vocab_file` is not found in
|
||||
`old_row_vocab_file`, that row is considered a "missing" row, and its values
|
||||
will be initialized using the `initializer` arg. The same logic also applies
|
||||
for the columns.
|
||||
|
||||
For example, assuming that:
|
||||
|
||||
* `old_row_vocab_file` contains "mercury\nvenus\nmars"
|
||||
* `new_row_vocab_file` contains "venus\njupiter\nmercury"
|
||||
* `old_col_vocab_file` contains "good\nbetter\nbest"
|
||||
* `new_col_vocab_file` contains "good\nbest\nfantastic"
|
||||
* `initializer` returns the natural numbers `[1, 2, 3, 4, ...]`
|
||||
* `w(i, j)` represents the value from row i, column j of the old matrix
|
||||
|
||||
Then the new output matrix will look like:
|
||||
|
||||
`[[w(1, 0), w(1, 2), 1],
|
||||
[2, 3, 4],
|
||||
[w(0, 0), w(0, 2), 5]]`
|
||||
|
||||
If we further specify that:
|
||||
|
||||
* `num_row_oov_buckets` == 2
|
||||
* `num_col_oov_buckets` == 1
|
||||
|
||||
Then the new output matrix will look like:
|
||||
|
||||
`[[w(1, 0), w(1, 2), 1, 12],
|
||||
[2, 3, 4, 13],
|
||||
[w(0, 0), w(0, 2), 5, 14],
|
||||
[6, 7, 8, 15],
|
||||
[9, 10, 11, 16]]`
|
||||
|
||||
If `{old,new}_row_vocab_file` are None, we assume that the old and new row
|
||||
vocab files are the same, and no row remapping is done. If
|
||||
`{old,new}_col_vocab_file` are None, we assume that the old and new column
|
||||
vocab files are the same, and no column remapping is done.
|
||||
|
||||
The returned initializer only supports div-partitioning along the row axis. It
|
||||
does not support partitioning along the column axis or mod-partitioning.
|
||||
|
||||
NOTE: When this is used to warm-start variables, client code should use
|
||||
`tf.lookup.index_table_from_tensor()` like
|
||||
contrib/layers/python/layers/feature_column.py does, as opposed to
|
||||
`tf.feature_to_id()` - in order to ensure the underlying lookup tables are the
|
||||
same.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
new_row_vocab_size: `int` specifying the number of entries in
|
||||
`new_row_vocab_file`. If no row remapping is needed (no row vocab
|
||||
provided), this should be equal to the number of rows to load from the old
|
||||
matrix (which can theoretically be smaller than the number of rows in the
|
||||
old matrix).
|
||||
new_col_vocab_size: `int` specifying the number of entries in
|
||||
`new_col_vocab_file`. If no column remapping is needed (no column vocab
|
||||
provided), this should be equal to the number of columns in the old
|
||||
matrix.
|
||||
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old row vocabulary file. Can be None, which represents no
|
||||
remapping on the row axis.
|
||||
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new row vocabulary file. Can be None, which represents no remapping
|
||||
on the row axis.
|
||||
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
|
||||
to append. Must be >= 0.
|
||||
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
columns to append. Must be >= 0.
|
||||
initializer: Initializer function to initialize missing values. Accepts a
|
||||
1-D tensor as the arg to specify the shape of the returned tensor. If
|
||||
`None`, defaults to using `zeros_initializer()`.
|
||||
|
||||
Returns:
|
||||
A variable initializer function that should be used to initialize a
|
||||
(potentially partitioned) `Variable` whose complete shape is
|
||||
`[new_row_vocab_size + num_row_oov_buckets, new_col_vocab_size +
|
||||
num_col_oov_buckets]`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `initializer` is specified but not callable.
|
||||
"""
|
||||
if initializer is None:
|
||||
# TODO(b/25671353): Consider using sqrt(6/(fan_in + fan_out)) instead, from
|
||||
# Glorot and Bengio, 2010.
|
||||
initializer = init_ops.zeros_initializer()
|
||||
|
||||
if not callable(initializer):
|
||||
raise TypeError(
|
||||
"initializer must be callable, instead of being {} of type {}.".format(
|
||||
initializer, type(initializer)))
|
||||
|
||||
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
|
||||
"""Variable initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of `Tensor` to return. Should include OOV on both axes.
|
||||
dtype: Must be float32.
|
||||
partition_info: variable_scope._PartitionInfo.
|
||||
|
||||
Returns:
|
||||
`Tensor` of shape `shape`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `dtype` is anything other than float32.
|
||||
ValueError: For shape mismatch upon invocation.
|
||||
"""
|
||||
# Sanity checks.
|
||||
if dtype != dtypes.float32:
|
||||
raise TypeError(
|
||||
"Currently, only float32 is supported. Received dtype: {}".format(
|
||||
dtype))
|
||||
if len(shape) != 2:
|
||||
raise ValueError("Expected 2-dim shape, but received: {}".format(shape))
|
||||
if shape[0] <= 0:
|
||||
raise ValueError(
|
||||
"Expected 1st dim of shape to be > 0, but received shape: {}".format(
|
||||
shape))
|
||||
if shape[1] != (new_col_vocab_size + num_col_oov_buckets):
|
||||
raise ValueError(
|
||||
"Expected 2nd dim of shape to be new_col_vocab_size ({}) + "
|
||||
"num_col_oov_buckets ({}) = {}, but received shape: {}".format(
|
||||
new_col_vocab_size, num_col_oov_buckets,
|
||||
new_col_vocab_size + num_col_oov_buckets, shape))
|
||||
|
||||
offset = 0
|
||||
if partition_info is not None:
|
||||
offset = partition_info.single_offset(shape)
|
||||
|
||||
if offset + shape[0] > new_row_vocab_size + num_row_oov_buckets:
|
||||
raise ValueError(
|
||||
"Trying to initialize {} additional rows after {} rows have already "
|
||||
"been initialized, which would exceed expected total row count of "
|
||||
"new_row_vocab_size ({}) + num_row_oov_buckets ({}) = {}.".format(
|
||||
shape[0], offset, new_row_vocab_size, num_row_oov_buckets,
|
||||
new_row_vocab_size + num_row_oov_buckets))
|
||||
|
||||
row_oov_buckets_to_use = min(shape[0],
|
||||
max(0, offset + shape[0] - new_row_vocab_size))
|
||||
num_rows_to_load = shape[0] - row_oov_buckets_to_use
|
||||
|
||||
return _load_and_remap_matrix(
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=old_tensor_name,
|
||||
new_row_vocab_offset=offset,
|
||||
num_rows_to_load=num_rows_to_load,
|
||||
new_col_vocab_size=new_col_vocab_size,
|
||||
initializer=initializer,
|
||||
old_row_vocab_file=old_row_vocab_file,
|
||||
new_row_vocab_file=new_row_vocab_file,
|
||||
old_col_vocab_file=old_col_vocab_file,
|
||||
new_col_vocab_file=new_col_vocab_file,
|
||||
num_row_oov_buckets=row_oov_buckets_to_use,
|
||||
num_col_oov_buckets=num_col_oov_buckets)
|
||||
|
||||
return _initializer
|
||||
|
||||
|
||||
def load_embedding_initializer(ckpt_path,
|
||||
embedding_tensor_name,
|
||||
new_vocab_size,
|
||||
embedding_dim,
|
||||
old_vocab_file,
|
||||
new_vocab_file,
|
||||
num_oov_buckets=0,
|
||||
initializer=None):
|
||||
"""Returns a variable initializer for loading pre-trained embeddings.
|
||||
|
||||
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
|
||||
embedding weights and remapping according to the provided vocab files. See
|
||||
docs for `load_and_remap_matrix_initializer()` for more details.
|
||||
|
||||
NOTE: Only for use with div-partitioned variables / vocabularies.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
embedding_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
new_vocab_size: Number of entries in the new vocab.
|
||||
embedding_dim: `int` specifying the dimension of the embedding vectors from
|
||||
the checkpoint. Must match the number of columns in the old embedding
|
||||
matrix.
|
||||
old_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old vocabulary file.
|
||||
new_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the new vocabulary file.
|
||||
num_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
buckets to use. Must be >= 0.
|
||||
initializer: Initializer function that accepts a 1-D tensor as the arg to
|
||||
specify the shape of the returned tensor. If `None`, defaults to using
|
||||
`truncated_normal_initializer()`.
|
||||
|
||||
Returns:
|
||||
A variable initializer function.
|
||||
"""
|
||||
if initializer is None:
|
||||
# TODO(b/25671353): This should be kept in sync with the stddev used by
|
||||
# feature_column.py's _EmbeddingColumn.
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
stddev=1.0 /
|
||||
math_ops.sqrt(math_ops.cast(embedding_dim, dtypes.float32)))
|
||||
|
||||
return load_and_remap_matrix_initializer(
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=embedding_tensor_name,
|
||||
new_row_vocab_size=new_vocab_size,
|
||||
new_col_vocab_size=embedding_dim,
|
||||
old_row_vocab_file=old_vocab_file,
|
||||
new_row_vocab_file=new_vocab_file,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=num_oov_buckets,
|
||||
num_col_oov_buckets=0,
|
||||
initializer=initializer)
|
||||
|
||||
|
||||
def load_linear_multiclass_bias_initializer(ckpt_path,
|
||||
bias_tensor_name,
|
||||
new_class_vocab_size,
|
||||
old_class_vocab_file,
|
||||
new_class_vocab_file,
|
||||
num_class_oov_buckets=0,
|
||||
initializer=None):
|
||||
"""Loads pre-trained multi-class biases for linear models from checkpoint.
|
||||
|
||||
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
|
||||
multi-class bias and remapping according to the provided vocab files. See docs
|
||||
for `load_and_remap_matrix_initializer()` for more details. In this case, the
|
||||
provided row_vocab is the class vocabulary, and the expected shape is
|
||||
`[new_class_vocab_size, 1]`.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
bias_tensor_name: Tensor name to load from in the checkpoints.
|
||||
new_class_vocab_size: Number of entries in the new class vocab.
|
||||
old_class_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old class vocabulary file.
|
||||
new_class_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the new class vocabulary file.
|
||||
num_class_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
buckets to use for the classes. Must be >= 0.
|
||||
initializer: Initializer function that accepts a 1-D tensor as the arg to
|
||||
specify the shape of the returned tensor. If `None`, defaults to using
|
||||
`zeros_initializer()`.
|
||||
|
||||
Returns:
|
||||
A variable initializer function.
|
||||
"""
|
||||
# Linear multi-class biases should be zero-initialized.
|
||||
if initializer is None:
|
||||
initializer = init_ops.zeros_initializer()
|
||||
|
||||
return load_and_remap_matrix_initializer(
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=bias_tensor_name,
|
||||
new_row_vocab_size=new_class_vocab_size,
|
||||
new_col_vocab_size=1,
|
||||
old_row_vocab_file=old_class_vocab_file,
|
||||
new_row_vocab_file=new_class_vocab_file,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=num_class_oov_buckets,
|
||||
num_col_oov_buckets=0,
|
||||
initializer=initializer)
|
@ -0,0 +1,558 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functional tests for the op to generate vocab remapping."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
from tensorflow.contrib.framework.python.framework import load_and_remap_matrix_ops
|
||||
from tensorflow.contrib.framework.python.ops import gen_checkpoint_ops
|
||||
from tensorflow.contrib.framework.python.ops.gen_generate_vocab_remapping_ops import generate_vocab_remapping
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
_TESTDATA_PATH = 'contrib/framework/testdata'
|
||||
|
||||
|
||||
class GenerateVocabRemappingTest(test.TestCase):
|
||||
"""Tests for the generate_vocab_remapping() method."""
|
||||
|
||||
def setUp(self):
|
||||
self.new_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_shifted.txt')
|
||||
self.old_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
|
||||
|
||||
def test_generate_remapping_with_no_vocab_changes(self):
|
||||
"""Tests where vocab does not change at all."""
|
||||
remapping, num_present = generate_vocab_remapping(
|
||||
new_vocab_file=self.old_vocab_file,
|
||||
old_vocab_file=self.old_vocab_file,
|
||||
num_new_vocab=3,
|
||||
new_vocab_offset=0)
|
||||
expected_remapping = range(0, 3)
|
||||
expected_num_present = 3
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_remapping, remapping.eval())
|
||||
self.assertAllEqual(expected_num_present, num_present.eval())
|
||||
|
||||
def test_generate_remapping_with_shifted_vocab(self):
|
||||
"""Tests where vocab is the same, but shifted / ordered differently."""
|
||||
remapping, num_present = generate_vocab_remapping(
|
||||
new_vocab_file=self.new_vocab_file,
|
||||
old_vocab_file=self.old_vocab_file,
|
||||
num_new_vocab=3,
|
||||
new_vocab_offset=0)
|
||||
expected_remapping = [2, 0, 1]
|
||||
expected_num_present = 3
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_remapping, remapping.eval())
|
||||
self.assertAllEqual(expected_num_present, num_present.eval())
|
||||
|
||||
def test_generate_remapping_with_offset(self):
|
||||
"""Tests offset and num_new_vocab logic."""
|
||||
remapping, num_present = generate_vocab_remapping(
|
||||
new_vocab_file=self.new_vocab_file,
|
||||
old_vocab_file=self.old_vocab_file,
|
||||
num_new_vocab=1,
|
||||
new_vocab_offset=1)
|
||||
expected_remapping = [0]
|
||||
expected_num_present = 1
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_remapping, remapping.eval())
|
||||
self.assertAllEqual(expected_num_present, num_present.eval())
|
||||
|
||||
|
||||
class LoadAndRemapMatrixTest(test.TestCase):
|
||||
"""Tests for the load_and_remap_weight_matrix() op."""
|
||||
|
||||
def setUp(self):
|
||||
ops.reset_default_graph()
|
||||
self.old_num_rows = 5
|
||||
self.old_num_cols = 16
|
||||
self.matrix_value = np.reshape(
|
||||
range(0, self.old_num_rows * self.old_num_cols), (self.old_num_rows,
|
||||
self.old_num_cols))
|
||||
with variable_scope.variable_scope('some_scope'):
|
||||
matrix = variable_scope.get_variable(
|
||||
'matrix',
|
||||
dtype=dtypes.float32,
|
||||
initializer=constant_op.constant(
|
||||
self.matrix_value, dtype=dtypes.float32))
|
||||
self.old_tensor_name = 'some_scope/matrix'
|
||||
|
||||
save = saver.Saver([matrix])
|
||||
with self.test_session() as sess:
|
||||
variables.global_variables_initializer().run()
|
||||
self.bundle_file = os.path.join(test.get_temp_dir(), 'bundle_checkpoint')
|
||||
save.save(sess, self.bundle_file)
|
||||
|
||||
def test_load_and_remap_no_missing(self):
|
||||
"""Tests the op's load and remap where there are no missing entries."""
|
||||
|
||||
# No column remapping, new weight matrix has second row, then first row.
|
||||
row_remapping = [1, 0]
|
||||
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
|
||||
ckpt_path=[self.bundle_file],
|
||||
old_tensor_name=self.old_tensor_name,
|
||||
row_remapping=row_remapping,
|
||||
col_remapping=[],
|
||||
initializing_values=[],
|
||||
num_rows=2,
|
||||
num_cols=self.old_num_cols)
|
||||
with self.test_session():
|
||||
self.assertAllClose(self.matrix_value[row_remapping],
|
||||
remapped_weight_matrix.eval())
|
||||
|
||||
# No row remapping, new weight matrix has third col, then first col.
|
||||
row_remapping = list(range(self.old_num_rows))
|
||||
col_remapping = [2, 0]
|
||||
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
|
||||
ckpt_path=[self.bundle_file],
|
||||
old_tensor_name=self.old_tensor_name,
|
||||
row_remapping=row_remapping,
|
||||
col_remapping=col_remapping,
|
||||
initializing_values=[],
|
||||
num_rows=len(row_remapping),
|
||||
num_cols=len(col_remapping))
|
||||
with self.test_session():
|
||||
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
|
||||
remapped_weight_matrix.eval())
|
||||
|
||||
# Both row and column remappings.
|
||||
row_remapping = [1, 0, 4]
|
||||
col_remapping = [1, 15]
|
||||
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
|
||||
ckpt_path=[self.bundle_file],
|
||||
old_tensor_name=self.old_tensor_name,
|
||||
row_remapping=row_remapping,
|
||||
col_remapping=col_remapping,
|
||||
initializing_values=[],
|
||||
num_rows=len(row_remapping),
|
||||
num_cols=len(col_remapping))
|
||||
with self.test_session():
|
||||
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
|
||||
remapped_weight_matrix.eval())
|
||||
|
||||
def test_load_and_remap_with_init(self):
|
||||
"""Tests the op's load and remap where there are missing entries."""
|
||||
init_val = 42
|
||||
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
|
||||
ckpt_path=[self.bundle_file],
|
||||
old_tensor_name=self.old_tensor_name,
|
||||
row_remapping=[2, -1, 0],
|
||||
col_remapping=[1, -1],
|
||||
initializing_values=[init_val] * 4,
|
||||
num_rows=3,
|
||||
num_cols=2)
|
||||
|
||||
expected_remapped_weight_matrix = np.reshape(
|
||||
[33, init_val, init_val, init_val, 1, init_val], [3, 2])
|
||||
|
||||
with self.test_session():
|
||||
self.assertAllClose(expected_remapped_weight_matrix,
|
||||
remapped_weight_matrix.eval())
|
||||
|
||||
def test_load_and_remap_all_missing_rows(self):
|
||||
"""Tests when all the rows are missing and need to be initialized."""
|
||||
num_rows = 7
|
||||
initializing_values = [42] * num_rows * self.old_num_cols
|
||||
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
|
||||
ckpt_path=[self.bundle_file],
|
||||
old_tensor_name=self.old_tensor_name,
|
||||
row_remapping=[-1] * num_rows,
|
||||
col_remapping=[],
|
||||
initializing_values=initializing_values,
|
||||
num_rows=num_rows,
|
||||
num_cols=self.old_num_cols)
|
||||
with self.test_session():
|
||||
self.assertAllClose(
|
||||
np.reshape(initializing_values, (num_rows, self.old_num_cols)),
|
||||
remapped_weight_matrix.eval())
|
||||
|
||||
def test_load_and_remap_all_missing_rows_and_cols(self):
|
||||
"""Tests when all the rows & cols are missing and need to be initialized."""
|
||||
num_rows = 7
|
||||
num_cols = 4
|
||||
initializing_values = [42] * num_rows * num_cols
|
||||
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
|
||||
ckpt_path=[self.bundle_file],
|
||||
old_tensor_name=self.old_tensor_name,
|
||||
row_remapping=[-1] * num_rows,
|
||||
col_remapping=[-1] * num_cols,
|
||||
initializing_values=initializing_values,
|
||||
num_rows=num_rows,
|
||||
num_cols=num_cols)
|
||||
with self.test_session():
|
||||
self.assertAllClose(
|
||||
np.reshape(initializing_values, (num_rows, num_cols)),
|
||||
remapped_weight_matrix.eval())
|
||||
|
||||
def test_load_and_remap_duplicate_row_remapping(self):
|
||||
"""Tests when an old row maps to multiple new rows.
|
||||
|
||||
(This should usually not happen when using public APIs).
|
||||
"""
|
||||
row_remapping = [1, 0, 0, 0, 1, 2]
|
||||
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
|
||||
ckpt_path=[self.bundle_file],
|
||||
old_tensor_name=self.old_tensor_name,
|
||||
row_remapping=row_remapping,
|
||||
col_remapping=[],
|
||||
initializing_values=[],
|
||||
num_rows=len(row_remapping),
|
||||
num_cols=self.old_num_cols)
|
||||
with self.test_session():
|
||||
self.assertAllClose(self.matrix_value[row_remapping],
|
||||
remapped_weight_matrix.eval())
|
||||
|
||||
def test_load_and_remap_invalid_col_remapping(self):
|
||||
"""Tests that an error is raised when an old col maps to multiple new cols.
|
||||
|
||||
(This should usually not happen when using public APIs).
|
||||
"""
|
||||
col_remapping = [1, 0, 0, 0, 1, 2]
|
||||
remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
|
||||
ckpt_path=[self.bundle_file],
|
||||
old_tensor_name=self.old_tensor_name,
|
||||
row_remapping=list(range(self.old_num_rows)),
|
||||
col_remapping=col_remapping,
|
||||
initializing_values=[],
|
||||
num_rows=self.old_num_rows,
|
||||
num_cols=len(col_remapping))
|
||||
with self.test_session(), self.assertRaises(errors.UnimplementedError):
|
||||
remapped_weight_matrix.eval()
|
||||
|
||||
|
||||
class LoadAndRemapWrappersTest(test.TestCase):
|
||||
"""Tests for the functionality of the Python wrappers."""
|
||||
|
||||
def setUp(self):
|
||||
self.bundle_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint')
|
||||
self.new_feature_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint_vocab.txt')
|
||||
self.old_feature_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH),
|
||||
'bundle_checkpoint_vocab_with_oov.txt')
|
||||
self.new_class_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt')
|
||||
self.old_class_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
|
||||
self.init_val = 42
|
||||
|
||||
def _init_val_initializer(shape, dtype=None, partition_info=None):
|
||||
del dtype, partition_info # Unused by this unit-testing initializer.
|
||||
return array_ops.tile(
|
||||
constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
|
||||
|
||||
self.initializer = _init_val_initializer
|
||||
|
||||
def test_load_and_remap_matrix(self):
|
||||
"""Tests the end-to-end loading / remapping of weights."""
|
||||
# load_and_remap_matrix() is the generalized wrapper that takes in row and
|
||||
# column vocabulary files, calls the relevant remappings, and returns the
|
||||
# weight matrix. Take this example to be linear multi-class by providing
|
||||
# both row and column vocabularies.
|
||||
remapped_matrix = load_and_remap_matrix_ops._load_and_remap_matrix(
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_rows_to_load=4,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_offset=1,
|
||||
initializer=self.initializer,
|
||||
num_row_oov_buckets=1,
|
||||
num_col_oov_buckets=1)
|
||||
|
||||
# [4 in vocab + 1 oov features, 4 in vocab + 1 oov classes]. The offset
|
||||
# means we read
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([18, 34, 50, self.init_val, self.init_val], [5, 1]),
|
||||
np.reshape([16, 32, 48, self.init_val, self.init_val], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1]),
|
||||
np.reshape([17, 33, 49, self.init_val, self.init_val], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
with self.test_session():
|
||||
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
|
||||
|
||||
def test_load_and_remap_output_layer_weight_initializer_linear(self):
|
||||
"""Tests for the output layer initializer in the linear multi-class case."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_row_oov_buckets=1,
|
||||
num_col_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50, self.init_val, self.init_val], [6, 1]),
|
||||
np.reshape([0, 16, 32, 48, self.init_val, self.init_val], [6, 1]),
|
||||
np.reshape([self.init_val] * 6, [6, 1]),
|
||||
np.reshape([1, 17, 33, 49, self.init_val, self.init_val], [6, 1]),
|
||||
np.reshape([self.init_val] * 6, [6, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5 feature vocab + 1 feature OOV, 4 class vocab + 1 class OOV]. Use a
|
||||
# partitioned variable to confirm that the offset logic works.
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='linear/obtained_weight_matrix',
|
||||
shape=[6, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_load_and_remap_output_layer_weight_initializer_dnn_output(self):
|
||||
"""Tests for the output layer initializer in the DNN output case."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
num_col_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50, 66], [5, 1]),
|
||||
np.reshape([0, 16, 32, 48, 64], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1]),
|
||||
np.reshape([1, 17, 33, 49, 65], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5-sized input layer, 4 class vocab + 1 class OOV].
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='dnn_output/obtained_weight_matrix',
|
||||
shape=[5, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_initializer_with_oov_only_partition(self):
|
||||
"""Tests for the output layer initializer where one partition is all OOV."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_row_oov_buckets=5,
|
||||
num_col_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50] + [self.init_val] * 6, [10, 1]),
|
||||
np.reshape([0, 16, 32, 48] + [self.init_val] * 6, [10, 1]),
|
||||
np.reshape([self.init_val] * 10, [10, 1]),
|
||||
np.reshape([1, 17, 33, 49] + [self.init_val] * 6, [10, 1]),
|
||||
np.reshape([self.init_val] * 10, [10, 1]),
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5 feature vocab + 5 feature OOV, 4 class vocab + 1 class OOV]. The
|
||||
# second partition has only OOV.
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='linear_all_oov/obtained_weight_matrix',
|
||||
shape=[10, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_load_and_remap_linear_multiclass_initializer_default_init(self):
|
||||
"""Tests where the zeros_initializer default is used for linear."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_row_oov_buckets=1,
|
||||
num_col_oov_buckets=1))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50, 0, 0], [6, 1]),
|
||||
np.reshape([0, 16, 32, 48, 0, 0], [6, 1]),
|
||||
np.reshape([0] * 6, [6, 1]),
|
||||
np.reshape([1, 17, 33, 49, 0, 0], [6, 1]),
|
||||
np.reshape([0] * 6, [6, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='linear_init_fallback/obtained_weight_matrix',
|
||||
shape=[6, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_load_embedding_initializer(self):
|
||||
"""Tests for the load_embedding initializer wrapper."""
|
||||
embedding_loading_initializer = (
|
||||
contrib_framework.load_embedding_initializer(
|
||||
new_vocab_file=self.new_feature_vocab_file,
|
||||
old_vocab_file=self.old_feature_vocab_file,
|
||||
new_vocab_size=5,
|
||||
embedding_dim=16,
|
||||
embedding_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
num_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_embeddings = np.concatenate(
|
||||
[
|
||||
np.reshape(range(64), [4, 16]),
|
||||
np.reshape([self.init_val] * 32, [2, 16]),
|
||||
],
|
||||
axis=0)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the
|
||||
# last vocab row (2nd last row) is newly initialized (wasn't found in
|
||||
# previous vocab) and the actual last row is OOV and also newly initialized.
|
||||
# Use a partitioned variable to confirm that the offset logic works.
|
||||
remapped_embeddings = variable_scope.get_variable(
|
||||
name='embedding/obtained_embedding_matrix',
|
||||
shape=[6, 16],
|
||||
initializer=embedding_loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_embeddings,
|
||||
remapped_embeddings.as_tensor().eval())
|
||||
|
||||
|
||||
class LoadMulticlassBiasTest(test.TestCase):
|
||||
"""Tests for the load_linear_multiclass_bias_initializer functionality."""
|
||||
|
||||
def setUp(self):
|
||||
ops.reset_default_graph()
|
||||
dim = 1
|
||||
num = 3
|
||||
with ops.name_scope('some_scope'):
|
||||
# Basically from 0 to dim*num-1.
|
||||
flat_data = math_ops.linspace(0.0, dim * num - 1, dim * num)
|
||||
bias = variables.Variable(
|
||||
array_ops.reshape(flat_data, (num, dim)), name='bias')
|
||||
save = saver.Saver([bias])
|
||||
with self.test_session() as sess:
|
||||
variables.global_variables_initializer().run()
|
||||
self.bundle_file = os.path.join(test.get_temp_dir(), 'bias_checkpoint')
|
||||
save.save(sess, self.bundle_file)
|
||||
|
||||
self.new_class_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt')
|
||||
self.old_class_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
|
||||
self.init_val = 42
|
||||
|
||||
def _init_val_initializer(shape, dtype=None, partition_info=None):
|
||||
del dtype, partition_info # Unused by this unit-testing initializer.
|
||||
return array_ops.tile(
|
||||
constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
|
||||
|
||||
self.initializer = _init_val_initializer
|
||||
|
||||
def test_load_linear_multiclass_bias_initializer(self):
|
||||
"""Tests for the bias initializer wrapper."""
|
||||
bias_loading_initializer = (
|
||||
contrib_framework.load_linear_multiclass_bias_initializer(
|
||||
new_class_vocab_file=self.new_class_vocab_file,
|
||||
old_class_vocab_file=self.old_class_vocab_file,
|
||||
new_class_vocab_size=4,
|
||||
bias_tensor_name='some_scope/bias',
|
||||
ckpt_path=[self.bundle_file],
|
||||
num_class_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_bias_vector = np.reshape(
|
||||
[2, 0, self.init_val, 1, self.init_val], [5, 1])
|
||||
|
||||
# The new bias vector is of size [4 class vocab + 1 class OOV, 1].
|
||||
remapped_bias_vector = variable_scope.get_variable(
|
||||
name='bias/obtained_bias_vector',
|
||||
shape=[5, 1],
|
||||
initializer=bias_loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(3))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_bias_vector,
|
||||
remapped_bias_vector.as_tensor().eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
BIN
tensorflow/contrib/framework/testdata/bundle_checkpoint.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/contrib/framework/testdata/bundle_checkpoint.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/contrib/framework/testdata/bundle_checkpoint.index
vendored
Normal file
BIN
tensorflow/contrib/framework/testdata/bundle_checkpoint.index
vendored
Normal file
Binary file not shown.
5
tensorflow/contrib/framework/testdata/bundle_checkpoint_vocab.txt
vendored
Normal file
5
tensorflow/contrib/framework/testdata/bundle_checkpoint_vocab.txt
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
zero
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
4
tensorflow/contrib/framework/testdata/bundle_checkpoint_vocab_with_oov.txt
vendored
Normal file
4
tensorflow/contrib/framework/testdata/bundle_checkpoint_vocab_with_oov.txt
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
zero
|
||||
one
|
||||
two
|
||||
three
|
3
tensorflow/contrib/framework/testdata/keyword.txt
vendored
Normal file
3
tensorflow/contrib/framework/testdata/keyword.txt
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
knitting
|
||||
eminem
|
||||
MISSING
|
4
tensorflow/contrib/framework/testdata/keyword_new.txt
vendored
Normal file
4
tensorflow/contrib/framework/testdata/keyword_new.txt
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
MISSING
|
||||
knitting
|
||||
flask
|
||||
eminem
|
3
tensorflow/contrib/framework/testdata/keyword_shifted.txt
vendored
Normal file
3
tensorflow/contrib/framework/testdata/keyword_shifted.txt
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
MISSING
|
||||
knitting
|
||||
eminem
|
@ -1336,6 +1336,11 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_header_only_library(
|
||||
name = "lookup_headers_lib",
|
||||
deps = [":lookup"],
|
||||
)
|
||||
|
||||
DATA_FLOW_DEPS = [
|
||||
":bounds_check",
|
||||
":concat_lib",
|
||||
|
@ -3,11 +3,16 @@
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
features = ["-parse_headers"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"cc_header_only_library",
|
||||
"tf_copts",
|
||||
)
|
||||
|
||||
# To be exported to tensorflow/core:mobile_srcs.
|
||||
filegroup(
|
||||
@ -43,6 +48,13 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_header_only_library(
|
||||
name = "tensor_bundle_headers_lib",
|
||||
deps = [
|
||||
":tensor_bundle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "naming",
|
||||
srcs = ["naming.cc"],
|
||||
|
@ -56,6 +56,7 @@ BLACKLIST = [
|
||||
"//tensorflow/contrib/factorization/examples:mnist",
|
||||
"//tensorflow/contrib/factorization/examples:mnist.py",
|
||||
"//tensorflow/contrib/factorization:factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", # pylint:disable=line-too-long
|
||||
"//tensorflow/contrib/framework:load_and_remap_matrix_testdata",
|
||||
"//tensorflow/contrib/bayesflow:reinforce_simple_example",
|
||||
"//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user