diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index e44a69b374c..0f20acefdfb 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -2056,7 +2056,7 @@ class WarmStartSettings( var_name_to_vocab_info: [Optional] Dict of variable names (strings) to `tf.estimator.VocabInfo`. The variable names should be "full" variables, not the names of the partitions. If not explicitly provided, the variable - is assumed to have no vocabulary. + is assumed to have no (changes to) vocabulary. var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to name of the previously-trained variable in `ckpt_to_initialize_from`. If not explicitly provided, the name of the variable is assumed to be same diff --git a/tensorflow/python/training/checkpoint_ops.py b/tensorflow/python/training/checkpoint_ops.py index a6e9662b730..cfd9b39ddc4 100644 --- a/tensorflow/python/training/checkpoint_ops.py +++ b/tensorflow/python/training/checkpoint_ops.py @@ -268,7 +268,8 @@ def _load_and_remap_matrix_initializer(ckpt_path, vocab files are the same, and no column remapping is done. The returned initializer only supports div-partitioning along the row axis. It - does not support partitioning along the column axis or mod-partitioning. + does not support partitioning along the column axis (as this is not common in + practice) or mod-partitioning. NOTE: When this is used to warm-start variables, client code should use `tf.lookup.index_table_from_tensor()` like diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py index c0dd46bfa5e..bea9bb6dffa 100644 --- a/tensorflow/python/training/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -41,6 +41,7 @@ class VocabInfo( "old_vocab", "old_vocab_size", "backup_initializer", + "axis", ])): """Vocabulary information for warm-starting. @@ -62,6 +63,42 @@ class VocabInfo( backup_initializer: [Optional] A variable initializer used for variables corresponding to new vocabulary entries and OOV. If not provided, these entries will be zero-initialized. + axis: [Optional] Denotes what axis the vocabulary corresponds to. The + default, 0, corresponds to the most common use case (embeddings or + linear weights for binary classification / regression). An axis of 1 + could be used for warm-starting output layers with class vocabularies. + + For example: + + embeddings_vocab_info = tf.VocabInfo( + new_vocab='embeddings_vocab', + new_vocab_size=100, + num_oov_buckets=1, + old_vocab='pretrained_embeddings_vocab', + old_vocab_size=10000, + backup_initializer=tf.truncated_normal_initializer( + mean=0.0, stddev=(1 / math.sqrt(embedding_dim))), + axis=0) + + softmax_output_layer_kernel_vocab_info = tf.VocabInfo( + new_vocab='class_vocab', + new_vocab_size=5, + num_oov_buckets=0, # No OOV for classes. + old_vocab='old_class_vocab', + old_vocab_size=8, + backup_initializer=tf.glorot_uniform_initializer(), + axis=1) + + softmax_output_layer_bias_vocab_info = tf.VocabInfo( + new_vocab='class_vocab', + new_vocab_size=5, + num_oov_buckets=0, # No OOV for classes. + old_vocab='old_class_vocab', + old_vocab_size=8, + backup_initializer=tf.zeros_initializer(), + axis=0) + + Currently, only axis=0 and axis=1 are supported. """ def __new__(cls, @@ -70,7 +107,12 @@ class VocabInfo( num_oov_buckets, old_vocab, old_vocab_size=-1, - backup_initializer=None): + backup_initializer=None, + axis=0): + if axis != 0 and axis != 1: + raise ValueError("The only supported values for the axis argument are 0 " + "and 1. Provided axis: {}".format(axis)) + return super(VocabInfo, cls).__new__( cls, new_vocab, @@ -79,6 +121,7 @@ class VocabInfo( old_vocab, old_vocab_size, backup_initializer, + axis, ) @@ -149,7 +192,8 @@ def _warm_start_var_with_vocab(var, previous_vocab_size=-1, current_oov_buckets=0, prev_tensor_name=None, - initializer=None): + initializer=None, + axis=0): """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. Use this method when the `var` is backed by vocabulary. This method stitches @@ -180,6 +224,7 @@ def _warm_start_var_with_vocab(var, None, we lookup tensor with same name as given `var`. initializer: Variable initializer to be used for missing entries. If None, missing entries will be zero-initialized. + axis: Axis of the variable that the provided vocabulary corresponds to. Raises: ValueError: If required args are not provided. @@ -204,6 +249,8 @@ def _warm_start_var_with_vocab(var, # Assume tensor name remains the same. prev_tensor_name = _infer_var_name(var) + # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases). + total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var]) for v in var: v_shape = v.get_shape().as_list() slice_info = v._get_save_slice_info() @@ -213,19 +260,45 @@ def _warm_start_var_with_vocab(var, full_shape=slice_info.full_shape, var_offset=slice_info.var_offset) - # TODO(eddz): Support cases where class vocabularies need remapping too. + if axis == 0: + new_row_vocab_size = current_vocab_size + new_col_vocab_size = v_shape[1] + old_row_vocab_size = previous_vocab_size + old_row_vocab_file = prev_vocab_path + new_row_vocab_file = current_vocab_path + old_col_vocab_file = None + new_col_vocab_file = None + num_row_oov_buckets = current_oov_buckets + num_col_oov_buckets = 0 + elif axis == 1: + # Note that we must compute this value across all partitions, whereas + # in the axis = 0 case, we can simply use v_shape[1] because we don't + # allow partitioning across axis = 1. + new_row_vocab_size = total_v_first_axis + new_col_vocab_size = current_vocab_size + old_row_vocab_size = -1 + old_row_vocab_file = None + new_row_vocab_file = None + old_col_vocab_file = prev_vocab_path + new_col_vocab_file = current_vocab_path + num_row_oov_buckets = 0 + num_col_oov_buckets = current_oov_buckets + else: + raise ValueError("The only supported values for the axis argument are 0 " + "and 1. Provided axis: {}".format(axis)) + init = checkpoint_ops._load_and_remap_matrix_initializer( ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), old_tensor_name=prev_tensor_name, - new_row_vocab_size=current_vocab_size, - new_col_vocab_size=v_shape[1], - old_row_vocab_size=previous_vocab_size, - old_row_vocab_file=prev_vocab_path, - new_row_vocab_file=current_vocab_path, - old_col_vocab_file=None, - new_col_vocab_file=None, - num_row_oov_buckets=current_oov_buckets, - num_col_oov_buckets=0, + new_row_vocab_size=new_row_vocab_size, + new_col_vocab_size=new_col_vocab_size, + old_row_vocab_size=old_row_vocab_size, + old_row_vocab_file=old_row_vocab_file, + new_row_vocab_file=new_row_vocab_file, + old_col_vocab_file=old_col_vocab_file, + new_col_vocab_file=new_col_vocab_file, + num_row_oov_buckets=num_row_oov_buckets, + num_col_oov_buckets=num_col_oov_buckets, initializer=initializer) new_init_val = ops.convert_to_tensor( init(shape=v_shape, partition_info=partition_info)) @@ -374,7 +447,8 @@ def warm_start(ckpt_to_initialize_from, previous_vocab_size=vocab_info.old_vocab_size, current_oov_buckets=vocab_info.num_oov_buckets, prev_tensor_name=prev_var_name, - initializer=vocab_info.backup_initializer) + initializer=vocab_info.backup_initializer, + axis=vocab_info.axis) else: # For the special value of vars_to_warm_start = None, # we only warm-start variables with explicitly specified vocabularies. diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py index 70a84bc3f6e..3ee0f6aaa2e 100644 --- a/tensorflow/python/training/warm_starting_util_test.py +++ b/tensorflow/python/training/warm_starting_util_test.py @@ -107,7 +107,7 @@ class WarmStartingUtilTest(test.TestCase): "fruit_weights", initializer=[[0.], [0.], [0.], [0.]]) ws_util._warm_start_var(fruit_weights, self.get_temp_dir()) sess.run(variables.global_variables_initializer()) - self.assertAllEqual(prev_val, fruit_weights.eval(sess)) + self.assertAllClose(prev_val, fruit_weights.eval(sess)) def testWarmStartVarPrevVarPartitioned(self): _, weights = self._create_prev_run_var( @@ -123,7 +123,7 @@ class WarmStartingUtilTest(test.TestCase): "fruit_weights", initializer=[[0.], [0.], [0.], [0.]]) ws_util._warm_start_var(fruit_weights, self.get_temp_dir()) sess.run(variables.global_variables_initializer()) - self.assertAllEqual(prev_val, fruit_weights.eval(sess)) + self.assertAllClose(prev_val, fruit_weights.eval(sess)) def testWarmStartVarCurrentVarPartitioned(self): _, prev_val = self._create_prev_run_var( @@ -143,7 +143,7 @@ class WarmStartingUtilTest(test.TestCase): fruit_weights = fruit_weights._get_variable_list() new_val = np.concatenate( [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0) - self.assertAllEqual(prev_val, new_val) + self.assertAllClose(prev_val, new_val) def testWarmStartVarBothVarsPartitioned(self): _, weights = self._create_prev_run_var( @@ -170,7 +170,7 @@ class WarmStartingUtilTest(test.TestCase): fruit_weights = fruit_weights._get_variable_list() new_val = np.concatenate( [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0) - self.assertAllEqual(prev_val, new_val) + self.assertAllClose(prev_val, new_val) def testWarmStartVarWithVocab(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], @@ -189,9 +189,34 @@ class WarmStartingUtilTest(test.TestCase): ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5, self.get_temp_dir(), prev_vocab_path) sess.run(variables.global_variables_initializer()) - self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]], + self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]], fruit_weights.eval(sess)) + def testWarmStartVarWithColumnVocab(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.], + [2.3, 2., 0.]], fruit_output_layer.eval(sess)) + def testWarmStartVarWithVocabConstrainedOldVocabSize(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") @@ -215,7 +240,7 @@ class WarmStartingUtilTest(test.TestCase): previous_vocab_size=2) sess.run(variables.global_variables_initializer()) # Old vocabulary limited to ['apple', 'banana']. - self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]], + self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]], fruit_weights.eval(sess)) def testWarmStartVarWithVocabPrevVarPartitioned(self): @@ -238,9 +263,36 @@ class WarmStartingUtilTest(test.TestCase): ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5, self.get_temp_dir(), prev_vocab_path) sess.run(variables.global_variables_initializer()) - self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]], + self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]], fruit_weights.eval(sess)) + def testWarmStartVarWithColumnVocabPrevVarPartitioned(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + shape=[4, 2], + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]], + partitioner=lambda shape, dtype: [2, 1]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.], + [2.3, 2., 0.]], fruit_output_layer.eval(sess)) + def testWarmStartVarWithVocabCurrentVarPartitioned(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") @@ -269,11 +321,43 @@ class WarmStartingUtilTest(test.TestCase): self.assertTrue( isinstance(fruit_weights, variables.PartitionedVariable)) fruit_weights_vars = fruit_weights._get_variable_list() - self.assertAllEqual([[2.], [1.5], [1.]], + self.assertAllClose([[2.], [1.5], [1.]], fruit_weights_vars[0].eval(sess)) - self.assertAllEqual([[0.5], [0.], [0.]], + self.assertAllClose([[0.5], [0.], [0.]], fruit_weights_vars[1].eval(sess)) + def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + shape=[4, 3], + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]], + partitioner=lambda shape, dtype: [2, 1]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertTrue( + isinstance(fruit_output_layer, variables.PartitionedVariable)) + fruit_output_layer_vars = fruit_output_layer._get_variable_list() + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]], + fruit_output_layer_vars[0].eval(sess)) + self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]], + fruit_output_layer_vars[1].eval(sess)) + def testWarmStartVarWithVocabBothVarsPartitioned(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") @@ -301,11 +385,45 @@ class WarmStartingUtilTest(test.TestCase): self.assertTrue( isinstance(fruit_weights, variables.PartitionedVariable)) fruit_weights_vars = fruit_weights._get_variable_list() - self.assertAllEqual([[2.], [1.5], [1.]], + self.assertAllClose([[2.], [1.5], [1.]], fruit_weights_vars[0].eval(sess)) - self.assertAllEqual([[0.5], [0.], [0.]], + self.assertAllClose([[0.5], [0.], [0.]], fruit_weights_vars[1].eval(sess)) + def testWarmStartVarWithColumnVocabBothVarsPartitioned(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + shape=[4, 2], + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]], + partitioner=lambda shape, dtype: [2, 1]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + shape=[4, 3], + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]], + partitioner=lambda shape, dtype: [2, 1]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertTrue( + isinstance(fruit_output_layer, variables.PartitionedVariable)) + fruit_output_layer_vars = fruit_output_layer._get_variable_list() + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]], + fruit_output_layer_vars[0].eval(sess)) + self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]], + fruit_output_layer_vars[1].eval(sess)) + def testWarmStart_ListOfVariables(self): # Save checkpoint from which to warm-start. _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1], diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt index 5301b94eb36..b6942cb7ed8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "axis" + mtype: "" + } member { name: "backup_initializer" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt index 4ce7cb11116..39b946b82f3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "axis" + mtype: "" + } member { name: "backup_initializer" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt index 5301b94eb36..b6942cb7ed8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "axis" + mtype: "" + } member { name: "backup_initializer" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt index 4ce7cb11116..39b946b82f3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "axis" + mtype: "" + } member { name: "backup_initializer" mtype: ""