Fix gradient computation bug in TPU embedding mid-level API.

PiperOrigin-RevId: 347744033
Change-Id: I685274217865fca9d5aa0f34bbbce618dcac5f13
This commit is contained in:
A. Unique TensorFlower 2020-12-15 20:09:11 -08:00 committed by TensorFlower Gardener
parent 2e61843949
commit 0f8a32c6cc
2 changed files with 36 additions and 45 deletions

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Operations for TPUs."""
from __future__ import absolute_import
@ -56,8 +57,8 @@ def all_to_all(x,
split_count: The number of splits, this number must equal to the sub-group
size(group_assignment.get_shape()[1])
group_assignment: Optional 2d int32 lists with shape [num_groups,
num_replicas_per_group]. `group_assignment[i]` represents the replica ids
in the ith subgroup.
num_replicas_per_group]. `group_assignment[i]` represents the replica
ids in the ith subgroup.
name: Optional op name.
Returns:
@ -96,8 +97,8 @@ def cross_replica_sum(x, group_assignment=None, name=None):
Args:
x: The local tensor to the sum.
group_assignment: Optional 2d int32 lists with shape [num_groups,
num_replicas_per_group]. `group_assignment[i]` represents the replica ids
in the ith subgroup.
num_replicas_per_group]. `group_assignment[i]` represents the replica
ids in the ith subgroup.
name: Optional op name.
Returns:
@ -167,8 +168,8 @@ def _embedding_activations_grad(activations_op, grad_wrt_activations):
g = ops.get_default_graph()
table_id = activations_op.get_attr("table_id")
lookup_id = activations_op.get_attr("lookup_id")
table_gradients = g.get_collection_ref("tpu_embedding_gradients_table_%d" %
table_id)
table_gradients = g.get_collection_ref(
"tpu_embedding_gradients_table_%d" % table_id)
if not table_gradients:
raise RuntimeError(
@ -180,15 +181,6 @@ def _embedding_activations_grad(activations_op, grad_wrt_activations):
" train_op = opt.minimize(loss)\n"
"\n")
if table_gradients[lookup_id] is not None:
raise RuntimeError(
"Duplicate gradients (w.r.t. TPUEmbedding activations) generated for "
"table_id {} and lookup_id {}. This happens when there are multiple "
"calls to tf.gradients in a graph containing TPU embeddings. "
"TF cannot identify which gradient to use for updating the embedding "
"variables. Consider placing tf.StopGradient around tensors where "
"variable update is not required.".format(table_id, lookup_id))
table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
return [
# RegisterGradient requires that value be returned for all inputs. Since
@ -230,10 +222,10 @@ def infeed_dequeue_tuple(dtypes, shapes, name=None):
"""A placeholder op for values fed into the TPU simultaneously as a tuple.
Args:
dtypes: A list of `tf.DType`s that has length `>= 1`. The element types of
each element in `outputs`.
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). The
shapes of each tensor in `outputs`.
dtypes: A list of `tf.DType`s that has length `>= 1`.
The element types of each element in `outputs`.
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
The shapes of each tensor in `outputs`.
name: A name for the operation (optional).
Returns:
@ -249,8 +241,6 @@ def infeed_dequeue_tuple(dtypes, shapes, name=None):
"{} is not a supported TPU infeed type. Supported types are: "
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
# pylint: enable=redefined-outer-name
@ -263,18 +253,19 @@ def send_tpu_embedding_gradients(inputs,
Args:
inputs: A TensorList of gradients with which to update embedding tables.
This argument has the same length and shapes as the return value of
RecvTPUEmbeddingActivations, but contains gradients of the model's loss
with respect to the embedding activations. The embedding tables are
updated from these gradients via the optimizers specified in the TPU
embedding configuration given to tpu.initialize_system.
This argument has the same length and shapes as the return value of
RecvTPUEmbeddingActivations, but contains gradients of the model's
loss with respect to the embedding activations. The embedding tables
are updated from these gradients via the optimizers specified in the
TPU embedding configuration given to tpu.initialize_system.
config: Serialized TPUEmbeddingConfiguration proto.
learning_rates: A TensorList of float32 scalars, one for each dynamic
learning rate tag: see the comments in
//third_party/tensorflow/core/protobuf/tpu/
optimization_parameters.proto. Multiple tables can share the same
dynamic learning rate tag as specified in the configuration. If the
learning rates for all tables are constant, this list should be empty.
//third_party/tensorflow/core/protobuf/tpu/
optimization_parameters.proto.
Multiple tables can share the same dynamic learning rate tag as
specified in the configuration. If the learning rates for all tables
are constant, this list should be empty.
name: A name for the operation (optional).
Returns:
@ -336,8 +327,8 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
"""A placeholder op for enqueueing embedding IDs to the TPU.
Args:
sample_indices: A list of rank 1 Tensors specifying the training example and
feature to which the corresponding embedding_indices and
sample_indices: A list of rank 1 Tensors specifying the training example
and feature to which the corresponding embedding_indices and
aggregation_weights values belong. sample_indices[i] must equal b * nf +
f, where nf is the number of features from the corresponding table, f is
in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
@ -345,9 +336,9 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables. Both int32 and int64 are allowed and will be converted to int32
internally.
aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e.,
per (training example, feature) -- aggregation weights. Both float32 and
float64 are allowed and will be converted to float32 internally.
aggregation_weights: A list of rank 1 Tensors containing per sample --
i.e. per (training example, feature) -- aggregation weights. Both float32
and float64 are allowed and will be converted to float32 internally.
device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed.
combiners: A list of string scalars, one for each embedding table that
@ -395,20 +386,20 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
"""A placeholder op for enqueueing embedding IDs to the TPU.
Args:
sample_indices: A list of rank 2 Tensors specifying the training example to
which the corresponding embedding_indices and aggregation_weights values
belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If
the size of its first dimension is 0, we assume each embedding_indices
belongs to a different sample. Both int32 and int64 are allowed and will
be converted to int32 internally.
sample_indices: A list of rank 2 Tensors specifying the training example
to which the corresponding embedding_indices and aggregation_weights
values belong. It corresponds to sp_ids.indices in
embedding_lookup_sparse(). If the size of its first dimension is 0, we
assume each embedding_indices belongs to a different sample. Both int32
and int64 are allowed and will be converted to int32 internally.
embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
int32 and int64 are allowed and will be converted to int32 internally.
aggregation_weights: A list of rank 1 Tensors containing per training
example aggregation weights. It corresponds to sp_weights.values in
embedding_lookup_sparse(). If the size of its first dimension is 0, we
assume all weights are 1. Both float32 and float64 are allowed and will be
converted to float32 internally.
assume all weights are 1. Both float32 and float64 are allowed and will
be converted to float32 internally.
table_ids: A list of integers specifying the identifier of the embedding
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
lookup the corresponding input. The ith input is looked up using

View File

@ -91,8 +91,8 @@ def create_dummy_table_variables(tpu_embedding):
if table_gradients:
raise RuntimeError(
'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
num_features = len(tpu_embedding.table_to_features_dict[table])
table_gradients.extend([None for _ in range(num_features)])
table_gradients.extend(
[None] * len(tpu_embedding.table_to_features_dict[table]))
return (dummy_table_variables,
variables.variables_initializer(