Fix bug re DataFrame->FeatureColumn

Change: 128989209
This commit is contained in:
David Soergel 2016-08-01 07:49:30 -08:00 committed by TensorFlower Gardener
parent d713ac4889
commit 52c0418614

View File

@ -1337,7 +1337,7 @@ def crossed_column(columns, hash_bucket_size, combiner="sum",
class DataFrameColumn(_FeatureColumn,
collections.namedtuple("DataFrameColumn",
["name", "series"])):
["column_name", "series"])):
"""Represents a feature column produced from a `DataFrame`.
Instances of this class are immutable. A `DataFrame` column may be dense or
@ -1345,13 +1345,17 @@ class DataFrameColumn(_FeatureColumn,
batch_size.
Args:
name: a name for this column
column_name: a name for this column
series: a `Series` to be wrapped, which has already had its base features
substituted with `PredefinedSeries`.
"""
def __new__(cls, name, series):
return super(DataFrameColumn, cls).__new__(cls, name, series)
def __new__(cls, column_name, series):
return super(DataFrameColumn, cls).__new__(cls, column_name, series)
@property
def name(self):
return self.column_name
@property
def config(self):
@ -1379,7 +1383,17 @@ class DataFrameColumn(_FeatureColumn,
input_tensor,
weight_collections=None,
trainable=True):
return input_tensor
# DataFrame typically provides Tensors of shape [batch_size],
# but Estimator requires shape [batch_size, 1]
dims = input_tensor.get_shape().ndims
if dims == 0:
raise ValueError(
"Can't build input layer from tensor of shape (): {}".format(
self.column_name))
elif dims == 1:
return array_ops.expand_dims(input_tensor, 1)
else:
return input_tensor
# TODO(soergel): This mirrors RealValuedColumn for now, but should become
# better abstracted with less code duplication when we add other kinds.