Fix gradient computation bug in TPU embedding mid-level API.
PiperOrigin-RevId: 347744033 Change-Id: I685274217865fca9d5aa0f34bbbce618dcac5f13
This commit is contained in:
parent
2e61843949
commit
0f8a32c6cc
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
"""Operations for TPUs."""
|
"""Operations for TPUs."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
@ -56,8 +57,8 @@ def all_to_all(x,
|
|||||||
split_count: The number of splits, this number must equal to the sub-group
|
split_count: The number of splits, this number must equal to the sub-group
|
||||||
size(group_assignment.get_shape()[1])
|
size(group_assignment.get_shape()[1])
|
||||||
group_assignment: Optional 2d int32 lists with shape [num_groups,
|
group_assignment: Optional 2d int32 lists with shape [num_groups,
|
||||||
num_replicas_per_group]. `group_assignment[i]` represents the replica ids
|
num_replicas_per_group]. `group_assignment[i]` represents the replica
|
||||||
in the ith subgroup.
|
ids in the ith subgroup.
|
||||||
name: Optional op name.
|
name: Optional op name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -96,8 +97,8 @@ def cross_replica_sum(x, group_assignment=None, name=None):
|
|||||||
Args:
|
Args:
|
||||||
x: The local tensor to the sum.
|
x: The local tensor to the sum.
|
||||||
group_assignment: Optional 2d int32 lists with shape [num_groups,
|
group_assignment: Optional 2d int32 lists with shape [num_groups,
|
||||||
num_replicas_per_group]. `group_assignment[i]` represents the replica ids
|
num_replicas_per_group]. `group_assignment[i]` represents the replica
|
||||||
in the ith subgroup.
|
ids in the ith subgroup.
|
||||||
name: Optional op name.
|
name: Optional op name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -167,8 +168,8 @@ def _embedding_activations_grad(activations_op, grad_wrt_activations):
|
|||||||
g = ops.get_default_graph()
|
g = ops.get_default_graph()
|
||||||
table_id = activations_op.get_attr("table_id")
|
table_id = activations_op.get_attr("table_id")
|
||||||
lookup_id = activations_op.get_attr("lookup_id")
|
lookup_id = activations_op.get_attr("lookup_id")
|
||||||
table_gradients = g.get_collection_ref("tpu_embedding_gradients_table_%d" %
|
table_gradients = g.get_collection_ref(
|
||||||
table_id)
|
"tpu_embedding_gradients_table_%d" % table_id)
|
||||||
|
|
||||||
if not table_gradients:
|
if not table_gradients:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -180,15 +181,6 @@ def _embedding_activations_grad(activations_op, grad_wrt_activations):
|
|||||||
" train_op = opt.minimize(loss)\n"
|
" train_op = opt.minimize(loss)\n"
|
||||||
"\n")
|
"\n")
|
||||||
|
|
||||||
if table_gradients[lookup_id] is not None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Duplicate gradients (w.r.t. TPUEmbedding activations) generated for "
|
|
||||||
"table_id {} and lookup_id {}. This happens when there are multiple "
|
|
||||||
"calls to tf.gradients in a graph containing TPU embeddings. "
|
|
||||||
"TF cannot identify which gradient to use for updating the embedding "
|
|
||||||
"variables. Consider placing tf.StopGradient around tensors where "
|
|
||||||
"variable update is not required.".format(table_id, lookup_id))
|
|
||||||
|
|
||||||
table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
|
table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
|
||||||
return [
|
return [
|
||||||
# RegisterGradient requires that value be returned for all inputs. Since
|
# RegisterGradient requires that value be returned for all inputs. Since
|
||||||
@ -230,10 +222,10 @@ def infeed_dequeue_tuple(dtypes, shapes, name=None):
|
|||||||
"""A placeholder op for values fed into the TPU simultaneously as a tuple.
|
"""A placeholder op for values fed into the TPU simultaneously as a tuple.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dtypes: A list of `tf.DType`s that has length `>= 1`. The element types of
|
dtypes: A list of `tf.DType`s that has length `>= 1`.
|
||||||
each element in `outputs`.
|
The element types of each element in `outputs`.
|
||||||
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). The
|
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
|
||||||
shapes of each tensor in `outputs`.
|
The shapes of each tensor in `outputs`.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -249,8 +241,6 @@ def infeed_dequeue_tuple(dtypes, shapes, name=None):
|
|||||||
"{} is not a supported TPU infeed type. Supported types are: "
|
"{} is not a supported TPU infeed type. Supported types are: "
|
||||||
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
|
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
|
||||||
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
|
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
|
||||||
|
|
||||||
|
|
||||||
# pylint: enable=redefined-outer-name
|
# pylint: enable=redefined-outer-name
|
||||||
|
|
||||||
|
|
||||||
@ -264,17 +254,18 @@ def send_tpu_embedding_gradients(inputs,
|
|||||||
Args:
|
Args:
|
||||||
inputs: A TensorList of gradients with which to update embedding tables.
|
inputs: A TensorList of gradients with which to update embedding tables.
|
||||||
This argument has the same length and shapes as the return value of
|
This argument has the same length and shapes as the return value of
|
||||||
RecvTPUEmbeddingActivations, but contains gradients of the model's loss
|
RecvTPUEmbeddingActivations, but contains gradients of the model's
|
||||||
with respect to the embedding activations. The embedding tables are
|
loss with respect to the embedding activations. The embedding tables
|
||||||
updated from these gradients via the optimizers specified in the TPU
|
are updated from these gradients via the optimizers specified in the
|
||||||
embedding configuration given to tpu.initialize_system.
|
TPU embedding configuration given to tpu.initialize_system.
|
||||||
config: Serialized TPUEmbeddingConfiguration proto.
|
config: Serialized TPUEmbeddingConfiguration proto.
|
||||||
learning_rates: A TensorList of float32 scalars, one for each dynamic
|
learning_rates: A TensorList of float32 scalars, one for each dynamic
|
||||||
learning rate tag: see the comments in
|
learning rate tag: see the comments in
|
||||||
//third_party/tensorflow/core/protobuf/tpu/
|
//third_party/tensorflow/core/protobuf/tpu/
|
||||||
optimization_parameters.proto. Multiple tables can share the same
|
optimization_parameters.proto.
|
||||||
dynamic learning rate tag as specified in the configuration. If the
|
Multiple tables can share the same dynamic learning rate tag as
|
||||||
learning rates for all tables are constant, this list should be empty.
|
specified in the configuration. If the learning rates for all tables
|
||||||
|
are constant, this list should be empty.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -336,8 +327,8 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
|
|||||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_indices: A list of rank 1 Tensors specifying the training example and
|
sample_indices: A list of rank 1 Tensors specifying the training example
|
||||||
feature to which the corresponding embedding_indices and
|
and feature to which the corresponding embedding_indices and
|
||||||
aggregation_weights values belong. sample_indices[i] must equal b * nf +
|
aggregation_weights values belong. sample_indices[i] must equal b * nf +
|
||||||
f, where nf is the number of features from the corresponding table, f is
|
f, where nf is the number of features from the corresponding table, f is
|
||||||
in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
|
in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
|
||||||
@ -345,9 +336,9 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
|
|||||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||||
tables. Both int32 and int64 are allowed and will be converted to int32
|
tables. Both int32 and int64 are allowed and will be converted to int32
|
||||||
internally.
|
internally.
|
||||||
aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e.,
|
aggregation_weights: A list of rank 1 Tensors containing per sample --
|
||||||
per (training example, feature) -- aggregation weights. Both float32 and
|
i.e. per (training example, feature) -- aggregation weights. Both float32
|
||||||
float64 are allowed and will be converted to float32 internally.
|
and float64 are allowed and will be converted to float32 internally.
|
||||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||||
number of TPU cores in the task on which the node is placed.
|
number of TPU cores in the task on which the node is placed.
|
||||||
combiners: A list of string scalars, one for each embedding table that
|
combiners: A list of string scalars, one for each embedding table that
|
||||||
@ -395,20 +386,20 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
|||||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_indices: A list of rank 2 Tensors specifying the training example to
|
sample_indices: A list of rank 2 Tensors specifying the training example
|
||||||
which the corresponding embedding_indices and aggregation_weights values
|
to which the corresponding embedding_indices and aggregation_weights
|
||||||
belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If
|
values belong. It corresponds to sp_ids.indices in
|
||||||
the size of its first dimension is 0, we assume each embedding_indices
|
embedding_lookup_sparse(). If the size of its first dimension is 0, we
|
||||||
belongs to a different sample. Both int32 and int64 are allowed and will
|
assume each embedding_indices belongs to a different sample. Both int32
|
||||||
be converted to int32 internally.
|
and int64 are allowed and will be converted to int32 internally.
|
||||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||||
tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
|
tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
|
||||||
int32 and int64 are allowed and will be converted to int32 internally.
|
int32 and int64 are allowed and will be converted to int32 internally.
|
||||||
aggregation_weights: A list of rank 1 Tensors containing per training
|
aggregation_weights: A list of rank 1 Tensors containing per training
|
||||||
example aggregation weights. It corresponds to sp_weights.values in
|
example aggregation weights. It corresponds to sp_weights.values in
|
||||||
embedding_lookup_sparse(). If the size of its first dimension is 0, we
|
embedding_lookup_sparse(). If the size of its first dimension is 0, we
|
||||||
assume all weights are 1. Both float32 and float64 are allowed and will be
|
assume all weights are 1. Both float32 and float64 are allowed and will
|
||||||
converted to float32 internally.
|
be converted to float32 internally.
|
||||||
table_ids: A list of integers specifying the identifier of the embedding
|
table_ids: A list of integers specifying the identifier of the embedding
|
||||||
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
|
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
|
||||||
lookup the corresponding input. The ith input is looked up using
|
lookup the corresponding input. The ith input is looked up using
|
||||||
|
@ -91,8 +91,8 @@ def create_dummy_table_variables(tpu_embedding):
|
|||||||
if table_gradients:
|
if table_gradients:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
|
'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
|
||||||
num_features = len(tpu_embedding.table_to_features_dict[table])
|
table_gradients.extend(
|
||||||
table_gradients.extend([None for _ in range(num_features)])
|
[None] * len(tpu_embedding.table_to_features_dict[table]))
|
||||||
|
|
||||||
return (dummy_table_variables,
|
return (dummy_table_variables,
|
||||||
variables.variables_initializer(
|
variables.variables_initializer(
|
||||||
|
Loading…
Reference in New Issue
Block a user