1093 lines
46 KiB
Python
1093 lines
46 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ===================================================================
|
|
"""TPU Feature Column Library."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import copy
|
|
import math
|
|
|
|
import enum
|
|
|
|
from tensorflow.python.feature_column import feature_column as fc
|
|
from tensorflow.python.feature_column import feature_column_lib as fc_lib
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import embedding_ops
|
|
from tensorflow.python.ops import init_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import sparse_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.tpu import tpu
|
|
from tensorflow.python.tpu.feature_column import _is_running_on_cpu
|
|
from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name
|
|
from tensorflow.python.tpu.feature_column import _SUPPORTED_CATEGORICAL_COLUMNS_V2
|
|
from tensorflow.python.tpu.feature_column import _SUPPORTED_SEQUENCE_COLUMNS
|
|
from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
# pylint: disable=protected-access
|
|
|
|
_ALLOWED_DEVICES = ['cpu', 'tpu_tensor_core', 'tpu_embedding_core']
|
|
_TENSOR_CORE_MASK_KEY_SUFFIX = '__TENSOR_CORE_MASK'
|
|
|
|
|
|
class EmbeddingDevice(enum.Enum):
|
|
CPU = 1
|
|
TPU_TENSOR_CORE = 2
|
|
TPU_EMBEDDING_CORE = 3
|
|
|
|
|
|
@tf_export(v1=['tpu.experimental.embedding_column'])
|
|
def embedding_column_v2(categorical_column,
|
|
dimension,
|
|
combiner='mean',
|
|
initializer=None,
|
|
max_sequence_length=0,
|
|
learning_rate_fn=None,
|
|
embedding_lookup_device=None,
|
|
tensor_core_shape=None,
|
|
use_safe_embedding_lookup=True):
|
|
"""TPU version of `tf.compat.v1.feature_column.embedding_column`.
|
|
|
|
Note that the interface for `tf.tpu.experimental.embedding_column` is
|
|
different from that of `tf.compat.v1.feature_column.embedding_column`: The
|
|
following arguments are NOT supported: `ckpt_to_load_from`,
|
|
`tensor_name_in_ckpt`, `max_norm` and `trainable`.
|
|
|
|
Use this function in place of `tf.compat.v1.feature_column.embedding_column`
|
|
when you want to use the TPU to accelerate your embedding lookups via TPU
|
|
embeddings.
|
|
|
|
```
|
|
column = tf.feature_column.categorical_column_with_identity(...)
|
|
tpu_column = tf.tpu.experimental.embedding_column(column, 10)
|
|
...
|
|
def model_fn(features):
|
|
dense_feature = tf.keras.layers.DenseFeature(tpu_column)
|
|
embedded_feature = dense_feature(features)
|
|
...
|
|
|
|
estimator = tf.estimator.tpu.TPUEstimator(
|
|
model_fn=model_fn,
|
|
...
|
|
embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
|
|
column=[tpu_column],
|
|
...))
|
|
```
|
|
|
|
Args:
|
|
categorical_column: A categorical column returned from
|
|
`categorical_column_with_identity`, `weighted_categorical_column`,
|
|
`categorical_column_with_vocabulary_file`,
|
|
`categorical_column_with_vocabulary_list`,
|
|
`sequence_categorical_column_with_identity`,
|
|
`sequence_categorical_column_with_vocabulary_file`,
|
|
`sequence_categorical_column_with_vocabulary_list`
|
|
dimension: An integer specifying dimension of the embedding, must be > 0.
|
|
combiner: A string specifying how to reduce if there are multiple entries
|
|
in a single row for a non-sequence column. For more information, see
|
|
`tf.feature_column.embedding_column`.
|
|
initializer: A variable initializer function to be used in embedding
|
|
variable initialization. If not specified, defaults to
|
|
`tf.compat.v1.truncated_normal_initializer` with mean `0.0` and
|
|
standard deviation `1/sqrt(dimension)`.
|
|
max_sequence_length: An non-negative integer specifying the max sequence
|
|
length. Any sequence shorter then this will be padded with 0 embeddings
|
|
and any sequence longer will be truncated. This must be positive for
|
|
sequence features and 0 for non-sequence features.
|
|
learning_rate_fn: A function that takes global step and returns learning
|
|
rate for the embedding table. If you intend to use the same learning rate
|
|
for multiple embedding tables, please ensure that you pass the exact same
|
|
python function to all calls of embedding_column, otherwise performence
|
|
may suffer.
|
|
embedding_lookup_device: The device on which to run the embedding lookup.
|
|
Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core".
|
|
If specifying "tpu_tensor_core", a tensor_core_shape must be supplied.
|
|
If not specified, the default behavior is embedding lookup on
|
|
"tpu_embedding_core" for training and "cpu" for inference.
|
|
Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"]
|
|
Valid options for serving : ["cpu", "tpu_tensor_core"]
|
|
For training, tpu_embedding_core is good for large embedding vocab (>1M),
|
|
otherwise, tpu_tensor_core is often sufficient.
|
|
For serving, doing embedding lookup on tpu_tensor_core during serving is
|
|
a way to reduce host cpu usage in cases where that is a bottleneck.
|
|
tensor_core_shape: If supplied, a list of integers which specifies
|
|
the intended dense shape to run embedding lookup for this feature on
|
|
TensorCore. The batch dimension can be left None or -1 to indicate
|
|
a dynamic shape. Only rank 2 shapes currently supported.
|
|
use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
|
|
instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
|
|
there are no empty rows and all weights and ids are positive at the
|
|
expense of extra compute cost. This only applies to rank 2 (NxM) shaped
|
|
input tensors. Defaults to true, consider turning off if the above checks
|
|
are not needed. Note that having empty rows will not trigger any error
|
|
though the output result might be 0 or omitted.
|
|
|
|
Returns:
|
|
A `_TPUEmbeddingColumnV2`.
|
|
|
|
Raises:
|
|
ValueError: if `dimension` not > 0.
|
|
ValueError: if `initializer` is specified but not callable.
|
|
"""
|
|
|
|
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS_V2):
|
|
raise TypeError(
|
|
'categorical_column for tpu '
|
|
' embedding_column must be type %s, got %s.' % (' or '.join([
|
|
cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS_V2
|
|
]), type(categorical_column)))
|
|
if (dimension is None) or (dimension < 1):
|
|
raise ValueError('Invalid dimension {}.'.format(dimension))
|
|
if tensor_core_shape and len(tensor_core_shape) != 2:
|
|
raise ValueError(
|
|
'tensor_core_shape must be size 2. Got {}.'.format(tensor_core_shape))
|
|
|
|
if (initializer is not None) and (not callable(initializer)):
|
|
raise ValueError('initializer must be callable if specified. '
|
|
'Embedding of column_name: {}'.format(
|
|
categorical_column.name))
|
|
if initializer is None:
|
|
initializer = init_ops.truncated_normal_initializer(
|
|
mean=0.0, stddev=1 / math.sqrt(dimension))
|
|
|
|
if (embedding_lookup_device and
|
|
embedding_lookup_device not in _ALLOWED_DEVICES):
|
|
raise ValueError('If set, embedding_lookup_device must be in ',
|
|
_ALLOWED_DEVICES)
|
|
|
|
if embedding_lookup_device == 'cpu':
|
|
embedding_lookup_device = EmbeddingDevice.CPU
|
|
elif embedding_lookup_device == 'tpu_tensor_core':
|
|
embedding_lookup_device = EmbeddingDevice.TPU_TENSOR_CORE
|
|
elif embedding_lookup_device == 'tpu_embedding_core':
|
|
embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE
|
|
|
|
if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE:
|
|
if not tensor_core_shape:
|
|
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
|
|
'tensor_core_shape to be set.')
|
|
if isinstance(categorical_column, _SUPPORTED_SEQUENCE_COLUMNS):
|
|
raise ValueError('embedding_lookup_device=tpu_tensor_core currently does '
|
|
'not support sequence columns.')
|
|
|
|
if not embedding_lookup_device:
|
|
return _TPUEmbeddingColumnV2(
|
|
categorical_column=categorical_column,
|
|
dimension=dimension,
|
|
combiner=combiner,
|
|
initializer=initializer,
|
|
max_sequence_length=max_sequence_length,
|
|
learning_rate_fn=learning_rate_fn,
|
|
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
|
else:
|
|
return _TPUDeviceSpecificEmbeddingColumnV2(
|
|
categorical_column=categorical_column,
|
|
dimension=dimension,
|
|
combiner=combiner,
|
|
initializer=initializer,
|
|
max_sequence_length=max_sequence_length,
|
|
learning_rate_fn=learning_rate_fn,
|
|
embedding_lookup_device=embedding_lookup_device,
|
|
tensor_core_shape=tensor_core_shape,
|
|
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
|
|
|
|
|
@tf_export(v1=['tpu.experimental.shared_embedding_columns'])
|
|
def shared_embedding_columns_v2(categorical_columns,
|
|
dimension,
|
|
combiner='mean',
|
|
initializer=None,
|
|
shared_embedding_collection_name=None,
|
|
max_sequence_lengths=None,
|
|
learning_rate_fn=None,
|
|
embedding_lookup_device=None,
|
|
tensor_core_shape=None,
|
|
use_safe_embedding_lookup=True):
|
|
"""TPU version of `tf.compat.v1.feature_column.shared_embedding_columns`.
|
|
|
|
Note that the interface for `tf.tpu.experimental.shared_embedding_columns` is
|
|
different from that of `tf.compat.v1.feature_column.shared_embedding_columns`:
|
|
The following arguments are NOT supported: `ckpt_to_load_from`,
|
|
`tensor_name_in_ckpt`, `max_norm` and `trainable`.
|
|
|
|
Use this function in place of
|
|
tf.compat.v1.feature_column.shared_embedding_columns` when you want to use the
|
|
TPU to accelerate your embedding lookups via TPU embeddings.
|
|
|
|
```
|
|
column_a = tf.feature_column.categorical_column_with_identity(...)
|
|
column_b = tf.feature_column.categorical_column_with_identity(...)
|
|
tpu_columns = tf.tpu.experimental.shared_embedding_columns(
|
|
[column_a, column_b], 10)
|
|
...
|
|
def model_fn(features):
|
|
dense_feature = tf.keras.layers.DenseFeature(tpu_columns)
|
|
embedded_feature = dense_feature(features)
|
|
...
|
|
|
|
estimator = tf.estimator.tpu.TPUEstimator(
|
|
model_fn=model_fn,
|
|
...
|
|
embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
|
|
column=tpu_columns,
|
|
...))
|
|
```
|
|
|
|
Args:
|
|
categorical_columns: A list of categorical columns returned from
|
|
`categorical_column_with_identity`, `weighted_categorical_column`,
|
|
`categorical_column_with_vocabulary_file`,
|
|
`categorical_column_with_vocabulary_list`,
|
|
`sequence_categorical_column_with_identity`,
|
|
`sequence_categorical_column_with_vocabulary_file`,
|
|
`sequence_categorical_column_with_vocabulary_list`
|
|
dimension: An integer specifying dimension of the embedding, must be > 0.
|
|
combiner: A string specifying how to reduce if there are multiple entries in
|
|
a single row for a non-sequence column. For more information, see
|
|
`tf.feature_column.embedding_column`.
|
|
initializer: A variable initializer function to be used in embedding
|
|
variable initialization. If not specified, defaults to
|
|
`tf.truncated_normal_initializer` with mean `0.0` and standard deviation
|
|
`1/sqrt(dimension)`.
|
|
shared_embedding_collection_name: Optional name of the collection where
|
|
shared embedding weights are added. If not given, a reasonable name will
|
|
be chosen based on the names of `categorical_columns`. This is also used
|
|
in `variable_scope` when creating shared embedding weights.
|
|
max_sequence_lengths: An list of non-negative integers, either None or empty
|
|
or the same length as the argument categorical_columns. Entries
|
|
corresponding to non-sequence columns must be 0 and entries corresponding
|
|
to sequence columns specify the max sequence length for the column. Any
|
|
sequence shorter then this will be padded with 0 embeddings and any
|
|
sequence longer will be truncated.
|
|
learning_rate_fn: A function that takes global step and returns learning
|
|
rate for the embedding table. If you intend to use the same learning rate
|
|
for multiple embedding tables, please ensure that you pass the exact same
|
|
python function to all calls of shared_embedding_columns, otherwise
|
|
performence may suffer.
|
|
embedding_lookup_device: The device on which to run the embedding lookup.
|
|
Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". If
|
|
specifying "tpu_tensor_core", a tensor_core_shape must be supplied.
|
|
Defaults to "cpu". If not specified, the default behavior is embedding
|
|
lookup on "tpu_embedding_core" for training and "cpu" for inference.
|
|
Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"]
|
|
Valid options for serving : ["cpu", "tpu_tensor_core"]
|
|
For training, tpu_embedding_core is good for large embedding vocab (>1M),
|
|
otherwise, tpu_tensor_core is often sufficient.
|
|
For serving, doing embedding lookup on tpu_tensor_core during serving is
|
|
a way to reduce host cpu usage in cases where that is a bottleneck.
|
|
tensor_core_shape: If supplied, a list of integers which specifies the
|
|
intended dense shape to run embedding lookup for this feature on
|
|
TensorCore. The batch dimension can be left None or -1 to indicate a
|
|
dynamic shape. Only rank 2 shapes currently supported.
|
|
use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
|
|
instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
|
|
there are no empty rows and all weights and ids are positive at the
|
|
expense of extra compute cost. This only applies to rank 2 (NxM) shaped
|
|
input tensors. Defaults to true, consider turning off if the above checks
|
|
are not needed. Note that having empty rows will not trigger any error
|
|
though the output result might be 0 or omitted.
|
|
|
|
Returns:
|
|
A list of `_TPUSharedEmbeddingColumnV2`.
|
|
|
|
Raises:
|
|
ValueError: if `dimension` not > 0.
|
|
ValueError: if `initializer` is specified but not callable.
|
|
ValueError: if `max_sequence_lengths` is specified and not the same length
|
|
as `categorical_columns`.
|
|
ValueError: if `max_sequence_lengths` is positive for a non sequence column
|
|
or 0 for a sequence column.
|
|
"""
|
|
|
|
for categorical_column in categorical_columns:
|
|
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS_V2):
|
|
raise TypeError(
|
|
'categorical_column for tpu '
|
|
' shared_embedding_columns must be type %s, got %s.' % (' or '.join([
|
|
cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS_V2
|
|
]), type(categorical_column)))
|
|
|
|
if not max_sequence_lengths:
|
|
max_sequence_lengths = [0] * len(categorical_columns)
|
|
if len(max_sequence_lengths) != len(categorical_columns):
|
|
raise ValueError('max_sequence_lengths and categorical_columns must be of '
|
|
'the same length. len(max_sequence_lengths)={} '
|
|
'len(categorical_columns)={}.'.format(
|
|
len(max_sequence_lengths), len(categorical_columns)))
|
|
|
|
if (dimension is None) or (dimension < 1):
|
|
raise ValueError('Invalid dimension {}.'.format(dimension))
|
|
if tensor_core_shape and len(tensor_core_shape) != 2:
|
|
raise ValueError(
|
|
'tensor_core_shape must be size 2. Got {}.'.format(tensor_core_shape))
|
|
|
|
if (initializer is not None) and (not callable(initializer)):
|
|
raise ValueError('initializer must be callable if specified. ')
|
|
if initializer is None:
|
|
initializer = init_ops.truncated_normal_initializer(
|
|
mean=0.0, stddev=1 / math.sqrt(dimension))
|
|
|
|
# Sort the columns so the default collection name is deterministic even if the
|
|
# user passes columns from an unsorted collection, such as dict.values().
|
|
sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
|
|
num_buckets = sorted_columns[0]._num_buckets # pylint: disable=protected-access
|
|
|
|
for c in sorted_columns[1:]:
|
|
if num_buckets != c._num_buckets: # pylint: disable=protected-access
|
|
raise ValueError(
|
|
'To use shared_embedding_column, all categorical_columns must have '
|
|
'the same number of buckets. Given column: {} with buckets: {} does '
|
|
'not match column: {} with buckets: {}'.format(
|
|
sorted_columns[0], num_buckets, c, c._num_buckets)) # pylint: disable=protected-access
|
|
|
|
if not shared_embedding_collection_name:
|
|
shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
|
|
shared_embedding_collection_name += '_shared_embedding'
|
|
|
|
tpu_columns = []
|
|
|
|
column_creator = fc_lib.SharedEmbeddingColumnCreator(
|
|
dimension=dimension, initializer=initializer, ckpt_to_load_from=None,
|
|
tensor_name_in_ckpt=None, num_buckets=num_buckets, trainable=None,
|
|
name=shared_embedding_collection_name)
|
|
|
|
if (embedding_lookup_device and
|
|
embedding_lookup_device not in _ALLOWED_DEVICES):
|
|
raise ValueError('If set, embedding_lookup_device must be in ',
|
|
_ALLOWED_DEVICES)
|
|
|
|
if embedding_lookup_device == 'cpu':
|
|
embedding_lookup_device = EmbeddingDevice.CPU
|
|
elif embedding_lookup_device == 'tpu_tensor_core':
|
|
embedding_lookup_device = EmbeddingDevice.TPU_TENSOR_CORE
|
|
elif embedding_lookup_device == 'tpu_embedding_core':
|
|
embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE
|
|
|
|
if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE:
|
|
if not tensor_core_shape:
|
|
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
|
|
'tensor_core_shape to be set.')
|
|
for c in sorted_columns:
|
|
if isinstance(c, _SUPPORTED_SEQUENCE_COLUMNS):
|
|
raise ValueError('embedding_lookup_device=tpu_tensor_core currently '
|
|
'does not support sequence columns.')
|
|
|
|
# Create the state (_SharedEmbeddingColumnLayer) here.
|
|
for categorical_column, max_sequence_length in zip(
|
|
categorical_columns, max_sequence_lengths):
|
|
if not embedding_lookup_device:
|
|
column = _TPUSharedEmbeddingColumnV2(
|
|
categorical_column=categorical_column,
|
|
shared_embedding_column_creator=column_creator,
|
|
combiner=combiner,
|
|
initializer=initializer,
|
|
shared_embedding_collection_name=shared_embedding_collection_name,
|
|
max_sequence_length=max_sequence_length,
|
|
learning_rate_fn=learning_rate_fn,
|
|
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
|
else:
|
|
column = _TPUSharedDeviceSpecificEmbeddingColumnV2(
|
|
categorical_column=categorical_column,
|
|
shared_embedding_column_creator=column_creator,
|
|
combiner=combiner,
|
|
initializer=initializer,
|
|
shared_embedding_collection_name=shared_embedding_collection_name,
|
|
max_sequence_length=max_sequence_length,
|
|
learning_rate_fn=learning_rate_fn,
|
|
embedding_lookup_device=embedding_lookup_device,
|
|
tensor_core_shape=tensor_core_shape,
|
|
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
|
tpu_columns.append(column)
|
|
|
|
return tpu_columns
|
|
|
|
|
|
class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
|
|
"""Core Embedding Column."""
|
|
|
|
def __new__(cls,
|
|
categorical_column,
|
|
dimension,
|
|
combiner='mean',
|
|
initializer=None,
|
|
max_sequence_length=0,
|
|
learning_rate_fn=None,
|
|
use_safe_embedding_lookup=True,
|
|
bypass_scope_validation=False):
|
|
del bypass_scope_validation
|
|
return fc_lib.EmbeddingColumn.__new__(
|
|
cls,
|
|
categorical_column,
|
|
dimension,
|
|
combiner=combiner,
|
|
initializer=initializer,
|
|
ckpt_to_load_from=None,
|
|
tensor_name_in_ckpt=None,
|
|
max_norm=None,
|
|
trainable=True,
|
|
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
|
|
|
def __getnewargs__(self):
|
|
return (self._tpu_categorical_column, self.dimension, self.combiner,
|
|
self.initializer, self._max_sequence_length, self._learning_rate_fn)
|
|
|
|
def __deepcopy__(self, memo):
|
|
return _TPUEmbeddingColumnV2(
|
|
*(copy.deepcopy(a, memo) for a in self.__getnewargs__()))
|
|
|
|
def __init__(self,
|
|
categorical_column,
|
|
dimension,
|
|
combiner='mean',
|
|
initializer=None,
|
|
max_sequence_length=0,
|
|
learning_rate_fn=None,
|
|
use_safe_embedding_lookup=True,
|
|
bypass_scope_validation=False):
|
|
_TPUBaseEmbeddingColumn.__init__(
|
|
self,
|
|
categorical_column,
|
|
max_sequence_length=max_sequence_length,
|
|
learning_rate_fn=learning_rate_fn)
|
|
self._key = None
|
|
# If true, scope validation is skipped to allow the same column to be used
|
|
# in multiple variable scopes. By default, this is False, and we expect a
|
|
# 1:1 mapping between feature columns and scopes.
|
|
self._bypass_scope_validation = bypass_scope_validation
|
|
|
|
def get_combiner(self):
|
|
return self.combiner
|
|
|
|
def get_embedding_table_size(self):
|
|
"""Returns num_ids and width."""
|
|
return (self.categorical_column._num_buckets, self.dimension)
|
|
|
|
def get_feature_key_name(self):
|
|
"""get_feature_key_name."""
|
|
if self.is_categorical_column_weighted():
|
|
return self.categorical_column.categorical_column.name
|
|
return self.categorical_column.name
|
|
|
|
def get_weight_key_name(self):
|
|
"""get_weight_key_name."""
|
|
if self.is_categorical_column_weighted():
|
|
return self.categorical_column.weight_feature_key
|
|
return None
|
|
|
|
def get_embedding_var_name(self):
|
|
"""get_embedding_var_name."""
|
|
return self.categorical_column.name
|
|
|
|
def get_initializer(self):
|
|
return self.initializer
|
|
|
|
def is_categorical_column_weighted(self):
|
|
"""Check if the categorical column of the embedding column is weighted."""
|
|
if isinstance(
|
|
self.categorical_column,
|
|
(
|
|
fc._WeightedCategoricalColumn, # pylint: disable=protected-access
|
|
fc_lib.WeightedCategoricalColumn)):
|
|
return True
|
|
return False
|
|
|
|
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
|
if tpu.under_tpu_inference_context():
|
|
def host_computation():
|
|
return fc_lib.EmbeddingColumn._get_dense_tensor(
|
|
self, inputs, weight_collections, trainable)
|
|
return tpu.outside_compilation(host_computation)
|
|
|
|
if _is_running_on_cpu():
|
|
return fc_lib.EmbeddingColumn._get_dense_tensor(
|
|
self, inputs, weight_collections, trainable)
|
|
|
|
# TPU mode
|
|
# Get the embeddings from the LazyBuilder.
|
|
tensor = inputs.get(self.get_feature_key_name())
|
|
|
|
# Add to collection for _create_tpu_embedding_variables_and_ops
|
|
_record_variable_scope_and_name(
|
|
self.get_embedding_var_name(),
|
|
'embedding_weights',
|
|
bypass_scope_validation=self._bypass_scope_validation)
|
|
|
|
return tensor
|
|
|
|
def create_state(self, state_manager):
|
|
if _is_running_on_cpu():
|
|
return fc_lib.EmbeddingColumn.create_state(
|
|
self, state_manager)
|
|
|
|
# Create state is called for the EmbeddingColumn to create its embedding
|
|
# variables under feature column V2, if we are on TPU so record the scope
|
|
# here.
|
|
_record_variable_scope_and_name(
|
|
self.get_embedding_var_name(),
|
|
'embedding_weights',
|
|
bypass_scope_validation=self._bypass_scope_validation)
|
|
|
|
def get_dense_tensor(self, transformation_cache, state_manager):
|
|
if tpu.under_tpu_inference_context():
|
|
def host_computation():
|
|
return fc_lib.EmbeddingColumn.get_dense_tensor(
|
|
self, transformation_cache, state_manager)
|
|
return tpu.outside_compilation(host_computation)
|
|
|
|
if _is_running_on_cpu():
|
|
return fc_lib.EmbeddingColumn.get_dense_tensor(
|
|
self, transformation_cache, state_manager)
|
|
|
|
# TPU mode
|
|
# Get the embeddings from the FeatureTransformationCache.
|
|
tensor = transformation_cache.get(self.get_feature_key_name(),
|
|
state_manager)
|
|
|
|
return tensor
|
|
|
|
def _get_sequence_dense_tensor(
|
|
self, inputs, weight_collections=None, trainable=None):
|
|
if tpu.under_tpu_inference_context():
|
|
def host_computation():
|
|
return fc_lib.EmbeddingColumn._get_sequence_dense_tensor(
|
|
self, inputs, weight_collections, trainable)
|
|
return tpu.outside_compilation(host_computation)
|
|
|
|
if _is_running_on_cpu():
|
|
return fc_lib.EmbeddingColumn._get_sequence_dense_tensor(
|
|
self, inputs, weight_collections, trainable)
|
|
|
|
tensor = inputs.get(self.get_feature_key_name())
|
|
tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name())
|
|
|
|
# inputs is a _LazyBuilder and for rank 1 tensors, it calls expand_dims(-1).
|
|
# We need to undo this to match the standard CPU sequence embedding.
|
|
tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
|
|
|
|
# Add to collection for _create_tpu_embedding_variables_and_ops
|
|
_record_variable_scope_and_name(
|
|
self.get_embedding_var_name(),
|
|
'embedding_weights',
|
|
bypass_scope_validation=self._bypass_scope_validation)
|
|
|
|
return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair(
|
|
dense_tensor=tensor, sequence_length=tensor_lengths)
|
|
|
|
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
|
|
if tpu.under_tpu_inference_context():
|
|
def host_computation():
|
|
return fc_lib.EmbeddingColumn.get_sequence_dense_tensor(
|
|
self, transformation_cache, state_manager)
|
|
return tpu.outside_compilation(host_computation)
|
|
|
|
if _is_running_on_cpu():
|
|
return fc_lib.EmbeddingColumn.get_sequence_dense_tensor(
|
|
self, transformation_cache, state_manager)
|
|
|
|
tensor = transformation_cache.get(self.get_feature_key_name(),
|
|
state_manager)
|
|
tensor_lengths = transformation_cache.get(
|
|
self.get_sequence_length_feature_key_name(),
|
|
state_manager)
|
|
|
|
# FeatureTransformationCache expands rank 1 tensors (like sequence length)
|
|
# to rank 2. We need to undo this to match the standard CPU sequence
|
|
# embedding.
|
|
tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
|
|
|
|
return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair(
|
|
dense_tensor=tensor, sequence_length=tensor_lengths)
|
|
|
|
|
|
class _TPUSharedEmbeddingColumnV2(_TPUBaseEmbeddingColumn,
|
|
fc_lib.SharedEmbeddingColumn):
|
|
"""Core Shared Embedding Column."""
|
|
|
|
def __new__(cls,
|
|
categorical_column,
|
|
shared_embedding_column_creator,
|
|
combiner='mean',
|
|
initializer=None,
|
|
shared_embedding_collection_name=None,
|
|
max_sequence_length=0,
|
|
learning_rate_fn=None,
|
|
use_safe_embedding_lookup=True):
|
|
return fc_lib.SharedEmbeddingColumn.__new__(
|
|
cls,
|
|
categorical_column,
|
|
combiner=combiner,
|
|
shared_embedding_column_creator=shared_embedding_column_creator,
|
|
max_norm=None,
|
|
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
|
|
|
def __getnewargs__(self):
|
|
return (self._tpu_categorical_column, self.shared_embedding_column_creator,
|
|
self.combiner, self._initializer,
|
|
self._shared_embedding_collection_name, self._max_sequence_length,
|
|
self._learning_rate_fn)
|
|
|
|
def __deepcopy__(self, memo):
|
|
return _TPUSharedEmbeddingColumnV2(
|
|
*(copy.deepcopy(a, memo) for a in self.__getnewargs__()))
|
|
|
|
def __init__(self,
|
|
categorical_column,
|
|
shared_embedding_column_creator,
|
|
combiner='mean',
|
|
initializer=None,
|
|
shared_embedding_collection_name=None,
|
|
max_sequence_length=0,
|
|
learning_rate_fn=None,
|
|
use_safe_embedding_lookup=True):
|
|
|
|
_TPUBaseEmbeddingColumn.__init__(
|
|
self,
|
|
categorical_column,
|
|
max_sequence_length=max_sequence_length,
|
|
learning_rate_fn=learning_rate_fn)
|
|
self._initializer = initializer
|
|
self._shared_embedding_collection_name = shared_embedding_collection_name
|
|
|
|
def get_combiner(self):
|
|
return self.combiner
|
|
|
|
def get_embedding_table_size(self):
|
|
"""Returns num_ids and width."""
|
|
return (self.categorical_column._num_buckets,
|
|
self.shared_embedding_column_creator.dimension)
|
|
|
|
def get_feature_key_name(self):
|
|
"""get_feature_key_name."""
|
|
if self.is_categorical_column_weighted():
|
|
return self.categorical_column.categorical_column.name
|
|
return self.categorical_column.name
|
|
|
|
def get_weight_key_name(self):
|
|
"""get_weight_key_name."""
|
|
if self.is_categorical_column_weighted():
|
|
return self.categorical_column.weight_feature_key
|
|
return None
|
|
|
|
def get_embedding_var_name(self):
|
|
"""get_embedding_var_name."""
|
|
return self._shared_embedding_collection_name
|
|
|
|
def get_initializer(self):
|
|
return self._initializer
|
|
|
|
def is_categorical_column_weighted(self):
|
|
"""Check if the categorical column of the embedding column is weighted."""
|
|
if isinstance(
|
|
self.categorical_column,
|
|
(
|
|
fc._WeightedCategoricalColumn, # pylint: disable=protected-access
|
|
fc_lib.WeightedCategoricalColumn)):
|
|
return True
|
|
return False
|
|
|
|
def _get_dense_tensor_internal(
|
|
self, transformation_cache, state_manager):
|
|
if tpu.under_tpu_inference_context():
|
|
def host_computation():
|
|
return fc_lib.SharedEmbeddingColumn._get_dense_tensor_internal(
|
|
self, transformation_cache, state_manager)
|
|
return tpu.outside_compilation(host_computation)
|
|
|
|
if _is_running_on_cpu():
|
|
return fc_lib.SharedEmbeddingColumn._get_dense_tensor_internal(
|
|
self, transformation_cache, state_manager)
|
|
|
|
# TPU mode
|
|
# Get the embeddings from the FeatureTransformationCache.
|
|
tensor = transformation_cache.get(self.get_feature_key_name(),
|
|
state_manager)
|
|
|
|
# Add to collection for _create_tpu_embedding_variables_and_ops
|
|
# Note that in Feature Column V2, shared embeddings have no scope.
|
|
_record_variable_scope_and_name(
|
|
self.get_embedding_var_name(),
|
|
self.shared_embedding_column_creator._name,
|
|
is_shared_embedding=True)
|
|
return tensor
|
|
|
|
def get_sequence_dense_tensor(
|
|
self, transformation_cache, state_manager):
|
|
if tpu.under_tpu_inference_context():
|
|
def host_computation():
|
|
return fc_lib.SharedEmbeddingColumn.get_sequence_dense_tensor(
|
|
self, transformation_cache, state_manager)
|
|
return tpu.outside_compilation(host_computation)
|
|
|
|
if _is_running_on_cpu():
|
|
return fc_lib.SharedEmbeddingColumn.get_sequence_dense_tensor(
|
|
self, transformation_cache, state_manager)
|
|
|
|
tensor = self._get_dense_tensor_internal(
|
|
transformation_cache, state_manager)
|
|
tensor_lengths = transformation_cache.get(
|
|
self.get_sequence_length_feature_key_name(),
|
|
state_manager)
|
|
|
|
# FeatureTransformationCache expands rank 1 tensors (like sequence length)
|
|
# to rank 2. We need to undo this to match the standard CPU sequence
|
|
# embedding.
|
|
tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
|
|
|
|
return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair(
|
|
dense_tensor=tensor, sequence_length=tensor_lengths)
|
|
|
|
|
|
def split_sequence_columns_v2(feature_columns):
|
|
"""Split a list of _TPUEmbeddingColumn into sequence and non-sequence columns.
|
|
|
|
For use in a TPUEstimator model_fn function. E.g.
|
|
|
|
def model_fn(features):
|
|
sequence_columns, feature_columns = (
|
|
tf.tpu.feature_column.split_sequence_columns(feature_columns))
|
|
input = tf.feature_column.input_layer(
|
|
features=features, feature_columns=feature_columns)
|
|
sequence_features, sequence_lengths = (
|
|
tf.contrib.feature_column.sequence_input_layer(
|
|
features=features, feature_columns=sequence_columns))
|
|
|
|
Args:
|
|
feature_columns: A list of _TPUEmbeddingColumns to split.
|
|
|
|
Returns:
|
|
Two lists of _TPUEmbeddingColumns, the first is the sequence columns and the
|
|
second is the non-sequence columns.
|
|
"""
|
|
sequence_columns = []
|
|
non_sequence_columns = []
|
|
for column in feature_columns:
|
|
if not isinstance(column, (_TPUEmbeddingColumnV2,
|
|
_TPUSharedEmbeddingColumnV2)):
|
|
raise TypeError(
|
|
'column must be a _TPUEmbeddingColumnV2 or '
|
|
'_TPUSharedEmbeddingColumnV2 but got %s instead.' % (type(column)))
|
|
if column.is_sequence_column():
|
|
sequence_columns.append(column)
|
|
else:
|
|
non_sequence_columns.append(column)
|
|
return sequence_columns, non_sequence_columns
|
|
|
|
|
|
def sparse_embedding_aggregate_slice(params,
|
|
values_and_values_mask,
|
|
combiner='mean',
|
|
name='sparse_embedding_aggregate_slice'):
|
|
"""Uses XLA's dynamic slice operations to perform embedding lookups.
|
|
|
|
From third_party/cloud_tpu/models/movielens/tpu_embedding.py
|
|
|
|
Args:
|
|
params: Tensor of embedding table. Rank 2 (table_size x embedding dim)
|
|
values_and_values_mask: is a two-tuple that contains: values - Tensor of
|
|
embedding indices. Rank 2 (batch x n_indices) values_mask - Tensor of mask
|
|
/ weights. Rank 2 (batch x n_indices)
|
|
combiner: The combiner to use for the embedding lookup. Currently supports
|
|
'sum' and 'mean'.
|
|
name: Optional name scope for created ops
|
|
|
|
Returns:
|
|
Rank 2 tensor of aggregated (per batch element) embedding vectors.
|
|
|
|
Raises:
|
|
ValueError: Combiner is not supported.
|
|
"""
|
|
values, values_mask = values_and_values_mask # unpack the two-tuple
|
|
with ops.name_scope(name):
|
|
_, embedding_dimension = params.get_shape().as_list()
|
|
n_batch, n_indices_padded = values.get_shape().as_list()
|
|
if not n_batch:
|
|
n_batch = -1
|
|
|
|
emb_lookup = array_ops.reshape(
|
|
embedding_ops.embedding_lookup(
|
|
params, array_ops.reshape(values, [n_batch, n_indices_padded])),
|
|
[n_batch, n_indices_padded, embedding_dimension])
|
|
|
|
values_mask_broadcast = array_ops.reshape(values_mask,
|
|
[n_batch, n_indices_padded, 1])
|
|
aggregate_emb = math_ops.reduce_sum(
|
|
emb_lookup * values_mask_broadcast, axis=1)
|
|
if combiner == 'sum':
|
|
return aggregate_emb
|
|
elif combiner == 'mean':
|
|
# In the case we have an empty row, both aggregate_emb and
|
|
# math_ops.reduce_sum(values_mask_broadcast, axis=1) will be 0. Thus,
|
|
# we can take max it with a non-zero value to prevent NaNs. Note that
|
|
# math_ops.reduce_sum(values_mask_broadcast, axis=1) will have integer
|
|
# values so 1.0 is the smallest value.
|
|
return aggregate_emb / math_ops.maximum(
|
|
math_ops.reduce_sum(values_mask_broadcast, axis=1), 1.0)
|
|
else:
|
|
raise ValueError('Dense TPU Embedding does not support combiner '
|
|
'other than sum and mean.')
|
|
|
|
|
|
def pad_sparse_embedding_lookup_indices(sparse_indices, padded_size):
|
|
"""Creates statically-sized Tensors containing indices and weights.
|
|
|
|
From third_party/cloud_tpu/models/movielens/tpu_embedding.py
|
|
|
|
Also computes sparse_indices.values % embedding_table_size, for equivalent
|
|
functionality to sparse_column_with_integerized_feature. The returned
|
|
padded weight Tensor also doubles as a mask indicating which values in
|
|
the returned padded indices Tensor are indices versus padded zeros.
|
|
|
|
Args:
|
|
sparse_indices: SparseTensor of embedding lookup indices.
|
|
padded_size: Number of columns of the returned Tensors. Indices which fall
|
|
out of bounds will be truncated to the padded size.
|
|
|
|
Returns:
|
|
(sparse_indices.values padded to the specified size,
|
|
a mask the same size as the returned padded values in which 0s
|
|
indicate padded locations and 1s (or values from sparse_weights)
|
|
indicate actual values)
|
|
"""
|
|
batch_size = sparse_indices.dense_shape[0]
|
|
sparse_indices = sparse_ops.sparse_slice(sparse_indices, [0, 0],
|
|
[batch_size, padded_size])
|
|
indices, values = sparse_indices.indices, sparse_indices.values
|
|
|
|
padded_values = array_ops.scatter_nd(
|
|
indices,
|
|
math_ops.cast(values, dtypes.int32),
|
|
shape=(batch_size, padded_size))
|
|
|
|
weights = array_ops.ones_like(values, dtype=dtypes.float32)
|
|
padded_mask = array_ops.scatter_nd(
|
|
indices, weights, shape=(batch_size, padded_size))
|
|
|
|
return padded_values, padded_mask
|
|
|
|
|
|
def _check_invalid_cases(embedding_lookup_device):
|
|
"""Checks for invalid embedding_lookup_device configurations."""
|
|
if (tpu.under_tpu_inference_context() and
|
|
embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE):
|
|
raise ValueError(
|
|
'Using embedding_lookup_device=tpu_embedding_core during inference '
|
|
'is not supported.')
|
|
if embedding_lookup_device == EmbeddingDevice.CPU:
|
|
if not tpu.under_tpu_inference_context():
|
|
raise ValueError(
|
|
'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" '
|
|
'during training is not supported.')
|
|
|
|
|
|
class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2):
|
|
"""TPUEmbeddingColumn which allows serving on TensorCore."""
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
# For __new__, just capture the inference dense shape and call parent.
|
|
if 'tensor_core_shape' in kwargs:
|
|
cls._tensor_core_shape = kwargs['tensor_core_shape']
|
|
del kwargs['tensor_core_shape']
|
|
if 'embedding_lookup_device' in kwargs:
|
|
cls._embedding_lookup_device = kwargs['embedding_lookup_device']
|
|
del kwargs['embedding_lookup_device']
|
|
return _TPUEmbeddingColumnV2.__new__(cls, *args, **kwargs)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
# For __init__, just capture the inference dense shape and call parent.
|
|
if 'tensor_core_shape' in kwargs:
|
|
self._tensor_core_shape = kwargs['tensor_core_shape']
|
|
del kwargs['tensor_core_shape']
|
|
if 'embedding_lookup_device' in kwargs:
|
|
self._embedding_lookup_device = kwargs['embedding_lookup_device']
|
|
del kwargs['embedding_lookup_device']
|
|
_TPUEmbeddingColumnV2.__init__(self, *args, **kwargs)
|
|
|
|
def __deepcopy__(self, memo):
|
|
return _TPUDeviceSpecificEmbeddingColumnV2(
|
|
*(copy.deepcopy(a, memo) for a in self.__getnewargs__()),
|
|
tensor_core_shape=self._tensor_core_shape,
|
|
embedding_lookup_device=self._embedding_lookup_device)
|
|
|
|
def create_state(self, state_manager):
|
|
_check_invalid_cases(self._embedding_lookup_device)
|
|
# CPU case.
|
|
is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU
|
|
is_cpu = is_cpu or _is_running_on_cpu()
|
|
if is_cpu:
|
|
return fc_lib.EmbeddingColumn.create_state(self, state_manager)
|
|
# TPU_EMBEDDING_CORE case.
|
|
elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
|
|
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
|
self).create_state(state_manager)
|
|
|
|
# TPU_EMBEDDING_CORE case.
|
|
return fc_lib.EmbeddingColumn.create_state(self, state_manager)
|
|
|
|
def get_dense_tensor(self, transformation_cache, state_manager):
|
|
"""Private method that follows get_dense_tensor."""
|
|
_check_invalid_cases(self._embedding_lookup_device)
|
|
# CPU Case.
|
|
is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU
|
|
is_cpu = is_cpu or _is_running_on_cpu()
|
|
if is_cpu:
|
|
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
|
self).get_dense_tensor(transformation_cache, state_manager)
|
|
# TPU_EMBEDDING_CORE case.
|
|
elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
|
|
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
|
self).get_dense_tensor(transformation_cache, state_manager)
|
|
|
|
# TPU_EMBEDDING_CORE cases.
|
|
if tpu.under_tpu_inference_context():
|
|
# For inference, use outside compile to densify and pad the input tensors.
|
|
sparse_tensor = transformation_cache.get(self.categorical_column.name,
|
|
state_manager)
|
|
|
|
def host_computation():
|
|
return pad_sparse_embedding_lookup_indices(sparse_tensor,
|
|
self._tensor_core_shape[1])
|
|
|
|
values, mask = tpu.outside_compilation(host_computation)
|
|
else:
|
|
# For training, the inputs should already have been densified and padded.
|
|
values = transformation_cache.get(self.categorical_column.name,
|
|
state_manager)
|
|
mask = transformation_cache.get(
|
|
self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX,
|
|
state_manager)
|
|
embedding_weights = state_manager.get_variable(
|
|
self, name='embedding_weights')
|
|
return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
|
|
self.get_combiner())
|
|
|
|
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
|
_check_invalid_cases(self._embedding_lookup_device)
|
|
# CPU Case.
|
|
is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU
|
|
is_cpu = is_cpu or _is_running_on_cpu()
|
|
if is_cpu:
|
|
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
|
self)._get_dense_tensor(inputs, weight_collections,
|
|
trainable)
|
|
# TPU_EMBEDDING_CORE case.
|
|
elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
|
|
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
|
self)._get_dense_tensor(inputs, weight_collections,
|
|
trainable)
|
|
|
|
# TPU_EMBEDDING_CORE cases.
|
|
if tpu.under_tpu_inference_context():
|
|
# For inference, use outside compile to densify and pad the input tensors.
|
|
sparse_tensor = inputs.get(self.get_feature_key_name())
|
|
|
|
def host_computation():
|
|
return pad_sparse_embedding_lookup_indices(sparse_tensor,
|
|
self._tensor_core_shape[1])
|
|
|
|
values, mask = tpu.outside_compilation(host_computation)
|
|
else:
|
|
# For training, the inputs should already have been densified and padded.
|
|
values = inputs.get(self.get_feature_key_name())
|
|
mask = inputs.get(self.get_feature_key_name() +
|
|
_TENSOR_CORE_MASK_KEY_SUFFIX)
|
|
|
|
embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
|
|
if (weight_collections and
|
|
ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
|
|
weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
|
|
embedding_weights = variable_scope.get_variable(
|
|
name='embedding_weights',
|
|
shape=embedding_shape,
|
|
dtype=dtypes.float32,
|
|
initializer=self.initializer,
|
|
trainable=self.trainable and trainable,
|
|
collections=weight_collections)
|
|
return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
|
|
self.get_combiner())
|
|
|
|
|
|
class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2):
|
|
"""TPUSharedEmbeddingColumnV2 which allows serving on TensorCore."""
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
# For __new__, just capture the inference dense shape and call parent.
|
|
if 'tensor_core_shape' in kwargs:
|
|
cls._tensor_core_shape = kwargs['tensor_core_shape']
|
|
del kwargs['tensor_core_shape']
|
|
if 'embedding_lookup_device' in kwargs:
|
|
cls._embedding_lookup_device = kwargs['embedding_lookup_device']
|
|
del kwargs['embedding_lookup_device']
|
|
|
|
return _TPUSharedEmbeddingColumnV2.__new__(cls, *args, **kwargs)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
# For __init__, just capture the inference dense shape and call parent.
|
|
if 'tensor_core_shape' in kwargs:
|
|
self._tensor_core_shape = kwargs['tensor_core_shape']
|
|
del kwargs['tensor_core_shape']
|
|
if 'embedding_lookup_device' in kwargs:
|
|
self._embedding_lookup_device = kwargs['embedding_lookup_device']
|
|
del kwargs['embedding_lookup_device']
|
|
_TPUSharedEmbeddingColumnV2.__init__(self, *args, **kwargs)
|
|
|
|
def __deepcopy__(self, memo):
|
|
return _TPUSharedDeviceSpecificEmbeddingColumnV2(
|
|
*(copy.deepcopy(a, memo) for a in self.__getnewargs__()),
|
|
tensor_core_shape=self._tensor_core_shape,
|
|
embedding_lookup_device=self._embedding_lookup_device)
|
|
|
|
def _get_dense_tensor_internal(self, transformation_cache, state_manager):
|
|
"""Private method that follows _get_dense_tensor_internal."""
|
|
_check_invalid_cases(self._embedding_lookup_device)
|
|
# CPU Case.
|
|
is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU
|
|
is_cpu = is_cpu or _is_running_on_cpu()
|
|
if is_cpu:
|
|
return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
|
|
self)._get_dense_tensor_internal(transformation_cache,
|
|
state_manager)
|
|
# TPU_EMBEDDING_CORE case.
|
|
if self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
|
|
return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
|
|
self)._get_dense_tensor_internal(transformation_cache,
|
|
state_manager)
|
|
|
|
# TPU_EMBEDDING_CORE cases.
|
|
if tpu.under_tpu_inference_context():
|
|
# For inference, use outside compile to densify and pad the input tensors.
|
|
sparse_tensor = transformation_cache.get(self.categorical_column.name,
|
|
state_manager)
|
|
|
|
def host_computation():
|
|
return pad_sparse_embedding_lookup_indices(sparse_tensor,
|
|
self._tensor_core_shape[1])
|
|
|
|
values, mask = tpu.outside_compilation(host_computation)
|
|
else:
|
|
# For training, the inputs should already have been densified and padded.
|
|
values = transformation_cache.get(self.categorical_column.name,
|
|
state_manager)
|
|
mask = transformation_cache.get(
|
|
self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX,
|
|
state_manager)
|
|
|
|
# Do a dense embedding lookup on TensorCore.
|
|
embedding_weights = self.shared_embedding_column_creator.embedding_weights
|
|
return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
|
|
self.get_combiner())
|