Update tpu_embeddings separate per table and table specific parameters. Enable use of tpu_embedding optimizer classes in tpu_estimator.
PiperOrigin-RevId: 235998702
This commit is contained in:
parent
434bf84d50
commit
4c6563e4d8
tensorflow/python/tpu
@ -25,6 +25,9 @@ from tensorflow.python.feature_column import feature_column as core_fc
|
||||
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
|
||||
from tensorflow.python.tpu import feature_column as tpu_fc
|
||||
from tensorflow.python.tpu import tpu_embedding
|
||||
from tensorflow.python.tpu.tpu_embedding import AdagradParameters
|
||||
from tensorflow.python.tpu.tpu_embedding import AdamParameters
|
||||
from tensorflow.python.tpu.tpu_embedding import StochasticGradientDescentParameters
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,
|
||||
@ -33,6 +36,8 @@ _EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn,
|
||||
core_fc_lib.EmbeddingColumn,
|
||||
core_fc._SharedEmbeddingColumn)
|
||||
_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn)
|
||||
_SUPPORTED_OPTIMIZERS = (AdagradParameters, AdamParameters,
|
||||
StochasticGradientDescentParameters)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@ -139,68 +144,25 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
||||
return table_to_config, feature_to_table
|
||||
|
||||
|
||||
def _get_tpu_embedding_optimization_parameters(embedding_config_spec):
|
||||
"""Get tpu_embedding._OptimizationParameters from EmbeddingConfigSpec."""
|
||||
if embedding_config_spec.optimizer_type == 'adagrad':
|
||||
return tpu_embedding.AdagradParameters(
|
||||
embedding_config_spec.learning_rate,
|
||||
embedding_config_spec.adagrad_initial_accumulator,
|
||||
embedding_config_spec.use_gradient_accumulation)
|
||||
elif embedding_config_spec.optimizer_type == 'sgd':
|
||||
return tpu_embedding.StochasticGradientDescentParameters(
|
||||
embedding_config_spec.learning_rate,
|
||||
embedding_config_spec.use_gradient_accumulation)
|
||||
elif embedding_config_spec.optimizer_type == 'adam':
|
||||
return tpu_embedding.AdamParameters(
|
||||
embedding_config_spec.learning_rate,
|
||||
embedding_config_spec.adam_parameters.beta1,
|
||||
embedding_config_spec.adam_parameters.beta2,
|
||||
embedding_config_spec.adam_parameters.epsilon,
|
||||
use_gradient_accumulation=embedding_config_spec
|
||||
.use_gradient_accumulation)
|
||||
else:
|
||||
raise ValueError('optimizer_type must be adagrad or sgd or adam for now.')
|
||||
|
||||
|
||||
AdamParameters = collections.namedtuple('AdamParameters',
|
||||
['beta1', 'beta2', 'epsilon'])
|
||||
|
||||
|
||||
# TODO(shizhiw): Improve the API to support more optimizer parameters in API.
|
||||
class EmbeddingConfigSpec(
|
||||
collections.namedtuple('EmbeddingConfigSpec', [
|
||||
'feature_columns', 'learning_rate', 'optimizer_type',
|
||||
'adagrad_initial_accumulator', 'clipping_limit',
|
||||
'use_gradient_accumulation', 'adam_parameters'
|
||||
'feature_columns', 'optimization_parameters', 'clipping_limit',
|
||||
])):
|
||||
"""Class to keep track of embedding config specification."""
|
||||
|
||||
def __new__(cls,
|
||||
feature_columns,
|
||||
learning_rate,
|
||||
optimizer_type='adagrad',
|
||||
adagrad_initial_accumulator=None,
|
||||
clipping_limit=None,
|
||||
use_gradient_accumulation=False,
|
||||
adam_parameters=None):
|
||||
optimization_parameters,
|
||||
clipping_limit=None):
|
||||
"""Creates an EmbeddingConfigSpec instance.
|
||||
|
||||
Args:
|
||||
feature_columns: All `FeatureColumn`s used by model.
|
||||
learning_rate: embedding optimizer learning rate.
|
||||
optimizer_type: (String) Name of the optimizer for embedding gradients
|
||||
updates. Must be either 'adagrad' ( `tf.train.AdagradOptimizer`, default
|
||||
value), 'sgd' (`tf.train.GradientDescentOptimizer`), or 'adam'
|
||||
(`tf.contrib.opt.LazyAdamOptimizer`) for lazy Adam. This optimizer will
|
||||
be applied to all embedding variables specified by `feature_columns`.
|
||||
adagrad_initial_accumulator: Initial accumulator for Adagrad. Used when
|
||||
optimizer_type is 'adagrad'. Default is `0.1`.
|
||||
optimization_parameters: An instance of `AdagradParameters`,
|
||||
`AdamParameters` or `StochasticGradientDescentParameters`. This
|
||||
optimizer will be applied to all embedding variables specified by
|
||||
`feature_columns`.
|
||||
clipping_limit: (Optional) Clipping limit (absolute value).
|
||||
use_gradient_accumulation: (Experimental) Whether to accumulate the
|
||||
gradients across TPU embedding mini-batches. Gradient accumulation does
|
||||
not affect SGD and therefore this is applicable only for Adagrad.
|
||||
adam_parameters: AdamParameters. Used when optimizer_type is 'adam'.
|
||||
Default is 0.9 for beta1, 0.999 for beta2 and 1e-8 for epsilon.
|
||||
|
||||
Returns:
|
||||
An EmbeddingConfigSpec instance.
|
||||
@ -210,9 +172,7 @@ class EmbeddingConfigSpec(
|
||||
TypeError: If the feature columns are not of ths correct type (one of
|
||||
_SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR
|
||||
_EMBEDDING_COLUMN_CLASSES).
|
||||
ValueError: If use_gradient_accumulation is True for SGD.
|
||||
ValueError: If `optimizer_type` is not one of "adagrad" or "sgd" or
|
||||
"adam".
|
||||
ValueError: If `optimization_parameters` is not one of the required types.
|
||||
"""
|
||||
if not feature_columns:
|
||||
raise ValueError('`feature_columns` cannot be `None` or empty.')
|
||||
@ -229,38 +189,16 @@ class EmbeddingConfigSpec(
|
||||
'All feature columns must be supported types in {}. Got {}'.format(
|
||||
supported_classes, type(column)))
|
||||
|
||||
if optimizer_type == 'adagrad':
|
||||
if adagrad_initial_accumulator is None:
|
||||
adagrad_initial_accumulator = 0.1
|
||||
if adagrad_initial_accumulator <= 0:
|
||||
raise ValueError('Adagrad initial_accumulator must be positive')
|
||||
elif optimizer_type == 'sgd':
|
||||
if use_gradient_accumulation:
|
||||
raise ValueError('Gradient accumulation makes sense for Adagrad only.')
|
||||
elif optimizer_type == 'adam':
|
||||
if adam_parameters is None:
|
||||
adam_parameters = AdamParameters(0.9, 0.999, 1e-8)
|
||||
if adam_parameters.beta1 < 0. or adam_parameters.beta1 >= 1.:
|
||||
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(
|
||||
adam_parameters.beta1))
|
||||
if adam_parameters.beta2 < 0. or adam_parameters.beta2 >= 1.:
|
||||
raise ValueError('beta2 must be between 0. and 1; got {}.'.format(
|
||||
adam_parameters.beta2))
|
||||
if adam_parameters.epsilon <= 0.:
|
||||
raise ValueError('epsilon must be positive; got {}.'.format(
|
||||
adam_parameters.epsilon))
|
||||
else:
|
||||
raise ValueError('optimizer_type must be adagrad or sgd or adam for now.')
|
||||
if not isinstance(optimization_parameters, _SUPPORTED_OPTIMIZERS):
|
||||
raise ValueError('optimization_parameters must be an instance of type '
|
||||
'{}. Got {}.'.format(_SUPPORTED_OPTIMIZERS,
|
||||
type(optimization_parameters)))
|
||||
|
||||
return super(EmbeddingConfigSpec, cls).__new__(
|
||||
cls,
|
||||
feature_columns=feature_columns,
|
||||
learning_rate=learning_rate,
|
||||
optimizer_type=optimizer_type,
|
||||
adagrad_initial_accumulator=adagrad_initial_accumulator,
|
||||
clipping_limit=clipping_limit,
|
||||
use_gradient_accumulation=use_gradient_accumulation,
|
||||
adam_parameters=adam_parameters)
|
||||
optimization_parameters=optimization_parameters,
|
||||
clipping_limit=clipping_limit)
|
||||
|
||||
|
||||
class EmbeddingConfig(object):
|
||||
@ -283,8 +221,6 @@ class EmbeddingConfig(object):
|
||||
self._table_to_config_dict, self._feature_to_table_dict = (
|
||||
get_tpu_embedding_config_from_feature_columns(
|
||||
embedding_config_spec.feature_columns))
|
||||
self._optimization_parameters = _get_tpu_embedding_optimization_parameters(
|
||||
self._embedding_config_spec)
|
||||
self._mode_to_tpu_embedding_dict = {}
|
||||
self.dummy_table_variables = None
|
||||
|
||||
@ -317,7 +253,7 @@ class EmbeddingConfig(object):
|
||||
batch_size,
|
||||
tpu_embedding_mode,
|
||||
master,
|
||||
self._optimization_parameters,
|
||||
self._embedding_config_spec.optimization_parameters,
|
||||
cluster_def,
|
||||
)
|
||||
return tpu_embedding_
|
||||
|
@ -116,42 +116,33 @@ VariablesAndOps = collections.namedtuple(
|
||||
)
|
||||
|
||||
|
||||
# TODO(shizhiw): Factor `use_gradient_accumulation` and
|
||||
# `pipeline_execution_with_tensor_core` out of `_OptimizationParameters`.
|
||||
class _OptimizationParameters(object):
|
||||
"""Parameters common to all optimizations."""
|
||||
|
||||
def __init__(self, learning_rate, use_gradient_accumulation,
|
||||
pipeline_execution_with_tensor_core):
|
||||
def __init__(self, learning_rate, use_gradient_accumulation):
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_accumulation = use_gradient_accumulation
|
||||
self.pipeline_execution_with_tensor_core = (
|
||||
pipeline_execution_with_tensor_core)
|
||||
|
||||
|
||||
class AdagradParameters(_OptimizationParameters):
|
||||
"""Optimization parameters for Adagrad."""
|
||||
|
||||
def __init__(self, learning_rate, initial_accumulator,
|
||||
use_gradient_accumulation=False,
|
||||
pipeline_execution_with_tensor_core=True):
|
||||
def __init__(self, learning_rate, initial_accumulator=0.1,
|
||||
use_gradient_accumulation=True):
|
||||
"""Optimization parameters for Adagrad.
|
||||
|
||||
Args:
|
||||
learning_rate: used for updating embedding table.
|
||||
initial_accumulator: initial accumulator for Adagrad.
|
||||
use_gradient_accumulation: setting this to `True` makes embedding
|
||||
gradients calculation more accurate but slower. Please see
|
||||
`optimization_parameters.proto` for details.
|
||||
for details.
|
||||
pipeline_execution_with_tensor_core: setting this to `True` makes training
|
||||
faster, but trained model will be different if step N and step N+1
|
||||
involve the same set of embedding ID. Please see
|
||||
`tpu_embedding_configuration.proto` for details.
|
||||
use_gradient_accumulation: setting this to `False` makes embedding
|
||||
gradients calculation less accurate but faster. Please see
|
||||
`optimization_parameters.proto` for details.
|
||||
for details.
|
||||
"""
|
||||
super(AdagradParameters, self).__init__(learning_rate,
|
||||
use_gradient_accumulation,
|
||||
pipeline_execution_with_tensor_core)
|
||||
use_gradient_accumulation)
|
||||
if initial_accumulator <= 0:
|
||||
raise ValueError('Adagrad initial_accumulator must be positive')
|
||||
self.initial_accumulator = initial_accumulator
|
||||
|
||||
|
||||
@ -164,8 +155,7 @@ class AdamParameters(_OptimizationParameters):
|
||||
epsilon=1e-08,
|
||||
lazy_adam=True,
|
||||
sum_inside_sqrt=True,
|
||||
use_gradient_accumulation=False,
|
||||
pipeline_execution_with_tensor_core=True):
|
||||
use_gradient_accumulation=True):
|
||||
"""Optimization parameters for Adam.
|
||||
|
||||
Args:
|
||||
@ -179,18 +169,23 @@ class AdamParameters(_OptimizationParameters):
|
||||
Please see `optimization_parameters.proto` for details.
|
||||
sum_inside_sqrt: This improves training speed. Please see
|
||||
`optimization_parameters.proto` for details.
|
||||
use_gradient_accumulation: setting this to `True` makes embedding
|
||||
gradients calculation more accurate but slower. Please see
|
||||
use_gradient_accumulation: setting this to `False` makes embedding
|
||||
gradients calculation less accurate but faster. Please see
|
||||
`optimization_parameters.proto` for details.
|
||||
for details.
|
||||
pipeline_execution_with_tensor_core: setting this to `True` makes training
|
||||
faster, but trained model will be different if step N and step N+1
|
||||
involve the same set of embedding ID. Please see
|
||||
`tpu_embedding_configuration.proto` for details.
|
||||
"""
|
||||
super(AdamParameters, self).__init__(learning_rate,
|
||||
use_gradient_accumulation,
|
||||
pipeline_execution_with_tensor_core)
|
||||
use_gradient_accumulation)
|
||||
if beta1 < 0. or beta1 >= 1.:
|
||||
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
|
||||
if beta2 < 0. or beta2 >= 1.:
|
||||
raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
|
||||
if epsilon <= 0.:
|
||||
raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
|
||||
if not use_gradient_accumulation and not lazy_adam:
|
||||
raise ValueError(
|
||||
'When disabling Lazy Adam, gradient accumulation must be used.')
|
||||
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.epsilon = epsilon
|
||||
@ -203,20 +198,11 @@ class StochasticGradientDescentParameters(_OptimizationParameters):
|
||||
|
||||
Args:
|
||||
learning_rate: a floating point value. The learning rate.
|
||||
use_gradient_accumulation: setting this to `True` makes embedding
|
||||
gradients calculation more accurate but slower. Please see
|
||||
`optimization_parameters.proto` for details.
|
||||
pipeline_execution_with_tensor_core: setting this to `True` makes training
|
||||
faster, but trained model will be different if step N and step N+1
|
||||
involve the same set of embedding ID. Please see
|
||||
`tpu_embedding_configuration.proto` for details.
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate, use_gradient_accumulation=False,
|
||||
pipeline_execution_with_tensor_core=True):
|
||||
def __init__(self, learning_rate):
|
||||
super(StochasticGradientDescentParameters, self).__init__(
|
||||
learning_rate, use_gradient_accumulation,
|
||||
pipeline_execution_with_tensor_core)
|
||||
learning_rate, False)
|
||||
|
||||
|
||||
class TPUEmbedding(object):
|
||||
@ -309,7 +295,8 @@ class TPUEmbedding(object):
|
||||
mode,
|
||||
master,
|
||||
optimization_parameters=None,
|
||||
cluster_def=None):
|
||||
cluster_def=None,
|
||||
pipeline_execution_with_tensor_core=True):
|
||||
"""API for using TPU for embedding lookups.
|
||||
|
||||
Args:
|
||||
@ -326,6 +313,10 @@ class TPUEmbedding(object):
|
||||
`Stochasticgradientdescentparameters`. Must be set in training and must
|
||||
be `None` in inference.
|
||||
cluster_def: A ClusterDef object describing the TPU cluster.
|
||||
pipeline_execution_with_tensor_core: setting this to `True` makes training
|
||||
faster, but trained model will be different if step N and step N+1
|
||||
involve the same set of embedding ID. Please see
|
||||
`tpu_embedding_configuration.proto` for details.
|
||||
|
||||
Raises:
|
||||
ValueError: if any input is invalid.
|
||||
@ -384,6 +375,8 @@ class TPUEmbedding(object):
|
||||
# on get_slot().
|
||||
self._optimizer_handler = _get_optimization_handler(
|
||||
self._optimization_parameters)
|
||||
self._pipeline_execution_with_tensor_core = (
|
||||
pipeline_execution_with_tensor_core)
|
||||
|
||||
self._config_proto = self._create_config_proto()
|
||||
|
||||
@ -483,7 +476,7 @@ class TPUEmbedding(object):
|
||||
config_proto.num_tensor_cores = self._num_cores
|
||||
config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT
|
||||
config_proto.pipeline_execution_with_tensor_core = (
|
||||
self._optimization_parameters.pipeline_execution_with_tensor_core)
|
||||
self._pipeline_execution_with_tensor_core)
|
||||
|
||||
return config_proto
|
||||
|
||||
@ -940,8 +933,7 @@ class _AdamHandler(_OptimizerHandler):
|
||||
table_name=table,
|
||||
num_shards=num_hosts,
|
||||
shard_id=host_id))
|
||||
|
||||
load_op_list.append(load_parameters_op)
|
||||
load_op_list.append(load_parameters_op)
|
||||
return load_op_list
|
||||
|
||||
def retrieve_ops_fn():
|
||||
|
@ -72,7 +72,9 @@ from tensorflow.python.tpu import tpu_feed
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
from tensorflow.python.tpu import training_loop
|
||||
from tensorflow.python.tpu import util as util_lib
|
||||
from tensorflow.python.tpu._tpu_estimator_embedding import AdagradParameters # pylint: disable=unused-import
|
||||
from tensorflow.python.tpu._tpu_estimator_embedding import AdamParameters # pylint: disable=unused-import
|
||||
from tensorflow.python.tpu._tpu_estimator_embedding import StochasticGradientDescentParameters # pylint: disable=unused-import
|
||||
from tensorflow.python.tpu._tpu_estimator_embedding import EmbeddingConfigSpec # pylint: disable=unused-import
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
|
Loading…
Reference in New Issue
Block a user