Exclude the index column in pandas_input_fn - it causes errors when used with infer_real_valued_columns_from_input_fn and canned estimators.

Change: 142719058
This commit is contained in:
A. Unique TensorFlower 2016-12-21 17:35:57 -08:00 committed by TensorFlower Gardener
parent a439d9975e
commit 65a983af2c
2 changed files with 25 additions and 8 deletions

View File

@ -122,9 +122,14 @@ def extract_pandas_labels(labels):
return labels
def pandas_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=True,
queue_capacity=1000, num_threads=1, target_column='target',
index_column='index'):
def pandas_input_fn(x,
y=None,
batch_size=128,
num_epochs=1,
shuffle=True,
queue_capacity=1000,
num_threads=1,
target_column='target'):
"""Returns input function that would feed Pandas DataFrame into the model.
Note: `y`'s index must match `x`'s index.
@ -140,7 +145,6 @@ def pandas_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=True,
roughly to the size of `x`.
num_threads: int, number of threads used for reading and enqueueing.
target_column: str, name to give the target column `y`.
index_column: str, name of the index column.
Returns:
Function, that has signature of ()->(dict of `features`, `target`)
@ -183,7 +187,10 @@ def pandas_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=True,
features = queue.dequeue_many(batch_size)
else:
features = queue.dequeue_up_to(batch_size)
features = dict(zip([index_column] + list(x.columns), features))
assert len(features) == len(x.columns) + 1, ('Features should have one '
'extra element for the index.')
features = features[1:]
features = dict(zip(list(x.columns), features))
if y is not None:
target = features.pop(target_column)
return features, target

View File

@ -57,7 +57,7 @@ class PandasIoTest(tf.test.TestCase):
x, _ = self.makeTestDataFrame()
y_noindex = pd.Series(np.arange(-32, -28))
with self.assertRaises(ValueError):
failing_input_fn = pandas_io.pandas_input_fn(
pandas_io.pandas_input_fn(
x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
def testPandasInputFn_ProducesExpectedOutputs(self):
@ -70,7 +70,6 @@ class PandasIoTest(tf.test.TestCase):
features, target = self.callInputFnOnce(input_fn, session)
self.assertAllEqual(features['index'], [100, 101])
self.assertAllEqual(features['a'], [0, 1])
self.assertAllEqual(features['b'], [32, 33])
self.assertAllEqual(target, [-32, -31])
@ -85,10 +84,21 @@ class PandasIoTest(tf.test.TestCase):
features = self.callInputFnOnce(input_fn, session)
self.assertAllEqual(features['index'], [100, 101])
self.assertAllEqual(features['a'], [0, 1])
self.assertAllEqual(features['b'], [32, 33])
def testPandasInputFn_ExcludesIndex(self):
if not HAS_PANDAS:
return
with self.test_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features, _ = self.callInputFnOnce(input_fn, session)
self.assertFalse('index' in features)
def assertInputsCallableNTimes(self, input_fn, session, n):
inputs = input_fn()
coord = tf.train.Coordinator()