From 47b1af2a3a724a5d783ae06ca0e0e78b30e0799b Mon Sep 17 00:00:00 2001
From: Eddie Zhou <eddz@google.com>
Date: Wed, 5 Sep 2018 15:24:38 -0700
Subject: [PATCH] Expose an axis argument for VocabInfo, which allows for
 warm-starting of the second axis of Tensors through tf.train.warm_start. 
 Note that the underlying initializer already has this functionality (for
 example, for output layers).

PiperOrigin-RevId: 211709879
---
 tensorflow/python/estimator/estimator.py      |   2 +-
 tensorflow/python/training/checkpoint_ops.py  |   3 +-
 .../python/training/warm_starting_util.py     | 100 +++++++++++--
 .../training/warm_starting_util_test.py       | 140 ++++++++++++++++--
 .../v1/tensorflow.estimator.-vocab-info.pbtxt |   4 +
 .../v1/tensorflow.train.-vocab-info.pbtxt     |   4 +
 .../v2/tensorflow.estimator.-vocab-info.pbtxt |   4 +
 .../v2/tensorflow.train.-vocab-info.pbtxt     |   4 +
 8 files changed, 235 insertions(+), 26 deletions(-)

diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index e44a69b374c..0f20acefdfb 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -2056,7 +2056,7 @@ class WarmStartSettings(
     var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
       `tf.estimator.VocabInfo`. The variable names should be "full" variables,
       not the names of the partitions.  If not explicitly provided, the variable
-      is assumed to have no vocabulary.
+      is assumed to have no (changes to) vocabulary.
     var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
       name of the previously-trained variable in `ckpt_to_initialize_from`. If
       not explicitly provided, the name of the variable is assumed to be same
diff --git a/tensorflow/python/training/checkpoint_ops.py b/tensorflow/python/training/checkpoint_ops.py
index a6e9662b730..cfd9b39ddc4 100644
--- a/tensorflow/python/training/checkpoint_ops.py
+++ b/tensorflow/python/training/checkpoint_ops.py
@@ -268,7 +268,8 @@ def _load_and_remap_matrix_initializer(ckpt_path,
   vocab files are the same, and no column remapping is done.
 
   The returned initializer only supports div-partitioning along the row axis. It
-  does not support partitioning along the column axis or mod-partitioning.
+  does not support partitioning along the column axis (as this is not common in
+  practice) or mod-partitioning.
 
   NOTE: When this is used to warm-start variables, client code should use
   `tf.lookup.index_table_from_tensor()` like
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py
index c0dd46bfa5e..bea9bb6dffa 100644
--- a/tensorflow/python/training/warm_starting_util.py
+++ b/tensorflow/python/training/warm_starting_util.py
@@ -41,6 +41,7 @@ class VocabInfo(
         "old_vocab",
         "old_vocab_size",
         "backup_initializer",
+        "axis",
     ])):
   """Vocabulary information for warm-starting.
 
@@ -62,6 +63,42 @@ class VocabInfo(
     backup_initializer: [Optional] A variable initializer used for variables
       corresponding to new vocabulary entries and OOV. If not provided, these
       entries will be zero-initialized.
+    axis: [Optional] Denotes what axis the vocabulary corresponds to.  The
+      default, 0, corresponds to the most common use case (embeddings or
+      linear weights for binary classification / regression).  An axis of 1
+      could be used for warm-starting output layers with class vocabularies.
+
+      For example:
+
+      embeddings_vocab_info = tf.VocabInfo(
+          new_vocab='embeddings_vocab',
+          new_vocab_size=100,
+          num_oov_buckets=1,
+          old_vocab='pretrained_embeddings_vocab',
+          old_vocab_size=10000,
+          backup_initializer=tf.truncated_normal_initializer(
+              mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
+          axis=0)
+
+      softmax_output_layer_kernel_vocab_info = tf.VocabInfo(
+          new_vocab='class_vocab',
+          new_vocab_size=5,
+          num_oov_buckets=0,  # No OOV for classes.
+          old_vocab='old_class_vocab',
+          old_vocab_size=8,
+          backup_initializer=tf.glorot_uniform_initializer(),
+          axis=1)
+
+      softmax_output_layer_bias_vocab_info = tf.VocabInfo(
+          new_vocab='class_vocab',
+          new_vocab_size=5,
+          num_oov_buckets=0,  # No OOV for classes.
+          old_vocab='old_class_vocab',
+          old_vocab_size=8,
+          backup_initializer=tf.zeros_initializer(),
+          axis=0)
+
+      Currently, only axis=0 and axis=1 are supported.
   """
 
   def __new__(cls,
@@ -70,7 +107,12 @@ class VocabInfo(
               num_oov_buckets,
               old_vocab,
               old_vocab_size=-1,
-              backup_initializer=None):
+              backup_initializer=None,
+              axis=0):
+    if axis != 0 and axis != 1:
+      raise ValueError("The only supported values for the axis argument are 0 "
+                       "and 1.  Provided axis: {}".format(axis))
+
     return super(VocabInfo, cls).__new__(
         cls,
         new_vocab,
@@ -79,6 +121,7 @@ class VocabInfo(
         old_vocab,
         old_vocab_size,
         backup_initializer,
+        axis,
     )
 
 
@@ -149,7 +192,8 @@ def _warm_start_var_with_vocab(var,
                                previous_vocab_size=-1,
                                current_oov_buckets=0,
                                prev_tensor_name=None,
-                               initializer=None):
+                               initializer=None,
+                               axis=0):
   """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
 
   Use this method when the `var` is backed by vocabulary. This method stitches
@@ -180,6 +224,7 @@ def _warm_start_var_with_vocab(var,
       None, we lookup tensor with same name as given `var`.
     initializer: Variable initializer to be used for missing entries.  If None,
       missing entries will be zero-initialized.
+    axis: Axis of the variable that the provided vocabulary corresponds to.
 
   Raises:
     ValueError: If required args are not provided.
@@ -204,6 +249,8 @@ def _warm_start_var_with_vocab(var,
     # Assume tensor name remains the same.
     prev_tensor_name = _infer_var_name(var)
 
+  # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
+  total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var])
   for v in var:
     v_shape = v.get_shape().as_list()
     slice_info = v._get_save_slice_info()
@@ -213,19 +260,45 @@ def _warm_start_var_with_vocab(var,
           full_shape=slice_info.full_shape,
           var_offset=slice_info.var_offset)
 
-    # TODO(eddz): Support cases where class vocabularies need remapping too.
+    if axis == 0:
+      new_row_vocab_size = current_vocab_size
+      new_col_vocab_size = v_shape[1]
+      old_row_vocab_size = previous_vocab_size
+      old_row_vocab_file = prev_vocab_path
+      new_row_vocab_file = current_vocab_path
+      old_col_vocab_file = None
+      new_col_vocab_file = None
+      num_row_oov_buckets = current_oov_buckets
+      num_col_oov_buckets = 0
+    elif axis == 1:
+      # Note that we must compute this value across all partitions, whereas
+      # in the axis = 0 case, we can simply use v_shape[1] because we don't
+      # allow partitioning across axis = 1.
+      new_row_vocab_size = total_v_first_axis
+      new_col_vocab_size = current_vocab_size
+      old_row_vocab_size = -1
+      old_row_vocab_file = None
+      new_row_vocab_file = None
+      old_col_vocab_file = prev_vocab_path
+      new_col_vocab_file = current_vocab_path
+      num_row_oov_buckets = 0
+      num_col_oov_buckets = current_oov_buckets
+    else:
+      raise ValueError("The only supported values for the axis argument are 0 "
+                       "and 1.  Provided axis: {}".format(axis))
+
     init = checkpoint_ops._load_and_remap_matrix_initializer(
         ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
         old_tensor_name=prev_tensor_name,
-        new_row_vocab_size=current_vocab_size,
-        new_col_vocab_size=v_shape[1],
-        old_row_vocab_size=previous_vocab_size,
-        old_row_vocab_file=prev_vocab_path,
-        new_row_vocab_file=current_vocab_path,
-        old_col_vocab_file=None,
-        new_col_vocab_file=None,
-        num_row_oov_buckets=current_oov_buckets,
-        num_col_oov_buckets=0,
+        new_row_vocab_size=new_row_vocab_size,
+        new_col_vocab_size=new_col_vocab_size,
+        old_row_vocab_size=old_row_vocab_size,
+        old_row_vocab_file=old_row_vocab_file,
+        new_row_vocab_file=new_row_vocab_file,
+        old_col_vocab_file=old_col_vocab_file,
+        new_col_vocab_file=new_col_vocab_file,
+        num_row_oov_buckets=num_row_oov_buckets,
+        num_col_oov_buckets=num_col_oov_buckets,
         initializer=initializer)
     new_init_val = ops.convert_to_tensor(
         init(shape=v_shape, partition_info=partition_info))
@@ -374,7 +447,8 @@ def warm_start(ckpt_to_initialize_from,
           previous_vocab_size=vocab_info.old_vocab_size,
           current_oov_buckets=vocab_info.num_oov_buckets,
           prev_tensor_name=prev_var_name,
-          initializer=vocab_info.backup_initializer)
+          initializer=vocab_info.backup_initializer,
+          axis=vocab_info.axis)
     else:
       # For the special value of vars_to_warm_start = None,
       # we only warm-start variables with explicitly specified vocabularies.
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index 70a84bc3f6e..3ee0f6aaa2e 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -107,7 +107,7 @@ class WarmStartingUtilTest(test.TestCase):
             "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
         ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
         sess.run(variables.global_variables_initializer())
-        self.assertAllEqual(prev_val, fruit_weights.eval(sess))
+        self.assertAllClose(prev_val, fruit_weights.eval(sess))
 
   def testWarmStartVarPrevVarPartitioned(self):
     _, weights = self._create_prev_run_var(
@@ -123,7 +123,7 @@ class WarmStartingUtilTest(test.TestCase):
             "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
         ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
         sess.run(variables.global_variables_initializer())
-        self.assertAllEqual(prev_val, fruit_weights.eval(sess))
+        self.assertAllClose(prev_val, fruit_weights.eval(sess))
 
   def testWarmStartVarCurrentVarPartitioned(self):
     _, prev_val = self._create_prev_run_var(
@@ -143,7 +143,7 @@ class WarmStartingUtilTest(test.TestCase):
         fruit_weights = fruit_weights._get_variable_list()
         new_val = np.concatenate(
             [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
-        self.assertAllEqual(prev_val, new_val)
+        self.assertAllClose(prev_val, new_val)
 
   def testWarmStartVarBothVarsPartitioned(self):
     _, weights = self._create_prev_run_var(
@@ -170,7 +170,7 @@ class WarmStartingUtilTest(test.TestCase):
         fruit_weights = fruit_weights._get_variable_list()
         new_val = np.concatenate(
             [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
-        self.assertAllEqual(prev_val, new_val)
+        self.assertAllClose(prev_val, new_val)
 
   def testWarmStartVarWithVocab(self):
     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
@@ -189,9 +189,34 @@ class WarmStartingUtilTest(test.TestCase):
         ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
                                            self.get_temp_dir(), prev_vocab_path)
         sess.run(variables.global_variables_initializer())
-        self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
+        self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
                             fruit_weights.eval(sess))
 
+  def testWarmStartVarWithColumnVocab(self):
+    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+    self._create_prev_run_var(
+        "fruit_output_layer",
+        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
+
+    # New vocab with elements in reverse order and one new element.
+    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+                                       "new_vocab")
+    # New session and new graph.
+    with ops.Graph().as_default() as g:
+      with self.test_session(graph=g) as sess:
+        fruit_output_layer = variable_scope.get_variable(
+            "fruit_output_layer",
+            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+                         [0., 0., 0.]])
+        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+                                           current_vocab_size=3,
+                                           prev_ckpt=self.get_temp_dir(),
+                                           prev_vocab_path=prev_vocab_path,
+                                           axis=1)
+        sess.run(variables.global_variables_initializer())
+        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
+                             [2.3, 2., 0.]], fruit_output_layer.eval(sess))
+
   def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                         "old_vocab")
@@ -215,7 +240,7 @@ class WarmStartingUtilTest(test.TestCase):
             previous_vocab_size=2)
         sess.run(variables.global_variables_initializer())
         # Old vocabulary limited to ['apple', 'banana'].
-        self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]],
+        self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]],
                             fruit_weights.eval(sess))
 
   def testWarmStartVarWithVocabPrevVarPartitioned(self):
@@ -238,9 +263,36 @@ class WarmStartingUtilTest(test.TestCase):
         ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
                                            self.get_temp_dir(), prev_vocab_path)
         sess.run(variables.global_variables_initializer())
-        self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
+        self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
                             fruit_weights.eval(sess))
 
+  def testWarmStartVarWithColumnVocabPrevVarPartitioned(self):
+    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+    self._create_prev_run_var(
+        "fruit_output_layer",
+        shape=[4, 2],
+        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
+        partitioner=lambda shape, dtype: [2, 1])
+
+    # New vocab with elements in reverse order and one new element.
+    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+                                       "new_vocab")
+    # New session and new graph.
+    with ops.Graph().as_default() as g:
+      with self.test_session(graph=g) as sess:
+        fruit_output_layer = variable_scope.get_variable(
+            "fruit_output_layer",
+            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+                         [0., 0., 0.]])
+        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+                                           current_vocab_size=3,
+                                           prev_ckpt=self.get_temp_dir(),
+                                           prev_vocab_path=prev_vocab_path,
+                                           axis=1)
+        sess.run(variables.global_variables_initializer())
+        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
+                             [2.3, 2., 0.]], fruit_output_layer.eval(sess))
+
   def testWarmStartVarWithVocabCurrentVarPartitioned(self):
     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                         "old_vocab")
@@ -269,11 +321,43 @@ class WarmStartingUtilTest(test.TestCase):
         self.assertTrue(
             isinstance(fruit_weights, variables.PartitionedVariable))
         fruit_weights_vars = fruit_weights._get_variable_list()
-        self.assertAllEqual([[2.], [1.5], [1.]],
+        self.assertAllClose([[2.], [1.5], [1.]],
                             fruit_weights_vars[0].eval(sess))
-        self.assertAllEqual([[0.5], [0.], [0.]],
+        self.assertAllClose([[0.5], [0.], [0.]],
                             fruit_weights_vars[1].eval(sess))
 
+  def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self):
+    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+    self._create_prev_run_var(
+        "fruit_output_layer",
+        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
+
+    # New vocab with elements in reverse order and one new element.
+    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+                                       "new_vocab")
+    # New session and new graph.
+    with ops.Graph().as_default() as g:
+      with self.test_session(graph=g) as sess:
+        fruit_output_layer = variable_scope.get_variable(
+            "fruit_output_layer",
+            shape=[4, 3],
+            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+                         [0., 0., 0.]],
+            partitioner=lambda shape, dtype: [2, 1])
+        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+                                           current_vocab_size=3,
+                                           prev_ckpt=self.get_temp_dir(),
+                                           prev_vocab_path=prev_vocab_path,
+                                           axis=1)
+        sess.run(variables.global_variables_initializer())
+        self.assertTrue(
+            isinstance(fruit_output_layer, variables.PartitionedVariable))
+        fruit_output_layer_vars = fruit_output_layer._get_variable_list()
+        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
+                            fruit_output_layer_vars[0].eval(sess))
+        self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
+                            fruit_output_layer_vars[1].eval(sess))
+
   def testWarmStartVarWithVocabBothVarsPartitioned(self):
     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                         "old_vocab")
@@ -301,11 +385,45 @@ class WarmStartingUtilTest(test.TestCase):
         self.assertTrue(
             isinstance(fruit_weights, variables.PartitionedVariable))
         fruit_weights_vars = fruit_weights._get_variable_list()
-        self.assertAllEqual([[2.], [1.5], [1.]],
+        self.assertAllClose([[2.], [1.5], [1.]],
                             fruit_weights_vars[0].eval(sess))
-        self.assertAllEqual([[0.5], [0.], [0.]],
+        self.assertAllClose([[0.5], [0.], [0.]],
                             fruit_weights_vars[1].eval(sess))
 
+  def testWarmStartVarWithColumnVocabBothVarsPartitioned(self):
+    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+    self._create_prev_run_var(
+        "fruit_output_layer",
+        shape=[4, 2],
+        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
+        partitioner=lambda shape, dtype: [2, 1])
+
+    # New vocab with elements in reverse order and one new element.
+    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+                                       "new_vocab")
+    # New session and new graph.
+    with ops.Graph().as_default() as g:
+      with self.test_session(graph=g) as sess:
+        fruit_output_layer = variable_scope.get_variable(
+            "fruit_output_layer",
+            shape=[4, 3],
+            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+                         [0., 0., 0.]],
+            partitioner=lambda shape, dtype: [2, 1])
+        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+                                           current_vocab_size=3,
+                                           prev_ckpt=self.get_temp_dir(),
+                                           prev_vocab_path=prev_vocab_path,
+                                           axis=1)
+        sess.run(variables.global_variables_initializer())
+        self.assertTrue(
+            isinstance(fruit_output_layer, variables.PartitionedVariable))
+        fruit_output_layer_vars = fruit_output_layer._get_variable_list()
+        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
+                            fruit_output_layer_vars[0].eval(sess))
+        self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
+                            fruit_output_layer_vars[1].eval(sess))
+
   def testWarmStart_ListOfVariables(self):
     # Save checkpoint from which to warm-start.
     _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
index 5301b94eb36..b6942cb7ed8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
@@ -3,6 +3,10 @@ tf_class {
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<type \'tuple\'>"
+  member {
+    name: "axis"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "backup_initializer"
     mtype: "<type \'property\'>"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
index 4ce7cb11116..39b946b82f3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
@@ -3,6 +3,10 @@ tf_class {
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<type \'tuple\'>"
+  member {
+    name: "axis"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "backup_initializer"
     mtype: "<type \'property\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
index 5301b94eb36..b6942cb7ed8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
@@ -3,6 +3,10 @@ tf_class {
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<type \'tuple\'>"
+  member {
+    name: "axis"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "backup_initializer"
     mtype: "<type \'property\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
index 4ce7cb11116..39b946b82f3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
@@ -3,6 +3,10 @@ tf_class {
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<type \'tuple\'>"
+  member {
+    name: "axis"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "backup_initializer"
     mtype: "<type \'property\'>"