Make Proximal Yogi available in Python TPU Embedding API
PiperOrigin-RevId: 301657986 Change-Id: I3f1cbc88bdf3fbb729ca16e1597d6d29d76ec464
This commit is contained in:
parent
b2a5472997
commit
1e0821e601
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "LoadTPUEmbeddingProximalYogiParameters"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RetrieveTPUEmbeddingProximalYogiParameters"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -546,13 +546,13 @@ Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg,
|
||||
case OptimizationAlgorithm::kCenteredRmsProp:
|
||||
case OptimizationAlgorithm::kMdlAdagradLight:
|
||||
case OptimizationAlgorithm::kAdadelta:
|
||||
case OptimizationAlgorithm::kProximalAdagrad: {
|
||||
case OptimizationAlgorithm::kProximalAdagrad:
|
||||
case OptimizationAlgorithm::kProximalYogi: {
|
||||
*internal = false;
|
||||
return Status::OK();
|
||||
}
|
||||
case OptimizationAlgorithm::kBoundedAdagrad:
|
||||
case OptimizationAlgorithm::kOnlineYogi:
|
||||
case OptimizationAlgorithm::kProximalYogi: {
|
||||
case OptimizationAlgorithm::kOnlineYogi: {
|
||||
*internal = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -241,6 +241,9 @@ ProximalAdagradSlotVariableName = collections.namedtuple(
|
||||
FtrlSlotVariableName = collections.namedtuple(
|
||||
'FtrlSlotVariableName', ['accumulator', 'linear'])
|
||||
|
||||
ProximalYogiSlotVariableNames = collections.namedtuple(
|
||||
'ProximalYogiSlotVariableNames', ['v', 'm'])
|
||||
|
||||
AdamSlotVariables = collections.namedtuple(
|
||||
'AdamSlotVariables', ['m', 'v'])
|
||||
|
||||
@ -253,6 +256,9 @@ ProximalAdagradSlotVariable = collections.namedtuple(
|
||||
FtrlSlotVariable = collections.namedtuple(
|
||||
'FtrlSlotVariable', ['accumulator', 'linear'])
|
||||
|
||||
ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables',
|
||||
['v', 'm'])
|
||||
|
||||
VariablesAndOps = collections.namedtuple(
|
||||
'VariablesAndOps',
|
||||
['embedding_variables_by_table', 'slot_variables_by_table',
|
||||
@ -545,6 +551,83 @@ class FtrlParameters(_OptimizationParameters):
|
||||
self.l2_regularization_strength = l2_regularization_strength
|
||||
|
||||
|
||||
class ProximalYogiParameters(_OptimizationParameters):
|
||||
# pylint: disable=line-too-long
|
||||
"""Optimization parameters for Proximal Yogi with TPU embeddings.
|
||||
|
||||
Implements the Yogi optimizer as described in
|
||||
[Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization).
|
||||
|
||||
Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
|
||||
`optimization_parameters` argument to set the optimizer and its parameters.
|
||||
See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
|
||||
for more details.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
def __init__(self,
|
||||
learning_rate=0.01,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-3,
|
||||
l1_regularization_strength=0.0,
|
||||
l2_regularization_strength=0.0,
|
||||
initial_accumulator_value=1e-6,
|
||||
use_gradient_accumulation=True,
|
||||
clip_weight_min=None,
|
||||
clip_weight_max=None,
|
||||
weight_decay_factor=None,
|
||||
multiply_weight_decay_factor_by_learning_rate=None):
|
||||
"""Optimization parameters for Proximal Yogi.
|
||||
|
||||
Args:
|
||||
learning_rate: a floating point value. The learning rate.
|
||||
beta1: A float value. The exponential decay rate for the 1st moment
|
||||
estimates.
|
||||
beta2: A float value. The exponential decay rate for the 2nd moment
|
||||
estimates.
|
||||
epsilon: A small constant for numerical stability.
|
||||
l1_regularization_strength: A float value, must be greater than or equal
|
||||
to zero.
|
||||
l2_regularization_strength: A float value, must be greater than or equal
|
||||
to zero.
|
||||
initial_accumulator_value: The starting value for accumulators. Only zero
|
||||
or positive values are allowed.
|
||||
use_gradient_accumulation: setting this to `False` makes embedding
|
||||
gradients calculation less accurate but faster. Please see
|
||||
`optimization_parameters.proto` for details. for details.
|
||||
clip_weight_min: the minimum value to clip by; None means -infinity.
|
||||
clip_weight_max: the maximum value to clip by; None means +infinity.
|
||||
weight_decay_factor: amount of weight decay to apply; None means that the
|
||||
weights are not decayed.
|
||||
multiply_weight_decay_factor_by_learning_rate: if true,
|
||||
`weight_decay_factor` is multiplied by the current learning rate.
|
||||
"""
|
||||
super(ProximalYogiParameters,
|
||||
self).__init__(learning_rate, use_gradient_accumulation,
|
||||
clip_weight_min, clip_weight_max, weight_decay_factor,
|
||||
multiply_weight_decay_factor_by_learning_rate)
|
||||
if beta1 < 0. or beta1 >= 1.:
|
||||
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
|
||||
if beta2 < 0. or beta2 >= 1.:
|
||||
raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
|
||||
if epsilon <= 0.:
|
||||
raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
|
||||
if l1_regularization_strength < 0.:
|
||||
raise ValueError('l1_regularization_strength must be greater than or '
|
||||
'equal to 0. got {}.'.format(l1_regularization_strength))
|
||||
if l2_regularization_strength < 0.:
|
||||
raise ValueError('l2_regularization_strength must be greater than or '
|
||||
'equal to 0. got {}.'.format(l2_regularization_strength))
|
||||
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.epsilon = epsilon
|
||||
self.l1_regularization_strength = l1_regularization_strength
|
||||
self.l2_regularization_strength = l2_regularization_strength
|
||||
self.initial_accumulator_value = initial_accumulator_value
|
||||
|
||||
|
||||
@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters'])
|
||||
class StochasticGradientDescentParameters(_OptimizationParameters):
|
||||
"""Optimization parameters for stochastic gradient descent for TPU embeddings.
|
||||
@ -1706,6 +1789,102 @@ class _FtrlHandler(_OptimizerHandler):
|
||||
return slot_variables, load_ops_fn, retrieve_ops_fn
|
||||
|
||||
|
||||
class _ProximalYogiHandler(_OptimizerHandler):
|
||||
"""Handles Proximal Yogi specific logic."""
|
||||
|
||||
def set_optimization_parameters(self, table_descriptor):
|
||||
table_descriptor.optimization_parameters.proximal_yogi.SetInParent()
|
||||
table_descriptor.optimization_parameters.proximal_yogi.beta1 = (
|
||||
self._optimization_parameters.beta1)
|
||||
table_descriptor.optimization_parameters.proximal_yogi.beta2 = (
|
||||
self._optimization_parameters.beta2)
|
||||
table_descriptor.optimization_parameters.proximal_yogi.epsilon = (
|
||||
self._optimization_parameters.epsilon)
|
||||
table_descriptor.optimization_parameters.proximal_yogi.l1 = (
|
||||
self._optimization_parameters.l1_regularization_strength)
|
||||
table_descriptor.optimization_parameters.proximal_yogi.l2 = (
|
||||
self._optimization_parameters.l2_regularization_strength)
|
||||
|
||||
def get_default_slot_variable_names(self, table):
|
||||
return ProximalYogiSlotVariableNames(
|
||||
'{}/{}'.format(table, 'ProximalYogi'), # v
|
||||
'{}/{}_1'.format(table, 'ProximalYogi')) # m
|
||||
|
||||
def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
|
||||
table_config, table_variables, config_proto):
|
||||
v_initializer = init_ops.constant_initializer(
|
||||
self._optimization_parameters.initial_accumulator_value)
|
||||
v_variables = _create_partitioned_variables(
|
||||
name=slot_variable_names.v,
|
||||
num_hosts=num_hosts,
|
||||
vocabulary_size=table_config.vocabulary_size,
|
||||
embedding_dimension=table_config.dimension,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
|
||||
initializer=v_initializer)
|
||||
m_initializer = init_ops.zeros_initializer()
|
||||
m_variables = _create_partitioned_variables(
|
||||
name=slot_variable_names.m,
|
||||
num_hosts=num_hosts,
|
||||
vocabulary_size=table_config.vocabulary_size,
|
||||
embedding_dimension=table_config.dimension,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
|
||||
initializer=m_initializer)
|
||||
slot_variables = ProximalYogiSlotVariables(v_variables, m_variables)
|
||||
|
||||
def load_ops_fn():
|
||||
"""Returns the load ops for Proximal Yogi embedding tables.
|
||||
|
||||
Returns:
|
||||
A list of ops to load embedding and slot variables from CPU to TPU.
|
||||
"""
|
||||
load_op_list = []
|
||||
config = config_proto
|
||||
for host_id, table_variable, v_variable, m_variable in (zip(
|
||||
range(num_hosts), table_variables, v_variables, m_variables)):
|
||||
with ops.colocate_with(table_variable):
|
||||
load_parameters_op = (
|
||||
tpu_ops.load_tpu_embedding_proximal_yogi_parameters(
|
||||
parameters=table_variable,
|
||||
v=v_variable,
|
||||
m=m_variable,
|
||||
table_name=table,
|
||||
num_shards=num_hosts,
|
||||
shard_id=host_id,
|
||||
config=config))
|
||||
# Set config to None to enforce that config is only loaded to the first
|
||||
# table.
|
||||
config = None
|
||||
load_op_list.append(load_parameters_op)
|
||||
return load_op_list
|
||||
|
||||
def retrieve_ops_fn():
|
||||
"""Returns the retrieve ops for Proximal Yogi embedding tables.
|
||||
|
||||
Returns:
|
||||
A list of ops to retrieve embedding and slot variables from TPU to CPU.
|
||||
"""
|
||||
retrieve_op_list = []
|
||||
config = config_proto
|
||||
for host_id, table_variable, v_variable, m_variable in (zip(
|
||||
range(num_hosts), table_variables, v_variables, m_variables)):
|
||||
with ops.colocate_with(table_variable):
|
||||
retrieved_table, retrieved_v, retrieved_m = (
|
||||
tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters(
|
||||
table_name=table,
|
||||
num_shards=num_hosts,
|
||||
shard_id=host_id,
|
||||
config=config))
|
||||
retrieve_parameters_op = control_flow_ops.group(
|
||||
state_ops.assign(table_variable, retrieved_table),
|
||||
state_ops.assign(v_variable, retrieved_v),
|
||||
state_ops.assign(m_variable, retrieved_m))
|
||||
config = None
|
||||
retrieve_op_list.append(retrieve_parameters_op)
|
||||
return retrieve_op_list
|
||||
|
||||
return slot_variables, load_ops_fn, retrieve_ops_fn
|
||||
|
||||
|
||||
class _StochasticGradientDescentHandler(_OptimizerHandler):
|
||||
"""Handles stochastic gradient descent specific logic."""
|
||||
|
||||
@ -1779,6 +1958,8 @@ def _get_optimization_handler(optimization_parameters):
|
||||
return _AdamHandler(optimization_parameters)
|
||||
elif isinstance(optimization_parameters, FtrlParameters):
|
||||
return _FtrlHandler(optimization_parameters)
|
||||
elif isinstance(optimization_parameters, ProximalYogiParameters):
|
||||
return _ProximalYogiHandler(optimization_parameters)
|
||||
elif isinstance(optimization_parameters, StochasticGradientDescentParameters):
|
||||
return _StochasticGradientDescentHandler(optimization_parameters)
|
||||
else:
|
||||
|
@ -2072,6 +2072,14 @@ tf_module {
|
||||
name: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug"
|
||||
argspec: "args=[\'parameters\', \'accumulators\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingProximalYogiParameters"
|
||||
argspec: "args=[\'parameters\', \'v\', \'m\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug"
|
||||
argspec: "args=[\'parameters\', \'v\', \'m\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingRMSPropParameters"
|
||||
argspec: "args=[\'parameters\', \'ms\', \'mom\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
@ -3608,6 +3616,14 @@ tf_module {
|
||||
name: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingProximalYogiParameters"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingRMSPropParameters"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
|
@ -2072,6 +2072,14 @@ tf_module {
|
||||
name: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug"
|
||||
argspec: "args=[\'parameters\', \'accumulators\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingProximalYogiParameters"
|
||||
argspec: "args=[\'parameters\', \'v\', \'m\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug"
|
||||
argspec: "args=[\'parameters\', \'v\', \'m\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingRMSPropParameters"
|
||||
argspec: "args=[\'parameters\', \'ms\', \'mom\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
@ -3608,6 +3616,14 @@ tf_module {
|
||||
name: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingProximalYogiParameters"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingRMSPropParameters"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user