Enable TensorCore embeddings for training via FeatureColumnV2.
PiperOrigin-RevId: 312340625 Change-Id: I559aba797a8f1a37ecec1e4ee71cd027701ae6dd
This commit is contained in:
parent
b3387c0c19
commit
82143c1ad8
tensorflow
python/tpu
tools/api/golden/v1
@ -31,15 +31,18 @@ from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu.feature_column import _is_running_on_cpu
|
||||
from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name
|
||||
from tensorflow.python.tpu.feature_column import _SUPPORTED_CATEGORICAL_COLUMNS_V2
|
||||
from tensorflow.python.tpu.feature_column import _SUPPORTED_SEQUENCE_COLUMNS
|
||||
from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
# pylint: disable=protected-access
|
||||
|
||||
_ALLOWED_DEVICES = ['cpu', 'tpu_tensor_core', 'tpu_embedding_core']
|
||||
_TENSOR_CORE_MASK_KEY_SUFFIX = '__TENSOR_CORE_MASK'
|
||||
|
||||
|
||||
class EmbeddingDevice(enum.Enum):
|
||||
@ -174,10 +177,13 @@ def embedding_column_v2(categorical_column,
|
||||
elif embedding_lookup_device == 'tpu_embedding_core':
|
||||
embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE
|
||||
|
||||
if (embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE and
|
||||
not tensor_core_shape):
|
||||
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
|
||||
'tensor_core_shape to be set.')
|
||||
if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE:
|
||||
if not tensor_core_shape:
|
||||
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
|
||||
'tensor_core_shape to be set.')
|
||||
if isinstance(categorical_column, _SUPPORTED_SEQUENCE_COLUMNS):
|
||||
raise ValueError('embedding_lookup_device=tpu_tensor_core currently does '
|
||||
'not support sequence columns.')
|
||||
|
||||
if not embedding_lookup_device:
|
||||
return _TPUEmbeddingColumnV2(
|
||||
@ -372,10 +378,14 @@ def shared_embedding_columns_v2(categorical_columns,
|
||||
elif embedding_lookup_device == 'tpu_embedding_core':
|
||||
embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE
|
||||
|
||||
if (embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE and
|
||||
not tensor_core_shape):
|
||||
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
|
||||
'tensor_core_shape to be set.')
|
||||
if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE:
|
||||
if not tensor_core_shape:
|
||||
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
|
||||
'tensor_core_shape to be set.')
|
||||
for c in sorted_columns:
|
||||
if isinstance(c, _SUPPORTED_SEQUENCE_COLUMNS):
|
||||
raise ValueError('embedding_lookup_device=tpu_tensor_core currently '
|
||||
'does not support sequence columns.')
|
||||
|
||||
# Create the state (_SharedEmbeddingColumnLayer) here.
|
||||
for categorical_column, max_sequence_length in zip(
|
||||
@ -807,7 +817,13 @@ def sparse_embedding_aggregate_slice(params,
|
||||
if combiner == 'sum':
|
||||
return aggregate_emb
|
||||
elif combiner == 'mean':
|
||||
return aggregate_emb / math_ops.reduce_sum(values_mask_broadcast, axis=1)
|
||||
# In the case we have an empty row, both aggregate_emb and
|
||||
# math_ops.reduce_sum(values_mask_broadcast, axis=1) will be 0. Thus,
|
||||
# we can take max it with a non-zero value to prevent NaNs. Note that
|
||||
# math_ops.reduce_sum(values_mask_broadcast, axis=1) will have integer
|
||||
# values so 1.0 is the smallest value.
|
||||
return aggregate_emb / math_ops.maximum(
|
||||
math_ops.reduce_sum(values_mask_broadcast, axis=1), 1.0)
|
||||
else:
|
||||
raise ValueError('Dense TPU Embedding does not support combiner '
|
||||
'other than sum and mean.')
|
||||
@ -851,6 +867,20 @@ def pad_sparse_embedding_lookup_indices(sparse_indices, padded_size):
|
||||
return padded_values, padded_mask
|
||||
|
||||
|
||||
def _check_invalid_cases(embedding_lookup_device):
|
||||
"""Checks for invalid embedding_lookup_device configurations."""
|
||||
if (tpu.under_tpu_inference_context() and
|
||||
embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE):
|
||||
raise ValueError(
|
||||
'Using embedding_lookup_device=tpu_embedding_core during inference '
|
||||
'is not supported.')
|
||||
if embedding_lookup_device == EmbeddingDevice.CPU:
|
||||
if not tpu.under_tpu_inference_context():
|
||||
raise ValueError(
|
||||
'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" '
|
||||
'during training is not supported.')
|
||||
|
||||
|
||||
class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2):
|
||||
"""TPUEmbeddingColumn which allows serving on TensorCore."""
|
||||
|
||||
@ -874,46 +904,105 @@ class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2):
|
||||
del kwargs['embedding_lookup_device']
|
||||
_TPUEmbeddingColumnV2.__init__(self, *args, **kwargs)
|
||||
|
||||
def create_state(self, state_manager):
|
||||
if (tpu.under_tpu_inference_context() and
|
||||
self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE):
|
||||
raise ValueError(
|
||||
'Using embedding_lookup_device=tpu_embedding_core during inference '
|
||||
'is not supported.')
|
||||
if self._embedding_lookup_device == EmbeddingDevice.CPU:
|
||||
if tpu.under_tpu_inference_context():
|
||||
return fc_lib.EmbeddingColumn.create_state(self, state_manager)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" '
|
||||
'during training is not supported.')
|
||||
def __deepcopy__(self, memo):
|
||||
return _TPUDeviceSpecificEmbeddingColumnV2(
|
||||
*(copy.deepcopy(a, memo) for a in self.__getnewargs__()),
|
||||
tensor_core_shape=self._tensor_core_shape,
|
||||
embedding_lookup_device=self._embedding_lookup_device)
|
||||
|
||||
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
||||
self).create_state(state_manager)
|
||||
def create_state(self, state_manager):
|
||||
_check_invalid_cases(self._embedding_lookup_device)
|
||||
# CPU case.
|
||||
if self._embedding_lookup_device == EmbeddingDevice.CPU or _is_running_on_cpu(
|
||||
):
|
||||
return fc_lib.EmbeddingColumn.create_state(self, state_manager)
|
||||
# TPU_EMBEDDING_CORE case.
|
||||
elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
|
||||
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
||||
self).create_state(state_manager)
|
||||
|
||||
# TPU_EMBEDDING_CORE case.
|
||||
return fc_lib.EmbeddingColumn.create_state(self, state_manager)
|
||||
|
||||
def get_dense_tensor(self, transformation_cache, state_manager):
|
||||
"""Private method that follows get_dense_tensor."""
|
||||
|
||||
# If we aren't inferencing on TensorCore, just delegate to parent.
|
||||
if not tpu.under_tpu_inference_context() or not self._tensor_core_shape:
|
||||
_check_invalid_cases(self._embedding_lookup_device)
|
||||
# CPU Case.
|
||||
if self._embedding_lookup_device == EmbeddingDevice.CPU or _is_running_on_cpu(
|
||||
):
|
||||
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
||||
self).get_dense_tensor(transformation_cache, state_manager)
|
||||
# TPU_EMBEDDING_CORE case.
|
||||
elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
|
||||
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
||||
self).get_dense_tensor(transformation_cache, state_manager)
|
||||
sparse_tensor = transformation_cache.get(self.categorical_column.name,
|
||||
state_manager)
|
||||
|
||||
# Use outside compile to densify and pad the input tensors.
|
||||
def host_computation():
|
||||
return pad_sparse_embedding_lookup_indices(sparse_tensor,
|
||||
self._tensor_core_shape[1])
|
||||
# TPU_EMBEDDING_CORE cases.
|
||||
if tpu.under_tpu_inference_context():
|
||||
# For inference, use outside compile to densify and pad the input tensors.
|
||||
sparse_tensor = transformation_cache.get(self.categorical_column.name,
|
||||
state_manager)
|
||||
|
||||
values, mask = tpu.outside_compilation(host_computation)
|
||||
def host_computation():
|
||||
return pad_sparse_embedding_lookup_indices(sparse_tensor,
|
||||
self._tensor_core_shape[1])
|
||||
|
||||
# Do a dense embedding lookup on TensorCore.
|
||||
embedding_weights = state_manager.get_variable(self, 'embedding_weights')
|
||||
embedding = sparse_embedding_aggregate_slice(embedding_weights,
|
||||
(values, mask),
|
||||
self.get_combiner())
|
||||
return embedding
|
||||
values, mask = tpu.outside_compilation(host_computation)
|
||||
else:
|
||||
# For training, the inputs should already have been densified and padded.
|
||||
values = transformation_cache.get(self.categorical_column.name,
|
||||
state_manager)
|
||||
mask = transformation_cache.get(
|
||||
self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX,
|
||||
state_manager)
|
||||
embedding_weights = state_manager.get_variable(
|
||||
self, name='embedding_weights')
|
||||
return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
|
||||
self.get_combiner())
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
_check_invalid_cases(self._embedding_lookup_device)
|
||||
# CPU Case.
|
||||
if self._embedding_lookup_device == EmbeddingDevice.CPU or _is_running_on_cpu(
|
||||
):
|
||||
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
||||
self)._get_dense_tensor(inputs, weight_collections,
|
||||
trainable)
|
||||
# TPU_EMBEDDING_CORE case.
|
||||
elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
|
||||
return super(_TPUDeviceSpecificEmbeddingColumnV2,
|
||||
self)._get_dense_tensor(inputs, weight_collections,
|
||||
trainable)
|
||||
|
||||
# TPU_EMBEDDING_CORE cases.
|
||||
if tpu.under_tpu_inference_context():
|
||||
# For inference, use outside compile to densify and pad the input tensors.
|
||||
sparse_tensor = inputs.get(self.get_feature_key_name())
|
||||
|
||||
def host_computation():
|
||||
return pad_sparse_embedding_lookup_indices(sparse_tensor,
|
||||
self._tensor_core_shape[1])
|
||||
|
||||
values, mask = tpu.outside_compilation(host_computation)
|
||||
else:
|
||||
# For training, the inputs should already have been densified and padded.
|
||||
values = inputs.get(self.get_feature_key_name())
|
||||
mask = inputs.get(self.get_feature_key_name() +
|
||||
_TENSOR_CORE_MASK_KEY_SUFFIX)
|
||||
|
||||
embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
|
||||
if (weight_collections and
|
||||
ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
|
||||
weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
embedding_weights = variable_scope.get_variable(
|
||||
name='embedding_weights',
|
||||
shape=embedding_shape,
|
||||
dtype=dtypes.float32,
|
||||
initializer=self.initializer,
|
||||
trainable=self.trainable and trainable,
|
||||
collections=weight_collections)
|
||||
return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
|
||||
self.get_combiner())
|
||||
|
||||
|
||||
class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2):
|
||||
@ -940,34 +1029,47 @@ class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2):
|
||||
del kwargs['embedding_lookup_device']
|
||||
_TPUSharedEmbeddingColumnV2.__init__(self, *args, **kwargs)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return _TPUSharedDeviceSpecificEmbeddingColumnV2(
|
||||
*(copy.deepcopy(a, memo) for a in self.__getnewargs__()),
|
||||
tensor_core_shape=self._tensor_core_shape,
|
||||
embedding_lookup_device=self._embedding_lookup_device)
|
||||
|
||||
def _get_dense_tensor_internal(self, transformation_cache, state_manager):
|
||||
"""Private method that follows _get_dense_tensor_internal."""
|
||||
if (tpu.under_tpu_inference_context() and
|
||||
self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE):
|
||||
raise ValueError('Using embedding_lookup_device=tpu_embedding_core '
|
||||
'during inference is not supported.')
|
||||
if self._embedding_lookup_device == EmbeddingDevice.CPU:
|
||||
if tpu.under_tpu_inference_context():
|
||||
return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
|
||||
self)._get_dense_tensor_internal(transformation_cache,
|
||||
state_manager)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Using TPUSharedEmbeddingColumn with '
|
||||
'embedding_lookup_device="cpu" during training is not supported.')
|
||||
sparse_tensor = transformation_cache.get(self.categorical_column.name,
|
||||
state_manager)
|
||||
_check_invalid_cases(self._embedding_lookup_device)
|
||||
# CPU Case.
|
||||
if self._embedding_lookup_device == EmbeddingDevice.CPU or _is_running_on_cpu(
|
||||
):
|
||||
return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
|
||||
self)._get_dense_tensor_internal(transformation_cache,
|
||||
state_manager)
|
||||
# TPU_EMBEDDING_CORE case.
|
||||
if self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
|
||||
return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
|
||||
self)._get_dense_tensor_internal(transformation_cache,
|
||||
state_manager)
|
||||
|
||||
# Use outside compile to densify and pad the input tensors.
|
||||
def host_computation():
|
||||
return pad_sparse_embedding_lookup_indices(sparse_tensor,
|
||||
self._tensor_core_shape[1])
|
||||
# TPU_EMBEDDING_CORE cases.
|
||||
if tpu.under_tpu_inference_context():
|
||||
# For inference, use outside compile to densify and pad the input tensors.
|
||||
sparse_tensor = transformation_cache.get(self.categorical_column.name,
|
||||
state_manager)
|
||||
|
||||
values, mask = tpu.outside_compilation(host_computation)
|
||||
def host_computation():
|
||||
return pad_sparse_embedding_lookup_indices(sparse_tensor,
|
||||
self._tensor_core_shape[1])
|
||||
|
||||
values, mask = tpu.outside_compilation(host_computation)
|
||||
else:
|
||||
# For training, the inputs should already have been densified and padded.
|
||||
values = transformation_cache.get(self.categorical_column.name,
|
||||
state_manager)
|
||||
mask = transformation_cache.get(
|
||||
self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX,
|
||||
state_manager)
|
||||
|
||||
# Do a dense embedding lookup on TensorCore.
|
||||
embedding_weights = self.shared_embedding_column_creator.embedding_weights
|
||||
embedding = sparse_embedding_aggregate_slice(embedding_weights,
|
||||
(values, mask),
|
||||
self.get_combiner())
|
||||
return embedding
|
||||
return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
|
||||
self.get_combiner())
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import feature_column_v2 as tpu_fc
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
|
||||
|
||||
def _initialized_session():
|
||||
@ -514,50 +515,119 @@ class DeviceSpecificEmbeddingColumnTestV2(test.TestCase,
|
||||
embedding_lookup_device='tpu_tensor_core',
|
||||
tensor_core_shape=[None, 3])
|
||||
|
||||
# Run in TPUInferenceContext so that we hit the intended densification case.
|
||||
# Run in TPUContexts so that we hit the intended densification case.
|
||||
context = tpu._TPUInferenceContext('tpu_inference')
|
||||
context.Enter()
|
||||
with tpu_function.tpu_shard_context(1):
|
||||
dense_features = fc_lib.DenseFeatures(embedding_column)
|
||||
# Sqrtn combiner not supported for now.
|
||||
if combiner == 'sqrtn':
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Dense TPU Embedding does not support combiner'):
|
||||
embedding_lookup = dense_features(input_features)
|
||||
return
|
||||
if combiner == 'mean':
|
||||
expected_lookups = (
|
||||
# example 0:
|
||||
(7., 11.), # ids [2], embedding = [7, 11]
|
||||
# example 1:
|
||||
(2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) =
|
||||
# [2, 3.5]
|
||||
)
|
||||
elif combiner == 'sum':
|
||||
expected_lookups = (
|
||||
# example 0:
|
||||
(7., 11.), # ids [2], embedding = [7, 11]
|
||||
# example 1:
|
||||
(4., 7), # ids [0, 1], embedding = sum([1, 2] + [3, 5]) = [4, 7]
|
||||
)
|
||||
|
||||
dense_features = fc_lib.DenseFeatures(embedding_column)
|
||||
# Sqrtn combiner not supported for now.
|
||||
if combiner == 'sqrtn':
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Dense TPU Embedding does not support combiner'):
|
||||
embedding_lookup = dense_features(input_features)
|
||||
return
|
||||
if combiner == 'mean':
|
||||
embedding_lookup = dense_features(input_features)
|
||||
|
||||
# Assert expected embedding variable and lookups.
|
||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
if shared:
|
||||
self.assertCountEqual(('inp_shared_embedding:0',),
|
||||
tuple([v.name for v in global_vars]))
|
||||
else:
|
||||
self.assertCountEqual(
|
||||
('dense_features/inp_embedding/embedding_weights:0',),
|
||||
tuple([v.name for v in global_vars]))
|
||||
|
||||
embedding_var = global_vars[0]
|
||||
with _initialized_session():
|
||||
self.assertAllEqual(embedding_values, embedding_var.eval())
|
||||
eval_res = embedding_lookup.eval()
|
||||
self.assertAllEqual(expected_lookups, eval_res)
|
||||
context.Exit()
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def test_empty_row(self):
|
||||
# Inputs.
|
||||
vocabulary_size = 3
|
||||
input_sparse_tensor = sparse_tensor.SparseTensorValue(
|
||||
# example 0, ids []
|
||||
# example 1, ids [0, 1, 3]
|
||||
indices=((1, 0), (1, 1), (1, 4)),
|
||||
values=(0, 1, 3),
|
||||
dense_shape=(2, 5))
|
||||
input_features = {'inp': input_sparse_tensor}
|
||||
|
||||
# Embedding variable.
|
||||
embedding_dimension = 2
|
||||
embedding_values = (
|
||||
(1., 2.), # id 0
|
||||
(3., 5.), # id 1
|
||||
(7., 11.), # id 2
|
||||
(13., 17.) # id 3
|
||||
)
|
||||
|
||||
def _initializer(shape, dtype, partition_info=None):
|
||||
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
||||
self.assertEqual(dtypes.float32, dtype)
|
||||
self.assertIsNone(partition_info)
|
||||
return embedding_values
|
||||
|
||||
# Build columns.
|
||||
categorical_column_input = fc_lib.categorical_column_with_identity(
|
||||
key='inp', num_buckets=vocabulary_size)
|
||||
|
||||
# Set tensor_core_shape to be [None, 20] to ensure some padding and
|
||||
# dynamic batch size.
|
||||
embedding_column = tpu_fc.embedding_column_v2(
|
||||
categorical_column_input,
|
||||
dimension=embedding_dimension,
|
||||
initializer=_initializer,
|
||||
combiner='mean',
|
||||
embedding_lookup_device='tpu_tensor_core',
|
||||
tensor_core_shape=[None, 3])
|
||||
|
||||
# Run in TPUContexts so that we hit the intended densification case.
|
||||
context = tpu._TPUInferenceContext('tpu_inference')
|
||||
context.Enter()
|
||||
with tpu_function.tpu_shard_context(1):
|
||||
dense_features = fc_lib.DenseFeatures(embedding_column)
|
||||
expected_lookups = (
|
||||
# example 0:
|
||||
(7., 11.), # ids [2], embedding = [7, 11]
|
||||
(0., 0.), # ids [], embedding = [0, 0]
|
||||
# example 1:
|
||||
(2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
|
||||
)
|
||||
elif combiner == 'sum':
|
||||
expected_lookups = (
|
||||
# example 0:
|
||||
(7., 11.), # ids [2], embedding = [7, 11]
|
||||
# example 1:
|
||||
(4., 7), # ids [0, 1], embedding = sum([1, 2] + [3, 5]) = [4, 7]
|
||||
)
|
||||
|
||||
embedding_lookup = dense_features(input_features)
|
||||
embedding_lookup = dense_features(input_features)
|
||||
|
||||
# Assert expected embedding variable and lookups.
|
||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
if shared:
|
||||
self.assertCountEqual(('inp_shared_embedding:0',),
|
||||
tuple([v.name for v in global_vars]))
|
||||
else:
|
||||
# Assert expected embedding variable and lookups.
|
||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertCountEqual(
|
||||
('dense_features/inp_embedding/embedding_weights:0',),
|
||||
tuple([v.name for v in global_vars]))
|
||||
|
||||
embedding_var = global_vars[0]
|
||||
with _initialized_session():
|
||||
self.assertAllEqual(embedding_values, embedding_var.eval())
|
||||
eval_res = embedding_lookup.eval()
|
||||
self.assertAllEqual(expected_lookups, eval_res)
|
||||
context.Exit()
|
||||
embedding_var = global_vars[0]
|
||||
with _initialized_session():
|
||||
self.assertAllEqual(embedding_values, embedding_var.eval())
|
||||
eval_res = embedding_lookup.eval()
|
||||
self.assertAllEqual(expected_lookups, eval_res)
|
||||
context.Exit()
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def test_error_dense_shape_invalid(self):
|
||||
|
@ -35,6 +35,10 @@ tf_class {
|
||||
name: "table_to_config_dict"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "tensor_core_feature_columns"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user