Fix gradient computation bug in TPU embedding mid-level API.

PiperOrigin-RevId: 347744033
Change-Id: I685274217865fca9d5aa0f34bbbce618dcac5f13
This commit is contained in:
A. Unique TensorFlower 2020-12-15 20:09:11 -08:00 committed by TensorFlower Gardener
parent 2e61843949
commit 0f8a32c6cc
2 changed files with 36 additions and 45 deletions

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================= # =============================================================================
"""Operations for TPUs.""" """Operations for TPUs."""
from __future__ import absolute_import from __future__ import absolute_import
@ -56,8 +57,8 @@ def all_to_all(x,
split_count: The number of splits, this number must equal to the sub-group split_count: The number of splits, this number must equal to the sub-group
size(group_assignment.get_shape()[1]) size(group_assignment.get_shape()[1])
group_assignment: Optional 2d int32 lists with shape [num_groups, group_assignment: Optional 2d int32 lists with shape [num_groups,
num_replicas_per_group]. `group_assignment[i]` represents the replica ids num_replicas_per_group]. `group_assignment[i]` represents the replica
in the ith subgroup. ids in the ith subgroup.
name: Optional op name. name: Optional op name.
Returns: Returns:
@ -96,8 +97,8 @@ def cross_replica_sum(x, group_assignment=None, name=None):
Args: Args:
x: The local tensor to the sum. x: The local tensor to the sum.
group_assignment: Optional 2d int32 lists with shape [num_groups, group_assignment: Optional 2d int32 lists with shape [num_groups,
num_replicas_per_group]. `group_assignment[i]` represents the replica ids num_replicas_per_group]. `group_assignment[i]` represents the replica
in the ith subgroup. ids in the ith subgroup.
name: Optional op name. name: Optional op name.
Returns: Returns:
@ -167,8 +168,8 @@ def _embedding_activations_grad(activations_op, grad_wrt_activations):
g = ops.get_default_graph() g = ops.get_default_graph()
table_id = activations_op.get_attr("table_id") table_id = activations_op.get_attr("table_id")
lookup_id = activations_op.get_attr("lookup_id") lookup_id = activations_op.get_attr("lookup_id")
table_gradients = g.get_collection_ref("tpu_embedding_gradients_table_%d" % table_gradients = g.get_collection_ref(
table_id) "tpu_embedding_gradients_table_%d" % table_id)
if not table_gradients: if not table_gradients:
raise RuntimeError( raise RuntimeError(
@ -180,15 +181,6 @@ def _embedding_activations_grad(activations_op, grad_wrt_activations):
" train_op = opt.minimize(loss)\n" " train_op = opt.minimize(loss)\n"
"\n") "\n")
if table_gradients[lookup_id] is not None:
raise RuntimeError(
"Duplicate gradients (w.r.t. TPUEmbedding activations) generated for "
"table_id {} and lookup_id {}. This happens when there are multiple "
"calls to tf.gradients in a graph containing TPU embeddings. "
"TF cannot identify which gradient to use for updating the embedding "
"variables. Consider placing tf.StopGradient around tensors where "
"variable update is not required.".format(table_id, lookup_id))
table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
return [ return [
# RegisterGradient requires that value be returned for all inputs. Since # RegisterGradient requires that value be returned for all inputs. Since
@ -230,10 +222,10 @@ def infeed_dequeue_tuple(dtypes, shapes, name=None):
"""A placeholder op for values fed into the TPU simultaneously as a tuple. """A placeholder op for values fed into the TPU simultaneously as a tuple.
Args: Args:
dtypes: A list of `tf.DType`s that has length `>= 1`. The element types of dtypes: A list of `tf.DType`s that has length `>= 1`.
each element in `outputs`. The element types of each element in `outputs`.
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). The shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
shapes of each tensor in `outputs`. The shapes of each tensor in `outputs`.
name: A name for the operation (optional). name: A name for the operation (optional).
Returns: Returns:
@ -249,8 +241,6 @@ def infeed_dequeue_tuple(dtypes, shapes, name=None):
"{} is not a supported TPU infeed type. Supported types are: " "{} is not a supported TPU infeed type. Supported types are: "
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
# pylint: enable=redefined-outer-name # pylint: enable=redefined-outer-name
@ -264,17 +254,18 @@ def send_tpu_embedding_gradients(inputs,
Args: Args:
inputs: A TensorList of gradients with which to update embedding tables. inputs: A TensorList of gradients with which to update embedding tables.
This argument has the same length and shapes as the return value of This argument has the same length and shapes as the return value of
RecvTPUEmbeddingActivations, but contains gradients of the model's loss RecvTPUEmbeddingActivations, but contains gradients of the model's
with respect to the embedding activations. The embedding tables are loss with respect to the embedding activations. The embedding tables
updated from these gradients via the optimizers specified in the TPU are updated from these gradients via the optimizers specified in the
embedding configuration given to tpu.initialize_system. TPU embedding configuration given to tpu.initialize_system.
config: Serialized TPUEmbeddingConfiguration proto. config: Serialized TPUEmbeddingConfiguration proto.
learning_rates: A TensorList of float32 scalars, one for each dynamic learning_rates: A TensorList of float32 scalars, one for each dynamic
learning rate tag: see the comments in learning rate tag: see the comments in
//third_party/tensorflow/core/protobuf/tpu/ //third_party/tensorflow/core/protobuf/tpu/
optimization_parameters.proto. Multiple tables can share the same optimization_parameters.proto.
dynamic learning rate tag as specified in the configuration. If the Multiple tables can share the same dynamic learning rate tag as
learning rates for all tables are constant, this list should be empty. specified in the configuration. If the learning rates for all tables
are constant, this list should be empty.
name: A name for the operation (optional). name: A name for the operation (optional).
Returns: Returns:
@ -336,8 +327,8 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
"""A placeholder op for enqueueing embedding IDs to the TPU. """A placeholder op for enqueueing embedding IDs to the TPU.
Args: Args:
sample_indices: A list of rank 1 Tensors specifying the training example and sample_indices: A list of rank 1 Tensors specifying the training example
feature to which the corresponding embedding_indices and and feature to which the corresponding embedding_indices and
aggregation_weights values belong. sample_indices[i] must equal b * nf + aggregation_weights values belong. sample_indices[i] must equal b * nf +
f, where nf is the number of features from the corresponding table, f is f, where nf is the number of features from the corresponding table, f is
in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed, in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
@ -345,9 +336,9 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
embedding_indices: A list of rank 1 Tensors, indices into the embedding embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables. Both int32 and int64 are allowed and will be converted to int32 tables. Both int32 and int64 are allowed and will be converted to int32
internally. internally.
aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e., aggregation_weights: A list of rank 1 Tensors containing per sample --
per (training example, feature) -- aggregation weights. Both float32 and i.e. per (training example, feature) -- aggregation weights. Both float32
float64 are allowed and will be converted to float32 internally. and float64 are allowed and will be converted to float32 internally.
device_ordinal: The TPU device to use. Should be >= 0 and less than the device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed. number of TPU cores in the task on which the node is placed.
combiners: A list of string scalars, one for each embedding table that combiners: A list of string scalars, one for each embedding table that
@ -395,20 +386,20 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
"""A placeholder op for enqueueing embedding IDs to the TPU. """A placeholder op for enqueueing embedding IDs to the TPU.
Args: Args:
sample_indices: A list of rank 2 Tensors specifying the training example to sample_indices: A list of rank 2 Tensors specifying the training example
which the corresponding embedding_indices and aggregation_weights values to which the corresponding embedding_indices and aggregation_weights
belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If values belong. It corresponds to sp_ids.indices in
the size of its first dimension is 0, we assume each embedding_indices embedding_lookup_sparse(). If the size of its first dimension is 0, we
belongs to a different sample. Both int32 and int64 are allowed and will assume each embedding_indices belongs to a different sample. Both int32
be converted to int32 internally. and int64 are allowed and will be converted to int32 internally.
embedding_indices: A list of rank 1 Tensors, indices into the embedding embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
int32 and int64 are allowed and will be converted to int32 internally. int32 and int64 are allowed and will be converted to int32 internally.
aggregation_weights: A list of rank 1 Tensors containing per training aggregation_weights: A list of rank 1 Tensors containing per training
example aggregation weights. It corresponds to sp_weights.values in example aggregation weights. It corresponds to sp_weights.values in
embedding_lookup_sparse(). If the size of its first dimension is 0, we embedding_lookup_sparse(). If the size of its first dimension is 0, we
assume all weights are 1. Both float32 and float64 are allowed and will be assume all weights are 1. Both float32 and float64 are allowed and will
converted to float32 internally. be converted to float32 internally.
table_ids: A list of integers specifying the identifier of the embedding table_ids: A list of integers specifying the identifier of the embedding
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
lookup the corresponding input. The ith input is looked up using lookup the corresponding input. The ith input is looked up using

View File

@ -91,8 +91,8 @@ def create_dummy_table_variables(tpu_embedding):
if table_gradients: if table_gradients:
raise RuntimeError( raise RuntimeError(
'tpu_embedding_gradients_table_{} is not empty.'.format(table_id)) 'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
num_features = len(tpu_embedding.table_to_features_dict[table]) table_gradients.extend(
table_gradients.extend([None for _ in range(num_features)]) [None] * len(tpu_embedding.table_to_features_dict[table]))
return (dummy_table_variables, return (dummy_table_variables,
variables.variables_initializer( variables.variables_initializer(