From f5743fefbd969822e96bf5e54fa774846ab499c8 Mon Sep 17 00:00:00 2001 From: Geeta Chavan Date: Tue, 2 Mar 2021 14:32:10 -0800 Subject: [PATCH] [Cherrypick:r2.4] Add TPU embedding profile data directory to both TPU estimator and TPU embedding. --- .../tpu/tpu_embedding_configuration.proto | 17 ++ tensorflow/python/tpu/tpu_embedding.py | 238 +++++++++--------- ....experimental.-embedding-config-spec.pbtxt | 4 + 3 files changed, 147 insertions(+), 112 deletions(-) diff --git a/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto b/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto index 038c7a1b8aa..7e321158091 100644 --- a/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto +++ b/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto @@ -87,6 +87,23 @@ message TPUEmbeddingConfiguration { // problem. bool pipeline_execution_with_tensor_core = 7; + // Directory where embedding lookup statistics are stored. These statistics + // summarize information about the inputs to the embedding lookup + // operation, in particular, the average number of embedding IDs per example + // and how well the embedding IDs are load balanced across the system. The + // lookup statistics are used during TPU initialization for embedding table + // partitioning. Collection of lookup statistics is done at runtime by + // profiling the embedding inputs: only 3% of input samples are profiled to + // minimize host CPU overhead. Once a suitable number of samples are + // profiled, the lookup statistics are saved to table-specific files in the + // profile data directory generally at the end of a TPU training loop. The + // filename corresponding to each table is obtained by hashing table specific + // parameters (e.g., table name and number of features) and global + // configuration parameters (e.g., sharding strategy and TPU worker task + // count). The same profile data directory can be shared amongst several + // models to reuse embedding lookup statistics. + string profile_data_directory = 9; + // Extended output layout information; deprecated and now ignored. TPUEmbeddingOutputLayout output_layout = 8 [deprecated = true]; } diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py index 7c42bb2c41f..35ff13bed0a 100644 --- a/tensorflow/python/tpu/tpu_embedding.py +++ b/tensorflow/python/tpu/tpu_embedding.py @@ -52,8 +52,13 @@ INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE # as AdagradParameters etc instead of learning_rate. class TableConfig( collections.namedtuple('TableConfig', [ - 'vocabulary_size', 'dimension', 'initializer', 'combiner', - 'hot_id_replication', 'learning_rate', 'learning_rate_fn', + 'vocabulary_size', + 'dimension', + 'initializer', + 'combiner', + 'hot_id_replication', + 'learning_rate', + 'learning_rate_fn', 'optimization_parameters', ])): """Embedding table configuration.""" @@ -85,16 +90,16 @@ class TableConfig( hot_id_replication: If true, enables hot id replication, which can make embedding lookups faster if there are some hot rows in the table. learning_rate: float, static learning rate for this table. If - learning_rate and learning_rate_fn are both `None`, static learning - rate as specified in local `optimization_parameters` will be used. - In case local `optimization_parameters` is `None`, global + learning_rate and learning_rate_fn are both `None`, static learning rate + as specified in local `optimization_parameters` will be used. In case + local `optimization_parameters` is `None`, global `optimization_parameters` in `TPUEmbedding` constructor will be used. `learning_rate_fn` must be `None` if `learning_rate` is not `None. learning_rate_fn: string, use dynamic learning rate given by the function. This function function will be passed the current global step. If - learning_rate and learning_rate_fn are both `None`, static - learning rate as specified in `optimization_parameters` is used. - `learning_rate` must be `None` if `learning_rate_fn` is not `None. + learning_rate and learning_rate_fn are both `None`, static learning rate + as specified in `optimization_parameters` is used. `learning_rate` must + be `None` if `learning_rate_fn` is not `None. optimization_parameters: `AdagradParameters`, `AdamParameters`, `Stochasticgradientdescentparameters`. Specifies table level optimizer. If it's `None` global optimizer in `TPUEmbedding` constructor is used. @@ -127,8 +132,8 @@ class TableConfig( if learning_rate is not None and learning_rate_fn is not None: raise ValueError('At most one of learning_rate and learning_rate_fn ' - 'can be None; got {} and {}' - .format(learning_rate, learning_rate_fn)) + 'can be None; got {} and {}'.format( + learning_rate, learning_rate_fn)) if optimization_parameters is not None: if not isinstance(optimization_parameters, _OptimizationParameters): @@ -144,15 +149,11 @@ class TableConfig( class FeatureConfig( - collections.namedtuple( - 'FeatureConfig', - ['table_id', 'max_sequence_length', 'weight_key'])): + collections.namedtuple('FeatureConfig', + ['table_id', 'max_sequence_length', 'weight_key'])): """Feature configuration.""" - def __new__(cls, - table_id, - max_sequence_length=0, - weight_key=None): + def __new__(cls, table_id, max_sequence_length=0, weight_key=None): """Feature configuration. Args: @@ -171,8 +172,8 @@ class FeatureConfig( ValueError: if `max_sequence_length` non-negative. """ if not isinstance(max_sequence_length, int) or max_sequence_length < 0: - raise ValueError('Invalid max_sequence_length {}.'.format( - max_sequence_length)) + raise ValueError( + 'Invalid max_sequence_length {}.'.format(max_sequence_length)) return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length, weight_key) @@ -191,19 +192,19 @@ class EnqueueData( """Data to be enqueued through generate_enqueue_ops(). Args: - embedding_indices: A rank 1 Tensors, indices into the embedding tables. It + embedding_indices: A rank 1 Tensor, indices into the embedding tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32 and int64 are allowed and will be converted to int32 internally. - sample_indices: A rank 2 Tensors specifying the training example to which + sample_indices: A rank 2 Tensor specifying the training example to which the corresponding embedding_indices and aggregation_weights values belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If it is None, we assume each embedding_indices belongs to a different sample. Both int32 and int64 are allowed and will be converted to int32 internally. - aggregation_weights: A rank 1 Tensors containing aggregation weights. - It corresponds to sp_weights.values in embedding_lookup_sparse(). If it - is None, we assume all weights are 1. Both float32 and float64 are - allowed and will be converted to float32 internally. + aggregation_weights: A rank 1 Tensor containing aggregation weights. It + corresponds to sp_weights.values in embedding_lookup_sparse(). If it is + None, we assume all weights are 1. Both float32 and float64 are allowed + and will be converted to float32 internally. Returns: An EnqueueData tuple. @@ -310,11 +311,11 @@ def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list): return enqueue_datas_list -AdamSlotVariableNames = collections.namedtuple( - 'AdamSlotVariableNames', ['m', 'v']) +AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames', + ['m', 'v']) -AdagradSlotVariableName = collections.namedtuple( - 'AdagradSlotVariableName', ['accumulator']) +AdagradSlotVariableName = collections.namedtuple('AdagradSlotVariableName', + ['accumulator']) MomentumSlotVariableName = collections.namedtuple('MomentumSlotVariableName', ['momenta']) @@ -325,11 +326,10 @@ RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames', ProximalAdagradSlotVariableName = collections.namedtuple( 'ProximalAdagradSlotVariableName', ['accumulator']) -FtrlSlotVariableName = collections.namedtuple( - 'FtrlSlotVariableName', ['accumulator', 'linear']) +FtrlSlotVariableName = collections.namedtuple('FtrlSlotVariableName', + ['accumulator', 'linear']) -ProximalYogiSlotVariableNames = collections.namedtuple( - 'ProximalYogiSlotVariableNames', ['v', 'm']) +ProximalYogiSlotVariableNames = collections.namedtuple('ProximalYogiSlotVariableNames', ['v', 'm']) AdamSlotVariables = collections.namedtuple( 'AdamSlotVariables', ['m', 'v']) @@ -340,22 +340,21 @@ MomentumSlotVariable = collections.namedtuple('MomentumSlotVariable', RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables', ['ms', 'mom']) -AdagradSlotVariable = collections.namedtuple( - 'AdagradSlotVariable', ['accumulator']) +AdagradSlotVariable = collections.namedtuple('AdagradSlotVariable', + ['accumulator']) ProximalAdagradSlotVariable = collections.namedtuple( 'ProximalAdagradSlotVariable', ['accumulator']) -FtrlSlotVariable = collections.namedtuple( - 'FtrlSlotVariable', ['accumulator', 'linear']) +FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable', + ['accumulator', 'linear']) ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables', ['v', 'm']) -VariablesAndOps = collections.namedtuple( - 'VariablesAndOps', - ['embedding_variables_by_table', 'slot_variables_by_table', - 'load_ops', 'retrieve_ops'] +VariablesAndOps = collections.namedtuple('VariablesAndOps',[ + 'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops', + 'retrieve_ops'] ) @@ -424,7 +423,6 @@ class AdagradParameters(_OptimizationParameters): use_gradient_accumulation: setting this to `False` makes embedding gradients calculation less accurate but faster. Please see `optimization_parameters.proto` for details. - for details. clip_weight_min: the minimum value to clip by; None means -infinity. clip_weight_max: the maximum value to clip by; None means +infinity. weight_decay_factor: amount of weight decay to apply; None means that the @@ -560,19 +558,18 @@ class AdamParameters(_OptimizationParameters): Args: learning_rate: a floating point value. The learning rate. - beta1: A float value. - The exponential decay rate for the 1st moment estimates. - beta2: A float value. - The exponential decay rate for the 2nd moment estimates. + beta1: A float value. The exponential decay rate for the 1st moment + estimates. + beta2: A float value. The exponential decay rate for the 2nd moment + estimates. epsilon: A small constant for numerical stability. - lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. - Please see `optimization_parameters.proto` for details. + lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See + `optimization_parameters.proto` for details. sum_inside_sqrt: This improves training speed. Please see `optimization_parameters.proto` for details. use_gradient_accumulation: setting this to `False` makes embedding gradients calculation less accurate but faster. Please see `optimization_parameters.proto` for details. - for details. clip_weight_min: the minimum value to clip by; None means -infinity. clip_weight_max: the maximum value to clip by; None means +infinity. weight_decay_factor: amount of weight decay to apply; None means that the @@ -656,19 +653,18 @@ class FtrlParameters(_OptimizationParameters): Args: learning_rate: a floating point value. The learning rate. learning_rate_power: A float value, must be less or equal to zero. - Controls how the learning rate decreases during training. Use zero for - a fixed learning rate. See section 3.1 in the + Controls how the learning rate decreases during training. Use zero for a + fixed learning rate. See section 3.1 in the [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). - initial_accumulator_value: The starting value for accumulators. - Only zero or positive values are allowed. - l1_regularization_strength: A float value, must be greater than or - equal to zero. - l2_regularization_strength: A float value, must be greater than or - equal to zero. + initial_accumulator_value: The starting value for accumulators. Only zero + or positive values are allowed. + l1_regularization_strength: A float value, must be greater than or equal + to zero. + l2_regularization_strength: A float value, must be greater than or equal + to zero. use_gradient_accumulation: setting this to `False` makes embedding gradients calculation less accurate but faster. Please see `optimization_parameters.proto` for details. - for details. clip_weight_min: the minimum value to clip by; None means -infinity. clip_weight_max: the maximum value to clip by; None means +infinity. weight_decay_factor: amount of weight decay to apply; None means that the @@ -728,13 +724,15 @@ class ProximalYogiParameters(_OptimizationParameters): """Optimization parameters for Proximal Yogi with TPU embeddings. Implements the Yogi optimizer as described in - [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization). + [Adaptive Methods for Nonconvex + Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization). Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the `optimization_parameters` argument to set the optimizer and its parameters. See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` for more details. """ + # pylint: enable=line-too-long def __init__( @@ -1140,6 +1138,7 @@ class TPUEmbedding(object): cluster_def=None, pipeline_execution_with_tensor_core=False, partition_strategy='div', + profile_data_directory=None, device_config=None, master_job_name=None): """API for using TPU for embedding lookups. @@ -1166,6 +1165,21 @@ class TPUEmbedding(object): partition_strategy: A string, either 'mod' or 'div', specifying how to map the lookup id to the embedding tensor. For more information see `tf.nn.embedding_lookup_sparse`. + profile_data_directory: Directory where embedding lookup statistics are + stored. These statistics summarize information about the inputs to the + embedding lookup operation, in particular, the average number of + embedding IDs per example and how well the embedding IDs are load + balanced across the system. The lookup statistics are used during TPU + initialization for embedding table partitioning. Collection of lookup + statistics is done at runtime by profiling the embedding inputs: only + 3% of input samples are profiled to minimize host CPU overhead. Once + a suitable number of samples are profiled, the lookup statistics are + saved to table-specific files in the profile data directory generally + at the end of a TPU training loop. The filename corresponding to each + table is obtained by hashing table specific parameters (e.g., table + name and number of features) and global configuration parameters (e.g., + sharding strategy and task count). The same profile data directory can + be shared among several models to reuse embedding lookup statistics. device_config: A DeviceConfig instance, used when `master` and `cluster_def` are both `None`. master_job_name: if set, overrides the master job name used to schedule @@ -1179,6 +1193,8 @@ class TPUEmbedding(object): 'Invalid partition_strategy {}'.format(partition_strategy)) self._partition_strategy = partition_strategy + self._profile_data_directory = profile_data_directory + _validate_table_to_config_dict(table_to_config_dict) # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) @@ -1220,14 +1236,14 @@ class TPUEmbedding(object): self._num_hosts = tpu_system_metadata.num_hosts if master_job_name is None: try: - master_job_name = tpu_system_metadata_lib.master_job(master, - cluster_def) + master_job_name = tpu_system_metadata_lib.master_job( + master, cluster_def) except ValueError as e: raise ValueError(str(e) + ' Please specify a master_job_name.') self._hosts = [] for device in tpu_system_metadata.devices: - if 'device:CPU:' in device.name and ( - master_job_name is None or master_job_name in device.name): + if 'device:CPU:' in device.name and (master_job_name is None or + master_job_name in device.name): self._hosts.append(device.name) self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host self._num_cores = tpu_system_metadata.num_cores @@ -1244,11 +1260,10 @@ class TPUEmbedding(object): if optimization_parameters is not None: raise ValueError('`optimization_parameters` should be `None` ' 'for inference mode.') - self._optimization_parameters = ( - StochasticGradientDescentParameters(1.)) + self._optimization_parameters = (StochasticGradientDescentParameters(1.)) else: - raise ValueError('`mode` only supports {} and {}; got {}.' - .format(TRAINING, INFERENCE, mode)) + raise ValueError('`mode` only supports {} and {}; got {}.'.format( + TRAINING, INFERENCE, mode)) self._mode = mode # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler` @@ -1259,11 +1274,13 @@ class TPUEmbedding(object): self._pipeline_execution_with_tensor_core = ( pipeline_execution_with_tensor_core) - self._learning_rate_fn = list(set( - c.learning_rate_fn for c in self._table_to_config_dict.values() - if c.learning_rate_fn is not None)) + self._learning_rate_fn = list( + set(c.learning_rate_fn + for c in self._table_to_config_dict.values() + if c.learning_rate_fn is not None)) self._learning_rate_fn_to_tag = { - fn: id for id, fn in enumerate(self._learning_rate_fn)} + fn: id for id, fn in enumerate(self._learning_rate_fn) + } self._config_proto = self._create_config_proto() @@ -1403,10 +1420,13 @@ class TPUEmbedding(object): elc.TPUEmbeddingConfiguration.MOD) config_proto.pipeline_execution_with_tensor_core = ( self._pipeline_execution_with_tensor_core) + if self._profile_data_directory: + config_proto.profile_data_directory = self._profile_data_directory return config_proto - def create_variables_and_ops(self, embedding_variable_name_by_table=None, + def create_variables_and_ops(self, + embedding_variable_name_by_table=None, slot_variable_names_by_table=None): """Create embedding and slot variables, with ops to load and retrieve them. @@ -1425,8 +1445,8 @@ class TPUEmbedding(object): Args: embedding_variable_name_by_table: A dictionary mapping from string of - table name to string of embedding variable name. If `None`, - defaults from `get_default_slot_variable_names()` will be used. + table name to string of embedding variable name. If `None`, defaults + from `get_default_slot_variable_names()` will be used. slot_variable_names_by_table: A dictionary mapping from string of table name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If `None`, defaults from `get_default_slot_variable_names()` will be used. @@ -1510,8 +1530,7 @@ class TPUEmbedding(object): return retrieve_ops_list return VariablesAndOps(embedding_variables_by_table, - slot_variables_by_table, - load_ops, retrieve_ops) + slot_variables_by_table, load_ops, retrieve_ops) def generate_enqueue_ops( self, @@ -1522,10 +1541,9 @@ class TPUEmbedding(object): """Generate enqueue ops. Args: - enqueue_datas_list: a list of dictionary mapping from string - of feature names to EnqueueData. Each dictionary is for one - TPU core. Dictionaries for the same host should be contiguous - on the list. + enqueue_datas_list: a list of dictionary mapping from string of feature + names to EnqueueData. Each dictionary is for one TPU core. Dictionaries + for the same host should be contiguous in the list. mode_override: A string input that overrides the mode specified in the TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', 'training', 'backward_pass_only'}. When set to @@ -1723,8 +1741,8 @@ class TPUEmbedding(object): if enqueue_data.sample_indices is not None else int_zeros) kwargs['aggregation_weights'].append( - enqueue_data.aggregation_weights if - enqueue_data.aggregation_weights is not None else float_zeros) + enqueue_data.aggregation_weights + if enqueue_data.aggregation_weights is not None else float_zeros) kwargs['embedding_indices'].append(enqueue_data.embedding_indices) @@ -1763,14 +1781,13 @@ class TPUEmbedding(object): feature_index = feature_index + 1 else: activations[feature] = ( - table_activations[:, feature_index:(feature_index+seq_length), :]) + table_activations[:, + feature_index:(feature_index + seq_length), :]) feature_index = feature_index + seq_length return activations - def generate_send_gradients_op(self, - feature_to_gradient_dict, - step=None): + def generate_send_gradients_op(self, feature_to_gradient_dict, step=None): """Send gradient to TPU embedding. Args: @@ -1786,8 +1803,8 @@ class TPUEmbedding(object): """ if self._mode != TRAINING: raise RuntimeError('Only in training mode gradients need to ' - 'be sent to TPU embedding; got mode {}.' - .format(self._mode)) + 'be sent to TPU embedding; got mode {}.'.format( + self._mode)) if step is None and self._learning_rate_fn: raise ValueError('There are dynamic learning rates but step is None.') @@ -1808,8 +1825,10 @@ class TPUEmbedding(object): return tpu_ops.send_tpu_embedding_gradients( inputs=gradients, - learning_rates=[math_ops.cast(fn(step), dtype=dtypes.float32) - for fn in self._learning_rate_fn], + learning_rates=[ + math_ops.cast(fn(step), dtype=dtypes.float32) + for fn in self._learning_rate_fn + ], config=self.config_proto.SerializeToString()) def _get_optimizer_handler_by_table(self): @@ -1835,21 +1854,21 @@ def _validate_table_to_config_dict(table_to_config_dict): def _validate_feature_to_config_dict(table_to_config_dict, feature_to_config_dict): """Validate `feature_to_config_dict`.""" - used_table_set = set([feature.table_id - for feature in feature_to_config_dict.values()]) + used_table_set = set( + [feature.table_id for feature in feature_to_config_dict.values()]) table_set = set(table_to_config_dict.keys()) unused_table_set = table_set - used_table_set if unused_table_set: - raise ValueError('`table_to_config_dict` specifies table that is not ' - 'used in `feature_to_config_dict`: {}.' - .format(unused_table_set)) + raise ValueError( + '`table_to_config_dict` specifies table that is not ' + 'used in `feature_to_config_dict`: {}.'.format(unused_table_set)) extra_table_set = used_table_set - table_set if extra_table_set: - raise ValueError('`feature_to_config_dict` refers to a table that is not ' - 'specified in `table_to_config_dict`: {}.' - .format(extra_table_set)) + raise ValueError( + '`feature_to_config_dict` refers to a table that is not ' + 'specified in `table_to_config_dict`: {}.'.format(extra_table_set)) def _validate_batch_size(batch_size, num_cores): @@ -1867,10 +1886,9 @@ def _validate_optimization_parameters(optimization_parameters, Args: optimization_parameters: global optimizer provided in `TPUEmbedding` - constructor. + constructor. table_to_config_dict: A dictionary mapping from string of table name to `TableConfig`. - """ tbl_optimizer_missing = False for _, table_config in table_to_config_dict.items(): @@ -2107,8 +2125,7 @@ class _AdamHandler(_OptimizerHandler): load_op_list = [] config = config_proto for host_id, table_variable, m_variable, v_variable in (zip( - range(num_hosts), table_variables, - m_variables, v_variables)): + range(num_hosts), table_variables, m_variables, v_variables)): with ops.colocate_with(table_variable): load_parameters_op = ( tpu_ops.load_tpu_embedding_adam_parameters( @@ -2134,8 +2151,7 @@ class _AdamHandler(_OptimizerHandler): retrieve_op_list = [] config = config_proto for host_id, table_variable, m_variable, v_variable in (zip( - range(num_hosts), table_variables, - m_variables, v_variables)): + range(num_hosts), table_variables, m_variables, v_variables)): with ops.colocate_with(table_variable): retrieved_table, retrieved_m, retrieved_v = ( tpu_ops.retrieve_tpu_embedding_adam_parameters( @@ -2174,8 +2190,9 @@ class _FtrlHandler(_OptimizerHandler): def get_default_slot_variable_names(self, table): # These match the default slot variable names created by # tf.train.FtrlOptimizer. - return FtrlSlotVariableName('{}/{}'.format(table, 'Ftrl'), # accumulator - '{}/{}'.format(table, 'Ftrl_1')) # linear + return FtrlSlotVariableName( + '{}/{}'.format(table, 'Ftrl'), # accumulator + '{}/{}'.format(table, 'Ftrl_1')) # linear def create_variables_and_ops(self, table, slot_variable_names, num_hosts, table_config, table_variables, config_proto): @@ -2197,8 +2214,7 @@ class _FtrlHandler(_OptimizerHandler): embedding_dimension=table_config.dimension, collections=[ops.GraphKeys.GLOBAL_VARIABLES], initializer=linear_initializer) - slot_variables = FtrlSlotVariable(accumulator_variables, - linear_variables) + slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables) def load_ops_fn(): """Returns the retrieve ops for Ftrl embedding tables. @@ -2539,8 +2555,7 @@ class _StochasticGradientDescentHandler(_OptimizerHandler): """ load_op_list = [] config = config_proto - for host_id, table_variable in (zip( - range(num_hosts), table_variables)): + for host_id, table_variable in enumerate (table_variables): with ops.colocate_with(table_variable): load_parameters_op = ( tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters( @@ -2561,8 +2576,7 @@ class _StochasticGradientDescentHandler(_OptimizerHandler): """ retrieve_op_list = [] config = config_proto - for host_id, table_variable in (zip( - range(num_hosts), table_variables)): + for host_id, table_variable in enumerate (table_variables): with ops.colocate_with(table_variable): retrieved_table = ( tpu_ops diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt index 355c57269fd..ebcf27eea53 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt @@ -31,6 +31,10 @@ tf_class { name: "pipeline_execution_with_tensor_core" mtype: "" } + member { + name: "profile_data_directory" + mtype: "" + } member { name: "table_to_config_dict" mtype: ""