From 7b81a79dbb18d2ef3a73ebdd24bad253dfa5b76c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Mar 2019 12:57:09 -0800 Subject: [PATCH] Move common feature column helper functions to separate file. PiperOrigin-RevId: 236707892 --- tensorflow/contrib/feature_column/BUILD | 1 + .../feature_column/sequence_feature_column.py | 3 +- tensorflow/python/feature_column/BUILD | 16 ++ .../python/feature_column/feature_column.py | 162 ++--------------- .../feature_column/feature_column_v2.py | 165 +++--------------- .../feature_column/sequence_feature_column.py | 3 +- tensorflow/python/feature_column/utils.py | 154 ++++++++++++++++ 7 files changed, 213 insertions(+), 291 deletions(-) create mode 100644 tensorflow/python/feature_column/utils.py diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 0a9199d61f3..edd6f36e07c 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -33,6 +33,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:variable_scope", "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/feature_column:utils", ], ) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 9b3a5c58aaa..64df44fe436 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -23,6 +23,7 @@ import collections from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import utils as fc_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -506,7 +507,7 @@ class _SequenceNumericColumn( # sequence length is not affected. num_elements = (self._variable_shape.num_elements() if sp_tensor.shape.ndims == 2 else 1) - seq_length = fc._sequence_length_from_sparse_tensor( + seq_length = fc_utils.sequence_length_from_sparse_tensor( sp_tensor, num_elements=num_elements) return fc._SequenceDenseColumn.TensorSequenceLengthPair( diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index 8caf46e3fa7..d696f7cb462 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -23,6 +23,7 @@ py_library( srcs = ["feature_column.py"], srcs_version = "PY2AND3", deps = [ + ":utils", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", @@ -57,6 +58,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_column", + ":utils", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", @@ -172,6 +174,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_column_v2", + ":utils", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:dtypes", @@ -184,6 +187,19 @@ py_library( ], ) +py_library( + name = "utils", + srcs = ["utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + ], +) + tf_py_test( name = "sequence_feature_column_test", srcs = ["sequence_feature_column_test.py"], diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 42a07cd9275..8c0300d204f 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -138,8 +138,8 @@ import math import numpy as np import six - from tensorflow.python.eager import context +from tensorflow.python.feature_column import utils as fc_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib @@ -982,13 +982,14 @@ def _numeric_column(key, if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) - default_value = _check_default_value(shape, default_value, dtype, key) + default_value = fc_utils.check_default_value( + shape, default_value, dtype, key) if normalizer_fn is not None and not callable(normalizer_fn): raise TypeError( 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) - _assert_key_is_string(key) + fc_utils.assert_key_is_string(key) return _NumericColumn( key, shape=shape, @@ -1080,19 +1081,6 @@ def _bucketized_column(source_column, boundaries): return _BucketizedColumn(source_column, tuple(boundaries)) -def _assert_string_or_int(dtype, prefix): - if (dtype != dtypes.string) and (not dtype.is_integer): - raise ValueError( - '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) - - -def _assert_key_is_string(key): - if not isinstance(key, six.string_types): - raise ValueError( - 'key must be a string. Got: type {}. Given key: {}.'.format( - type(key), key)) - - def _categorical_column_with_hash_bucket(key, hash_bucket_size, dtype=dtypes.string): @@ -1145,8 +1133,8 @@ def _categorical_column_with_hash_bucket(key, 'hash_bucket_size: {}, key: {}'.format( hash_bucket_size, key)) - _assert_key_is_string(key) - _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + fc_utils.assert_key_is_string(key) + fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) return _HashedCategoricalColumn(key, hash_bucket_size, dtype) @@ -1259,8 +1247,8 @@ def _categorical_column_with_vocabulary_file(key, if num_oov_buckets < 0: raise ValueError('Invalid num_oov_buckets {} in {}.'.format( num_oov_buckets, key)) - _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) - _assert_key_is_string(key) + fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + fc_utils.assert_key_is_string(key) return _VocabularyFileCategoricalColumn( key=key, vocabulary_file=vocabulary_file, @@ -1367,7 +1355,7 @@ def _categorical_column_with_vocabulary_list(key, if num_oov_buckets < 0: raise ValueError('Invalid num_oov_buckets {} in {}.'.format( num_oov_buckets, key)) - _assert_string_or_int( + fc_utils.assert_string_or_int( vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key)) if dtype is None: dtype = vocabulary_dtype @@ -1375,8 +1363,8 @@ def _categorical_column_with_vocabulary_list(key, raise ValueError( 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( dtype, vocabulary_dtype, key)) - _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) - _assert_key_is_string(key) + fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + fc_utils.assert_key_is_string(key) return _VocabularyListCategoricalColumn( key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype, @@ -1445,7 +1433,7 @@ def _categorical_column_with_identity(key, num_buckets, default_value=None): raise ValueError( 'default_value {} not in range [0, {}), column_name {}'.format( default_value, num_buckets, key)) - _assert_key_is_string(key) + fc_utils.assert_key_is_string(key) return _IdentityCategoricalColumn( key=key, num_buckets=num_buckets, default_value=default_value) @@ -2495,7 +2483,7 @@ class _EmbeddingColumn( trainable=trainable) sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access - sequence_length = _sequence_length_from_sparse_tensor( + sequence_length = fc_utils.sequence_length_from_sparse_tensor( sparse_tensors.id_tensor) return _SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) @@ -2637,25 +2625,12 @@ class _SharedEmbeddingColumn( weight_collections=weight_collections, trainable=trainable) sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access - sequence_length = _sequence_length_from_sparse_tensor( + sequence_length = fc_utils.sequence_length_from_sparse_tensor( sparse_tensors.id_tensor) return _SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) -def _create_tuple(shape, value): - """Returns a tuple with given shape and filled with value.""" - if shape: - return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])]) - return value - - -def _as_tuple(value): - if not nest.is_sequence(value): - return value - return tuple([_as_tuple(v) for v in value]) - - def _check_shape(shape, key): """Returns shape if it's valid, raises error otherwise.""" assert shape is not None @@ -2672,82 +2647,6 @@ def _check_shape(shape, key): return shape -def _is_shape_and_default_value_compatible(default_value, shape): - """Verifies compatibility of shape and default_value.""" - # Invalid condition: - # * if default_value is not a scalar and shape is empty - # * or if default_value is an iterable and shape is not empty - if nest.is_sequence(default_value) != bool(shape): - return False - if not shape: - return True - if len(default_value) != shape[0]: - return False - for i in range(shape[0]): - if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]): - return False - return True - - -def _check_default_value(shape, default_value, dtype, key): - """Returns default value as tuple if it's valid, otherwise raises errors. - - This function verifies that `default_value` is compatible with both `shape` - and `dtype`. If it is not compatible, it raises an error. If it is compatible, - it casts default_value to a tuple and returns it. `key` is used only - for error message. - - Args: - shape: An iterable of integers specifies the shape of the `Tensor`. - default_value: If a single value is provided, the same value will be applied - as the default value for every item. If an iterable of values is - provided, the shape of the `default_value` should be equal to the given - `shape`. - dtype: defines the type of values. Default value is `tf.float32`. Must be a - non-quantized, real integer or floating point type. - key: Column name, used only for error messages. - - Returns: - A tuple which will be used as default value. - - Raises: - TypeError: if `default_value` is an iterable but not compatible with `shape` - TypeError: if `default_value` is not compatible with `dtype`. - ValueError: if `dtype` is not convertible to `tf.float32`. - """ - if default_value is None: - return None - - if isinstance(default_value, int): - return _create_tuple(shape, default_value) - - if isinstance(default_value, float) and dtype.is_floating: - return _create_tuple(shape, default_value) - - if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays - default_value = default_value.tolist() - - if nest.is_sequence(default_value): - if not _is_shape_and_default_value_compatible(default_value, shape): - raise ValueError( - 'The shape of default_value must be equal to given shape. ' - 'default_value: {}, shape: {}, key: {}'.format( - default_value, shape, key)) - # Check if the values in the list are all integers or are convertible to - # floats. - is_list_all_int = all( - isinstance(v, int) for v in nest.flatten(default_value)) - is_list_has_float = any( - isinstance(v, float) for v in nest.flatten(default_value)) - if is_list_all_int: - return _as_tuple(default_value) - if is_list_has_float and dtype.is_floating: - return _as_tuple(default_value) - raise TypeError('default_value must be compatible with dtype. ' - 'default_value: {}, dtype: {}, key: {}'.format( - default_value, dtype, key)) - - class _HashedCategoricalColumn( _CategoricalColumn, collections.namedtuple('_HashedCategoricalColumn', @@ -2767,7 +2666,7 @@ class _HashedCategoricalColumn( if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): raise ValueError('SparseColumn input must be a SparseTensor.') - _assert_string_or_int( + fc_utils.assert_string_or_int( input_tensor.dtype, prefix='column_name: {} input_tensor'.format(self.key)) @@ -2822,7 +2721,7 @@ class _VocabularyFileCategoricalColumn( 'key: {}, column dtype: {}, tensor dtype: {}'.format( self.key, self.dtype, input_tensor.dtype)) - _assert_string_or_int( + fc_utils.assert_string_or_int( input_tensor.dtype, prefix='column_name: {} input_tensor'.format(self.key)) @@ -2874,7 +2773,7 @@ class _VocabularyListCategoricalColumn( 'key: {}, column dtype: {}, tensor dtype: {}'.format( self.key, self.dtype, input_tensor.dtype)) - _assert_string_or_int( + fc_utils.assert_string_or_int( input_tensor.dtype, prefix='column_name: {} input_tensor'.format(self.key)) @@ -3210,7 +3109,7 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn, # representation created by _transform_feature. dense_tensor = inputs.get(self) sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access - sequence_length = _sequence_length_from_sparse_tensor( + sequence_length = fc_utils.sequence_length_from_sparse_tensor( sparse_tensors.id_tensor) return _SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) @@ -3242,31 +3141,6 @@ def _verify_static_batch_size_equality(tensors, columns): expected_batch_size, tensors[i].shape.dims[0])) -def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): - """Returns a [batch_size] Tensor with per-example sequence length.""" - with ops.name_scope(None, 'sequence_length') as name_scope: - row_ids = sp_tensor.indices[:, 0] - column_ids = sp_tensor.indices[:, 1] - # Add one to convert column indices to element length - column_ids += array_ops.ones_like(column_ids) - # Get the number of elements we will have per example/row - seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids) - - # The raw values are grouped according to num_elements; - # how many entities will we have after grouping? - # Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1), - # row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2, - # these will get grouped, and the final seq_length is [1, 1] - seq_length = math_ops.cast( - math_ops.ceil(seq_length / num_elements), dtypes.int64) - - # If the last n rows do not have ids, seq_length will have shape - # [batch_size - n]. Pad the remaining values with zeros. - n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] - padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) - return array_ops.concat([seq_length, padding], axis=0, name=name_scope) - - class _SequenceCategoricalColumn( _CategoricalColumn, collections.namedtuple( diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 3b9f527061b..5b7d794c398 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -137,6 +137,7 @@ import six from tensorflow.python.eager import context from tensorflow.python.feature_column import feature_column as fc_old +from tensorflow.python.feature_column import utils as fc_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib @@ -1318,13 +1319,14 @@ def numeric_column(key, if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) - default_value = _check_default_value(shape, default_value, dtype, key) + default_value = fc_utils.check_default_value( + shape, default_value, dtype, key) if normalizer_fn is not None and not callable(normalizer_fn): raise TypeError( 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) - _assert_key_is_string(key) + fc_utils.assert_key_is_string(key) return NumericColumn( key, shape=shape, @@ -1418,19 +1420,6 @@ def bucketized_column(source_column, boundaries): return BucketizedColumn(source_column, tuple(boundaries)) -def _assert_string_or_int(dtype, prefix): - if (dtype != dtypes.string) and (not dtype.is_integer): - raise ValueError( - '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) - - -def _assert_key_is_string(key): - if not isinstance(key, six.string_types): - raise ValueError( - 'key must be a string. Got: type {}. Given key: {}.'.format( - type(key), key)) - - @tf_export('feature_column.categorical_column_with_hash_bucket') def categorical_column_with_hash_bucket(key, hash_bucket_size, @@ -1484,8 +1473,8 @@ def categorical_column_with_hash_bucket(key, 'hash_bucket_size: {}, key: {}'.format( hash_bucket_size, key)) - _assert_key_is_string(key) - _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + fc_utils.assert_key_is_string(key) + fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) return HashedCategoricalColumn(key, hash_bucket_size, dtype) @@ -1690,8 +1679,8 @@ def categorical_column_with_vocabulary_file_v2(key, if num_oov_buckets < 0: raise ValueError('Invalid num_oov_buckets {} in {}.'.format( num_oov_buckets, key)) - _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) - _assert_key_is_string(key) + fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + fc_utils.assert_key_is_string(key) return VocabularyFileCategoricalColumn( key=key, vocabulary_file=vocabulary_file, @@ -1799,7 +1788,7 @@ def categorical_column_with_vocabulary_list(key, if num_oov_buckets < 0: raise ValueError('Invalid num_oov_buckets {} in {}.'.format( num_oov_buckets, key)) - _assert_string_or_int( + fc_utils.assert_string_or_int( vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key)) if dtype is None: dtype = vocabulary_dtype @@ -1807,8 +1796,8 @@ def categorical_column_with_vocabulary_list(key, raise ValueError( 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( dtype, vocabulary_dtype, key)) - _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) - _assert_key_is_string(key) + fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + fc_utils.assert_key_is_string(key) return VocabularyListCategoricalColumn( key=key, @@ -1881,7 +1870,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None): raise ValueError( 'default_value {} not in range [0, {}), column_name {}'.format( default_value, num_buckets, key)) - _assert_key_is_string(key) + fc_utils.assert_key_is_string(key) return IdentityCategoricalColumn( key=key, number_buckets=num_buckets, default_value=default_value) @@ -3166,7 +3155,7 @@ class EmbeddingColumn( transformation_cache, state_manager) dense_tensor = self._get_dense_tensor_internal(sparse_tensors, state_manager) - sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access + sequence_length = fc_utils.sequence_length_from_sparse_tensor( sparse_tensors.id_tensor) return SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) @@ -3192,7 +3181,7 @@ class EmbeddingColumn( sparse_tensors, weight_collections=weight_collections, trainable=trainable) - sequence_length = _sequence_length_from_sparse_tensor( + sequence_length = fc_utils.sequence_length_from_sparse_tensor( sparse_tensors.id_tensor) return SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) @@ -3376,7 +3365,7 @@ class SharedEmbeddingColumn( state_manager) sparse_tensors = self.categorical_column.get_sparse_tensors( transformation_cache, state_manager) - sequence_length = _sequence_length_from_sparse_tensor( + sequence_length = fc_utils.sequence_length_from_sparse_tensor( sparse_tensors.id_tensor) return SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) @@ -3402,19 +3391,6 @@ class SharedEmbeddingColumn( raise NotImplementedError() -def _create_tuple(shape, value): - """Returns a tuple with given shape and filled with value.""" - if shape: - return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])]) - return value - - -def _as_tuple(value): - if not nest.is_sequence(value): - return value - return tuple([_as_tuple(v) for v in value]) - - def _check_shape(shape, key): """Returns shape if it's valid, raises error otherwise.""" assert shape is not None @@ -3431,82 +3407,6 @@ def _check_shape(shape, key): return shape -def _is_shape_and_default_value_compatible(default_value, shape): - """Verifies compatibility of shape and default_value.""" - # Invalid condition: - # * if default_value is not a scalar and shape is empty - # * or if default_value is an iterable and shape is not empty - if nest.is_sequence(default_value) != bool(shape): - return False - if not shape: - return True - if len(default_value) != shape[0]: - return False - for i in range(shape[0]): - if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]): - return False - return True - - -def _check_default_value(shape, default_value, dtype, key): - """Returns default value as tuple if it's valid, otherwise raises errors. - - This function verifies that `default_value` is compatible with both `shape` - and `dtype`. If it is not compatible, it raises an error. If it is compatible, - it casts default_value to a tuple and returns it. `key` is used only - for error message. - - Args: - shape: An iterable of integers specifies the shape of the `Tensor`. - default_value: If a single value is provided, the same value will be applied - as the default value for every item. If an iterable of values is - provided, the shape of the `default_value` should be equal to the given - `shape`. - dtype: defines the type of values. Default value is `tf.float32`. Must be a - non-quantized, real integer or floating point type. - key: Column name, used only for error messages. - - Returns: - A tuple which will be used as default value. - - Raises: - TypeError: if `default_value` is an iterable but not compatible with `shape` - TypeError: if `default_value` is not compatible with `dtype`. - ValueError: if `dtype` is not convertible to `tf.float32`. - """ - if default_value is None: - return None - - if isinstance(default_value, int): - return _create_tuple(shape, default_value) - - if isinstance(default_value, float) and dtype.is_floating: - return _create_tuple(shape, default_value) - - if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays - default_value = default_value.tolist() - - if nest.is_sequence(default_value): - if not _is_shape_and_default_value_compatible(default_value, shape): - raise ValueError( - 'The shape of default_value must be equal to given shape. ' - 'default_value: {}, shape: {}, key: {}'.format( - default_value, shape, key)) - # Check if the values in the list are all integers or are convertible to - # floats. - is_list_all_int = all( - isinstance(v, int) for v in nest.flatten(default_value)) - is_list_has_float = any( - isinstance(v, float) for v in nest.flatten(default_value)) - if is_list_all_int: - return _as_tuple(default_value) - if is_list_has_float and dtype.is_floating: - return _as_tuple(default_value) - raise TypeError('default_value must be compatible with dtype. ' - 'default_value: {}, dtype: {}, key: {}'.format( - default_value, dtype, key)) - - class HashedCategoricalColumn( CategoricalColumn, fc_old._CategoricalColumn, # pylint: disable=protected-access @@ -3539,7 +3439,7 @@ class HashedCategoricalColumn( if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): raise ValueError('SparseColumn input must be a SparseTensor.') - _assert_string_or_int( + fc_utils.assert_string_or_int( input_tensor.dtype, prefix='column_name: {} input_tensor'.format(self.key)) @@ -3651,7 +3551,7 @@ class VocabularyFileCategoricalColumn( 'key: {}, column dtype: {}, tensor dtype: {}'.format( self.key, self.dtype, input_tensor.dtype)) - _assert_string_or_int( + fc_utils.assert_string_or_int( input_tensor.dtype, prefix='column_name: {} input_tensor'.format(self.key)) @@ -3763,7 +3663,7 @@ class VocabularyListCategoricalColumn( 'key: {}, column dtype: {}, tensor dtype: {}'.format( self.key, self.dtype, input_tensor.dtype)) - _assert_string_or_int( + fc_utils.assert_string_or_int( input_tensor.dtype, prefix='column_name: {} input_tensor'.format(self.key)) @@ -4426,7 +4326,7 @@ class IndicatorColumn( dense_tensor = transformation_cache.get(self, state_manager) sparse_tensors = self.categorical_column.get_sparse_tensors( transformation_cache, state_manager) - sequence_length = _sequence_length_from_sparse_tensor( + sequence_length = fc_utils.sequence_length_from_sparse_tensor( sparse_tensors.id_tensor) return SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) @@ -4455,7 +4355,7 @@ class IndicatorColumn( # representation created by _transform_feature. dense_tensor = inputs.get(self) sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access - sequence_length = _sequence_length_from_sparse_tensor( + sequence_length = fc_utils.sequence_length_from_sparse_tensor( sparse_tensors.id_tensor) return SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) @@ -4509,31 +4409,6 @@ def _verify_static_batch_size_equality(tensors, columns): expected_batch_size, batch_size)) -def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): - """Returns a [batch_size] Tensor with per-example sequence length.""" - with ops.name_scope(None, 'sequence_length') as name_scope: - row_ids = sp_tensor.indices[:, 0] - column_ids = sp_tensor.indices[:, 1] - # Add one to convert column indices to element length - column_ids += array_ops.ones_like(column_ids) - # Get the number of elements we will have per example/row - seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids) - - # The raw values are grouped according to num_elements; - # how many entities will we have after grouping? - # Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1), - # row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2, - # these will get grouped, and the final seq_length is [1, 1] - seq_length = math_ops.cast( - math_ops.ceil(seq_length / num_elements), dtypes.int64) - - # If the last n rows do not have ids, seq_length will have shape - # [batch_size - n]. Pad the remaining values with zeros. - n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] - padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) - return array_ops.concat([seq_length, padding], axis=0, name=name_scope) - - class SequenceCategoricalColumn( CategoricalColumn, fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access diff --git a/tensorflow/python/feature_column/sequence_feature_column.py b/tensorflow/python/feature_column/sequence_feature_column.py index bc58c413fef..7e31497db87 100644 --- a/tensorflow/python/feature_column/sequence_feature_column.py +++ b/tensorflow/python/feature_column/sequence_feature_column.py @@ -26,6 +26,7 @@ import collections from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.feature_column import utils as fc_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -564,7 +565,7 @@ class SequenceNumericColumn( num_elements = self.variable_shape.num_elements() else: num_elements = 1 - seq_length = fc._sequence_length_from_sparse_tensor( + seq_length = fc_utils.sequence_length_from_sparse_tensor( sp_tensor, num_elements=num_elements) return fc.SequenceDenseColumn.TensorSequenceLengthPair( diff --git a/tensorflow/python/feature_column/utils.py b/tensorflow/python/feature_column/utils.py new file mode 100644 index 00000000000..0dd17aadc29 --- /dev/null +++ b/tensorflow/python/feature_column/utils.py @@ -0,0 +1,154 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines functions common to multiple feature column files.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest + + +def sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): + """Returns a [batch_size] Tensor with per-example sequence length.""" + with ops.name_scope(None, 'sequence_length') as name_scope: + row_ids = sp_tensor.indices[:, 0] + column_ids = sp_tensor.indices[:, 1] + # Add one to convert column indices to element length + column_ids += array_ops.ones_like(column_ids) + # Get the number of elements we will have per example/row + seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids) + + # The raw values are grouped according to num_elements; + # how many entities will we have after grouping? + # Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1), + # row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2, + # these will get grouped, and the final seq_length is [1, 1] + seq_length = math_ops.cast( + math_ops.ceil(seq_length / num_elements), dtypes.int64) + + # If the last n rows do not have ids, seq_length will have shape + # [batch_size - n]. Pad the remaining values with zeros. + n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] + padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) + return array_ops.concat([seq_length, padding], axis=0, name=name_scope) + + +def assert_string_or_int(dtype, prefix): + if (dtype != dtypes.string) and (not dtype.is_integer): + raise ValueError( + '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) + + +def assert_key_is_string(key): + if not isinstance(key, six.string_types): + raise ValueError( + 'key must be a string. Got: type {}. Given key: {}.'.format( + type(key), key)) + + +def check_default_value(shape, default_value, dtype, key): + """Returns default value as tuple if it's valid, otherwise raises errors. + + This function verifies that `default_value` is compatible with both `shape` + and `dtype`. If it is not compatible, it raises an error. If it is compatible, + it casts default_value to a tuple and returns it. `key` is used only + for error message. + + Args: + shape: An iterable of integers specifies the shape of the `Tensor`. + default_value: If a single value is provided, the same value will be applied + as the default value for every item. If an iterable of values is + provided, the shape of the `default_value` should be equal to the given + `shape`. + dtype: defines the type of values. Default value is `tf.float32`. Must be a + non-quantized, real integer or floating point type. + key: Column name, used only for error messages. + + Returns: + A tuple which will be used as default value. + + Raises: + TypeError: if `default_value` is an iterable but not compatible with `shape` + TypeError: if `default_value` is not compatible with `dtype`. + ValueError: if `dtype` is not convertible to `tf.float32`. + """ + if default_value is None: + return None + + if isinstance(default_value, int): + return _create_tuple(shape, default_value) + + if isinstance(default_value, float) and dtype.is_floating: + return _create_tuple(shape, default_value) + + if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays + default_value = default_value.tolist() + + if nest.is_sequence(default_value): + if not _is_shape_and_default_value_compatible(default_value, shape): + raise ValueError( + 'The shape of default_value must be equal to given shape. ' + 'default_value: {}, shape: {}, key: {}'.format( + default_value, shape, key)) + # Check if the values in the list are all integers or are convertible to + # floats. + is_list_all_int = all( + isinstance(v, int) for v in nest.flatten(default_value)) + is_list_has_float = any( + isinstance(v, float) for v in nest.flatten(default_value)) + if is_list_all_int: + return _as_tuple(default_value) + if is_list_has_float and dtype.is_floating: + return _as_tuple(default_value) + raise TypeError('default_value must be compatible with dtype. ' + 'default_value: {}, dtype: {}, key: {}'.format( + default_value, dtype, key)) + + +def _create_tuple(shape, value): + """Returns a tuple with given shape and filled with value.""" + if shape: + return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])]) + return value + + +def _as_tuple(value): + if not nest.is_sequence(value): + return value + return tuple([_as_tuple(v) for v in value]) + + +def _is_shape_and_default_value_compatible(default_value, shape): + """Verifies compatibility of shape and default_value.""" + # Invalid condition: + # * if default_value is not a scalar and shape is empty + # * or if default_value is an iterable and shape is not empty + if nest.is_sequence(default_value) != bool(shape): + return False + if not shape: + return True + if len(default_value) != shape[0]: + return False + for i in range(shape[0]): + if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]): + return False + return True