This avoids log truncation for larger models. PiperOrigin-RevId: 338551150 Change-Id: I693cb9771e493f7070dfd93f9e56f5b417b30b66
752 lines
30 KiB
Python
752 lines
30 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Companion classes for mid level API for TPU Embeddings in TF2."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import abc
|
|
import math
|
|
import typing
|
|
from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union
|
|
|
|
from absl import logging
|
|
import six
|
|
|
|
from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
|
|
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
|
|
from tensorflow.python.distribute import sharded_variable
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import init_ops_v2
|
|
from tensorflow.python.ops import variables as tf_variables
|
|
from tensorflow.python.tpu.ops import tpu_ops
|
|
from tensorflow.python.types import core
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
TableVariable = TypeVar("TableVariable", sharded_variable.ShardedVariable,
|
|
tf_variables.Variable)
|
|
SlotVarCreationFnType = Callable[
|
|
[TableVariable, List[Text], List[init_ops_v2.Initializer]],
|
|
Dict[Text, TableVariable]]
|
|
ClipValueType = Union[Tuple[float, float], float]
|
|
|
|
|
|
@six.add_metaclass(abc.ABCMeta)
|
|
class _Optimizer(object):
|
|
"""Base class for all optimizers, with common parameters."""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate: Union[float, Callable[[], float]],
|
|
use_gradient_accumulation: bool,
|
|
clip_weight_min: Optional[float],
|
|
clip_weight_max: Optional[float],
|
|
weight_decay_factor: Optional[float],
|
|
multiply_weight_decay_factor_by_learning_rate: bool,
|
|
clipvalue: Optional[ClipValueType] = None,
|
|
slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None):
|
|
self.learning_rate = learning_rate
|
|
self.use_gradient_accumulation = use_gradient_accumulation
|
|
self.clip_weight_min = clip_weight_min
|
|
self.clip_weight_max = clip_weight_max
|
|
if not use_gradient_accumulation and clipvalue is not None:
|
|
raise ValueError("Received non-None gradient clipping limit {} but "
|
|
"use_gradient_accumulation is not set to True.".format(
|
|
clipvalue))
|
|
if clipvalue is None:
|
|
clipvalue = (None, None)
|
|
elif not isinstance(clipvalue, tuple):
|
|
clipvalue = (-1. * clipvalue, clipvalue)
|
|
self.clip_gradient_min, self.clip_gradient_max = clipvalue
|
|
|
|
self.weight_decay_factor = weight_decay_factor
|
|
self.multiply_weight_decay_factor_by_learning_rate = (
|
|
multiply_weight_decay_factor_by_learning_rate)
|
|
|
|
if (slot_variable_creation_fn is not None and
|
|
not callable(slot_variable_creation_fn)):
|
|
raise ValueError("slot_variable_creation_fn must be either None or a "
|
|
"callable.")
|
|
self.slot_variable_creation_fn = slot_variable_creation_fn
|
|
|
|
@abc.abstractmethod
|
|
def _slot_names(self) -> List[Text]:
|
|
"""Returns the name of all the slot variables.
|
|
|
|
This does not include the 'parameters' variable and these names must match
|
|
the names of the slots variables as used in the corresponding
|
|
`tpu_ops.load_tpu_embedding_*` ops.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
|
|
"""Returns initializers for slot variables.
|
|
|
|
This returns a parallel list to self._slot_names().
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def _set_optimization_parameters(
|
|
self, parameters: optimization_parameters_pb2.OptimizationParameters):
|
|
"""Sets the optimizer fields in the OptimizationParameters."""
|
|
if self.use_gradient_accumulation:
|
|
parameters.gradient_accumulation_status = (
|
|
optimization_parameters_pb2.GradientAccumulationStatus.ENABLED)
|
|
else:
|
|
parameters.gradient_accumulation_status = (
|
|
optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
|
|
|
|
if self.clip_weight_min is not None:
|
|
parameters.clipping_limits.lower.value = self.clip_weight_min
|
|
|
|
if self.clip_weight_max is not None:
|
|
parameters.clipping_limits.upper.value = self.clip_weight_max
|
|
|
|
if self.clip_gradient_min is not None:
|
|
parameters.gradient_clipping_limits.lower.value = self.clip_gradient_min
|
|
|
|
if self.clip_gradient_max is not None:
|
|
parameters.gradient_clipping_limits.upper.value = self.clip_gradient_max
|
|
|
|
if self.weight_decay_factor:
|
|
parameters.weight_decay_factor = self.weight_decay_factor
|
|
if self.multiply_weight_decay_factor_by_learning_rate:
|
|
parameters.multiply_weight_decay_factor_by_learning_rate = True
|
|
|
|
@abc.abstractmethod
|
|
def _load(self) -> Callable[..., ops.Operation]:
|
|
"""Returns the load function for the optimizer."""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def _retrieve(self) -> Callable[..., core.Tensor]:
|
|
"""Returns the retrieve function for the optimizer."""
|
|
raise NotImplementedError
|
|
|
|
def _create_slots(
|
|
self, table: "TableConfig",
|
|
variable_creator: Callable[[Text, init_ops_v2.Initializer],
|
|
tf_variables.Variable]
|
|
) -> Dict[Text, tf_variables.Variable]:
|
|
"""Creates slot variables for table.
|
|
|
|
Args:
|
|
table: The table variable to create slots for.
|
|
variable_creator: A function which creates variables. Takes parameters
|
|
'name', 'initializer'.
|
|
|
|
Returns:
|
|
A dict of variables, keyed by self._slot_names().
|
|
"""
|
|
if self.slot_variable_creation_fn is not None:
|
|
return self.slot_variable_creation_fn(table, self._slot_names(),
|
|
self._slot_initializers())
|
|
else:
|
|
slots = {}
|
|
for slot, initializer in zip(self._slot_names(),
|
|
self._slot_initializers()):
|
|
slots[slot] = variable_creator(slot, initializer)
|
|
return slots
|
|
|
|
|
|
@tf_export("tpu.experimental.embedding.SGD")
|
|
class SGD(_Optimizer):
|
|
"""Optimization parameters for stochastic gradient descent for TPU embeddings.
|
|
|
|
Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
|
|
argument to set the global optimizer and its parameters:
|
|
|
|
```
|
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
|
...
|
|
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
|
|
```
|
|
|
|
This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
|
|
optimizer parameter to set a table specific optimizer. This will override the
|
|
optimizer and parameters for global embedding optimizer defined above:
|
|
|
|
```
|
|
table_one = tf.tpu.experimental.embedding.TableConfig(
|
|
vocabulary_size=...,
|
|
dim=...,
|
|
optimizer=tf.tpu.experimental.embedding.SGD(0.2))
|
|
table_two = tf.tpu.experimental.embedding.TableConfig(
|
|
vocabulary_size=...,
|
|
dim=...)
|
|
|
|
feature_config = (
|
|
tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_one),
|
|
tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_two))
|
|
|
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
|
feature_config=feature_config,
|
|
batch_size=...
|
|
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
|
|
```
|
|
|
|
In the above example, the first feature will be looked up in a table that has
|
|
a learning rate of 0.2 while the second feature will be looked up in a table
|
|
that has a learning rate of 0.1.
|
|
|
|
See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
|
|
complete description of these parameters and their impacts on the optimizer
|
|
algorithm.
|
|
"""
|
|
|
|
def __init__(self,
|
|
learning_rate: Union[float, Callable[[], float]] = 0.01,
|
|
clip_weight_min: Optional[float] = None,
|
|
clip_weight_max: Optional[float] = None,
|
|
weight_decay_factor: Optional[float] = None,
|
|
multiply_weight_decay_factor_by_learning_rate: bool = None,
|
|
clipvalue: Optional[ClipValueType] = None):
|
|
"""Optimization parameters for stochastic gradient descent.
|
|
|
|
Args:
|
|
learning_rate: The learning rate. It should be a floating point value or a
|
|
callable taking no arguments for a dynamic learning rate.
|
|
clip_weight_min: the minimum value to clip by; None means -infinity.
|
|
clip_weight_max: the maximum value to clip by; None means +infinity.
|
|
weight_decay_factor: amount of weight decay to apply; None means that the
|
|
weights are not decayed. Weights are decayed by multiplying the weight
|
|
by this factor each step.
|
|
multiply_weight_decay_factor_by_learning_rate: if true,
|
|
`weight_decay_factor` is multiplied by the current learning rate.
|
|
clipvalue: Controls clipping of the gradient. Set to either a single
|
|
positive scalar value to get clipping or a tiple of scalar values (min,
|
|
max) to set a separate maximum or minimum. If one of the two entries is
|
|
None, then there will be no clipping that direction. Note if this is
|
|
set, you may see a decrease in performance as gradient accumulation
|
|
will be enabled (it is normally off for SGD as it has no affect on
|
|
accuracy). See
|
|
'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for more
|
|
information on gradient accumulation and its impact on tpu embeddings.
|
|
"""
|
|
use_gradient_accumulation = clipvalue is not None
|
|
|
|
super(SGD, self).__init__(
|
|
learning_rate, use_gradient_accumulation, clip_weight_min,
|
|
clip_weight_max, weight_decay_factor,
|
|
multiply_weight_decay_factor_by_learning_rate, clipvalue)
|
|
|
|
def _slot_names(self) -> List[Text]:
|
|
return []
|
|
|
|
def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
|
|
return []
|
|
|
|
def _set_optimization_parameters(
|
|
self, parameters: optimization_parameters_pb2.OptimizationParameters):
|
|
super(SGD, self)._set_optimization_parameters(parameters)
|
|
parameters.stochastic_gradient_descent.SetInParent()
|
|
|
|
def _load(self) -> Callable[..., ops.Operation]:
|
|
return tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters
|
|
|
|
def _retrieve(self) -> Callable[..., core.Tensor]:
|
|
return tpu_ops.retrieve_tpu_embedding_stochastic_gradient_descent_parameters
|
|
|
|
|
|
@tf_export("tpu.experimental.embedding.Adagrad")
|
|
class Adagrad(_Optimizer):
|
|
"""Optimization parameters for Adagrad with TPU embeddings.
|
|
|
|
Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
|
|
argument to set the global optimizer and its parameters:
|
|
|
|
```python
|
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
|
...
|
|
optimizer=tf.tpu.experimental.embedding.Adagrad(0.1))
|
|
```
|
|
|
|
This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
|
|
optimizer parameter to set a table specific optimizer. This will override the
|
|
optimizer and parameters for global embedding optimizer defined above:
|
|
|
|
```python
|
|
table_one = tf.tpu.experimental.embedding.TableConfig(
|
|
vocabulary_size=...,
|
|
dim=...,
|
|
optimizer=tf.tpu.experimental.embedding.Adagrad(0.2))
|
|
table_two = tf.tpu.experimental.embedding.TableConfig(
|
|
vocabulary_size=...,
|
|
dim=...)
|
|
|
|
feature_config = (
|
|
tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_one),
|
|
tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_two))
|
|
|
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
|
feature_config=feature_config,
|
|
batch_size=...
|
|
optimizer=tf.tpu.experimental.embedding.Adagrad(0.1))
|
|
```
|
|
|
|
In the above example, the first feature will be looked up in a table that has
|
|
a learning rate of 0.2 while the second feature will be looked up in a table
|
|
that has a learning rate of 0.1.
|
|
|
|
See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
|
|
complete description of these parameters and their impacts on the optimizer
|
|
algorithm.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate: float = 0.001,
|
|
initial_accumulator_value: float = 0.1,
|
|
use_gradient_accumulation: bool = True,
|
|
clip_weight_min: Optional[float] = None,
|
|
clip_weight_max: Optional[float] = None,
|
|
weight_decay_factor: Optional[float] = None,
|
|
multiply_weight_decay_factor_by_learning_rate: bool = None,
|
|
slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
|
|
clipvalue: Optional[ClipValueType] = None):
|
|
"""Optimization parameters for Adagrad.
|
|
|
|
Args:
|
|
learning_rate: The learning rate. It should be a floating point value or a
|
|
callable taking no arguments for a dynamic learning rate.
|
|
initial_accumulator_value: initial accumulator for Adagrad.
|
|
use_gradient_accumulation: setting this to `False` makes embedding
|
|
gradients calculation less accurate but faster.
|
|
clip_weight_min: the minimum value to clip by; None means -infinity.
|
|
clip_weight_max: the maximum value to clip by; None means +infinity.
|
|
weight_decay_factor: amount of weight decay to apply; None means that the
|
|
weights are not decayed.
|
|
multiply_weight_decay_factor_by_learning_rate: if true,
|
|
`weight_decay_factor` is multiplied by the current learning rate.
|
|
slot_variable_creation_fn: If you wish do directly control the creation of
|
|
the slot variables, set this to a callable taking three parameters: a
|
|
table variable, a list of slot names to create for it, and a list of
|
|
initializers. This function should return a dict with the slot names
|
|
as keys and the created variables as values with types matching the
|
|
table variable. When set to None (the default), uses the built-in
|
|
variable creation.
|
|
clipvalue: Controls clipping of the gradient. Set to either a single
|
|
positive scalar value to get clipping or a tuple of scalar values (min,
|
|
max) to set a separate maximum or minimum. If one of the two entries is
|
|
None, then there will be no clipping that direction.
|
|
"""
|
|
super(Adagrad, self).__init__(
|
|
learning_rate, use_gradient_accumulation, clip_weight_min,
|
|
clip_weight_max, weight_decay_factor,
|
|
multiply_weight_decay_factor_by_learning_rate, clipvalue,
|
|
slot_variable_creation_fn)
|
|
if initial_accumulator_value <= 0:
|
|
raise ValueError("Adagrad initial_accumulator_value must be positive")
|
|
self.initial_accumulator_value = initial_accumulator_value
|
|
|
|
def _slot_names(self) -> List[Text]:
|
|
return ["accumulators"]
|
|
|
|
def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
|
|
return [init_ops_v2.Constant(self.initial_accumulator_value)]
|
|
|
|
def _set_optimization_parameters(
|
|
self, parameters: optimization_parameters_pb2.OptimizationParameters):
|
|
super(Adagrad, self)._set_optimization_parameters(parameters)
|
|
parameters.adagrad.SetInParent()
|
|
|
|
def _load(self) -> Callable[..., ops.Operation]:
|
|
return tpu_ops.load_tpu_embedding_adagrad_parameters
|
|
|
|
def _retrieve(self) -> Callable[..., core.Tensor]:
|
|
return tpu_ops.retrieve_tpu_embedding_adagrad_parameters
|
|
|
|
|
|
@tf_export("tpu.experimental.embedding.Adam")
|
|
class Adam(_Optimizer):
|
|
"""Optimization parameters for Adam with TPU embeddings.
|
|
|
|
Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
|
|
argument to set the global optimizer and its parameters:
|
|
|
|
NOTE: By default this optimizer is lazy, i.e. it will not apply the gradient
|
|
update of zero to rows that were not looked up. You can change this behavior
|
|
by setting `lazy_adam` to `False`.
|
|
|
|
```python
|
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
|
...
|
|
optimizer=tf.tpu.experimental.embedding.Adam(0.1))
|
|
```
|
|
|
|
This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
|
|
optimizer parameter to set a table specific optimizer. This will override the
|
|
optimizer and parameters for global embedding optimizer defined above:
|
|
|
|
```python
|
|
table_one = tf.tpu.experimental.embedding.TableConfig(
|
|
vocabulary_size=...,
|
|
dim=...,
|
|
optimizer=tf.tpu.experimental.embedding.Adam(0.2))
|
|
table_two = tf.tpu.experimental.embedding.TableConfig(
|
|
vocabulary_size=...,
|
|
dim=...)
|
|
|
|
feature_config = (
|
|
tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_one),
|
|
tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_two))
|
|
|
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
|
feature_config=feature_config,
|
|
batch_size=...
|
|
optimizer=tf.tpu.experimental.embedding.Adam(0.1))
|
|
```
|
|
|
|
In the above example, the first feature will be looked up in a table that has
|
|
a learning rate of 0.2 while the second feature will be looked up in a table
|
|
that has a learning rate of 0.1.
|
|
|
|
See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
|
|
complete description of these parameters and their impacts on the optimizer
|
|
algorithm.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate: Union[float, Callable[[], float]] = 0.001,
|
|
beta_1: float = 0.9,
|
|
beta_2: float = 0.999,
|
|
epsilon: float = 1e-07,
|
|
lazy_adam: bool = True,
|
|
sum_inside_sqrt: bool = True,
|
|
use_gradient_accumulation: bool = True,
|
|
clip_weight_min: Optional[float] = None,
|
|
clip_weight_max: Optional[float] = None,
|
|
weight_decay_factor: Optional[float] = None,
|
|
multiply_weight_decay_factor_by_learning_rate: bool = None,
|
|
slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
|
|
clipvalue: Optional[ClipValueType] = None):
|
|
"""Optimization parameters for Adam.
|
|
|
|
See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
|
|
complete description of these parameters and their impacts on the optimizer
|
|
algorithm.
|
|
|
|
Args:
|
|
learning_rate: The learning rate. It should be a floating point value or a
|
|
callable taking no arguments for a dynamic learning rate.
|
|
beta_1: A float value. The exponential decay rate for the 1st moment
|
|
estimates.
|
|
beta_2: A float value. The exponential decay rate for the 2nd moment
|
|
estimates.
|
|
epsilon: A small constant for numerical stability.
|
|
lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster.
|
|
sum_inside_sqrt: When this is true, the Adam update formula is changed
|
|
from `m / (sqrt(v) + epsilon)` to `m / sqrt(v + epsilon**2)`. This
|
|
option improves the performance of TPU training and is not expected to
|
|
harm model quality.
|
|
use_gradient_accumulation: Setting this to `False` makes embedding
|
|
gradients calculation less accurate but faster.
|
|
clip_weight_min: the minimum value to clip by; None means -infinity.
|
|
clip_weight_max: the maximum value to clip by; None means +infinity.
|
|
weight_decay_factor: amount of weight decay to apply; None means that the
|
|
weights are not decayed.
|
|
multiply_weight_decay_factor_by_learning_rate: if true,
|
|
`weight_decay_factor` is multiplied by the current learning rate.
|
|
slot_variable_creation_fn: If you wish do directly control the creation of
|
|
the slot variables, set this to a callable taking three parameters: a
|
|
table variable, a list of slot names to create for it, and a list of
|
|
initializers. This function should return a dict with the slot names
|
|
as keys and the created variables as values with types matching the
|
|
table variable. When set to None (the default), uses the built-in
|
|
variable creation.
|
|
clipvalue: Controls clipping of the gradient. Set to either a single
|
|
positive scalar value to get clipping or a tiple of scalar values (min,
|
|
max) to set a separate maximum or minimum. If one of the two entries is
|
|
None, then there will be no clipping that direction.
|
|
"""
|
|
super(Adam, self).__init__(
|
|
learning_rate, use_gradient_accumulation, clip_weight_min,
|
|
clip_weight_max, weight_decay_factor,
|
|
multiply_weight_decay_factor_by_learning_rate, clipvalue,
|
|
slot_variable_creation_fn)
|
|
if beta_1 < 0. or beta_1 >= 1.:
|
|
raise ValueError("beta1 must be in the range [0, 1), but received {}."
|
|
.format(beta_1))
|
|
if beta_2 < 0. or beta_2 >= 1.:
|
|
raise ValueError("beta2 must be in the range [0, 1), but received {}."
|
|
.format(beta_2))
|
|
if epsilon <= 0.:
|
|
raise ValueError("epsilon must be positive; got {}.".format(epsilon))
|
|
if not use_gradient_accumulation and not lazy_adam:
|
|
raise ValueError(
|
|
"When disabling Lazy Adam, gradient accumulation must be used.")
|
|
|
|
self.beta_1 = beta_1
|
|
self.beta_2 = beta_2
|
|
self.epsilon = epsilon
|
|
self.lazy_adam = lazy_adam
|
|
self.sum_inside_sqrt = sum_inside_sqrt
|
|
|
|
def _slot_names(self) -> List[Text]:
|
|
return ["momenta", "velocities"]
|
|
|
|
def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
|
|
return [init_ops_v2.Constant(), init_ops_v2.Constant()]
|
|
|
|
def _set_optimization_parameters(
|
|
self, parameters: optimization_parameters_pb2.OptimizationParameters):
|
|
super(Adam, self)._set_optimization_parameters(parameters)
|
|
parameters.adam.beta1 = self.beta_1
|
|
parameters.adam.beta2 = self.beta_2
|
|
parameters.adam.epsilon = self.epsilon
|
|
parameters.adam.use_non_lazy_adam = not self.lazy_adam
|
|
parameters.adam.use_sum_inside_sqrt = self.sum_inside_sqrt
|
|
|
|
def _load(self) -> Callable[..., ops.Operation]:
|
|
return tpu_ops.load_tpu_embedding_adam_parameters
|
|
|
|
def _retrieve(self) -> Callable[..., core.Tensor]:
|
|
return tpu_ops.retrieve_tpu_embedding_adam_parameters
|
|
|
|
|
|
@tf_export("tpu.experimental.embedding.TableConfig")
|
|
class TableConfig(object):
|
|
"""Configuration data for one embedding table.
|
|
|
|
This class holds the configuration data for a single embedding table. It is
|
|
used as the `table` parameter of a
|
|
`tf.tpu.experimental.embedding.FeatureConfig`. Multiple
|
|
`tf.tpu.experimental.embedding.FeatureConfig` objects can use the same
|
|
`tf.tpu.experimental.embedding.TableConfig` object. In this case a shared
|
|
table will be created for those feature lookups.
|
|
|
|
```python
|
|
table_config_one = tf.tpu.experimental.embedding.TableConfig(
|
|
vocabulary_size=...,
|
|
dim=...)
|
|
table_config_two = tf.tpu.experimental.embedding.TableConfig(
|
|
vocabulary_size=...,
|
|
dim=...)
|
|
feature_config = {
|
|
'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_config_one),
|
|
'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_config_one),
|
|
'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_config_two)}
|
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
|
feature_config=feature_config,
|
|
batch_size=...
|
|
optimizer=tf.tpu.experimental.embedding.Adam(0.1))
|
|
```
|
|
|
|
The above configuration has 2 tables, and three features. The first two
|
|
features will be looked up in the first table and the third feature will be
|
|
looked up in the second table.
|
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
vocabulary_size: int,
|
|
dim: int,
|
|
initializer: Optional[Callable[[Any], None]],
|
|
optimizer: Optional[_Optimizer] = None,
|
|
combiner: Text = "mean",
|
|
name: Optional[Text] = None):
|
|
"""Embedding table configuration.
|
|
|
|
Args:
|
|
vocabulary_size: Size of the table's vocabulary (number of rows).
|
|
dim: The embedding dimension (width) of the table.
|
|
initializer: A callable initializer taking one parameter, the shape of the
|
|
variable that will be initialized. Will be called once per task, to
|
|
initialize that task's shard of the embedding table. If not specified,
|
|
defaults to `truncated_normal_initializer` with mean `0.0` and standard
|
|
deviation `1/sqrt(dim)`.
|
|
optimizer: An optional instance of an optimizer parameters class, instance
|
|
of one of `tf.tpu.experimental.embedding.SGD`,
|
|
`tf.tpu.experimental.embedding.Adagrad` or
|
|
`tf.tpu.experimental.embedding.Adam`. It set will override the global
|
|
optimizer passed to `tf.tpu.experimental.embedding.TPUEmbedding`.
|
|
combiner: A string specifying how to reduce if there are multiple entries
|
|
in a single row. Currently 'mean', 'sqrtn', 'sum' are supported, with
|
|
'mean' the default. 'sqrtn' often achieves good accuracy, in particular
|
|
with bag-of-words columns. For more information, see
|
|
`tf.nn.embedding_lookup_sparse`.
|
|
name: An optional string used to name the table. Useful for debugging.
|
|
|
|
Returns:
|
|
`TableConfig`.
|
|
|
|
Raises:
|
|
ValueError: if `vocabulary_size` is not a positive integer.
|
|
ValueError: if `dim` is not a positive integer.
|
|
ValueError: if `initializer` is specified and is not callable.
|
|
ValueError: if `combiner` is not supported.
|
|
"""
|
|
if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
|
|
raise ValueError("Invalid vocabulary_size {}.".format(vocabulary_size))
|
|
|
|
if not isinstance(dim, int) or dim < 1:
|
|
raise ValueError("Invalid dim {}.".format(dim))
|
|
|
|
if (initializer is not None) and (not callable(initializer)):
|
|
raise ValueError("initializer must be callable if specified.")
|
|
if initializer is None:
|
|
initializer = init_ops_v2.TruncatedNormal(mean=0.0,
|
|
stddev=1/math.sqrt(dim))
|
|
|
|
if combiner not in ("mean", "sum", "sqrtn"):
|
|
raise ValueError("Invalid combiner {}".format(combiner))
|
|
|
|
self.vocabulary_size = vocabulary_size
|
|
self.dim = dim
|
|
self.initializer = initializer
|
|
self.optimizer = optimizer
|
|
self.combiner = combiner
|
|
self.name = name
|
|
|
|
def __repr__(self):
|
|
# If using the default initializer, just print "None" for clarity.
|
|
initializer = self.initializer
|
|
|
|
if isinstance(initializer, init_ops_v2.TruncatedNormal):
|
|
# PY2 type checking can't infer type of initializer even after if.
|
|
initializer = typing.cast(init_ops_v2.TruncatedNormal, initializer)
|
|
if (initializer.mean == 0.0
|
|
and math.isclose(initializer.stddev, 1/math.sqrt(self.dim))): # pytype: disable=module-attr (math.isclose not in PY2)
|
|
initializer = None
|
|
|
|
return (
|
|
"TableConfig(vocabulary_size={vocabulary_size!r}, dim={dim!r}, "
|
|
"initializer={initializer!r}, optimizer={optimizer!r}, "
|
|
"combiner={combiner!r}, name={name!r})".format(
|
|
vocabulary_size=self.vocabulary_size,
|
|
dim=self.dim,
|
|
initializer=initializer,
|
|
optimizer=self.optimizer,
|
|
combiner=self.combiner,
|
|
name=self.name,)
|
|
)
|
|
|
|
|
|
@tf_export("tpu.experimental.embedding.FeatureConfig")
|
|
class FeatureConfig(object):
|
|
"""Configuration data for one embedding feature.
|
|
|
|
This class holds the configuration data for a single embedding feature. The
|
|
main use is to assign features to `tf.tpu.experimental.embedding.TableConfig`s
|
|
via the table parameter:
|
|
|
|
```python
|
|
table_config_one = tf.tpu.experimental.embedding.TableConfig(
|
|
vocabulary_size=...,
|
|
dim=...)
|
|
table_config_two = tf.tpu.experimental.embedding.TableConfig(
|
|
vocabulary_size=...,
|
|
dim=...)
|
|
feature_config = {
|
|
'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_config_one),
|
|
'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_config_one),
|
|
'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
|
|
table=table_config_two)}
|
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
|
feature_config=feature_config,
|
|
batch_size=...
|
|
optimizer=tf.tpu.experimental.embedding.Adam(0.1))
|
|
```
|
|
|
|
The above configuration has 2 tables, and three features. The first two
|
|
features will be looked up in the first table and the third feature will be
|
|
looked up in the second table.
|
|
|
|
When feeding features into `embedding.enqueue` they can be `tf.Tensor`s,
|
|
`tf.SparseTensor`s or `tf.RaggedTensor`s. When the argument
|
|
`max_sequence_length` is 0, the default, you should expect a output of
|
|
`embedding.dequeue` for this feature of shape `(batch_size, dim)`. If
|
|
`max_sequence_length` is greater than 0, the feature is embedded as a sequence
|
|
and padded up to the given length. The shape of the output for this feature
|
|
will be `(batch_size, max_sequence_length, dim)`.
|
|
"""
|
|
|
|
def __init__(self,
|
|
table: TableConfig,
|
|
max_sequence_length: int = 0,
|
|
name: Optional[Text] = None):
|
|
"""Feature configuration.
|
|
|
|
Args:
|
|
table: An instance of `tf.tpu.experimental.embedding.TableConfig`,
|
|
describing the table in which this feature should be looked up.
|
|
max_sequence_length: If positive, the feature is a sequence feature with
|
|
the corresponding maximum sequence length. If the sequence is longer
|
|
than this, it will be truncated. If 0, the feature is not a sequence
|
|
feature.
|
|
name: An optional name for the feature, useful for debugging.
|
|
|
|
Returns:
|
|
`FeatureConfig`.
|
|
|
|
Raises:
|
|
ValueError: if `table` is not an instance of
|
|
`tf.tpu.experimental.embedding.TableConfig`.
|
|
ValueError: if `max_sequence_length` not an integer or is negative.
|
|
"""
|
|
if not isinstance(table, TableConfig):
|
|
raise ValueError("table is type {}, expected "
|
|
"`tf.tpu.experimental.embedding.TableConfig`".format(
|
|
type(table)))
|
|
|
|
if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
|
|
raise ValueError("Invalid max_sequence_length {}.".format(
|
|
max_sequence_length))
|
|
|
|
self.table = table
|
|
self.max_sequence_length = max_sequence_length
|
|
self.name = name
|
|
|
|
def __repr__(self):
|
|
return (
|
|
"FeatureConfig(table={table!r}, "
|
|
"max_sequence_length={max_sequence_length!r}, name={name!r})"
|
|
.format(
|
|
table=self.table,
|
|
max_sequence_length=self.max_sequence_length,
|
|
name=self.name)
|
|
)
|
|
|
|
|
|
def log_tpu_embedding_configuration(
|
|
config: tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration) -> None:
|
|
"""Logs a TPUEmbeddingConfiguration proto across multiple statements.
|
|
|
|
Args:
|
|
config: TPUEmbeddingConfiguration proto to log. Necessary because
|
|
logging.info has a maximum length to each log statement, which
|
|
particularly large configs can exceed.
|
|
"""
|
|
logging.info("Beginning log of TPUEmbeddingConfiguration.")
|
|
for line in str(config).splitlines():
|
|
logging.info(line)
|
|
logging.info("Done with log of TPUEmbeddingConfiguration.")
|