Exclude the index column in pandas_input_fn - it causes errors when used with infer_real_valued_columns_from_input_fn and canned estimators.
Change: 142719058
This commit is contained in:
parent
a439d9975e
commit
65a983af2c
@ -122,9 +122,14 @@ def extract_pandas_labels(labels):
|
|||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
def pandas_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=True,
|
def pandas_input_fn(x,
|
||||||
queue_capacity=1000, num_threads=1, target_column='target',
|
y=None,
|
||||||
index_column='index'):
|
batch_size=128,
|
||||||
|
num_epochs=1,
|
||||||
|
shuffle=True,
|
||||||
|
queue_capacity=1000,
|
||||||
|
num_threads=1,
|
||||||
|
target_column='target'):
|
||||||
"""Returns input function that would feed Pandas DataFrame into the model.
|
"""Returns input function that would feed Pandas DataFrame into the model.
|
||||||
|
|
||||||
Note: `y`'s index must match `x`'s index.
|
Note: `y`'s index must match `x`'s index.
|
||||||
@ -140,7 +145,6 @@ def pandas_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=True,
|
|||||||
roughly to the size of `x`.
|
roughly to the size of `x`.
|
||||||
num_threads: int, number of threads used for reading and enqueueing.
|
num_threads: int, number of threads used for reading and enqueueing.
|
||||||
target_column: str, name to give the target column `y`.
|
target_column: str, name to give the target column `y`.
|
||||||
index_column: str, name of the index column.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Function, that has signature of ()->(dict of `features`, `target`)
|
Function, that has signature of ()->(dict of `features`, `target`)
|
||||||
@ -183,7 +187,10 @@ def pandas_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=True,
|
|||||||
features = queue.dequeue_many(batch_size)
|
features = queue.dequeue_many(batch_size)
|
||||||
else:
|
else:
|
||||||
features = queue.dequeue_up_to(batch_size)
|
features = queue.dequeue_up_to(batch_size)
|
||||||
features = dict(zip([index_column] + list(x.columns), features))
|
assert len(features) == len(x.columns) + 1, ('Features should have one '
|
||||||
|
'extra element for the index.')
|
||||||
|
features = features[1:]
|
||||||
|
features = dict(zip(list(x.columns), features))
|
||||||
if y is not None:
|
if y is not None:
|
||||||
target = features.pop(target_column)
|
target = features.pop(target_column)
|
||||||
return features, target
|
return features, target
|
||||||
|
@ -57,7 +57,7 @@ class PandasIoTest(tf.test.TestCase):
|
|||||||
x, _ = self.makeTestDataFrame()
|
x, _ = self.makeTestDataFrame()
|
||||||
y_noindex = pd.Series(np.arange(-32, -28))
|
y_noindex = pd.Series(np.arange(-32, -28))
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
failing_input_fn = pandas_io.pandas_input_fn(
|
pandas_io.pandas_input_fn(
|
||||||
x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
|
x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
|
||||||
|
|
||||||
def testPandasInputFn_ProducesExpectedOutputs(self):
|
def testPandasInputFn_ProducesExpectedOutputs(self):
|
||||||
@ -70,7 +70,6 @@ class PandasIoTest(tf.test.TestCase):
|
|||||||
|
|
||||||
features, target = self.callInputFnOnce(input_fn, session)
|
features, target = self.callInputFnOnce(input_fn, session)
|
||||||
|
|
||||||
self.assertAllEqual(features['index'], [100, 101])
|
|
||||||
self.assertAllEqual(features['a'], [0, 1])
|
self.assertAllEqual(features['a'], [0, 1])
|
||||||
self.assertAllEqual(features['b'], [32, 33])
|
self.assertAllEqual(features['b'], [32, 33])
|
||||||
self.assertAllEqual(target, [-32, -31])
|
self.assertAllEqual(target, [-32, -31])
|
||||||
@ -85,10 +84,21 @@ class PandasIoTest(tf.test.TestCase):
|
|||||||
|
|
||||||
features = self.callInputFnOnce(input_fn, session)
|
features = self.callInputFnOnce(input_fn, session)
|
||||||
|
|
||||||
self.assertAllEqual(features['index'], [100, 101])
|
|
||||||
self.assertAllEqual(features['a'], [0, 1])
|
self.assertAllEqual(features['a'], [0, 1])
|
||||||
self.assertAllEqual(features['b'], [32, 33])
|
self.assertAllEqual(features['b'], [32, 33])
|
||||||
|
|
||||||
|
def testPandasInputFn_ExcludesIndex(self):
|
||||||
|
if not HAS_PANDAS:
|
||||||
|
return
|
||||||
|
with self.test_session() as session:
|
||||||
|
x, y = self.makeTestDataFrame()
|
||||||
|
input_fn = pandas_io.pandas_input_fn(
|
||||||
|
x, y, batch_size=2, shuffle=False, num_epochs=1)
|
||||||
|
|
||||||
|
features, _ = self.callInputFnOnce(input_fn, session)
|
||||||
|
|
||||||
|
self.assertFalse('index' in features)
|
||||||
|
|
||||||
def assertInputsCallableNTimes(self, input_fn, session, n):
|
def assertInputsCallableNTimes(self, input_fn, session, n):
|
||||||
inputs = input_fn()
|
inputs = input_fn()
|
||||||
coord = tf.train.Coordinator()
|
coord = tf.train.Coordinator()
|
||||||
|
Loading…
Reference in New Issue
Block a user