Added unknown rank support to dense to sparse conversion withing categorical columns. Used same logic with contrib.layers.dense_to_sparse.

PiperOrigin-RevId: 178305360
This commit is contained in:
Mustafa Ispir 2017-12-07 15:28:37 -08:00 committed by TensorFlower Gardener
parent f37380b064
commit 51fa3f7fef
2 changed files with 46 additions and 47 deletions

View File

@ -1988,29 +1988,26 @@ def _to_sparse_input(input_tensor, ignore_value=None):
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
return input_tensor return input_tensor
with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)): with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
input_rank = input_tensor.get_shape().ndims
if input_rank is None:
# TODO(b/32318825): Implement dense_to_sparse_tensor for undefined rank.
raise ValueError('Undefined input_tensor shape.')
if ignore_value is None: if ignore_value is None:
ignore_value = '' if input_tensor.dtype == dtypes.string else -1 if input_tensor.dtype == dtypes.string:
dense_shape = math_ops.cast(array_ops.shape(input_tensor), dtypes.int64) # Exception due to TF strings are converted to numpy objects by default.
indices = array_ops.where(math_ops.not_equal( ignore_value = ''
input_tensor, math_ops.cast(ignore_value, input_tensor.dtype))) elif input_tensor.dtype.is_integer:
# Flattens the tensor and indices for use with gather. ignore_value = -1 # -1 has a special meaning of missing feature
flat_tensor = array_ops.reshape(input_tensor, [-1]) else:
flat_indices = indices[:, input_rank - 1] # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
# Computes the correct flattened indices for 2d (or higher) tensors. # constructing a new numpy object of the given type, which yields the
if input_rank > 1: # default value for that type.
higher_dims = indices[:, :input_rank - 1] ignore_value = input_tensor.dtype.as_numpy_dtype()
shape_offsets = array_ops.stack( ignore_value = math_ops.cast(
_shape_offsets(array_ops.unstack(dense_shape)[1:])) ignore_value, input_tensor.dtype, name='ignore_value')
offsets = math_ops.reduce_sum( indices = array_ops.where(
math_ops.multiply(higher_dims, shape_offsets), math_ops.not_equal(input_tensor, ignore_value), name='indices')
reduction_indices=[1]) return sparse_tensor_lib.SparseTensor(
flat_indices = math_ops.add(flat_indices, offsets) indices=indices,
values = array_ops.gather(flat_tensor, flat_indices) values=array_ops.gather_nd(input_tensor, indices, name='values'),
return sparse_tensor_lib.SparseTensor(indices, values, dense_shape) dense_shape=array_ops.shape(
input_tensor, out_type=dtypes.int64, name='dense_shape'))
def _clean_feature_columns(feature_columns): def _clean_feature_columns(feature_columns):

View File

@ -1650,8 +1650,9 @@ class LinearModelTest(test.TestCase):
indices=((0,), (1,)), indices=((0,), (1,)),
values=('sedan', 'hardtop'), values=('sedan', 'hardtop'),
dense_shape=(2,)) dense_shape=(2,))
country_data = np.array(['US', 'CA'])
net = fc.linear_model(features, [price_buckets, body_style]) net = fc.linear_model(features, [price_buckets, body_style, country])
bias = get_linear_model_bias() bias = get_linear_model_bias()
price_buckets_var = get_linear_model_column_var(price_buckets) price_buckets_var = get_linear_model_column_var(price_buckets)
body_style_var = get_linear_model_column_var(body_style) body_style_var = get_linear_model_column_var(body_style)
@ -1660,15 +1661,14 @@ class LinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.])) sess.run(bias.assign([5.]))
self.assertAllClose( self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
[[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(
sess.run(net, feed_dict={ net,
features['price']: price_data, feed_dict={
features['body-style']: body_style_data})) features['price']: price_data,
features['body-style']: body_style_data,
# Dense categorical_column with unknown shape is not allowed. features['country']: country_data
with self.assertRaisesRegexp(ValueError, 'Undefined input_tensor shape.'): }))
fc.linear_model(features, [price_buckets, body_style, country])
def test_with_rank_0_feature(self): def test_with_rank_0_feature(self):
price = fc.numeric_column('price') price = fc.numeric_column('price')
@ -2119,9 +2119,9 @@ class FunctionalInputLayerTest(test.TestCase):
def test_with_1d_unknown_shape_sparse_tensor(self): def test_with_1d_unknown_shape_sparse_tensor(self):
embedding_values = ( embedding_values = (
(1., 2., 3., 4., 5.), # id 0 (1., 2.), # id 0
(6., 7., 8., 9., 10.), # id 1 (6., 7.), # id 1
(11., 12., 13., 14., 15.) # id 2 (11., 12.) # id 2
) )
def _initializer(shape, dtype, partition_info): def _initializer(shape, dtype, partition_info):
del shape, dtype, partition_info del shape, dtype, partition_info
@ -2138,8 +2138,8 @@ class FunctionalInputLayerTest(test.TestCase):
# embedded_body_style has 5 dims in input_layer. # embedded_body_style has 5 dims in input_layer.
country = fc.categorical_column_with_vocabulary_list( country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA']) 'country', vocabulary_list=['US', 'JP', 'CA'])
embedded_country = fc.embedding_column(country, dimension=5, embedded_country = fc.embedding_column(
initializer=_initializer) country, dimension=2, initializer=_initializer)
# Provides 1-dim tensor and dense tensor. # Provides 1-dim tensor and dense tensor.
features = { features = {
@ -2157,22 +2157,24 @@ class FunctionalInputLayerTest(test.TestCase):
indices=((0,), (1,)), indices=((0,), (1,)),
values=('sedan', 'hardtop'), values=('sedan', 'hardtop'),
dense_shape=(2,)) dense_shape=(2,))
country_data = np.array([['US'], ['CA']])
# Dense categorical_column with unknown shape is not allowed. net = fc.input_layer(features,
with self.assertRaisesRegexp(ValueError, 'Undefined input_tensor shape.'): [price, one_hot_body_style, embedded_country])
fc.input_layer(features, [price, one_hot_body_style, embedded_country]) self.assertEqual(1 + 3 + 2, net.shape[1])
net = fc.input_layer(features, [price, one_hot_body_style])
self.assertEqual(1 + 3, net.shape[1])
with _initialized_session() as sess: with _initialized_session() as sess:
# Each row is formed by concatenating `embedded_body_style`, # Each row is formed by concatenating `embedded_body_style`,
# `one_hot_body_style`, and `price` in order. # `one_hot_body_style`, and `price` in order.
self.assertAllEqual( self.assertAllEqual(
[[0., 0., 1., 11.], [1., 0., 0., 12.]], [[0., 0., 1., 1., 2., 11.], [1., 0., 0., 11., 12., 12.]],
sess.run(net, feed_dict={ sess.run(
features['price']: price_data, net,
features['body-style']: body_style_data})) feed_dict={
features['price']: price_data,
features['body-style']: body_style_data,
features['country']: country_data
}))
def test_with_rank_0_feature(self): def test_with_rank_0_feature(self):
# price has 1 dimension in input_layer # price has 1 dimension in input_layer