Fix bug re DataFrame->FeatureColumn
Change: 128989209
This commit is contained in:
parent
d713ac4889
commit
52c0418614
@ -1337,7 +1337,7 @@ def crossed_column(columns, hash_bucket_size, combiner="sum",
|
||||
|
||||
class DataFrameColumn(_FeatureColumn,
|
||||
collections.namedtuple("DataFrameColumn",
|
||||
["name", "series"])):
|
||||
["column_name", "series"])):
|
||||
"""Represents a feature column produced from a `DataFrame`.
|
||||
|
||||
Instances of this class are immutable. A `DataFrame` column may be dense or
|
||||
@ -1345,13 +1345,17 @@ class DataFrameColumn(_FeatureColumn,
|
||||
batch_size.
|
||||
|
||||
Args:
|
||||
name: a name for this column
|
||||
column_name: a name for this column
|
||||
series: a `Series` to be wrapped, which has already had its base features
|
||||
substituted with `PredefinedSeries`.
|
||||
"""
|
||||
|
||||
def __new__(cls, name, series):
|
||||
return super(DataFrameColumn, cls).__new__(cls, name, series)
|
||||
def __new__(cls, column_name, series):
|
||||
return super(DataFrameColumn, cls).__new__(cls, column_name, series)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.column_name
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
@ -1379,7 +1383,17 @@ class DataFrameColumn(_FeatureColumn,
|
||||
input_tensor,
|
||||
weight_collections=None,
|
||||
trainable=True):
|
||||
return input_tensor
|
||||
# DataFrame typically provides Tensors of shape [batch_size],
|
||||
# but Estimator requires shape [batch_size, 1]
|
||||
dims = input_tensor.get_shape().ndims
|
||||
if dims == 0:
|
||||
raise ValueError(
|
||||
"Can't build input layer from tensor of shape (): {}".format(
|
||||
self.column_name))
|
||||
elif dims == 1:
|
||||
return array_ops.expand_dims(input_tensor, 1)
|
||||
else:
|
||||
return input_tensor
|
||||
|
||||
# TODO(soergel): This mirrors RealValuedColumn for now, but should become
|
||||
# better abstracted with less code duplication when we add other kinds.
|
||||
|
Loading…
Reference in New Issue
Block a user