Move common feature column helper functions to separate file.
PiperOrigin-RevId: 236707892
This commit is contained in:
parent
d79bd989b9
commit
7b81a79dbb
@ -33,6 +33,7 @@ py_library(
|
|||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python/feature_column:feature_column_py",
|
"//tensorflow/python/feature_column:feature_column_py",
|
||||||
|
"//tensorflow/python/feature_column:utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ import collections
|
|||||||
|
|
||||||
|
|
||||||
from tensorflow.python.feature_column import feature_column as fc
|
from tensorflow.python.feature_column import feature_column as fc
|
||||||
|
from tensorflow.python.feature_column import utils as fc_utils
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -506,7 +507,7 @@ class _SequenceNumericColumn(
|
|||||||
# sequence length is not affected.
|
# sequence length is not affected.
|
||||||
num_elements = (self._variable_shape.num_elements()
|
num_elements = (self._variable_shape.num_elements()
|
||||||
if sp_tensor.shape.ndims == 2 else 1)
|
if sp_tensor.shape.ndims == 2 else 1)
|
||||||
seq_length = fc._sequence_length_from_sparse_tensor(
|
seq_length = fc_utils.sequence_length_from_sparse_tensor(
|
||||||
sp_tensor, num_elements=num_elements)
|
sp_tensor, num_elements=num_elements)
|
||||||
|
|
||||||
return fc._SequenceDenseColumn.TensorSequenceLengthPair(
|
return fc._SequenceDenseColumn.TensorSequenceLengthPair(
|
||||||
|
@ -23,6 +23,7 @@ py_library(
|
|||||||
srcs = ["feature_column.py"],
|
srcs = ["feature_column.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":utils",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:check_ops",
|
"//tensorflow/python:check_ops",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
@ -57,6 +58,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":feature_column",
|
":feature_column",
|
||||||
|
":utils",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:check_ops",
|
"//tensorflow/python:check_ops",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
@ -172,6 +174,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":feature_column_v2",
|
":feature_column_v2",
|
||||||
|
":utils",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:check_ops",
|
"//tensorflow/python:check_ops",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
@ -184,6 +187,19 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "utils",
|
||||||
|
srcs = ["utils.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "sequence_feature_column_test",
|
name = "sequence_feature_column_test",
|
||||||
srcs = ["sequence_feature_column_test.py"],
|
srcs = ["sequence_feature_column_test.py"],
|
||||||
|
@ -138,8 +138,8 @@ import math
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.feature_column import utils as fc_utils
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||||
@ -982,13 +982,14 @@ def _numeric_column(key,
|
|||||||
if not (dtype.is_integer or dtype.is_floating):
|
if not (dtype.is_integer or dtype.is_floating):
|
||||||
raise ValueError('dtype must be convertible to float. '
|
raise ValueError('dtype must be convertible to float. '
|
||||||
'dtype: {}, key: {}'.format(dtype, key))
|
'dtype: {}, key: {}'.format(dtype, key))
|
||||||
default_value = _check_default_value(shape, default_value, dtype, key)
|
default_value = fc_utils.check_default_value(
|
||||||
|
shape, default_value, dtype, key)
|
||||||
|
|
||||||
if normalizer_fn is not None and not callable(normalizer_fn):
|
if normalizer_fn is not None and not callable(normalizer_fn):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
|
'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
|
||||||
|
|
||||||
_assert_key_is_string(key)
|
fc_utils.assert_key_is_string(key)
|
||||||
return _NumericColumn(
|
return _NumericColumn(
|
||||||
key,
|
key,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
@ -1080,19 +1081,6 @@ def _bucketized_column(source_column, boundaries):
|
|||||||
return _BucketizedColumn(source_column, tuple(boundaries))
|
return _BucketizedColumn(source_column, tuple(boundaries))
|
||||||
|
|
||||||
|
|
||||||
def _assert_string_or_int(dtype, prefix):
|
|
||||||
if (dtype != dtypes.string) and (not dtype.is_integer):
|
|
||||||
raise ValueError(
|
|
||||||
'{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
|
|
||||||
|
|
||||||
|
|
||||||
def _assert_key_is_string(key):
|
|
||||||
if not isinstance(key, six.string_types):
|
|
||||||
raise ValueError(
|
|
||||||
'key must be a string. Got: type {}. Given key: {}.'.format(
|
|
||||||
type(key), key))
|
|
||||||
|
|
||||||
|
|
||||||
def _categorical_column_with_hash_bucket(key,
|
def _categorical_column_with_hash_bucket(key,
|
||||||
hash_bucket_size,
|
hash_bucket_size,
|
||||||
dtype=dtypes.string):
|
dtype=dtypes.string):
|
||||||
@ -1145,8 +1133,8 @@ def _categorical_column_with_hash_bucket(key,
|
|||||||
'hash_bucket_size: {}, key: {}'.format(
|
'hash_bucket_size: {}, key: {}'.format(
|
||||||
hash_bucket_size, key))
|
hash_bucket_size, key))
|
||||||
|
|
||||||
_assert_key_is_string(key)
|
fc_utils.assert_key_is_string(key)
|
||||||
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
||||||
|
|
||||||
return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
|
return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
|
||||||
|
|
||||||
@ -1259,8 +1247,8 @@ def _categorical_column_with_vocabulary_file(key,
|
|||||||
if num_oov_buckets < 0:
|
if num_oov_buckets < 0:
|
||||||
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
||||||
num_oov_buckets, key))
|
num_oov_buckets, key))
|
||||||
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
||||||
_assert_key_is_string(key)
|
fc_utils.assert_key_is_string(key)
|
||||||
return _VocabularyFileCategoricalColumn(
|
return _VocabularyFileCategoricalColumn(
|
||||||
key=key,
|
key=key,
|
||||||
vocabulary_file=vocabulary_file,
|
vocabulary_file=vocabulary_file,
|
||||||
@ -1367,7 +1355,7 @@ def _categorical_column_with_vocabulary_list(key,
|
|||||||
if num_oov_buckets < 0:
|
if num_oov_buckets < 0:
|
||||||
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
||||||
num_oov_buckets, key))
|
num_oov_buckets, key))
|
||||||
_assert_string_or_int(
|
fc_utils.assert_string_or_int(
|
||||||
vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
|
vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = vocabulary_dtype
|
dtype = vocabulary_dtype
|
||||||
@ -1375,8 +1363,8 @@ def _categorical_column_with_vocabulary_list(key,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
|
'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
|
||||||
dtype, vocabulary_dtype, key))
|
dtype, vocabulary_dtype, key))
|
||||||
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
||||||
_assert_key_is_string(key)
|
fc_utils.assert_key_is_string(key)
|
||||||
|
|
||||||
return _VocabularyListCategoricalColumn(
|
return _VocabularyListCategoricalColumn(
|
||||||
key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype,
|
key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype,
|
||||||
@ -1445,7 +1433,7 @@ def _categorical_column_with_identity(key, num_buckets, default_value=None):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'default_value {} not in range [0, {}), column_name {}'.format(
|
'default_value {} not in range [0, {}), column_name {}'.format(
|
||||||
default_value, num_buckets, key))
|
default_value, num_buckets, key))
|
||||||
_assert_key_is_string(key)
|
fc_utils.assert_key_is_string(key)
|
||||||
return _IdentityCategoricalColumn(
|
return _IdentityCategoricalColumn(
|
||||||
key=key, num_buckets=num_buckets, default_value=default_value)
|
key=key, num_buckets=num_buckets, default_value=default_value)
|
||||||
|
|
||||||
@ -2495,7 +2483,7 @@ class _EmbeddingColumn(
|
|||||||
trainable=trainable)
|
trainable=trainable)
|
||||||
|
|
||||||
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
|
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
|
||||||
sequence_length = _sequence_length_from_sparse_tensor(
|
sequence_length = fc_utils.sequence_length_from_sparse_tensor(
|
||||||
sparse_tensors.id_tensor)
|
sparse_tensors.id_tensor)
|
||||||
return _SequenceDenseColumn.TensorSequenceLengthPair(
|
return _SequenceDenseColumn.TensorSequenceLengthPair(
|
||||||
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
||||||
@ -2637,25 +2625,12 @@ class _SharedEmbeddingColumn(
|
|||||||
weight_collections=weight_collections,
|
weight_collections=weight_collections,
|
||||||
trainable=trainable)
|
trainable=trainable)
|
||||||
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
|
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
|
||||||
sequence_length = _sequence_length_from_sparse_tensor(
|
sequence_length = fc_utils.sequence_length_from_sparse_tensor(
|
||||||
sparse_tensors.id_tensor)
|
sparse_tensors.id_tensor)
|
||||||
return _SequenceDenseColumn.TensorSequenceLengthPair(
|
return _SequenceDenseColumn.TensorSequenceLengthPair(
|
||||||
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
||||||
|
|
||||||
|
|
||||||
def _create_tuple(shape, value):
|
|
||||||
"""Returns a tuple with given shape and filled with value."""
|
|
||||||
if shape:
|
|
||||||
return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])])
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def _as_tuple(value):
|
|
||||||
if not nest.is_sequence(value):
|
|
||||||
return value
|
|
||||||
return tuple([_as_tuple(v) for v in value])
|
|
||||||
|
|
||||||
|
|
||||||
def _check_shape(shape, key):
|
def _check_shape(shape, key):
|
||||||
"""Returns shape if it's valid, raises error otherwise."""
|
"""Returns shape if it's valid, raises error otherwise."""
|
||||||
assert shape is not None
|
assert shape is not None
|
||||||
@ -2672,82 +2647,6 @@ def _check_shape(shape, key):
|
|||||||
return shape
|
return shape
|
||||||
|
|
||||||
|
|
||||||
def _is_shape_and_default_value_compatible(default_value, shape):
|
|
||||||
"""Verifies compatibility of shape and default_value."""
|
|
||||||
# Invalid condition:
|
|
||||||
# * if default_value is not a scalar and shape is empty
|
|
||||||
# * or if default_value is an iterable and shape is not empty
|
|
||||||
if nest.is_sequence(default_value) != bool(shape):
|
|
||||||
return False
|
|
||||||
if not shape:
|
|
||||||
return True
|
|
||||||
if len(default_value) != shape[0]:
|
|
||||||
return False
|
|
||||||
for i in range(shape[0]):
|
|
||||||
if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _check_default_value(shape, default_value, dtype, key):
|
|
||||||
"""Returns default value as tuple if it's valid, otherwise raises errors.
|
|
||||||
|
|
||||||
This function verifies that `default_value` is compatible with both `shape`
|
|
||||||
and `dtype`. If it is not compatible, it raises an error. If it is compatible,
|
|
||||||
it casts default_value to a tuple and returns it. `key` is used only
|
|
||||||
for error message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
shape: An iterable of integers specifies the shape of the `Tensor`.
|
|
||||||
default_value: If a single value is provided, the same value will be applied
|
|
||||||
as the default value for every item. If an iterable of values is
|
|
||||||
provided, the shape of the `default_value` should be equal to the given
|
|
||||||
`shape`.
|
|
||||||
dtype: defines the type of values. Default value is `tf.float32`. Must be a
|
|
||||||
non-quantized, real integer or floating point type.
|
|
||||||
key: Column name, used only for error messages.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple which will be used as default value.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: if `default_value` is an iterable but not compatible with `shape`
|
|
||||||
TypeError: if `default_value` is not compatible with `dtype`.
|
|
||||||
ValueError: if `dtype` is not convertible to `tf.float32`.
|
|
||||||
"""
|
|
||||||
if default_value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if isinstance(default_value, int):
|
|
||||||
return _create_tuple(shape, default_value)
|
|
||||||
|
|
||||||
if isinstance(default_value, float) and dtype.is_floating:
|
|
||||||
return _create_tuple(shape, default_value)
|
|
||||||
|
|
||||||
if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays
|
|
||||||
default_value = default_value.tolist()
|
|
||||||
|
|
||||||
if nest.is_sequence(default_value):
|
|
||||||
if not _is_shape_and_default_value_compatible(default_value, shape):
|
|
||||||
raise ValueError(
|
|
||||||
'The shape of default_value must be equal to given shape. '
|
|
||||||
'default_value: {}, shape: {}, key: {}'.format(
|
|
||||||
default_value, shape, key))
|
|
||||||
# Check if the values in the list are all integers or are convertible to
|
|
||||||
# floats.
|
|
||||||
is_list_all_int = all(
|
|
||||||
isinstance(v, int) for v in nest.flatten(default_value))
|
|
||||||
is_list_has_float = any(
|
|
||||||
isinstance(v, float) for v in nest.flatten(default_value))
|
|
||||||
if is_list_all_int:
|
|
||||||
return _as_tuple(default_value)
|
|
||||||
if is_list_has_float and dtype.is_floating:
|
|
||||||
return _as_tuple(default_value)
|
|
||||||
raise TypeError('default_value must be compatible with dtype. '
|
|
||||||
'default_value: {}, dtype: {}, key: {}'.format(
|
|
||||||
default_value, dtype, key))
|
|
||||||
|
|
||||||
|
|
||||||
class _HashedCategoricalColumn(
|
class _HashedCategoricalColumn(
|
||||||
_CategoricalColumn,
|
_CategoricalColumn,
|
||||||
collections.namedtuple('_HashedCategoricalColumn',
|
collections.namedtuple('_HashedCategoricalColumn',
|
||||||
@ -2767,7 +2666,7 @@ class _HashedCategoricalColumn(
|
|||||||
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
|
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
|
||||||
raise ValueError('SparseColumn input must be a SparseTensor.')
|
raise ValueError('SparseColumn input must be a SparseTensor.')
|
||||||
|
|
||||||
_assert_string_or_int(
|
fc_utils.assert_string_or_int(
|
||||||
input_tensor.dtype,
|
input_tensor.dtype,
|
||||||
prefix='column_name: {} input_tensor'.format(self.key))
|
prefix='column_name: {} input_tensor'.format(self.key))
|
||||||
|
|
||||||
@ -2822,7 +2721,7 @@ class _VocabularyFileCategoricalColumn(
|
|||||||
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
||||||
self.key, self.dtype, input_tensor.dtype))
|
self.key, self.dtype, input_tensor.dtype))
|
||||||
|
|
||||||
_assert_string_or_int(
|
fc_utils.assert_string_or_int(
|
||||||
input_tensor.dtype,
|
input_tensor.dtype,
|
||||||
prefix='column_name: {} input_tensor'.format(self.key))
|
prefix='column_name: {} input_tensor'.format(self.key))
|
||||||
|
|
||||||
@ -2874,7 +2773,7 @@ class _VocabularyListCategoricalColumn(
|
|||||||
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
||||||
self.key, self.dtype, input_tensor.dtype))
|
self.key, self.dtype, input_tensor.dtype))
|
||||||
|
|
||||||
_assert_string_or_int(
|
fc_utils.assert_string_or_int(
|
||||||
input_tensor.dtype,
|
input_tensor.dtype,
|
||||||
prefix='column_name: {} input_tensor'.format(self.key))
|
prefix='column_name: {} input_tensor'.format(self.key))
|
||||||
|
|
||||||
@ -3210,7 +3109,7 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
|
|||||||
# representation created by _transform_feature.
|
# representation created by _transform_feature.
|
||||||
dense_tensor = inputs.get(self)
|
dense_tensor = inputs.get(self)
|
||||||
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
|
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
|
||||||
sequence_length = _sequence_length_from_sparse_tensor(
|
sequence_length = fc_utils.sequence_length_from_sparse_tensor(
|
||||||
sparse_tensors.id_tensor)
|
sparse_tensors.id_tensor)
|
||||||
return _SequenceDenseColumn.TensorSequenceLengthPair(
|
return _SequenceDenseColumn.TensorSequenceLengthPair(
|
||||||
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
||||||
@ -3242,31 +3141,6 @@ def _verify_static_batch_size_equality(tensors, columns):
|
|||||||
expected_batch_size, tensors[i].shape.dims[0]))
|
expected_batch_size, tensors[i].shape.dims[0]))
|
||||||
|
|
||||||
|
|
||||||
def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
|
|
||||||
"""Returns a [batch_size] Tensor with per-example sequence length."""
|
|
||||||
with ops.name_scope(None, 'sequence_length') as name_scope:
|
|
||||||
row_ids = sp_tensor.indices[:, 0]
|
|
||||||
column_ids = sp_tensor.indices[:, 1]
|
|
||||||
# Add one to convert column indices to element length
|
|
||||||
column_ids += array_ops.ones_like(column_ids)
|
|
||||||
# Get the number of elements we will have per example/row
|
|
||||||
seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids)
|
|
||||||
|
|
||||||
# The raw values are grouped according to num_elements;
|
|
||||||
# how many entities will we have after grouping?
|
|
||||||
# Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1),
|
|
||||||
# row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2,
|
|
||||||
# these will get grouped, and the final seq_length is [1, 1]
|
|
||||||
seq_length = math_ops.cast(
|
|
||||||
math_ops.ceil(seq_length / num_elements), dtypes.int64)
|
|
||||||
|
|
||||||
# If the last n rows do not have ids, seq_length will have shape
|
|
||||||
# [batch_size - n]. Pad the remaining values with zeros.
|
|
||||||
n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
|
|
||||||
padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
|
|
||||||
return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
|
|
||||||
|
|
||||||
|
|
||||||
class _SequenceCategoricalColumn(
|
class _SequenceCategoricalColumn(
|
||||||
_CategoricalColumn,
|
_CategoricalColumn,
|
||||||
collections.namedtuple(
|
collections.namedtuple(
|
||||||
|
@ -137,6 +137,7 @@ import six
|
|||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.feature_column import feature_column as fc_old
|
from tensorflow.python.feature_column import feature_column as fc_old
|
||||||
|
from tensorflow.python.feature_column import utils as fc_utils
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||||
@ -1318,13 +1319,14 @@ def numeric_column(key,
|
|||||||
if not (dtype.is_integer or dtype.is_floating):
|
if not (dtype.is_integer or dtype.is_floating):
|
||||||
raise ValueError('dtype must be convertible to float. '
|
raise ValueError('dtype must be convertible to float. '
|
||||||
'dtype: {}, key: {}'.format(dtype, key))
|
'dtype: {}, key: {}'.format(dtype, key))
|
||||||
default_value = _check_default_value(shape, default_value, dtype, key)
|
default_value = fc_utils.check_default_value(
|
||||||
|
shape, default_value, dtype, key)
|
||||||
|
|
||||||
if normalizer_fn is not None and not callable(normalizer_fn):
|
if normalizer_fn is not None and not callable(normalizer_fn):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
|
'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
|
||||||
|
|
||||||
_assert_key_is_string(key)
|
fc_utils.assert_key_is_string(key)
|
||||||
return NumericColumn(
|
return NumericColumn(
|
||||||
key,
|
key,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
@ -1418,19 +1420,6 @@ def bucketized_column(source_column, boundaries):
|
|||||||
return BucketizedColumn(source_column, tuple(boundaries))
|
return BucketizedColumn(source_column, tuple(boundaries))
|
||||||
|
|
||||||
|
|
||||||
def _assert_string_or_int(dtype, prefix):
|
|
||||||
if (dtype != dtypes.string) and (not dtype.is_integer):
|
|
||||||
raise ValueError(
|
|
||||||
'{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
|
|
||||||
|
|
||||||
|
|
||||||
def _assert_key_is_string(key):
|
|
||||||
if not isinstance(key, six.string_types):
|
|
||||||
raise ValueError(
|
|
||||||
'key must be a string. Got: type {}. Given key: {}.'.format(
|
|
||||||
type(key), key))
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export('feature_column.categorical_column_with_hash_bucket')
|
@tf_export('feature_column.categorical_column_with_hash_bucket')
|
||||||
def categorical_column_with_hash_bucket(key,
|
def categorical_column_with_hash_bucket(key,
|
||||||
hash_bucket_size,
|
hash_bucket_size,
|
||||||
@ -1484,8 +1473,8 @@ def categorical_column_with_hash_bucket(key,
|
|||||||
'hash_bucket_size: {}, key: {}'.format(
|
'hash_bucket_size: {}, key: {}'.format(
|
||||||
hash_bucket_size, key))
|
hash_bucket_size, key))
|
||||||
|
|
||||||
_assert_key_is_string(key)
|
fc_utils.assert_key_is_string(key)
|
||||||
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
||||||
|
|
||||||
return HashedCategoricalColumn(key, hash_bucket_size, dtype)
|
return HashedCategoricalColumn(key, hash_bucket_size, dtype)
|
||||||
|
|
||||||
@ -1690,8 +1679,8 @@ def categorical_column_with_vocabulary_file_v2(key,
|
|||||||
if num_oov_buckets < 0:
|
if num_oov_buckets < 0:
|
||||||
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
||||||
num_oov_buckets, key))
|
num_oov_buckets, key))
|
||||||
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
||||||
_assert_key_is_string(key)
|
fc_utils.assert_key_is_string(key)
|
||||||
return VocabularyFileCategoricalColumn(
|
return VocabularyFileCategoricalColumn(
|
||||||
key=key,
|
key=key,
|
||||||
vocabulary_file=vocabulary_file,
|
vocabulary_file=vocabulary_file,
|
||||||
@ -1799,7 +1788,7 @@ def categorical_column_with_vocabulary_list(key,
|
|||||||
if num_oov_buckets < 0:
|
if num_oov_buckets < 0:
|
||||||
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
||||||
num_oov_buckets, key))
|
num_oov_buckets, key))
|
||||||
_assert_string_or_int(
|
fc_utils.assert_string_or_int(
|
||||||
vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
|
vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = vocabulary_dtype
|
dtype = vocabulary_dtype
|
||||||
@ -1807,8 +1796,8 @@ def categorical_column_with_vocabulary_list(key,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
|
'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
|
||||||
dtype, vocabulary_dtype, key))
|
dtype, vocabulary_dtype, key))
|
||||||
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
||||||
_assert_key_is_string(key)
|
fc_utils.assert_key_is_string(key)
|
||||||
|
|
||||||
return VocabularyListCategoricalColumn(
|
return VocabularyListCategoricalColumn(
|
||||||
key=key,
|
key=key,
|
||||||
@ -1881,7 +1870,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'default_value {} not in range [0, {}), column_name {}'.format(
|
'default_value {} not in range [0, {}), column_name {}'.format(
|
||||||
default_value, num_buckets, key))
|
default_value, num_buckets, key))
|
||||||
_assert_key_is_string(key)
|
fc_utils.assert_key_is_string(key)
|
||||||
return IdentityCategoricalColumn(
|
return IdentityCategoricalColumn(
|
||||||
key=key, number_buckets=num_buckets, default_value=default_value)
|
key=key, number_buckets=num_buckets, default_value=default_value)
|
||||||
|
|
||||||
@ -3166,7 +3155,7 @@ class EmbeddingColumn(
|
|||||||
transformation_cache, state_manager)
|
transformation_cache, state_manager)
|
||||||
dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
|
dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
|
||||||
state_manager)
|
state_manager)
|
||||||
sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
|
sequence_length = fc_utils.sequence_length_from_sparse_tensor(
|
||||||
sparse_tensors.id_tensor)
|
sparse_tensors.id_tensor)
|
||||||
return SequenceDenseColumn.TensorSequenceLengthPair(
|
return SequenceDenseColumn.TensorSequenceLengthPair(
|
||||||
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
||||||
@ -3192,7 +3181,7 @@ class EmbeddingColumn(
|
|||||||
sparse_tensors,
|
sparse_tensors,
|
||||||
weight_collections=weight_collections,
|
weight_collections=weight_collections,
|
||||||
trainable=trainable)
|
trainable=trainable)
|
||||||
sequence_length = _sequence_length_from_sparse_tensor(
|
sequence_length = fc_utils.sequence_length_from_sparse_tensor(
|
||||||
sparse_tensors.id_tensor)
|
sparse_tensors.id_tensor)
|
||||||
return SequenceDenseColumn.TensorSequenceLengthPair(
|
return SequenceDenseColumn.TensorSequenceLengthPair(
|
||||||
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
||||||
@ -3376,7 +3365,7 @@ class SharedEmbeddingColumn(
|
|||||||
state_manager)
|
state_manager)
|
||||||
sparse_tensors = self.categorical_column.get_sparse_tensors(
|
sparse_tensors = self.categorical_column.get_sparse_tensors(
|
||||||
transformation_cache, state_manager)
|
transformation_cache, state_manager)
|
||||||
sequence_length = _sequence_length_from_sparse_tensor(
|
sequence_length = fc_utils.sequence_length_from_sparse_tensor(
|
||||||
sparse_tensors.id_tensor)
|
sparse_tensors.id_tensor)
|
||||||
return SequenceDenseColumn.TensorSequenceLengthPair(
|
return SequenceDenseColumn.TensorSequenceLengthPair(
|
||||||
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
||||||
@ -3402,19 +3391,6 @@ class SharedEmbeddingColumn(
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def _create_tuple(shape, value):
|
|
||||||
"""Returns a tuple with given shape and filled with value."""
|
|
||||||
if shape:
|
|
||||||
return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])])
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def _as_tuple(value):
|
|
||||||
if not nest.is_sequence(value):
|
|
||||||
return value
|
|
||||||
return tuple([_as_tuple(v) for v in value])
|
|
||||||
|
|
||||||
|
|
||||||
def _check_shape(shape, key):
|
def _check_shape(shape, key):
|
||||||
"""Returns shape if it's valid, raises error otherwise."""
|
"""Returns shape if it's valid, raises error otherwise."""
|
||||||
assert shape is not None
|
assert shape is not None
|
||||||
@ -3431,82 +3407,6 @@ def _check_shape(shape, key):
|
|||||||
return shape
|
return shape
|
||||||
|
|
||||||
|
|
||||||
def _is_shape_and_default_value_compatible(default_value, shape):
|
|
||||||
"""Verifies compatibility of shape and default_value."""
|
|
||||||
# Invalid condition:
|
|
||||||
# * if default_value is not a scalar and shape is empty
|
|
||||||
# * or if default_value is an iterable and shape is not empty
|
|
||||||
if nest.is_sequence(default_value) != bool(shape):
|
|
||||||
return False
|
|
||||||
if not shape:
|
|
||||||
return True
|
|
||||||
if len(default_value) != shape[0]:
|
|
||||||
return False
|
|
||||||
for i in range(shape[0]):
|
|
||||||
if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _check_default_value(shape, default_value, dtype, key):
|
|
||||||
"""Returns default value as tuple if it's valid, otherwise raises errors.
|
|
||||||
|
|
||||||
This function verifies that `default_value` is compatible with both `shape`
|
|
||||||
and `dtype`. If it is not compatible, it raises an error. If it is compatible,
|
|
||||||
it casts default_value to a tuple and returns it. `key` is used only
|
|
||||||
for error message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
shape: An iterable of integers specifies the shape of the `Tensor`.
|
|
||||||
default_value: If a single value is provided, the same value will be applied
|
|
||||||
as the default value for every item. If an iterable of values is
|
|
||||||
provided, the shape of the `default_value` should be equal to the given
|
|
||||||
`shape`.
|
|
||||||
dtype: defines the type of values. Default value is `tf.float32`. Must be a
|
|
||||||
non-quantized, real integer or floating point type.
|
|
||||||
key: Column name, used only for error messages.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple which will be used as default value.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: if `default_value` is an iterable but not compatible with `shape`
|
|
||||||
TypeError: if `default_value` is not compatible with `dtype`.
|
|
||||||
ValueError: if `dtype` is not convertible to `tf.float32`.
|
|
||||||
"""
|
|
||||||
if default_value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if isinstance(default_value, int):
|
|
||||||
return _create_tuple(shape, default_value)
|
|
||||||
|
|
||||||
if isinstance(default_value, float) and dtype.is_floating:
|
|
||||||
return _create_tuple(shape, default_value)
|
|
||||||
|
|
||||||
if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays
|
|
||||||
default_value = default_value.tolist()
|
|
||||||
|
|
||||||
if nest.is_sequence(default_value):
|
|
||||||
if not _is_shape_and_default_value_compatible(default_value, shape):
|
|
||||||
raise ValueError(
|
|
||||||
'The shape of default_value must be equal to given shape. '
|
|
||||||
'default_value: {}, shape: {}, key: {}'.format(
|
|
||||||
default_value, shape, key))
|
|
||||||
# Check if the values in the list are all integers or are convertible to
|
|
||||||
# floats.
|
|
||||||
is_list_all_int = all(
|
|
||||||
isinstance(v, int) for v in nest.flatten(default_value))
|
|
||||||
is_list_has_float = any(
|
|
||||||
isinstance(v, float) for v in nest.flatten(default_value))
|
|
||||||
if is_list_all_int:
|
|
||||||
return _as_tuple(default_value)
|
|
||||||
if is_list_has_float and dtype.is_floating:
|
|
||||||
return _as_tuple(default_value)
|
|
||||||
raise TypeError('default_value must be compatible with dtype. '
|
|
||||||
'default_value: {}, dtype: {}, key: {}'.format(
|
|
||||||
default_value, dtype, key))
|
|
||||||
|
|
||||||
|
|
||||||
class HashedCategoricalColumn(
|
class HashedCategoricalColumn(
|
||||||
CategoricalColumn,
|
CategoricalColumn,
|
||||||
fc_old._CategoricalColumn, # pylint: disable=protected-access
|
fc_old._CategoricalColumn, # pylint: disable=protected-access
|
||||||
@ -3539,7 +3439,7 @@ class HashedCategoricalColumn(
|
|||||||
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
|
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
|
||||||
raise ValueError('SparseColumn input must be a SparseTensor.')
|
raise ValueError('SparseColumn input must be a SparseTensor.')
|
||||||
|
|
||||||
_assert_string_or_int(
|
fc_utils.assert_string_or_int(
|
||||||
input_tensor.dtype,
|
input_tensor.dtype,
|
||||||
prefix='column_name: {} input_tensor'.format(self.key))
|
prefix='column_name: {} input_tensor'.format(self.key))
|
||||||
|
|
||||||
@ -3651,7 +3551,7 @@ class VocabularyFileCategoricalColumn(
|
|||||||
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
||||||
self.key, self.dtype, input_tensor.dtype))
|
self.key, self.dtype, input_tensor.dtype))
|
||||||
|
|
||||||
_assert_string_or_int(
|
fc_utils.assert_string_or_int(
|
||||||
input_tensor.dtype,
|
input_tensor.dtype,
|
||||||
prefix='column_name: {} input_tensor'.format(self.key))
|
prefix='column_name: {} input_tensor'.format(self.key))
|
||||||
|
|
||||||
@ -3763,7 +3663,7 @@ class VocabularyListCategoricalColumn(
|
|||||||
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
||||||
self.key, self.dtype, input_tensor.dtype))
|
self.key, self.dtype, input_tensor.dtype))
|
||||||
|
|
||||||
_assert_string_or_int(
|
fc_utils.assert_string_or_int(
|
||||||
input_tensor.dtype,
|
input_tensor.dtype,
|
||||||
prefix='column_name: {} input_tensor'.format(self.key))
|
prefix='column_name: {} input_tensor'.format(self.key))
|
||||||
|
|
||||||
@ -4426,7 +4326,7 @@ class IndicatorColumn(
|
|||||||
dense_tensor = transformation_cache.get(self, state_manager)
|
dense_tensor = transformation_cache.get(self, state_manager)
|
||||||
sparse_tensors = self.categorical_column.get_sparse_tensors(
|
sparse_tensors = self.categorical_column.get_sparse_tensors(
|
||||||
transformation_cache, state_manager)
|
transformation_cache, state_manager)
|
||||||
sequence_length = _sequence_length_from_sparse_tensor(
|
sequence_length = fc_utils.sequence_length_from_sparse_tensor(
|
||||||
sparse_tensors.id_tensor)
|
sparse_tensors.id_tensor)
|
||||||
return SequenceDenseColumn.TensorSequenceLengthPair(
|
return SequenceDenseColumn.TensorSequenceLengthPair(
|
||||||
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
||||||
@ -4455,7 +4355,7 @@ class IndicatorColumn(
|
|||||||
# representation created by _transform_feature.
|
# representation created by _transform_feature.
|
||||||
dense_tensor = inputs.get(self)
|
dense_tensor = inputs.get(self)
|
||||||
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
|
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
|
||||||
sequence_length = _sequence_length_from_sparse_tensor(
|
sequence_length = fc_utils.sequence_length_from_sparse_tensor(
|
||||||
sparse_tensors.id_tensor)
|
sparse_tensors.id_tensor)
|
||||||
return SequenceDenseColumn.TensorSequenceLengthPair(
|
return SequenceDenseColumn.TensorSequenceLengthPair(
|
||||||
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
dense_tensor=dense_tensor, sequence_length=sequence_length)
|
||||||
@ -4509,31 +4409,6 @@ def _verify_static_batch_size_equality(tensors, columns):
|
|||||||
expected_batch_size, batch_size))
|
expected_batch_size, batch_size))
|
||||||
|
|
||||||
|
|
||||||
def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
|
|
||||||
"""Returns a [batch_size] Tensor with per-example sequence length."""
|
|
||||||
with ops.name_scope(None, 'sequence_length') as name_scope:
|
|
||||||
row_ids = sp_tensor.indices[:, 0]
|
|
||||||
column_ids = sp_tensor.indices[:, 1]
|
|
||||||
# Add one to convert column indices to element length
|
|
||||||
column_ids += array_ops.ones_like(column_ids)
|
|
||||||
# Get the number of elements we will have per example/row
|
|
||||||
seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids)
|
|
||||||
|
|
||||||
# The raw values are grouped according to num_elements;
|
|
||||||
# how many entities will we have after grouping?
|
|
||||||
# Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1),
|
|
||||||
# row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2,
|
|
||||||
# these will get grouped, and the final seq_length is [1, 1]
|
|
||||||
seq_length = math_ops.cast(
|
|
||||||
math_ops.ceil(seq_length / num_elements), dtypes.int64)
|
|
||||||
|
|
||||||
# If the last n rows do not have ids, seq_length will have shape
|
|
||||||
# [batch_size - n]. Pad the remaining values with zeros.
|
|
||||||
n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
|
|
||||||
padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
|
|
||||||
return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceCategoricalColumn(
|
class SequenceCategoricalColumn(
|
||||||
CategoricalColumn,
|
CategoricalColumn,
|
||||||
fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access
|
fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access
|
||||||
|
@ -26,6 +26,7 @@ import collections
|
|||||||
|
|
||||||
|
|
||||||
from tensorflow.python.feature_column import feature_column_v2 as fc
|
from tensorflow.python.feature_column import feature_column_v2 as fc
|
||||||
|
from tensorflow.python.feature_column import utils as fc_utils
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -564,7 +565,7 @@ class SequenceNumericColumn(
|
|||||||
num_elements = self.variable_shape.num_elements()
|
num_elements = self.variable_shape.num_elements()
|
||||||
else:
|
else:
|
||||||
num_elements = 1
|
num_elements = 1
|
||||||
seq_length = fc._sequence_length_from_sparse_tensor(
|
seq_length = fc_utils.sequence_length_from_sparse_tensor(
|
||||||
sp_tensor, num_elements=num_elements)
|
sp_tensor, num_elements=num_elements)
|
||||||
|
|
||||||
return fc.SequenceDenseColumn.TensorSequenceLengthPair(
|
return fc.SequenceDenseColumn.TensorSequenceLengthPair(
|
||||||
|
154
tensorflow/python/feature_column/utils.py
Normal file
154
tensorflow/python/feature_column/utils.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Defines functions common to multiple feature column files."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
def sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
|
||||||
|
"""Returns a [batch_size] Tensor with per-example sequence length."""
|
||||||
|
with ops.name_scope(None, 'sequence_length') as name_scope:
|
||||||
|
row_ids = sp_tensor.indices[:, 0]
|
||||||
|
column_ids = sp_tensor.indices[:, 1]
|
||||||
|
# Add one to convert column indices to element length
|
||||||
|
column_ids += array_ops.ones_like(column_ids)
|
||||||
|
# Get the number of elements we will have per example/row
|
||||||
|
seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids)
|
||||||
|
|
||||||
|
# The raw values are grouped according to num_elements;
|
||||||
|
# how many entities will we have after grouping?
|
||||||
|
# Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1),
|
||||||
|
# row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2,
|
||||||
|
# these will get grouped, and the final seq_length is [1, 1]
|
||||||
|
seq_length = math_ops.cast(
|
||||||
|
math_ops.ceil(seq_length / num_elements), dtypes.int64)
|
||||||
|
|
||||||
|
# If the last n rows do not have ids, seq_length will have shape
|
||||||
|
# [batch_size - n]. Pad the remaining values with zeros.
|
||||||
|
n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
|
||||||
|
padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
|
||||||
|
return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_string_or_int(dtype, prefix):
|
||||||
|
if (dtype != dtypes.string) and (not dtype.is_integer):
|
||||||
|
raise ValueError(
|
||||||
|
'{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def assert_key_is_string(key):
|
||||||
|
if not isinstance(key, six.string_types):
|
||||||
|
raise ValueError(
|
||||||
|
'key must be a string. Got: type {}. Given key: {}.'.format(
|
||||||
|
type(key), key))
|
||||||
|
|
||||||
|
|
||||||
|
def check_default_value(shape, default_value, dtype, key):
|
||||||
|
"""Returns default value as tuple if it's valid, otherwise raises errors.
|
||||||
|
|
||||||
|
This function verifies that `default_value` is compatible with both `shape`
|
||||||
|
and `dtype`. If it is not compatible, it raises an error. If it is compatible,
|
||||||
|
it casts default_value to a tuple and returns it. `key` is used only
|
||||||
|
for error message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: An iterable of integers specifies the shape of the `Tensor`.
|
||||||
|
default_value: If a single value is provided, the same value will be applied
|
||||||
|
as the default value for every item. If an iterable of values is
|
||||||
|
provided, the shape of the `default_value` should be equal to the given
|
||||||
|
`shape`.
|
||||||
|
dtype: defines the type of values. Default value is `tf.float32`. Must be a
|
||||||
|
non-quantized, real integer or floating point type.
|
||||||
|
key: Column name, used only for error messages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple which will be used as default value.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `default_value` is an iterable but not compatible with `shape`
|
||||||
|
TypeError: if `default_value` is not compatible with `dtype`.
|
||||||
|
ValueError: if `dtype` is not convertible to `tf.float32`.
|
||||||
|
"""
|
||||||
|
if default_value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(default_value, int):
|
||||||
|
return _create_tuple(shape, default_value)
|
||||||
|
|
||||||
|
if isinstance(default_value, float) and dtype.is_floating:
|
||||||
|
return _create_tuple(shape, default_value)
|
||||||
|
|
||||||
|
if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays
|
||||||
|
default_value = default_value.tolist()
|
||||||
|
|
||||||
|
if nest.is_sequence(default_value):
|
||||||
|
if not _is_shape_and_default_value_compatible(default_value, shape):
|
||||||
|
raise ValueError(
|
||||||
|
'The shape of default_value must be equal to given shape. '
|
||||||
|
'default_value: {}, shape: {}, key: {}'.format(
|
||||||
|
default_value, shape, key))
|
||||||
|
# Check if the values in the list are all integers or are convertible to
|
||||||
|
# floats.
|
||||||
|
is_list_all_int = all(
|
||||||
|
isinstance(v, int) for v in nest.flatten(default_value))
|
||||||
|
is_list_has_float = any(
|
||||||
|
isinstance(v, float) for v in nest.flatten(default_value))
|
||||||
|
if is_list_all_int:
|
||||||
|
return _as_tuple(default_value)
|
||||||
|
if is_list_has_float and dtype.is_floating:
|
||||||
|
return _as_tuple(default_value)
|
||||||
|
raise TypeError('default_value must be compatible with dtype. '
|
||||||
|
'default_value: {}, dtype: {}, key: {}'.format(
|
||||||
|
default_value, dtype, key))
|
||||||
|
|
||||||
|
|
||||||
|
def _create_tuple(shape, value):
|
||||||
|
"""Returns a tuple with given shape and filled with value."""
|
||||||
|
if shape:
|
||||||
|
return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])])
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _as_tuple(value):
|
||||||
|
if not nest.is_sequence(value):
|
||||||
|
return value
|
||||||
|
return tuple([_as_tuple(v) for v in value])
|
||||||
|
|
||||||
|
|
||||||
|
def _is_shape_and_default_value_compatible(default_value, shape):
|
||||||
|
"""Verifies compatibility of shape and default_value."""
|
||||||
|
# Invalid condition:
|
||||||
|
# * if default_value is not a scalar and shape is empty
|
||||||
|
# * or if default_value is an iterable and shape is not empty
|
||||||
|
if nest.is_sequence(default_value) != bool(shape):
|
||||||
|
return False
|
||||||
|
if not shape:
|
||||||
|
return True
|
||||||
|
if len(default_value) != shape[0]:
|
||||||
|
return False
|
||||||
|
for i in range(shape[0]):
|
||||||
|
if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]):
|
||||||
|
return False
|
||||||
|
return True
|
Loading…
Reference in New Issue
Block a user