Make Proximal Yogi available in Python TPU Embedding API

PiperOrigin-RevId: 301657986
Change-Id: I3f1cbc88bdf3fbb729ca16e1597d6d29d76ec464
This commit is contained in:
Philip Pham 2020-03-18 13:42:34 -07:00 committed by TensorFlower Gardener
parent b2a5472997
commit 1e0821e601
8 changed files with 232 additions and 3 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "LoadTPUEmbeddingProximalYogiParameters"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RetrieveTPUEmbeddingProximalYogiParameters"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug"
visibility: HIDDEN
}

View File

@ -546,13 +546,13 @@ Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg,
case OptimizationAlgorithm::kCenteredRmsProp:
case OptimizationAlgorithm::kMdlAdagradLight:
case OptimizationAlgorithm::kAdadelta:
case OptimizationAlgorithm::kProximalAdagrad: {
case OptimizationAlgorithm::kProximalAdagrad:
case OptimizationAlgorithm::kProximalYogi: {
*internal = false;
return Status::OK();
}
case OptimizationAlgorithm::kBoundedAdagrad:
case OptimizationAlgorithm::kOnlineYogi:
case OptimizationAlgorithm::kProximalYogi: {
case OptimizationAlgorithm::kOnlineYogi: {
*internal = true;
return Status::OK();
}

View File

@ -241,6 +241,9 @@ ProximalAdagradSlotVariableName = collections.namedtuple(
FtrlSlotVariableName = collections.namedtuple(
'FtrlSlotVariableName', ['accumulator', 'linear'])
ProximalYogiSlotVariableNames = collections.namedtuple(
'ProximalYogiSlotVariableNames', ['v', 'm'])
AdamSlotVariables = collections.namedtuple(
'AdamSlotVariables', ['m', 'v'])
@ -253,6 +256,9 @@ ProximalAdagradSlotVariable = collections.namedtuple(
FtrlSlotVariable = collections.namedtuple(
'FtrlSlotVariable', ['accumulator', 'linear'])
ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables',
['v', 'm'])
VariablesAndOps = collections.namedtuple(
'VariablesAndOps',
['embedding_variables_by_table', 'slot_variables_by_table',
@ -545,6 +551,83 @@ class FtrlParameters(_OptimizationParameters):
self.l2_regularization_strength = l2_regularization_strength
class ProximalYogiParameters(_OptimizationParameters):
# pylint: disable=line-too-long
"""Optimization parameters for Proximal Yogi with TPU embeddings.
Implements the Yogi optimizer as described in
[Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization).
Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
`optimization_parameters` argument to set the optimizer and its parameters.
See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
for more details.
"""
# pylint: enable=line-too-long
def __init__(self,
learning_rate=0.01,
beta1=0.9,
beta2=0.999,
epsilon=1e-3,
l1_regularization_strength=0.0,
l2_regularization_strength=0.0,
initial_accumulator_value=1e-6,
use_gradient_accumulation=True,
clip_weight_min=None,
clip_weight_max=None,
weight_decay_factor=None,
multiply_weight_decay_factor_by_learning_rate=None):
"""Optimization parameters for Proximal Yogi.
Args:
learning_rate: a floating point value. The learning rate.
beta1: A float value. The exponential decay rate for the 1st moment
estimates.
beta2: A float value. The exponential decay rate for the 2nd moment
estimates.
epsilon: A small constant for numerical stability.
l1_regularization_strength: A float value, must be greater than or equal
to zero.
l2_regularization_strength: A float value, must be greater than or equal
to zero.
initial_accumulator_value: The starting value for accumulators. Only zero
or positive values are allowed.
use_gradient_accumulation: setting this to `False` makes embedding
gradients calculation less accurate but faster. Please see
`optimization_parameters.proto` for details. for details.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
weight_decay_factor: amount of weight decay to apply; None means that the
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
"""
super(ProximalYogiParameters,
self).__init__(learning_rate, use_gradient_accumulation,
clip_weight_min, clip_weight_max, weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate)
if beta1 < 0. or beta1 >= 1.:
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
if beta2 < 0. or beta2 >= 1.:
raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
if epsilon <= 0.:
raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
if l1_regularization_strength < 0.:
raise ValueError('l1_regularization_strength must be greater than or '
'equal to 0. got {}.'.format(l1_regularization_strength))
if l2_regularization_strength < 0.:
raise ValueError('l2_regularization_strength must be greater than or '
'equal to 0. got {}.'.format(l2_regularization_strength))
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.l1_regularization_strength = l1_regularization_strength
self.l2_regularization_strength = l2_regularization_strength
self.initial_accumulator_value = initial_accumulator_value
@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters'])
class StochasticGradientDescentParameters(_OptimizationParameters):
"""Optimization parameters for stochastic gradient descent for TPU embeddings.
@ -1706,6 +1789,102 @@ class _FtrlHandler(_OptimizerHandler):
return slot_variables, load_ops_fn, retrieve_ops_fn
class _ProximalYogiHandler(_OptimizerHandler):
"""Handles Proximal Yogi specific logic."""
def set_optimization_parameters(self, table_descriptor):
table_descriptor.optimization_parameters.proximal_yogi.SetInParent()
table_descriptor.optimization_parameters.proximal_yogi.beta1 = (
self._optimization_parameters.beta1)
table_descriptor.optimization_parameters.proximal_yogi.beta2 = (
self._optimization_parameters.beta2)
table_descriptor.optimization_parameters.proximal_yogi.epsilon = (
self._optimization_parameters.epsilon)
table_descriptor.optimization_parameters.proximal_yogi.l1 = (
self._optimization_parameters.l1_regularization_strength)
table_descriptor.optimization_parameters.proximal_yogi.l2 = (
self._optimization_parameters.l2_regularization_strength)
def get_default_slot_variable_names(self, table):
return ProximalYogiSlotVariableNames(
'{}/{}'.format(table, 'ProximalYogi'), # v
'{}/{}_1'.format(table, 'ProximalYogi')) # m
def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
table_config, table_variables, config_proto):
v_initializer = init_ops.constant_initializer(
self._optimization_parameters.initial_accumulator_value)
v_variables = _create_partitioned_variables(
name=slot_variable_names.v,
num_hosts=num_hosts,
vocabulary_size=table_config.vocabulary_size,
embedding_dimension=table_config.dimension,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
initializer=v_initializer)
m_initializer = init_ops.zeros_initializer()
m_variables = _create_partitioned_variables(
name=slot_variable_names.m,
num_hosts=num_hosts,
vocabulary_size=table_config.vocabulary_size,
embedding_dimension=table_config.dimension,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
initializer=m_initializer)
slot_variables = ProximalYogiSlotVariables(v_variables, m_variables)
def load_ops_fn():
"""Returns the load ops for Proximal Yogi embedding tables.
Returns:
A list of ops to load embedding and slot variables from CPU to TPU.
"""
load_op_list = []
config = config_proto
for host_id, table_variable, v_variable, m_variable in (zip(
range(num_hosts), table_variables, v_variables, m_variables)):
with ops.colocate_with(table_variable):
load_parameters_op = (
tpu_ops.load_tpu_embedding_proximal_yogi_parameters(
parameters=table_variable,
v=v_variable,
m=m_variable,
table_name=table,
num_shards=num_hosts,
shard_id=host_id,
config=config))
# Set config to None to enforce that config is only loaded to the first
# table.
config = None
load_op_list.append(load_parameters_op)
return load_op_list
def retrieve_ops_fn():
"""Returns the retrieve ops for Proximal Yogi embedding tables.
Returns:
A list of ops to retrieve embedding and slot variables from TPU to CPU.
"""
retrieve_op_list = []
config = config_proto
for host_id, table_variable, v_variable, m_variable in (zip(
range(num_hosts), table_variables, v_variables, m_variables)):
with ops.colocate_with(table_variable):
retrieved_table, retrieved_v, retrieved_m = (
tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters(
table_name=table,
num_shards=num_hosts,
shard_id=host_id,
config=config))
retrieve_parameters_op = control_flow_ops.group(
state_ops.assign(table_variable, retrieved_table),
state_ops.assign(v_variable, retrieved_v),
state_ops.assign(m_variable, retrieved_m))
config = None
retrieve_op_list.append(retrieve_parameters_op)
return retrieve_op_list
return slot_variables, load_ops_fn, retrieve_ops_fn
class _StochasticGradientDescentHandler(_OptimizerHandler):
"""Handles stochastic gradient descent specific logic."""
@ -1779,6 +1958,8 @@ def _get_optimization_handler(optimization_parameters):
return _AdamHandler(optimization_parameters)
elif isinstance(optimization_parameters, FtrlParameters):
return _FtrlHandler(optimization_parameters)
elif isinstance(optimization_parameters, ProximalYogiParameters):
return _ProximalYogiHandler(optimization_parameters)
elif isinstance(optimization_parameters, StochasticGradientDescentParameters):
return _StochasticGradientDescentHandler(optimization_parameters)
else:

View File

@ -2072,6 +2072,14 @@ tf_module {
name: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug"
argspec: "args=[\'parameters\', \'accumulators\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingProximalYogiParameters"
argspec: "args=[\'parameters\', \'v\', \'m\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug"
argspec: "args=[\'parameters\', \'v\', \'m\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingRMSPropParameters"
argspec: "args=[\'parameters\', \'ms\', \'mom\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
@ -3608,6 +3616,14 @@ tf_module {
name: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingProximalYogiParameters"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingRMSPropParameters"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "

View File

@ -2072,6 +2072,14 @@ tf_module {
name: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug"
argspec: "args=[\'parameters\', \'accumulators\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingProximalYogiParameters"
argspec: "args=[\'parameters\', \'v\', \'m\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug"
argspec: "args=[\'parameters\', \'v\', \'m\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingRMSPropParameters"
argspec: "args=[\'parameters\', \'ms\', \'mom\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
@ -3608,6 +3616,14 @@ tf_module {
name: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingProximalYogiParameters"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingRMSPropParameters"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "