Merge pull request #19388 from nrstott/pandas_input_fn_y_as_df
Pandas input fn accepts DataFrame for Y
This commit is contained in:
commit
16a965c5c9
@ -18,6 +18,8 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import six
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.python.estimator.inputs.queues import feeding_functions
|
||||
@ -35,6 +37,22 @@ except ImportError:
|
||||
HAS_PANDAS = False
|
||||
|
||||
|
||||
def _get_unique_target_key(features, target_column_name):
|
||||
"""Returns a key that does not exist in the input DataFrame `features`.
|
||||
|
||||
Args:
|
||||
features: DataFrame
|
||||
target_column_name: Name of the target column as a `str`
|
||||
|
||||
Returns:
|
||||
A unique key that can be used to insert the target into
|
||||
features.
|
||||
"""
|
||||
if target_column_name in features:
|
||||
target_column_name += '_' + str(uuid.uuid4())
|
||||
return target_column_name
|
||||
|
||||
|
||||
@estimator_export('estimator.inputs.pandas_input_fn')
|
||||
def pandas_input_fn(x,
|
||||
y=None,
|
||||
@ -50,7 +68,7 @@ def pandas_input_fn(x,
|
||||
|
||||
Args:
|
||||
x: pandas `DataFrame` object.
|
||||
y: pandas `Series` object. `None` if absent.
|
||||
y: pandas `Series` object or `DataFrame`. `None` if absent.
|
||||
batch_size: int, size of batches to return.
|
||||
num_epochs: int, number of epochs to iterate over data. If not `None`,
|
||||
read attempts that would exceed this value will raise `OutOfRangeError`.
|
||||
@ -60,7 +78,8 @@ def pandas_input_fn(x,
|
||||
num_threads: Integer, number of threads used for reading and enqueueing. In
|
||||
order to have predicted and repeatable order of reading and enqueueing,
|
||||
such as in prediction and evaluation mode, `num_threads` should be 1.
|
||||
target_column: str, name to give the target column `y`.
|
||||
target_column: str, name to give the target column `y`. This parameter
|
||||
is not used when `y` is a `DataFrame`.
|
||||
|
||||
Returns:
|
||||
Function, that has signature of ()->(dict of `features`, `target`)
|
||||
@ -79,6 +98,9 @@ def pandas_input_fn(x,
|
||||
'(it is recommended to set it as True for training); '
|
||||
'got {}'.format(shuffle))
|
||||
|
||||
if not isinstance(target_column, six.string_types):
|
||||
raise TypeError('target_column must be a string type')
|
||||
|
||||
x = x.copy()
|
||||
if y is not None:
|
||||
if target_column in x:
|
||||
@ -88,7 +110,13 @@ def pandas_input_fn(x,
|
||||
if not np.array_equal(x.index, y.index):
|
||||
raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n'
|
||||
'Index for y: %s\n' % (x.index, y.index))
|
||||
x[target_column] = y
|
||||
if isinstance(y, pd.DataFrame):
|
||||
y_columns = [(column, _get_unique_target_key(x, column))
|
||||
for column in list(y)]
|
||||
target_column = [v for _, v in y_columns]
|
||||
x[target_column] = y
|
||||
else:
|
||||
x[target_column] = y
|
||||
|
||||
# TODO(mdan): These are memory copies. We probably don't need 4x slack space.
|
||||
# The sizes below are consistent with what I've seen elsewhere.
|
||||
@ -118,7 +146,12 @@ def pandas_input_fn(x,
|
||||
features = features[1:]
|
||||
features = dict(zip(list(x.columns), features))
|
||||
if y is not None:
|
||||
target = features.pop(target_column)
|
||||
if isinstance(target_column, list):
|
||||
keys = [k for k, _ in y_columns]
|
||||
values = [features.pop(column) for column in target_column]
|
||||
target = {k: v for k, v in zip(keys, values)}
|
||||
else:
|
||||
target = features.pop(target_column)
|
||||
return features, target
|
||||
return features
|
||||
return input_fn
|
||||
|
@ -47,6 +47,16 @@ class PandasIoTest(test.TestCase):
|
||||
y = pd.Series(np.arange(-32, -28), index=index)
|
||||
return x, y
|
||||
|
||||
def makeTestDataFrameWithYAsDataFrame(self):
|
||||
index = np.arange(100, 104)
|
||||
a = np.arange(4)
|
||||
b = np.arange(32, 36)
|
||||
a_label = np.arange(10, 14)
|
||||
b_label = np.arange(50, 54)
|
||||
x = pd.DataFrame({'a': a, 'b': b}, index=index)
|
||||
y = pd.DataFrame({'a_target': a_label, 'b_target': b_label}, index=index)
|
||||
return x, y
|
||||
|
||||
def callInputFnOnce(self, input_fn, session):
|
||||
results = input_fn()
|
||||
coord = coordinator.Coordinator()
|
||||
@ -65,6 +75,19 @@ class PandasIoTest(test.TestCase):
|
||||
pandas_io.pandas_input_fn(
|
||||
x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
|
||||
|
||||
def testPandasInputFn_RaisesWhenTargetColumnIsAList(self):
|
||||
if not HAS_PANDAS:
|
||||
return
|
||||
|
||||
x, y = self.makeTestDataFrame()
|
||||
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'target_column must be a string type'):
|
||||
pandas_io.pandas_input_fn(x, y, batch_size=2,
|
||||
shuffle=False,
|
||||
num_epochs=1,
|
||||
target_column=['one', 'two'])
|
||||
|
||||
def testPandasInputFn_NonBoolShuffle(self):
|
||||
if not HAS_PANDAS:
|
||||
return
|
||||
@ -90,6 +113,53 @@ class PandasIoTest(test.TestCase):
|
||||
self.assertAllEqual(features['b'], [32, 33])
|
||||
self.assertAllEqual(target, [-32, -31])
|
||||
|
||||
def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self):
|
||||
if not HAS_PANDAS:
|
||||
return
|
||||
with self.test_session() as session:
|
||||
x, y = self.makeTestDataFrameWithYAsDataFrame()
|
||||
input_fn = pandas_io.pandas_input_fn(
|
||||
x, y, batch_size=2, shuffle=False, num_epochs=1)
|
||||
|
||||
features, targets = self.callInputFnOnce(input_fn, session)
|
||||
|
||||
self.assertAllEqual(features['a'], [0, 1])
|
||||
self.assertAllEqual(features['b'], [32, 33])
|
||||
self.assertAllEqual(targets['a_target'], [10, 11])
|
||||
self.assertAllEqual(targets['b_target'], [50, 51])
|
||||
|
||||
def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self):
|
||||
if not HAS_PANDAS:
|
||||
return
|
||||
with self.test_session() as session:
|
||||
x, y = self.makeTestDataFrameWithYAsDataFrame()
|
||||
y = y.rename(columns={'a_target': 'a', 'b_target': 'b'})
|
||||
input_fn = pandas_io.pandas_input_fn(
|
||||
x, y, batch_size=2, shuffle=False, num_epochs=1)
|
||||
|
||||
features, targets = self.callInputFnOnce(input_fn, session)
|
||||
|
||||
self.assertAllEqual(features['a'], [0, 1])
|
||||
self.assertAllEqual(features['b'], [32, 33])
|
||||
self.assertAllEqual(targets['a'], [10, 11])
|
||||
self.assertAllEqual(targets['b'], [50, 51])
|
||||
|
||||
def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self):
|
||||
if not HAS_PANDAS:
|
||||
return
|
||||
with self.test_session() as session:
|
||||
x, y = self.makeTestDataFrameWithYAsDataFrame()
|
||||
y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'})
|
||||
input_fn = pandas_io.pandas_input_fn(
|
||||
x, y, batch_size=2, shuffle=False, num_epochs=1)
|
||||
|
||||
features, targets = self.callInputFnOnce(input_fn, session)
|
||||
|
||||
self.assertAllEqual(features['a'], [0, 1])
|
||||
self.assertAllEqual(features['b'], [32, 33])
|
||||
self.assertAllEqual(targets['a'], [10, 11])
|
||||
self.assertAllEqual(targets['a_n'], [50, 51])
|
||||
|
||||
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
|
||||
if not HAS_PANDAS:
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user