Support aggregation weights for embedding lookup with TPU in TPUEstimator.
PiperOrigin-RevId: 241827445
This commit is contained in:
parent
bf3bd1c026
commit
bb5880c426
@ -20,13 +20,13 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
import six
|
|
||||||
|
|
||||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||||
from tensorflow.python.feature_column import feature_column as core_fc
|
from tensorflow.python.feature_column import feature_column as core_fc
|
||||||
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
|
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.tpu import feature_column as tpu_fc
|
from tensorflow.python.tpu import feature_column as tpu_fc
|
||||||
from tensorflow.python.tpu import tpu_embedding
|
from tensorflow.python.tpu import tpu_embedding
|
||||||
from tensorflow.python.tpu.tpu_embedding import AdagradParameters
|
from tensorflow.python.tpu.tpu_embedding import AdagradParameters
|
||||||
@ -107,18 +107,15 @@ def get_full_variable_names(
|
|||||||
return embedding_variable_name_by_table, slot_variable_names_by_table
|
return embedding_variable_name_by_table, slot_variable_names_by_table
|
||||||
|
|
||||||
|
|
||||||
def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
def get_configs_from_feature_columns(feature_columns):
|
||||||
"""Create configs for TPUEmbedding from a list of feature columns.
|
"""Create configs for TPUEmbedding etc from a list of feature columns.
|
||||||
|
|
||||||
This function will place one embedding tensor per table and the return is
|
|
||||||
intended to be used as input to TPUEmbedding.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
feature_columns: a list of supported feature columns.
|
feature_columns: a list of supported feature columns.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A pair of dicts, the first maps tables to their config, the second maps
|
A tuple of dicts, the first maps tables to their config, the second maps
|
||||||
features to tables.
|
features to tables, and the third maps features to weight key names.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
|
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
|
||||||
@ -131,6 +128,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
|||||||
|
|
||||||
table_to_config = {}
|
table_to_config = {}
|
||||||
feature_to_table = {}
|
feature_to_table = {}
|
||||||
|
feature_to_weight_key_name = {}
|
||||||
for column in feature_columns:
|
for column in feature_columns:
|
||||||
feature_name = column.get_feature_key_name()
|
feature_name = column.get_feature_key_name()
|
||||||
table_name = _get_table_name_from_embedding_var_name(
|
table_name = _get_table_name_from_embedding_var_name(
|
||||||
@ -140,6 +138,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
|||||||
'Feature column {} is used with multiple embeddings and this is '
|
'Feature column {} is used with multiple embeddings and this is '
|
||||||
'not supported.'.format(feature_name))
|
'not supported.'.format(feature_name))
|
||||||
feature_to_table[feature_name] = table_name
|
feature_to_table[feature_name] = table_name
|
||||||
|
feature_to_weight_key_name[feature_name] = column.get_weight_key_name()
|
||||||
vocabulary_size, dimension = column.get_embedding_table_size()
|
vocabulary_size, dimension = column.get_embedding_table_size()
|
||||||
table_to_config[table_name] = tpu_embedding.TableConfig(
|
table_to_config[table_name] = tpu_embedding.TableConfig(
|
||||||
vocabulary_size=vocabulary_size,
|
vocabulary_size=vocabulary_size,
|
||||||
@ -147,7 +146,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
|||||||
initializer=column.get_initializer(),
|
initializer=column.get_initializer(),
|
||||||
combiner=column.get_combiner())
|
combiner=column.get_combiner())
|
||||||
|
|
||||||
return table_to_config, feature_to_table
|
return table_to_config, feature_to_table, feature_to_weight_key_name
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfigSpec(
|
class EmbeddingConfigSpec(
|
||||||
@ -239,9 +238,10 @@ class EmbeddingConfig(object):
|
|||||||
self._num_cores = num_cores
|
self._num_cores = num_cores
|
||||||
self._run_config = run_config
|
self._run_config = run_config
|
||||||
|
|
||||||
self._table_to_config_dict, self._feature_to_table_dict = (
|
(self._table_to_config_dict, self._feature_to_table_dict,
|
||||||
get_tpu_embedding_config_from_feature_columns(
|
self.feature_to_weight_key_name_dict) = (
|
||||||
embedding_config_spec.feature_columns))
|
get_configs_from_feature_columns(
|
||||||
|
embedding_config_spec.feature_columns))
|
||||||
self._mode_to_tpu_embedding_dict = {}
|
self._mode_to_tpu_embedding_dict = {}
|
||||||
self.dummy_table_variables = None
|
self.dummy_table_variables = None
|
||||||
|
|
||||||
@ -305,19 +305,61 @@ class EmbeddingConfig(object):
|
|||||||
|
|
||||||
def split_inputs(ctx, features, labels):
|
def split_inputs(ctx, features, labels):
|
||||||
"""Splits the dense and sparse tensors inside the features and labels."""
|
"""Splits the dense and sparse tensors inside the features and labels."""
|
||||||
sparse_features = collections.OrderedDict()
|
enqueue_datas = collections.OrderedDict()
|
||||||
if ctx.embedding_config:
|
if ctx.embedding_config:
|
||||||
tpu_embedding_ = ctx.embedding_config.tpu_embedding
|
tpu_embedding_ = ctx.embedding_config.tpu_embedding
|
||||||
|
feature_to_weight_key_name_dict = (
|
||||||
|
ctx.embedding_config.feature_to_weight_key_name_dict)
|
||||||
for feature_key in tpu_embedding_.feature_to_table_dict:
|
for feature_key in tpu_embedding_.feature_to_table_dict:
|
||||||
sparse_features[feature_key] = features.pop(feature_key)
|
sparse_feature = _get_sparse_feature_from_feature(feature_key, features)
|
||||||
|
weight_key_name = feature_to_weight_key_name_dict[feature_key]
|
||||||
|
if isinstance(sparse_feature, sparse_tensor.SparseTensor):
|
||||||
|
weights = _get_weights_from_features(weight_key_name, features)
|
||||||
|
enqueue_data = tpu_embedding.EnqueueData.from_sparse_tensor(
|
||||||
|
sparse_feature, weights)
|
||||||
|
else:
|
||||||
|
if weight_key_name is not None:
|
||||||
|
raise ValueError(
|
||||||
|
'Found weights {} for weighted_categorical_column, which is not'
|
||||||
|
'compatible with sparse feature {} enqueued as dense tensor.'
|
||||||
|
.format(weight_key_name, feature_key))
|
||||||
|
enqueue_data = tpu_embedding.EnqueueData(sparse_feature)
|
||||||
|
enqueue_datas[feature_key] = enqueue_data
|
||||||
|
|
||||||
for v in six.itervalues(sparse_features):
|
return features, labels, enqueue_datas
|
||||||
if not v.dtype.is_integer:
|
|
||||||
raise ValueError('SparseTensor with string as values are not supported. '
|
|
||||||
'If you are using vocabulary_file_categorical_column or '
|
|
||||||
'vocabulary_list_categorical_column, please call '
|
|
||||||
'your_column.categorical_column._transform_feature({'
|
|
||||||
'your_column.key: features[your_column.key]}) in'
|
|
||||||
'your input_fn() to convert string to int.')
|
|
||||||
|
|
||||||
return features, labels, sparse_features
|
|
||||||
|
def _get_sparse_feature_from_feature(feature_key, features):
|
||||||
|
"""Pop and return sparse feature."""
|
||||||
|
sparse_feature = features.pop(feature_key)
|
||||||
|
if not sparse_feature.dtype.is_integer:
|
||||||
|
raise ValueError('SparseTensor with string as values are not supported. '
|
||||||
|
'If you are using vocabulary_file_categorical_column or '
|
||||||
|
'vocabulary_list_categorical_column, please call '
|
||||||
|
'your_column.categorical_column._transform_feature({{'
|
||||||
|
'your_column.key: features[your_column.key]}}) in'
|
||||||
|
'your input_fn() to convert string to int. '
|
||||||
|
'feature_key = {}.'.format(feature_key))
|
||||||
|
return sparse_feature
|
||||||
|
|
||||||
|
|
||||||
|
def _get_weights_from_features(weight_key_name, features):
|
||||||
|
"""Pop and return feature for weights, possibly None."""
|
||||||
|
weights = None
|
||||||
|
if weight_key_name is not None:
|
||||||
|
if weight_key_name in features:
|
||||||
|
weights = features.pop(weight_key_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Cannot find weights {} for weighted_categorical_column.'
|
||||||
|
' Please check if the weights are present in feature dict. Also'
|
||||||
|
' note weight-sharing among weighted_categorical_column is not '
|
||||||
|
'supported on TPU.'.format(weight_key_name))
|
||||||
|
if not isinstance(weights, sparse_tensor.SparseTensor):
|
||||||
|
raise ValueError(
|
||||||
|
'weighted_categorical_column with weight key name {} has dense '
|
||||||
|
'weights. Dense weights are not supported on TPU. Please use '
|
||||||
|
'sparse weights instead.'.format(weight_key_name))
|
||||||
|
if weights.dtype is not dtypes.float32:
|
||||||
|
weights = math_ops.to_float(weights)
|
||||||
|
return weights
|
||||||
|
@ -28,7 +28,6 @@ from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
|
|||||||
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
|
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
@ -97,6 +96,73 @@ class TableConfig(
|
|||||||
initializer, combiner)
|
initializer, combiner)
|
||||||
|
|
||||||
|
|
||||||
|
class EnqueueData(
|
||||||
|
collections.namedtuple(
|
||||||
|
'EnqueueData',
|
||||||
|
['embedding_indices', 'sample_indices', 'aggregation_weights'])):
|
||||||
|
"""Data to be enqueued through generate_enqueue_ops()."""
|
||||||
|
|
||||||
|
def __new__(cls,
|
||||||
|
embedding_indices,
|
||||||
|
sample_indices=None,
|
||||||
|
aggregation_weights=None):
|
||||||
|
"""Data to be enqueued through generate_enqueue_ops().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_indices: A rank 1 Tensors, indices into the embedding tables. It
|
||||||
|
corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32
|
||||||
|
and int64 are allowed and will be converted to int32 internally.
|
||||||
|
sample_indices: A rank 2 Tensors specifying the training example to which
|
||||||
|
the corresponding embedding_indices and aggregation_weights values
|
||||||
|
belong. It corresponds to sp_ids.indices in embedding_lookup_sparse().
|
||||||
|
If it is None, we assume each embedding_indices belongs to a different
|
||||||
|
sample. Both int32 and int64 are allowed and will be converted to int32
|
||||||
|
internally.
|
||||||
|
aggregation_weights: A rank 1 Tensors containing per training example
|
||||||
|
aggregation weights. It corresponds to sp_weights.values in
|
||||||
|
embedding_lookup_sparse(). If it is None, we assume all weights are 1.
|
||||||
|
Both float32 and float64 are allowed and will be converted to float32
|
||||||
|
internally.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An EnqueueData tuple.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return super(EnqueueData, cls).__new__(cls, embedding_indices,
|
||||||
|
sample_indices, aggregation_weights)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_sparse_tensor(sp_tensor, weights=None):
|
||||||
|
return EnqueueData(
|
||||||
|
sp_tensor.values,
|
||||||
|
sp_tensor.indices,
|
||||||
|
aggregation_weights=weights.values if weights is not None else None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
|
||||||
|
"""Convenient function for generate_enqueue_ops().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sp_tensors_list: a list of dictionary mapping from string of feature names
|
||||||
|
to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the
|
||||||
|
same host should be contiguous on the list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
enqueue_datas_list: a list of dictionary mapping from string
|
||||||
|
of feature names to EnqueueData. Each dictionary is for one
|
||||||
|
TPU core. Dictionaries for the same host should be contiguous
|
||||||
|
on the list.
|
||||||
|
|
||||||
|
"""
|
||||||
|
enqueue_datas_list = []
|
||||||
|
for sp_tensors in sp_tensors_list:
|
||||||
|
enqueue_datas = collections.OrderedDict(
|
||||||
|
(k, EnqueueData.from_sparse_tensor(v))
|
||||||
|
for k, v in six.iteritems(sp_tensors))
|
||||||
|
enqueue_datas_list.append(enqueue_datas)
|
||||||
|
return enqueue_datas_list
|
||||||
|
|
||||||
|
|
||||||
AdamSlotVariableNames = collections.namedtuple(
|
AdamSlotVariableNames = collections.namedtuple(
|
||||||
'AdamSlotVariableNames', ['m', 'v'])
|
'AdamSlotVariableNames', ['m', 'v'])
|
||||||
|
|
||||||
@ -564,119 +630,148 @@ class TPUEmbedding(object):
|
|||||||
slot_variables_by_table,
|
slot_variables_by_table,
|
||||||
load_ops, retrieve_ops)
|
load_ops, retrieve_ops)
|
||||||
|
|
||||||
def generate_enqueue_ops(self, sparse_features_list):
|
def generate_enqueue_ops(self, enqueue_datas_list):
|
||||||
"""Generate enqueue ops.
|
"""Generate enqueue ops.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_features_list: a list of dictionary mapping from string
|
enqueue_datas_list: a list of dictionary mapping from string
|
||||||
of feature names to sparse tensor. Each dictionary is for one
|
of feature names to EnqueueData. Each dictionary is for one
|
||||||
TPU core. Dictionaries for the same host should be contiguous
|
TPU core. Dictionaries for the same host should be contiguous
|
||||||
on the list.
|
on the list.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Ops to enqueue to TPU for embedding.
|
Ops to enqueue to TPU for embedding.
|
||||||
"""
|
"""
|
||||||
self._validate_generate_enqueue_ops_sparse_features_list(
|
self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list)
|
||||||
sparse_features_list)
|
|
||||||
return [
|
return [
|
||||||
self._generate_enqueue_op(
|
self._generate_enqueue_op(
|
||||||
sparse_features, device_ordinal=i % self._num_cores_per_host)
|
enqueue_datas, device_ordinal=i % self._num_cores_per_host)
|
||||||
for i, sparse_features in enumerate(sparse_features_list)
|
for i, enqueue_datas in enumerate(enqueue_datas_list)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _validate_generate_enqueue_ops_sparse_features_list(
|
def _validate_generate_enqueue_ops_enqueue_datas_list(self,
|
||||||
self, sparse_features_list):
|
enqueue_datas_list):
|
||||||
"""Validate `sparse_features_list`."""
|
"""Validate `enqueue_datas_list`."""
|
||||||
feature_set = set(self._feature_to_table_dict.keys())
|
feature_set = set(self._feature_to_table_dict.keys())
|
||||||
contiguous_device = None
|
contiguous_device = None
|
||||||
for i, sparse_features in enumerate(sparse_features_list):
|
for i, enqueue_datas in enumerate(enqueue_datas_list):
|
||||||
used_feature_set = set(sparse_features.keys())
|
used_feature_set = set(enqueue_datas.keys())
|
||||||
|
|
||||||
# Check features are valid.
|
# Check features are valid.
|
||||||
missing_feature_set = feature_set - used_feature_set
|
missing_feature_set = feature_set - used_feature_set
|
||||||
if missing_feature_set:
|
if missing_feature_set:
|
||||||
raise ValueError('`sparse_features_list[{}]` misses a feature that is '
|
raise ValueError('`enqueue_datas_list[{}]` misses a feature that is '
|
||||||
'in `feature_to_config_dict`: {}.'.format(
|
'in `feature_to_config_dict`: {}.'.format(
|
||||||
i, missing_feature_set))
|
i, missing_feature_set))
|
||||||
|
|
||||||
extra_feature_set = used_feature_set - feature_set
|
extra_feature_set = used_feature_set - feature_set
|
||||||
if extra_feature_set:
|
if extra_feature_set:
|
||||||
raise ValueError('`sparse_features_list[{}]` has a feature that is not '
|
raise ValueError('`enqueue_datas_list[{}]` has a feature that is not '
|
||||||
'in `feature_to_config_dict`: {}.'.format(
|
'in `feature_to_config_dict`: {}.'.format(
|
||||||
i, extra_feature_set))
|
i, extra_feature_set))
|
||||||
|
|
||||||
device = None
|
device = None
|
||||||
device_feature = None
|
device_feature = None
|
||||||
for feature, tensor in six.iteritems(sparse_features):
|
for feature, enqueue_data in six.iteritems(enqueue_datas):
|
||||||
combiner = self._table_to_config_dict[
|
combiner = self._table_to_config_dict[
|
||||||
self._feature_to_table_dict[feature]].combiner
|
self._feature_to_table_dict[feature]].combiner
|
||||||
if not isinstance(tensor, sparse_tensor.SparseTensor) and combiner:
|
if not isinstance(enqueue_data, EnqueueData):
|
||||||
raise ValueError('`sparse_features_list[{}]` has a feature that is '
|
raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
|
||||||
'not mapped to `SparseTensor` and has a combiner. '
|
'not mapped to `EnqueueData`. `feature`: {}'.format(
|
||||||
'`feature`: {}, combiner: {}'.format(
|
i, feature))
|
||||||
|
|
||||||
|
if enqueue_data.sample_indices is None and combiner:
|
||||||
|
raise ValueError('`enqueue_datas_list[{}]` has a feature that has '
|
||||||
|
'neither `EnqueueData` or `combiner`.'
|
||||||
|
'`feature`: {}, combiner: {}.'.format(
|
||||||
i, feature, combiner))
|
i, feature, combiner))
|
||||||
|
|
||||||
|
if (enqueue_data.sample_indices is not None and
|
||||||
|
enqueue_data.sample_indices.op.device !=
|
||||||
|
enqueue_data.embedding_indices.op.device):
|
||||||
|
raise ValueError(
|
||||||
|
'Device of sample_indices does not agree with '
|
||||||
|
'that of emebdding_indices for feature {}.'.format(feature))
|
||||||
|
if (enqueue_data.aggregation_weights is not None and
|
||||||
|
enqueue_data.aggregation_weights.op.device !=
|
||||||
|
enqueue_data.embedding_indices.op.device):
|
||||||
|
raise ValueError(
|
||||||
|
'Device of aggregation_weights does not agree with '
|
||||||
|
'that of emebdding_indices for feature {}.'.format(feature))
|
||||||
# Check all features are on the same device.
|
# Check all features are on the same device.
|
||||||
if device is None:
|
if device is None:
|
||||||
device = tensor.op.device
|
device = enqueue_data.embedding_indices.op.device
|
||||||
device_feature = feature
|
device_feature = feature
|
||||||
else:
|
else:
|
||||||
if device != tensor.op.device:
|
if device != enqueue_data.embedding_indices.op.device:
|
||||||
raise ValueError('Devices are different between features in '
|
raise ValueError('Devices are different between features in '
|
||||||
'`sparse_features_list[{}]`; '
|
'`enqueue_datas_list[{}]`; '
|
||||||
'devices: {}, {}; features: {}, {}.'.format(
|
'devices: {}, {}; features: {}, {}.'.format(
|
||||||
i, device, tensor.op.device, feature,
|
i, device,
|
||||||
device_feature))
|
enqueue_data.embedding_indices.op.device,
|
||||||
|
feature, device_feature))
|
||||||
|
|
||||||
if i % self._num_cores_per_host:
|
if i % self._num_cores_per_host:
|
||||||
if device != contiguous_device:
|
if device != contiguous_device:
|
||||||
raise ValueError('We expect the `sparse_features` which are on the '
|
raise ValueError('We expect the `enqueue_datas` which are on the '
|
||||||
'same host to be contiguous in '
|
'same host to be contiguous in '
|
||||||
'`sparse_features_list`, '
|
'`enqueue_datas_list`, '
|
||||||
'`sparse_features_list[{}]` is on device {}, '
|
'`enqueue_datas_list[{}]` is on device {}, '
|
||||||
'but is expected to be on device {}.'.format(
|
'but is expected to be on device {}.'.format(
|
||||||
i, device, contiguous_device))
|
i, device, contiguous_device))
|
||||||
else:
|
else:
|
||||||
contiguous_device = device
|
contiguous_device = device
|
||||||
|
|
||||||
def _generate_enqueue_op(self, sparse_features, device_ordinal):
|
def _generate_enqueue_op(self, enqueue_datas, device_ordinal):
|
||||||
with ops.colocate_with(list(sparse_features.values())[0]):
|
enqueue_data0 = list(enqueue_datas.values())[0]
|
||||||
sample_idcs, embedding_idcs, aggregation_weights, table_ids = (
|
with ops.colocate_with(enqueue_data0.embedding_indices):
|
||||||
self._format_for_tpu_embedding_sparse_tensor_batch(sparse_features))
|
(sample_indices_list, embedding_indices_list, aggregation_weights_list,
|
||||||
|
table_ids) = (
|
||||||
|
self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
|
||||||
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
|
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
|
||||||
sample_idcs,
|
sample_indices_list,
|
||||||
embedding_idcs,
|
embedding_indices_list,
|
||||||
aggregation_weights,
|
aggregation_weights_list,
|
||||||
table_ids,
|
table_ids,
|
||||||
device_ordinal=device_ordinal,
|
device_ordinal=device_ordinal,
|
||||||
combiners=self._combiners)
|
combiners=self._combiners)
|
||||||
|
|
||||||
def _format_for_tpu_embedding_sparse_tensor_batch(self, sparse_features):
|
def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
|
||||||
"""Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
|
"""Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_features: a `Dict` of tensors for embedding. Can be sparse or
|
enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or
|
||||||
dense.
|
dense.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
|
Arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sample_idcs, embedding_idcs, aggregation_weights, table_ids = [], [], [], []
|
(sample_indices_list, embedding_indices_list, aggregation_weights_list,
|
||||||
|
table_ids) = [], [], [], []
|
||||||
for table_id, table in enumerate(self._table_to_features_dict):
|
for table_id, table in enumerate(self._table_to_features_dict):
|
||||||
features = self._table_to_features_dict[table]
|
features = self._table_to_features_dict[table]
|
||||||
for feature in features:
|
for feature in features:
|
||||||
tensor = sparse_features[feature]
|
enqueue_data = enqueue_datas[feature]
|
||||||
if not isinstance(tensor, sparse_tensor.SparseTensor):
|
|
||||||
sample_idcs.append(array_ops.zeros([0], dtype=dtypes.int32))
|
sample_indices = (
|
||||||
embedding_idcs.append(tensor)
|
enqueue_data.sample_indices
|
||||||
else:
|
if enqueue_data.sample_indices is not None else array_ops.zeros(
|
||||||
sample_idcs.append(tensor.indices)
|
(0,), dtype=dtypes.int32))
|
||||||
embedding_idcs.append(tensor.values)
|
sample_indices_list.append(sample_indices)
|
||||||
aggregation_weights.append(array_ops.zeros([0]))
|
|
||||||
|
aggregation_weights = (
|
||||||
|
enqueue_data.aggregation_weights if
|
||||||
|
enqueue_data.aggregation_weights is not None else array_ops.zeros(
|
||||||
|
(0,), dtype=dtypes.float32))
|
||||||
|
aggregation_weights_list.append(aggregation_weights)
|
||||||
|
|
||||||
|
embedding_indices_list.append(enqueue_data.embedding_indices)
|
||||||
|
|
||||||
table_ids.append(table_id)
|
table_ids.append(table_id)
|
||||||
|
|
||||||
return sample_idcs, embedding_idcs, aggregation_weights, table_ids
|
return (sample_indices_list, embedding_indices_list,
|
||||||
|
aggregation_weights_list, table_ids)
|
||||||
|
|
||||||
def get_activations(self):
|
def get_activations(self):
|
||||||
"""Get activations for features.
|
"""Get activations for features.
|
||||||
|
@ -891,7 +891,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
|
|||||||
"""Generates the per_host enqueue ops."""
|
"""Generates the per_host enqueue ops."""
|
||||||
control_deps = []
|
control_deps = []
|
||||||
per_host_sharded_inputs = []
|
per_host_sharded_inputs = []
|
||||||
sparse_features_list = []
|
enqueue_datas_list = []
|
||||||
num_replicas_per_host = ctx.num_of_replicas_per_host
|
num_replicas_per_host = ctx.num_of_replicas_per_host
|
||||||
cached_signals = None
|
cached_signals = None
|
||||||
with ops.device(device):
|
with ops.device(device):
|
||||||
@ -910,9 +910,9 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
|
|||||||
else:
|
else:
|
||||||
cached_signals = signals
|
cached_signals = signals
|
||||||
|
|
||||||
features, labels, sparse_features = (
|
features, labels, enqueue_data = (
|
||||||
_tpu_estimator_embedding.split_inputs(ctx, features, labels))
|
_tpu_estimator_embedding.split_inputs(ctx, features, labels))
|
||||||
sparse_features_list.append(sparse_features)
|
enqueue_datas_list.append(enqueue_data)
|
||||||
|
|
||||||
inputs_structure_recorder.validate_and_record_structure(
|
inputs_structure_recorder.validate_and_record_structure(
|
||||||
features, labels)
|
features, labels)
|
||||||
@ -945,7 +945,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
|
|||||||
if ctx.embedding_config:
|
if ctx.embedding_config:
|
||||||
per_host_enqueue_ops.extend(
|
per_host_enqueue_ops.extend(
|
||||||
ctx.embedding_config.tpu_embedding.generate_enqueue_ops(
|
ctx.embedding_config.tpu_embedding.generate_enqueue_ops(
|
||||||
sparse_features_list))
|
enqueue_datas_list))
|
||||||
|
|
||||||
if signals is None:
|
if signals is None:
|
||||||
return per_host_enqueue_ops
|
return per_host_enqueue_ops
|
||||||
|
Loading…
Reference in New Issue
Block a user