Add EnqueueTPUEmbeddingRaggedTensorBatch for RaggedTensor support.
PiperOrigin-RevId: 304250071 Change-Id: If1f0d7a8716c95a090f28d085a46ffa9c3e9053e
This commit is contained in:
parent
b2ce99d186
commit
b220af894a
@ -0,0 +1,77 @@
|
||||
op {
|
||||
graph_op_name: "EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "sample_splits"
|
||||
description: <<END
|
||||
A list of rank 1 Tensors specifying the break points for splitting
|
||||
embedding_indices and aggregation_weights into rows.
|
||||
It corresponds to ids.row_splits in embedding_lookup(), when ids is a
|
||||
RaggedTensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "embedding_indices"
|
||||
description: <<END
|
||||
A list of rank 1 Tensors, indices into the embedding tables.
|
||||
It corresponds to ids.values in embedding_lookup(), when ids is a RaggedTensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "aggregation_weights"
|
||||
description: <<END
|
||||
A list of rank 1 Tensors containing per training example
|
||||
aggregation weights. It corresponds to the values field of a RaggedTensor
|
||||
with the same row_splits as ids in embedding_lookup(), when ids is a
|
||||
RaggedTensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "mode_override"
|
||||
description: <<END
|
||||
A string input that overrides the mode specified in the
|
||||
TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
|
||||
'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
|
||||
in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "device_ordinal"
|
||||
description: <<END
|
||||
The TPU device to use. Should be >= 0 and less than the number
|
||||
of TPU cores in the task on which the node is placed.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "combiners"
|
||||
description: <<END
|
||||
A list of string scalars, one for each embedding table that specify
|
||||
how to normalize the embedding activations after weighted summation.
|
||||
Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have
|
||||
the sum of the weights be 0 for 'mean' or the sum of the squared weights be
|
||||
0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for
|
||||
all tables.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "table_ids"
|
||||
description: <<END
|
||||
A list of integers specifying the identifier of the embedding table
|
||||
(offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the
|
||||
corresponding input. The ith input is looked up using table_ids[i]. The size
|
||||
of the table_ids list must be equal to that of sample_indices,
|
||||
embedding_indices and aggregation_weights.
|
||||
END
|
||||
}
|
||||
summary: "Eases the porting of code that uses tf.nn.embedding_lookup()."
|
||||
description: <<END
|
||||
sample_splits[i], embedding_indices[i] and aggregation_weights[i] correspond
|
||||
to the ith feature. table_ids[i] indicates which embedding table to look up ith
|
||||
feature.
|
||||
|
||||
The tensors at corresponding positions in two of the input lists,
|
||||
embedding_indices and aggregation_weights, must have the same shape, i.e. rank 1
|
||||
with dim_size() equal to the total number of lookups into the table described by
|
||||
the corresponding feature.
|
||||
END
|
||||
}
|
@ -13184,6 +13184,101 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||
input_arg {
|
||||
name: "sample_indices"
|
||||
type_attr: "T1"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "embedding_indices"
|
||||
type_attr: "T2"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "aggregation_weights"
|
||||
type_attr: "T3"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "mode_override"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "T1"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T2"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T3"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "device_ordinal"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: -1
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "combiners"
|
||||
type: "list(string)"
|
||||
default_value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "table_ids"
|
||||
type: "list(int)"
|
||||
}
|
||||
attr {
|
||||
name: "max_sequence_lengths"
|
||||
type: "list(int)"
|
||||
default_value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "EnsureShape"
|
||||
input_arg {
|
||||
|
@ -168,4 +168,20 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch")
|
||||
.Input("sample_splits: N * T1")
|
||||
.Input("embedding_indices: N * T2")
|
||||
.Input("aggregation_weights: N * T3")
|
||||
.Input("mode_override: string")
|
||||
.Attr("T1: {int32,int64} = DT_INT32")
|
||||
.Attr("T2: {int32,int64} = DT_INT32")
|
||||
.Attr("T3: {float32,float64} = DT_FLOAT")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("device_ordinal: int = -1")
|
||||
.Attr("combiners: list(string) = []")
|
||||
.Attr("table_ids: list(int)")
|
||||
.Attr("max_sequence_lengths: list(int) = []")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -444,3 +444,76 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
||||
|
||||
enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
|
||||
gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
|
||||
embedding_indices,
|
||||
aggregation_weights,
|
||||
table_ids,
|
||||
device_ordinal,
|
||||
max_sequence_lengths=None,
|
||||
combiners=None,
|
||||
mode_override=None,
|
||||
name=None):
|
||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||
|
||||
Args:
|
||||
sample_splits: A list of rank 1 Tensors specifying the break points for
|
||||
splitting embedding_indices and aggregation_weights into rows. It
|
||||
corresponds to ids.row_splits in embedding_lookup(), when ids is a
|
||||
RaggedTensor. Both int32 and int64 are allowed and will be converted to
|
||||
int32 internally.
|
||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||
tables. It corresponds to ids.values in embedding_lookup(), when ids is a
|
||||
RaggedTensor. Both int32 and int64 are allowed and will be converted to
|
||||
int32 internally.
|
||||
aggregation_weights: A list of rank 1 Tensors containing per training
|
||||
example aggregation weights. It corresponds to the values field of a
|
||||
RaggedTensor with the same row_splits as ids in embedding_lookup(), when
|
||||
ids is a RaggedTensor. Both float32 and float64 are allowed and will be
|
||||
converted to float32 internally.
|
||||
table_ids: A list of integers specifying the identifier of the embedding
|
||||
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
|
||||
lookup the corresponding input. The ith input is looked up using
|
||||
table_ids[i]. The size of the table_ids list must be equal to that of
|
||||
sample_indices, embedding_indices and aggregation_weights.
|
||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||
number of TPU cores in the task on which the node is placed.
|
||||
max_sequence_lengths: A list of integers, the size of which is equal to
|
||||
sample_indices. If equal to 0, the corresponding feature is considered to
|
||||
be a non-sequence feature, If greater than 0, the corresponding feature is
|
||||
a sequence feature with the given maximal length. If None, then we assume
|
||||
a list of all zeroes.
|
||||
combiners: A list of string scalars, one for each embedding table that
|
||||
specify how to normalize the embedding activations after weighted
|
||||
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
||||
invalid to have the sum of the weights be 0 for 'mean' or the sum of the
|
||||
squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
|
||||
is to use 'sum' for all tables (optional).
|
||||
mode_override: A string input that overrides the mode specified in the
|
||||
TPUEmbeddingConfiguration. Supported values are {'unspecified',
|
||||
'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
|
||||
the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
|
||||
is used (optional).
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
An EnqueueTPUEmbeddingRaggedTensorBatch operation.
|
||||
"""
|
||||
if mode_override is None:
|
||||
mode_override = "unspecified"
|
||||
return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
|
||||
sample_splits=sample_splits,
|
||||
embedding_indices=embedding_indices,
|
||||
aggregation_weights=aggregation_weights,
|
||||
table_ids=table_ids,
|
||||
device_ordinal=device_ordinal,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
combiners=combiners,
|
||||
mode_override=mode_override,
|
||||
name=name)
|
||||
|
||||
|
||||
enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = (
|
||||
gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__)
|
||||
|
@ -205,6 +205,48 @@ class EnqueueData(
|
||||
aggregation_weights=weights.values if weights is not None else None)
|
||||
|
||||
|
||||
class RaggedEnqueueData(
|
||||
collections.namedtuple(
|
||||
'RaggedEnqueueData',
|
||||
['embedding_indices', 'sample_splits', 'aggregation_weights'])):
|
||||
"""RaggedTensor Data to be enqueued through generate_enqueue_ops()."""
|
||||
|
||||
def __new__(cls,
|
||||
embedding_indices,
|
||||
sample_splits=None,
|
||||
aggregation_weights=None):
|
||||
"""Data to be enqueued through generate_enqueue_ops().
|
||||
|
||||
Args:
|
||||
embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
|
||||
corresponds to ids.values in embedding_lookup(), when ids is a
|
||||
RaggedTensor. Both int32 and int64 are allowed and will be converted to
|
||||
int32 internally.
|
||||
sample_splits: A rank 1 Tensor specifying the break points for splitting
|
||||
embedding_indices and aggregation_weights into rows. It corresponds to
|
||||
ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both
|
||||
int32 and int64 are allowed and will be converted to int32 internally.
|
||||
aggregation_weights: A rank 1 Tensor containing per training example
|
||||
aggregation weights. It corresponds to the values field of a
|
||||
RaggedTensor with the same row_splits as ids in embedding_lookup(), when
|
||||
ids is a RaggedTensor.
|
||||
|
||||
Returns:
|
||||
An RaggedEnqueueData tuple.
|
||||
|
||||
"""
|
||||
return super(RaggedEnqueueData,
|
||||
cls).__new__(cls, embedding_indices, sample_splits,
|
||||
aggregation_weights)
|
||||
|
||||
@staticmethod
|
||||
def from_ragged_tensor(rg_tensor, weights=None):
|
||||
return RaggedEnqueueData(
|
||||
rg_tensor.values,
|
||||
rg_tensor.row_splits,
|
||||
aggregation_weights=weights.values if weights is not None else None)
|
||||
|
||||
|
||||
def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
|
||||
"""Convenient function for generate_enqueue_ops().
|
||||
|
||||
@ -229,6 +271,30 @@ def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
|
||||
return enqueue_datas_list
|
||||
|
||||
|
||||
def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
|
||||
"""Convenient function for generate_enqueue_ops().
|
||||
|
||||
Args:
|
||||
rg_tensors_list: a list of dictionary mapping from string of feature names
|
||||
to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the
|
||||
same host should be contiguous on the list.
|
||||
|
||||
Returns:
|
||||
enqueue_datas_list: a list of dictionary mapping from string
|
||||
of feature names to RaggedEnqueueData. Each dictionary is for one
|
||||
TPU core. Dictionaries for the same host should be contiguous
|
||||
on the list.
|
||||
|
||||
"""
|
||||
enqueue_datas_list = []
|
||||
for rg_tensors in rg_tensors_list:
|
||||
enqueue_datas = collections.OrderedDict(
|
||||
(k, RaggedEnqueueData.from_ragged_tensor(v))
|
||||
for k, v in six.iteritems(rg_tensors))
|
||||
enqueue_datas_list.append(enqueue_datas)
|
||||
return enqueue_datas_list
|
||||
|
||||
|
||||
AdamSlotVariableNames = collections.namedtuple(
|
||||
'AdamSlotVariableNames', ['m', 'v'])
|
||||
|
||||
@ -1159,7 +1225,12 @@ class TPUEmbedding(object):
|
||||
slot_variables_by_table,
|
||||
load_ops, retrieve_ops)
|
||||
|
||||
def generate_enqueue_ops(self, enqueue_datas_list, mode_override=None):
|
||||
def generate_enqueue_ops(
|
||||
self,
|
||||
enqueue_datas_list,
|
||||
mode_override=None,
|
||||
ragged=False,
|
||||
):
|
||||
"""Generate enqueue ops.
|
||||
|
||||
Args:
|
||||
@ -1172,6 +1243,8 @@ class TPUEmbedding(object):
|
||||
'inference', 'training', 'backward_pass_only'}. When set to
|
||||
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
|
||||
otherwise mode_override is used (optional).
|
||||
ragged: If True, creates RaggedTensor enqueue ops rather than
|
||||
SparseTensor.
|
||||
|
||||
Returns:
|
||||
Ops to enqueue to TPU for embedding.
|
||||
@ -1182,6 +1255,7 @@ class TPUEmbedding(object):
|
||||
enqueue_datas,
|
||||
device_ordinal=i % self._num_cores_per_host,
|
||||
mode_override=mode_override,
|
||||
ragged=ragged,
|
||||
) for i, enqueue_datas in enumerate(enqueue_datas_list)
|
||||
]
|
||||
|
||||
@ -1211,28 +1285,50 @@ class TPUEmbedding(object):
|
||||
for feature, enqueue_data in six.iteritems(enqueue_datas):
|
||||
combiner = self._table_to_config_dict[
|
||||
self._feature_to_config_dict[feature].table_id].combiner
|
||||
if not isinstance(enqueue_data, EnqueueData):
|
||||
raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
|
||||
'not mapped to `EnqueueData`. `feature`: {}'.format(
|
||||
i, feature))
|
||||
|
||||
if enqueue_data.sample_indices is None and combiner:
|
||||
logging.warn('No sample indices set for features %f table %f but '
|
||||
'combiner is set to %s.', feature,
|
||||
self._feature_to_config_dict[feature].table_id, combiner)
|
||||
if isinstance(enqueue_data, EnqueueData):
|
||||
if enqueue_data.sample_indices is None and combiner:
|
||||
logging.warn(
|
||||
'No sample indices set for features %f table %f but '
|
||||
'combiner is set to %s.', feature,
|
||||
self._feature_to_config_dict[feature].table_id, combiner)
|
||||
if (enqueue_data.sample_indices is not None and
|
||||
enqueue_data.sample_indices.device !=
|
||||
enqueue_data.embedding_indices.device):
|
||||
raise ValueError(
|
||||
'Device of sample_indices does not agree with '
|
||||
'that of embedding_indices for feature {}.'.format(feature))
|
||||
if (enqueue_data.aggregation_weights is not None and
|
||||
enqueue_data.aggregation_weights.device !=
|
||||
enqueue_data.embedding_indices.device):
|
||||
raise ValueError(
|
||||
'Device of aggregation_weights does not agree with '
|
||||
'that of embedding_indices for feature {}.'.format(feature))
|
||||
|
||||
if (enqueue_data.sample_indices is not None and
|
||||
enqueue_data.sample_indices.device !=
|
||||
enqueue_data.embedding_indices.device):
|
||||
elif isinstance(enqueue_data, RaggedEnqueueData):
|
||||
if enqueue_data.sample_splits is None and combiner:
|
||||
logging.warn(
|
||||
'No sample splits set for features %f table %f but '
|
||||
'combiner is set to %s.', feature,
|
||||
self._feature_to_config_dict[feature].table_id, combiner)
|
||||
if (enqueue_data.sample_splits is not None and
|
||||
enqueue_data.sample_splits.device !=
|
||||
enqueue_data.embedding_indices.device):
|
||||
raise ValueError(
|
||||
'Device of sample_splits does not agree with '
|
||||
'that of embedding_indices for feature {}.'.format(feature))
|
||||
if (enqueue_data.aggregation_weights is not None and
|
||||
enqueue_data.aggregation_weights.device !=
|
||||
enqueue_data.embedding_indices.device):
|
||||
raise ValueError(
|
||||
'Device of aggregation_weights does not agree with '
|
||||
'that of embedding_indices for feature {}.'.format(feature))
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
'Device of sample_indices does not agree with '
|
||||
'that of embedding_indices for feature {}.'.format(feature))
|
||||
if (enqueue_data.aggregation_weights is not None and
|
||||
enqueue_data.aggregation_weights.device !=
|
||||
enqueue_data.embedding_indices.device):
|
||||
raise ValueError(
|
||||
'Device of aggregation_weights does not agree with '
|
||||
'that of embedding_indices for feature {}.'.format(feature))
|
||||
'`enqueue_datas_list[{}]` has a feature that is not mapped to '
|
||||
'`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format(
|
||||
i, feature))
|
||||
# Check all features are on the same device.
|
||||
if device is None:
|
||||
device = enqueue_data.embedding_indices.device
|
||||
@ -1257,23 +1353,69 @@ class TPUEmbedding(object):
|
||||
else:
|
||||
contiguous_device = device
|
||||
|
||||
def _generate_enqueue_op(
|
||||
self, enqueue_datas, device_ordinal, mode_override=None):
|
||||
def _generate_enqueue_op(self,
|
||||
enqueue_datas,
|
||||
device_ordinal,
|
||||
mode_override=None,
|
||||
ragged=False):
|
||||
"""Creates op for enqueuing batch to TPU."""
|
||||
enqueue_data0 = list(enqueue_datas.values())[0]
|
||||
with ops.colocate_with(enqueue_data0.embedding_indices):
|
||||
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
|
||||
device_ordinal=device_ordinal,
|
||||
combiners=self._combiners,
|
||||
mode_override=mode_override,
|
||||
**self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)
|
||||
)
|
||||
if ragged:
|
||||
# note that this is currently identical in behavior
|
||||
return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
|
||||
device_ordinal=device_ordinal,
|
||||
combiners=self._combiners,
|
||||
mode_override=mode_override,
|
||||
**self._format_for_tpu_embedding_ragged_tensor_batch(enqueue_datas))
|
||||
else:
|
||||
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
|
||||
device_ordinal=device_ordinal,
|
||||
combiners=self._combiners,
|
||||
mode_override=mode_override,
|
||||
**self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
|
||||
|
||||
def _format_for_tpu_embedding_ragged_tensor_batch(self, enqueue_datas):
|
||||
"""Format sparse features for `enqueue_tpu_embedding_ragged_tensor_batch()`.
|
||||
|
||||
Args:
|
||||
enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding.
|
||||
|
||||
Returns:
|
||||
Dict of arguments for `enqueue_tpu_embedding_ragged_tensor_batch()`.
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
'sample_splits': [],
|
||||
'embedding_indices': [],
|
||||
'aggregation_weights': [],
|
||||
'table_ids': [],
|
||||
'max_sequence_lengths': [],
|
||||
}
|
||||
for table_id, table in enumerate(self._table_to_features_dict):
|
||||
features = self._table_to_features_dict[table]
|
||||
for feature in features:
|
||||
enqueue_data = enqueue_datas[feature]
|
||||
|
||||
kwargs['sample_splits'].append(enqueue_data.sample_splits)
|
||||
|
||||
kwargs['aggregation_weights'].append(
|
||||
enqueue_data.aggregation_weights if enqueue_data.aggregation_weights
|
||||
is not None else array_ops.zeros((0,), dtype=dtypes.float32))
|
||||
|
||||
kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
|
||||
|
||||
kwargs['table_ids'].append(table_id)
|
||||
kwargs['max_sequence_lengths'].append(
|
||||
self._feature_to_config_dict[feature].max_sequence_length)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
|
||||
"""Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
|
||||
|
||||
Args:
|
||||
enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or
|
||||
dense.
|
||||
enqueue_datas: a `Dict` of `EnqueueData` objects for embedding.
|
||||
|
||||
Returns:
|
||||
Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
|
||||
|
@ -1244,6 +1244,10 @@ tf_module {
|
||||
name: "EnqueueTPUEmbeddingIntegerBatch"
|
||||
argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "EnqueueTPUEmbeddingSparseBatch"
|
||||
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "
|
||||
|
@ -1244,6 +1244,10 @@ tf_module {
|
||||
name: "EnqueueTPUEmbeddingIntegerBatch"
|
||||
argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "EnqueueTPUEmbeddingSparseBatch"
|
||||
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user