Inflow: establish Column, Transform, and DataFrame abstractions.

Part of a series of CLs setting up the minimal Inflow.
Change: 123072135
This commit is contained in:
David Soergel 2016-05-23 21:06:23 -08:00 committed by TensorFlower Gardener
parent 977c933de8
commit 26f54a9fcd
11 changed files with 963 additions and 0 deletions

View File

@ -75,6 +75,42 @@ py_test(
],
)
py_test(
name = "test_dataframe",
size = "small",
srcs = ["python/learn/tests/dataframe/test_dataframe.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
],
)
py_test(
name = "test_column",
size = "small",
srcs = ["python/learn/tests/dataframe/test_column.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
],
)
py_test(
name = "test_transform",
size = "small",
srcs = ["python/learn/tests/dataframe/test_transform.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
],
)
py_test(
name = "test_early_stopping",
size = "medium",

View File

@ -29,6 +29,7 @@ from tensorflow.contrib.learn.python.learn import ops
from tensorflow.contrib.learn.python.learn import preprocessing
from tensorflow.contrib.learn.python.learn import utils
# pylint: disable=wildcard-import
from tensorflow.contrib.learn.python.learn.dataframe import *
from tensorflow.contrib.learn.python.learn.estimators import *
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
from tensorflow.contrib.learn.python.learn.graph_actions import infer

View File

@ -0,0 +1,26 @@
"""DataFrames for ingesting and preprocessing data."""
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.dataframe.column import Column
from tensorflow.contrib.learn.python.learn.dataframe.column import TransformedColumn
from tensorflow.contrib.learn.python.learn.dataframe.dataframe import DataFrame
from tensorflow.contrib.learn.python.learn.dataframe.transform import parameter
from tensorflow.contrib.learn.python.learn.dataframe.transform import Transform
__all__ = ['Column', 'TransformedColumn', 'DataFrame', 'parameter', 'Transform']

View File

@ -0,0 +1,93 @@
"""A Column represents a deferred Tensor computation in a DataFrame."""
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import ABCMeta
class Column(object):
"""A single output column.
Represents the deferred construction of a graph that computes the column
values.
Note every `Column` should be a `TransformedColumn`, except when mocked.
"""
__metaclass__ = ABCMeta
def build(self, cache):
"""Returns a Tensor."""
raise NotImplementedError()
class TransformedColumn(Column):
"""A `Column` that results from applying a `Transform` to a list of inputs."""
def __init__(self, input_columns, transform, output_name):
super(TransformedColumn, self).__init__()
self._input_columns = input_columns
self._transform = transform
self._output_name = output_name
if output_name is None:
raise ValueError("output_name must be provided")
if len(input_columns) != transform.input_valency:
raise ValueError("Expected %s input Columns but received %s." %
(transform.input_valency, len(input_columns)))
self._repr = TransformedColumn.make_repr(
self._input_columns, self._transform, self._output_name)
def build(self, cache=None):
if cache is None:
cache = {}
all_outputs = self._transform.apply_transform(self._input_columns, cache)
return getattr(all_outputs, self._output_name)
def __repr__(self):
return self._repr
# Note we need to generate column reprs from Transform, without needing the
# columns themselves. So we just make this public. Alternatively we could
# create throwaway columns just in order to call repr() on them.
@staticmethod
def make_repr(input_columns, transform, output_name):
"""Generate a key for caching Tensors produced for a TransformedColumn.
Generally we a need a deterministic unique key representing which transform
was applied to which inputs, and which output was selected.
Args:
input_columns: the input `Columns` for the `Transform`
transform: the `Transform` being applied
output_name: the name of the specific output from the `Transform` that is
to be cached
Returns:
A string suitable for use as a cache key for Tensors produced via a
TransformedColumn
"""
input_column_keys = [repr(column) for column in input_columns]
input_column_keys_joined = ", ".join(input_column_keys)
return "%s(%s)[%s]" % (
repr(transform), input_column_keys_joined, output_name)

View File

@ -0,0 +1,124 @@
"""A DataFrame is a container for ingesting and preprocessing data."""
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import ABCMeta
import collections
from .column import Column
from .transform import Transform
class DataFrame(object):
"""A DataFrame is a container for ingesting and preprocessing data."""
__metaclass__ = ABCMeta
def __init__(self):
self._columns = {}
def columns(self):
"""Set of the column names."""
return frozenset(self._columns.keys())
def __len__(self):
"""The number of columns in the DataFrame."""
return len(self._columns)
def assign(self, **kwargs):
"""Adds columns to DataFrame.
Args:
**kwargs: assignments of the form key=value where key is a string
and value is an `inflow.Series`, a `pandas.Series` or a numpy array.
Raises:
TypeError: keys are not strings.
TypeError: values are not `inflow.Series`, `pandas.Series` or
`numpy.ndarray`.
TODO(jamieas): pandas assign method returns a new DataFrame. Consider
switching to this behavior, changing the name or adding in_place as an
argument.
"""
for k, v in kwargs.items():
if not isinstance(k, str):
raise TypeError("The only supported type for keys is string; got %s" %
type(k))
if isinstance(v, Column):
s = v
elif isinstance(v, Transform) and v.input_valency() == 0:
s = v()
# TODO(jamieas): hook up these special cases again
# TODO(soergel): can these special cases be generalized?
# elif isinstance(v, pd.Series):
# s = series.NumpySeries(v.values)
# elif isinstance(v, np.ndarray):
# s = series.NumpySeries(v)
else:
raise TypeError(
"Column in assignment must be an inflow.Column, pandas.Series or a"
" numpy array; got type '%s'." % type(v).__name__)
self._columns[k] = s
def select(self, keys):
"""Returns a new DataFrame with a subset of columns.
Args:
keys: A list of strings. Each should be the name of a column in the
DataFrame.
Returns:
A new DataFrame containing only the specified columns.
"""
result = type(self)()
for key in keys:
result[key] = self._columns[key]
return result
def __getitem__(self, key):
"""Indexing functionality for DataFrames.
Args:
key: a string or an iterable of strings.
Returns:
A Series or list of Series corresponding to the given keys.
"""
if isinstance(key, str):
return self._columns[key]
elif isinstance(key, collections.Iterable):
for i in key:
if not isinstance(i, str):
raise TypeError("Expected a String; entry %s has type %s." %
(i, type(i).__name__))
return [self.__getitem__(i) for i in key]
raise TypeError(
"Invalid index: %s of type %s. Only strings or lists of strings are "
"supported." % (key, type(key)))
def __setitem__(self, key, value):
if isinstance(key, str):
key = [key]
if isinstance(value, Column):
value = [value]
self.assign(**dict(zip(key, value)))
def build(self, cache=None):
if cache is None:
cache = {}
tensors = {name: c.build(cache) for name, c in self._columns.items()}
return tensors

View File

@ -0,0 +1,287 @@
"""A Transform takes a list of `Column` and returns a namedtuple of `Column`."""
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import ABCMeta
from abc import abstractmethod
from abc import abstractproperty
import collections
import inspect
from .column import Column
from .column import TransformedColumn
def _make_list_of_column(x):
"""Converts `x` into a list of `Column` if possible.
Args:
x: a `Column`, a list of `Column` or `None`.
Returns:
`x` if it is a list of Column, `[x]` if `x` is a `Column`, `[]` if x is
`None`.
Raises:
TypeError: `x` is not a `Column` a list of `Column` or `None`.
"""
if x is None:
return []
elif isinstance(x, Column):
return [x]
elif isinstance(x, (list, tuple)):
for i, y in enumerate(x):
if not isinstance(y, Column):
raise TypeError(
"Expected a tuple or list of Columns; entry %s has type %s." %
(i, type(y).__name__))
return list(x)
raise TypeError("Expected a Column or list of Column; got %s" %
type(x).__name__)
def _make_tuple_of_string(x):
"""Converts `x` into a list of `str` if possible.
Args:
x: a `str`, a list of `str`, a tuple of `str`, or `None`.
Returns:
`x` if it is a tuple of str, `tuple(x)` if it is a list of str,
`(x)` if `x` is a `str`, `()` if x is `None`.
Raises:
TypeError: `x` is not a `str`, a list or tuple of `str`, or `None`.
"""
if x is None:
return ()
elif isinstance(x, str):
return (x,)
elif isinstance(x, (list, tuple)):
for i, y in enumerate(x):
if not isinstance(y, str):
raise TypeError(
"Expected a tuple or list of strings; entry %s has type %s." %
(i, type(y).__name__))
return x
raise TypeError("Expected a string or list of strings or tuple of strings; " +
"got %s" % type(x).__name__)
def parameter(func):
"""Tag functions annotated with `@parameter` for later retrieval.
Note that all `@parameter`s are automatically `@property`s as well.
Args:
func: the getter function to tag and wrap
Returns:
A `@property` whose getter function is marked with is_parameter = True
"""
func.is_parameter = True
return property(func)
class Transform(object):
"""A function from a list of `Column` to a namedtuple of `Column`.
Transforms map zero or more columns of a DataFrame to new columns.
"""
__metaclass__ = ABCMeta
def __init__(self):
self._return_type = None
@abstractproperty
def name(self):
"""Name of the transform."""
raise NotImplementedError()
def parameters(self):
"""A dict of names to values of properties marked with `@parameter`."""
property_param_names = [name
for name, func in inspect.getmembers(type(self))
if (hasattr(func, "fget") and hasattr(
getattr(func, "fget"), "is_parameter"))]
return {name: getattr(self, name) for name in property_param_names}
@abstractproperty
def input_valency(self):
"""The number of `Column`s that the `Transform` should expect as input.
`None` indicates that the transform can take a variable number of inputs.
This function should depend only on `@parameter`s of this `Transform`.
Returns:
The number of expected inputs.
"""
raise NotImplementedError()
@property
def output_names(self):
"""The names of `Column`s output by the `Transform`.
This function should depend only on `@parameter`s of this `Transform`.
Returns:
A tuple of names of outputs provided by this Transform.
"""
return _make_tuple_of_string(self._output_names)
@abstractproperty
def _output_names(self):
"""The names of `Column`s output by the `Transform`.
This function should depend only on `@parameter`s of this `Transform`.
Returns:
Names of outputs provided by this Transform, as a string, tuple, or list.
"""
raise NotImplementedError()
@property
def return_type(self):
"""Provides a namedtuple type which will be used for output.
A Transform generates one or many outputs, named according to
_output_names. This method creates (and caches) a namedtuple type using
those names as the keys. The Transform output is then generated by
instantiating an object of this type with corresponding values.
Note this output type is used both for `__call__`, in which case the
values are `TransformedColumn`s, and for `apply_transform`, in which case
the values are `Tensor`s.
Returns:
A namedtuple type fixing the order and names of the outputs of this
transform.
"""
if self._return_type is None:
# TODO(soergel): pylint 3 chokes on this, but it is legit and preferred.
# return_type_name = "%sReturnType" % type(self).__name__
return_type_name = "ReturnType"
self._return_type = collections.namedtuple(return_type_name,
self.output_names)
return self._return_type
def _check_output_tensors(self, output_tensors):
"""Helper for `build(...)`; verifies the output of `_build_transform`.
Args:
output_tensors: value returned by a call to `_build_transform`.
Raises:
TypeError: `transform_output` is not a list.
ValueError: `transform_output` does not match `output_names`.
"""
if not isinstance(output_tensors, self.return_type):
raise TypeError(
"Expected a NamedTuple of Tensors with elements %s; got %s." %
(self.output_names, type(output_tensors).__name__))
def __call__(self, input_columns=None):
"""Apply this `Transform` to the provided `Column`s, producing 'Column's.
Args:
input_columns: None, a `Column`, or a list of input `Column`s, acting as
positional arguments.
Returns:
A namedtuple of the output Columns.
Raises:
ValueError: `input_columns` does not have expected length
"""
input_columns = _make_list_of_column(input_columns)
if len(input_columns) != self.input_valency:
raise ValueError("Expected %s input Columns but received %s." %
(self.input_valency, len(input_columns)))
output_columns = [TransformedColumn(input_columns, self, output_name)
for output_name in self.output_names]
# pylint: disable=not-callable
return self.return_type(*output_columns)
def apply_transform(self, input_columns, cache=None):
"""Apply this `Transform` to the provided `Column`s, producing 'Tensor's.
Args:
input_columns: None, a `Column`, or a list of input `Column`s, acting as
positional arguments.
cache: a dict from Column reprs to Tensors.
Returns:
A namedtuple of the output Tensors.
Raises:
ValueError: `input_columns` does not have expected length
"""
# pylint: disable=not-callable
if cache is None:
cache = {}
if len(input_columns) != self.input_valency:
raise ValueError("Expected %s input Columns but received %s." %
(self.input_valency, len(input_columns)))
input_tensors = [input_column.build(cache)
for input_column in input_columns]
# Note we cache each output individually, not just the entire output
# tuple. This allows using the graph as the cache, since it can sensibly
# cache only individual Tensors.
output_reprs = [TransformedColumn.make_repr(input_columns, self,
output_name)
for output_name in self.output_names]
output_tensors = [cache.get(output_repr) for output_repr in output_reprs]
if None in output_tensors:
result = self._apply_transform(input_tensors)
for output_name, output_repr in zip(self.output_names, output_reprs):
cache[output_repr] = getattr(result, output_name)
else:
result = self.return_type(*output_tensors)
self._check_output_tensors(result)
return result
@abstractmethod
def _apply_transform(self, input_tensors):
"""Applies the transformation to the `transform_input`.
Args:
input_tensors: a list of Tensors representing the input to
the Transform.
Returns:
A namedtuple of Tensors representing the transformed output.
"""
raise NotImplementedError()
def __str__(self):
return self.name
def __repr__(self):
parameters_sorted = ["%s: %s" % (repr(k), repr(v))
for k, v in sorted(self.parameters().items())]
parameters_joined = ", ".join(parameters_sorted)
return "%s({%s})" % (self.name, parameters_joined)

View File

@ -0,0 +1,18 @@
"""Tests for DataFrames."""
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -0,0 +1,122 @@
"""Mock DataFrame constituents for testing."""
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import ABCMeta
from tensorflow.contrib.learn.python import learn
class MockColumn(learn.Column):
"""A mock column for use in testing."""
def __init__(self, cachekey, mock_tensors):
super(MockColumn, self).__init__()
self._cachekey = cachekey
self._mock_tensors = mock_tensors
def build(self, cache):
return self._mock_tensors
def __repr__(self):
return self._cachekey
class MockTransform(learn.Transform):
"""A mock transform for use in testing."""
__metaclass__ = ABCMeta
def __init__(self, param_one, param_two):
super(MockTransform, self).__init__()
self._param_one = param_one
self._param_two = param_two
@property
def name(self):
return "MockTransform"
@learn.parameter
def param_one(self):
return self._param_one
@learn.parameter
def param_two(self):
return self._param_two
@property
def input_valency(self):
return 1
class MockZeroOutputTransform(MockTransform):
"""A mock transform for use in testing."""
_mock_output_names = []
def __init__(self, param_one, param_two):
super(MockZeroOutputTransform, self).__init__(param_one, param_two)
@property
def _output_names(self):
return MockZeroOutputTransform._mock_output_names
def _apply_transform(self, input_tensors):
# pylint: disable=not-callable
return self.return_type()
class MockOneOutputTransform(MockTransform):
"""A mock transform for use in testing."""
_mock_output_names = ["out1"]
def __init__(self, param_one, param_two):
super(MockOneOutputTransform, self).__init__(param_one, param_two)
@property
def _output_names(self):
return MockOneOutputTransform._mock_output_names
def _apply_transform(self, input_tensors):
# pylint: disable=not-callable
return self.return_type("Fake Tensor 1")
class MockTwoOutputTransform(MockTransform):
"""A mock transform for use in testing."""
_mock_output_names = ["out1", "out2"]
@learn.parameter
def param_three(self):
return self._param_three
def __init__(self, param_one, param_two, param_three):
super(MockTwoOutputTransform, self).__init__(param_one, param_two)
self._param_three = param_three
@property
def _output_names(self):
return MockTwoOutputTransform._mock_output_names
def _apply_transform(self, input_tensors):
# pylint: disable=not-callable
return self.return_type("Fake Tensor 1", "Fake Tensor 2")

View File

@ -0,0 +1,68 @@
"""Tests of the Column class."""
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.learn.python import learn
from tensorflow.contrib.learn.python.learn.tests.dataframe import mocks
class TransformedColumnTest(tf.test.TestCase):
"""Test of `TransformedColumn`."""
def test_repr(self):
col = learn.TransformedColumn(
[mocks.MockColumn("foobar", [])],
mocks.MockTwoOutputTransform("thb", "nth", "snt"), "qux")
# note params are sorted by name
expected = ("MockTransform({'param_one': 'thb', 'param_three': 'snt', "
"'param_two': 'nth'})"
"(foobar)[qux]")
self.assertEqual(expected, repr(col))
def test_build_no_output(self):
def create_no_output_column():
return learn.TransformedColumn(
[mocks.MockColumn("foobar", [])],
mocks.MockZeroOutputTransform("thb", "nth"), None)
self.assertRaises(ValueError, create_no_output_column)
def test_build_single_output(self):
col = learn.TransformedColumn(
[mocks.MockColumn("foobar", [])],
mocks.MockOneOutputTransform("thb", "nth"), "out1")
result = col.build()
expected = "Fake Tensor 1"
self.assertEqual(expected, result)
def test_build_multiple_output(self):
col = learn.TransformedColumn(
[mocks.MockColumn("foobar", [])],
mocks.MockTwoOutputTransform("thb", "nth", "snt"), "out2")
result = col.build()
expected = "Fake Tensor 2"
self.assertEqual(expected, result)
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,97 @@
"""Tests of the DataFrame class."""
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.learn.python import learn
from tensorflow.contrib.learn.python.learn.tests.dataframe import mocks
def setup_test_df():
"""Create a dataframe populated with some test columns."""
df = learn.DataFrame()
df["a"] = learn.TransformedColumn(
[mocks.MockColumn("foobar", [])],
mocks.MockTwoOutputTransform("iue", "eui", "snt"), "out1")
df["b"] = learn.TransformedColumn(
[mocks.MockColumn("foobar", [])],
mocks.MockTwoOutputTransform("iue", "eui", "snt"), "out2")
df["c"] = learn.TransformedColumn(
[mocks.MockColumn("foobar", [])],
mocks.MockTwoOutputTransform("iue", "eui", "snt"), "out1")
return df
class DataFrameTest(tf.test.TestCase):
"""Test of `DataFrame`."""
def test_create(self):
df = setup_test_df()
self.assertEqual(df.columns(), frozenset(["a", "b", "c"]))
def test_select(self):
df = setup_test_df()
df2 = df.select(["a", "c"])
self.assertEqual(df2.columns(), frozenset(["a", "c"]))
def test_get_item(self):
df = setup_test_df()
c1 = df["b"]
self.assertEqual("Fake Tensor 2", c1.build())
def test_set_item_column(self):
df = setup_test_df()
self.assertEqual(3, len(df))
col1 = mocks.MockColumn("QuackColumn", [])
df["quack"] = col1
self.assertEqual(4, len(df))
col2 = df["quack"]
self.assertEqual(col1, col2)
def test_set_item_column_multi(self):
df = setup_test_df()
self.assertEqual(3, len(df))
col1 = mocks.MockColumn("QuackColumn", [])
col2 = mocks.MockColumn("MooColumn", [])
df["quack", "moo"] = [col1, col2]
self.assertEqual(5, len(df))
col3 = df["quack"]
self.assertEqual(col1, col3)
col4 = df["moo"]
self.assertEqual(col2, col4)
def test_set_item_pandas(self):
# TODO(jamieas)
pass
def test_set_item_numpy(self):
# TODO(jamieas)
pass
def test_build(self):
df = setup_test_df()
result = df.build()
expected = {"a": "Fake Tensor 1",
"b": "Fake Tensor 2",
"c": "Fake Tensor 1"}
self.assertEqual(expected, result)
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,91 @@
"""Tests of the Transform class."""
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.learn.python import learn
from tensorflow.contrib.learn.python.learn.dataframe.transform import _make_list_of_column
from tensorflow.contrib.learn.python.learn.tests.dataframe import mocks
class TransformTest(tf.test.TestCase):
"""Tests of the Transform class."""
def test_make_list_of_column(self):
col1 = mocks.MockColumn("foo", [])
col2 = mocks.MockColumn("bar", [])
self.assertEqual([], _make_list_of_column(None))
self.assertEqual([col1], _make_list_of_column(col1))
self.assertEqual([col1], _make_list_of_column([col1]))
self.assertEqual([col1, col2], _make_list_of_column([col1, col2]))
self.assertEqual([col1, col2], _make_list_of_column((col1, col2)))
def test_cache(self):
z = mocks.MockColumn("foobar", [])
t = mocks.MockTwoOutputTransform("thb", "nth", "snt")
cache = {}
t.apply_transform([z], cache)
self.assertEqual(2, len(cache))
expected_keys = [
"MockTransform("
"{'param_one': 'thb', 'param_three': 'snt', 'param_two': 'nth'})"
"(foobar)[out1]",
"MockTransform("
"{'param_one': 'thb', 'param_three': 'snt', 'param_two': 'nth'})"
"(foobar)[out2]"]
self.assertEqual(expected_keys, sorted(cache.keys()))
def test_parameters(self):
t = mocks.MockTwoOutputTransform("a", "b", "c")
self.assertEqual({"param_one": "a", "param_three": "c", "param_two": "b"},
t.parameters())
def test_parameters_inherited_combined(self):
t = mocks.MockTwoOutputTransform("thb", "nth", "snt")
expected = {"param_one": "thb", "param_two": "nth", "param_three": "snt"}
self.assertEqual(expected, t.parameters())
def test_return_type(self):
t = mocks.MockTwoOutputTransform("a", "b", "c")
rt = t.return_type
self.assertEqual("ReturnType", rt.__name__)
self.assertEqual(("out1", "out2"), rt._fields)
def test_call(self):
t = mocks.MockTwoOutputTransform("a", "b", "c")
# MockTwoOutputTransform has input valency 1
input1 = mocks.MockColumn("foobar", [])
out1, out2 = t([input1]) # pylint: disable=not-callable
self.assertEqual(learn.TransformedColumn, type(out1))
# self.assertEqual(out1.transform, t)
# self.assertEqual(out1.output_name, "output1")
self.assertEqual(learn.TransformedColumn, type(out2))
# self.assertEqual(out2.transform, t)
# self.assertEqual(out2.output_name, "output2")
if __name__ == "__main__":
tf.test.main()