Fix gradient computation bug in TPU embedding mid-level API.
PiperOrigin-RevId: 347744033 Change-Id: I685274217865fca9d5aa0f34bbbce618dcac5f13
This commit is contained in:
parent
2e61843949
commit
0f8a32c6cc
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Operations for TPUs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -56,8 +57,8 @@ def all_to_all(x,
|
||||
split_count: The number of splits, this number must equal to the sub-group
|
||||
size(group_assignment.get_shape()[1])
|
||||
group_assignment: Optional 2d int32 lists with shape [num_groups,
|
||||
num_replicas_per_group]. `group_assignment[i]` represents the replica ids
|
||||
in the ith subgroup.
|
||||
num_replicas_per_group]. `group_assignment[i]` represents the replica
|
||||
ids in the ith subgroup.
|
||||
name: Optional op name.
|
||||
|
||||
Returns:
|
||||
@ -96,8 +97,8 @@ def cross_replica_sum(x, group_assignment=None, name=None):
|
||||
Args:
|
||||
x: The local tensor to the sum.
|
||||
group_assignment: Optional 2d int32 lists with shape [num_groups,
|
||||
num_replicas_per_group]. `group_assignment[i]` represents the replica ids
|
||||
in the ith subgroup.
|
||||
num_replicas_per_group]. `group_assignment[i]` represents the replica
|
||||
ids in the ith subgroup.
|
||||
name: Optional op name.
|
||||
|
||||
Returns:
|
||||
@ -167,8 +168,8 @@ def _embedding_activations_grad(activations_op, grad_wrt_activations):
|
||||
g = ops.get_default_graph()
|
||||
table_id = activations_op.get_attr("table_id")
|
||||
lookup_id = activations_op.get_attr("lookup_id")
|
||||
table_gradients = g.get_collection_ref("tpu_embedding_gradients_table_%d" %
|
||||
table_id)
|
||||
table_gradients = g.get_collection_ref(
|
||||
"tpu_embedding_gradients_table_%d" % table_id)
|
||||
|
||||
if not table_gradients:
|
||||
raise RuntimeError(
|
||||
@ -180,15 +181,6 @@ def _embedding_activations_grad(activations_op, grad_wrt_activations):
|
||||
" train_op = opt.minimize(loss)\n"
|
||||
"\n")
|
||||
|
||||
if table_gradients[lookup_id] is not None:
|
||||
raise RuntimeError(
|
||||
"Duplicate gradients (w.r.t. TPUEmbedding activations) generated for "
|
||||
"table_id {} and lookup_id {}. This happens when there are multiple "
|
||||
"calls to tf.gradients in a graph containing TPU embeddings. "
|
||||
"TF cannot identify which gradient to use for updating the embedding "
|
||||
"variables. Consider placing tf.StopGradient around tensors where "
|
||||
"variable update is not required.".format(table_id, lookup_id))
|
||||
|
||||
table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
|
||||
return [
|
||||
# RegisterGradient requires that value be returned for all inputs. Since
|
||||
@ -230,10 +222,10 @@ def infeed_dequeue_tuple(dtypes, shapes, name=None):
|
||||
"""A placeholder op for values fed into the TPU simultaneously as a tuple.
|
||||
|
||||
Args:
|
||||
dtypes: A list of `tf.DType`s that has length `>= 1`. The element types of
|
||||
each element in `outputs`.
|
||||
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). The
|
||||
shapes of each tensor in `outputs`.
|
||||
dtypes: A list of `tf.DType`s that has length `>= 1`.
|
||||
The element types of each element in `outputs`.
|
||||
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
|
||||
The shapes of each tensor in `outputs`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -249,8 +241,6 @@ def infeed_dequeue_tuple(dtypes, shapes, name=None):
|
||||
"{} is not a supported TPU infeed type. Supported types are: "
|
||||
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
|
||||
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
|
||||
|
||||
|
||||
# pylint: enable=redefined-outer-name
|
||||
|
||||
|
||||
@ -263,18 +253,19 @@ def send_tpu_embedding_gradients(inputs,
|
||||
|
||||
Args:
|
||||
inputs: A TensorList of gradients with which to update embedding tables.
|
||||
This argument has the same length and shapes as the return value of
|
||||
RecvTPUEmbeddingActivations, but contains gradients of the model's loss
|
||||
with respect to the embedding activations. The embedding tables are
|
||||
updated from these gradients via the optimizers specified in the TPU
|
||||
embedding configuration given to tpu.initialize_system.
|
||||
This argument has the same length and shapes as the return value of
|
||||
RecvTPUEmbeddingActivations, but contains gradients of the model's
|
||||
loss with respect to the embedding activations. The embedding tables
|
||||
are updated from these gradients via the optimizers specified in the
|
||||
TPU embedding configuration given to tpu.initialize_system.
|
||||
config: Serialized TPUEmbeddingConfiguration proto.
|
||||
learning_rates: A TensorList of float32 scalars, one for each dynamic
|
||||
learning rate tag: see the comments in
|
||||
//third_party/tensorflow/core/protobuf/tpu/
|
||||
optimization_parameters.proto. Multiple tables can share the same
|
||||
dynamic learning rate tag as specified in the configuration. If the
|
||||
learning rates for all tables are constant, this list should be empty.
|
||||
//third_party/tensorflow/core/protobuf/tpu/
|
||||
optimization_parameters.proto.
|
||||
Multiple tables can share the same dynamic learning rate tag as
|
||||
specified in the configuration. If the learning rates for all tables
|
||||
are constant, this list should be empty.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -336,8 +327,8 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
|
||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||
|
||||
Args:
|
||||
sample_indices: A list of rank 1 Tensors specifying the training example and
|
||||
feature to which the corresponding embedding_indices and
|
||||
sample_indices: A list of rank 1 Tensors specifying the training example
|
||||
and feature to which the corresponding embedding_indices and
|
||||
aggregation_weights values belong. sample_indices[i] must equal b * nf +
|
||||
f, where nf is the number of features from the corresponding table, f is
|
||||
in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
|
||||
@ -345,9 +336,9 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
|
||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||
tables. Both int32 and int64 are allowed and will be converted to int32
|
||||
internally.
|
||||
aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e.,
|
||||
per (training example, feature) -- aggregation weights. Both float32 and
|
||||
float64 are allowed and will be converted to float32 internally.
|
||||
aggregation_weights: A list of rank 1 Tensors containing per sample --
|
||||
i.e. per (training example, feature) -- aggregation weights. Both float32
|
||||
and float64 are allowed and will be converted to float32 internally.
|
||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||
number of TPU cores in the task on which the node is placed.
|
||||
combiners: A list of string scalars, one for each embedding table that
|
||||
@ -395,20 +386,20 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||
|
||||
Args:
|
||||
sample_indices: A list of rank 2 Tensors specifying the training example to
|
||||
which the corresponding embedding_indices and aggregation_weights values
|
||||
belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If
|
||||
the size of its first dimension is 0, we assume each embedding_indices
|
||||
belongs to a different sample. Both int32 and int64 are allowed and will
|
||||
be converted to int32 internally.
|
||||
sample_indices: A list of rank 2 Tensors specifying the training example
|
||||
to which the corresponding embedding_indices and aggregation_weights
|
||||
values belong. It corresponds to sp_ids.indices in
|
||||
embedding_lookup_sparse(). If the size of its first dimension is 0, we
|
||||
assume each embedding_indices belongs to a different sample. Both int32
|
||||
and int64 are allowed and will be converted to int32 internally.
|
||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||
tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
|
||||
int32 and int64 are allowed and will be converted to int32 internally.
|
||||
aggregation_weights: A list of rank 1 Tensors containing per training
|
||||
example aggregation weights. It corresponds to sp_weights.values in
|
||||
embedding_lookup_sparse(). If the size of its first dimension is 0, we
|
||||
assume all weights are 1. Both float32 and float64 are allowed and will be
|
||||
converted to float32 internally.
|
||||
assume all weights are 1. Both float32 and float64 are allowed and will
|
||||
be converted to float32 internally.
|
||||
table_ids: A list of integers specifying the identifier of the embedding
|
||||
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
|
||||
lookup the corresponding input. The ith input is looked up using
|
||||
|
@ -91,8 +91,8 @@ def create_dummy_table_variables(tpu_embedding):
|
||||
if table_gradients:
|
||||
raise RuntimeError(
|
||||
'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
|
||||
num_features = len(tpu_embedding.table_to_features_dict[table])
|
||||
table_gradients.extend([None for _ in range(num_features)])
|
||||
table_gradients.extend(
|
||||
[None] * len(tpu_embedding.table_to_features_dict[table]))
|
||||
|
||||
return (dummy_table_variables,
|
||||
variables.variables_initializer(
|
||||
|
Loading…
Reference in New Issue
Block a user