Expose an axis argument for VocabInfo, which allows for warm-starting of the second axis of Tensors through tf.train.warm_start. Note that the underlying initializer already has this functionality (for example, for output layers).

PiperOrigin-RevId: 211709879
This commit is contained in:
Eddie Zhou 2018-09-05 15:24:38 -07:00 committed by TensorFlower Gardener
parent ebf6d259fd
commit 47b1af2a3a
8 changed files with 235 additions and 26 deletions

View File

@ -2056,7 +2056,7 @@ class WarmStartSettings(
var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
`tf.estimator.VocabInfo`. The variable names should be "full" variables,
not the names of the partitions. If not explicitly provided, the variable
is assumed to have no vocabulary.
is assumed to have no (changes to) vocabulary.
var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
name of the previously-trained variable in `ckpt_to_initialize_from`. If
not explicitly provided, the name of the variable is assumed to be same

View File

@ -268,7 +268,8 @@ def _load_and_remap_matrix_initializer(ckpt_path,
vocab files are the same, and no column remapping is done.
The returned initializer only supports div-partitioning along the row axis. It
does not support partitioning along the column axis or mod-partitioning.
does not support partitioning along the column axis (as this is not common in
practice) or mod-partitioning.
NOTE: When this is used to warm-start variables, client code should use
`tf.lookup.index_table_from_tensor()` like

View File

@ -41,6 +41,7 @@ class VocabInfo(
"old_vocab",
"old_vocab_size",
"backup_initializer",
"axis",
])):
"""Vocabulary information for warm-starting.
@ -62,6 +63,42 @@ class VocabInfo(
backup_initializer: [Optional] A variable initializer used for variables
corresponding to new vocabulary entries and OOV. If not provided, these
entries will be zero-initialized.
axis: [Optional] Denotes what axis the vocabulary corresponds to. The
default, 0, corresponds to the most common use case (embeddings or
linear weights for binary classification / regression). An axis of 1
could be used for warm-starting output layers with class vocabularies.
For example:
embeddings_vocab_info = tf.VocabInfo(
new_vocab='embeddings_vocab',
new_vocab_size=100,
num_oov_buckets=1,
old_vocab='pretrained_embeddings_vocab',
old_vocab_size=10000,
backup_initializer=tf.truncated_normal_initializer(
mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
axis=0)
softmax_output_layer_kernel_vocab_info = tf.VocabInfo(
new_vocab='class_vocab',
new_vocab_size=5,
num_oov_buckets=0, # No OOV for classes.
old_vocab='old_class_vocab',
old_vocab_size=8,
backup_initializer=tf.glorot_uniform_initializer(),
axis=1)
softmax_output_layer_bias_vocab_info = tf.VocabInfo(
new_vocab='class_vocab',
new_vocab_size=5,
num_oov_buckets=0, # No OOV for classes.
old_vocab='old_class_vocab',
old_vocab_size=8,
backup_initializer=tf.zeros_initializer(),
axis=0)
Currently, only axis=0 and axis=1 are supported.
"""
def __new__(cls,
@ -70,7 +107,12 @@ class VocabInfo(
num_oov_buckets,
old_vocab,
old_vocab_size=-1,
backup_initializer=None):
backup_initializer=None,
axis=0):
if axis != 0 and axis != 1:
raise ValueError("The only supported values for the axis argument are 0 "
"and 1. Provided axis: {}".format(axis))
return super(VocabInfo, cls).__new__(
cls,
new_vocab,
@ -79,6 +121,7 @@ class VocabInfo(
old_vocab,
old_vocab_size,
backup_initializer,
axis,
)
@ -149,7 +192,8 @@ def _warm_start_var_with_vocab(var,
previous_vocab_size=-1,
current_oov_buckets=0,
prev_tensor_name=None,
initializer=None):
initializer=None,
axis=0):
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
Use this method when the `var` is backed by vocabulary. This method stitches
@ -180,6 +224,7 @@ def _warm_start_var_with_vocab(var,
None, we lookup tensor with same name as given `var`.
initializer: Variable initializer to be used for missing entries. If None,
missing entries will be zero-initialized.
axis: Axis of the variable that the provided vocabulary corresponds to.
Raises:
ValueError: If required args are not provided.
@ -204,6 +249,8 @@ def _warm_start_var_with_vocab(var,
# Assume tensor name remains the same.
prev_tensor_name = _infer_var_name(var)
# TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var])
for v in var:
v_shape = v.get_shape().as_list()
slice_info = v._get_save_slice_info()
@ -213,19 +260,45 @@ def _warm_start_var_with_vocab(var,
full_shape=slice_info.full_shape,
var_offset=slice_info.var_offset)
# TODO(eddz): Support cases where class vocabularies need remapping too.
if axis == 0:
new_row_vocab_size = current_vocab_size
new_col_vocab_size = v_shape[1]
old_row_vocab_size = previous_vocab_size
old_row_vocab_file = prev_vocab_path
new_row_vocab_file = current_vocab_path
old_col_vocab_file = None
new_col_vocab_file = None
num_row_oov_buckets = current_oov_buckets
num_col_oov_buckets = 0
elif axis == 1:
# Note that we must compute this value across all partitions, whereas
# in the axis = 0 case, we can simply use v_shape[1] because we don't
# allow partitioning across axis = 1.
new_row_vocab_size = total_v_first_axis
new_col_vocab_size = current_vocab_size
old_row_vocab_size = -1
old_row_vocab_file = None
new_row_vocab_file = None
old_col_vocab_file = prev_vocab_path
new_col_vocab_file = current_vocab_path
num_row_oov_buckets = 0
num_col_oov_buckets = current_oov_buckets
else:
raise ValueError("The only supported values for the axis argument are 0 "
"and 1. Provided axis: {}".format(axis))
init = checkpoint_ops._load_and_remap_matrix_initializer(
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
old_tensor_name=prev_tensor_name,
new_row_vocab_size=current_vocab_size,
new_col_vocab_size=v_shape[1],
old_row_vocab_size=previous_vocab_size,
old_row_vocab_file=prev_vocab_path,
new_row_vocab_file=current_vocab_path,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=current_oov_buckets,
num_col_oov_buckets=0,
new_row_vocab_size=new_row_vocab_size,
new_col_vocab_size=new_col_vocab_size,
old_row_vocab_size=old_row_vocab_size,
old_row_vocab_file=old_row_vocab_file,
new_row_vocab_file=new_row_vocab_file,
old_col_vocab_file=old_col_vocab_file,
new_col_vocab_file=new_col_vocab_file,
num_row_oov_buckets=num_row_oov_buckets,
num_col_oov_buckets=num_col_oov_buckets,
initializer=initializer)
new_init_val = ops.convert_to_tensor(
init(shape=v_shape, partition_info=partition_info))
@ -374,7 +447,8 @@ def warm_start(ckpt_to_initialize_from,
previous_vocab_size=vocab_info.old_vocab_size,
current_oov_buckets=vocab_info.num_oov_buckets,
prev_tensor_name=prev_var_name,
initializer=vocab_info.backup_initializer)
initializer=vocab_info.backup_initializer,
axis=vocab_info.axis)
else:
# For the special value of vars_to_warm_start = None,
# we only warm-start variables with explicitly specified vocabularies.

View File

@ -107,7 +107,7 @@ class WarmStartingUtilTest(test.TestCase):
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
sess.run(variables.global_variables_initializer())
self.assertAllEqual(prev_val, fruit_weights.eval(sess))
self.assertAllClose(prev_val, fruit_weights.eval(sess))
def testWarmStartVarPrevVarPartitioned(self):
_, weights = self._create_prev_run_var(
@ -123,7 +123,7 @@ class WarmStartingUtilTest(test.TestCase):
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
sess.run(variables.global_variables_initializer())
self.assertAllEqual(prev_val, fruit_weights.eval(sess))
self.assertAllClose(prev_val, fruit_weights.eval(sess))
def testWarmStartVarCurrentVarPartitioned(self):
_, prev_val = self._create_prev_run_var(
@ -143,7 +143,7 @@ class WarmStartingUtilTest(test.TestCase):
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
self.assertAllEqual(prev_val, new_val)
self.assertAllClose(prev_val, new_val)
def testWarmStartVarBothVarsPartitioned(self):
_, weights = self._create_prev_run_var(
@ -170,7 +170,7 @@ class WarmStartingUtilTest(test.TestCase):
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
self.assertAllEqual(prev_val, new_val)
self.assertAllClose(prev_val, new_val)
def testWarmStartVarWithVocab(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
@ -189,9 +189,34 @@ class WarmStartingUtilTest(test.TestCase):
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
def testWarmStartVarWithColumnVocab(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
sess.run(variables.global_variables_initializer())
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
[2.3, 2., 0.]], fruit_output_layer.eval(sess))
def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@ -215,7 +240,7 @@ class WarmStartingUtilTest(test.TestCase):
previous_vocab_size=2)
sess.run(variables.global_variables_initializer())
# Old vocabulary limited to ['apple', 'banana'].
self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]],
self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
def testWarmStartVarWithVocabPrevVarPartitioned(self):
@ -238,9 +263,36 @@ class WarmStartingUtilTest(test.TestCase):
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
def testWarmStartVarWithColumnVocabPrevVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
shape=[4, 2],
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
partitioner=lambda shape, dtype: [2, 1])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
sess.run(variables.global_variables_initializer())
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
[2.3, 2., 0.]], fruit_output_layer.eval(sess))
def testWarmStartVarWithVocabCurrentVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@ -269,11 +321,43 @@ class WarmStartingUtilTest(test.TestCase):
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
self.assertAllEqual([[2.], [1.5], [1.]],
self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
self.assertAllEqual([[0.5], [0.], [0.]],
self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
shape=[4, 3],
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]],
partitioner=lambda shape, dtype: [2, 1])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
sess.run(variables.global_variables_initializer())
self.assertTrue(
isinstance(fruit_output_layer, variables.PartitionedVariable))
fruit_output_layer_vars = fruit_output_layer._get_variable_list()
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
fruit_output_layer_vars[0].eval(sess))
self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
fruit_output_layer_vars[1].eval(sess))
def testWarmStartVarWithVocabBothVarsPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@ -301,11 +385,45 @@ class WarmStartingUtilTest(test.TestCase):
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
self.assertAllEqual([[2.], [1.5], [1.]],
self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
self.assertAllEqual([[0.5], [0.], [0.]],
self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
def testWarmStartVarWithColumnVocabBothVarsPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
shape=[4, 2],
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
partitioner=lambda shape, dtype: [2, 1])
# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
shape=[4, 3],
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]],
partitioner=lambda shape, dtype: [2, 1])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
sess.run(variables.global_variables_initializer())
self.assertTrue(
isinstance(fruit_output_layer, variables.PartitionedVariable))
fruit_output_layer_vars = fruit_output_layer._get_variable_list()
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
fruit_output_layer_vars[0].eval(sess))
self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
fruit_output_layer_vars[1].eval(sess))
def testWarmStart_ListOfVariables(self):
# Save checkpoint from which to warm-start.
_, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],

View File

@ -3,6 +3,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
name: "axis"
mtype: "<type \'property\'>"
}
member {
name: "backup_initializer"
mtype: "<type \'property\'>"

View File

@ -3,6 +3,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
name: "axis"
mtype: "<type \'property\'>"
}
member {
name: "backup_initializer"
mtype: "<type \'property\'>"

View File

@ -3,6 +3,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
name: "axis"
mtype: "<type \'property\'>"
}
member {
name: "backup_initializer"
mtype: "<type \'property\'>"

View File

@ -3,6 +3,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
name: "axis"
mtype: "<type \'property\'>"
}
member {
name: "backup_initializer"
mtype: "<type \'property\'>"