diff --git a/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto b/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto
index 038c7a1b8aa..7e321158091 100644
--- a/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto
+++ b/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto
@@ -87,6 +87,23 @@ message TPUEmbeddingConfiguration {
   // problem.
   bool pipeline_execution_with_tensor_core = 7;
 
+  // Directory where embedding lookup statistics are stored. These statistics
+  // summarize information about the inputs to the embedding lookup
+  // operation, in particular, the average number of embedding IDs per example
+  // and how well the embedding IDs are load balanced across the system. The
+  // lookup statistics are used during TPU initialization for embedding table
+  // partitioning. Collection of lookup statistics is done at runtime by
+  // profiling the embedding inputs: only 3% of input samples are profiled to
+  // minimize host CPU overhead. Once a suitable number of samples are
+  // profiled, the lookup statistics are saved to table-specific files in the
+  // profile data directory generally at the end of a TPU training loop. The
+  // filename corresponding to each table is obtained by hashing table specific
+  // parameters (e.g., table name and number of features) and global
+  // configuration parameters (e.g., sharding strategy and TPU worker task
+  // count). The same profile data directory can be shared amongst several
+  // models to reuse embedding lookup statistics.
+  string profile_data_directory = 9;
+
   // Extended output layout information; deprecated and now ignored.
   TPUEmbeddingOutputLayout output_layout = 8 [deprecated = true];
 }
diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py
index 7c42bb2c41f..35ff13bed0a 100644
--- a/tensorflow/python/tpu/tpu_embedding.py
+++ b/tensorflow/python/tpu/tpu_embedding.py
@@ -52,8 +52,13 @@ INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
 #  as AdagradParameters etc instead of learning_rate.
 class TableConfig(
     collections.namedtuple('TableConfig', [
-        'vocabulary_size', 'dimension', 'initializer', 'combiner',
-        'hot_id_replication', 'learning_rate', 'learning_rate_fn',
+        'vocabulary_size',
+        'dimension',
+        'initializer',
+        'combiner',
+        'hot_id_replication',
+        'learning_rate',
+        'learning_rate_fn',
         'optimization_parameters',
     ])):
   """Embedding table configuration."""
@@ -85,16 +90,16 @@ class TableConfig(
       hot_id_replication: If true, enables hot id replication, which can make
         embedding lookups faster if there are some hot rows in the table.
       learning_rate: float, static learning rate for this table. If
-        learning_rate and learning_rate_fn are both `None`, static learning
-        rate as specified in local `optimization_parameters` will be used.
-        In case local `optimization_parameters` is `None`, global
+        learning_rate and learning_rate_fn are both `None`, static learning rate
+        as specified in local `optimization_parameters` will be used. In case
+        local `optimization_parameters` is `None`, global
         `optimization_parameters` in `TPUEmbedding` constructor will be used.
         `learning_rate_fn` must be `None` if `learning_rate` is not `None.
       learning_rate_fn: string, use dynamic learning rate given by the function.
         This function function will be passed the current global step. If
-        learning_rate and learning_rate_fn are both `None`, static
-        learning rate as specified in `optimization_parameters` is used.
-        `learning_rate` must be `None` if `learning_rate_fn` is not `None.
+        learning_rate and learning_rate_fn are both `None`, static learning rate
+        as specified in `optimization_parameters` is used. `learning_rate` must
+        be `None` if `learning_rate_fn` is not `None.
       optimization_parameters: `AdagradParameters`, `AdamParameters`,
         `Stochasticgradientdescentparameters`. Specifies table level optimizer.
         If it's `None` global optimizer in `TPUEmbedding` constructor is used.
@@ -127,8 +132,8 @@ class TableConfig(
 
     if learning_rate is not None and learning_rate_fn is not None:
       raise ValueError('At most one of learning_rate and learning_rate_fn '
-                       'can be None; got {} and {}'
-                       .format(learning_rate, learning_rate_fn))
+                       'can be None; got {} and {}'.format(
+                           learning_rate, learning_rate_fn))
 
     if optimization_parameters is not None:
       if not isinstance(optimization_parameters, _OptimizationParameters):
@@ -144,15 +149,11 @@ class TableConfig(
 
 
 class FeatureConfig(
-    collections.namedtuple(
-        'FeatureConfig',
-        ['table_id', 'max_sequence_length', 'weight_key'])):
+    collections.namedtuple('FeatureConfig',
+                           ['table_id', 'max_sequence_length', 'weight_key'])):
   """Feature configuration."""
 
-  def __new__(cls,
-              table_id,
-              max_sequence_length=0,
-              weight_key=None):
+  def __new__(cls, table_id, max_sequence_length=0, weight_key=None):
     """Feature configuration.
 
     Args:
@@ -171,8 +172,8 @@ class FeatureConfig(
       ValueError: if `max_sequence_length` non-negative.
     """
     if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
-      raise ValueError('Invalid max_sequence_length {}.'.format(
-          max_sequence_length))
+      raise ValueError(
+          'Invalid max_sequence_length {}.'.format(max_sequence_length))
 
     return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length,
                                              weight_key)
@@ -191,19 +192,19 @@ class EnqueueData(
     """Data to be enqueued through generate_enqueue_ops().
 
     Args:
-      embedding_indices: A rank 1 Tensors, indices into the embedding tables. It
+      embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
         corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32
         and int64 are allowed and will be converted to int32 internally.
-      sample_indices: A rank 2 Tensors specifying the training example to which
+      sample_indices: A rank 2 Tensor specifying the training example to which
         the corresponding embedding_indices and aggregation_weights values
         belong. It corresponds to sp_ids.indices in embedding_lookup_sparse().
         If it is None, we assume each embedding_indices belongs to a different
         sample. Both int32 and int64 are allowed and will be converted to int32
         internally.
-      aggregation_weights: A rank 1 Tensors containing aggregation weights.
-        It corresponds to sp_weights.values in embedding_lookup_sparse(). If it
-        is None, we assume all weights are 1. Both float32 and float64 are
-        allowed and will be converted to float32 internally.
+      aggregation_weights: A rank 1 Tensor containing aggregation weights. It
+        corresponds to sp_weights.values in embedding_lookup_sparse(). If it is
+        None, we assume all weights are 1. Both float32 and float64 are allowed
+        and will be converted to float32 internally.
 
     Returns:
       An EnqueueData tuple.
@@ -310,11 +311,11 @@ def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
   return enqueue_datas_list
 
 
-AdamSlotVariableNames = collections.namedtuple(
-    'AdamSlotVariableNames', ['m', 'v'])
+AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames',
+                                               ['m', 'v'])
 
-AdagradSlotVariableName = collections.namedtuple(
-    'AdagradSlotVariableName', ['accumulator'])
+AdagradSlotVariableName = collections.namedtuple('AdagradSlotVariableName',
+                                                 ['accumulator'])
 
 MomentumSlotVariableName = collections.namedtuple('MomentumSlotVariableName',
                                                   ['momenta'])
@@ -325,11 +326,10 @@ RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames',
 ProximalAdagradSlotVariableName = collections.namedtuple(
     'ProximalAdagradSlotVariableName', ['accumulator'])
 
-FtrlSlotVariableName = collections.namedtuple(
-    'FtrlSlotVariableName', ['accumulator', 'linear'])
+FtrlSlotVariableName = collections.namedtuple('FtrlSlotVariableName',
+                                              ['accumulator', 'linear'])
 
-ProximalYogiSlotVariableNames = collections.namedtuple(
-    'ProximalYogiSlotVariableNames', ['v', 'm'])
+ProximalYogiSlotVariableNames = collections.namedtuple('ProximalYogiSlotVariableNames', ['v', 'm'])
 
 AdamSlotVariables = collections.namedtuple(
     'AdamSlotVariables', ['m', 'v'])
@@ -340,22 +340,21 @@ MomentumSlotVariable = collections.namedtuple('MomentumSlotVariable',
 RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables',
                                               ['ms', 'mom'])
 
-AdagradSlotVariable = collections.namedtuple(
-    'AdagradSlotVariable', ['accumulator'])
+AdagradSlotVariable = collections.namedtuple('AdagradSlotVariable',
+                                             ['accumulator'])
 
 ProximalAdagradSlotVariable = collections.namedtuple(
     'ProximalAdagradSlotVariable', ['accumulator'])
 
-FtrlSlotVariable = collections.namedtuple(
-    'FtrlSlotVariable', ['accumulator', 'linear'])
+FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable',
+                                          ['accumulator', 'linear'])
 
 ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables',
                                                    ['v', 'm'])
 
-VariablesAndOps = collections.namedtuple(
-    'VariablesAndOps',
-    ['embedding_variables_by_table', 'slot_variables_by_table',
-     'load_ops', 'retrieve_ops']
+VariablesAndOps = collections.namedtuple('VariablesAndOps',[
+    'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops',
+    'retrieve_ops']
 )
 
 
@@ -424,7 +423,6 @@ class AdagradParameters(_OptimizationParameters):
       use_gradient_accumulation: setting this to `False` makes embedding
         gradients calculation less accurate but faster. Please see
         `optimization_parameters.proto` for details.
-        for details.
       clip_weight_min: the minimum value to clip by; None means -infinity.
       clip_weight_max: the maximum value to clip by; None means +infinity.
       weight_decay_factor: amount of weight decay to apply; None means that the
@@ -560,19 +558,18 @@ class AdamParameters(_OptimizationParameters):
 
     Args:
       learning_rate: a floating point value. The learning rate.
-      beta1: A float value.
-        The exponential decay rate for the 1st moment estimates.
-      beta2: A float value.
-        The exponential decay rate for the 2nd moment estimates.
+      beta1: A float value. The exponential decay rate for the 1st moment
+        estimates.
+      beta2: A float value. The exponential decay rate for the 2nd moment
+        estimates.
       epsilon: A small constant for numerical stability.
-      lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster.
-        Please see `optimization_parameters.proto` for details.
+      lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See
+        `optimization_parameters.proto` for details.
       sum_inside_sqrt: This improves training speed. Please see
         `optimization_parameters.proto` for details.
       use_gradient_accumulation: setting this to `False` makes embedding
         gradients calculation less accurate but faster. Please see
         `optimization_parameters.proto` for details.
-        for details.
       clip_weight_min: the minimum value to clip by; None means -infinity.
       clip_weight_max: the maximum value to clip by; None means +infinity.
       weight_decay_factor: amount of weight decay to apply; None means that the
@@ -656,19 +653,18 @@ class FtrlParameters(_OptimizationParameters):
     Args:
       learning_rate: a floating point value. The learning rate.
       learning_rate_power: A float value, must be less or equal to zero.
-        Controls how the learning rate decreases during training. Use zero for
-        a fixed learning rate. See section 3.1 in the
+        Controls how the learning rate decreases during training. Use zero for a
+        fixed learning rate. See section 3.1 in the
         [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
-      initial_accumulator_value: The starting value for accumulators.
-        Only zero or positive values are allowed.
-      l1_regularization_strength: A float value, must be greater than or
-        equal to zero.
-      l2_regularization_strength: A float value, must be greater than or
-        equal to zero.
+      initial_accumulator_value: The starting value for accumulators. Only zero
+        or positive values are allowed.
+      l1_regularization_strength: A float value, must be greater than or equal
+        to zero.
+      l2_regularization_strength: A float value, must be greater than or equal
+        to zero.
       use_gradient_accumulation: setting this to `False` makes embedding
         gradients calculation less accurate but faster. Please see
         `optimization_parameters.proto` for details.
-        for details.
       clip_weight_min: the minimum value to clip by; None means -infinity.
       clip_weight_max: the maximum value to clip by; None means +infinity.
       weight_decay_factor: amount of weight decay to apply; None means that the
@@ -728,13 +724,15 @@ class ProximalYogiParameters(_OptimizationParameters):
   """Optimization parameters for Proximal Yogi with TPU embeddings.
 
   Implements the Yogi optimizer as described in
-  [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization).
+  [Adaptive Methods for Nonconvex
+  Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization).
 
   Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
   `optimization_parameters` argument to set the optimizer and its parameters.
   See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
   for more details.
   """
+
   # pylint: enable=line-too-long
 
   def __init__(
@@ -1140,6 +1138,7 @@ class TPUEmbedding(object):
                cluster_def=None,
                pipeline_execution_with_tensor_core=False,
                partition_strategy='div',
+               profile_data_directory=None,
                device_config=None,
                master_job_name=None):
     """API for using TPU for embedding lookups.
@@ -1166,6 +1165,21 @@ class TPUEmbedding(object):
       partition_strategy: A string, either 'mod' or 'div', specifying how to map
         the lookup id to the embedding tensor. For more information see
         `tf.nn.embedding_lookup_sparse`.
+      profile_data_directory: Directory where embedding lookup statistics are
+        stored. These statistics summarize information about the inputs to the
+        embedding lookup operation, in particular, the average number of
+        embedding IDs per example and how well the embedding IDs are load
+        balanced across the system. The lookup statistics are used during TPU
+        initialization for embedding table partitioning. Collection of lookup
+        statistics is done at runtime by  profiling the embedding inputs: only
+        3% of input samples are profiled to minimize host CPU overhead. Once
+        a suitable number of samples are profiled, the lookup statistics are
+        saved to table-specific files in the profile data directory generally
+        at the end of a TPU training loop. The filename corresponding to each
+        table is obtained by hashing table specific parameters (e.g., table
+        name and number of features) and global configuration parameters (e.g.,
+        sharding strategy and task count). The same profile data directory can
+        be shared among several models to reuse embedding lookup statistics.
       device_config: A DeviceConfig instance, used when `master` and
         `cluster_def` are both `None`.
       master_job_name: if set, overrides the master job name used to schedule
@@ -1179,6 +1193,8 @@ class TPUEmbedding(object):
           'Invalid partition_strategy {}'.format(partition_strategy))
     self._partition_strategy = partition_strategy
 
+    self._profile_data_directory = profile_data_directory
+
     _validate_table_to_config_dict(table_to_config_dict)
     # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
     self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
@@ -1220,14 +1236,14 @@ class TPUEmbedding(object):
       self._num_hosts = tpu_system_metadata.num_hosts
       if master_job_name is None:
         try:
-          master_job_name = tpu_system_metadata_lib.master_job(master,
-                                                               cluster_def)
+          master_job_name = tpu_system_metadata_lib.master_job(
+              master, cluster_def)
         except ValueError as e:
           raise ValueError(str(e) + ' Please specify a master_job_name.')
       self._hosts = []
       for device in tpu_system_metadata.devices:
-        if 'device:CPU:' in device.name and (
-            master_job_name is None or master_job_name in device.name):
+        if 'device:CPU:' in device.name and (master_job_name is None or
+                                             master_job_name in device.name):
           self._hosts.append(device.name)
       self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host
       self._num_cores = tpu_system_metadata.num_cores
@@ -1244,11 +1260,10 @@ class TPUEmbedding(object):
       if optimization_parameters is not None:
         raise ValueError('`optimization_parameters` should be `None` '
                          'for inference mode.')
-      self._optimization_parameters = (
-          StochasticGradientDescentParameters(1.))
+      self._optimization_parameters = (StochasticGradientDescentParameters(1.))
     else:
-      raise ValueError('`mode` only supports {} and {}; got {}.'
-                       .format(TRAINING, INFERENCE, mode))
+      raise ValueError('`mode` only supports {} and {}; got {}.'.format(
+          TRAINING, INFERENCE, mode))
     self._mode = mode
 
     # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler`
@@ -1259,11 +1274,13 @@ class TPUEmbedding(object):
 
     self._pipeline_execution_with_tensor_core = (
         pipeline_execution_with_tensor_core)
-    self._learning_rate_fn = list(set(
-        c.learning_rate_fn for c in self._table_to_config_dict.values()
-        if c.learning_rate_fn is not None))
+    self._learning_rate_fn = list(
+        set(c.learning_rate_fn
+            for c in self._table_to_config_dict.values()
+            if c.learning_rate_fn is not None))
     self._learning_rate_fn_to_tag = {
-        fn: id for id, fn in enumerate(self._learning_rate_fn)}
+        fn: id for id, fn in enumerate(self._learning_rate_fn)
+    }
 
     self._config_proto = self._create_config_proto()
 
@@ -1403,10 +1420,13 @@ class TPUEmbedding(object):
         elc.TPUEmbeddingConfiguration.MOD)
     config_proto.pipeline_execution_with_tensor_core = (
         self._pipeline_execution_with_tensor_core)
+    if self._profile_data_directory:
+      config_proto.profile_data_directory = self._profile_data_directory
 
     return config_proto
 
-  def create_variables_and_ops(self, embedding_variable_name_by_table=None,
+  def create_variables_and_ops(self,
+                               embedding_variable_name_by_table=None,
                                slot_variable_names_by_table=None):
     """Create embedding and slot variables, with ops to load and retrieve them.
 
@@ -1425,8 +1445,8 @@ class TPUEmbedding(object):
 
     Args:
       embedding_variable_name_by_table: A dictionary mapping from string of
-        table name to string of embedding variable name. If `None`,
-        defaults from `get_default_slot_variable_names()` will be used.
+        table name to string of embedding variable name. If `None`, defaults
+        from `get_default_slot_variable_names()` will be used.
       slot_variable_names_by_table: A dictionary mapping from string of table
         name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If
         `None`, defaults from `get_default_slot_variable_names()` will be used.
@@ -1510,8 +1530,7 @@ class TPUEmbedding(object):
       return retrieve_ops_list
 
     return VariablesAndOps(embedding_variables_by_table,
-                           slot_variables_by_table,
-                           load_ops, retrieve_ops)
+                           slot_variables_by_table, load_ops, retrieve_ops)
 
   def generate_enqueue_ops(
       self,
@@ -1522,10 +1541,9 @@ class TPUEmbedding(object):
     """Generate enqueue ops.
 
     Args:
-      enqueue_datas_list: a list of dictionary mapping from string
-        of feature names to EnqueueData. Each dictionary is for one
-        TPU core. Dictionaries for the same host should be contiguous
-        on the list.
+      enqueue_datas_list: a list of dictionary mapping from string of feature
+        names to EnqueueData. Each dictionary is for one TPU core. Dictionaries
+        for the same host should be contiguous in the list.
       mode_override: A string input that overrides the mode specified in the
         TPUEmbeddingConfiguration. Supported values are {'unspecified',
         'inference', 'training', 'backward_pass_only'}. When set to
@@ -1723,8 +1741,8 @@ class TPUEmbedding(object):
             if enqueue_data.sample_indices is not None else int_zeros)
 
         kwargs['aggregation_weights'].append(
-            enqueue_data.aggregation_weights if
-            enqueue_data.aggregation_weights is not None else float_zeros)
+            enqueue_data.aggregation_weights
+            if enqueue_data.aggregation_weights is not None else float_zeros)
 
         kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
 
@@ -1763,14 +1781,13 @@ class TPUEmbedding(object):
           feature_index = feature_index + 1
         else:
           activations[feature] = (
-              table_activations[:, feature_index:(feature_index+seq_length), :])
+              table_activations[:,
+                                feature_index:(feature_index + seq_length), :])
           feature_index = feature_index + seq_length
 
     return activations
 
-  def generate_send_gradients_op(self,
-                                 feature_to_gradient_dict,
-                                 step=None):
+  def generate_send_gradients_op(self, feature_to_gradient_dict, step=None):
     """Send gradient to TPU embedding.
 
     Args:
@@ -1786,8 +1803,8 @@ class TPUEmbedding(object):
     """
     if self._mode != TRAINING:
       raise RuntimeError('Only in training mode gradients need to '
-                         'be sent to TPU embedding; got mode {}.'
-                         .format(self._mode))
+                         'be sent to TPU embedding; got mode {}.'.format(
+                             self._mode))
     if step is None and self._learning_rate_fn:
       raise ValueError('There are dynamic learning rates but step is None.')
 
@@ -1808,8 +1825,10 @@ class TPUEmbedding(object):
 
     return tpu_ops.send_tpu_embedding_gradients(
         inputs=gradients,
-        learning_rates=[math_ops.cast(fn(step), dtype=dtypes.float32)
-                        for fn in self._learning_rate_fn],
+        learning_rates=[
+            math_ops.cast(fn(step), dtype=dtypes.float32)
+            for fn in self._learning_rate_fn
+        ],
         config=self.config_proto.SerializeToString())
 
   def _get_optimizer_handler_by_table(self):
@@ -1835,21 +1854,21 @@ def _validate_table_to_config_dict(table_to_config_dict):
 def _validate_feature_to_config_dict(table_to_config_dict,
                                      feature_to_config_dict):
   """Validate `feature_to_config_dict`."""
-  used_table_set = set([feature.table_id
-                        for feature in feature_to_config_dict.values()])
+  used_table_set = set(
+      [feature.table_id for feature in feature_to_config_dict.values()])
   table_set = set(table_to_config_dict.keys())
 
   unused_table_set = table_set - used_table_set
   if unused_table_set:
-    raise ValueError('`table_to_config_dict` specifies table that is not '
-                     'used in `feature_to_config_dict`: {}.'
-                     .format(unused_table_set))
+    raise ValueError(
+        '`table_to_config_dict` specifies table that is not '
+        'used in `feature_to_config_dict`: {}.'.format(unused_table_set))
 
   extra_table_set = used_table_set - table_set
   if extra_table_set:
-    raise ValueError('`feature_to_config_dict` refers to a table that is not '
-                     'specified in `table_to_config_dict`: {}.'
-                     .format(extra_table_set))
+    raise ValueError(
+        '`feature_to_config_dict` refers to a table that is not '
+        'specified in `table_to_config_dict`: {}.'.format(extra_table_set))
 
 
 def _validate_batch_size(batch_size, num_cores):
@@ -1867,10 +1886,9 @@ def _validate_optimization_parameters(optimization_parameters,
 
   Args:
       optimization_parameters: global optimizer provided in `TPUEmbedding`
-         constructor.
+        constructor.
       table_to_config_dict: A dictionary mapping from string of table name to
         `TableConfig`.
-
   """
   tbl_optimizer_missing = False
   for _, table_config in table_to_config_dict.items():
@@ -2107,8 +2125,7 @@ class _AdamHandler(_OptimizerHandler):
       load_op_list = []
       config = config_proto
       for host_id, table_variable, m_variable, v_variable in (zip(
-          range(num_hosts), table_variables,
-          m_variables, v_variables)):
+          range(num_hosts), table_variables, m_variables, v_variables)):
         with ops.colocate_with(table_variable):
           load_parameters_op = (
               tpu_ops.load_tpu_embedding_adam_parameters(
@@ -2134,8 +2151,7 @@ class _AdamHandler(_OptimizerHandler):
       retrieve_op_list = []
       config = config_proto
       for host_id, table_variable, m_variable, v_variable in (zip(
-          range(num_hosts), table_variables,
-          m_variables, v_variables)):
+          range(num_hosts), table_variables, m_variables, v_variables)):
         with ops.colocate_with(table_variable):
           retrieved_table, retrieved_m, retrieved_v = (
               tpu_ops.retrieve_tpu_embedding_adam_parameters(
@@ -2174,8 +2190,9 @@ class _FtrlHandler(_OptimizerHandler):
   def get_default_slot_variable_names(self, table):
     # These match the default slot variable names created by
     # tf.train.FtrlOptimizer.
-    return FtrlSlotVariableName('{}/{}'.format(table, 'Ftrl'),  # accumulator
-                                '{}/{}'.format(table, 'Ftrl_1'))  # linear
+    return FtrlSlotVariableName(
+        '{}/{}'.format(table, 'Ftrl'),  # accumulator
+        '{}/{}'.format(table, 'Ftrl_1'))  # linear
 
   def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                                table_config, table_variables, config_proto):
@@ -2197,8 +2214,7 @@ class _FtrlHandler(_OptimizerHandler):
         embedding_dimension=table_config.dimension,
         collections=[ops.GraphKeys.GLOBAL_VARIABLES],
         initializer=linear_initializer)
-    slot_variables = FtrlSlotVariable(accumulator_variables,
-                                      linear_variables)
+    slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables)
 
     def load_ops_fn():
       """Returns the retrieve ops for Ftrl embedding tables.
@@ -2539,8 +2555,7 @@ class _StochasticGradientDescentHandler(_OptimizerHandler):
       """
       load_op_list = []
       config = config_proto
-      for host_id, table_variable in (zip(
-          range(num_hosts), table_variables)):
+      for host_id, table_variable in enumerate (table_variables):
         with ops.colocate_with(table_variable):
           load_parameters_op = (
               tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters(
@@ -2561,8 +2576,7 @@ class _StochasticGradientDescentHandler(_OptimizerHandler):
       """
       retrieve_op_list = []
       config = config_proto
-      for host_id, table_variable in (zip(
-          range(num_hosts), table_variables)):
+      for host_id, table_variable in enumerate (table_variables):
         with ops.colocate_with(table_variable):
           retrieved_table = (
               tpu_ops
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt
index 355c57269fd..ebcf27eea53 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt
@@ -31,6 +31,10 @@ tf_class {
     name: "pipeline_execution_with_tensor_core"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "profile_data_directory"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "table_to_config_dict"
     mtype: "<type \'property\'>"