From b220af894a67d875b8059cfa78f0cb28fa72f834 Mon Sep 17 00:00:00 2001
From: Revan Sopher <rsopher@google.com>
Date: Wed, 1 Apr 2020 13:23:59 -0700
Subject: [PATCH] Add EnqueueTPUEmbeddingRaggedTensorBatch for RaggedTensor
 support.

PiperOrigin-RevId: 304250071
Change-Id: If1f0d7a8716c95a090f28d085a46ffa9c3e9053e
---
 ...EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt |  77 +++++++
 tensorflow/core/ops/ops.pbtxt                 |  95 ++++++++
 tensorflow/core/ops/tpu_embedding_ops.cc      |  16 ++
 tensorflow/python/tpu/ops/tpu_ops.py          |  73 +++++++
 tensorflow/python/tpu/tpu_embedding.py        | 202 +++++++++++++++---
 .../api/golden/v1/tensorflow.raw_ops.pbtxt    |   4 +
 .../api/golden/v2/tensorflow.raw_ops.pbtxt    |   4 +
 7 files changed, 441 insertions(+), 30 deletions(-)
 create mode 100644 tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt

diff --git a/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt
new file mode 100644
index 00000000000..cdcdd6d06b0
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt
@@ -0,0 +1,77 @@
+op {
+  graph_op_name: "EnqueueTPUEmbeddingRaggedTensorBatch"
+  visibility: HIDDEN
+  in_arg {
+    name: "sample_splits"
+    description: <<END
+A list of rank 1 Tensors specifying the break points for splitting
+embedding_indices and aggregation_weights into rows.
+It corresponds to ids.row_splits in embedding_lookup(), when ids is a
+RaggedTensor.
+END
+  }
+  in_arg {
+    name: "embedding_indices"
+    description: <<END
+A list of rank 1 Tensors, indices into the embedding tables.
+It corresponds to ids.values in embedding_lookup(), when ids is a RaggedTensor.
+END
+  }
+  in_arg {
+    name: "aggregation_weights"
+    description: <<END
+A list of rank 1 Tensors containing per training example
+aggregation weights. It corresponds to the values field of a RaggedTensor
+with the same row_splits as ids in embedding_lookup(), when ids is a
+RaggedTensor.
+END
+  }
+  in_arg {
+    name: "mode_override"
+    description: <<END
+A string input that overrides the mode specified in the
+TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
+'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
+in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
+END
+  }
+  attr {
+    name: "device_ordinal"
+    description: <<END
+The TPU device to use. Should be >= 0 and less than the number
+of TPU cores in the task on which the node is placed.
+END
+  }
+  attr {
+    name: "combiners"
+    description: <<END
+A list of string scalars, one for each embedding table that specify
+how to normalize the embedding activations after weighted summation.
+Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have
+the sum of the weights be 0 for 'mean' or the sum of the squared weights be
+0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for
+all tables.
+END
+  }
+  attr {
+    name: "table_ids"
+    description: <<END
+A list of integers specifying the identifier of the embedding table
+(offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the
+corresponding input. The ith input is looked up using table_ids[i]. The size
+of the table_ids list must be equal to that of sample_indices,
+embedding_indices and aggregation_weights.
+END
+  }
+  summary: "Eases the porting of code that uses tf.nn.embedding_lookup()."
+  description: <<END
+sample_splits[i], embedding_indices[i] and aggregation_weights[i] correspond
+to the ith feature. table_ids[i] indicates which embedding table to look up ith
+feature.
+
+The tensors at corresponding positions in two of the input lists,
+embedding_indices and aggregation_weights, must have the same shape, i.e. rank 1
+with dim_size() equal to the total number of lookups into the table described by
+the corresponding feature.
+END
+}
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index c4d9daeeb47..91750889b95 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -13184,6 +13184,101 @@ op {
   }
   is_stateful: true
 }
+op {
+  name: "EnqueueTPUEmbeddingRaggedTensorBatch"
+  input_arg {
+    name: "sample_indices"
+    type_attr: "T1"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "embedding_indices"
+    type_attr: "T2"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "aggregation_weights"
+    type_attr: "T3"
+    number_attr: "N"
+  }
+  input_arg {
+    name: "mode_override"
+    type: DT_STRING
+  }
+  attr {
+    name: "T1"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T2"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  attr {
+    name: "T3"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "device_ordinal"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "combiners"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "table_ids"
+    type: "list(int)"
+  }
+  attr {
+    name: "max_sequence_lengths"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  is_stateful: true
+}
 op {
   name: "EnsureShape"
   input_arg {
diff --git a/tensorflow/core/ops/tpu_embedding_ops.cc b/tensorflow/core/ops/tpu_embedding_ops.cc
index 821dff7c64a..164d78e8e9e 100644
--- a/tensorflow/core/ops/tpu_embedding_ops.cc
+++ b/tensorflow/core/ops/tpu_embedding_ops.cc
@@ -168,4 +168,20 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
     .SetIsStateful()
     .SetShapeFn(shape_inference::UnknownShape);
 
+REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch")
+    .Input("sample_splits: N * T1")
+    .Input("embedding_indices: N * T2")
+    .Input("aggregation_weights: N * T3")
+    .Input("mode_override: string")
+    .Attr("T1: {int32,int64} = DT_INT32")
+    .Attr("T2: {int32,int64} = DT_INT32")
+    .Attr("T3: {float32,float64} = DT_FLOAT")
+    .Attr("N: int >= 1")
+    .Attr("device_ordinal: int = -1")
+    .Attr("combiners: list(string) = []")
+    .Attr("table_ids: list(int)")
+    .Attr("max_sequence_lengths: list(int) = []")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::UnknownShape);
+
 }  // namespace tensorflow
diff --git a/tensorflow/python/tpu/ops/tpu_ops.py b/tensorflow/python/tpu/ops/tpu_ops.py
index c1ea3641757..8facb1fdad7 100644
--- a/tensorflow/python/tpu/ops/tpu_ops.py
+++ b/tensorflow/python/tpu/ops/tpu_ops.py
@@ -444,3 +444,76 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
 
 enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
     gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
+
+
+# pylint: disable=protected-access
+def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
+                                              embedding_indices,
+                                              aggregation_weights,
+                                              table_ids,
+                                              device_ordinal,
+                                              max_sequence_lengths=None,
+                                              combiners=None,
+                                              mode_override=None,
+                                              name=None):
+  """A placeholder op for enqueueing embedding IDs to the TPU.
+
+  Args:
+    sample_splits: A list of rank 1 Tensors specifying the break points for
+      splitting embedding_indices and aggregation_weights into rows. It
+      corresponds to ids.row_splits in embedding_lookup(), when ids is a
+      RaggedTensor. Both int32 and int64 are allowed and will be converted to
+      int32 internally.
+    embedding_indices: A list of rank 1 Tensors, indices into the embedding
+      tables. It corresponds to ids.values in embedding_lookup(), when ids is a
+      RaggedTensor. Both int32 and int64 are allowed and will be converted to
+      int32 internally.
+    aggregation_weights: A list of rank 1 Tensors containing per training
+      example aggregation weights. It corresponds to the values field of a
+      RaggedTensor with the same row_splits as ids in embedding_lookup(), when
+      ids is a RaggedTensor. Both float32 and float64 are allowed and will be
+      converted to float32 internally.
+    table_ids: A list of integers specifying the identifier of the embedding
+      table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
+      lookup the corresponding input. The ith input is looked up using
+      table_ids[i]. The size of the table_ids list must be equal to that of
+      sample_indices, embedding_indices and aggregation_weights.
+    device_ordinal: The TPU device to use. Should be >= 0 and less than the
+      number of TPU cores in the task on which the node is placed.
+    max_sequence_lengths: A list of integers, the size of which is equal to
+      sample_indices. If equal to 0, the corresponding feature is considered to
+      be a non-sequence feature, If greater than 0, the corresponding feature is
+      a sequence feature with the given maximal length. If None, then we assume
+      a list of all zeroes.
+    combiners: A list of string scalars, one for each embedding table that
+      specify how to normalize the embedding activations after weighted
+      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
+      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
+      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
+      is to use 'sum' for all tables (optional).
+    mode_override: A string input that overrides the mode specified in the
+      TPUEmbeddingConfiguration. Supported values are {'unspecified',
+      'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
+      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
+      is used (optional).
+    name: A name for the operation (optional).
+
+  Returns:
+    An EnqueueTPUEmbeddingRaggedTensorBatch operation.
+  """
+  if mode_override is None:
+    mode_override = "unspecified"
+  return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
+      sample_splits=sample_splits,
+      embedding_indices=embedding_indices,
+      aggregation_weights=aggregation_weights,
+      table_ids=table_ids,
+      device_ordinal=device_ordinal,
+      max_sequence_lengths=max_sequence_lengths,
+      combiners=combiners,
+      mode_override=mode_override,
+      name=name)
+
+
+enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = (
+    gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__)
diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py
index e3dbe7fb93f..e24188eaf16 100644
--- a/tensorflow/python/tpu/tpu_embedding.py
+++ b/tensorflow/python/tpu/tpu_embedding.py
@@ -205,6 +205,48 @@ class EnqueueData(
         aggregation_weights=weights.values if weights is not None else None)
 
 
+class RaggedEnqueueData(
+    collections.namedtuple(
+        'RaggedEnqueueData',
+        ['embedding_indices', 'sample_splits', 'aggregation_weights'])):
+  """RaggedTensor Data to be enqueued through generate_enqueue_ops()."""
+
+  def __new__(cls,
+              embedding_indices,
+              sample_splits=None,
+              aggregation_weights=None):
+    """Data to be enqueued through generate_enqueue_ops().
+
+    Args:
+      embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
+        corresponds to ids.values in embedding_lookup(), when ids is a
+        RaggedTensor. Both int32 and int64 are allowed and will be converted to
+        int32 internally.
+      sample_splits: A rank 1 Tensor specifying the break points for splitting
+        embedding_indices and aggregation_weights into rows. It corresponds to
+        ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both
+        int32 and int64 are allowed and will be converted to int32 internally.
+      aggregation_weights: A rank 1 Tensor containing per training example
+        aggregation weights. It corresponds to the values field of a
+        RaggedTensor with the same row_splits as ids in embedding_lookup(), when
+        ids is a RaggedTensor.
+
+    Returns:
+      An RaggedEnqueueData tuple.
+
+    """
+    return super(RaggedEnqueueData,
+                 cls).__new__(cls, embedding_indices, sample_splits,
+                              aggregation_weights)
+
+  @staticmethod
+  def from_ragged_tensor(rg_tensor, weights=None):
+    return RaggedEnqueueData(
+        rg_tensor.values,
+        rg_tensor.row_splits,
+        aggregation_weights=weights.values if weights is not None else None)
+
+
 def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
   """Convenient function for generate_enqueue_ops().
 
@@ -229,6 +271,30 @@ def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
   return enqueue_datas_list
 
 
+def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
+  """Convenient function for generate_enqueue_ops().
+
+  Args:
+    rg_tensors_list: a list of dictionary mapping from string of feature names
+      to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the
+      same host should be contiguous on the list.
+
+  Returns:
+    enqueue_datas_list: a list of dictionary mapping from string
+      of feature names to RaggedEnqueueData. Each dictionary is for one
+      TPU core. Dictionaries for the same host should be contiguous
+      on the list.
+
+  """
+  enqueue_datas_list = []
+  for rg_tensors in rg_tensors_list:
+    enqueue_datas = collections.OrderedDict(
+        (k, RaggedEnqueueData.from_ragged_tensor(v))
+        for k, v in six.iteritems(rg_tensors))
+    enqueue_datas_list.append(enqueue_datas)
+  return enqueue_datas_list
+
+
 AdamSlotVariableNames = collections.namedtuple(
     'AdamSlotVariableNames', ['m', 'v'])
 
@@ -1159,7 +1225,12 @@ class TPUEmbedding(object):
                            slot_variables_by_table,
                            load_ops, retrieve_ops)
 
-  def generate_enqueue_ops(self, enqueue_datas_list, mode_override=None):
+  def generate_enqueue_ops(
+      self,
+      enqueue_datas_list,
+      mode_override=None,
+      ragged=False,
+  ):
     """Generate enqueue ops.
 
     Args:
@@ -1172,6 +1243,8 @@ class TPUEmbedding(object):
         'inference', 'training', 'backward_pass_only'}. When set to
         'unspecified', the mode set in TPUEmbeddingConfiguration is used,
         otherwise mode_override is used (optional).
+      ragged: If True, creates RaggedTensor enqueue ops rather than
+        SparseTensor.
 
     Returns:
       Ops to enqueue to TPU for embedding.
@@ -1182,6 +1255,7 @@ class TPUEmbedding(object):
             enqueue_datas,
             device_ordinal=i % self._num_cores_per_host,
             mode_override=mode_override,
+            ragged=ragged,
         ) for i, enqueue_datas in enumerate(enqueue_datas_list)
     ]
 
@@ -1211,28 +1285,50 @@ class TPUEmbedding(object):
       for feature, enqueue_data in six.iteritems(enqueue_datas):
         combiner = self._table_to_config_dict[
             self._feature_to_config_dict[feature].table_id].combiner
-        if not isinstance(enqueue_data, EnqueueData):
-          raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
-                           'not mapped to `EnqueueData`. `feature`: {}'.format(
-                               i, feature))
 
-        if enqueue_data.sample_indices is None and combiner:
-          logging.warn('No sample indices set for features %f table %f but '
-                       'combiner is set to %s.', feature,
-                       self._feature_to_config_dict[feature].table_id, combiner)
+        if isinstance(enqueue_data, EnqueueData):
+          if enqueue_data.sample_indices is None and combiner:
+            logging.warn(
+                'No sample indices set for features %f table %f but '
+                'combiner is set to %s.', feature,
+                self._feature_to_config_dict[feature].table_id, combiner)
+          if (enqueue_data.sample_indices is not None and
+              enqueue_data.sample_indices.device !=
+              enqueue_data.embedding_indices.device):
+            raise ValueError(
+                'Device of sample_indices does not agree with '
+                'that of embedding_indices for feature {}.'.format(feature))
+          if (enqueue_data.aggregation_weights is not None and
+              enqueue_data.aggregation_weights.device !=
+              enqueue_data.embedding_indices.device):
+            raise ValueError(
+                'Device of aggregation_weights does not agree with '
+                'that of embedding_indices for feature {}.'.format(feature))
 
-        if (enqueue_data.sample_indices is not None and
-            enqueue_data.sample_indices.device !=
-            enqueue_data.embedding_indices.device):
+        elif isinstance(enqueue_data, RaggedEnqueueData):
+          if enqueue_data.sample_splits is None and combiner:
+            logging.warn(
+                'No sample splits set for features %f table %f but '
+                'combiner is set to %s.', feature,
+                self._feature_to_config_dict[feature].table_id, combiner)
+          if (enqueue_data.sample_splits is not None and
+              enqueue_data.sample_splits.device !=
+              enqueue_data.embedding_indices.device):
+            raise ValueError(
+                'Device of sample_splits does not agree with '
+                'that of embedding_indices for feature {}.'.format(feature))
+          if (enqueue_data.aggregation_weights is not None and
+              enqueue_data.aggregation_weights.device !=
+              enqueue_data.embedding_indices.device):
+            raise ValueError(
+                'Device of aggregation_weights does not agree with '
+                'that of embedding_indices for feature {}.'.format(feature))
+
+        else:
           raise ValueError(
-              'Device of sample_indices does not agree with '
-              'that of embedding_indices for feature {}.'.format(feature))
-        if (enqueue_data.aggregation_weights is not None and
-            enqueue_data.aggregation_weights.device !=
-            enqueue_data.embedding_indices.device):
-          raise ValueError(
-              'Device of aggregation_weights does not agree with '
-              'that of embedding_indices for feature {}.'.format(feature))
+              '`enqueue_datas_list[{}]` has a feature that is not mapped to '
+              '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format(
+                  i, feature))
         # Check all features are on the same device.
         if device is None:
           device = enqueue_data.embedding_indices.device
@@ -1257,23 +1353,69 @@ class TPUEmbedding(object):
       else:
         contiguous_device = device
 
-  def _generate_enqueue_op(
-      self, enqueue_datas, device_ordinal, mode_override=None):
+  def _generate_enqueue_op(self,
+                           enqueue_datas,
+                           device_ordinal,
+                           mode_override=None,
+                           ragged=False):
+    """Creates op for enqueuing batch to TPU."""
     enqueue_data0 = list(enqueue_datas.values())[0]
     with ops.colocate_with(enqueue_data0.embedding_indices):
-      return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
-          device_ordinal=device_ordinal,
-          combiners=self._combiners,
-          mode_override=mode_override,
-          **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)
-      )
+      if ragged:
+        # note that this is currently identical in behavior
+        return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
+            device_ordinal=device_ordinal,
+            combiners=self._combiners,
+            mode_override=mode_override,
+            **self._format_for_tpu_embedding_ragged_tensor_batch(enqueue_datas))
+      else:
+        return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
+            device_ordinal=device_ordinal,
+            combiners=self._combiners,
+            mode_override=mode_override,
+            **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
+
+  def _format_for_tpu_embedding_ragged_tensor_batch(self, enqueue_datas):
+    """Format sparse features for `enqueue_tpu_embedding_ragged_tensor_batch()`.
+
+    Args:
+      enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding.
+
+    Returns:
+      Dict of arguments for `enqueue_tpu_embedding_ragged_tensor_batch()`.
+    """
+
+    kwargs = {
+        'sample_splits': [],
+        'embedding_indices': [],
+        'aggregation_weights': [],
+        'table_ids': [],
+        'max_sequence_lengths': [],
+    }
+    for table_id, table in enumerate(self._table_to_features_dict):
+      features = self._table_to_features_dict[table]
+      for feature in features:
+        enqueue_data = enqueue_datas[feature]
+
+        kwargs['sample_splits'].append(enqueue_data.sample_splits)
+
+        kwargs['aggregation_weights'].append(
+            enqueue_data.aggregation_weights if enqueue_data.aggregation_weights
+            is not None else array_ops.zeros((0,), dtype=dtypes.float32))
+
+        kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
+
+        kwargs['table_ids'].append(table_id)
+        kwargs['max_sequence_lengths'].append(
+            self._feature_to_config_dict[feature].max_sequence_length)
+
+    return kwargs
 
   def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
     """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
 
     Args:
-      enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or
-      dense.
+      enqueue_datas: a `Dict` of `EnqueueData` objects for embedding.
 
     Returns:
       Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 80aca6304c0..af2a47fb3b9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1244,6 +1244,10 @@ tf_module {
     name: "EnqueueTPUEmbeddingIntegerBatch"
     argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
   }
+  member_method {
+    name: "EnqueueTPUEmbeddingRaggedTensorBatch"
+    argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
+  }
   member_method {
     name: "EnqueueTPUEmbeddingSparseBatch"
     argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 80aca6304c0..af2a47fb3b9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1244,6 +1244,10 @@ tf_module {
     name: "EnqueueTPUEmbeddingIntegerBatch"
     argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
   }
+  member_method {
+    name: "EnqueueTPUEmbeddingRaggedTensorBatch"
+    argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
+  }
   member_method {
     name: "EnqueueTPUEmbeddingSparseBatch"
     argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "