Add support for FrequencyEstimator optimizer to the mid level tpu embedding API.

PiperOrigin-RevId: 345126452
Change-Id: I7cc3345a3efd8f34bfd49b6f9b88ec9c48de76bd
This commit is contained in:
Bruce Fontaine 2020-12-01 16:42:55 -08:00 committed by TensorFlower Gardener
parent 8493ce6116
commit 249c207846
8 changed files with 339 additions and 0 deletions

View File

@ -0,0 +1,24 @@
op {
graph_op_name: "LoadTPUEmbeddingFrequencyEstimatorParameters"
visibility: HIDDEN
in_arg {
name: "parameters"
description: <<END
Value of parameters used in the frequency estimator optimization algorithm.
END
}
in_arg {
name: "last_hit_step"
description: <<END
Value of last_hit_step used in the frequency estimator optimization algorithm.
END
}
summary: "Load frequency estimator embedding parameters."
description: <<END
An op that loads optimization parameters into HBM for embedding. Must be
preceded by a ConfigureTPUEmbeddingHost op that sets up the correct
embedding table configuration. For example, this op is used to install
parameters that are loaded from a checkpoint before a training loop is
executed.
END
}

View File

@ -0,0 +1,31 @@
op {
graph_op_name: "LoadTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
visibility: HIDDEN
in_arg {
name: "parameters"
description: <<END
Value of parameters used in the frequency estimator optimization algorithm.
END
}
in_arg {
name: "last_hit_step"
description: <<END
Value of last_hit_step used in the frequency estimator optimization algorithm.
END
}
in_arg {
name: "gradient_accumulators"
description: <<END
Value of gradient_accumulators used in the frequency estimator optimization
algorithm.
END
}
summary: "Load frequency estimator embedding parameters with debug support."
description: <<END
An op that loads optimization parameters into HBM for embedding. Must be
preceded by a ConfigureTPUEmbeddingHost op that sets up the correct
embedding table configuration. For example, this op is used to install
parameters that are loaded from a checkpoint before a training loop is
executed.
END
}

View File

@ -0,0 +1,24 @@
op {
graph_op_name: "RetrieveTPUEmbeddingFrequencyEstimatorParameters"
visibility: HIDDEN
out_arg {
name: "parameters"
description: <<END
Parameter parameters updated by the frequency estimator optimization algorithm.
END
}
out_arg {
name: "last_hit_step"
description: <<END
Parameter last_hit_step updated by the frequency estimator optimization
algorithm.
END
}
summary: "Retrieve frequency estimator embedding parameters."
description: <<END
An op that retrieves optimization parameters from embedding to host
memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up
the correct embedding table configuration. For example, this op is
used to retrieve updated parameters before saving a checkpoint.
END
}

View File

@ -0,0 +1,33 @@
op {
graph_op_name:
"RetrieveTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
visibility: HIDDEN
out_arg {
name: "parameters"
description: <<END
Parameter parameters updated by the frequency estimator optimization algorithm.
END
}
out_arg {
name: "last_hit_step"
description: <<END
Parameter last_hit_step updated by the frequency estimator optimization
algorithm.
END
}
out_arg {
name: "gradient_accumulators"
description: <<END
Parameter gradient_accumulators updated by the frequency estimator optimization
algorithm.
END
}
summary:
"Retrieve frequency estimator embedding parameters with debug support."
description: <<END
An op that retrieves optimization parameters from embedding to host
memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up
the correct embedding table configuration. For example, this op is
used to retrieve updated parameters before saving a checkpoint.
END
}

View File

@ -511,5 +511,51 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingFrequencyEstimatorParameters")
.Input("parameters: float32")
.Input("last_hit_step: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("LoadTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug")
.Input("parameters: float32")
.Input("last_hit_step: float32")
.Input("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(LoadOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingFrequencyEstimatorParameters")
.Output("parameters: float32")
.Output("last_hit_step: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction());
REGISTER_OP("RetrieveTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug")
.Output("parameters: float32")
.Output("last_hit_step: float32")
.Output("gradient_accumulators: float32")
.Attr("table_id: int = -1")
.Attr("table_name: string = \"\"")
.Attr("num_shards: int")
.Attr("shard_id: int")
.Attr("config: string = \"\"")
.SetIsStateful()
.SetShapeFn(RetrieveOpShapeFunction());
} // namespace tpu
} // namespace tensorflow

View File

@ -332,6 +332,9 @@ FtrlSlotVariableName = collections.namedtuple('FtrlSlotVariableName',
ProximalYogiSlotVariableNames = collections.namedtuple(
'ProximalYogiSlotVariableNames', ['v', 'm'])
FrequencyEstimatorSlotVariableName = collections.namedtuple(
'FrequencyEstimatorSlotVariableName', ['last_hit_step'])
AdamSlotVariables = collections.namedtuple('AdamSlotVariables', ['m', 'v'])
MomentumSlotVariable = collections.namedtuple('MomentumSlotVariable',
@ -352,6 +355,9 @@ FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable',
ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables',
['v', 'm'])
FrequencyEstimatorSlotVariables = collections.namedtuple(
'FrequencyEstimatorSlotVariables', ['last_hit_step'])
VariablesAndOps = collections.namedtuple('VariablesAndOps', [
'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops',
'retrieve_ops'
@ -1034,6 +1040,64 @@ class StochasticGradientDescentParameters(_OptimizationParameters):
)
class FrequencyEstimatorParameters(_OptimizationParameters):
"""Optimization parameters for Frequency Estimator TPU embeddings.
This is a non-standard optimizer, which returns the estimated frequency of
lookup for the feature passed to it. It should only be used on a table of
width 1. The gradient fed back to the TPU embedding should always be zero.
This can be acomplished via using `tf.stop_gradients` on the feature before
using it.
You must use the dynamic learning rate mechanism to set the 'learning rate'
for this table to be the a float32 cast of the global training step counter.
See `tensorflow/core/protobuf/tpu/optimization_parameters.proto` for more
details on this optimizer.
Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
`optimization_parameters` argument to set the optimizer and its parameters.
See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
for more details.
```
estimator = tf.estimator.tpu.TPUEstimator(
...
embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
...
optimization_parameters=FrequencyEstimatorParameters(0.1),
...))
```
"""
def __init__(self, tau: float, max_delta: float, outlier_threshold: float,
weight_exponent: float):
"""Optimization parameters for frequency estimator.
Args:
tau: Learning rate between (0, 1) that is used to update the array.
max_delta: Maximum value of delta, the difference between the current
global step and the last global step at which the row was sampled.
outlier_threshold: Threshold used to determine whether the current update
is an outlier.
weight_exponent: The weight exponent used to transform the estimated delta
into weights.
"""
super(FrequencyEstimatorParameters, self).__init__(
learning_rate=1.0,
use_gradient_accumulation=True,
clip_weight_min=None,
clip_weight_max=None,
weight_decay_factor=None,
multiply_weight_decay_factor_by_learning_rate=None,
)
self.tau = tau
self.max_delta = max_delta
self.outlier_threshold = outlier_threshold
self.weight_exponent = weight_exponent
DeviceConfig = collections.namedtuple('DeviceConfig',
['num_hosts', 'num_cores', 'job_name'])
@ -2559,6 +2623,89 @@ class _RMSPropHandler(_OptimizerHandler):
return slot_variables, load_ops_fn, retrieve_ops_fn
class _FrequencyEstimatorHandler(_OptimizerHandler):
"""Handles frequency estimator specific logic."""
def set_optimization_parameters(self, table_descriptor):
table_descriptor.optimization_parameters.frequency_estimator.SetInParent()
freq = table_descriptor.optimization_parameters.frequency_estimator
freq.tau = self._optimization_parameters.tau
freq.max_delta = self._optimization_parameters.max_delta
freq.outlier_threshold = self._optimization_parameters.outlier_threshold
freq.weight_exponent = self._optimization_parameters.weight_exponent
def get_default_slot_variable_names(self, table):
return FrequencyEstimatorSlotVariableName(
'{}/FrequencyEstimator'.format(table))
def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
table_config, table_variables, config_proto):
if table_config.dimension != 1:
raise ValueError('FrequencyEstimator tables should only have a dimension '
'of 1. Received dimension {}'.format(
table_config.dimension))
last_hit_step_variables = _create_partitioned_variables(
name=slot_variable_names.last_hit_step,
num_hosts=num_hosts,
vocabulary_size=table_config.vocabulary_size,
embedding_dimension=table_config.dimension,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
initializer=init_ops.zeros_initializer(),
)
slot_variables = FrequencyEstimatorSlotVariables(last_hit_step_variables)
def load_ops_fn():
"""Returns the retrieve ops for Frequency Estimator embedding tables.
Returns:
A list of ops to load embedding and slot variables from CPU to TPU.
"""
load_op_list = []
config = config_proto
for host_id, table_variable, last_hit_step_variable in (zip(
range(num_hosts), table_variables, last_hit_step_variables)):
with ops.colocate_with(table_variable):
load_parameters_op = (
tpu_ops.load_tpu_embedding_frequency_estimator_parameters(
parameters=table_variable,
last_hit_step=last_hit_step_variable,
table_name=table,
num_shards=num_hosts,
shard_id=host_id,
config=config))
config = None
load_op_list.append(load_parameters_op)
return load_op_list
def retrieve_ops_fn():
"""Returns the retrieve ops for Frequency Estimator embedding tables.
Returns:
A list of ops to retrieve embedding and slot variables from TPU to CPU.
"""
retrieve_op_list = []
config = config_proto
for host_id, table_variable, last_hit_step_variable in (zip(
range(num_hosts), table_variables, last_hit_step_variables)):
with ops.colocate_with(table_variable):
retrieved_table, retrieved_last_hit_step = (
tpu_ops.retrieve_tpu_embedding_frequency_estimator_parameters(
table_name=table,
num_shards=num_hosts,
shard_id=host_id,
config=config,
))
retrieve_parameters_op = control_flow_ops.group(
state_ops.assign(table_variable, retrieved_table),
state_ops.assign(last_hit_step_variable, retrieved_last_hit_step))
config = None
retrieve_op_list.append(retrieve_parameters_op)
return retrieve_op_list
return slot_variables, load_ops_fn, retrieve_ops_fn
class _StochasticGradientDescentHandler(_OptimizerHandler):
"""Handles stochastic gradient descent specific logic."""
@ -2638,6 +2785,8 @@ def _get_optimization_handler(optimization_parameters):
return _MomentumHandler(optimization_parameters)
elif isinstance(optimization_parameters, RMSPropParameters):
return _RMSPropHandler(optimization_parameters)
elif isinstance(optimization_parameters, FrequencyEstimatorParameters):
return _FrequencyEstimatorHandler(optimization_parameters)
return NotImplementedError()

View File

@ -2160,6 +2160,14 @@ tf_module {
name: "LoadTPUEmbeddingFTRLParametersGradAccumDebug"
argspec: "args=[\'parameters\', \'accumulators\', \'linears\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingFrequencyEstimatorParameters"
argspec: "args=[\'parameters\', \'last_hit_step\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
argspec: "args=[\'parameters\', \'last_hit_step\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingMDLAdagradLightParameters"
argspec: "args=[\'parameters\', \'accumulators\', \'weights\', \'benefits\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
@ -3748,6 +3756,14 @@ tf_module {
name: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingFrequencyEstimatorParameters"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingMDLAdagradLightParameters"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "

View File

@ -2160,6 +2160,14 @@ tf_module {
name: "LoadTPUEmbeddingFTRLParametersGradAccumDebug"
argspec: "args=[\'parameters\', \'accumulators\', \'linears\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingFrequencyEstimatorParameters"
argspec: "args=[\'parameters\', \'last_hit_step\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
argspec: "args=[\'parameters\', \'last_hit_step\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingMDLAdagradLightParameters"
argspec: "args=[\'parameters\', \'accumulators\', \'weights\', \'benefits\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
@ -3748,6 +3756,14 @@ tf_module {
name: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingFrequencyEstimatorParameters"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "RetrieveTPUEmbeddingMDLAdagradLightParameters"
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "