Add support for FrequencyEstimator optimizer to the mid level tpu embedding API.
PiperOrigin-RevId: 345126452 Change-Id: I7cc3345a3efd8f34bfd49b6f9b88ec9c48de76bd
This commit is contained in:
parent
8493ce6116
commit
249c207846
@ -0,0 +1,24 @@
|
||||
op {
|
||||
graph_op_name: "LoadTPUEmbeddingFrequencyEstimatorParameters"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "parameters"
|
||||
description: <<END
|
||||
Value of parameters used in the frequency estimator optimization algorithm.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "last_hit_step"
|
||||
description: <<END
|
||||
Value of last_hit_step used in the frequency estimator optimization algorithm.
|
||||
END
|
||||
}
|
||||
summary: "Load frequency estimator embedding parameters."
|
||||
description: <<END
|
||||
An op that loads optimization parameters into HBM for embedding. Must be
|
||||
preceded by a ConfigureTPUEmbeddingHost op that sets up the correct
|
||||
embedding table configuration. For example, this op is used to install
|
||||
parameters that are loaded from a checkpoint before a training loop is
|
||||
executed.
|
||||
END
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
op {
|
||||
graph_op_name: "LoadTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "parameters"
|
||||
description: <<END
|
||||
Value of parameters used in the frequency estimator optimization algorithm.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "last_hit_step"
|
||||
description: <<END
|
||||
Value of last_hit_step used in the frequency estimator optimization algorithm.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "gradient_accumulators"
|
||||
description: <<END
|
||||
Value of gradient_accumulators used in the frequency estimator optimization
|
||||
algorithm.
|
||||
END
|
||||
}
|
||||
summary: "Load frequency estimator embedding parameters with debug support."
|
||||
description: <<END
|
||||
An op that loads optimization parameters into HBM for embedding. Must be
|
||||
preceded by a ConfigureTPUEmbeddingHost op that sets up the correct
|
||||
embedding table configuration. For example, this op is used to install
|
||||
parameters that are loaded from a checkpoint before a training loop is
|
||||
executed.
|
||||
END
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
op {
|
||||
graph_op_name: "RetrieveTPUEmbeddingFrequencyEstimatorParameters"
|
||||
visibility: HIDDEN
|
||||
out_arg {
|
||||
name: "parameters"
|
||||
description: <<END
|
||||
Parameter parameters updated by the frequency estimator optimization algorithm.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "last_hit_step"
|
||||
description: <<END
|
||||
Parameter last_hit_step updated by the frequency estimator optimization
|
||||
algorithm.
|
||||
END
|
||||
}
|
||||
summary: "Retrieve frequency estimator embedding parameters."
|
||||
description: <<END
|
||||
An op that retrieves optimization parameters from embedding to host
|
||||
memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up
|
||||
the correct embedding table configuration. For example, this op is
|
||||
used to retrieve updated parameters before saving a checkpoint.
|
||||
END
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
op {
|
||||
graph_op_name:
|
||||
"RetrieveTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
|
||||
visibility: HIDDEN
|
||||
out_arg {
|
||||
name: "parameters"
|
||||
description: <<END
|
||||
Parameter parameters updated by the frequency estimator optimization algorithm.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "last_hit_step"
|
||||
description: <<END
|
||||
Parameter last_hit_step updated by the frequency estimator optimization
|
||||
algorithm.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "gradient_accumulators"
|
||||
description: <<END
|
||||
Parameter gradient_accumulators updated by the frequency estimator optimization
|
||||
algorithm.
|
||||
END
|
||||
}
|
||||
summary:
|
||||
"Retrieve frequency estimator embedding parameters with debug support."
|
||||
description: <<END
|
||||
An op that retrieves optimization parameters from embedding to host
|
||||
memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up
|
||||
the correct embedding table configuration. For example, this op is
|
||||
used to retrieve updated parameters before saving a checkpoint.
|
||||
END
|
||||
}
|
@ -511,5 +511,51 @@ REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingFrequencyEstimatorParameters")
|
||||
.Input("parameters: float32")
|
||||
.Input("last_hit_step: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("LoadTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug")
|
||||
.Input("parameters: float32")
|
||||
.Input("last_hit_step: float32")
|
||||
.Input("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(LoadOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingFrequencyEstimatorParameters")
|
||||
.Output("parameters: float32")
|
||||
.Output("last_hit_step: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
REGISTER_OP("RetrieveTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug")
|
||||
.Output("parameters: float32")
|
||||
.Output("last_hit_step: float32")
|
||||
.Output("gradient_accumulators: float32")
|
||||
.Attr("table_id: int = -1")
|
||||
.Attr("table_name: string = \"\"")
|
||||
.Attr("num_shards: int")
|
||||
.Attr("shard_id: int")
|
||||
.Attr("config: string = \"\"")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(RetrieveOpShapeFunction());
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -332,6 +332,9 @@ FtrlSlotVariableName = collections.namedtuple('FtrlSlotVariableName',
|
||||
ProximalYogiSlotVariableNames = collections.namedtuple(
|
||||
'ProximalYogiSlotVariableNames', ['v', 'm'])
|
||||
|
||||
FrequencyEstimatorSlotVariableName = collections.namedtuple(
|
||||
'FrequencyEstimatorSlotVariableName', ['last_hit_step'])
|
||||
|
||||
AdamSlotVariables = collections.namedtuple('AdamSlotVariables', ['m', 'v'])
|
||||
|
||||
MomentumSlotVariable = collections.namedtuple('MomentumSlotVariable',
|
||||
@ -352,6 +355,9 @@ FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable',
|
||||
ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables',
|
||||
['v', 'm'])
|
||||
|
||||
FrequencyEstimatorSlotVariables = collections.namedtuple(
|
||||
'FrequencyEstimatorSlotVariables', ['last_hit_step'])
|
||||
|
||||
VariablesAndOps = collections.namedtuple('VariablesAndOps', [
|
||||
'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops',
|
||||
'retrieve_ops'
|
||||
@ -1034,6 +1040,64 @@ class StochasticGradientDescentParameters(_OptimizationParameters):
|
||||
)
|
||||
|
||||
|
||||
class FrequencyEstimatorParameters(_OptimizationParameters):
|
||||
"""Optimization parameters for Frequency Estimator TPU embeddings.
|
||||
|
||||
This is a non-standard optimizer, which returns the estimated frequency of
|
||||
lookup for the feature passed to it. It should only be used on a table of
|
||||
width 1. The gradient fed back to the TPU embedding should always be zero.
|
||||
This can be acomplished via using `tf.stop_gradients` on the feature before
|
||||
using it.
|
||||
|
||||
You must use the dynamic learning rate mechanism to set the 'learning rate'
|
||||
for this table to be the a float32 cast of the global training step counter.
|
||||
|
||||
See `tensorflow/core/protobuf/tpu/optimization_parameters.proto` for more
|
||||
details on this optimizer.
|
||||
|
||||
Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
|
||||
`optimization_parameters` argument to set the optimizer and its parameters.
|
||||
See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
|
||||
for more details.
|
||||
|
||||
```
|
||||
estimator = tf.estimator.tpu.TPUEstimator(
|
||||
...
|
||||
embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
|
||||
...
|
||||
optimization_parameters=FrequencyEstimatorParameters(0.1),
|
||||
...))
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, tau: float, max_delta: float, outlier_threshold: float,
|
||||
weight_exponent: float):
|
||||
"""Optimization parameters for frequency estimator.
|
||||
|
||||
Args:
|
||||
tau: Learning rate between (0, 1) that is used to update the array.
|
||||
max_delta: Maximum value of delta, the difference between the current
|
||||
global step and the last global step at which the row was sampled.
|
||||
outlier_threshold: Threshold used to determine whether the current update
|
||||
is an outlier.
|
||||
weight_exponent: The weight exponent used to transform the estimated delta
|
||||
into weights.
|
||||
"""
|
||||
super(FrequencyEstimatorParameters, self).__init__(
|
||||
learning_rate=1.0,
|
||||
use_gradient_accumulation=True,
|
||||
clip_weight_min=None,
|
||||
clip_weight_max=None,
|
||||
weight_decay_factor=None,
|
||||
multiply_weight_decay_factor_by_learning_rate=None,
|
||||
)
|
||||
self.tau = tau
|
||||
self.max_delta = max_delta
|
||||
self.outlier_threshold = outlier_threshold
|
||||
self.weight_exponent = weight_exponent
|
||||
|
||||
|
||||
DeviceConfig = collections.namedtuple('DeviceConfig',
|
||||
['num_hosts', 'num_cores', 'job_name'])
|
||||
|
||||
@ -2559,6 +2623,89 @@ class _RMSPropHandler(_OptimizerHandler):
|
||||
return slot_variables, load_ops_fn, retrieve_ops_fn
|
||||
|
||||
|
||||
class _FrequencyEstimatorHandler(_OptimizerHandler):
|
||||
"""Handles frequency estimator specific logic."""
|
||||
|
||||
def set_optimization_parameters(self, table_descriptor):
|
||||
table_descriptor.optimization_parameters.frequency_estimator.SetInParent()
|
||||
freq = table_descriptor.optimization_parameters.frequency_estimator
|
||||
freq.tau = self._optimization_parameters.tau
|
||||
freq.max_delta = self._optimization_parameters.max_delta
|
||||
freq.outlier_threshold = self._optimization_parameters.outlier_threshold
|
||||
freq.weight_exponent = self._optimization_parameters.weight_exponent
|
||||
|
||||
def get_default_slot_variable_names(self, table):
|
||||
return FrequencyEstimatorSlotVariableName(
|
||||
'{}/FrequencyEstimator'.format(table))
|
||||
|
||||
def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
|
||||
table_config, table_variables, config_proto):
|
||||
if table_config.dimension != 1:
|
||||
raise ValueError('FrequencyEstimator tables should only have a dimension '
|
||||
'of 1. Received dimension {}'.format(
|
||||
table_config.dimension))
|
||||
|
||||
last_hit_step_variables = _create_partitioned_variables(
|
||||
name=slot_variable_names.last_hit_step,
|
||||
num_hosts=num_hosts,
|
||||
vocabulary_size=table_config.vocabulary_size,
|
||||
embedding_dimension=table_config.dimension,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
)
|
||||
slot_variables = FrequencyEstimatorSlotVariables(last_hit_step_variables)
|
||||
|
||||
def load_ops_fn():
|
||||
"""Returns the retrieve ops for Frequency Estimator embedding tables.
|
||||
|
||||
Returns:
|
||||
A list of ops to load embedding and slot variables from CPU to TPU.
|
||||
"""
|
||||
load_op_list = []
|
||||
config = config_proto
|
||||
for host_id, table_variable, last_hit_step_variable in (zip(
|
||||
range(num_hosts), table_variables, last_hit_step_variables)):
|
||||
with ops.colocate_with(table_variable):
|
||||
load_parameters_op = (
|
||||
tpu_ops.load_tpu_embedding_frequency_estimator_parameters(
|
||||
parameters=table_variable,
|
||||
last_hit_step=last_hit_step_variable,
|
||||
table_name=table,
|
||||
num_shards=num_hosts,
|
||||
shard_id=host_id,
|
||||
config=config))
|
||||
config = None
|
||||
load_op_list.append(load_parameters_op)
|
||||
return load_op_list
|
||||
|
||||
def retrieve_ops_fn():
|
||||
"""Returns the retrieve ops for Frequency Estimator embedding tables.
|
||||
|
||||
Returns:
|
||||
A list of ops to retrieve embedding and slot variables from TPU to CPU.
|
||||
"""
|
||||
retrieve_op_list = []
|
||||
config = config_proto
|
||||
for host_id, table_variable, last_hit_step_variable in (zip(
|
||||
range(num_hosts), table_variables, last_hit_step_variables)):
|
||||
with ops.colocate_with(table_variable):
|
||||
retrieved_table, retrieved_last_hit_step = (
|
||||
tpu_ops.retrieve_tpu_embedding_frequency_estimator_parameters(
|
||||
table_name=table,
|
||||
num_shards=num_hosts,
|
||||
shard_id=host_id,
|
||||
config=config,
|
||||
))
|
||||
retrieve_parameters_op = control_flow_ops.group(
|
||||
state_ops.assign(table_variable, retrieved_table),
|
||||
state_ops.assign(last_hit_step_variable, retrieved_last_hit_step))
|
||||
config = None
|
||||
retrieve_op_list.append(retrieve_parameters_op)
|
||||
return retrieve_op_list
|
||||
|
||||
return slot_variables, load_ops_fn, retrieve_ops_fn
|
||||
|
||||
|
||||
class _StochasticGradientDescentHandler(_OptimizerHandler):
|
||||
"""Handles stochastic gradient descent specific logic."""
|
||||
|
||||
@ -2638,6 +2785,8 @@ def _get_optimization_handler(optimization_parameters):
|
||||
return _MomentumHandler(optimization_parameters)
|
||||
elif isinstance(optimization_parameters, RMSPropParameters):
|
||||
return _RMSPropHandler(optimization_parameters)
|
||||
elif isinstance(optimization_parameters, FrequencyEstimatorParameters):
|
||||
return _FrequencyEstimatorHandler(optimization_parameters)
|
||||
return NotImplementedError()
|
||||
|
||||
|
||||
|
@ -2160,6 +2160,14 @@ tf_module {
|
||||
name: "LoadTPUEmbeddingFTRLParametersGradAccumDebug"
|
||||
argspec: "args=[\'parameters\', \'accumulators\', \'linears\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingFrequencyEstimatorParameters"
|
||||
argspec: "args=[\'parameters\', \'last_hit_step\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
|
||||
argspec: "args=[\'parameters\', \'last_hit_step\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingMDLAdagradLightParameters"
|
||||
argspec: "args=[\'parameters\', \'accumulators\', \'weights\', \'benefits\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
@ -3748,6 +3756,14 @@ tf_module {
|
||||
name: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingFrequencyEstimatorParameters"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingMDLAdagradLightParameters"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
|
@ -2160,6 +2160,14 @@ tf_module {
|
||||
name: "LoadTPUEmbeddingFTRLParametersGradAccumDebug"
|
||||
argspec: "args=[\'parameters\', \'accumulators\', \'linears\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingFrequencyEstimatorParameters"
|
||||
argspec: "args=[\'parameters\', \'last_hit_step\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
|
||||
argspec: "args=[\'parameters\', \'last_hit_step\', \'gradient_accumulators\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingMDLAdagradLightParameters"
|
||||
argspec: "args=[\'parameters\', \'accumulators\', \'weights\', \'benefits\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
@ -3748,6 +3756,14 @@ tf_module {
|
||||
name: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingFrequencyEstimatorParameters"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingFrequencyEstimatorParametersGradAccumDebug"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RetrieveTPUEmbeddingMDLAdagradLightParameters"
|
||||
argspec: "args=[\'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user