Fix case where embedding column -> use_safe_embedding = false is used with variable partitioning.
PiperOrigin-RevId: 311782693 Change-Id: I38b59943a25adbe77e9f3f01c49a713876cc3f22
This commit is contained in:
parent
321d3d9fd0
commit
2db0d85d05
|
@ -2546,7 +2546,7 @@ class _EmbeddingColumn(
|
||||||
embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
|
embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
|
||||||
if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
|
if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
|
||||||
sparse_id_rank <= 2):
|
sparse_id_rank <= 2):
|
||||||
embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse
|
embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
|
||||||
# Return embedding lookup result.
|
# Return embedding lookup result.
|
||||||
return embedding_lookup_sparse(
|
return embedding_lookup_sparse(
|
||||||
embedding_weights,
|
embedding_weights,
|
||||||
|
@ -2696,7 +2696,7 @@ class _SharedEmbeddingColumn(
|
||||||
embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
|
embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
|
||||||
if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
|
if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
|
||||||
sparse_id_rank <= 2):
|
sparse_id_rank <= 2):
|
||||||
embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse
|
embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
|
||||||
# Return embedding lookup result.
|
# Return embedding lookup result.
|
||||||
return embedding_lookup_sparse(
|
return embedding_lookup_sparse(
|
||||||
embedding_weights,
|
embedding_weights,
|
||||||
|
|
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.core.example import example_pb2
|
from tensorflow.core.example import example_pb2
|
||||||
|
@ -852,9 +853,9 @@ class HashedCategoricalColumnTest(test.TestCase):
|
||||||
'aaa': inputs
|
'aaa': inputs
|
||||||
}), weight_collections=('my_weights',))
|
}), weight_collections=('my_weights',))
|
||||||
|
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual([],
|
||||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
self.assertCountEqual([], ops.get_collection('my_weights'))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
|
@ -1714,10 +1715,10 @@ class LinearModelTest(test.TestCase):
|
||||||
# We check the mapping by checking that we have the right keys,
|
# We check the mapping by checking that we have the right keys,
|
||||||
# and that the values (output_tensors) were indeed the ones used to
|
# and that the values (output_tensors) were indeed the ones used to
|
||||||
# form the input layer.
|
# form the input layer.
|
||||||
self.assertItemsEqual(all_cols, cols_to_output_tensors.keys())
|
self.assertCountEqual(all_cols, cols_to_output_tensors.keys())
|
||||||
input_layer_inputs = [tensor for tensor in input_layer.op.inputs[:-1]]
|
input_layer_inputs = [tensor for tensor in input_layer.op.inputs[:-1]]
|
||||||
output_tensors = [tensor for tensor in cols_to_output_tensors.values()]
|
output_tensors = [tensor for tensor in cols_to_output_tensors.values()]
|
||||||
self.assertItemsEqual(input_layer_inputs, output_tensors)
|
self.assertCountEqual(input_layer_inputs, output_tensors)
|
||||||
|
|
||||||
def test_dense_collection(self):
|
def test_dense_collection(self):
|
||||||
price = fc._numeric_column('price')
|
price = fc._numeric_column('price')
|
||||||
|
@ -2841,7 +2842,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
||||||
cols_to_vars = {}
|
cols_to_vars = {}
|
||||||
all_cols = [price1, dense_feature_bucketized, some_embedding_column]
|
all_cols = [price1, dense_feature_bucketized, some_embedding_column]
|
||||||
fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
||||||
self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
|
self.assertCountEqual(list(cols_to_vars.keys()), all_cols)
|
||||||
self.assertEqual(0, len(cols_to_vars[price1]))
|
self.assertEqual(0, len(cols_to_vars[price1]))
|
||||||
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
||||||
self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
|
self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
|
||||||
|
@ -2891,7 +2892,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
||||||
shared_embedding_a, shared_embedding_b
|
shared_embedding_a, shared_embedding_b
|
||||||
]
|
]
|
||||||
fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
||||||
self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
|
self.assertCountEqual(list(cols_to_vars.keys()), all_cols)
|
||||||
self.assertEqual(0, len(cols_to_vars[price1]))
|
self.assertEqual(0, len(cols_to_vars[price1]))
|
||||||
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
||||||
self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
|
self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
|
||||||
|
@ -2927,7 +2928,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
||||||
'input_from_feature_columns',
|
'input_from_feature_columns',
|
||||||
partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
|
partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
|
||||||
fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
||||||
self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
|
self.assertCountEqual(list(cols_to_vars.keys()), all_cols)
|
||||||
self.assertEqual(0, len(cols_to_vars[price1]))
|
self.assertEqual(0, len(cols_to_vars[price1]))
|
||||||
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
||||||
self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
|
self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
|
||||||
|
@ -3043,7 +3044,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
||||||
'input_layer/sparse_feature_embedding/embedding_weights:0',
|
'input_layer/sparse_feature_embedding/embedding_weights:0',
|
||||||
'input_layer_1/sparse_feature_embedding/embedding_weights:0'
|
'input_layer_1/sparse_feature_embedding/embedding_weights:0'
|
||||||
]
|
]
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@ -3077,7 +3078,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
||||||
# Make sure that only 1 variable gets created in this case.
|
# Make sure that only 1 variable gets created in this case.
|
||||||
self.assertEqual(1, len(
|
self.assertEqual(1, len(
|
||||||
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
|
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@ -3129,7 +3130,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
||||||
# Make sure that only 1 variable gets created in this case.
|
# Make sure that only 1 variable gets created in this case.
|
||||||
self.assertEqual(1, len(
|
self.assertEqual(1, len(
|
||||||
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
|
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@ -3618,9 +3619,9 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
||||||
'aaa': inputs
|
'aaa': inputs
|
||||||
}), weight_collections=('my_weights',))
|
}), weight_collections=('my_weights',))
|
||||||
|
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual([],
|
||||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
self.assertCountEqual([], ops.get_collection('my_weights'))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
|
@ -4058,9 +4059,9 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
||||||
'aaa': inputs
|
'aaa': inputs
|
||||||
}), weight_collections=('my_weights',))
|
}), weight_collections=('my_weights',))
|
||||||
|
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual([],
|
||||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
self.assertCountEqual([], ops.get_collection('my_weights'))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
|
@ -4363,9 +4364,9 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
||||||
'aaa': inputs
|
'aaa': inputs
|
||||||
}), weight_collections=('my_weights',))
|
}), weight_collections=('my_weights',))
|
||||||
|
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual([],
|
||||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
self.assertCountEqual([], ops.get_collection('my_weights'))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_get_sparse_tensors_dense_input(self):
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
|
@ -4820,7 +4821,7 @@ class IndicatorColumnTest(test.TestCase):
|
||||||
self.assertAllClose([[0., 1., 1., 0.]], self.evaluate(net))
|
self.assertAllClose([[0., 1., 1., 0.]], self.evaluate(net))
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingColumnTest(test.TestCase):
|
class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
|
@ -4956,10 +4957,29 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
_assert_sparse_tensor_value(self, self.evaluate(output_a),
|
_assert_sparse_tensor_value(self, self.evaluate(output_a),
|
||||||
self.evaluate(output_embedded))
|
self.evaluate(output_embedded))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
{
|
||||||
|
'testcase_name': 'use_safe_embedding_lookup',
|
||||||
|
'use_safe_embedding_lookup': True,
|
||||||
|
'partition_variables': False,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'dont_use_safe_embedding_lookup',
|
||||||
|
'use_safe_embedding_lookup': False,
|
||||||
|
'partition_variables': False,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'use_safe_embedding_lookup_partitioned',
|
||||||
|
'use_safe_embedding_lookup': True,
|
||||||
|
'partition_variables': True,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'dont_use_safe_embedding_lookup_partitioned',
|
||||||
|
'use_safe_embedding_lookup': False,
|
||||||
|
'partition_variables': True,
|
||||||
|
})
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor(self):
|
def test_get_dense_tensor(self, use_safe_embedding_lookup,
|
||||||
|
partition_variables):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 4
|
||||||
sparse_input = sparse_tensor.SparseTensorValue(
|
sparse_input = sparse_tensor.SparseTensorValue(
|
||||||
# example 0, ids [2]
|
# example 0, ids [2]
|
||||||
# example 1, ids [0, 1]
|
# example 1, ids [0, 1]
|
||||||
|
@ -4974,12 +4994,20 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2.), # id 0
|
(1., 2.), # id 0
|
||||||
(3., 5.), # id 1
|
(3., 5.), # id 1
|
||||||
(7., 11.) # id 2
|
(7., 11.), # id 2
|
||||||
|
(9., 13.) # id 3
|
||||||
)
|
)
|
||||||
def _initializer(shape, dtype, partition_info):
|
|
||||||
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
def _initializer(shape, dtype, partition_info=None):
|
||||||
|
if partition_variables:
|
||||||
|
self.assertEqual([vocabulary_size, embedding_dimension],
|
||||||
|
partition_info.full_shape)
|
||||||
|
self.assertAllEqual((2, embedding_dimension), shape)
|
||||||
|
else:
|
||||||
|
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
||||||
|
self.assertIsNone(partition_info)
|
||||||
|
|
||||||
self.assertEqual(dtypes.float32, dtype)
|
self.assertEqual(dtypes.float32, dtype)
|
||||||
self.assertIsNone(partition_info)
|
|
||||||
return embedding_values
|
return embedding_values
|
||||||
|
|
||||||
# Expected lookup result, using combiner='mean'.
|
# Expected lookup result, using combiner='mean'.
|
||||||
|
@ -4997,25 +5025,43 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
# Build columns.
|
# Build columns.
|
||||||
categorical_column = fc._categorical_column_with_identity(
|
categorical_column = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=vocabulary_size)
|
key='aaa', num_buckets=vocabulary_size)
|
||||||
embedding_column = fc._embedding_column(
|
partitioner = None
|
||||||
categorical_column,
|
if partition_variables:
|
||||||
dimension=embedding_dimension,
|
partitioner = partitioned_variables.fixed_size_partitioner(2, axis=0)
|
||||||
initializer=_initializer)
|
with variable_scope.variable_scope('vars', partitioner=partitioner):
|
||||||
|
embedding_column = fc._embedding_column(
|
||||||
|
categorical_column,
|
||||||
|
dimension=embedding_dimension,
|
||||||
|
initializer=_initializer,
|
||||||
|
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
||||||
|
|
||||||
# Provide sparse input and get dense result.
|
# Provide sparse input and get dense result.
|
||||||
embedding_lookup = embedding_column._get_dense_tensor(
|
embedding_lookup = embedding_column._get_dense_tensor(
|
||||||
_LazyBuilder({
|
_LazyBuilder({'aaa': sparse_input}))
|
||||||
'aaa': sparse_input
|
|
||||||
}))
|
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('embedding_weights:0',),
|
if partition_variables:
|
||||||
tuple([v.name for v in global_vars]))
|
self.assertCountEqual(('vars/embedding_weights/part_0:0',
|
||||||
|
'vars/embedding_weights/part_1:0'),
|
||||||
|
tuple([v.name for v in global_vars]))
|
||||||
|
else:
|
||||||
|
self.assertCountEqual(('vars/embedding_weights:0',),
|
||||||
|
tuple([v.name for v in global_vars]))
|
||||||
|
for v in global_vars:
|
||||||
|
self.assertIsInstance(v, variables_lib.Variable)
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
||||||
|
|
||||||
|
if use_safe_embedding_lookup:
|
||||||
|
self.assertIn('SparseFillEmptyRows',
|
||||||
|
[x.type for x in ops.get_default_graph().get_operations()])
|
||||||
|
else:
|
||||||
|
self.assertNotIn(
|
||||||
|
'SparseFillEmptyRows',
|
||||||
|
[x.type for x in ops.get_default_graph().get_operations()])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_3d(self):
|
def test_get_dense_tensor_3d(self):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
|
@ -5072,7 +5118,7 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('embedding_weights:0',),
|
self.assertCountEqual(('embedding_weights:0',),
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||||
|
@ -5102,11 +5148,11 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('embedding_weights:0',),
|
self.assertCountEqual(('embedding_weights:0',),
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
my_vars = ops.get_collection('my_vars')
|
my_vars = ops.get_collection('my_vars')
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(('embedding_weights:0',),
|
||||||
('embedding_weights:0',), tuple([v.name for v in my_vars]))
|
tuple([v.name for v in my_vars]))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_placeholder_inputs(self):
|
def test_get_dense_tensor_placeholder_inputs(self):
|
||||||
|
@ -5169,8 +5215,8 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(('embedding_weights:0',),
|
||||||
('embedding_weights:0',), tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||||
self.assertAllEqual(expected_lookups, embedding_lookup.eval(
|
self.assertAllEqual(expected_lookups, embedding_lookup.eval(
|
||||||
|
@ -5233,8 +5279,8 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(('embedding_weights:0',),
|
||||||
('embedding_weights:0',), tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup))
|
||||||
|
@ -5280,14 +5326,14 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
'linear_model/aaa_embedding/weights:0',
|
'linear_model/aaa_embedding/weights:0',
|
||||||
'linear_model/aaa_embedding/embedding_weights:0',
|
'linear_model/aaa_embedding/embedding_weights:0',
|
||||||
)
|
)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
trainable_vars = {
|
trainable_vars = {
|
||||||
v.name: v for v in ops.get_collection(
|
v.name: v for v in ops.get_collection(
|
||||||
ops.GraphKeys.TRAINABLE_VARIABLES)
|
ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
}
|
}
|
||||||
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
|
self.assertCountEqual(expected_var_names, trainable_vars.keys())
|
||||||
bias = trainable_vars['linear_model/bias_weights:0']
|
bias = trainable_vars['linear_model/bias_weights:0']
|
||||||
embedding_weights = trainable_vars[
|
embedding_weights = trainable_vars[
|
||||||
'linear_model/aaa_embedding/embedding_weights:0']
|
'linear_model/aaa_embedding/embedding_weights:0']
|
||||||
|
@ -5361,14 +5407,14 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
'linear_model/aaa_embedding/weights:0',
|
'linear_model/aaa_embedding/weights:0',
|
||||||
'linear_model/aaa_embedding/embedding_weights:0',
|
'linear_model/aaa_embedding/embedding_weights:0',
|
||||||
)
|
)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
trainable_vars = {
|
trainable_vars = {
|
||||||
v.name: v
|
v.name: v
|
||||||
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
}
|
}
|
||||||
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
|
self.assertCountEqual(expected_var_names, trainable_vars.keys())
|
||||||
bias = trainable_vars['linear_model/bias_weights:0']
|
bias = trainable_vars['linear_model/bias_weights:0']
|
||||||
embedding_weights = trainable_vars[
|
embedding_weights = trainable_vars[
|
||||||
'linear_model/aaa_embedding/embedding_weights:0']
|
'linear_model/aaa_embedding/embedding_weights:0']
|
||||||
|
@ -5450,13 +5496,11 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(('input_layer/aaa_embedding/embedding_weights:0',),
|
||||||
('input_layer/aaa_embedding/embedding_weights:0',),
|
tuple([v.name for v in global_vars]))
|
||||||
tuple([v.name for v in global_vars]))
|
|
||||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(('input_layer/aaa_embedding/embedding_weights:0',),
|
||||||
('input_layer/aaa_embedding/embedding_weights:0',),
|
tuple([v.name for v in trainable_vars]))
|
||||||
tuple([v.name for v in trainable_vars]))
|
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self.assertAllEqual(embedding_values, trainable_vars[0].eval())
|
self.assertAllEqual(embedding_values, trainable_vars[0].eval())
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(input_layer))
|
self.assertAllEqual(expected_lookups, self.evaluate(input_layer))
|
||||||
|
@ -5513,17 +5557,16 @@ class EmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(('input_layer/aaa_embedding/embedding_weights:0',),
|
||||||
('input_layer/aaa_embedding/embedding_weights:0',),
|
tuple([v.name for v in global_vars]))
|
||||||
tuple([v.name for v in global_vars]))
|
self.assertCountEqual([],
|
||||||
self.assertItemsEqual(
|
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
|
||||||
[], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
|
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||||
self.assertAllEqual(expected_lookups, self.evaluate(input_layer))
|
self.assertAllEqual(expected_lookups, self.evaluate(input_layer))
|
||||||
|
|
||||||
|
|
||||||
class SharedEmbeddingColumnTest(test.TestCase):
|
class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
|
@ -5772,33 +5815,59 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
||||||
_assert_sparse_tensor_value(self, self.evaluate(output_b),
|
_assert_sparse_tensor_value(self, self.evaluate(output_b),
|
||||||
self.evaluate(output_b_embedded))
|
self.evaluate(output_b_embedded))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
{
|
||||||
|
'testcase_name': 'use_safe_embedding_lookup',
|
||||||
|
'use_safe_embedding_lookup': True,
|
||||||
|
'partition_variables': False,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'dont_use_safe_embedding_lookup',
|
||||||
|
'use_safe_embedding_lookup': False,
|
||||||
|
'partition_variables': False,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'use_safe_embedding_lookup_partitioned',
|
||||||
|
'use_safe_embedding_lookup': True,
|
||||||
|
'partition_variables': True,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'dont_use_safe_embedding_lookup_partitioned',
|
||||||
|
'use_safe_embedding_lookup': False,
|
||||||
|
'partition_variables': True,
|
||||||
|
})
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor(self):
|
def test_get_dense_tensor(self, use_safe_embedding_lookup,
|
||||||
|
partition_variables):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 4
|
||||||
# -1 values are ignored.
|
# -1 values are ignored.
|
||||||
input_a = np.array(
|
input_a = np.array([
|
||||||
[[2, -1, -1], # example 0, ids [2]
|
[2, -1, -1], # example 0, ids [2]
|
||||||
[0, 1, -1]]) # example 1, ids [0, 1]
|
[0, 1, -1]
|
||||||
input_b = np.array(
|
]) # example 1, ids [0, 1]
|
||||||
[[0, -1, -1], # example 0, ids [0]
|
input_b = np.array([
|
||||||
[-1, -1, -1]]) # example 1, ids []
|
[0, -1, -1], # example 0, ids [0]
|
||||||
input_features = {
|
[-1, -1, -1]
|
||||||
'aaa': input_a,
|
]) # example 1, ids []
|
||||||
'bbb': input_b
|
input_features = {'aaa': input_a, 'bbb': input_b}
|
||||||
}
|
|
||||||
|
|
||||||
# Embedding variable.
|
# Embedding variable.
|
||||||
embedding_dimension = 2
|
embedding_dimension = 2
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2.), # id 0
|
(1., 2.), # id 0
|
||||||
(3., 5.), # id 1
|
(3., 5.), # id 1
|
||||||
(7., 11.) # id 2
|
(7., 11.), # id 2
|
||||||
|
(9., 13.) # id 3
|
||||||
)
|
)
|
||||||
def _initializer(shape, dtype, partition_info):
|
|
||||||
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
def _initializer(shape, dtype, partition_info=None):
|
||||||
|
if partition_variables:
|
||||||
|
self.assertEqual([vocabulary_size, embedding_dimension],
|
||||||
|
partition_info.full_shape)
|
||||||
|
self.assertAllEqual((2, embedding_dimension), shape)
|
||||||
|
else:
|
||||||
|
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
||||||
|
self.assertIsNone(partition_info)
|
||||||
|
|
||||||
self.assertEqual(dtypes.float32, dtype)
|
self.assertEqual(dtypes.float32, dtype)
|
||||||
self.assertIsNone(partition_info)
|
|
||||||
return embedding_values
|
return embedding_values
|
||||||
|
|
||||||
# Expected lookup result, using combiner='mean'.
|
# Expected lookup result, using combiner='mean'.
|
||||||
|
@ -5808,38 +5877,65 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
||||||
# example 1:
|
# example 1:
|
||||||
(2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
|
(2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
|
||||||
)
|
)
|
||||||
expected_lookups_b = (
|
if use_safe_embedding_lookup:
|
||||||
# example 0:
|
expected_lookups_b = (
|
||||||
(1., 2.), # ids [0], embedding = [1, 2]
|
# example 0:
|
||||||
# example 1:
|
(1., 2.), # ids [0], embedding = [1, 2]
|
||||||
(0., 0.), # ids [], embedding = [0, 0]
|
# example 1:
|
||||||
)
|
(0., 0.), # ids [], embedding = [0, 0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
expected_lookups_b = (
|
||||||
|
# example 0:
|
||||||
|
(1., 2.), # ids [0], embedding = [1, 2]
|
||||||
|
)
|
||||||
|
|
||||||
# Build columns.
|
# Build columns.
|
||||||
categorical_column_a = fc._categorical_column_with_identity(
|
categorical_column_a = fc._categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=vocabulary_size)
|
key='aaa', num_buckets=vocabulary_size)
|
||||||
categorical_column_b = fc._categorical_column_with_identity(
|
categorical_column_b = fc._categorical_column_with_identity(
|
||||||
key='bbb', num_buckets=vocabulary_size)
|
key='bbb', num_buckets=vocabulary_size)
|
||||||
embedding_column_a, embedding_column_b = fc_new.shared_embedding_columns(
|
|
||||||
[categorical_column_a, categorical_column_b],
|
|
||||||
dimension=embedding_dimension,
|
|
||||||
initializer=_initializer)
|
|
||||||
|
|
||||||
# Provide sparse input and get dense result.
|
partitioner = None
|
||||||
embedding_lookup_a = embedding_column_a._get_dense_tensor(
|
if partition_variables:
|
||||||
_LazyBuilder(input_features))
|
partitioner = partitioned_variables.fixed_size_partitioner(2, axis=0)
|
||||||
embedding_lookup_b = embedding_column_b._get_dense_tensor(
|
|
||||||
_LazyBuilder(input_features))
|
|
||||||
|
|
||||||
|
with variable_scope.variable_scope('vars', partitioner=partitioner):
|
||||||
|
embedding_column_a, embedding_column_b = fc_new.shared_embedding_columns(
|
||||||
|
[categorical_column_a, categorical_column_b],
|
||||||
|
dimension=embedding_dimension,
|
||||||
|
initializer=_initializer,
|
||||||
|
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
||||||
|
# Provide sparse input and get dense result.
|
||||||
|
embedding_lookup_a = embedding_column_a._get_dense_tensor(
|
||||||
|
_LazyBuilder(input_features))
|
||||||
|
embedding_lookup_b = embedding_column_b._get_dense_tensor(
|
||||||
|
_LazyBuilder(input_features))
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('embedding_weights:0',),
|
if partition_variables:
|
||||||
tuple([v.name for v in global_vars]))
|
self.assertCountEqual(('vars/embedding_weights/part_0:0',
|
||||||
|
'vars/embedding_weights/part_1:0'),
|
||||||
|
tuple([v.name for v in global_vars]))
|
||||||
|
else:
|
||||||
|
self.assertCountEqual(('vars/embedding_weights:0',),
|
||||||
|
tuple([v.name for v in global_vars]))
|
||||||
embedding_var = global_vars[0]
|
embedding_var = global_vars[0]
|
||||||
with _initialized_session():
|
|
||||||
self.assertAllEqual(embedding_values, self.evaluate(embedding_var))
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
self.assertAllEqual(expected_lookups_a, self.evaluate(embedding_lookup_a))
|
self.evaluate(lookup_ops.tables_initializer())
|
||||||
self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b))
|
|
||||||
|
self.assertAllEqual(embedding_values, self.evaluate(embedding_var))
|
||||||
|
self.assertAllEqual(expected_lookups_a, self.evaluate(embedding_lookup_a))
|
||||||
|
self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b))
|
||||||
|
|
||||||
|
if use_safe_embedding_lookup:
|
||||||
|
self.assertIn('SparseFillEmptyRows',
|
||||||
|
[x.type for x in ops.get_default_graph().get_operations()])
|
||||||
|
else:
|
||||||
|
self.assertNotIn(
|
||||||
|
'SparseFillEmptyRows',
|
||||||
|
[x.type for x in ops.get_default_graph().get_operations()])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor_weight_collections(self):
|
def test_get_dense_tensor_weight_collections(self):
|
||||||
|
@ -5886,11 +5982,11 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
|
('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
|
||||||
tuple(v.name for v in global_vars))
|
tuple(v.name for v in global_vars))
|
||||||
my_vars = ops.get_collection('my_vars')
|
my_vars = ops.get_collection('my_vars')
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
|
('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
|
||||||
tuple(v.name for v in my_vars))
|
tuple(v.name for v in my_vars))
|
||||||
|
|
||||||
|
@ -5997,14 +6093,14 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
||||||
'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
|
'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
|
||||||
'linear_model/aaa_bbb_shared_embedding_1/weights:0',
|
'linear_model/aaa_bbb_shared_embedding_1/weights:0',
|
||||||
)
|
)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
trainable_vars = {
|
trainable_vars = {
|
||||||
v.name: v for v in ops.get_collection(
|
v.name: v for v in ops.get_collection(
|
||||||
ops.GraphKeys.TRAINABLE_VARIABLES)
|
ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
}
|
}
|
||||||
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
|
self.assertCountEqual(expected_var_names, trainable_vars.keys())
|
||||||
bias = trainable_vars['linear_model/bias_weights:0']
|
bias = trainable_vars['linear_model/bias_weights:0']
|
||||||
embedding_weights = trainable_vars[
|
embedding_weights = trainable_vars[
|
||||||
'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
|
'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
|
||||||
|
@ -6091,14 +6187,14 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
||||||
'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
|
'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
|
||||||
'linear_model/aaa_bbb_shared_embedding_1/weights:0',
|
'linear_model/aaa_bbb_shared_embedding_1/weights:0',
|
||||||
)
|
)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
trainable_vars = {
|
trainable_vars = {
|
||||||
v.name: v
|
v.name: v
|
||||||
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
}
|
}
|
||||||
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
|
self.assertCountEqual(expected_var_names, trainable_vars.keys())
|
||||||
bias = trainable_vars['linear_model/bias_weights:0']
|
bias = trainable_vars['linear_model/bias_weights:0']
|
||||||
embedding_weights = trainable_vars[
|
embedding_weights = trainable_vars[
|
||||||
'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
|
'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
|
||||||
|
@ -6195,16 +6291,16 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
if trainable:
|
if trainable:
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
|
||||||
tuple([v.name for v in trainable_vars]))
|
tuple([v.name for v in trainable_vars]))
|
||||||
else:
|
else:
|
||||||
self.assertItemsEqual([], tuple([v.name for v in trainable_vars]))
|
self.assertCountEqual([], tuple([v.name for v in trainable_vars]))
|
||||||
shared_embedding_vars = global_vars
|
shared_embedding_vars = global_vars
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
|
self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
|
||||||
|
|
|
@ -3263,7 +3263,7 @@ class EmbeddingColumn(
|
||||||
embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
|
embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
|
||||||
if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
|
if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
|
||||||
sparse_id_rank <= 2):
|
sparse_id_rank <= 2):
|
||||||
embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse
|
embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
|
||||||
# Return embedding lookup result.
|
# Return embedding lookup result.
|
||||||
return embedding_lookup_sparse(
|
return embedding_lookup_sparse(
|
||||||
embedding_weights,
|
embedding_weights,
|
||||||
|
@ -3558,7 +3558,7 @@ class SharedEmbeddingColumn(
|
||||||
embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
|
embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
|
||||||
if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
|
if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
|
||||||
sparse_id_rank <= 2):
|
sparse_id_rank <= 2):
|
||||||
embedding_lookup_sparse = (embedding_ops.embedding_lookup_sparse)
|
embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
|
||||||
# Return embedding lookup result.
|
# Return embedding lookup result.
|
||||||
return embedding_lookup_sparse(
|
return embedding_lookup_sparse(
|
||||||
embedding_weights,
|
embedding_weights,
|
||||||
|
|
|
@ -2087,7 +2087,7 @@ class LinearModelTest(test.TestCase):
|
||||||
for var in model.variables:
|
for var in model.variables:
|
||||||
self.assertIsInstance(var, variables_lib.VariableV1)
|
self.assertIsInstance(var, variables_lib.VariableV1)
|
||||||
variable_names = [var.name for var in model.variables]
|
variable_names = [var.name for var in model.variables]
|
||||||
self.assertItemsEqual([
|
self.assertCountEqual([
|
||||||
'linear_model/dense_feature_bucketized/weights:0',
|
'linear_model/dense_feature_bucketized/weights:0',
|
||||||
'linear_model/price1/weights:0',
|
'linear_model/price1/weights:0',
|
||||||
'linear_model/sparse_feature_embedding/embedding_weights:0',
|
'linear_model/sparse_feature_embedding/embedding_weights:0',
|
||||||
|
@ -2731,10 +2731,10 @@ class OldLinearModelTest(test.TestCase):
|
||||||
# We check the mapping by checking that we have the right keys,
|
# We check the mapping by checking that we have the right keys,
|
||||||
# and that the values (output_tensors) were indeed the ones used to
|
# and that the values (output_tensors) were indeed the ones used to
|
||||||
# form the input layer.
|
# form the input layer.
|
||||||
self.assertItemsEqual(all_cols, cols_to_output_tensors.keys())
|
self.assertCountEqual(all_cols, cols_to_output_tensors.keys())
|
||||||
input_layer_inputs = [tensor for tensor in input_layer.op.inputs[:-1]]
|
input_layer_inputs = [tensor for tensor in input_layer.op.inputs[:-1]]
|
||||||
output_tensors = [tensor for tensor in cols_to_output_tensors.values()]
|
output_tensors = [tensor for tensor in cols_to_output_tensors.values()]
|
||||||
self.assertItemsEqual(input_layer_inputs, output_tensors)
|
self.assertCountEqual(input_layer_inputs, output_tensors)
|
||||||
|
|
||||||
def test_dense_collection(self):
|
def test_dense_collection(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
|
@ -3411,7 +3411,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
||||||
cols_to_vars = {}
|
cols_to_vars = {}
|
||||||
all_cols = [price1, dense_feature_bucketized, some_embedding_column]
|
all_cols = [price1, dense_feature_bucketized, some_embedding_column]
|
||||||
fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
||||||
self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
|
self.assertCountEqual(list(cols_to_vars.keys()), all_cols)
|
||||||
self.assertEqual(0, len(cols_to_vars[price1]))
|
self.assertEqual(0, len(cols_to_vars[price1]))
|
||||||
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
||||||
self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
|
self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
|
||||||
|
@ -3461,7 +3461,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
||||||
shared_embedding_a, shared_embedding_b
|
shared_embedding_a, shared_embedding_b
|
||||||
]
|
]
|
||||||
fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
||||||
self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
|
self.assertCountEqual(list(cols_to_vars.keys()), all_cols)
|
||||||
self.assertEqual(0, len(cols_to_vars[price1]))
|
self.assertEqual(0, len(cols_to_vars[price1]))
|
||||||
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
||||||
self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
|
self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
|
||||||
|
@ -3497,7 +3497,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
||||||
'input_from_feature_columns',
|
'input_from_feature_columns',
|
||||||
partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
|
partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
|
||||||
fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
|
||||||
self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
|
self.assertCountEqual(list(cols_to_vars.keys()), all_cols)
|
||||||
self.assertEqual(0, len(cols_to_vars[price1]))
|
self.assertEqual(0, len(cols_to_vars[price1]))
|
||||||
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
|
||||||
self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
|
self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
|
||||||
|
@ -3616,7 +3616,7 @@ class FunctionalInputLayerTest(test.TestCase):
|
||||||
'input_layer/sparse_feature_embedding/embedding_weights:0',
|
'input_layer/sparse_feature_embedding/embedding_weights:0',
|
||||||
'input_layer_1/sparse_feature_embedding/embedding_weights:0'
|
'input_layer_1/sparse_feature_embedding/embedding_weights:0'
|
||||||
]
|
]
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
|
|
||||||
|
@ -5904,7 +5904,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('embedding_weights:0',),
|
self.assertCountEqual(('embedding_weights:0',),
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
|
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
|
@ -5968,7 +5968,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('embedding_weights:0',),
|
self.assertCountEqual(('embedding_weights:0',),
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
|
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
|
@ -6036,7 +6036,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('embedding_weights:0',),
|
self.assertCountEqual(('embedding_weights:0',),
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
|
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
|
@ -6109,7 +6109,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('embedding_weights:0',),
|
self.assertCountEqual(('embedding_weights:0',),
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
|
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
|
@ -6180,7 +6180,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('embedding_weights:0',),
|
self.assertCountEqual(('embedding_weights:0',),
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
|
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
|
@ -6230,14 +6230,14 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
'linear_model/aaa_embedding/weights:0',
|
'linear_model/aaa_embedding/weights:0',
|
||||||
'linear_model/aaa_embedding/embedding_weights:0',
|
'linear_model/aaa_embedding/embedding_weights:0',
|
||||||
)
|
)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
trainable_vars = {
|
trainable_vars = {
|
||||||
v.name: v
|
v.name: v
|
||||||
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
}
|
}
|
||||||
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
|
self.assertCountEqual(expected_var_names, trainable_vars.keys())
|
||||||
bias = trainable_vars['linear_model/bias_weights:0']
|
bias = trainable_vars['linear_model/bias_weights:0']
|
||||||
embedding_weights = trainable_vars[
|
embedding_weights = trainable_vars[
|
||||||
'linear_model/aaa_embedding/embedding_weights:0']
|
'linear_model/aaa_embedding/embedding_weights:0']
|
||||||
|
@ -6274,15 +6274,25 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
{
|
{
|
||||||
'testcase_name': 'use_safe_embedding_lookup',
|
'testcase_name': 'use_safe_embedding_lookup',
|
||||||
'use_safe_embedding_lookup': True
|
'use_safe_embedding_lookup': True,
|
||||||
|
'partition_variables': False,
|
||||||
}, {
|
}, {
|
||||||
'testcase_name': 'dont_use_safe_embedding_lookup',
|
'testcase_name': 'dont_use_safe_embedding_lookup',
|
||||||
'use_safe_embedding_lookup': False
|
'use_safe_embedding_lookup': False,
|
||||||
|
'partition_variables': False,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'use_safe_embedding_lookup_partitioned',
|
||||||
|
'use_safe_embedding_lookup': True,
|
||||||
|
'partition_variables': True,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'dont_use_safe_embedding_lookup_partitioned',
|
||||||
|
'use_safe_embedding_lookup': False,
|
||||||
|
'partition_variables': True,
|
||||||
})
|
})
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_dense_features(self, use_safe_embedding_lookup):
|
def test_dense_features(self, use_safe_embedding_lookup, partition_variables):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 4
|
||||||
sparse_input = sparse_tensor.SparseTensorValue(
|
sparse_input = sparse_tensor.SparseTensorValue(
|
||||||
# example 0, ids [2]
|
# example 0, ids [2]
|
||||||
# example 1, ids [0, 1]
|
# example 1, ids [0, 1]
|
||||||
|
@ -6297,13 +6307,20 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2.), # id 0
|
(1., 2.), # id 0
|
||||||
(3., 5.), # id 1
|
(3., 5.), # id 1
|
||||||
(7., 11.) # id 2
|
(7., 11.), # id 2
|
||||||
|
(9., 13.) # id 3
|
||||||
)
|
)
|
||||||
|
|
||||||
def _initializer(shape, dtype, partition_info=None):
|
def _initializer(shape, dtype, partition_info=None):
|
||||||
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
if partition_variables:
|
||||||
|
self.assertEqual([vocabulary_size, embedding_dimension],
|
||||||
|
partition_info.full_shape)
|
||||||
|
self.assertAllEqual((2, embedding_dimension), shape)
|
||||||
|
else:
|
||||||
|
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
||||||
|
self.assertIsNone(partition_info)
|
||||||
|
|
||||||
self.assertEqual(dtypes.float32, dtype)
|
self.assertEqual(dtypes.float32, dtype)
|
||||||
self.assertIsNone(partition_info)
|
|
||||||
return embedding_values
|
return embedding_values
|
||||||
|
|
||||||
# Expected lookup result, using combiner='mean'.
|
# Expected lookup result, using combiner='mean'.
|
||||||
|
@ -6321,25 +6338,43 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
# Build columns.
|
# Build columns.
|
||||||
categorical_column = fc.categorical_column_with_identity(
|
categorical_column = fc.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=vocabulary_size)
|
key='aaa', num_buckets=vocabulary_size)
|
||||||
embedding_column = fc.embedding_column(
|
partitioner = None
|
||||||
categorical_column,
|
if partition_variables:
|
||||||
dimension=embedding_dimension,
|
partitioner = partitioned_variables.fixed_size_partitioner(2, axis=0)
|
||||||
initializer=_initializer,
|
with variable_scope.variable_scope('vars', partitioner=partitioner):
|
||||||
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
embedding_column = fc.embedding_column(
|
||||||
|
categorical_column,
|
||||||
|
dimension=embedding_dimension,
|
||||||
|
initializer=_initializer,
|
||||||
|
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
||||||
|
|
||||||
# Provide sparse input and get dense result.
|
# Provide sparse input and get dense result.
|
||||||
l = df.DenseFeatures((embedding_column,))
|
l = df.DenseFeatures((embedding_column,))
|
||||||
dense_features = l({'aaa': sparse_input})
|
dense_features = l({'aaa': sparse_input})
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('dense_features/aaa_embedding/embedding_weights:0',),
|
if partition_variables:
|
||||||
tuple([v.name for v in global_vars]))
|
self.assertCountEqual(
|
||||||
|
('vars/dense_features/aaa_embedding/embedding_weights/part_0:0',
|
||||||
|
'vars/dense_features/aaa_embedding/embedding_weights/part_1:0'),
|
||||||
|
tuple([v.name for v in global_vars]))
|
||||||
|
else:
|
||||||
|
self.assertCountEqual(
|
||||||
|
('vars/dense_features/aaa_embedding/embedding_weights:0',),
|
||||||
|
tuple([v.name for v in global_vars]))
|
||||||
for v in global_vars:
|
for v in global_vars:
|
||||||
self.assertIsInstance(v, variables_lib.Variable)
|
self.assertIsInstance(v, variables_lib.Variable)
|
||||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
self.assertItemsEqual(('dense_features/aaa_embedding/embedding_weights:0',),
|
if partition_variables:
|
||||||
tuple([v.name for v in trainable_vars]))
|
self.assertCountEqual(
|
||||||
|
('vars/dense_features/aaa_embedding/embedding_weights/part_0:0',
|
||||||
|
'vars/dense_features/aaa_embedding/embedding_weights/part_1:0'),
|
||||||
|
tuple([v.name for v in trainable_vars]))
|
||||||
|
else:
|
||||||
|
self.assertCountEqual(
|
||||||
|
('vars/dense_features/aaa_embedding/embedding_weights:0',),
|
||||||
|
tuple([v.name for v in trainable_vars]))
|
||||||
|
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
self.evaluate(lookup_ops.tables_initializer())
|
self.evaluate(lookup_ops.tables_initializer())
|
||||||
|
@ -6410,9 +6445,9 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('dense_features/aaa_embedding/embedding_weights:0',),
|
self.assertCountEqual(('dense_features/aaa_embedding/embedding_weights:0',),
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
self.assertItemsEqual([],
|
self.assertCountEqual([],
|
||||||
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
|
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
|
||||||
|
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
|
@ -6475,10 +6510,10 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('input_layer/aaa_embedding/embedding_weights:0',),
|
self.assertCountEqual(('input_layer/aaa_embedding/embedding_weights:0',),
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
self.assertItemsEqual(('input_layer/aaa_embedding/embedding_weights:0',),
|
self.assertCountEqual(('input_layer/aaa_embedding/embedding_weights:0',),
|
||||||
tuple([v.name for v in trainable_vars]))
|
tuple([v.name for v in trainable_vars]))
|
||||||
|
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
|
@ -6528,14 +6563,14 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
'linear_model/aaa_embedding/weights:0',
|
'linear_model/aaa_embedding/weights:0',
|
||||||
'linear_model/aaa_embedding/embedding_weights:0',
|
'linear_model/aaa_embedding/embedding_weights:0',
|
||||||
)
|
)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
trainable_vars = {
|
trainable_vars = {
|
||||||
v.name: v
|
v.name: v
|
||||||
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
}
|
}
|
||||||
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
|
self.assertCountEqual(expected_var_names, trainable_vars.keys())
|
||||||
bias = trainable_vars['linear_model/bias_weights:0']
|
bias = trainable_vars['linear_model/bias_weights:0']
|
||||||
embedding_weights = trainable_vars[
|
embedding_weights = trainable_vars[
|
||||||
'linear_model/aaa_embedding/embedding_weights:0']
|
'linear_model/aaa_embedding/embedding_weights:0']
|
||||||
|
@ -6610,14 +6645,14 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
'linear_model/aaa_embedding/weights:0',
|
'linear_model/aaa_embedding/weights:0',
|
||||||
'linear_model/aaa_embedding/embedding_weights:0',
|
'linear_model/aaa_embedding/embedding_weights:0',
|
||||||
)
|
)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
trainable_vars = {
|
trainable_vars = {
|
||||||
v.name: v
|
v.name: v
|
||||||
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
}
|
}
|
||||||
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
|
self.assertCountEqual(expected_var_names, trainable_vars.keys())
|
||||||
bias = trainable_vars['linear_model/bias_weights:0']
|
bias = trainable_vars['linear_model/bias_weights:0']
|
||||||
embedding_weights = trainable_vars[
|
embedding_weights = trainable_vars[
|
||||||
'linear_model/aaa_embedding/embedding_weights:0']
|
'linear_model/aaa_embedding/embedding_weights:0']
|
||||||
|
@ -6972,15 +7007,26 @@ class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
{
|
{
|
||||||
'testcase_name': 'use_safe_embedding_lookup',
|
'testcase_name': 'use_safe_embedding_lookup',
|
||||||
'use_safe_embedding_lookup': True
|
'use_safe_embedding_lookup': True,
|
||||||
|
'partition_variables': False,
|
||||||
}, {
|
}, {
|
||||||
'testcase_name': 'dont_use_safe_embedding_lookup',
|
'testcase_name': 'dont_use_safe_embedding_lookup',
|
||||||
'use_safe_embedding_lookup': False
|
'use_safe_embedding_lookup': False,
|
||||||
|
'partition_variables': False,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'use_safe_embedding_lookup_partitioned',
|
||||||
|
'use_safe_embedding_lookup': True,
|
||||||
|
'partition_variables': True,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'dont_use_safe_embedding_lookup_partitioned',
|
||||||
|
'use_safe_embedding_lookup': False,
|
||||||
|
'partition_variables': True,
|
||||||
})
|
})
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_get_dense_tensor(self, use_safe_embedding_lookup):
|
def test_get_dense_tensor(self, use_safe_embedding_lookup,
|
||||||
|
partition_variables):
|
||||||
# Inputs.
|
# Inputs.
|
||||||
vocabulary_size = 3
|
vocabulary_size = 4
|
||||||
# -1 values are ignored.
|
# -1 values are ignored.
|
||||||
input_a = np.array([
|
input_a = np.array([
|
||||||
[2, -1, -1], # example 0, ids [2]
|
[2, -1, -1], # example 0, ids [2]
|
||||||
|
@ -6997,13 +7043,20 @@ class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
embedding_values = (
|
embedding_values = (
|
||||||
(1., 2.), # id 0
|
(1., 2.), # id 0
|
||||||
(3., 5.), # id 1
|
(3., 5.), # id 1
|
||||||
(7., 11.) # id 2
|
(7., 11.), # id 2
|
||||||
|
(9., 13.) # id 3
|
||||||
)
|
)
|
||||||
|
|
||||||
def _initializer(shape, dtype, partition_info=None):
|
def _initializer(shape, dtype, partition_info=None):
|
||||||
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
if partition_variables:
|
||||||
|
self.assertEqual([vocabulary_size, embedding_dimension],
|
||||||
|
partition_info.full_shape)
|
||||||
|
self.assertAllEqual((2, embedding_dimension), shape)
|
||||||
|
else:
|
||||||
|
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
||||||
|
self.assertIsNone(partition_info)
|
||||||
|
|
||||||
self.assertEqual(dtypes.float32, dtype)
|
self.assertEqual(dtypes.float32, dtype)
|
||||||
self.assertIsNone(partition_info)
|
|
||||||
return embedding_values
|
return embedding_values
|
||||||
|
|
||||||
# Expected lookup result, using combiner='mean'.
|
# Expected lookup result, using combiner='mean'.
|
||||||
|
@ -7031,22 +7084,32 @@ class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
key='aaa', num_buckets=vocabulary_size)
|
key='aaa', num_buckets=vocabulary_size)
|
||||||
categorical_column_b = fc.categorical_column_with_identity(
|
categorical_column_b = fc.categorical_column_with_identity(
|
||||||
key='bbb', num_buckets=vocabulary_size)
|
key='bbb', num_buckets=vocabulary_size)
|
||||||
embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
|
|
||||||
[categorical_column_a, categorical_column_b],
|
|
||||||
dimension=embedding_dimension,
|
|
||||||
initializer=_initializer,
|
|
||||||
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
|
||||||
|
|
||||||
# Provide sparse input and get dense result.
|
partitioner = None
|
||||||
embedding_lookup_a = embedding_column_a.get_dense_tensor(
|
if partition_variables:
|
||||||
fc.FeatureTransformationCache(input_features), None)
|
partitioner = partitioned_variables.fixed_size_partitioner(2, axis=0)
|
||||||
embedding_lookup_b = embedding_column_b.get_dense_tensor(
|
|
||||||
fc.FeatureTransformationCache(input_features), None)
|
with variable_scope.variable_scope('vars', partitioner=partitioner):
|
||||||
|
embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
|
||||||
|
[categorical_column_a, categorical_column_b],
|
||||||
|
dimension=embedding_dimension,
|
||||||
|
initializer=_initializer,
|
||||||
|
use_safe_embedding_lookup=use_safe_embedding_lookup)
|
||||||
|
# Provide sparse input and get dense result.
|
||||||
|
embedding_lookup_a = embedding_column_a.get_dense_tensor(
|
||||||
|
fc.FeatureTransformationCache(input_features), None)
|
||||||
|
embedding_lookup_b = embedding_column_b.get_dense_tensor(
|
||||||
|
fc.FeatureTransformationCache(input_features), None)
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('aaa_bbb_shared_embedding:0',),
|
if partition_variables:
|
||||||
tuple([v.name for v in global_vars]))
|
self.assertCountEqual(('vars/aaa_bbb_shared_embedding/part_0:0',
|
||||||
|
'vars/aaa_bbb_shared_embedding/part_1:0'),
|
||||||
|
tuple([v.name for v in global_vars]))
|
||||||
|
else:
|
||||||
|
self.assertCountEqual(('vars/aaa_bbb_shared_embedding:0',),
|
||||||
|
tuple([v.name for v in global_vars]))
|
||||||
embedding_var = global_vars[0]
|
embedding_var = global_vars[0]
|
||||||
|
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
|
@ -7279,14 +7342,14 @@ class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
'aaa_bbb_shared_embedding:0',
|
'aaa_bbb_shared_embedding:0',
|
||||||
'linear_model/bbb_shared_embedding/weights:0',
|
'linear_model/bbb_shared_embedding/weights:0',
|
||||||
)
|
)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
expected_var_names,
|
expected_var_names,
|
||||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||||
trainable_vars = {
|
trainable_vars = {
|
||||||
v.name: v
|
v.name: v
|
||||||
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
}
|
}
|
||||||
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
|
self.assertCountEqual(expected_var_names, trainable_vars.keys())
|
||||||
bias = trainable_vars['linear_model/bias_weights:0']
|
bias = trainable_vars['linear_model/bias_weights:0']
|
||||||
embedding_weights = trainable_vars['aaa_bbb_shared_embedding:0']
|
embedding_weights = trainable_vars['aaa_bbb_shared_embedding:0']
|
||||||
linear_weights_a = trainable_vars[
|
linear_weights_a = trainable_vars[
|
||||||
|
@ -7420,18 +7483,18 @@ class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Assert expected embedding variable and lookups.
|
# Assert expected embedding variable and lookups.
|
||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
|
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
for v in global_vars:
|
for v in global_vars:
|
||||||
self.assertIsInstance(v, variables_lib.Variable)
|
self.assertIsInstance(v, variables_lib.Variable)
|
||||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
if trainable:
|
if trainable:
|
||||||
self.assertItemsEqual(
|
self.assertCountEqual(
|
||||||
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
|
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
|
||||||
tuple([v.name for v in trainable_vars]))
|
tuple([v.name for v in trainable_vars]))
|
||||||
else:
|
else:
|
||||||
self.assertItemsEqual([], tuple([v.name for v in trainable_vars]))
|
self.assertCountEqual([], tuple([v.name for v in trainable_vars]))
|
||||||
shared_embedding_vars = global_vars
|
shared_embedding_vars = global_vars
|
||||||
|
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
|
|
Loading…
Reference in New Issue