Adds NumpySource and PandasSource, which read pydata into Columns.
Change: 123445673
This commit is contained in:
parent
09e75d408b
commit
c922ed63e3
@ -149,6 +149,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_in_memory_source",
|
||||
size = "small",
|
||||
srcs = ["python/learn/tests/dataframe/test_in_memory_source.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_early_stopping",
|
||||
size = "medium",
|
||||
|
@ -27,7 +27,8 @@ from tensorflow.contrib.learn.python.learn.dataframe.transform import parameter
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transform import Transform
|
||||
|
||||
# Transforms
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import NumpySource
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import PandasSource
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source import ReaderSource
|
||||
|
||||
|
||||
__all__ = ['Column', 'TransformedColumn', 'DataFrame', 'parameter', 'Transform']
|
||||
|
@ -21,8 +21,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.learn.python.learn.dataframe.queues.feeding_queue_runner as fqr
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_queue_runner as fqr
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.training import queue_runner
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
@ -111,12 +115,13 @@ def enqueue_data(data,
|
||||
"""
|
||||
# TODO(jamieas): create multithreaded version of enqueue_data.
|
||||
if isinstance(data, np.ndarray):
|
||||
dtypes = [tf.int64, tf.as_dtype(data.dtype)]
|
||||
types = [dtypes.int64, dtypes.as_dtype(data.dtype)]
|
||||
shapes = [(), data.shape[1:]]
|
||||
get_feed_fn = _ArrayFeedFn
|
||||
elif HAS_PANDAS and isinstance(data, pd.DataFrame):
|
||||
dtypes = [tf.as_dtype(dt) for dt in [data.index.dtype] + list(data.dtypes)]
|
||||
shapes = [() for _ in dtypes]
|
||||
types = [dtypes.as_dtype(dt)
|
||||
for dt in [data.index.dtype] + list(data.dtypes)]
|
||||
shapes = [() for _ in types]
|
||||
get_feed_fn = _PandasFeedFn
|
||||
else:
|
||||
raise TypeError(
|
||||
@ -124,22 +129,22 @@ def enqueue_data(data,
|
||||
"installed; got {}".format(
|
||||
type(data).__name__))
|
||||
|
||||
placeholders = [tf.placeholder(*type_and_shape)
|
||||
for type_and_shape in zip(dtypes, shapes)]
|
||||
placeholders = [array_ops.placeholder(*type_and_shape)
|
||||
for type_and_shape in zip(types, shapes)]
|
||||
if shuffle:
|
||||
min_after_dequeue = (capacity / 4 if min_after_dequeue is None else
|
||||
min_after_dequeue)
|
||||
queue = tf.RandomShuffleQueue(capacity,
|
||||
min_after_dequeue,
|
||||
dtypes=dtypes,
|
||||
shapes=shapes,
|
||||
seed=seed)
|
||||
queue = data_flow_ops.RandomShuffleQueue(capacity,
|
||||
min_after_dequeue,
|
||||
dtypes=types,
|
||||
shapes=shapes,
|
||||
seed=seed)
|
||||
else:
|
||||
queue = tf.FIFOQueue(capacity, dtypes=dtypes, shapes=shapes)
|
||||
queue = data_flow_ops.FIFOQueue(capacity, dtypes=types, shapes=shapes)
|
||||
enqueue_op = queue.enqueue(placeholders)
|
||||
feed_fn = get_feed_fn(placeholders, data)
|
||||
queue_runner = fqr.FeedingQueueRunner(queue=queue,
|
||||
enqueue_ops=[enqueue_op],
|
||||
feed_fn=feed_fn)
|
||||
tf.train.add_queue_runner(queue_runner)
|
||||
runner = fqr.FeedingQueueRunner(queue=queue,
|
||||
enqueue_ops=[enqueue_op],
|
||||
feed_fn=feed_fn)
|
||||
queue_runner.add_queue_runner(runner)
|
||||
return queue
|
||||
|
@ -0,0 +1,122 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Sources for numpy arrays and pandas DataFrames."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_functions
|
||||
|
||||
|
||||
class BaseInMemorySource(transform.Transform):
|
||||
"""Abstract parent class for NumpySource and PandasSource."""
|
||||
|
||||
def __init__(self,
|
||||
data,
|
||||
batch_size=None,
|
||||
queue_capacity=None,
|
||||
shuffle=False,
|
||||
min_after_dequeue=None,
|
||||
seed=None):
|
||||
super(BaseInMemorySource, self).__init__()
|
||||
self._data = data
|
||||
self._batch_size = (1 if batch_size is None else batch_size)
|
||||
self._queue_capacity = (self._batch_size * 10 if batch_size is None
|
||||
else batch_size)
|
||||
self._shuffle = shuffle
|
||||
self._min_after_dequeue = (batch_size if min_after_dequeue is None
|
||||
else min_after_dequeue)
|
||||
self._seed = seed
|
||||
|
||||
@transform.parameter
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@transform.parameter
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@transform.parameter
|
||||
def queue_capacity(self):
|
||||
return self._queue_capacity
|
||||
|
||||
@transform.parameter
|
||||
def shuffle(self):
|
||||
return self._shuffle
|
||||
|
||||
@transform.parameter
|
||||
def min_after_dequeue(self):
|
||||
return self._min_after_dequeue
|
||||
|
||||
@transform.parameter
|
||||
def seed(self):
|
||||
return self._seed
|
||||
|
||||
@property
|
||||
def input_valency(self):
|
||||
return 0
|
||||
|
||||
def _apply_transform(self, transform_input):
|
||||
queue = feeding_functions.enqueue_data(
|
||||
self.data, self.queue_capacity, self.shuffle, self.min_after_dequeue)
|
||||
|
||||
dequeued = queue.dequeue_many(self.batch_size)
|
||||
|
||||
# TODO(jamieas): dequeue and dequeue_many will soon return a list regardless
|
||||
# of the number of enqueued tensors. Remove the following once that change
|
||||
# is in place.
|
||||
if not isinstance(dequeued, (tuple, list)):
|
||||
dequeued = (dequeued,)
|
||||
# pylint: disable=not-callable
|
||||
return self.return_type(*dequeued)
|
||||
|
||||
|
||||
class NumpySource(BaseInMemorySource):
|
||||
"""A zero-input Transform that produces a single column from a numpy array."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "NumpySource"
|
||||
|
||||
@property
|
||||
def _output_names(self):
|
||||
return ("index", "value")
|
||||
|
||||
|
||||
class PandasSource(BaseInMemorySource):
|
||||
"""A zero-input Transform that produces Series from a DataFrame."""
|
||||
|
||||
def __init__(self,
|
||||
dataframe,
|
||||
batch_size=None,
|
||||
queue_capacity=None,
|
||||
shuffle=False,
|
||||
min_after_dequeue=None,
|
||||
seed=None):
|
||||
if "index" in dataframe.columns:
|
||||
raise ValueError("Column name `index` is reserved.")
|
||||
super(PandasSource, self).__init__(dataframe, batch_size, queue_capacity,
|
||||
shuffle, min_after_dequeue, seed)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "PandasSource"
|
||||
|
||||
@property
|
||||
def _output_names(self):
|
||||
return tuple(["index"] + self._data.columns.tolist())
|
@ -0,0 +1,102 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests NumpySource and PandasSource."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import in_memory_source
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
import pandas as pd
|
||||
HAS_PANDAS = True
|
||||
except ImportError:
|
||||
HAS_PANDAS = False
|
||||
|
||||
|
||||
def get_rows(array, row_indices):
|
||||
rows = [array[i] for i in row_indices]
|
||||
return np.vstack(rows)
|
||||
|
||||
|
||||
class NumpySourceTestCase(tf.test.TestCase):
|
||||
|
||||
def testNumpySource(self):
|
||||
batch_size = 3
|
||||
iterations = 1000
|
||||
array = np.arange(32).reshape([16, 2])
|
||||
numpy_source = in_memory_source.NumpySource(array, batch_size)
|
||||
index_column = numpy_source().index
|
||||
value_column = numpy_source().value
|
||||
cache = {}
|
||||
with tf.Graph().as_default():
|
||||
value_tensor = value_column.build(cache)
|
||||
index_tensor = index_column.build(cache)
|
||||
with tf.Session() as sess:
|
||||
coord = tf.train.Coordinator()
|
||||
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
|
||||
for i in range(iterations):
|
||||
expected_index = [
|
||||
j % array.shape[0]
|
||||
for j in range(batch_size * i, batch_size * (i + 1))
|
||||
]
|
||||
expected_value = get_rows(array, expected_index)
|
||||
actual_index, actual_value = sess.run([index_tensor, value_tensor])
|
||||
np.testing.assert_array_equal(expected_index, actual_index)
|
||||
np.testing.assert_array_equal(expected_value, actual_value)
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
|
||||
class PandasSourceTestCase(tf.test.TestCase):
|
||||
|
||||
def testPandasFeeding(self):
|
||||
if not HAS_PANDAS:
|
||||
return
|
||||
batch_size = 3
|
||||
iterations = 1000
|
||||
index = np.arange(100, 132)
|
||||
a = np.arange(32)
|
||||
b = np.arange(32, 64)
|
||||
dataframe = pd.DataFrame({"a": a, "b": b}, index=index)
|
||||
pandas_source = in_memory_source.PandasSource(dataframe, batch_size)
|
||||
pandas_columns = pandas_source()
|
||||
cache = {}
|
||||
with tf.Graph().as_default():
|
||||
pandas_tensors = [col.build(cache) for col in pandas_columns]
|
||||
with tf.Session() as sess:
|
||||
coord = tf.train.Coordinator()
|
||||
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
|
||||
for i in range(iterations):
|
||||
indices = [j % dataframe.shape[0]
|
||||
for j in range(batch_size * i, batch_size * (i + 1))]
|
||||
expected_df_indices = dataframe.index[indices]
|
||||
expected_rows = dataframe.iloc[indices]
|
||||
actual_value = sess.run(pandas_tensors)
|
||||
np.testing.assert_array_equal(expected_df_indices, actual_value[0])
|
||||
for col_num, col in enumerate(dataframe.columns):
|
||||
np.testing.assert_array_equal(expected_rows[col].values,
|
||||
actual_value[col_num + 1])
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
Loading…
x
Reference in New Issue
Block a user