Merge pull request #19388 from nrstott/pandas_input_fn_y_as_df

Pandas input fn accepts DataFrame for Y
This commit is contained in:
Yifei Feng 2018-07-02 15:07:51 -07:00 committed by GitHub
commit 16a965c5c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 107 additions and 4 deletions

View File

@ -18,6 +18,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import uuid
import numpy as np
from tensorflow.python.estimator.inputs.queues import feeding_functions
@ -35,6 +37,22 @@ except ImportError:
HAS_PANDAS = False
def _get_unique_target_key(features, target_column_name):
"""Returns a key that does not exist in the input DataFrame `features`.
Args:
features: DataFrame
target_column_name: Name of the target column as a `str`
Returns:
A unique key that can be used to insert the target into
features.
"""
if target_column_name in features:
target_column_name += '_' + str(uuid.uuid4())
return target_column_name
@estimator_export('estimator.inputs.pandas_input_fn')
def pandas_input_fn(x,
y=None,
@ -50,7 +68,7 @@ def pandas_input_fn(x,
Args:
x: pandas `DataFrame` object.
y: pandas `Series` object. `None` if absent.
y: pandas `Series` object or `DataFrame`. `None` if absent.
batch_size: int, size of batches to return.
num_epochs: int, number of epochs to iterate over data. If not `None`,
read attempts that would exceed this value will raise `OutOfRangeError`.
@ -60,7 +78,8 @@ def pandas_input_fn(x,
num_threads: Integer, number of threads used for reading and enqueueing. In
order to have predicted and repeatable order of reading and enqueueing,
such as in prediction and evaluation mode, `num_threads` should be 1.
target_column: str, name to give the target column `y`.
target_column: str, name to give the target column `y`. This parameter
is not used when `y` is a `DataFrame`.
Returns:
Function, that has signature of ()->(dict of `features`, `target`)
@ -79,6 +98,9 @@ def pandas_input_fn(x,
'(it is recommended to set it as True for training); '
'got {}'.format(shuffle))
if not isinstance(target_column, six.string_types):
raise TypeError('target_column must be a string type')
x = x.copy()
if y is not None:
if target_column in x:
@ -88,7 +110,13 @@ def pandas_input_fn(x,
if not np.array_equal(x.index, y.index):
raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n'
'Index for y: %s\n' % (x.index, y.index))
x[target_column] = y
if isinstance(y, pd.DataFrame):
y_columns = [(column, _get_unique_target_key(x, column))
for column in list(y)]
target_column = [v for _, v in y_columns]
x[target_column] = y
else:
x[target_column] = y
# TODO(mdan): These are memory copies. We probably don't need 4x slack space.
# The sizes below are consistent with what I've seen elsewhere.
@ -118,7 +146,12 @@ def pandas_input_fn(x,
features = features[1:]
features = dict(zip(list(x.columns), features))
if y is not None:
target = features.pop(target_column)
if isinstance(target_column, list):
keys = [k for k, _ in y_columns]
values = [features.pop(column) for column in target_column]
target = {k: v for k, v in zip(keys, values)}
else:
target = features.pop(target_column)
return features, target
return features
return input_fn

View File

@ -47,6 +47,16 @@ class PandasIoTest(test.TestCase):
y = pd.Series(np.arange(-32, -28), index=index)
return x, y
def makeTestDataFrameWithYAsDataFrame(self):
index = np.arange(100, 104)
a = np.arange(4)
b = np.arange(32, 36)
a_label = np.arange(10, 14)
b_label = np.arange(50, 54)
x = pd.DataFrame({'a': a, 'b': b}, index=index)
y = pd.DataFrame({'a_target': a_label, 'b_target': b_label}, index=index)
return x, y
def callInputFnOnce(self, input_fn, session):
results = input_fn()
coord = coordinator.Coordinator()
@ -65,6 +75,19 @@ class PandasIoTest(test.TestCase):
pandas_io.pandas_input_fn(
x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
def testPandasInputFn_RaisesWhenTargetColumnIsAList(self):
if not HAS_PANDAS:
return
x, y = self.makeTestDataFrame()
with self.assertRaisesRegexp(TypeError,
'target_column must be a string type'):
pandas_io.pandas_input_fn(x, y, batch_size=2,
shuffle=False,
num_epochs=1,
target_column=['one', 'two'])
def testPandasInputFn_NonBoolShuffle(self):
if not HAS_PANDAS:
return
@ -90,6 +113,53 @@ class PandasIoTest(test.TestCase):
self.assertAllEqual(features['b'], [32, 33])
self.assertAllEqual(target, [-32, -31])
def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self):
if not HAS_PANDAS:
return
with self.test_session() as session:
x, y = self.makeTestDataFrameWithYAsDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features, targets = self.callInputFnOnce(input_fn, session)
self.assertAllEqual(features['a'], [0, 1])
self.assertAllEqual(features['b'], [32, 33])
self.assertAllEqual(targets['a_target'], [10, 11])
self.assertAllEqual(targets['b_target'], [50, 51])
def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self):
if not HAS_PANDAS:
return
with self.test_session() as session:
x, y = self.makeTestDataFrameWithYAsDataFrame()
y = y.rename(columns={'a_target': 'a', 'b_target': 'b'})
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features, targets = self.callInputFnOnce(input_fn, session)
self.assertAllEqual(features['a'], [0, 1])
self.assertAllEqual(features['b'], [32, 33])
self.assertAllEqual(targets['a'], [10, 11])
self.assertAllEqual(targets['b'], [50, 51])
def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self):
if not HAS_PANDAS:
return
with self.test_session() as session:
x, y = self.makeTestDataFrameWithYAsDataFrame()
y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'})
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features, targets = self.callInputFnOnce(input_fn, session)
self.assertAllEqual(features['a'], [0, 1])
self.assertAllEqual(features['b'], [32, 33])
self.assertAllEqual(targets['a'], [10, 11])
self.assertAllEqual(targets['a_n'], [50, 51])
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return