Exclude the index column in pandas_input_fn - it causes errors when used with infer_real_valued_columns_from_input_fn and canned estimators.
Change: 142719058
This commit is contained in:
parent
a439d9975e
commit
65a983af2c
@ -122,9 +122,14 @@ def extract_pandas_labels(labels):
|
||||
return labels
|
||||
|
||||
|
||||
def pandas_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=True,
|
||||
queue_capacity=1000, num_threads=1, target_column='target',
|
||||
index_column='index'):
|
||||
def pandas_input_fn(x,
|
||||
y=None,
|
||||
batch_size=128,
|
||||
num_epochs=1,
|
||||
shuffle=True,
|
||||
queue_capacity=1000,
|
||||
num_threads=1,
|
||||
target_column='target'):
|
||||
"""Returns input function that would feed Pandas DataFrame into the model.
|
||||
|
||||
Note: `y`'s index must match `x`'s index.
|
||||
@ -140,7 +145,6 @@ def pandas_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=True,
|
||||
roughly to the size of `x`.
|
||||
num_threads: int, number of threads used for reading and enqueueing.
|
||||
target_column: str, name to give the target column `y`.
|
||||
index_column: str, name of the index column.
|
||||
|
||||
Returns:
|
||||
Function, that has signature of ()->(dict of `features`, `target`)
|
||||
@ -183,7 +187,10 @@ def pandas_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=True,
|
||||
features = queue.dequeue_many(batch_size)
|
||||
else:
|
||||
features = queue.dequeue_up_to(batch_size)
|
||||
features = dict(zip([index_column] + list(x.columns), features))
|
||||
assert len(features) == len(x.columns) + 1, ('Features should have one '
|
||||
'extra element for the index.')
|
||||
features = features[1:]
|
||||
features = dict(zip(list(x.columns), features))
|
||||
if y is not None:
|
||||
target = features.pop(target_column)
|
||||
return features, target
|
||||
|
@ -57,7 +57,7 @@ class PandasIoTest(tf.test.TestCase):
|
||||
x, _ = self.makeTestDataFrame()
|
||||
y_noindex = pd.Series(np.arange(-32, -28))
|
||||
with self.assertRaises(ValueError):
|
||||
failing_input_fn = pandas_io.pandas_input_fn(
|
||||
pandas_io.pandas_input_fn(
|
||||
x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
|
||||
|
||||
def testPandasInputFn_ProducesExpectedOutputs(self):
|
||||
@ -70,7 +70,6 @@ class PandasIoTest(tf.test.TestCase):
|
||||
|
||||
features, target = self.callInputFnOnce(input_fn, session)
|
||||
|
||||
self.assertAllEqual(features['index'], [100, 101])
|
||||
self.assertAllEqual(features['a'], [0, 1])
|
||||
self.assertAllEqual(features['b'], [32, 33])
|
||||
self.assertAllEqual(target, [-32, -31])
|
||||
@ -85,10 +84,21 @@ class PandasIoTest(tf.test.TestCase):
|
||||
|
||||
features = self.callInputFnOnce(input_fn, session)
|
||||
|
||||
self.assertAllEqual(features['index'], [100, 101])
|
||||
self.assertAllEqual(features['a'], [0, 1])
|
||||
self.assertAllEqual(features['b'], [32, 33])
|
||||
|
||||
def testPandasInputFn_ExcludesIndex(self):
|
||||
if not HAS_PANDAS:
|
||||
return
|
||||
with self.test_session() as session:
|
||||
x, y = self.makeTestDataFrame()
|
||||
input_fn = pandas_io.pandas_input_fn(
|
||||
x, y, batch_size=2, shuffle=False, num_epochs=1)
|
||||
|
||||
features, _ = self.callInputFnOnce(input_fn, session)
|
||||
|
||||
self.assertFalse('index' in features)
|
||||
|
||||
def assertInputsCallableNTimes(self, input_fn, session, n):
|
||||
inputs = input_fn()
|
||||
coord = tf.train.Coordinator()
|
||||
|
Loading…
Reference in New Issue
Block a user