[Cherrypick:r2.4] Add TPU embedding profile data directory to both TPU estimator and TPU embedding.
This commit is contained in:
parent
85c8b2a817
commit
f5743fefbd
@ -87,6 +87,23 @@ message TPUEmbeddingConfiguration {
|
||||
// problem.
|
||||
bool pipeline_execution_with_tensor_core = 7;
|
||||
|
||||
// Directory where embedding lookup statistics are stored. These statistics
|
||||
// summarize information about the inputs to the embedding lookup
|
||||
// operation, in particular, the average number of embedding IDs per example
|
||||
// and how well the embedding IDs are load balanced across the system. The
|
||||
// lookup statistics are used during TPU initialization for embedding table
|
||||
// partitioning. Collection of lookup statistics is done at runtime by
|
||||
// profiling the embedding inputs: only 3% of input samples are profiled to
|
||||
// minimize host CPU overhead. Once a suitable number of samples are
|
||||
// profiled, the lookup statistics are saved to table-specific files in the
|
||||
// profile data directory generally at the end of a TPU training loop. The
|
||||
// filename corresponding to each table is obtained by hashing table specific
|
||||
// parameters (e.g., table name and number of features) and global
|
||||
// configuration parameters (e.g., sharding strategy and TPU worker task
|
||||
// count). The same profile data directory can be shared amongst several
|
||||
// models to reuse embedding lookup statistics.
|
||||
string profile_data_directory = 9;
|
||||
|
||||
// Extended output layout information; deprecated and now ignored.
|
||||
TPUEmbeddingOutputLayout output_layout = 8 [deprecated = true];
|
||||
}
|
||||
|
@ -52,8 +52,13 @@ INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
|
||||
# as AdagradParameters etc instead of learning_rate.
|
||||
class TableConfig(
|
||||
collections.namedtuple('TableConfig', [
|
||||
'vocabulary_size', 'dimension', 'initializer', 'combiner',
|
||||
'hot_id_replication', 'learning_rate', 'learning_rate_fn',
|
||||
'vocabulary_size',
|
||||
'dimension',
|
||||
'initializer',
|
||||
'combiner',
|
||||
'hot_id_replication',
|
||||
'learning_rate',
|
||||
'learning_rate_fn',
|
||||
'optimization_parameters',
|
||||
])):
|
||||
"""Embedding table configuration."""
|
||||
@ -85,16 +90,16 @@ class TableConfig(
|
||||
hot_id_replication: If true, enables hot id replication, which can make
|
||||
embedding lookups faster if there are some hot rows in the table.
|
||||
learning_rate: float, static learning rate for this table. If
|
||||
learning_rate and learning_rate_fn are both `None`, static learning
|
||||
rate as specified in local `optimization_parameters` will be used.
|
||||
In case local `optimization_parameters` is `None`, global
|
||||
learning_rate and learning_rate_fn are both `None`, static learning rate
|
||||
as specified in local `optimization_parameters` will be used. In case
|
||||
local `optimization_parameters` is `None`, global
|
||||
`optimization_parameters` in `TPUEmbedding` constructor will be used.
|
||||
`learning_rate_fn` must be `None` if `learning_rate` is not `None.
|
||||
learning_rate_fn: string, use dynamic learning rate given by the function.
|
||||
This function function will be passed the current global step. If
|
||||
learning_rate and learning_rate_fn are both `None`, static
|
||||
learning rate as specified in `optimization_parameters` is used.
|
||||
`learning_rate` must be `None` if `learning_rate_fn` is not `None.
|
||||
learning_rate and learning_rate_fn are both `None`, static learning rate
|
||||
as specified in `optimization_parameters` is used. `learning_rate` must
|
||||
be `None` if `learning_rate_fn` is not `None.
|
||||
optimization_parameters: `AdagradParameters`, `AdamParameters`,
|
||||
`Stochasticgradientdescentparameters`. Specifies table level optimizer.
|
||||
If it's `None` global optimizer in `TPUEmbedding` constructor is used.
|
||||
@ -127,8 +132,8 @@ class TableConfig(
|
||||
|
||||
if learning_rate is not None and learning_rate_fn is not None:
|
||||
raise ValueError('At most one of learning_rate and learning_rate_fn '
|
||||
'can be None; got {} and {}'
|
||||
.format(learning_rate, learning_rate_fn))
|
||||
'can be None; got {} and {}'.format(
|
||||
learning_rate, learning_rate_fn))
|
||||
|
||||
if optimization_parameters is not None:
|
||||
if not isinstance(optimization_parameters, _OptimizationParameters):
|
||||
@ -144,15 +149,11 @@ class TableConfig(
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
collections.namedtuple(
|
||||
'FeatureConfig',
|
||||
['table_id', 'max_sequence_length', 'weight_key'])):
|
||||
collections.namedtuple('FeatureConfig',
|
||||
['table_id', 'max_sequence_length', 'weight_key'])):
|
||||
"""Feature configuration."""
|
||||
|
||||
def __new__(cls,
|
||||
table_id,
|
||||
max_sequence_length=0,
|
||||
weight_key=None):
|
||||
def __new__(cls, table_id, max_sequence_length=0, weight_key=None):
|
||||
"""Feature configuration.
|
||||
|
||||
Args:
|
||||
@ -171,8 +172,8 @@ class FeatureConfig(
|
||||
ValueError: if `max_sequence_length` non-negative.
|
||||
"""
|
||||
if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
|
||||
raise ValueError('Invalid max_sequence_length {}.'.format(
|
||||
max_sequence_length))
|
||||
raise ValueError(
|
||||
'Invalid max_sequence_length {}.'.format(max_sequence_length))
|
||||
|
||||
return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length,
|
||||
weight_key)
|
||||
@ -191,19 +192,19 @@ class EnqueueData(
|
||||
"""Data to be enqueued through generate_enqueue_ops().
|
||||
|
||||
Args:
|
||||
embedding_indices: A rank 1 Tensors, indices into the embedding tables. It
|
||||
embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
|
||||
corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32
|
||||
and int64 are allowed and will be converted to int32 internally.
|
||||
sample_indices: A rank 2 Tensors specifying the training example to which
|
||||
sample_indices: A rank 2 Tensor specifying the training example to which
|
||||
the corresponding embedding_indices and aggregation_weights values
|
||||
belong. It corresponds to sp_ids.indices in embedding_lookup_sparse().
|
||||
If it is None, we assume each embedding_indices belongs to a different
|
||||
sample. Both int32 and int64 are allowed and will be converted to int32
|
||||
internally.
|
||||
aggregation_weights: A rank 1 Tensors containing aggregation weights.
|
||||
It corresponds to sp_weights.values in embedding_lookup_sparse(). If it
|
||||
is None, we assume all weights are 1. Both float32 and float64 are
|
||||
allowed and will be converted to float32 internally.
|
||||
aggregation_weights: A rank 1 Tensor containing aggregation weights. It
|
||||
corresponds to sp_weights.values in embedding_lookup_sparse(). If it is
|
||||
None, we assume all weights are 1. Both float32 and float64 are allowed
|
||||
and will be converted to float32 internally.
|
||||
|
||||
Returns:
|
||||
An EnqueueData tuple.
|
||||
@ -310,11 +311,11 @@ def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
|
||||
return enqueue_datas_list
|
||||
|
||||
|
||||
AdamSlotVariableNames = collections.namedtuple(
|
||||
'AdamSlotVariableNames', ['m', 'v'])
|
||||
AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames',
|
||||
['m', 'v'])
|
||||
|
||||
AdagradSlotVariableName = collections.namedtuple(
|
||||
'AdagradSlotVariableName', ['accumulator'])
|
||||
AdagradSlotVariableName = collections.namedtuple('AdagradSlotVariableName',
|
||||
['accumulator'])
|
||||
|
||||
MomentumSlotVariableName = collections.namedtuple('MomentumSlotVariableName',
|
||||
['momenta'])
|
||||
@ -325,11 +326,10 @@ RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames',
|
||||
ProximalAdagradSlotVariableName = collections.namedtuple(
|
||||
'ProximalAdagradSlotVariableName', ['accumulator'])
|
||||
|
||||
FtrlSlotVariableName = collections.namedtuple(
|
||||
'FtrlSlotVariableName', ['accumulator', 'linear'])
|
||||
FtrlSlotVariableName = collections.namedtuple('FtrlSlotVariableName',
|
||||
['accumulator', 'linear'])
|
||||
|
||||
ProximalYogiSlotVariableNames = collections.namedtuple(
|
||||
'ProximalYogiSlotVariableNames', ['v', 'm'])
|
||||
ProximalYogiSlotVariableNames = collections.namedtuple('ProximalYogiSlotVariableNames', ['v', 'm'])
|
||||
|
||||
AdamSlotVariables = collections.namedtuple(
|
||||
'AdamSlotVariables', ['m', 'v'])
|
||||
@ -340,22 +340,21 @@ MomentumSlotVariable = collections.namedtuple('MomentumSlotVariable',
|
||||
RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables',
|
||||
['ms', 'mom'])
|
||||
|
||||
AdagradSlotVariable = collections.namedtuple(
|
||||
'AdagradSlotVariable', ['accumulator'])
|
||||
AdagradSlotVariable = collections.namedtuple('AdagradSlotVariable',
|
||||
['accumulator'])
|
||||
|
||||
ProximalAdagradSlotVariable = collections.namedtuple(
|
||||
'ProximalAdagradSlotVariable', ['accumulator'])
|
||||
|
||||
FtrlSlotVariable = collections.namedtuple(
|
||||
'FtrlSlotVariable', ['accumulator', 'linear'])
|
||||
FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable',
|
||||
['accumulator', 'linear'])
|
||||
|
||||
ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables',
|
||||
['v', 'm'])
|
||||
|
||||
VariablesAndOps = collections.namedtuple(
|
||||
'VariablesAndOps',
|
||||
['embedding_variables_by_table', 'slot_variables_by_table',
|
||||
'load_ops', 'retrieve_ops']
|
||||
VariablesAndOps = collections.namedtuple('VariablesAndOps',[
|
||||
'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops',
|
||||
'retrieve_ops']
|
||||
)
|
||||
|
||||
|
||||
@ -424,7 +423,6 @@ class AdagradParameters(_OptimizationParameters):
|
||||
use_gradient_accumulation: setting this to `False` makes embedding
|
||||
gradients calculation less accurate but faster. Please see
|
||||
`optimization_parameters.proto` for details.
|
||||
for details.
|
||||
clip_weight_min: the minimum value to clip by; None means -infinity.
|
||||
clip_weight_max: the maximum value to clip by; None means +infinity.
|
||||
weight_decay_factor: amount of weight decay to apply; None means that the
|
||||
@ -560,19 +558,18 @@ class AdamParameters(_OptimizationParameters):
|
||||
|
||||
Args:
|
||||
learning_rate: a floating point value. The learning rate.
|
||||
beta1: A float value.
|
||||
The exponential decay rate for the 1st moment estimates.
|
||||
beta2: A float value.
|
||||
The exponential decay rate for the 2nd moment estimates.
|
||||
beta1: A float value. The exponential decay rate for the 1st moment
|
||||
estimates.
|
||||
beta2: A float value. The exponential decay rate for the 2nd moment
|
||||
estimates.
|
||||
epsilon: A small constant for numerical stability.
|
||||
lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster.
|
||||
Please see `optimization_parameters.proto` for details.
|
||||
lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See
|
||||
`optimization_parameters.proto` for details.
|
||||
sum_inside_sqrt: This improves training speed. Please see
|
||||
`optimization_parameters.proto` for details.
|
||||
use_gradient_accumulation: setting this to `False` makes embedding
|
||||
gradients calculation less accurate but faster. Please see
|
||||
`optimization_parameters.proto` for details.
|
||||
for details.
|
||||
clip_weight_min: the minimum value to clip by; None means -infinity.
|
||||
clip_weight_max: the maximum value to clip by; None means +infinity.
|
||||
weight_decay_factor: amount of weight decay to apply; None means that the
|
||||
@ -656,19 +653,18 @@ class FtrlParameters(_OptimizationParameters):
|
||||
Args:
|
||||
learning_rate: a floating point value. The learning rate.
|
||||
learning_rate_power: A float value, must be less or equal to zero.
|
||||
Controls how the learning rate decreases during training. Use zero for
|
||||
a fixed learning rate. See section 3.1 in the
|
||||
Controls how the learning rate decreases during training. Use zero for a
|
||||
fixed learning rate. See section 3.1 in the
|
||||
[paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
|
||||
initial_accumulator_value: The starting value for accumulators.
|
||||
Only zero or positive values are allowed.
|
||||
l1_regularization_strength: A float value, must be greater than or
|
||||
equal to zero.
|
||||
l2_regularization_strength: A float value, must be greater than or
|
||||
equal to zero.
|
||||
initial_accumulator_value: The starting value for accumulators. Only zero
|
||||
or positive values are allowed.
|
||||
l1_regularization_strength: A float value, must be greater than or equal
|
||||
to zero.
|
||||
l2_regularization_strength: A float value, must be greater than or equal
|
||||
to zero.
|
||||
use_gradient_accumulation: setting this to `False` makes embedding
|
||||
gradients calculation less accurate but faster. Please see
|
||||
`optimization_parameters.proto` for details.
|
||||
for details.
|
||||
clip_weight_min: the minimum value to clip by; None means -infinity.
|
||||
clip_weight_max: the maximum value to clip by; None means +infinity.
|
||||
weight_decay_factor: amount of weight decay to apply; None means that the
|
||||
@ -728,13 +724,15 @@ class ProximalYogiParameters(_OptimizationParameters):
|
||||
"""Optimization parameters for Proximal Yogi with TPU embeddings.
|
||||
|
||||
Implements the Yogi optimizer as described in
|
||||
[Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization).
|
||||
[Adaptive Methods for Nonconvex
|
||||
Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization).
|
||||
|
||||
Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
|
||||
`optimization_parameters` argument to set the optimizer and its parameters.
|
||||
See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
|
||||
for more details.
|
||||
"""
|
||||
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
def __init__(
|
||||
@ -1140,6 +1138,7 @@ class TPUEmbedding(object):
|
||||
cluster_def=None,
|
||||
pipeline_execution_with_tensor_core=False,
|
||||
partition_strategy='div',
|
||||
profile_data_directory=None,
|
||||
device_config=None,
|
||||
master_job_name=None):
|
||||
"""API for using TPU for embedding lookups.
|
||||
@ -1166,6 +1165,21 @@ class TPUEmbedding(object):
|
||||
partition_strategy: A string, either 'mod' or 'div', specifying how to map
|
||||
the lookup id to the embedding tensor. For more information see
|
||||
`tf.nn.embedding_lookup_sparse`.
|
||||
profile_data_directory: Directory where embedding lookup statistics are
|
||||
stored. These statistics summarize information about the inputs to the
|
||||
embedding lookup operation, in particular, the average number of
|
||||
embedding IDs per example and how well the embedding IDs are load
|
||||
balanced across the system. The lookup statistics are used during TPU
|
||||
initialization for embedding table partitioning. Collection of lookup
|
||||
statistics is done at runtime by profiling the embedding inputs: only
|
||||
3% of input samples are profiled to minimize host CPU overhead. Once
|
||||
a suitable number of samples are profiled, the lookup statistics are
|
||||
saved to table-specific files in the profile data directory generally
|
||||
at the end of a TPU training loop. The filename corresponding to each
|
||||
table is obtained by hashing table specific parameters (e.g., table
|
||||
name and number of features) and global configuration parameters (e.g.,
|
||||
sharding strategy and task count). The same profile data directory can
|
||||
be shared among several models to reuse embedding lookup statistics.
|
||||
device_config: A DeviceConfig instance, used when `master` and
|
||||
`cluster_def` are both `None`.
|
||||
master_job_name: if set, overrides the master job name used to schedule
|
||||
@ -1179,6 +1193,8 @@ class TPUEmbedding(object):
|
||||
'Invalid partition_strategy {}'.format(partition_strategy))
|
||||
self._partition_strategy = partition_strategy
|
||||
|
||||
self._profile_data_directory = profile_data_directory
|
||||
|
||||
_validate_table_to_config_dict(table_to_config_dict)
|
||||
# Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
|
||||
self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
|
||||
@ -1220,14 +1236,14 @@ class TPUEmbedding(object):
|
||||
self._num_hosts = tpu_system_metadata.num_hosts
|
||||
if master_job_name is None:
|
||||
try:
|
||||
master_job_name = tpu_system_metadata_lib.master_job(master,
|
||||
cluster_def)
|
||||
master_job_name = tpu_system_metadata_lib.master_job(
|
||||
master, cluster_def)
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e) + ' Please specify a master_job_name.')
|
||||
self._hosts = []
|
||||
for device in tpu_system_metadata.devices:
|
||||
if 'device:CPU:' in device.name and (
|
||||
master_job_name is None or master_job_name in device.name):
|
||||
if 'device:CPU:' in device.name and (master_job_name is None or
|
||||
master_job_name in device.name):
|
||||
self._hosts.append(device.name)
|
||||
self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host
|
||||
self._num_cores = tpu_system_metadata.num_cores
|
||||
@ -1244,11 +1260,10 @@ class TPUEmbedding(object):
|
||||
if optimization_parameters is not None:
|
||||
raise ValueError('`optimization_parameters` should be `None` '
|
||||
'for inference mode.')
|
||||
self._optimization_parameters = (
|
||||
StochasticGradientDescentParameters(1.))
|
||||
self._optimization_parameters = (StochasticGradientDescentParameters(1.))
|
||||
else:
|
||||
raise ValueError('`mode` only supports {} and {}; got {}.'
|
||||
.format(TRAINING, INFERENCE, mode))
|
||||
raise ValueError('`mode` only supports {} and {}; got {}.'.format(
|
||||
TRAINING, INFERENCE, mode))
|
||||
self._mode = mode
|
||||
|
||||
# TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler`
|
||||
@ -1259,11 +1274,13 @@ class TPUEmbedding(object):
|
||||
|
||||
self._pipeline_execution_with_tensor_core = (
|
||||
pipeline_execution_with_tensor_core)
|
||||
self._learning_rate_fn = list(set(
|
||||
c.learning_rate_fn for c in self._table_to_config_dict.values()
|
||||
if c.learning_rate_fn is not None))
|
||||
self._learning_rate_fn = list(
|
||||
set(c.learning_rate_fn
|
||||
for c in self._table_to_config_dict.values()
|
||||
if c.learning_rate_fn is not None))
|
||||
self._learning_rate_fn_to_tag = {
|
||||
fn: id for id, fn in enumerate(self._learning_rate_fn)}
|
||||
fn: id for id, fn in enumerate(self._learning_rate_fn)
|
||||
}
|
||||
|
||||
self._config_proto = self._create_config_proto()
|
||||
|
||||
@ -1403,10 +1420,13 @@ class TPUEmbedding(object):
|
||||
elc.TPUEmbeddingConfiguration.MOD)
|
||||
config_proto.pipeline_execution_with_tensor_core = (
|
||||
self._pipeline_execution_with_tensor_core)
|
||||
if self._profile_data_directory:
|
||||
config_proto.profile_data_directory = self._profile_data_directory
|
||||
|
||||
return config_proto
|
||||
|
||||
def create_variables_and_ops(self, embedding_variable_name_by_table=None,
|
||||
def create_variables_and_ops(self,
|
||||
embedding_variable_name_by_table=None,
|
||||
slot_variable_names_by_table=None):
|
||||
"""Create embedding and slot variables, with ops to load and retrieve them.
|
||||
|
||||
@ -1425,8 +1445,8 @@ class TPUEmbedding(object):
|
||||
|
||||
Args:
|
||||
embedding_variable_name_by_table: A dictionary mapping from string of
|
||||
table name to string of embedding variable name. If `None`,
|
||||
defaults from `get_default_slot_variable_names()` will be used.
|
||||
table name to string of embedding variable name. If `None`, defaults
|
||||
from `get_default_slot_variable_names()` will be used.
|
||||
slot_variable_names_by_table: A dictionary mapping from string of table
|
||||
name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If
|
||||
`None`, defaults from `get_default_slot_variable_names()` will be used.
|
||||
@ -1510,8 +1530,7 @@ class TPUEmbedding(object):
|
||||
return retrieve_ops_list
|
||||
|
||||
return VariablesAndOps(embedding_variables_by_table,
|
||||
slot_variables_by_table,
|
||||
load_ops, retrieve_ops)
|
||||
slot_variables_by_table, load_ops, retrieve_ops)
|
||||
|
||||
def generate_enqueue_ops(
|
||||
self,
|
||||
@ -1522,10 +1541,9 @@ class TPUEmbedding(object):
|
||||
"""Generate enqueue ops.
|
||||
|
||||
Args:
|
||||
enqueue_datas_list: a list of dictionary mapping from string
|
||||
of feature names to EnqueueData. Each dictionary is for one
|
||||
TPU core. Dictionaries for the same host should be contiguous
|
||||
on the list.
|
||||
enqueue_datas_list: a list of dictionary mapping from string of feature
|
||||
names to EnqueueData. Each dictionary is for one TPU core. Dictionaries
|
||||
for the same host should be contiguous in the list.
|
||||
mode_override: A string input that overrides the mode specified in the
|
||||
TPUEmbeddingConfiguration. Supported values are {'unspecified',
|
||||
'inference', 'training', 'backward_pass_only'}. When set to
|
||||
@ -1723,8 +1741,8 @@ class TPUEmbedding(object):
|
||||
if enqueue_data.sample_indices is not None else int_zeros)
|
||||
|
||||
kwargs['aggregation_weights'].append(
|
||||
enqueue_data.aggregation_weights if
|
||||
enqueue_data.aggregation_weights is not None else float_zeros)
|
||||
enqueue_data.aggregation_weights
|
||||
if enqueue_data.aggregation_weights is not None else float_zeros)
|
||||
|
||||
kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
|
||||
|
||||
@ -1763,14 +1781,13 @@ class TPUEmbedding(object):
|
||||
feature_index = feature_index + 1
|
||||
else:
|
||||
activations[feature] = (
|
||||
table_activations[:, feature_index:(feature_index+seq_length), :])
|
||||
table_activations[:,
|
||||
feature_index:(feature_index + seq_length), :])
|
||||
feature_index = feature_index + seq_length
|
||||
|
||||
return activations
|
||||
|
||||
def generate_send_gradients_op(self,
|
||||
feature_to_gradient_dict,
|
||||
step=None):
|
||||
def generate_send_gradients_op(self, feature_to_gradient_dict, step=None):
|
||||
"""Send gradient to TPU embedding.
|
||||
|
||||
Args:
|
||||
@ -1786,8 +1803,8 @@ class TPUEmbedding(object):
|
||||
"""
|
||||
if self._mode != TRAINING:
|
||||
raise RuntimeError('Only in training mode gradients need to '
|
||||
'be sent to TPU embedding; got mode {}.'
|
||||
.format(self._mode))
|
||||
'be sent to TPU embedding; got mode {}.'.format(
|
||||
self._mode))
|
||||
if step is None and self._learning_rate_fn:
|
||||
raise ValueError('There are dynamic learning rates but step is None.')
|
||||
|
||||
@ -1808,8 +1825,10 @@ class TPUEmbedding(object):
|
||||
|
||||
return tpu_ops.send_tpu_embedding_gradients(
|
||||
inputs=gradients,
|
||||
learning_rates=[math_ops.cast(fn(step), dtype=dtypes.float32)
|
||||
for fn in self._learning_rate_fn],
|
||||
learning_rates=[
|
||||
math_ops.cast(fn(step), dtype=dtypes.float32)
|
||||
for fn in self._learning_rate_fn
|
||||
],
|
||||
config=self.config_proto.SerializeToString())
|
||||
|
||||
def _get_optimizer_handler_by_table(self):
|
||||
@ -1835,21 +1854,21 @@ def _validate_table_to_config_dict(table_to_config_dict):
|
||||
def _validate_feature_to_config_dict(table_to_config_dict,
|
||||
feature_to_config_dict):
|
||||
"""Validate `feature_to_config_dict`."""
|
||||
used_table_set = set([feature.table_id
|
||||
for feature in feature_to_config_dict.values()])
|
||||
used_table_set = set(
|
||||
[feature.table_id for feature in feature_to_config_dict.values()])
|
||||
table_set = set(table_to_config_dict.keys())
|
||||
|
||||
unused_table_set = table_set - used_table_set
|
||||
if unused_table_set:
|
||||
raise ValueError('`table_to_config_dict` specifies table that is not '
|
||||
'used in `feature_to_config_dict`: {}.'
|
||||
.format(unused_table_set))
|
||||
raise ValueError(
|
||||
'`table_to_config_dict` specifies table that is not '
|
||||
'used in `feature_to_config_dict`: {}.'.format(unused_table_set))
|
||||
|
||||
extra_table_set = used_table_set - table_set
|
||||
if extra_table_set:
|
||||
raise ValueError('`feature_to_config_dict` refers to a table that is not '
|
||||
'specified in `table_to_config_dict`: {}.'
|
||||
.format(extra_table_set))
|
||||
raise ValueError(
|
||||
'`feature_to_config_dict` refers to a table that is not '
|
||||
'specified in `table_to_config_dict`: {}.'.format(extra_table_set))
|
||||
|
||||
|
||||
def _validate_batch_size(batch_size, num_cores):
|
||||
@ -1867,10 +1886,9 @@ def _validate_optimization_parameters(optimization_parameters,
|
||||
|
||||
Args:
|
||||
optimization_parameters: global optimizer provided in `TPUEmbedding`
|
||||
constructor.
|
||||
constructor.
|
||||
table_to_config_dict: A dictionary mapping from string of table name to
|
||||
`TableConfig`.
|
||||
|
||||
"""
|
||||
tbl_optimizer_missing = False
|
||||
for _, table_config in table_to_config_dict.items():
|
||||
@ -2107,8 +2125,7 @@ class _AdamHandler(_OptimizerHandler):
|
||||
load_op_list = []
|
||||
config = config_proto
|
||||
for host_id, table_variable, m_variable, v_variable in (zip(
|
||||
range(num_hosts), table_variables,
|
||||
m_variables, v_variables)):
|
||||
range(num_hosts), table_variables, m_variables, v_variables)):
|
||||
with ops.colocate_with(table_variable):
|
||||
load_parameters_op = (
|
||||
tpu_ops.load_tpu_embedding_adam_parameters(
|
||||
@ -2134,8 +2151,7 @@ class _AdamHandler(_OptimizerHandler):
|
||||
retrieve_op_list = []
|
||||
config = config_proto
|
||||
for host_id, table_variable, m_variable, v_variable in (zip(
|
||||
range(num_hosts), table_variables,
|
||||
m_variables, v_variables)):
|
||||
range(num_hosts), table_variables, m_variables, v_variables)):
|
||||
with ops.colocate_with(table_variable):
|
||||
retrieved_table, retrieved_m, retrieved_v = (
|
||||
tpu_ops.retrieve_tpu_embedding_adam_parameters(
|
||||
@ -2174,8 +2190,9 @@ class _FtrlHandler(_OptimizerHandler):
|
||||
def get_default_slot_variable_names(self, table):
|
||||
# These match the default slot variable names created by
|
||||
# tf.train.FtrlOptimizer.
|
||||
return FtrlSlotVariableName('{}/{}'.format(table, 'Ftrl'), # accumulator
|
||||
'{}/{}'.format(table, 'Ftrl_1')) # linear
|
||||
return FtrlSlotVariableName(
|
||||
'{}/{}'.format(table, 'Ftrl'), # accumulator
|
||||
'{}/{}'.format(table, 'Ftrl_1')) # linear
|
||||
|
||||
def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
|
||||
table_config, table_variables, config_proto):
|
||||
@ -2197,8 +2214,7 @@ class _FtrlHandler(_OptimizerHandler):
|
||||
embedding_dimension=table_config.dimension,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
|
||||
initializer=linear_initializer)
|
||||
slot_variables = FtrlSlotVariable(accumulator_variables,
|
||||
linear_variables)
|
||||
slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables)
|
||||
|
||||
def load_ops_fn():
|
||||
"""Returns the retrieve ops for Ftrl embedding tables.
|
||||
@ -2539,8 +2555,7 @@ class _StochasticGradientDescentHandler(_OptimizerHandler):
|
||||
"""
|
||||
load_op_list = []
|
||||
config = config_proto
|
||||
for host_id, table_variable in (zip(
|
||||
range(num_hosts), table_variables)):
|
||||
for host_id, table_variable in enumerate (table_variables):
|
||||
with ops.colocate_with(table_variable):
|
||||
load_parameters_op = (
|
||||
tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters(
|
||||
@ -2561,8 +2576,7 @@ class _StochasticGradientDescentHandler(_OptimizerHandler):
|
||||
"""
|
||||
retrieve_op_list = []
|
||||
config = config_proto
|
||||
for host_id, table_variable in (zip(
|
||||
range(num_hosts), table_variables)):
|
||||
for host_id, table_variable in enumerate (table_variables):
|
||||
with ops.colocate_with(table_variable):
|
||||
retrieved_table = (
|
||||
tpu_ops
|
||||
|
@ -31,6 +31,10 @@ tf_class {
|
||||
name: "pipeline_execution_with_tensor_core"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "profile_data_directory"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "table_to_config_dict"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user