Support gradient multiplier for embeddings in TPUEstimator.

PiperOrigin-RevId: 241354959
This commit is contained in:
A. Unique TensorFlower 2019-04-01 10:46:20 -07:00 committed by TensorFlower Gardener
parent bf955a8662
commit daf85eddac
2 changed files with 30 additions and 4 deletions

View File

@ -25,11 +25,14 @@ import six
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.feature_column import feature_column as core_fc
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.tpu import feature_column as tpu_fc
from tensorflow.python.tpu import tpu_embedding
from tensorflow.python.tpu.tpu_embedding import AdagradParameters
from tensorflow.python.tpu.tpu_embedding import AdamParameters
from tensorflow.python.tpu.tpu_embedding import StochasticGradientDescentParameters
from tensorflow.python.training import training
# pylint: disable=protected-access
_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,
@ -150,7 +153,8 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
class EmbeddingConfigSpec(
collections.namedtuple('EmbeddingConfigSpec', [
'feature_columns', 'optimization_parameters', 'clipping_limit',
'pipeline_execution_with_tensor_core'
'pipeline_execution_with_tensor_core',
'experimental_gradient_multiplier_fn'
])):
"""Class to keep track of embedding config specification."""
@ -158,7 +162,8 @@ class EmbeddingConfigSpec(
feature_columns,
optimization_parameters,
clipping_limit=None,
pipeline_execution_with_tensor_core=False):
pipeline_execution_with_tensor_core=False,
experimental_gradient_multiplier_fn=None):
"""Creates an EmbeddingConfigSpec instance.
Args:
@ -172,6 +177,8 @@ class EmbeddingConfigSpec(
faster, but trained model will be different if step N and step N+1
involve the same set of embedding IDs. Please see
`tpu_embedding_configuration.proto` for details.
experimental_gradient_multiplier_fn: (Optional) A Fn taking global step as
input returning the current multiplier for all embedding gradients.
Returns:
An EmbeddingConfigSpec instance.
@ -208,7 +215,8 @@ class EmbeddingConfigSpec(
feature_columns=feature_columns,
optimization_parameters=optimization_parameters,
clipping_limit=clipping_limit,
pipeline_execution_with_tensor_core=pipeline_execution_with_tensor_core)
pipeline_execution_with_tensor_core=pipeline_execution_with_tensor_core,
experimental_gradient_multiplier_fn=experimental_gradient_multiplier_fn)
class EmbeddingConfig(object):
@ -221,6 +229,9 @@ class EmbeddingConfig(object):
def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,
num_hosts, num_cores, run_config):
if not embedding_config_spec:
raise ValueError('embedding_config_spec cannot be None.')
self._embedding_config_spec = embedding_config_spec
self._train_batch_size = train_batch_size
self._eval_batch_size = eval_batch_size
@ -234,6 +245,15 @@ class EmbeddingConfig(object):
self._mode_to_tpu_embedding_dict = {}
self.dummy_table_variables = None
self._grad_multiplier_fn = (
embedding_config_spec.experimental_gradient_multiplier_fn)
def get_grad_multiplier(self):
if self._grad_multiplier_fn:
return ops.convert_to_tensor(
self._grad_multiplier_fn(training.get_global_step()),
dtype=dtypes.float32)
def has_embedding_tables(self):
return bool(self._table_to_config_dict)

View File

@ -1488,8 +1488,14 @@ class _ModelFnWrapper(object):
tpu_embedding_gradient.get_gradients_through_dummy_table_variables(
tpu_embedding_)
)
grad_multiplier = self._ctx.embedding_config.get_grad_multiplier()
if grad_multiplier is not None:
scaled_gradients = collections.OrderedDict(
(k, v * grad_multiplier) for k, v in six.iteritems(gradients))
else:
scaled_gradients = gradients
apply_sparse_grads = [
tpu_embedding_.generate_send_gradients_op(gradients)
tpu_embedding_.generate_send_gradients_op(scaled_gradients)
]
# We must run train_op to update the variables prior to running the