Support gradient multiplier for embeddings in TPUEstimator.
PiperOrigin-RevId: 241354959
This commit is contained in:
parent
bf955a8662
commit
daf85eddac
@ -25,11 +25,14 @@ import six
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.feature_column import feature_column as core_fc
|
||||
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.tpu import feature_column as tpu_fc
|
||||
from tensorflow.python.tpu import tpu_embedding
|
||||
from tensorflow.python.tpu.tpu_embedding import AdagradParameters
|
||||
from tensorflow.python.tpu.tpu_embedding import AdamParameters
|
||||
from tensorflow.python.tpu.tpu_embedding import StochasticGradientDescentParameters
|
||||
from tensorflow.python.training import training
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,
|
||||
@ -150,7 +153,8 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
||||
class EmbeddingConfigSpec(
|
||||
collections.namedtuple('EmbeddingConfigSpec', [
|
||||
'feature_columns', 'optimization_parameters', 'clipping_limit',
|
||||
'pipeline_execution_with_tensor_core'
|
||||
'pipeline_execution_with_tensor_core',
|
||||
'experimental_gradient_multiplier_fn'
|
||||
])):
|
||||
"""Class to keep track of embedding config specification."""
|
||||
|
||||
@ -158,7 +162,8 @@ class EmbeddingConfigSpec(
|
||||
feature_columns,
|
||||
optimization_parameters,
|
||||
clipping_limit=None,
|
||||
pipeline_execution_with_tensor_core=False):
|
||||
pipeline_execution_with_tensor_core=False,
|
||||
experimental_gradient_multiplier_fn=None):
|
||||
"""Creates an EmbeddingConfigSpec instance.
|
||||
|
||||
Args:
|
||||
@ -172,6 +177,8 @@ class EmbeddingConfigSpec(
|
||||
faster, but trained model will be different if step N and step N+1
|
||||
involve the same set of embedding IDs. Please see
|
||||
`tpu_embedding_configuration.proto` for details.
|
||||
experimental_gradient_multiplier_fn: (Optional) A Fn taking global step as
|
||||
input returning the current multiplier for all embedding gradients.
|
||||
|
||||
Returns:
|
||||
An EmbeddingConfigSpec instance.
|
||||
@ -208,7 +215,8 @@ class EmbeddingConfigSpec(
|
||||
feature_columns=feature_columns,
|
||||
optimization_parameters=optimization_parameters,
|
||||
clipping_limit=clipping_limit,
|
||||
pipeline_execution_with_tensor_core=pipeline_execution_with_tensor_core)
|
||||
pipeline_execution_with_tensor_core=pipeline_execution_with_tensor_core,
|
||||
experimental_gradient_multiplier_fn=experimental_gradient_multiplier_fn)
|
||||
|
||||
|
||||
class EmbeddingConfig(object):
|
||||
@ -221,6 +229,9 @@ class EmbeddingConfig(object):
|
||||
|
||||
def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,
|
||||
num_hosts, num_cores, run_config):
|
||||
if not embedding_config_spec:
|
||||
raise ValueError('embedding_config_spec cannot be None.')
|
||||
|
||||
self._embedding_config_spec = embedding_config_spec
|
||||
self._train_batch_size = train_batch_size
|
||||
self._eval_batch_size = eval_batch_size
|
||||
@ -234,6 +245,15 @@ class EmbeddingConfig(object):
|
||||
self._mode_to_tpu_embedding_dict = {}
|
||||
self.dummy_table_variables = None
|
||||
|
||||
self._grad_multiplier_fn = (
|
||||
embedding_config_spec.experimental_gradient_multiplier_fn)
|
||||
|
||||
def get_grad_multiplier(self):
|
||||
if self._grad_multiplier_fn:
|
||||
return ops.convert_to_tensor(
|
||||
self._grad_multiplier_fn(training.get_global_step()),
|
||||
dtype=dtypes.float32)
|
||||
|
||||
def has_embedding_tables(self):
|
||||
return bool(self._table_to_config_dict)
|
||||
|
||||
|
@ -1488,8 +1488,14 @@ class _ModelFnWrapper(object):
|
||||
tpu_embedding_gradient.get_gradients_through_dummy_table_variables(
|
||||
tpu_embedding_)
|
||||
)
|
||||
grad_multiplier = self._ctx.embedding_config.get_grad_multiplier()
|
||||
if grad_multiplier is not None:
|
||||
scaled_gradients = collections.OrderedDict(
|
||||
(k, v * grad_multiplier) for k, v in six.iteritems(gradients))
|
||||
else:
|
||||
scaled_gradients = gradients
|
||||
apply_sparse_grads = [
|
||||
tpu_embedding_.generate_send_gradients_op(gradients)
|
||||
tpu_embedding_.generate_send_gradients_op(scaled_gradients)
|
||||
]
|
||||
|
||||
# We must run train_op to update the variables prior to running the
|
||||
|
Loading…
Reference in New Issue
Block a user